# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Stefan Barbu, Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
"""
LAI Prediction Dataset.
PyTorch dataset utilities for predicting Leaf Area Index (LAI) from
daily meteorological, site, and vegetation characteristics derived from
PhenoCam observations.
The module provides functionality to:
* Load and preprocess site-level time series data.
* Engineer derived phenological predictors such as growing degree days
(GDD), chilling degree days (CDD), forcing metrics, and related indices.
* Compute and apply feature normalization statistics.
* Optionally normalize LAI targets globally or using site-specific
minimum/maximum values.
* Construct sliding-window datasets for supervised learning.
* Support multiple feature configurations, including meteorological-only,
site-only, or combined models.
* Generate datasets for single-day prediction, multi-day sequence
prediction, full-year forecasting, or residual learning workflows.
* Create train/validation splits by site and/or year.
Dataset Structure
-----------------
Input samples consist of a fixed-length sequence of daily observations
ending on a target day.
Typical configuration:
* Input:
``(n_features, seq_length)``
where ``seq_length`` is typically 365 days.
* Output:
``(1,)``
LAI value on the final day of the sequence.
Alternative modes support:
* Multi-day targets:
``(1, n_target_days)``
* Full-year prediction:
``(1, 365)``
Feature Groups
--------------
Dynamic features
Daily meteorological variables and derived phenological metrics,
including temperature, precipitation, radiation, snow water
equivalent, vapor pressure deficit, and growing degree day indices.
Static features
Site-level attributes such as climatology, geographic coordinates,
soil properties, elevation, and terrain characteristics.
Cyclic features
Sine/cosine encodings of day-of-year.
PFT features
One-hot encoding of plant functional type (PFT).
Normalization
-------------
Feature normalization statistics are computed from the training sites.
Dynamic variables can optionally be log-transformed prior to computing
means and standard deviations. Static variables are normalized using
cross-site statistics.
Targets may be normalized either:
* Globally using dataset-wide LAI mean and standard deviation.
* Per site using externally supplied LAI minimum and maximum values.
Main Components
---------------
compute_norm_stats
Compute feature normalization statistics from training sites.
load_norm_stats
Load previously computed normalization statistics.
load_site
Load and preprocess a site-level CSV file.
PhenoCamDataset
PyTorch ``Dataset`` implementation providing sliding-window LAI
prediction samples.
split_sites_by_fraction
Create leave-site-out train/validation splits.
split_by_year
Create train/validation datasets using year-based target splits.
Notes
-----
This implementation is adapted from the RTnn data preprocessing workflow
for daily PhenoCam time series and is designed for ecological forecasting
and vegetation phenology modeling.
"""
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from typing import Dict, List, Optional, Tuple
from .feature_engineering import add_derived_features
# ── Feature definitions ──────────────────────────────────────────────────────
DYNAMIC_FEATURES = ["tmin", "tmax", "daylength", "vpd", "prcp", "srad", "swe"]
DERIVED_FEATURES = [
"gdd_0",
"gdd_5",
"gdd_10",
"cdd",
"botta_threshold",
"botta_forcing",
"ncd",
]
SITE_FEATURES = ["lat", "lon", "elev"] # "lat", "lon", "elevation"
STATIC_FEATURES = ["mat", "map"] # mean annual temp and precip
SOIL_FEATURES = ["clay", "sand", "silt", "ph"]
SLOPE_FEATURES = ["slope"]
add_site_features = False
add_GDD_features = True
add_soil_features = False
add_slope_features = False
add_cyclic_features = False
if add_site_features:
STATIC_FEATURES += SITE_FEATURES
if add_GDD_features:
DYNAMIC_FEATURES += DERIVED_FEATURES
if add_soil_features:
STATIC_FEATURES += SOIL_FEATURES
if add_slope_features:
STATIC_FEATURES += SLOPE_FEATURES
CYCLIC_FEATURES = []
if add_cyclic_features:
CYCLIC_FEATURES = ["doy_sin", "doy_cos"]
LOG_TRANSFORM_FEATURES = {
"swe",
"vpd",
"prcp",
"gdd_0",
"gdd_5",
"gdd_10",
"cdd",
"elev",
}
# ── Helpers ──────────────────────────────────────────────────────────────────
[docs]
def extract_pft_and_site(filepath: str) -> Tuple[str, str]:
"""
Extract PFT code and site name from a filename like 'DB_asuhighlands.csv'.
Parameters
----------
filepath : str
Path to site CSV file.
Returns
-------
tuple of (str, str)
(pft_code, site_name), e.g. ('DB', 'asuhighlands').
"""
basename = os.path.splitext(os.path.basename(filepath))[0]
parts = basename.split("_", 1)
if len(parts) == 2:
return parts[0], parts[1]
return "UNK", basename
[docs]
def load_site(filepath: str) -> pd.DataFrame:
"""
Load and clean a single site CSV.
Sorts by (year, doy), adds cyclic day-of-year features, and forward-fills small meteo gaps.
Parameters
----------
filepath : str
Path to the CSV file.
Returns
-------
pd.DataFrame
Cleaned dataframe with all feature and target columns.
"""
df = pd.read_csv(filepath)
df = df.sort_values(["year", "doy"]).reset_index(drop=True)
# Cyclic day-of-year encoding
df["doy_sin"] = np.sin(2 * np.pi * df["doy"] / 365.25)
df["doy_cos"] = np.cos(2 * np.pi * df["doy"] / 365.25)
df = add_derived_features(df)
# Forward-fill small gaps in meteo (typically ≤1 day ) Not sure of why there still are missing meteo values -> look data_loader
for col in DYNAMIC_FEATURES:
df[col] = df[col].ffill().bfill()
return df
[docs]
def load_lai_norms(
norms_csv: str,
site_files: Optional[List[str]] = None,
) -> Dict[str, Dict[str, float]]:
"""
Load per-site LAI min, max (from external file).
Enables inter-site normalization: lai_norm = (lai - min) / (max - min).
"""
df = pd.read_csv(norms_csv)
key_col = df.columns[0]
lai_norms = {}
for _, row in df.iterrows():
site_key = row[key_col]
entry = {"lai_min": float(row["lai_min"])}
entry["lai_max"] = float(row["lai_max"])
lai_norms[site_key] = entry
return lai_norms
# ── Normalization stats ──────────────────────────────────────────────────────
[docs]
def compute_norm_stats(
site_files: List[str],
save_path: Optional[str] = None,
) -> Dict[str, Dict[str, float]]:
"""
Compute per-feature mean and std across all training sites.
For features in LOG_TRANSFORM_FEATURES, stats are computed on log1p(x).
For static features (mat, map), stats are computed per-site (one value
per site) to avoid inflating variance with repeated rows.
Parameters
----------
site_files : list of str
Paths to training site CSVs.
save_path : str, optional
If provided, save stats as JSON to this path.
Returns
-------
dict
{feature_name: {"mean": float, "std": float}} for all features.
"""
# Accumulate dynamic features from all sites
dynamic_values = {feat: [] for feat in DYNAMIC_FEATURES + CYCLIC_FEATURES}
static_values = {feat: [] for feat in STATIC_FEATURES}
target_values = []
for f in site_files:
df = load_site(f)
for feat in DYNAMIC_FEATURES:
vals = df[feat].dropna().values
if feat in LOG_TRANSFORM_FEATURES:
vals = np.log1p(np.clip(vals, 0, None))
dynamic_values[feat].append(vals)
for feat in CYCLIC_FEATURES:
dynamic_values[feat].append(df[feat].values)
# Static: take single value per site
for feat in STATIC_FEATURES:
val = df[feat].dropna()
if len(val) > 0:
static_values[feat].append(val.iloc[0])
target_values.append(df["LAI"].dropna().values)
stats = {}
# Dynamic + cyclic features: global mean/std
for feat in DYNAMIC_FEATURES + CYCLIC_FEATURES:
all_vals = np.concatenate(dynamic_values[feat])
stats[feat] = {"mean": float(np.mean(all_vals)), "std": float(np.std(all_vals))}
# Avoid zero std
if stats[feat]["std"] < 1e-8:
stats[feat]["std"] = 1.0
# Static features: cross-site mean/std
for feat in STATIC_FEATURES:
arr = np.array(static_values[feat])
stats[feat] = {"mean": float(np.mean(arr)), "std": float(np.std(arr))}
if stats[feat]["std"] < 1e-8:
stats[feat]["std"] = 1.0
# Target stats (for optional normalization)
all_target = np.concatenate(target_values)
stats["LAI"] = {
"mean": float(np.mean(all_target)),
"std": float(np.std(all_target)),
}
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
with open(save_path, "w") as f:
json.dump(stats, f, indent=2)
return stats
[docs]
def load_norm_stats(path: str) -> Dict[str, Dict[str, float]]:
"""Load normalization stats from a JSON file."""
with open(path) as f:
return json.load(f)
# ── Dataset ──────────────────────────────────────────────────────────────────
[docs]
class PhenoCamDataset(Dataset):
"""
PyTorch dataset for LAI prediction.
Each sample is a 365-day sliding window of meteorological and site features,
with the target being LAI on the final day of the window.
The feature tensor has shape (feature_channels, 365) where feature_channels =
len(DYNAMIC_FEATURES) + len(CYCLIC_FEATURES) + len(STATIC_FEATURES) + n_pfts.
Parameters
----------
site_files : list of str
Paths to site CSV files (e.g. ['DB_asuhighlands.csv', ...]).
norm_stats : dict
Normalization statistics from compute_norm_stats().
seq_length : int
Window length in days. Default 365.
target_col : str
Target column name. Default 'LAI'.
normalize_target : bool
Whether to z-score the target. Default False.
pft_list : list of str, optional
Ordered list of all PFT codes for one-hot encoding.
If None, auto-discovered from site_files.
stride : int
Step between consecutive windows. Default 1 (every day).
years : list of int, optional
If provided, only use windows whose target day falls in these years.
Use this to split "predict year by year".
Examples
--------
> stats = compute_norm_stats(train_files)
> train_ds = PhenoCamDataset(train_files, stats, pft_list=['DB', 'EN', 'GR'])
> features, target = train_ds[0]
> features.shape
torch.Size([14, 365]) # 7 meteo + 2 cyclic + 2 static + 3 PFT
> target.shape
torch.Size([1])
"""
[docs]
def __init__(
self,
site_files: List[str],
norm_stats: Dict[str, Dict[str, float]],
seq_length: int = 365,
target_col: str = "LAI",
normalize_target: bool = True,
pft_list: Optional[List[str]] = None,
stride: int = 1,
lai_norms=None,
years: Optional[List[int]] = None,
n_target_days: int = 1,
residual_csv: Optional[str] = None,
random_stride: int = 0,
feature_mode: str = "all",
full_year: bool = False,
) -> None:
super().__init__()
self.seq_length = seq_length
self.target_col = target_col
self.norm_stats = norm_stats
self.normalize_target = normalize_target
self.stride = stride
self.n_target_days = n_target_days
self.residual_mode = residual_csv is not None
self.random_stride = random_stride
self.feature_mode = feature_mode
self.full_year = full_year
# For full_year mode: predict last 365 days from a 730-day window
self.pred_length = 365 if full_year else 1
# Load residuals from a previous prediction run
self._residual_lookup: Dict[Tuple[str, int], float] = {}
if residual_csv is not None:
res_df = pd.read_csv(residual_csv)
# residual = observed - predicted (what the first model missed)
if "lai_obs_norm" in res_df.columns and "lai_pred_norm" in res_df.columns:
for _, row in res_df.iterrows():
key = (str(row["site"]), int(row["day_index"]))
self._residual_lookup[key] = float(
row["lai_obs_norm"] - row["lai_pred_norm"]
)
else:
for _, row in res_df.iterrows():
key = (str(row["site"]), int(row["day_index"]))
self._residual_lookup[key] = float(row["lai_obs"] - row["lai_pred"])
# Debug: which sites are in the residual CSV
residual_sites = set(str(row["site"]) for _, row in res_df.iterrows())
print(
f"[residual] Loaded {len(self._residual_lookup)} entries "
f"from {len(residual_sites)} sites: {sorted(residual_sites)}"
)
# Discover PFTs
if pft_list is None:
pft_list = sorted(set(extract_pft_and_site(f)[0] for f in site_files))
self.pft_list = pft_list
self.pft_to_idx = {pft: i for i, pft in enumerate(pft_list)}
self.n_pfts = len(pft_list)
# Feature groups based on mode
# "all" : dynamic + cyclic + static + PFT (default)
# "site_only" : cyclic + static + PFT (climatology model)
# "meteo_only": dynamic + cyclic (anomaly model)
self.use_dynamic = feature_mode in ("all", "meteo_only")
self.use_static = feature_mode in ("all", "site_only")
self.use_pft = feature_mode in ("all", "site_only")
# Feature channel count
self.n_features = len(CYCLIC_FEATURES)
if self.use_dynamic:
self.n_features += len(DYNAMIC_FEATURES)
if self.use_static:
self.n_features += len(STATIC_FEATURES)
if self.use_pft:
self.n_features += self.n_pfts
# Build all samples: list of (site_idx, target_day_position)
self.samples: List[Tuple[int, int]] = []
self.site_data: List[Dict] = []
for file_idx, filepath in enumerate(site_files):
pft_code, site_name = extract_pft_and_site(filepath)
df = load_site(filepath)
feature_parts = []
# Dynamic features (meteo + derived)
if self.use_dynamic:
features_normed = np.zeros(
(len(df), len(DYNAMIC_FEATURES)), dtype=np.float32
)
for i, feat in enumerate(DYNAMIC_FEATURES):
vals = df[feat].values.copy()
if feat in LOG_TRANSFORM_FEATURES:
vals = np.log1p(np.clip(vals, 0, None))
features_normed[:, i] = (
vals - norm_stats[feat]["mean"]
) / norm_stats[feat]["std"]
feature_parts.append(features_normed)
# Cyclic features (always included)
cyclic_normed = np.zeros((len(df), len(CYCLIC_FEATURES)), dtype=np.float32)
for i, feat in enumerate(CYCLIC_FEATURES):
vals = df[feat].values
cyclic_normed[:, i] = (vals - norm_stats[feat]["mean"]) / norm_stats[
feat
]["std"]
feature_parts.append(cyclic_normed)
# Static features
if self.use_static:
static_normed = np.zeros(
(len(df), len(STATIC_FEATURES)), dtype=np.float32
)
for i, feat in enumerate(STATIC_FEATURES):
val = df[feat].dropna()
val = val.iloc[0] if len(val) > 0 else norm_stats[feat]["mean"]
static_normed[:, i] = (val - norm_stats[feat]["mean"]) / norm_stats[
feat
]["std"]
feature_parts.append(static_normed)
# PFT one-hot
if self.use_pft:
pft_onehot = np.zeros((len(df), self.n_pfts), dtype=np.float32)
if pft_code in self.pft_to_idx:
pft_onehot[:, self.pft_to_idx[pft_code]] = 1.0
feature_parts.append(pft_onehot)
# Concatenate selected features: (n_days, n_features)
all_features = np.concatenate(feature_parts, axis=1)
# Target
target_vals = df[target_col].values.astype(np.float32)
site_key = f"{pft_code}_{site_name}"
lai_min = None
lai_max = None
if lai_norms is not None and site_key in lai_norms:
lai_min = lai_norms[site_key]["lai_min"]
lai_max = lai_norms[site_key]["lai_max"]
denom = lai_max - lai_min
if denom > 1e-8:
target_vals = (target_vals - lai_min) / denom
elif normalize_target:
target_vals = (target_vals - norm_stats["LAI"]["mean"]) / norm_stats[
"LAI"
]["std"]
# Year array for filtering
year_arr = df["year"].values
# Store preprocessed data for this site
self.site_data.append(
{
"features": all_features, # (n_days, n_features)
"target": target_vals, # (n_days,)
"years": year_arr,
"pft": pft_code,
"site": site_name,
"lai_min": lai_min,
"lai_max": lai_max,
}
)
# If residual mode, replace targets with residuals from first model
if self.residual_mode:
residual_vals = np.full_like(target_vals, np.nan)
n_matched = 0
for d in range(len(target_vals)):
key = (site_name, d)
if key in self._residual_lookup:
residual_vals[d] = self._residual_lookup[key]
n_matched += 1
self.site_data[-1]["target"] = residual_vals
if n_matched == 0:
print(
f"[residual] WARNING: site '{site_name}' — 0 matches "
f"(day_index range 0..{len(target_vals)-1}). "
f"Is this site in the predictions CSV?"
)
else:
print(
f"[residual] site '{site_name}': "
f"{n_matched}/{len(target_vals)} days matched"
)
# Build sample indices: every valid window position
n_days = len(df)
target_check = self.site_data[-1]["target"]
if self.full_year:
# Full-year mode: one sample per year, positioned at the last day
# of each year that has 730 days of features and 365 days of valid targets
for yr in sorted(df["year"].unique()):
yr_mask = year_arr == yr
yr_indices = np.where(yr_mask)[0]
if len(yr_indices) == 0:
continue
d = yr_indices[-1] # last day of this year
# Check enough history for the input window
if d < seq_length - 1:
continue
# Check 365 days of valid targets at the end
target_start = d - self.pred_length + 1
if target_start < 0:
continue
target_slice = target_check[target_start : d + 1]
if len(target_slice) < self.pred_length or np.any(
np.isnan(target_slice)
):
continue
if years is not None and yr not in years:
continue
self.samples.append((file_idx, d))
else:
# Standard mode: sliding window
effective_stride = stride if random_stride == 0 else 1
for d in range(seq_length - 1, n_days, effective_stride):
if np.isnan(target_check[d]):
continue
if n_target_days > 1:
t_start = d - n_target_days + 1
if t_start < seq_length - 1 or np.any(
np.isnan(target_check[t_start : d + 1])
):
continue
if years is not None and year_arr[d] not in years:
continue
self.samples.append((file_idx, d))
# For random_stride: group all valid samples by (site_idx, year),
# then subsample N per group
if self.random_stride > 0:
self._samples_by_site_year: Dict[
Tuple[int, int], List[Tuple[int, int]]
] = {}
for site_idx, d in self.samples:
yr = int(self.site_data[site_idx]["years"][d])
key = (site_idx, yr)
self._samples_by_site_year.setdefault(key, []).append((site_idx, d))
self._all_valid_count = len(self.samples)
self.resample() # initial random selection
else:
self._samples_by_site_year = None
# Final check
if self.residual_mode and len(self.samples) == 0:
dataset_sites = [extract_pft_and_site(f)[1] for f in site_files]
residual_sites = set(k[0] for k in self._residual_lookup.keys())
overlap = set(dataset_sites) & residual_sites
raise RuntimeError(
f"Residual mode: 0 valid samples! "
f"Dataset sites ({len(dataset_sites)}): {sorted(dataset_sites)[:5]}... "
f"Residual CSV sites ({len(residual_sites)}): {sorted(residual_sites)[:5]}... "
f"Overlap: {len(overlap)} sites. "
f"Did you run predict.py with --predict_sites all?"
)
@property
def feature_channels(self) -> int:
"""Number of input feature channels."""
return self.n_features
[docs]
def __len__(self) -> int:
return len(self.samples)
[docs]
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a single sample.
Parameters
----------
index : int
Sample index.
Returns
-------
features : torch.Tensor
Feature tensor of shape (feature_channels, seq_length)
target : torch.Tensor
Target tensor. Shape is (1,) if n_target_days=1,
or (1, n_target_days) if n_target_days>1
"""
site_idx, day_idx = self.samples[index]
site = self.site_data[site_idx]
# Extract window: [day_idx - seq_length + 1, day_idx + 1)
start = day_idx - self.seq_length + 1
window = site["features"][start : day_idx + 1] # (seq_length, n_features)
# Transpose to (n_features, seq_length) to match RTnn convention
features = torch.from_numpy(window.T.copy())
if self.full_year:
# Target: last pred_length (365) days of LAI in the window
target_start = day_idx - self.pred_length + 1
target_seq = site["target"][target_start : day_idx + 1] # (365,)
target = torch.from_numpy(target_seq.copy()).unsqueeze(0) # (1, 365)
elif self.n_target_days > 1:
# Target: last n days for gradient loss, e.g. [LAI(t-1), LAI(t)]
t_start = day_idx - self.n_target_days + 1
target_days = site["target"][t_start : day_idx + 1] # (n_target_days,)
target = torch.from_numpy(target_days.copy()).unsqueeze(
0
) # (1, n_target_days)
else:
# Target: scalar for the last day
target = torch.tensor([site["target"][day_idx]], dtype=torch.float32)
return features, target
[docs]
def resample(self):
"""
Randomly select `random_stride` samples per site-year.
Call this at the start of each training epoch for fresh random samples.
No-op if random_stride was not set.
"""
if self._samples_by_site_year is None:
return
rng = np.random.RandomState()
new_samples = []
for key, candidates in self._samples_by_site_year.items():
n = min(self.random_stride, len(candidates))
chosen = rng.choice(len(candidates), size=n, replace=False)
for i in chosen:
new_samples.append(candidates[i])
self.samples = new_samples
[docs]
def get_site_info(self, index: int) -> Dict:
"""
Get metadata for a sample (useful for evaluation / plotting).
Returns dict with keys: site, pft, year, doy_index.
"""
site_idx, day_idx = self.samples[index]
site = self.site_data[site_idx]
return {
"site": site["site"],
"pft": site["pft"],
"year": int(site["years"][day_idx]),
"day_index": day_idx,
"lai_min": site["lai_min"],
"lai_max": site["lai_max"],
}
# ── Convenience: build train/val split by site ──────────────────────────────
[docs]
def split_sites_by_fraction(
site_files: List[str],
val_fraction: float = 0.2,
seed: int = 42,
) -> Tuple[List[str], List[str]]:
"""
Split site files into train and validation sets (leave-site-out).
Parameters
----------
site_files : list of str
All site CSV paths.
val_fraction : float
Fraction of sites to hold out for validation.
seed : int
Random seed for reproducibility.
Returns
-------
tuple of (list, list)
(train_files, val_files)
"""
rng = np.random.RandomState(seed)
files = np.array(site_files)
n_val = max(1, int(len(files) * val_fraction))
indices = rng.permutation(len(files))
val_idx = indices[:n_val]
train_idx = indices[n_val:]
return files[train_idx].tolist(), files[val_idx].tolist()
[docs]
def split_by_year(
site_files: List[str],
norm_stats: Dict,
train_years: List[int],
val_years: List[int],
pft_list: List[str],
**kwargs,
) -> Tuple["PhenoCamDataset", "PhenoCamDataset"]:
"""
Create train and validation datasets split by target year.
All sites are used for both sets, but training windows target
train_years and validation windows target val_years.
Parameters
----------
site_files : list of str
All site CSV paths.
norm_stats : dict
Normalization statistics.
train_years : list of int
Years to include in training targets.
val_years : list of int
Years to include in validation targets.
pft_list : list of str
Ordered PFT codes.
**kwargs
Additional arguments passed to PhenoCamDataset.
Returns
-------
tuple of (PhenoCamDataset, PhenoCamDataset)
"""
train_ds = PhenoCamDataset(
site_files, norm_stats, pft_list=pft_list, years=train_years, **kwargs
)
val_ds = PhenoCamDataset(
site_files, norm_stats, pft_list=pft_list, years=val_years, **kwargs
)
return train_ds, val_ds
[docs]
def split_by_sites_years(
site_files: List[str],
norm_stats: Dict,
train_years: List[int],
val_years: List[int],
pft_list: List[str],
val_fraction: float = 0.1,
seed: int = 42,
**kwargs,
) -> Tuple["PhenoCamDataset", "PhenoCamDataset"]:
"""
Create train and validation datasets split by both sites and years.
Validation set is composed of all years for val_fraction of sites (randomly selected) and of the years in val_years for
the remaining sites. The training set is composed of the remaining years for the remaining sites.
Parameters
----------
site_files : list of str
All site CSV paths.
norm_stats : dict
Normalization statistics.
train_years : list of int
Years to include in training targets.
val_years : list of int
Years to include in validation targets.
pft_list : list of str
Ordered PFT codes.
val_fraction : float
Fraction of sites to hold out for validation.
seed : int
Random seed for reproducibility.
"""
return None, None # Todo