Source code for phenonn.utils.diagnostics

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Stefan Barbu, Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

# PhenoCam diagnostic plots
#
# Three figures for monitoring training and evaluating results:
#   - plot_loss_histories:   train/val loss curve over epochs
#   - plot_metric_histories: multi-metric panel (R², RMSE, ...) over epochs
#   - plot_pred_vs_obs:      predicted vs observed hexbin scatter with R² / RMSE

import math
import os
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams as mpl
import mpltex
import matplotlib.gridspec as gridspec
import matplotlib.ticker as ticker
import pandas as pd
from phenonn.data.dataset import (
    DYNAMIC_FEATURES,
    CYCLIC_FEATURES,
    STATIC_FEATURES,
    LOG_TRANSFORM_FEATURES,
)
from phenonn.data.feature_engineering import add_derived_features

params = {
    "font.family": "DejaVu Sans",
    #    'figure.dpi': 300,
    #    'savefig.dpi': 300,
    "lines.linewidth": 2,
    "lines.dashed_pattern": [4, 2],
    "lines.dashdot_pattern": [6, 3, 2, 3],
    "lines.dotted_pattern": [2, 3],
    "mathtext.rm": "arial",
    "axes.labelsize": 15,
    "axes.titlesize": 15,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "xtick.major.size": 6,
    "ytick.major.size": 6,
    "legend.fontsize": 15,
    "legend.loc": "best",
    "legend.frameon": False,
    "xtick.direction": "out",
    "ytick.direction": "out",
}
mpl.update(params)


# ── Helpers ──────────────────────────────────────────────────────────────────


def _compute_metrics(pred: np.ndarray, obs: np.ndarray) -> Dict[str, float]:
    """Compute RMSE, MAE, bias, and R² on 1-D arrays. Ignores NaNs."""
    pred = np.asarray(pred, dtype=np.float64).ravel()
    obs = np.asarray(obs, dtype=np.float64).ravel()
    mask = np.isfinite(pred) & np.isfinite(obs)
    pred, obs = pred[mask], obs[mask]
    if len(pred) == 0:
        return {"rmse": np.nan, "mae": np.nan, "bias": np.nan, "r2": np.nan, "n": 0}

    err = pred - obs
    rmse = float(np.sqrt(np.mean(err**2)))
    mae = float(np.mean(np.abs(err)))
    bias = float(np.mean(err))
    ss_res = np.sum(err**2)
    ss_tot = np.sum((obs - obs.mean()) ** 2)
    r2 = float(1 - ss_res / ss_tot) if ss_tot > 0 else float("nan")
    return {"rmse": rmse, "mae": mae, "bias": bias, "r2": r2, "n": len(pred)}


def _save(fig, filename: str, logger=None):
    """Save a figure, create parent directories if needed, and log."""
    os.makedirs(os.path.dirname(filename) or ".", exist_ok=True)
    fig.savefig(filename, bbox_inches="tight", dpi=150)
    plt.close(fig)
    msg = f"Saved plot to {filename}"
    if logger is not None:
        logger.info(msg)
    else:
        print(msg)


