Source code for phenonn.prediction.predict_flat

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Stefan Barbu, Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

"""
LAI Prediction — Flat CSV Inference

Loads a trained checkpoint from main_flat.py and predicts LAI at the 36
observation days (5th, 15th, 25th of each month) for every (site_id, year)
pair in the chosen site set.

Output
------
predictions_flat.csv : str
    One row per (site, year, observation day)
pred_vs_obs.png : str
    Scatter plot on real LAI values
pred_vs_obs_norm.png : str
    Scatter plot on normalized values
lai_curves.png : str
    Annual curves for 3 representative sites
lai_curves_all.png : str
    Annual curves for all sites (grid)

Usage
-----
Basic prediction:
    python -m phenonn.prediction.predict_flat \\
        --checkpoint runs/exp_flat/checkpoints/best_model.pth \\
        --features_csv data/features.csv \\
        --target_csv data/targets.csv

Predict on all sites for specific years:
    python -m phenonn.prediction.predict_flat \\
        --checkpoint runs/exp_flat/checkpoints/best_model.pth \\
        --predict_sites all \\
        --predict_years 2002,2003

Examples
--------
Run prediction on validation sites:
    >>> python -m phenonn.prediction.predict_flat \\
    ...     --checkpoint runs/exp_flat/checkpoints/best_model.pth \\
    ...     --features_csv data/features.csv \\
    ...     --target_csv data/targets.csv

Run prediction on all sites for years 2020-2022:
    >>> python -m phenonn.prediction.predict_flat \\
    ...     --checkpoint runs/exp_flat/checkpoints/best_model.pth \\
    ...     --features_csv data/features.csv \\
    ...     --target_csv data/targets.csv \\
    ...     --predict_sites all \\
    ...     --predict_years 2020,2021,2022

Notes
-----
The input features CSV must contain daily data with columns:
    site_id, date, year, month, day, pft1_frac..pft15_frac,
    tmin, tmax, daylength, prcp, srad, vpd, swe

The target CSV must contain LAI observations for days 5, 15, 25 of each month:
    site_id, date, year, month, day, LAI
"""

import os
import argparse
import datetime
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from phenonn.data.dataset_flat import LAIDataset, get_site_ids
from phenonn.training.train_flat import build_model
from phenonn.utils.diagnostics import (
    plot_pred_vs_obs,
    plot_gcc_curves,
    plot_gcc_curves_all,
)
from phenonn.utils.utils import EasyDict

# Observation (month, day, doy) for each of the 36 annual positions
_OBS_DATES = [
    (month, day, datetime.date(2001, month, day).timetuple().tm_yday)
    for month in range(1, 13)
    for day in [5, 15, 25]
]


# ── CLI ───────────────────────────────────────────────────────────────────────


