Source code for rtnn.evaluater

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

"""
Evaluation utilities for RTnn model assessment.

This module provides comprehensive evaluation tools for radiative transfer
neural network models, including custom loss functions, metric computation,
and visualization helpers. It is designed to support both training diagnostics
and rigorous model validation against physical constraints.

The module implements several key capabilities:

**Custom Loss Functions**
    - Normalized losses (NMAE, NMSE) for scale-invariant error measurement
    - Standard losses (MSE, MAE, Huber, Smooth L1) for baseline comparison
    - Weighted and physics-informed losses for multi-objective optimization

**Evaluation Metrics**
    - NMAE: Normalized Mean Absolute Error
    - NMSE: Normalized Mean Squared Error
    - R²: Coefficient of determination
    - MBE: Mean Bias Error
    - MARE: Mean Absolute Relative Error
    - GMRAE: Geometric Mean Relative Absolute Error

**Physical Consistency**
    - Absorption rate calculation from flux divergence
    - Energy conservation penalty (albedo + transmittance + absorptance = 1)
    - Heating rate computation from net flux profiles

**Data Handling**
    - Normalization/de-normalization for all supported transformation types
    - Multi-dimensional tensor reshaping (batch, channels, PFTs, bands, levels)
    - Metric tracking with running statistics

The module follows a modular design where loss functions and metrics are
implemented as separate callable classes/functions, allowing easy extension
and composition.

Notes
-----
**Flux Variable Ordering**

    - Channel 0: collimated albedo (direct upwelling)
    - Channel 1: collimated transmittance (direct downwelling)
    - Channel 2: isotropic albedo (diffuse upwelling)
    - Channel 3: isotropic transmittance (diffuse downwelling)

This ordering matches the ``ov`` list in :class:`rtnn.dataset.DataPreprocessor`.

**Absorption Calculation**

    - For collimated: absorption = -d(net_flux)/dz
    - For isotropic: absorption = -d(net_flux)/dz


**Supported Normalization Types**

    - Linear: minmax, standard, robust
    - Log1p-based: log1p_minmax, log1p_standard, log1p_robust
    - Sqrt-based: sqrt_minmax, sqrt_standard, sqrt_robust

Examples
--------
Basic usage for model evaluation:

>>> import torch
>>> from rtnn.evaluater import get_loss_function, run_validation
>>>
>>> # Create loss function
>>> args = argparse.Namespace(loss_type='huber', beta_delta=1.0)
>>> criterion = get_loss_function('huber', args)
>>>
>>> # Evaluate model
>>> valid_loss, metrics = run_validation(
...     loader=val_loader,
...     model=my_model,
...     norm_mapping=norm_stats,
...     normalization_type=norm_types,
...     index_mapping=idxmap,
...     device=device,
...     args=args,
...     epoch=10,
...     logger=logger
... )
>>>
>>> print(f"Validation NMAE: {metrics['fluxes_NMAE']:.4f}")
>>> print(f"R² score: {metrics['fluxes_R2']:.4f}")

Using custom metric tracking:

>>> from rtnn.evaluater import MetricTracker, nmae_all
>>>
>>> tracker = MetricTracker()
>>> for batch in dataloader:
...     pred, target = model(batch)
...     count, value = nmae_all(pred, target)
...     tracker.update(value.item(), count)
>>>
>>> mean_nmae = tracker.getmean()

See Also
--------
rtnn.dataset.DataPreprocessor : Data loading and normalization
rtnn.diagnostics : Visualization tools for model predictions
rtnn.models : Neural network architectures
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
from tqdm import tqdm
import os
from rtnn.diagnostics import plot_all_diagnostics
from typing import Optional

sys.path.append("..")


[docs] class NMSELoss(nn.Module): """ Normalized Mean Squared Error Loss. Computes MSE normalized by the mean square of the target values. Useful when the scale of the target variable varies across samples or when comparing models trained on different datasets. Parameters ---------- eps : float, optional Small constant for numerical stability. Default is 1e-8. Notes ----- The loss is calculated as: NMSE = MSE(pred, target) / (mean(target²) + eps) This normalization makes the loss scale-invariant, with values typically in the range [0, 1]. Examples -------- >>> criterion = NMSELoss() >>> predictions = torch.tensor([[2.0, 3.0], [1.0, 2.0]]) >>> targets = torch.tensor([[2.0, 4.0], [1.0, 2.5]]) >>> loss = criterion(predictions, targets) >>> print(loss.item()) 0.0625 # approximates 0.0625 for this example """
[docs] def __init__(self, eps=1e-8): super(NMSELoss, self).__init__() self.eps = eps self.mse = nn.MSELoss()
[docs] def forward(self, pred, target): mse = self.mse(pred, target) norm = torch.mean(target**2) + self.eps return mse / norm
[docs] class NMAELoss(nn.Module): """ Normalized Mean Absolute Error Loss. Computes MAE normalized by the mean absolute value of the target. Provides a scale-invariant error metric that is more robust to outliers than NMSE. Parameters ---------- eps : float, optional Small constant for numerical stability. Default is 1e-8. Notes ----- Values typically range from 0 to 1, with 0 representing perfect predictions and values >1 indicating predictions worse than the trivial zero predictor. Examples -------- >>> criterion = NMAELoss() >>> predictions = torch.tensor([[2.0, 3.0], [1.0, 2.0]]) >>> targets = torch.tensor([[2.0, 4.0], [1.0, 2.5]]) >>> loss = criterion(predictions, targets) >>> print(loss.item()) 0.0833 # approximates 0.0833 for this example """
[docs] def __init__(self, eps=1e-8): super(NMAELoss, self).__init__() self.eps = eps self.l1 = nn.L1Loss()
[docs] def forward(self, pred, target): mae = self.l1(pred, target) norm = torch.mean(torch.abs(target)) + self.eps return mae / norm
[docs] def physics_loss( pred: torch.Tensor, target: torch.Tensor, conservation_penalty: Optional[torch.Tensor] = None, lambda_phys: float = 0.1, delta: float = 1.0, ) -> torch.Tensor: """ Combined Huber loss + energy conservation penalty (improvement E). The four output variables satisfy: albedo + transmittance + absorptance = 1 For collimated: collim_alb + collim_tran + collim_abs = 1 For isotropic: isotrop_alb + isotrop_tran + isotrop_abs = 1 This function enforces the constraint as a soft penalty. You can pass ``pred_abs`` if your model also predicts absorptance; otherwise the penalty is computed implicitly as (1 - alb - tran). Parameters ---------- pred : torch.Tensor shape (B, 4, L) Model predictions: [collim_alb, collim_tran, isotrop_alb, isotrop_tran] (channel order matches your ``ov`` list in DataPreprocessor). target : torch.Tensor shape (B, 4, L) Ground-truth targets. pred_abs : torch.Tensor or None shape (B, 2, L) If provided, predicted absorptance for [collimated, isotropic]. If None, absorptance is inferred as (1 - alb - tran). lambda_phys : float Weight of the energy conservation penalty relative to Huber loss. delta : float Huber loss delta parameter. Returns ------- torch.Tensor scalar loss value. """ # --- Primary Huber loss --- huber = F.huber_loss(pred, target, delta=delta, reduction="mean") return huber + lambda_phys * conservation_penalty
[docs] class MetricTracker: """ A utility class for tracking and computing statistics of metric values. This class maintains running sums of metric values and their squares, allowing incremental updates and computation of mean and standard deviation. It is particularly useful for aggregating metrics across multiple batches during evaluation. Attributes ---------- value : float Cumulative weighted sum of metric values count : int Total number of samples processed value_sq : float Cumulative weighted sum of squared metric values Examples -------- >>> tracker = MetricTracker() >>> tracker.update(10.0, 5) # value=10.0, count=5 samples >>> tracker.update(20.0, 3) # value=20.0, count=3 samples >>> print(tracker.getmean()) # (10*5 + 20*3) / (5+3) = 110/8 = 13.75 13.75 >>> print(tracker.getstd()) 5.0 # computed from variance >>> print(tracker.getsqrtmean()) 3.7080992435478315 """
[docs] def __init__(self): """ Initialize MetricTracker with zero values. """ self.reset()
[docs] def reset(self): """ Reset all tracked values to zero. Returns ------- None """ self.value = 0.0 self.count = 0 self.value_sq = 0.0
[docs] def update(self, value, count): """ Update the tracker with new metric values. Parameters ---------- value : float The metric value to add count : int Number of samples this value represents (weight) Returns ------- None """ self.count += count self.value += value * count self.value_sq += (value**2) * count
[docs] def getmean(self): """ Calculate the mean of all tracked values. Returns ------- float Weighted mean of all values: total_value / total_count Raises ------ ZeroDivisionError If no values have been added (count == 0) """ if self.count == 0: raise ZeroDivisionError("Cannot compute mean with zero samples") return self.value / self.count
[docs] def getstd(self): """ Calculate the standard deviation of all tracked values. Returns ------- float Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2) Raises ------ ZeroDivisionError If no values have been added (count == 0) """ if self.count == 0: raise ZeroDivisionError("Cannot compute std with zero samples") mean = self.getmean() variance = self.value_sq / self.count - mean**2 return np.sqrt(max(variance, 0.0)) # numerical safety
[docs] def getsqrtmean(self): """ Calculate the square root of the mean of all tracked values. Returns ------- float Square root of the weighted mean: sqrt(total_value / total_count) Raises ------ ZeroDivisionError If no values have been added (count == 0) """ return np.sqrt(self.getmean())
[docs] def get_loss_function(loss_type, args, logger=None): """ Factory function to instantiate the requested loss function. Parameters ---------- loss_type : str Type of loss function. Options: **Standard losses:** - 'mse': Mean Squared Error - 'mae': Mean Absolute Error **Normalized losses:** - 'nmae': Normalized Mean Absolute Error - 'nmse': Normalized Mean Squared Error **Robust losses:** - 'smoothl1': Smooth L1 Loss (Huber with beta) - 'huber': Huber Loss with delta parameter args : argparse.Namespace Arguments containing loss-specific parameters: - For 'huber'/'smoothl1': requires `args.beta_delta` - For composite losses: may require `args.beta` logger : logging.Logger, optional Logger for informational messages. If None, no logging occurs. Returns ------- torch.nn.Module Initialized loss function. Raises ------ ValueError If loss_type is not supported or required parameters are missing. Examples -------- >>> import argparse >>> args = argparse.Namespace(beta_delta=1.0) >>> criterion = get_loss_function('huber', args) >>> loss = criterion(predictions, targets) >>> args = argparse.Namespace() >>> criterion = get_loss_function('mse', args) """ if loss_type == "mse": if logger: logger.info("Using Mean Squared Error (MSE) loss") return nn.MSELoss() elif loss_type == "mae": if logger: logger.info("Using Mean Absolute Error (MAE) loss") return nn.L1Loss() elif loss_type == "nmae": if logger: logger.info("Using Normalized Mean Absolute Error (NMAE) loss") return NMAELoss() elif loss_type == "nmse": if logger: logger.info("Using Normalized Mean Squared Error (NMSE) loss") return NMSELoss() elif loss_type in ["smoothl1", "huber"]: if not hasattr(args, "beta_delta"): raise ValueError(f"{loss_type.capitalize()}Loss requires --beta_delta") if logger: logger.info( f"Using {loss_type.capitalize()} loss with delta={args.beta_delta}" ) return ( nn.SmoothL1Loss(beta=args.beta_delta) if loss_type == "smoothl1" else nn.HuberLoss(delta=args.beta_delta) ) else: raise ValueError(f"Unsupported loss type: {loss_type}")
[docs] def mse_all(pred, true): """ Compute Mean Squared Error. Parameters ---------- pred : torch.Tensor Predictions. true : torch.Tensor Ground truth. Returns ------- tuple (num_elements, mse_value) """ return pred.numel(), torch.mean((pred - true) ** 2)
[docs] def mbe_all(pred, true): """ Compute Mean Bias Error. Parameters ---------- pred : torch.Tensor Predictions. true : torch.Tensor Ground truth. Returns ------- tuple (num_elements, mbe_value) """ return pred.numel(), torch.mean(pred - true)
[docs] def mae_all(pred, true): """ Compute Mean Absolute Error. Parameters ---------- pred : torch.Tensor Predictions. true : torch.Tensor Ground truth. Returns ------- tuple (num_elements, mae_value) """ return pred.numel(), torch.mean(torch.abs(pred - true))
[docs] def r2_all(pred, true): """ Calculate R2 (coefficient of determination) between predicted and true values. Computes the R2 metric and returns both the number of elements and the R2 value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values Returns ------- tuple (num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score Notes ----- R2 is calculated as: R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2) This implementation is fully torch-based and works on CPU and GPU. """ if pred.shape != true.shape: raise RuntimeError(f"Shape mismatch: pred {pred.shape} vs true {true.shape}") eps = 1e-12 # Small value to avoid division by zero when variance is zero num_elements = pred.numel() # Flatten pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) # Residual sum of squares ss_res = torch.sum((true_flat - pred_flat) ** 2) # Total sum of squares true_mean = torch.mean(true_flat) ss_tot = torch.sum((true_flat - true_mean) ** 2) # R2 score r2_value = 1.0 - ss_res / (ss_tot + eps) return num_elements, r2_value
[docs] def nmae_all(pred, true): """ Compute Normalized Mean Absolute Error. Parameters ---------- pred : torch.Tensor Predictions. true : torch.Tensor Ground truth. Returns ------- tuple (num_elements, nmae_value) """ mae = torch.mean(torch.abs(pred - true)) norm = torch.mean(torch.abs(true)) + 1e-8 nmae = mae / norm return pred.numel(), nmae
[docs] def nmse_all(pred, true): """ Compute Normalized Mean Squared Error. Parameters ---------- pred : torch.Tensor Predictions. true : torch.Tensor Ground truth. Returns ------- tuple (num_elements, nmse_value) """ mse = torch.mean((pred - true) ** 2) norm = torch.mean(true**2) + 1e-8 nmse = mse / norm return pred.numel(), nmse
[docs] def mare_all(pred, true): """ Compute Mean Absolute Relative Error. Parameters ---------- pred : torch.Tensor Predictions. true : torch.Tensor Ground truth. Returns ------- tuple (num_elements, mare_value) """ relative_error = torch.abs(pred - true) / (torch.abs(true) + 1e-8) mare = torch.mean(relative_error) return pred.numel(), mare
[docs] def gmrae_all(pred, true): """ Compute Geometric Mean Relative Absolute Error. Parameters ---------- pred : torch.Tensor Predictions. true : torch.Tensor Ground truth. Returns ------- tuple (num_elements, gmrae_value) """ eps = 1e-8 relative_errors = torch.abs(pred - true) / (torch.abs(true) + eps) log_rel_errors = torch.log(relative_errors + eps) gmrae = torch.exp(torch.mean(log_rel_errors)) return pred.numel(), gmrae
[docs] def unnorm_mpas(pred, targ, norm_mapping, normalization_type, idxmap): """ Reverse normalization for 5D tensors. This function converts normalized predictions and targets back to physical units using stored normalization statistics. It supports all normalization types defined in DataPreprocessor. Parameters ---------- pred : torch.Tensor Normalized predictions. Shape: (batch, 4, n_pft, n_bands, seq_length) targ : torch.Tensor Normalized targets. Same shape as pred. norm_mapping : dict Dictionary containing normalization statistics for each variable. normalization_type : dict Dictionary specifying normalization type for each variable. idxmap : dict Mapping from channel indices (0-3) to variable names. Returns ------- tuple (upred, utarg) where: - upred (torch.Tensor): Unnormalized predictions - utarg (torch.Tensor): Unnormalized targets Raises ------ ValueError If normalization type is not supported. Notes ----- The function handles the following transformations: - Linear: x' = (x - mean)/std or (x - min)/(max - min) - Log1p: x_norm = (log(1+x) - log_mean)/log_std - Sqrt: x_norm = (sqrt(x) - sqrt_mean)/sqrt_std The reverse operation is applied to recover physical values. """ device = pred.device upred = torch.zeros_like(pred, device=device) utarg = torch.zeros_like(targ, device=device) # idxmap maps channel indices 0-3 to variable names for var_idx, var_name in idxmap.items(): norm_type = normalization_type.get(var_name, "log1p_minmax") norm = norm_mapping[var_name] # Get slice for this variable: (batch, 1, n_pft, n_bands, seq) pred_var = pred[:, var_idx : var_idx + 1, :, :, :] targ_var = targ[:, var_idx : var_idx + 1, :, :, :] if norm_type == "standard": mean = norm["vmean"] std = norm["vstd"] upred[:, var_idx : var_idx + 1, :, :, :] = pred_var * std + mean utarg[:, var_idx : var_idx + 1, :, :, :] = targ_var * std + mean elif norm_type == "minmax": vmin = norm["vmin"] vmax = norm["vmax"] upred[:, var_idx : var_idx + 1, :, :, :] = pred_var * (vmax - vmin) + vmin utarg[:, var_idx : var_idx + 1, :, :, :] = targ_var * (vmax - vmin) + vmin elif norm_type == "robust": median = norm["median"] iqr = norm["iqr"] upred[:, var_idx : var_idx + 1, :, :, :] = pred_var * iqr + median utarg[:, var_idx : var_idx + 1, :, :, :] = targ_var * iqr + median elif norm_type == "log1p_minmax": log_min = norm["log_min"] log_max = norm["log_max"] unnorm_pred = pred_var * (log_max - log_min) + log_min unnorm_targ = targ_var * (log_max - log_min) + log_min upred[:, var_idx : var_idx + 1, :, :, :] = torch.expm1(unnorm_pred) utarg[:, var_idx : var_idx + 1, :, :, :] = torch.expm1(unnorm_targ) elif norm_type == "log1p_standard": mean = norm["log_mean"] std = norm["log_std"] unnorm_pred = pred_var * std + mean unnorm_targ = targ_var * std + mean upred[:, var_idx : var_idx + 1, :, :, :] = torch.expm1(unnorm_pred) utarg[:, var_idx : var_idx + 1, :, :, :] = torch.expm1(unnorm_targ) elif norm_type == "log1p_robust": median = norm["log_median"] iqr = norm["log_iqr"] unnorm_pred = pred_var * iqr + median unnorm_targ = targ_var * iqr + median upred[:, var_idx : var_idx + 1, :, :, :] = torch.expm1(unnorm_pred) utarg[:, var_idx : var_idx + 1, :, :, :] = torch.expm1(unnorm_targ) elif norm_type == "sqrt_minmax": sqrt_min = norm["sqrt_min"] sqrt_max = norm["sqrt_max"] unnorm_pred = pred_var * (sqrt_max - sqrt_min) + sqrt_min unnorm_targ = targ_var * (sqrt_max - sqrt_min) + sqrt_min upred[:, var_idx : var_idx + 1, :, :, :] = unnorm_pred**2 utarg[:, var_idx : var_idx + 1, :, :, :] = unnorm_targ**2 elif norm_type == "sqrt_standard": mean = norm["sqrt_mean"] std = norm["sqrt_std"] unnorm_pred = pred_var * std + mean unnorm_targ = targ_var * std + mean upred[:, var_idx : var_idx + 1, :, :, :] = unnorm_pred**2 utarg[:, var_idx : var_idx + 1, :, :, :] = unnorm_targ**2 elif norm_type == "sqrt_robust": median = norm["sqrt_median"] iqr = norm["sqrt_iqr"] unnorm_pred = pred_var * iqr + median unnorm_targ = targ_var * iqr + median upred[:, var_idx : var_idx + 1, :, :, :] = unnorm_pred**2 utarg[:, var_idx : var_idx + 1, :, :, :] = unnorm_targ**2 else: raise ValueError( f"Unsupported normalization type '{norm_type}' for variable '{var_name}'" ) return upred, utarg
[docs] def conservation_residual(alb, tran, abs_flux): """ Compute energy conservation residual for a given level. Parameters ---------- alb : torch.Tensor Albedo (upwelling flux). Shape (batch, 1, n_pft, n_bands, seq) tran : torch.Tensor Transmittance (downwelling flux). Same shape as alb. abs_flux : torch.Tensor Absorptance (absorbed flux). Shape (batch, 1, n_pft, n_bands, seq-1) Returns ------- torch.Tensor Squared conservation residual. Shape (batch, 1, n_pft, n_bands, seq-1) Notes ----- The function averages alb and tran to layer centers before computing: residual = (alb_center + tran_center + abs_flux - 1)² This enforces the physical constraint that: upwelling + downwelling + absorption = total incoming radiation = 1 """ # Average fluxes to layer centers (N-1 layers) alb_center = (alb[..., :-1] + alb[..., 1:]) / 2.0 tran_center = (tran[..., :-1] + tran[..., 1:]) / 2.0 # Conservation: alb + tran + abs = 1 return (alb_center + tran_center + abs_flux - 1.0) ** 2
[docs] def calc_abs(pred, targ, p=None): """ Calculate absorption rates from flux predictions. Computes absorption rates for both collimated (direct) and isotropic (diffuse) components by calculating the divergence of net flux. Parameters ---------- pred : torch.Tensor Predicted fluxes. Shape (batch, 4, n_pft, n_bands, seq_length) Channel order: [collim_alb, collim_tran, isotrop_alb, isotrop_tran] targ : torch.Tensor Target fluxes. Same shape as pred. p : torch.Tensor, optional Pressure levels. If provided, computes heating rates. Shape can be (seq_length,) or (batch, seq_length). Default is None. Returns ------- tuple (abs12_pred, abs12_targ, abs34_pred, abs34_targ, conservation_penalty) """ # Collimated (channels 0 and 1) abs12_pred = heating_rate(pred[:, 0:1, :, :, :], pred[:, 1:2, :, :, :], p) abs12_targ = heating_rate(targ[:, 0:1, :, :, :], targ[:, 1:2, :, :, :], p) # Isotropic (channels 2 and 3) abs34_pred = heating_rate(pred[:, 2:3, :, :, :], pred[:, 3:4, :, :, :], p) abs34_targ = heating_rate(targ[:, 2:3, :, :, :], targ[:, 3:4, :, :, :], p) # Conservation penalty collim_resid = conservation_residual( pred[:, 0:1, :, :, :], pred[:, 1:2, :, :, :], abs12_pred ) isotrop_resid = conservation_residual( pred[:, 2:3, :, :, :], pred[:, 3:4, :, :, :], abs34_pred ) conservation_penalty = (collim_resid + isotrop_resid).mean() return abs12_pred, abs12_targ, abs34_pred, abs34_targ, conservation_penalty
[docs] def heating_rate(up, down, p=None): """ Calculate heating rate from upwelling and downwelling fluxes. Parameters ---------- up : torch.Tensor Upwelling flux. Shape (batch, 1, n_pft, n_bands, seq_length) down : torch.Tensor Downwelling flux. Same shape as up. p : torch.Tensor, optional Pressure levels. If provided, computes heating rate in K/day. If None, computes negative flux divergence. Default is None. Returns ------- torch.Tensor Heating rate or flux divergence. Shape (batch, 1, n_pft, n_bands, seq_length-1) Notes ----- The calculation involves: 1. net = up - down (net flux) 2. dnet = net - roll(net, 1) (vertical flux divergence) 3. If p is provided, convert to heating rate using: hr = -dnet/dp * (g * 8.64e4) / (cp * 100) where g = 9.8066 m/s², cp ≈ 1004 J/(kg·K) The factor 8.64e4 converts from W/m²/Pa to K/day """ net = up - down # Roll along the last dimension (seq_length) dnet = net - torch.roll(net, 1, dims=-1) if p is not None: g = 9.8066 r = 287.0 cp = 7.0 * r / 2.0 fac = g * 8.64e4 / (cp * 100) # p shape: (seq_length,) or (batch, seq_length) # Need to handle broadcasting if p.dim() == 1: p = p.view(1, 1, 1, 1, -1) dp = p - torch.roll(p, 1, dims=-1) return dnet[..., 1:] / dp[..., 1:] * fac else: return -dnet[..., 1:]
[docs] def run_validation( loader, model, norm_mapping, normalization_type, index_mapping, device, args, epoch, logger=None, base_dir="./results", n_pft=15, n_bands=2, n_chans=4, ): """ Evaluate model accuracy on LSM dataset. Performs comprehensive evaluation including: - Loss computation for main fluxes and absorption rates - Metric calculation (NMAE, NMSE, R²) for fluxes and absorption - Optional plotting of predictions vs targets (every 10 epochs) - Energy conservation verification Parameters ---------- loader : torch.utils.data.DataLoader Data loader for evaluation dataset. model : torch.nn.Module Trained model to evaluate. norm_mapping : dict Normalization statistics for variables. normalization_type : dict Normalization types per variable. index_mapping : dict Mapping from channel indices (0-3) to variable names. device : torch.device Device to run evaluation on (cuda or cpu). args : argparse.Namespace Arguments containing loss_type, beta, beta_delta, and num_epochs. epoch : int Current epoch number (for plotting schedule). logger : logging.Logger, optional Logger for informational messages. If None, no logging occurs. base_dir : str, optional Directory to save diagnostic plots. Default is "./results". n_pft : int, optional Number of Plant Functional Types. Default is 15. n_bands : int, optional Number of spectral bands. Default is 2 (VIS, NIR). n_chans : int, optional Number of output channels. Default is 4. Returns ------- tuple (valid_loss, valid_metrics) Notes ----- The evaluation performs the following steps: 1. Iterates through validation loader 2. Computes predictions and reshapes to 5D tensors 3. De-normalizes predictions and targets to physical units 4. Calculates absorption rates and conservation penalties 5. Computes metrics for fluxes and absorption 6. Optionally generates diagnostic plots (epoch % 10 == 0 or final epoch) The combined loss includes both flux and absorption terms weighted by β: total_loss = (1-β)*loss_fluxes + β*(loss_abs12 + loss_abs34) Examples -------- >>> valid_loss, metrics = run_validation( ... loader=val_loader, ... model=my_model, ... norm_mapping=norm_stats, ... normalization_type=norm_types, ... index_mapping=idxmap, ... device=torch.device('cuda'), ... args=args, ... epoch=10, ... logger=logger ... ) >>> print(f"Validation NMAE: {metrics['fluxes_NMAE']:.4f}") >>> print(f"R² score: {metrics['fluxes_R2']:.4f}") """ model.eval() valid_loss_types = [ "mse", "mae", "nmae", "nmse", "wmse", "logcosh", "smoothl1", "huber", ] loss_type = args.loss_type.lower() assert ( loss_type in valid_loss_types ), f"Invalid loss_type (should be one of {valid_loss_types})" func = get_loss_function(loss_type, args) metric_names = ["NMAE", "NMSE", "R2", "MAE", "MSE"] metric_funcs = { "NMAE": nmae_all, "NMSE": nmse_all, "R2": r2_all, "MAE": mae_all, "MSE": mse_all, } output_keys = ["fluxes", "abs12", "abs34"] valid_metrics = { f"{k}_{m}": MetricTracker() for k in output_keys for m in metric_names } valid_loss = MetricTracker() save_plots = (epoch % 10 == 0) or (epoch == args.num_epochs - 1) if save_plots: if logger: logger.info("Collecting data for final epoch") else: print("Collecting data for final epoch") all_predicts_unnorm = [] all_targets_unnorm = [] all_abs12_predict = [] all_abs12_target = [] all_abs34_predict = [] all_abs34_target = [] # Progress bar for validation loop = tqdm( enumerate(loader), total=len(loader), desc=f"Validation Epoch {epoch}", leave=False, ) with torch.no_grad(): for batch_idx, (feature, targets) in loop: feature_shape = feature.shape target_shape = targets.shape inner_batch_size = feature_shape[0] * feature_shape[1] feature = feature.reshape( inner_batch_size, feature_shape[2], feature_shape[3] ).to(device=device) targets = targets.reshape( inner_batch_size, target_shape[2], target_shape[3] ).to(device=device) predicts = model(feature) # Reshape to (batch, 4, n_pft, n_bands, seq) pred_reshaped = predicts.reshape( inner_batch_size, n_chans, n_pft, n_bands, target_shape[3] ) targ_reshaped = targets.reshape( inner_batch_size, n_chans, n_pft, n_bands, target_shape[3] ) predicts_unnorm, targets_unnorm = unnorm_mpas( pred_reshaped, targ_reshaped, norm_mapping, normalization_type, index_mapping, ) assert ( predicts_unnorm.shape == pred_reshaped.shape ), f"Expected predicts_unnorm shape {pred_reshaped.shape}, got {predicts_unnorm.shape}" assert ( targets_unnorm.shape == targ_reshaped.shape ), f"Expected targets_unnorm shape {targ_reshaped.shape}, got {targets_unnorm.shape}" ( abs12_predict, abs12_target, abs34_predict, abs34_target, conservation_penalty, ) = calc_abs(predicts_unnorm, targets_unnorm) expected_abs_shape = ( inner_batch_size, 1, n_pft, n_bands, target_shape[3] - 1, ) assert ( abs12_predict.shape == expected_abs_shape ), f"Expected abs12_predict shape {expected_abs_shape}, got {abs12_predict.shape}" assert ( abs12_target.shape == expected_abs_shape ), f"Expected abs12_target shape {expected_abs_shape}, got {abs12_target.shape}" assert ( abs34_predict.shape == expected_abs_shape ), f"Expected abs34_predict shape {expected_abs_shape}, got {abs34_predict.shape}" assert ( abs34_target.shape == expected_abs_shape ), f"Expected abs34_target shape {expected_abs_shape}, got {abs34_target.shape}" if batch_idx == 0: logger.info(f"Feature shape: {feature.shape}") logger.info(f"Targets shape: {targets.shape}") logger.info(f"abs12_predict shape: {abs12_predict.shape}") logger.info(f"abs12_target shape: {abs12_target.shape}") logger.info(f"abs34_predict shape: {abs34_predict.shape}") logger.info(f"abs34_target shape: {abs34_target.shape}") logger.info(f"Conservation penalty: {conservation_penalty.item():.6f}") if save_plots: all_predicts_unnorm.append(predicts_unnorm.cpu()) all_targets_unnorm.append(targets_unnorm.cpu()) all_abs12_predict.append(abs12_predict.cpu()) all_abs12_target.append(abs12_target.cpu()) all_abs34_predict.append(abs34_predict.cpu()) all_abs34_target.append(abs34_target.cpu()) output_dict = { "fluxes": (predicts, targets), "abs12": (abs12_predict, abs12_target), "abs34": (abs34_predict, abs34_target), } for key in output_keys: pred, tgt = output_dict[key] for metric in metric_names: metric_key = f"{key}_{metric}" if metric_key not in valid_metrics: raise KeyError( f"Metric key '{metric_key}' not found in valid_metrics" ) count, value = metric_funcs[metric](pred, tgt) valid_metrics[metric_key].update(value.item(), count) main_count, main_val = predicts.numel(), func(predicts, targets) abs12_count, abs12_val = ( abs12_predict.numel(), func(abs12_predict, abs12_target), ) abs34_count, abs34_val = ( abs34_predict.numel(), func(abs34_predict, abs34_target), ) weighted_loss = (1.0 - args.beta) * main_val * main_count + args.beta * ( abs12_val * abs12_count + abs34_val * abs34_count ) total_count = (1.0 - args.beta) * main_count + args.beta * ( abs12_count + abs34_count ) total_loss = weighted_loss / total_count valid_loss.update(total_loss.item(), 1) loop.set_postfix(loss=total_loss.item()) if save_plots: if logger: logger.info("Doing plot for final epoch") else: print("Doing plot for final epoch") os.makedirs(base_dir, exist_ok=True) all_predicts_unnorm = torch.cat(all_predicts_unnorm, dim=0) all_targets_unnorm = torch.cat(all_targets_unnorm, dim=0) all_abs12_predict = torch.cat(all_abs12_predict, dim=0) all_abs12_target = torch.cat(all_abs12_target, dim=0) all_abs34_predict = torch.cat(all_abs34_predict, dim=0) all_abs34_target = torch.cat(all_abs34_target, dim=0) plot_all_diagnostics( all_predicts_unnorm, all_targets_unnorm, abs12_predict=all_abs12_predict, abs12_target=all_abs12_target, abs34_predict=all_abs34_predict, abs34_target=all_abs34_target, n_pft=n_pft, n_bands=n_bands, output_dir=base_dir, prefix=f"validation_epoch{epoch}", logger=logger, ) return valid_loss.getmean(), { k: (tracker.getsqrtmean() if k.lower().endswith("mse") else tracker.getmean()) for k, tracker in valid_metrics.items() }