Source code for phenonn.prediction.predict

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Stefan Barbu, Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

"""
LAI Prediction — Year-by-Year Inference

Loads a trained model checkpoint and predicts the full annual LAI curve.
Supports both site-split and year-split models.

Usage
-----
# Predict on validation sites, all available years:
python -m phenocam.predict \
    --checkpoint ./runs/exp01/checkpoints/best_model.pth \
    --data_dir ./data/DB/

# Predict on all sites, specific years:
python -m phenocam.predict \
    --checkpoint ./runs/exp01/checkpoints/best_model.pth \
    --data_dir ./data/DB/ \
    --predict_sites all \
    --predict_years 2022,2023

# Predict on training sites only:
python -m phenocam.predict \
    --checkpoint ./runs/exp01/checkpoints/best_model.pth \
    --data_dir ./data/DB/ \
    --predict_sites train \
    --predict_years 2022
"""

import os
import argparse
import glob
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from phenonn.data.dataset import (
    PhenoCamDataset,
)
from phenonn.utils.model_loader import load_model
from phenonn.utils.diagnostics import (
    plot_pred_vs_obs,
    plot_gcc_curves,
    plot_gcc_curves_all,
)
from phenonn.utils.utils import EasyDict


[docs] def parse_args(): p = argparse.ArgumentParser(description="PhenoCam year-by-year inference") p.add_argument( "--checkpoint", type=str, required=True, help="Path to best_model.pth" ) p.add_argument( "--data_dir", type=str, required=True, help="Directory with site CSVs" ) p.add_argument( "--predict_years", type=str, default="all", help="Comma-separated years (e.g. '2022,2023') or 'all'", ) p.add_argument( "--predict_sites", type=str, default="val", choices=["val", "train", "all"], help="Which sites to predict on: 'val' (validation sites from " "site-split), 'train', or 'all'. Default 'val'.", ) p.add_argument( "--output_csv", type=str, default="predictions.csv", help="Where to save predictions", ) p.add_argument("--batch_size", type=int, default=128) return p.parse_args()
[docs] def run_prediction(): args = parse_args() # ── Load checkpoint ── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False) train_args = EasyDict(ckpt["args"]) norm_stats = ckpt["norm_stats"] pft_list = ckpt["pft_list"] split_mode = train_args.get("split_mode", "year") # Recover file lists from checkpoint (saved by main.py for site splits) saved_train_files = ckpt.get("train_files", None) saved_val_files = ckpt.get("val_files", None) lai_norms = ckpt.get("lai_norms", None) if lai_norms: print(f"Per-site LAI normalization: {len(lai_norms)} sites") else: print("No per-site LAI normalization") # ── Rebuild model ── # At inference we always want single-day output, so force n_target_days=1 train_args.n_target_days = 1 model = load_model(train_args).to(device) # Load weights — handle mismatch if trained with LastNDaysWrapper state_dict = ckpt["model_state_dict"] # Both SingleDayWrapper and LastNDaysWrapper store weights under base_model.* # so state_dict is compatible model.load_state_dict(state_dict) model.eval() print( f"Model loaded from {args.checkpoint} (epoch {ckpt['epoch']}, " f"val_R²={ckpt.get('val_r2', '?'):.4f})" ) print(f"Training split_mode: {split_mode}") # ── Resolve which site files to use ── all_site_files = sorted(glob.glob(os.path.join(args.data_dir, "*.csv"))) if not all_site_files: raise FileNotFoundError(f"No CSV files in {args.data_dir}") if args.predict_sites == "all": site_files = all_site_files print(f"Predicting on ALL {len(site_files)} sites") elif args.predict_sites == "val": if saved_val_files is not None: # Remap to current data_dir (paths may differ between train/predict) available = set(os.path.basename(f) for f in all_site_files) site_files = [ os.path.join(args.data_dir, os.path.basename(f)) for f in saved_val_files if os.path.basename(f) in available ] print(f"Predicting on {len(site_files)} VALIDATION sites (from checkpoint)") elif split_mode == "year": site_files = all_site_files print( f"Year-split model: all {len(site_files)} sites used (no site holdout)" ) else: print("WARNING: checkpoint has no saved val_files. Using all sites.") site_files = all_site_files elif args.predict_sites == "train": if saved_train_files is not None: available = set(os.path.basename(f) for f in all_site_files) site_files = [ os.path.join(args.data_dir, os.path.basename(f)) for f in saved_train_files if os.path.basename(f) in available ] print(f"Predicting on {len(site_files)} TRAINING sites (from checkpoint)") elif split_mode == "year": site_files = all_site_files print( f"Year-split model: all {len(site_files)} sites used (no site holdout)" ) else: print("WARNING: checkpoint has no saved train_files. Using all sites.") site_files = all_site_files if not site_files: raise RuntimeError( "No site files matched. Check --data_dir and --predict_sites." ) # ── Resolve which years to predict ── predict_years = None # None = all available years if args.predict_years.lower() != "all": predict_years = [int(y) for y in args.predict_years.split(",")] # ── Build dataset ── feature_mode = train_args.get("feature_mode", "all") full_year = train_args.get("full_year", False) if feature_mode != "all": print(f"Feature mode: {feature_mode}") if full_year: print("Full-year prediction mode (365 days per sample)") dataset = PhenoCamDataset( site_files, norm_stats, seq_length=train_args.seq_length, pft_list=pft_list, years=predict_years, stride=1, lai_norms=lai_norms, feature_mode=feature_mode, full_year=full_year, ) year_str = predict_years if predict_years else "all" print( f"Prediction samples: {len(dataset)} " f"({len(site_files)} sites × years {year_str})" ) if len(dataset) == 0: raise RuntimeError( "No prediction samples. Check year availability and seq_length." ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, ) # ── Predict ── all_preds = [] all_targets = [] all_meta = [] all_preds_norm = [] all_targets_norm = [] with torch.no_grad(): for i_batch, (features, targets) in enumerate(loader): preds = model(features.to(device)).cpu() batch_start = i_batch * args.batch_size for j in range(preds.size(0)): idx = batch_start + j meta = dataset.get_site_info(idx) lai_min = meta.get("lai_min") lai_max = meta.get("lai_max") site_idx, day_idx = dataset.samples[idx] if full_year: # preds: (B, 1, 365), targets: (B, 1, 365) pred_seq = preds[j, 0, :].numpy() # (365,) tgt_seq = targets[j, 0, :].numpy() # (365,) pred_length = len(pred_seq) for k in range(pred_length): d = day_idx - pred_length + 1 + k site_data = dataset.site_data[site_idx] yr = int(site_data["years"][d]) row_meta = { "site": meta["site"], "pft": meta["pft"], "year": yr, "day_index": d, "lai_min": lai_min, "lai_max": lai_max, } all_meta.append(row_meta) all_preds_norm.append(float(pred_seq[k])) all_targets_norm.append(float(tgt_seq[k])) if lai_min is not None and lai_max is not None: denom = lai_max - lai_min all_preds.append(float(pred_seq[k]) * denom + lai_min) all_targets.append(float(tgt_seq[k]) * denom + lai_min) else: all_preds.append(float(pred_seq[k])) all_targets.append(float(tgt_seq[k])) else: # Standard: preds (B, 1), targets (B, 1) pred_norm = preds[j, 0].item() tgt_norm = targets[j, 0].item() all_meta.append(meta) all_preds_norm.append(pred_norm) all_targets_norm.append(tgt_norm) if lai_min is not None and lai_max is not None: denom = lai_max - lai_min all_preds.append(pred_norm * denom + lai_min) all_targets.append(tgt_norm * denom + lai_min) else: all_preds.append(pred_norm) all_targets.append(tgt_norm) # ── Assemble results ── df = pd.DataFrame(all_meta) df["lai_pred"] = all_preds df["lai_obs"] = all_targets df["error"] = df["lai_pred"] - df["lai_obs"] df["lai_pred_norm"] = all_preds_norm df["lai_obs_norm"] = all_targets_norm # Per-site, per-year stats print("\n── Per-site, per-year metrics ──") for (site, year), g in df.groupby(["site", "year"]): rmse = np.sqrt(np.mean(g["error"] ** 2)) ss_res = np.sum(g["error"] ** 2) ss_tot = np.sum((g["lai_obs"] - g["lai_obs"].mean()) ** 2) r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0 print(f" {site:25s} {year} RMSE={rmse:.5f} R²={r2:.4f} n={len(g)}") site_r2s = [] for site, g in df.groupby("site"): ss_res = np.sum((g.lai_obs - g.lai_pred) ** 2) ss_tot = np.sum((g.lai_obs - g.lai_obs.mean()) ** 2) if ss_tot > 0: site_r2s.append(1 - ss_res / ss_tot) print(f"Median per-site R²: {np.median(site_r2s):.3f}") # Overall RMSE errors = np.asarray(all_preds) - np.asarray(all_targets) rmse_all = float(np.sqrt(np.mean(errors**2))) print(f"\nOverall RMSE: {rmse_all:.5f}") # Save CSV os.makedirs(os.path.dirname(args.output_csv) or ".", exist_ok=True) df.to_csv(args.output_csv, index=False) print(f"Predictions saved to {args.output_csv}") # Pred-vs-obs scatter plot_filename = os.path.splitext(args.output_csv)[0] + "_pred_vs_obs.png" plot_pred_vs_obs( df["lai_pred"].values, df["lai_obs"].values, filename=plot_filename, title=f"Predicted vs observed — {len(df):,} points", ) # LAI annual curves for low / medium / high R² sites if df["site"].nunique() >= 3: curves_filename = os.path.splitext(args.output_csv)[0] + "_gcc_curves.png" selected = plot_gcc_curves( df, filename=curves_filename, site_col="site", year_col="year", doy_col="day_index", ) print(f"Selected sites for curve plot: {selected}") # LAI curves for ALL sites (grid sorted by R²) if df["site"].nunique() >= 1: all_curves_filename = ( os.path.splitext(args.output_csv)[0] + "_gcc_curves_all.png" ) plot_gcc_curves_all( df, filename=all_curves_filename, site_col="site", year_col="year", doy_col="day_index", ) # Normalized pred-vs-obs scatter if lai_norms: plot_filename = os.path.splitext(args.output_csv)[0] + "_pred_vs_obs_norm.png" plot_pred_vs_obs( df["lai_pred_norm"].values, df["lai_obs_norm"].values, filename=plot_filename, title=f"Predicted vs observed (normalized) — {len(df):,} points", ) # LAI curves for ALL sites, normalized (if norms available) if df["site"].nunique() >= 1: if lai_norms: all_curves_norm_filename = ( os.path.splitext(args.output_csv)[0] + "_gcc_curves_all_norm.png" ) plot_gcc_curves_all( df, filename=all_curves_norm_filename, site_col="site", year_col="year", doy_col="day_index", pred_col="lai_pred_norm", obs_col="lai_obs_norm", )
if __name__ == "__main__": run_prediction()