[docs] def parse_args(): p = argparse.ArgumentParser(description="LAI flat-CSV inference") p.add_argument( "--checkpoint", required=True, help="Path to best_model.pth from main_flat.py" ) p.add_argument( "--features_csv", default="", help="Features CSV (defaults to path stored in checkpoint)", ) p.add_argument( "--target_csv", default="", help="Targets CSV (defaults to path stored in checkpoint)", ) p.add_argument( "--predict_sites", default="val", choices=["val", "train", "all"], help="Which sites to predict on: 'val' (held-out), " "'train', or 'all'. Default: 'val'.", ) p.add_argument( "--predict_years", default="all", help="Comma-separated years (e.g. '2001,2002') or 'all'", ) p.add_argument( "--output_csv", default="predictions_flat.csv", help="Path for the output CSV" ) p.add_argument("--batch_size", type=int, default=64) return p.parse_args()
# ── Main ──────────────────────────────────────────────────────────────────────
[docs] def run_prediction_flat(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── Load checkpoint ── ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False) train_args = EasyDict(ckpt["args"]) norm_stats = ckpt["norm_stats"] train_ids = ckpt.get("train_site_ids", []) val_ids = ckpt.get("val_site_ids", []) print(f"Checkpoint : {args.checkpoint}") print(f"Epoch : {ckpt['epoch']}") print(f"val_loss : {ckpt.get('val_loss', float('nan')):.6f}") print(f"val_R² : {ckpt.get('val_r2', float('nan')):.4f}") print(f"val_RMSE : {ckpt.get('val_rmse', float('nan')):.5f}") # CSV paths: CLI overrides take precedence over checkpoint paths features_csv = args.features_csv or train_args.get("features_csv", "") target_csv = args.target_csv or train_args.get("target_csv", "") if not features_csv or not target_csv: raise ValueError( "Could not determine CSV paths. Provide --features_csv and --target_csv." ) # ── Rebuild model ── model = build_model(train_args).to(device) model.load_state_dict(ckpt["model_state_dict"]) model.eval() n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print( f"Model : {train_args.get('type')} + Every10DaysWrapper " f"({n_params:,} parameters)" ) # ── Resolve site IDs ── if args.predict_sites == "all": site_ids = get_site_ids(features_csv) print(f"Predicting on ALL {len(site_ids)} sites") elif args.predict_sites == "val": site_ids = val_ids if val_ids else get_site_ids(features_csv) label = "VALIDATION" if val_ids else "ALL (no val split found)" print(f"Predicting on {len(site_ids)} {label} sites") else: # train site_ids = train_ids if train_ids else get_site_ids(features_csv) label = "TRAINING" if train_ids else "ALL (no train split found)" print(f"Predicting on {len(site_ids)} {label} sites") # ── Resolve years ── predict_years = None if args.predict_years.lower() != "all": predict_years = [int(y) for y in args.predict_years.split(",")] year_label = predict_years if predict_years else "all available" # ── Build dataset ── dataset = LAIDataset( features_csv, target_csv, norm_stats, site_ids=site_ids, seq_length=train_args.get("seq_length", 720), years=predict_years, normalize_target=True, ) print(f"Samples : {len(dataset)} (years: {year_label})") if len(dataset) == 0: raise RuntimeError( "No prediction samples. Check that the requested sites and years " "are present in the CSVs and that seq_length <= available history." ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 ) # ── Inference ── lai_mean = norm_stats["LAI"]["mean"] lai_std = norm_stats["LAI"]["std"] rows = [] print("\nRunning inference...") with torch.no_grad(): for i_batch, (features, targets) in enumerate(loader): preds = model(features.to(device)).cpu() # (B, 1, 36) batch_start = i_batch * args.batch_size for j in range(preds.size(0)): idx = batch_start + j meta = dataset.get_site_info(idx) pred_norm = preds[j, 0, :].numpy() # (36,) tgt_norm = targets[j, 0, :].numpy() # (36,) pred_real = pred_norm * lai_std + lai_mean tgt_real = tgt_norm * lai_std + lai_mean for k, (month, day, doy) in enumerate(_OBS_DATES): rows.append( { "site_id": meta["site_id"], "year": meta["year"], "month": month, "day": day, "doy": doy, "lai_pred": float(pred_real[k]), "lai_obs": float(tgt_real[k]), "lai_pred_norm": float(pred_norm[k]), "lai_obs_norm": float(tgt_norm[k]), } ) df = pd.DataFrame(rows) df["error"] = df["lai_pred"] - df["lai_obs"] # ── Per-site, per-year metrics ── print("\n── Per-site, per-year metrics ──") for (site, year), g in df.groupby(["site_id", "year"]): rmse = float(np.sqrt(np.mean(g["error"] ** 2))) ss_res = float(np.sum(g["error"] ** 2)) ss_tot = float(np.sum((g["lai_obs"] - g["lai_obs"].mean()) ** 2)) r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan") print(f" {site:30s} {year} RMSE={rmse:.4f} R²={r2:.4f} n={len(g)}") # Per-site R² (pooled across all years) site_r2s = [] for site, g in df.groupby("site_id"): ss_res = float(np.sum((g["lai_obs"] - g["lai_pred"]) ** 2)) ss_tot = float(np.sum((g["lai_obs"] - g["lai_obs"].mean()) ** 2)) if ss_tot > 0: site_r2s.append(1.0 - ss_res / ss_tot) if site_r2s: print(f"\nMedian per-site R² : {np.median(site_r2s):.4f}") print(f"Mean per-site R² : {np.mean(site_r2s):.4f}") overall_rmse = float(np.sqrt(np.mean(df["error"] ** 2))) print(f"Overall RMSE : {overall_rmse:.4f}") print( f"Total predictions : {len(df):,} " f"({df['site_id'].nunique()} sites × {df['year'].nunique()} years × 36 obs)" ) # ── Save CSV ── os.makedirs(os.path.dirname(args.output_csv) or ".", exist_ok=True) df.to_csv(args.output_csv, index=False) print(f"\nPredictions saved to {args.output_csv}") base = os.path.splitext(args.output_csv)[0] n_pts = len(df) n_sites = df["site_id"].nunique() # ── Pred vs obs scatter (real values) ── plot_pred_vs_obs( df["lai_pred"].values, df["lai_obs"].values, filename=base + "_pred_vs_obs.png", title=f"Predicted vs observed LAI — {n_pts:,} points", ) # ── Pred vs obs scatter (normalized values) ── plot_pred_vs_obs( df["lai_pred_norm"].values, df["lai_obs_norm"].values, filename=base + "_pred_vs_obs_norm.png", title=f"Predicted vs observed LAI (normalized) — {n_pts:,} points", xlabel="Observed LAI (normalized)", ylabel="Predicted LAI (normalized)", ) # ── Annual curves — 3 representative sites (low/medium/high R²) ── if n_sites >= 3: selected = plot_gcc_curves( df, filename=base + "_lai_curves.png", site_col="site_id", year_col="year", doy_col="doy", ) print(f"Curve plot sites: {selected}") else: print("Fewer than 3 sites — skipping 3-site curve plot") # ── Annual curves — all sites grid ── if n_sites >= 1: plot_gcc_curves_all( df, filename=base + "_lai_curves_all.png", site_col="site_id", year_col="year", doy_col="doy", ) print("\nDone.")
if __name__ == "__main__": run_prediction_flat()