# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Christian Reimers
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
"""
Neural network architectures for phenological and vegetation dynamics modeling.
This module contains a collection of recurrent, transformer-based, and
feed-forward neural network models originally developed for vegetation
greenness prediction from meteorological and environmental forcing data.
The architectures range from simple baselines to multi-stage transformer
models that explicitly separate environmental stress estimation from
phenological state prediction.
The implementations are adapted from work by Christian Reimers and
collaborators at the Max Planck Institute for Biogeochemistry and are
used within the PhenoCam LAI modeling framework.
Implemented Models
------------------
CombinedModel
Two-stage architecture that first estimates an intermediate
environmental stress signal using a transformer module and then
combines this signal with plant functional type (PFT) information
before phenological prediction.
BiTransformer
Improved bidirectional transformer architecture that combines
environmental stress estimation with causal temporal encoding and
configurable plant functional type inputs. This model serves as the
primary transformer implementation for vegetation prediction tasks.
Supporting Components
---------------------
PositionalEncoding
Sinusoidal positional encoding module following the formulation
introduced in "Attention Is All You Need" (Vaswani et al., 2017).
Provides temporal position information to transformer-based models.
Model Inputs
------------
Most models expect input tensors of shape:
(batch_size, sequence_length, n_features)
where each timestep contains meteorological, environmental, and
optionally plant functional type (PFT) information.
Model Outputs
-------------
Models generally return predictions of shape:
(batch_size, sequence_length, output_dim)
where ``output_dim`` typically represents one or more vegetation
state variables such as greenness, LAI, or intermediate stress
indicators.
Notes
-----
* Several architectures support a ``return_stress`` mode that exposes
intermediate latent stress representations used internally by the model.
* Transformer-based models employ causal attention masks to prevent
future information leakage during sequence prediction.
* Plant Functional Type (PFT) variables are assumed to occupy the final
feature channels of the input tensor when required.
* Within the PhenoCam LAI workflow, wrapper modules may permute tensor
dimensions to match the project's standard convention:
(batch_size, feature_channels, sequence_length)
while the original models operate on:
(batch_size, sequence_length, feature_channels)
References
----------
Original vegetation prediction architectures developed at the
Max Planck Institute for Biogeochemistry.
"""
import torch
[docs]
class PositionalEncoding(torch.nn.Module):
"""
Sinusoidal positional encoding module.
Adds fixed positional encodings to token embeddings as described in
the Transformer architecture. The positional encodings are computed
using sine and cosine functions of different frequencies and stored
as a non-trainable buffer.
Parameters
----------
dim_model : int
Dimensionality of the token embeddings.
dropout_p : float
Dropout probability applied after adding positional encodings.
max_len : int
Maximum sequence length for which positional encodings are
precomputed.
Attributes
----------
dropout : torch.nn.Dropout
Dropout layer applied to the sum of token embeddings and
positional encodings.
pos_encoding : torch.Tensor
Tensor of shape (max_len, 1, dim_model) containing the
precomputed positional encodings.
Notes
-----
For position ``pos`` and embedding dimension ``i``:
- Even dimensions:
``PE(pos, 2i) = sin(pos * scale_i)``
- Odd dimensions:
``PE(pos, 2i + 1) = cos(pos * scale_i)``
where
``scale_i = 10000^(2i / dim_model)``
Examples
--------
>>> pe = PositionalEncoding(dim_model=512, dropout_p=0.1, max_len=5000)
>>> x = torch.randn(20, 32, 512)
>>> y = pe(x)
>>> y.shape
torch.Size([20, 32, 512])
"""
[docs]
def __init__(self, dim_model, dropout_p, max_len):
super().__init__()
self.dropout = torch.nn.Dropout(dropout_p)
pos_encoding = torch.zeros(max_len, dim_model)
positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1)
division_term = 10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)
pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
self.register_buffer("pos_encoding", pos_encoding)
[docs]
def forward(self, token_embeding: torch.tensor) -> torch.tensor:
"""
Add positional encodings to token embeddings.
Parameters
----------
token_embeding : torch.Tensor
Input embeddings of shape (sequence_length, batch_size,
dim_model).
Returns
-------
torch.Tensor
Embeddings with positional encodings added and dropout applied.
Shape is identical to the input:
(sequence_length, batch_size, dim_model).
"""
return self.dropout(
token_embeding + self.pos_encoding[: token_embeding.size(0), :]
)
[docs]
class CombinedModel(torch.nn.Module):
"""
Hybrid Transformer-based model for sequence prediction with auxiliary features.
This model combines:
- A linear projection of input features into model space
- A Transformer module for sequence-to-sequence processing
- Additional concatenated features (PFT inputs)
- A stack of Transformer encoder layers with causal masking
- Final linear projection to output space
Parameters
----------
input_dim : int, default=26
Number of input features per timestep.
hidden_dim : int, default=1024
Feedforward dimension inside Transformer encoder layers.
hidden_dim_trans : int, default=1024
Feedforward dimension inside the built-in Transformer module.
output_dim : int, default=2
Number of output features per timestep.
d_model : int, default=32
Internal embedding dimension used throughout the model.
nr_blocks : int, default=3
Number of stacked TransformerEncoderLayer blocks.
Attributes
----------
lin1 : torch.nn.Linear
Projects input features into model dimension (d_model).
trans : torch.nn.Transformer
Transformer module applied to the projected input.
lin2 : torch.nn.Linear
Reduces Transformer output to a single feature.
lin3 : torch.nn.Linear
Projects concatenated features into model dimension.
encoder : torch.nn.ModuleList
Stack of TransformerEncoderLayer blocks.
lin4 : torch.nn.Linear
Final projection from model dimension to output space (2).
positional_encoder : PositionalEncoding
Sinusoidal positional encoding module.
Notes
-----
- The model uses a causal (lower-triangular) attention mask to prevent
attending to future timesteps.
- Input tensors are expected in shape:
(batch_size, sequence_length, input_dim)
- Internally, tensors are permuted to:
(sequence_length, batch_size, d_model)
for Transformer processing.
"""
[docs]
def __init__(
self,
input_dim=26,
hidden_dim=1024,
hidden_dim_trans=1024,
output_dim=2,
d_model=32,
nr_blocks=3,
):
super().__init__()
self.lin1 = torch.nn.Linear(input_dim, d_model)
self.trans = torch.nn.Transformer(
num_encoder_layers=1,
num_decoder_layers=1,
d_model=d_model,
dim_feedforward=hidden_dim_trans,
)
self.lin2 = torch.nn.Linear(d_model, 1)
self.lin3 = torch.nn.Linear(11, d_model)
self.encoder = torch.nn.ModuleList(
[
torch.nn.TransformerEncoderLayer(
d_model=d_model,
nhead=4,
dim_feedforward=hidden_dim,
dropout=0,
)
for i in range(nr_blocks)
]
)
self.lin4 = torch.nn.Linear(d_model, 2)
self.positional_encoder = PositionalEncoding(
dim_model=d_model, dropout_p=0, max_len=730
)
[docs]
def forward(self, x, return_stress=False):
"""
Forward pass of the model using full input sequence.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, input_dim).
return_stress : bool, default=False
If True, also returns intermediate representation after lin2.
Returns
-------
torch.Tensor or tuple of torch.Tensor
If return_stress=False:
Output tensor of shape
(batch_size, sequence_length, 2)
If return_stress=True:
Tuple (output, stress_tensor)
where:
- output: final predictions (B, T, 2)
- stress_tensor: intermediate representation (B, T, 1)
"""
pft = x[:, :, -10:]
size = x.size(1)
mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
mask = mask.float()
mask = mask.masked_fill(mask == 0, float("-inf")) # Convert zeros to -inf
mask = mask.masked_fill(mask == 1, 0.0) # Convert ones to 0
x = self.lin1(x)
x = self.trans(x, x)
x = self.lin2(x)
if return_stress:
y = x.clone() # <-- clone, not deepcopy
x = torch.concatenate([x, pft], axis=-1)
x = x.permute(1, 0, 2)
x = self.lin3(x)
for i in range(len(self.encoder)):
x = self.encoder[i](x, mask)
x = self.lin4(x)
x = x.permute(1, 0, 2)
if return_stress:
return x, y
return x
[docs]
def forward_from_stress(self, x, pft):
"""
Forward pass starting from intermediate stress representation.
This bypasses the initial Transformer and lin1/lin2 layers,
and instead continues processing from a reduced representation
combined with auxiliary PFT features.
Parameters
----------
x : torch.Tensor
Stress-like representation of shape (batch_size, sequence_length, 1).
pft : torch.Tensor
Auxiliary features of shape (batch_size, sequence_length, 10).
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, sequence_length, 2).
"""
size = x.size(1)
mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
mask = mask.float()
mask = mask.masked_fill(mask == 0, float("-inf")) # Convert zeros to -inf
mask = mask.masked_fill(mask == 1, 0.0) # Convert ones to 0
x = torch.concatenate([x, pft], axis=-1)
x = x.permute(1, 0, 2)
x = self.lin3(x)
for i in range(len(self.encoder)):
x = self.encoder[i](x, mask)
x = self.lin4(x)
return x.permute(1, 0, 2)
# ___Bi Transformer model________