# ── 1. Loss history ──────────────────────────────────────────────────────────
[docs] def plot_loss_histories( train_loss: Sequence[float], valid_loss: Sequence[float], filename: str = "loss_history.png", logger=None, log_scale: bool = True, title: str = "Training / Validation Loss", ) -> None: """ Plot training and validation loss curves over epochs. Parameters ---------- train_loss : sequence of float Per-epoch training loss. valid_loss : sequence of float Per-epoch validation loss (same length as train_loss). filename : str Output path. logger : phenonn.utils.logger.Logger, optional If provided, log the save location instead of printing. log_scale : bool Use log y-axis (default True — losses usually span orders of magnitude). title : str Figure title. Notes ----- - Uses mpltex linestyles for consistent styling - Gray vertical dashed line: best validation epoch - Includes grid for better readability """ epochs = np.arange(1, len(train_loss) + 1) fig = plt.figure(figsize=(8, 5)) ax = fig.add_subplot(111) # Generate linestyles linestyles = mpltex.linestyle_generator(markers=[]) # Plot training loss ax.plot(epochs, train_loss, label="train", **next(linestyles)) # Plot validation loss ax.plot(epochs, valid_loss, label="validation", **next(linestyles)) # Set y-scale (log or linear) if log_scale and min(min(train_loss), min(valid_loss)) > 0: ax.set_yscale("log") # Labels and title ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax.set_title(title) # Grid for better readability ax.grid(True) # Mark the best validation epoch with a vertical dashed line best_epoch = int(np.argmin(valid_loss)) + 1 ax.axvline(best_epoch, label=f"Best epoch: {best_epoch}", **next(linestyles)) # Legend ax.legend() # Save the figure plt.savefig(filename, bbox_inches="tight") plt.close(fig) # Log or print the save location if logger: logger.info(f"Saved loss history plot to {filename}") else: print(f"Saved loss history plot to {filename}")
[docs] def plot_metric_histories( train_history: Dict[str, Sequence[float]], valid_history: Dict[str, Sequence[float]], filename: str = "metric_history.png", logger=None, log_metrics: Optional[Sequence[str]] = None, cols: int = 3, ) -> None: """ Multi-panel figure of metric evolution over epochs. One panel per metric. Train and validation plotted together on each panel. Parameters ---------- train_history : dict {metric_name: per-epoch values}. Example: {"rmse": [...], "r2": [...]}. valid_history : dict Same keys as train_history, same lengths. filename : str Output path. logger : phenonn.utils.logger.Logger, optional log_metrics : sequence of str, optional Metric names that should use log y-scale. Default: all except R²-like metrics (anything with "r2" or "r²" in the name). cols : int Number of columns in the grid. Notes ----- - Uses mpltex linestyles for consistent styling - Missing/NaN values in a metric series are skipped — useful when you record metrics conditionally and some epochs lack certain values. - Includes grid for better readability """ metric_names = list(train_history.keys()) if not metric_names: raise ValueError("train_history is empty — nothing to plot") if set(valid_history.keys()) != set(metric_names): raise ValueError( f"train/valid histories must have same keys. " f"Train: {sorted(metric_names)}. " f"Valid: {sorted(valid_history.keys())}" ) if log_metrics is None: log_metrics = [m for m in metric_names] n = len(metric_names) rows = math.ceil(n / cols) fig = plt.figure(figsize=(5 * cols, 4 * rows), tight_layout=True) gs = gridspec.GridSpec(rows, cols) for i, name in enumerate(metric_names): r, c = divmod(i, cols) ax = fig.add_subplot(gs[r, c]) # Generate linestyles for this subplot linestyles = mpltex.linestyle_generator(markers=[]) tr = np.asarray(train_history[name], dtype=float) vl = np.asarray(valid_history[name], dtype=float) ep = np.arange(1, max(len(tr), len(vl)) + 1) # Skip NaNs — plot only finite values tr_mask = np.isfinite(tr) vl_mask = np.isfinite(vl) ax.plot(ep[: len(tr)][tr_mask], tr[tr_mask], label="train", **next(linestyles)) ax.plot(ep[: len(vl)][vl_mask], vl[vl_mask], label="valid", **next(linestyles)) ax.set_xlabel("Epoch") ax.set_ylabel(name.replace("_", " ").upper()) # Log-scale only if strictly positive and metric in log_metrics all_vals = np.concatenate([tr[tr_mask], vl[vl_mask]]) if name in log_metrics and len(all_vals) > 0 and all_vals.min() > 0: ax.set_yscale("log") ax.legend() ax.grid(True) # Hide unused axes for j in range(n, rows * cols): r, c = divmod(j, cols) fig.add_subplot(gs[r, c]).set_visible(False) # Save the figure plt.savefig(filename, bbox_inches="tight", dpi=150) plt.close(fig) # Log or print the save location if logger: logger.info(f"Saved metric history plot to {filename}") else: print(f"Saved metric history plot to {filename}")
# ── 3. Predicted vs observed hexbin ──────────────────────────────────────────
[docs] def plot_pred_vs_obs( pred: Union[np.ndarray, Sequence[float]], obs: Union[np.ndarray, Sequence[float]], filename: str = "pred_vs_obs.png", logger=None, title: str = "Predicted vs Observed LAI", xlabel: str = "Observed LAI", ylabel: str = "Predicted LAI", gridsize: int = 40, hexbin: bool = True, ) -> Dict[str, float]: """ Scatter plot of predictions vs observations with y=x reference and metrics. Uses hexbin density by default (fast, readable for >10k points). Falls back to a regular scatter for small sample sets where hexbin would be mostly empty. Parameters ---------- pred : array-like Predicted values (any shape — will be flattened). obs : array-like Observed values (same shape). filename : str Output path. logger : phenonn.utils.logger.Logger, optional title, xlabel, ylabel : str Plot labels. gridsize : int Hexbin resolution (number of hexagons along each axis). hexbin : bool If True, use hexbin density. If False, use scatter. Default True. Returns ------- dict Computed metrics: rmse, mae, bias, r2, n. Notes ----- Errors ignored in computation: NaN and Inf pairs are dropped. Plot axis limits are set to cover both predictions and observations with a small margin, so the y=x line always appears diagonal. """ pred = np.asarray(pred, dtype=np.float64).ravel() obs = np.asarray(obs, dtype=np.float64).ravel() mask = np.isfinite(pred) & np.isfinite(obs) pred, obs = pred[mask], obs[mask] if len(pred) == 0: raise ValueError("No finite pred/obs pairs to plot") metrics = _compute_metrics(pred, obs) fig, ax = plt.subplots(figsize=(7, 7)) # Plot density or scatter if hexbin and len(pred) >= 200: hb = ax.hexbin( obs, pred, gridsize=gridsize, cmap="viridis", mincnt=1, bins="log" ) cbar_ax = fig.add_axes([0.92, 0.1, 0.02, 0.8]) fig.colorbar(hb, cax=cbar_ax, label=r"$\mathrm{\log_{10}[Count]}$") else: ax.scatter(obs, pred, alpha=0.5, s=12, color="#1f77b4", edgecolor="none") # y=x reference line — span the combined data range lo = float(min(obs.min(), pred.min())) hi = float(max(obs.max(), pred.max())) span = hi - lo margin = 0.05 * span if span > 0 else 0.01 axmin, axmax = lo - margin, hi + margin ax.plot([axmin, axmax], [axmin, axmax], "k--") ax.set_xlim(axmin, axmax) ax.set_ylim(axmin, axmax) ax.set_aspect("equal") ax.xaxis.set_major_locator(ticker.MultipleLocator(span / 4)) ax.yaxis.set_major_locator(ticker.MultipleLocator(span / 4)) ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) # Metrics text box (upper left) txt = ( f"$R^2$ = {metrics['r2']:.4f}\n" f"RMSE = {metrics['rmse']:.4f}\n" f"MAE = {metrics['mae']:.4f}\n" f"Bias = {metrics['bias']:+.4f}\n" f"N = {metrics['n']:,}" ) ax.text(0.03, 0.97, txt, transform=ax.transAxes, va="top", ha="left") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.legend() plt.savefig(filename, bbox_inches="tight") plt.close(fig) if logger: logger.info(f"Saved plot to {filename}") else: print(f"Saved plot to {filename}") return metrics
# ── 4. LAI annual curves for low / medium / high R² sites ────────────────────
[docs] def plot_gcc_curves( df: "pd.DataFrame", filename: str = "gcc_curves_by_r2.png", logger=None, seed: int = 42, site_col: str = "site", year_col: str = "year", doy_col: str = "day_index", pred_col: str = "lai_pred", obs_col: str = "lai_obs", ) -> Dict[str, float]: """ Plot observed vs predicted annual LAI curves for three representative sites. Picks one random site with low R², one with medium R², and one with high R² relative to the mean across all sites. Each site gets a subplot showing all validation years overlaid, with observed LAI as a solid line and predicted LAI as a dashed line. Parameters ---------- df : pd.DataFrame Predictions dataframe as produced by predict.py, with columns for site, year, day index, lai_pred, and lai_obs. filename : str Output path for the figure. logger : phenonn.utils.logger.Logger, optional seed : int Random seed for reproducible site selection. site_col, year_col, doy_col, pred_col, obs_col : str Column names in df. Defaults match predict.py output. Returns ------- dict {site_name: r2_value} for the three selected sites. Notes ----- Site selection: all sites are ranked by R². The site pool is split into three terciles (low / mid / high). One site is drawn randomly from each tercile. Using terciles rather than the absolute best/worst avoids always showing the same outlier sites. """ # ── Compute per-site R² ── site_metrics = {} for site, g in df.groupby(site_col): p = g[pred_col].values o = g[obs_col].values mask = np.isfinite(p) & np.isfinite(o) if mask.sum() < 10: continue p, o = p[mask], o[mask] ss_res = np.sum((o - p) ** 2) ss_tot = np.sum((o - o.mean()) ** 2) r2 = float(1 - ss_res / ss_tot) if ss_tot > 0 else float("nan") site_metrics[site] = r2 if len(site_metrics) < 3: raise ValueError( f"Need at least 3 sites with enough data, got {len(site_metrics)}" ) # ── Rank sites and pick one from each tercile ── sorted_sites = sorted(site_metrics.items(), key=lambda x: x[1]) n = len(sorted_sites) tercile_size = max(1, n // 3) low_pool = sorted_sites[:tercile_size] mid_pool = sorted_sites[tercile_size : 2 * tercile_size] high_pool = sorted_sites[2 * tercile_size :] rng = np.random.RandomState(seed) pick_low = low_pool[rng.randint(len(low_pool))] pick_mid = mid_pool[rng.randint(len(mid_pool))] pick_high = high_pool[rng.randint(len(high_pool))] picks = [ ("Low R²", pick_low[0], pick_low[1]), ("Medium R²", pick_mid[0], pick_mid[1]), ("High R²", pick_high[0], pick_high[1]), ] # ── Plot ── fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=False) plt.subplots_adjust(hspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1) fig.suptitle("LAI annual curves — low / medium / high R² sites", fontsize=14) for ax, (label, site_name, r2_val) in zip(axes, picks): site_df = df[df[site_col] == site_name].copy() years = sorted(site_df[year_col].unique()) linestyles = mpltex.linestyle_generator(markers=[]) for i, year in enumerate(years): ydf = site_df[site_df[year_col] == year].sort_values(doy_col) # DOY axis: use day_index modulo 365 if it's an absolute index, # or use it directly if it's already 1-365 doys = ydf[doy_col].values if doys.max() > 400: # Absolute index — convert to within-year DOY doys = doys - doys.min() + 1 base_style = next(linestyles) obs_style = base_style.copy() obs_style["linestyle"] = "-" ax.plot( doys, ydf[obs_col].values, label=f"{year} obs" if i == 0 else f"{year} obs", **obs_style, ) pred_style = base_style.copy() pred_style["linestyle"] = "--" ax.plot( doys, ydf[pred_col].values, label=f"{year} pred" if i == 0 else f"{year} pred", **pred_style, ) ax.set_title(f"{label}: {site_name} (R² = {r2_val:.4f})", fontsize=12) ax.set_ylabel("LAI") ax.legend(fontsize=8, ncol=min(len(years), 4), loc="upper right") ax.grid(True, alpha=0.3) axes[-1].set_xlabel("Day of year") # Save the figure plt.savefig(filename, bbox_inches="tight") plt.close(fig) # Log or print the save location if logger: logger.info(f"Saved GCC curves plot to {filename}") else: print(f"Saved GCC curves plot to {filename}") return {site: r2 for _, site, r2 in picks}
# ── 5. LAI annual curves for ALL sites ───────────────────────────────────────
[docs] def plot_gcc_curves_all( df: "pd.DataFrame", filename: str = "gcc_curves_all.png", logger=None, cols: int = 4, site_col: str = "site", year_col: str = "year", doy_col: str = "day_index", pred_col: str = "lai_pred", obs_col: str = "lai_obs", ) -> Dict[str, float]: """ Plot observed vs predicted annual LAI curves for every site. One small subplot per site, arranged in a grid, sorted by R² from best (top-left) to worst (bottom-right). Each subplot overlays all validation years with observed (solid) and predicted (dashed) lines. Parameters ---------- df : pd.DataFrame Predictions dataframe with columns for site, year, day index, lai_pred, and lai_obs. filename : str Output path. logger : phenonn.utils.logger.Logger, optional cols : int Number of columns in the grid. Default 4. site_col, year_col, doy_col, pred_col, obs_col : str Column names in df. Returns ------- dict {site_name: r2_value} for all sites. """ # ── Compute per-site R² ── site_metrics = {} for site, g in df.groupby(site_col): p = g[pred_col].values o = g[obs_col].values mask = np.isfinite(p) & np.isfinite(o) if mask.sum() < 10: continue p, o = p[mask], o[mask] ss_res = np.sum((o - p) ** 2) ss_tot = np.sum((o - o.mean()) ** 2) r2 = float(1 - ss_res / ss_tot) if ss_tot > 0 else float("nan") site_metrics[site] = r2 # Sort by R² descending (best first) sorted_sites = sorted(site_metrics.items(), key=lambda x: -x[1]) n_sites = len(sorted_sites) if n_sites == 0: raise ValueError("No sites with enough data to plot") rows = math.ceil(n_sites / cols) fig, axes = plt.subplots( rows, cols, figsize=(4 * cols, 2.8 * rows), squeeze=False, ) fig.suptitle(f"LAI predicted vs observed — {n_sites} sites (sorted by R²)") plt.subplots_adjust(hspace=0.3, left=0.3, right=0.9, top=0.9, bottom=0.1) legend_handles = [] legend_labels = [] for idx, (site_name, r2_val) in enumerate(sorted_sites): r, c = divmod(idx, cols) ax = axes[r][c] site_df = df[df[site_col] == site_name] years = sorted(site_df[year_col].unique()) linestyles = mpltex.linestyle_generator(markers=[]) for i, year in enumerate(years): ydf = site_df[site_df[year_col] == year].sort_values(doy_col) doys = ydf[doy_col].values if len(doys) > 0 and doys.max() > 400: doys = doys - doys.min() + 1 line_obs = ax.plot(doys, ydf[obs_col].values, **next(linestyles)) line_pred = ax.plot(doys, ydf[pred_col].values, **next(linestyles)) # Collect legend handles and labels from the first site only if idx == 0: legend_handles.append(line_obs[0]) legend_labels.append(f"{year} obs") legend_handles.append(line_pred[0]) legend_labels.append(f"{year} pred") # Title with site name and R² r2_str = f"{r2_val:.4f}" if np.isfinite(r2_val) else "N/A" ax.set_title(f"{site_name}\nR²={r2_str}", fontsize=8, pad=3) ax.tick_params(labelsize=7) ax.grid(True, alpha=0.2) # Hide unused axes for idx in range(n_sites, rows * cols): r, c = divmod(idx, cols) axes[r][c].set_visible(False) # Shared legend (one for the whole figure) # Build legend entries from the first site's years if legend_handles: fig.legend( handles=legend_handles, labels=legend_labels, loc="center right", bbox_to_anchor=(1.1, 0.5), ncol=1, ) # Save the figure plt.savefig(filename, bbox_inches="tight") plt.close(fig) # Log or print the save location if logger: logger.info(f"Saved GCC curves plot to {filename}") else: print(f"Saved GCC curves plot to {filename}") return dict(sorted_sites)
# ── 6. Feature and target distributions ──────────────────────────────────────
[docs] def plot_feature_distributions( site_files: List[str], filename: str = "feature_distributions.png", logger=None, cols: int = 4, n_bins: int = 60, max_sites: int = 20, ) -> None: """ Plot histograms of all input features and the target variable. Shows the raw distribution and, for log-transformed features, the distribution after log1p. Useful for checking skewness, spotting outliers, and verifying that normalization choices are appropriate. Parameters ---------- site_files : list of str Paths to site CSVs. A random subset of max_sites is used to keep computation fast. filename : str Output path for the figure. logger : phenonn.utils.logger.Logger, optional cols : int Number of columns in the grid. n_bins : int Number of histogram bins. max_sites : int Maximum number of sites to sample (for speed). Notes ----- For each feature, the histogram shows: - Blue: raw values (pooled across all sampled sites) - Orange (if applicable): values after log1p transform - Vertical red dashed lines: mean ± 1 std - Title includes skewness and % of near-zero values The target (LAI) is shown in green in the last panel. """ # Sample sites for speed rng = np.random.RandomState(42) if len(site_files) > max_sites: indices = rng.choice(len(site_files), max_sites, replace=False) sampled = [site_files[i] for i in indices] else: sampled = site_files # Load and concatenate data all_dfs = [] for f in sampled: df = pd.read_csv(f) df = df.sort_values(["year", "doy"]).reset_index(drop=True) df["doy_sin"] = np.sin(2 * np.pi * df["doy"] / 365.25) df["doy_cos"] = np.cos(2 * np.pi * df["doy"] / 365.25) df = add_derived_features(df) all_dfs.append(df) data = pd.concat(all_dfs, ignore_index=True) # Features to plot features = list(DYNAMIC_FEATURES) + list(CYCLIC_FEATURES) + list(STATIC_FEATURES) targets = ["LAI"] all_vars = features + targets n_vars = len(all_vars) rows = math.ceil(n_vars / cols) fig, axes = plt.subplots(rows, cols, figsize=(4.5 * cols, 3 * rows)) plt.subplots_adjust(hspace=0.3, left=0.3, right=0.9, top=0.9, bottom=0.1) fig.suptitle( f"Feature & target distributions ({len(sampled)} sites, " f"{len(data):,} days)", fontsize=14, ) for idx, var_name in enumerate(all_vars): r, c = divmod(idx, cols) ax = axes[r][c] if rows > 1 else axes[c] if var_name not in data.columns: ax.set_title(f"{var_name}\n(not found)", fontsize=9) ax.set_visible(False) continue vals = data[var_name].dropna().values if len(vals) == 0: ax.set_title(f"{var_name}\n(all NaN)", fontsize=9) continue # Compute stats skewness = ( float(((vals - vals.mean()) ** 3).mean() / (vals.std() ** 3)) if vals.std() > 0 else 0 ) pct_zero = float(100 * (np.abs(vals) < 1e-6).mean()) mean_val = vals.mean() std_val = vals.std() # Choose color based on type if var_name in targets: color = "#2ca02c" # green for target label = "target" else: color = "#1f77b4" # blue for features label = "raw" # Plot raw histogram ax.hist( vals, bins=n_bins, color=color, alpha=0.7, density=True, label=label, edgecolor="none", ) # If log-transformed, overlay log1p distribution if var_name in LOG_TRANSFORM_FEATURES: vals_log = np.log1p(np.clip(vals, 0, None)) ax.hist( vals_log, bins=n_bins, color="#ff7f0e", alpha=0.5, density=True, label="log1p", edgecolor="none", ) # Mean ± std lines ax.axvline(mean_val, color="red", linestyle="--", linewidth=1, alpha=0.7) ax.axvline( mean_val - std_val, color="red", linestyle=":", linewidth=0.8, alpha=0.5 ) ax.axvline( mean_val + std_val, color="red", linestyle=":", linewidth=0.8, alpha=0.5 ) # Title with stats ax.set_title( f"{var_name}\n" f"skew={skewness:+.2f} zero={pct_zero:.0f}% " f"μ={mean_val:.2g} σ={std_val:.2g}", fontsize=8, ) ax.tick_params(labelsize=7) ax.legend(fontsize=7, loc="upper right") # Hide unused axes for idx in range(n_vars, rows * cols): r, c = divmod(idx, cols) ax = axes[r][c] if rows > 1 else axes[c] ax.set_visible(False) # Save the figure plt.savefig(filename, bbox_inches="tight", dpi=150) plt.close(fig) if logger: logger.info(f"Saved feature distributions plot to {filename}") else: print(f"Saved feature distributions plot to {filename}")
# ── 7. Feature distributions per site ────────────────────────────────────────
[docs] def plot_feature_distributions_per_site( site_files: List[str], output_dir: str = "./feature_distributions", logger=None, cols: int = 4, n_bins: int = 60, ) -> None: """ Generate one feature distribution plot per site. Produces one PNG file per site, named {sitename}_feature_distribution.png, in the output directory. Each file shows the same histogram layout as plot_feature_distributions but for a single site only. Parameters ---------- site_files : list of str Paths to site CSVs. output_dir : str Directory where individual PNGs will be saved. logger : phenonn.utils.logger.Logger, optional cols : int Number of columns in the histogram grid. n_bins : int Number of histogram bins. Examples -------- >>> site_files = sorted(glob.glob("./data/DB/*.csv")) >>> plot_feature_distributions_per_site(site_files, output_dir="./runs/dist_per_site") # Creates: ./runs/dist_per_site/DB_asuhighlands_feature_distribution.png # ./runs/dist_per_site/DB_bartlett_feature_distribution.png # ... (one per site) """ os.makedirs(output_dir, exist_ok=True) n_sites = len(site_files) failed_sites = [] for i, filepath in enumerate(site_files): site_key = os.path.splitext(os.path.basename(filepath))[0] out_filename = os.path.join(output_dir, f"{site_key}_feature_distribution.png") progress_msg = f" [{i+1}/{n_sites}] {site_key}" if logger: logger.info(progress_msg) else: print(progress_msg) try: plot_feature_distributions( site_files=[filepath], filename=out_filename, logger=logger, cols=cols, n_bins=n_bins, max_sites=1, ) except Exception as e: error_msg = f" Failed to plot {site_key}: {str(e)}" if logger: logger.error(error_msg) else: print(error_msg) failed_sites.append(site_key) continue if failed_sites: summary = f"Completed {n_sites - len(failed_sites)}/{n_sites} sites. Failed: {', '.join(failed_sites)}" else: summary = f"Successfully saved {n_sites} distribution plots to {output_dir}/" if logger: if hasattr(logger, "success"): logger.success(summary) else: logger.info(summary) else: print(summary)
# ── Convenience: build histories from a training loop ────────────────────────
[docs] def make_history_dicts() -> tuple: """ Shortcut for initializing matched train/valid history dicts. Returns ------- (train_hist, valid_hist) : tuple of dict Each with empty lists keyed by 'loss', 'rmse', 'r2'. Example ------- >>> train_hist, valid_hist = make_history_dicts() >>> # inside training loop: >>> train_hist['loss'].append(train_loss) >>> valid_hist['loss'].append(val_loss) >>> valid_hist['rmse'].append(val_rmse) >>> valid_hist['r2'].append(val_r2) """ keys = ["loss", "rmse", "r2"] return {k: [] for k in keys}, {k: [] for k in keys}