Source code for rtnn.diagnostics

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

"""
Diagnostics and visualization utilities for RTnn radiative transfer model.

This module provides comprehensive visualization tools for analyzing and
evaluating neural network emulators of atmospheric radiative transfer. It
includes functions for creating diagnostic plots of model predictions,
training histories, and data statistics.

The module is designed to support:
- Comparison of model predictions against ground truth targets
- Visualization of radiative fluxes (direct/diffuse, upwelling/downwelling)
- Absorption rate analysis across vertical levels and spectral bands
- Per-plant functional type (PFT) and per-spectral band diagnostics
- Training and validation metric tracking over epochs
- Spatial and temporal sampling distribution analysis
- Statistical characterization of input variables for normalization

Key visualization types
-----------------------
- **Hexbin plots**: Density scatter plots with color mapping for high-volume data
- **Line plots**: Time series or vertical profile comparisons
- **Histograms**: Distribution analysis with optional log scaling
- **Marginal histograms**: 2D density plots with side distributions
- **Multi-panel layouts**: Systematic comparison across flux channels

The module follows the matplotlib architecture using Figure and FigureCanvasAgg
for non-interactive, file-based rendering suitable for batch processing and
headless environments.

Notes
-----
All plotting functions are designed to work with PyTorch tensors or numpy arrays
and save figures directly to disk. The module uses custom matplotlib parameters
optimized for scientific publication quality.

Default figure parameters:
- Font family: DejaVu Sans
- Font sizes: axes labels (15), titles (15), ticks (12), legends (15)
- Line widths: 2
- Tick direction: outward
- Legend: frame off, best location

Examples
--------
Basic usage for generating diagnostic plots:

>>> import torch
>>> from rtnn.diagnostics import plot_flux_and_abs, plot_metric_histories
>>>
>>> # Assuming you have model predictions and targets
>>> predicts = torch.randn(32, 4, 10)  # (batch, channels, levels)
>>> targets = torch.randn(32, 4, 10)
>>>
>>> # Create hexbin plot
>>> plot_flux_and_abs(
...     predicts=predicts,
...     targets=targets,
...     filename="diagnostics.png",
...     logger=logger
... )
>>>
>>> # Plot training history
>>> train_history = {"nmae": [0.1, 0.08, 0.06], "r2": [0.85, 0.88, 0.91]}
>>> valid_history = {"nmae": [0.12, 0.10, 0.08], "r2": [0.82, 0.85, 0.89]}
>>> plot_metric_histories(
...     train_history=train_history,
...     valid_history=valid_history,
...     filename="metrics.png"
... )

See Also
--------
rtnn.dataset.DataPreprocessor : Data loading and preprocessing
rtnn.evaluater : Metrics and loss functions for evaluation
rtnn.models : Neural network architectures for radiative transfer
"""

from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib.gridspec as gridspec
import numpy as np
import mpltex
import math
from matplotlib import rcParams as mpl
from sklearn.metrics import r2_score
import matplotlib.ticker as ticker
import random
import xarray as xr
import os
import collections

params = {
    "font.family": "DejaVu Sans",
    #    'figure.dpi': 300,
    #    'savefig.dpi': 300,
    "lines.linewidth": 2,
    "lines.dashed_pattern": [4, 2],
    "lines.dashdot_pattern": [6, 3, 2, 3],
    "lines.dotted_pattern": [2, 3],
    "mathtext.rm": "arial",
    "axes.labelsize": 15,
    "axes.titlesize": 15,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "xtick.major.size": 6,
    "ytick.major.size": 6,
    "legend.fontsize": 15,
    "legend.loc": "best",
    "legend.frameon": False,
    "xtick.direction": "out",
    "ytick.direction": "out",
}
mpl.update(params)


[docs] def stats(file_list, logger, output_dir, norm_mapping=None, plots=False): """ Compute statistics and generate histograms for variables in NetCDF files. Reads a collection of NetCDF files, computes descriptive statistics for each variable, and generates histogram plots saved to disk. In addition to raw statistics, transformed statistics using logarithmic (log1p) and square-root transformations are also computed. Parameters ---------- file_list : list of str Paths to the NetCDF files to process. logger : logging.Logger Logger used to report progress and informational messages. output_dir : str Directory where histogram plots will be saved. norm_mapping : dict, optional Dictionary to update with computed statistics. If None, a new dictionary is created. Default is None. plots : bool, optional If True, generate and save histogram plots for each variable. Default is False. Returns ------- dict Dictionary mapping variable names to their computed statistics. Each variable contains the following entries: Raw statistics: - vmin : float Minimum value - vmax : float Maximum value - vmean : float Mean value - vstd : float Standard deviation Robust statistics: - q1 : float First quartile (25th percentile) - q3 : float Third quartile (75th percentile) - iqr : float Interquartile range (q3 - q1) - median : float Median value Log-transformed statistics (log1p): - log_min : float - log_max : float - log_mean : float - log_std : float - log_q1 : float - log_q3 : float - log_iqr : float - log_median : float Square-root-transformed statistics: - sqrt_min : float - sqrt_max : float - sqrt_mean : float - sqrt_std : float - sqrt_q1 : float - sqrt_q3 : float - sqrt_iqr : float - sqrt_median : float Examples -------- >>> norm_mapping = stats( ... file_list=["data_1995.nc", "data_1996.nc"], ... logger=logger, ... output_dir="./stats", ... plots=True ... ) >>> norm_mapping["coszang"]["vmean"] 0.5 >>> norm_mapping["lai"]["log_median"] 1.23 Notes ----- - Log transformation uses np.log1p (log(1+x)) to handle zero values - Square root transformation uses np.sqrt with clipping to non-negative - Histograms use 200 bins and log-scaled y-axis """ variable_data = collections.defaultdict(list) if norm_mapping is None: norm_mapping = {} logger.info("Starting statistics computation for normalization") for fpath in file_list: try: ds = xr.open_dataset(fpath) logger.info(f"Processing file: {fpath}") for var_name in ds.data_vars: data = ds[var_name].values variable_data[var_name].append(data.flatten()) ds.close() except Exception as e: logger.warning(f"Skipping file {fpath} due to error: {e}") continue for var_name, arrays in variable_data.items(): full_data = np.concatenate(arrays) if full_data.size == 0: logger.warning(f"{var_name} is empty after filtering, skipping.") continue vmin = float(np.min(full_data)) vmax = float(np.max(full_data)) vmean = float(np.mean(full_data)) vstd = float(np.std(full_data)) q1 = float(np.percentile(full_data, 25)) q3 = float(np.percentile(full_data, 75)) iqr = q3 - q1 if q3 != q1 else 1.0 median = float(np.median(full_data)) log_data = np.log1p(np.clip(full_data, a_min=0, a_max=None)) log_min = float(log_data.min()) log_max = float(log_data.max()) log_mean = float(log_data.mean()) log_std = float(log_data.std()) log_q1 = float(np.percentile(log_data, 25)) log_q3 = float(np.percentile(log_data, 75)) log_iqr = log_q3 - log_q1 if log_q3 != log_q1 else 1.0 log_median = float(np.median(log_data)) sqrt_data = np.sqrt(np.clip(full_data, a_min=0, a_max=None)) sqrt_min = float(sqrt_data.min()) sqrt_max = float(sqrt_data.max()) sqrt_mean = float(sqrt_data.mean()) sqrt_std = float(sqrt_data.std()) sqrt_q1 = float(np.percentile(sqrt_data, 25)) sqrt_q3 = float(np.percentile(sqrt_data, 75)) sqrt_iqr = sqrt_q3 - sqrt_q1 if sqrt_q3 != sqrt_q1 else 1.0 sqrt_median = float(np.median(sqrt_data)) norm_mapping[var_name] = { "vmin": vmin, "vmax": vmax, "vmean": vmean, "vstd": vstd, "q1": q1, "q3": q3, "iqr": iqr, "median": median, "log_min": log_min, "log_max": log_max, "log_mean": log_mean, "log_std": log_std, "log_q1": log_q1, "log_q3": log_q3, "log_iqr": log_iqr, "log_median": log_median, "sqrt_min": sqrt_min, "sqrt_max": sqrt_max, "sqrt_mean": sqrt_mean, "sqrt_std": sqrt_std, "sqrt_q1": sqrt_q1, "sqrt_q3": sqrt_q3, "sqrt_iqr": sqrt_iqr, "sqrt_median": sqrt_median, } if plots: norm_label = "" file_suffix = "_histogram.png" fig = Figure(figsize=(8, 5)) canvas = FigureCanvasAgg(fig) ax = fig.add_subplot(111) ax.hist(full_data, bins=200) ax.set_yscale("log") ax.set_title(f"Histogram of {var_name}{norm_label}") ax.set_xlabel(var_name + norm_label) ax.set_ylabel("Log Count") ax.grid(True) out_path = os.path.join(output_dir, f"{var_name}{file_suffix}") canvas.print_figure(out_path, bbox_inches="tight") return norm_mapping
[docs] def subplots(nrows, ncols, figsize): """ Create a figure and grid of subplots without pyplot. This is a replacement for matplotlib.pyplot.subplots() that uses the Figure and FigureCanvasAgg API directly, suitable for headless environments. Parameters ---------- nrows : int Number of rows in the subplot grid. ncols : int Number of columns in the subplot grid. figsize : tuple Figure size in inches as (width, height). Returns ------- fig : matplotlib.figure.Figure The created figure object. axes : numpy.ndarray Array of axes objects with shape (nrows, ncols). """ fig = Figure(figsize=figsize) FigureCanvasAgg(fig) axes = [] for i in range(nrows): row_axes = [] for j in range(ncols): ax = fig.add_subplot(nrows, ncols, i * ncols + j + 1) row_axes.append(ax) axes.append(row_axes) return fig, np.array(axes)
[docs] def plot_flux_and_abs_lines( predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename="output_lines.png", logger=None, ): """ Create line plots for fluxes and absorption rates. Generates a multi-panel figure with line plots for four flux channels and optionally two absorption panels. Each panel shows predictions vs targets across vertical levels. Parameters ---------- predicts : torch.Tensor or np.ndarray Model predictions for fluxes of shape (batch_size, 4, seq_length). targets : torch.Tensor or np.ndarray Ground truth fluxes of shape (batch_size, 4, seq_length). abs12_predict : torch.Tensor or np.ndarray, optional Predicted absorption for channels 1-2 (direct beam). abs12_target : torch.Tensor or np.ndarray, optional True absorption for channels 1-2 (direct beam). abs34_predict : torch.Tensor or np.ndarray, optional Predicted absorption for channels 3-4 (diffuse). abs34_target : torch.Tensor or np.ndarray, optional True absorption for channels 3-4 (diffuse). filename : str, optional Output filename. Default is "output_lines.png". logger : logging.Logger, optional Logger instance for logging messages. If None, no logging is performed. Notes ----- Figure layout: - 2x2 grid for fluxes: - (0,0): Direct upwelling (Flux_direct^u) - (0,1): Direct downwelling (Flux_direct^d) - (1,0): Diffuse upwelling (Flux_diffusion^u) - (1,1): Diffuse downwelling (Flux_diffusion^d) - Optional 1x2 grid for absorption rates (if provided): - (2,0): Direct absorption (Abs_direct) - (2,1): Diffuse absorption (Abs_diffusion) The function randomly selects 10 samples from the batch for plotting. Each sample is shown with a different line style. Examples -------- >>> plot_flux_and_abs_lines( ... predicts=model_outputs, ... targets=ground_truth, ... filename="line_comparison.png", ... logger=logger ... ) """ include_abs12 = abs12_predict is not None and abs12_target is not None include_abs34 = abs34_predict is not None and abs34_target is not None include_abs = include_abs12 and include_abs34 if include_abs: fig, axes = subplots(3, 2, figsize=(10, 15)) else: fig, axes = subplots(2, 2, figsize=(10, 10)) fig.subplots_adjust( hspace=0.3, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1 ) canvas = FigureCanvasAgg(fig) num_samples = predicts.shape[0] sample_indices = random.sample(range(num_samples), 10) index_map = {(0, 0): 0, (0, 1): 1, (1, 0): 2, (1, 1): 3} name_dict = { (0, 0): {"plotname": r"$\mathrm{Flux_{direct}^{u}}$"}, (0, 1): {"plotname": r"$\mathrm{Flux_{direct}^{d}}$"}, (1, 0): {"plotname": r"$\mathrm{Flux_{diffusion}^{u}}$"}, (1, 1): {"plotname": r"$\mathrm{Flux_{diffusion}^{d}}$"}, } legend_lines = [] legend_labels = [] for (r, c), props in name_dict.items(): flux_idx = index_map[(r, c)] ax = axes[r, c] linestyles = mpltex.linestyle_generator() for sample_index in sample_indices: (pred_line,) = ax.plot( predicts[sample_index, flux_idx, :], label="Predict", **next(linestyles) ) (true_line,) = ax.plot( targets[sample_index, flux_idx, :], label="True", **next(linestyles) ) if (r, c) == (0, 0): legend_lines.extend([pred_line, true_line]) legend_labels.extend([pred_line.get_label(), true_line.get_label()]) ax.set_xlabel(r"$\mathrm{Vertical\ Level}$") ax.set_ylabel(props["plotname"]) ax.set_xlim(0, predicts.shape[-1] - 1) ax.set_ylim(0, 1) ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) ax.yaxis.set_major_locator(ticker.MultipleLocator(0.25)) if include_abs: ax = axes[2, 0] linestyles = mpltex.linestyle_generator() for sample_index in sample_indices: (pred_line,) = ax.plot( abs12_predict[sample_index, 0, :], label="Predict)", **next(linestyles), ) (true_line,) = ax.plot( abs12_target[sample_index, 0, :], label="True", **next(linestyles), ) ax.set_ylabel(r"$\mathrm{Abs_{direct}}$") ax.set_xlabel(r"$\mathrm{Vertical\ Level}$") ax.set_xlim(0, abs12_predict.shape[-1] - 1) ax.set_ylim(0, 1) ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) ax.yaxis.set_major_locator(ticker.MultipleLocator(0.25)) ax = axes[2, 1] linestyles = mpltex.linestyle_generator() for sample_index in sample_indices: (pred_line,) = ax.plot( abs34_predict[sample_index, 0, :], label="Predict", **next(linestyles), ) (true_line,) = ax.plot( abs34_target[sample_index, 0, :], label="True", **next(linestyles), ) ax.set_ylabel(r"$\mathrm{Abs_{diffusion}}$") ax.set_xlabel(r"$\mathrm{Vertical\ Level}$") ax.set_xlim(0, abs34_predict.shape[-1] - 1) ax.set_ylim(0, 1) ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) ax.yaxis.set_major_locator(ticker.MultipleLocator(0.25)) fig.legend( handles=legend_lines, labels=legend_labels, loc="center right", bbox_to_anchor=(1.1, 0.5), borderaxespad=0.5, frameon=False, ncol=1, ) canvas.print_figure(filename, bbox_inches="tight") if logger: logger.info(f"Saved line plot to {filename}") else: print(f"Saved line plot to {filename}")
[docs] def plot_flux_and_abs( predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename="output.png", logger=None, ): """ Create hexbin plots for fluxes and absorption rates. Generates a multi-panel figure with hexbin density plots showing the relationship between predicted and true values. Useful for assessing prediction accuracy across the entire dataset. Parameters ---------- predicts : torch.Tensor or np.ndarray Model predictions for fluxes of shape (batch_size, 4, seq_length). targets : torch.Tensor or np.ndarray Ground truth fluxes of shape (batch_size, 4, seq_length). abs12_predict : torch.Tensor or np.ndarray, optional Predicted absorption for channels 1-2 (direct beam). abs12_target : torch.Tensor or np.ndarray, optional True absorption for channels 1-2 (direct beam). abs34_predict : torch.Tensor or np.ndarray, optional Predicted absorption for channels 3-4 (diffuse). abs34_target : torch.Tensor or np.ndarray, optional True absorption for channels 3-4 (diffuse). filename : str, optional Output filename. Default is "output.png". logger : logging.Logger, optional Logger instance for logging messages. If None, no logging is performed. Notes ----- - Hexbin plots use logarithmic color scale (bins='log') - Includes diagonal reference line (y=x) in red dotted style - Displays R² score in the top-left corner of each panel - Shared colorbar on the right with log10(count) label - Grid size of 100 bins for density calculation The hexbin representation is particularly effective for large datasets where scatter plots would suffer from overplotting. Examples -------- >>> plot_flux_and_abs( ... predicts=predictions, ... targets=targets, ... filename="hexbin_diagnostics.png", ... logger=logger ... ) """ include_abs12 = abs12_predict is not None and abs12_target is not None include_abs34 = abs34_predict is not None and abs34_target is not None include_abs = include_abs12 and include_abs34 if include_abs: fig, axes = subplots(3, 2, figsize=(10, 15)) else: fig, axes = subplots(2, 2, figsize=(10, 10)) fig.subplots_adjust( hspace=0.3, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1 ) canvas = FigureCanvasAgg(fig) index_map = {(0, 0): 0, (0, 1): 1, (1, 0): 2, (1, 1): 3} name_dict = { (0, 0): {"name": "Flux1u", "plotname": r"$\mathrm{Flux_{direct}^{u}}$"}, (0, 1): {"name": "Flux1d", "plotname": r"$\mathrm{Flux_{direct}^{d}}$"}, (1, 0): {"name": "Flux2u", "plotname": r"$\mathrm{Flux_{diffusion}^{u}}$"}, (1, 1): {"name": "Flux2d", "plotname": r"$\mathrm{Flux_{diffusion}^{d}}$"}, } for (r, c), props in name_dict.items(): flux_idx = index_map[(r, c)] y_pred = predicts[:, flux_idx, :].reshape(-1) y_true = targets[:, flux_idx, :].reshape(-1) ax = axes[r, c] hb = ax.hexbin( y_true, y_pred, gridsize=100, cmap="jet", bins="log", vmin=1, vmax=1e6 ) ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) ax.plot( [y_true.min(), y_true.max()], [y_true.min(), y_true.max()], "r:", linewidth=0.5, ) r2 = r2_score(y_true, y_pred) ax.text(0.05, 0.9, f"$R^2$: {r2:.5f}", transform=ax.transAxes) flux_name = props["plotname"] ax.set_title(rf"{flux_name}") ax.set_xlabel("Observed") ax.set_ylabel("Predicted") if include_abs: ax = axes[2, 0] y_pred = abs12_predict.reshape(-1) y_true = abs12_target.reshape(-1) hb = ax.hexbin( y_true, y_pred, gridsize=100, cmap="jet", bins="log", vmin=1, vmax=1e6 ) ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) ax.plot( [y_true.min(), y_true.max()], [y_true.min(), y_true.max()], "r:", linewidth=0.5, ) r2 = r2_score(y_true, y_pred) ax.text(0.05, 0.9, f"$R^2$: {r2:.5f}", transform=ax.transAxes) ax.set_title(r"$\mathrm{Abs_{direct}}$") ax.set_xlabel("Observed") ax.set_ylabel("Predicted") ax = axes[2, 1] y_pred = abs34_predict.reshape(-1) y_true = abs34_target.reshape(-1) hb = ax.hexbin( y_true, y_pred, gridsize=100, cmap="jet", bins="log", vmin=1, vmax=1e6 ) ax.plot( [y_true.min(), y_true.max()], [y_true.min(), y_true.max()], "r:", linewidth=0.5, ) ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) r2 = r2_score(y_true, y_pred) ax.text(0.05, 0.9, f"$R^2$: {r2:.5f}", transform=ax.transAxes) ax.set_title(r"$\mathrm{Abs_{diffusion}}$") ax.set_xlabel("Observed") ax.set_ylabel("Predicted") # Shared colorbar on the right cbar_ax = fig.add_axes([0.92, 0.1, 0.015, 0.8]) fig.colorbar(hb, cax=cbar_ax, label=r"$\mathrm{\log_{10}[Count]}$") canvas.print_figure(filename, bbox_inches="tight") if logger: logger.info(f"Saved hexbin plot to {filename}") else: print(f"Saved hexbin plot to {filename}")
[docs] def plot_all_diagnostics( predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, n_pft=15, n_bands=2, n_chans=4, output_dir="./results", prefix="diagnostics", logger=None, ): """ Generate all diagnostic plots: aggregated, per PFT, per band. This function creates a comprehensive set of diagnostic plots including one aggregated plot (all PFTs, all bands combined), per-band plots for each selected PFT (VIS and NIR bands), and line plots and hexbin plots for each combination. Parameters ---------- predicts : torch.Tensor or np.ndarray Model predictions for fluxes. Shape: (batch, 4, n_pft, n_bands, seq_length) targets : torch.Tensor or np.ndarray Ground truth targets. Same shape as predicts. abs12_predict : torch.Tensor or np.ndarray, optional Predicted absorption for channels 1-2 (direct beam). abs12_target : torch.Tensor or np.ndarray, optional True absorption for channels 1-2. abs34_predict : torch.Tensor or np.ndarray, optional Predicted absorption for channels 3-4 (diffuse). abs34_target : torch.Tensor or np.ndarray, optional True absorption for channels 3-4. n_pft : int, optional Number of Plant Functional Types. Default is 15. n_bands : int, optional Number of spectral bands. Default is 2 (VIS and NIR). n_chans : int, optional Number of output channels. Default is 4. output_dir : str, optional Directory to save diagnostic plots. Default is "./results". prefix : str, optional Prefix for output filenames. Default is "diagnostics". logger : logging.Logger, optional Logger instance for logging messages. If None, no logging is performed. Notes ----- The function generates 1 aggregated hexbin plot, 1 aggregated line plot, and for each selected PFT (up to 8 randomly selected): hexbin and line plots per band (VIS and NIR). Total plots = 1 + (selected_pfts * n_bands * 2) where selected_pfts = min(8, n_pft). Examples -------- >>> plot_all_diagnostics( ... predicts=model_outputs, ... targets=targets, ... n_pft=15, ... n_bands=2, ... output_dir="./diagnostics", ... prefix="experiment_1" ... ) """ predicts_np = predicts.detach().numpy() targets_np = targets.detach().numpy() if abs12_predict is not None: abs12_predict_np = abs12_predict.detach().numpy() abs12_target_np = abs12_target.detach().numpy() abs34_predict_np = abs34_predict.detach().numpy() abs34_target_np = abs34_target.detach().numpy() else: abs12_predict_np = None abs12_target_np = None abs34_predict_np = None abs34_target_np = None # 1. Aggregated plot (all PFTs, all bands) plot_flux_and_abs( predicts_np, targets_np, abs12_predict_np, abs12_target_np, abs34_predict_np, abs34_target_np, filename=os.path.join(output_dir, f"{prefix}_aggregated.png"), logger=logger, ) # 2. Randomly select 8 PFTs (or fewer if n_pft < 8) num_pft_to_plot = min(8, n_pft) selected_pfts = random.sample(range(n_pft), num_pft_to_plot) if logger: logger.info(f"Selected PFTs for detailed plots: {selected_pfts}") # 3. For each selected PFT, create plots for each band for pft_idx in selected_pfts: for band_idx in range(n_bands): # Extract data for specific PFT and band - shape (batch, 4, seq) predicts_pft_band = predicts_np[:, :, pft_idx, band_idx, :] targets_pft_band = targets_np[:, :, pft_idx, band_idx, :] if abs12_predict_np is not None: abs12_pft_band = abs12_predict_np[:, :, pft_idx, band_idx, :] abs12_target_pft_band = abs12_target_np[:, :, pft_idx, band_idx, :] abs34_pft_band = abs34_predict_np[:, :, pft_idx, band_idx, :] abs34_target_pft_band = abs34_target_np[:, :, pft_idx, band_idx, :] else: abs12_pft_band = None abs12_target_pft_band = None abs34_pft_band = None abs34_target_pft_band = None band_name = "VIS" if band_idx == 0 else "NIR" base_filename = os.path.join( output_dir, f"{prefix}_pft{pft_idx:02d}_{band_name}" ) # Hexbin plot plot_flux_and_abs( predicts_pft_band, targets_pft_band, abs12_pft_band, abs12_target_pft_band, abs34_pft_band, abs34_target_pft_band, filename=f"{base_filename}_hexbin.png", logger=logger, ) # Line plot plot_flux_and_abs_lines( predicts_pft_band, targets_pft_band, abs12_pft_band, abs12_target_pft_band, abs34_pft_band, abs34_target_pft_band, filename=f"{base_filename}_lines.png", logger=logger, ) if logger: logger.info( f"Generated {1 + num_pft_to_plot * n_bands} diagnostic plots in {output_dir}" ) else: print( f"Generated {1 + num_pft_to_plot * n_bands} diagnostic plots in {output_dir}" )
[docs] def plot_metric_histories( train_history, valid_history, filename="training_validation_metrics.png", logger=None, ): """ Plot training and validation metrics over epochs. Creates a multi-panel figure showing the evolution of various metrics (e.g., NMAE, NMSE, R2, bias) over training epochs. Parameters ---------- train_history : dict Dictionary with metric names as keys and lists of training values. Example: {"nmae": [0.1, 0.08, 0.06], "r2": [0.85, 0.88, 0.91]} valid_history : dict Dictionary with metric names as keys and lists of validation values. Same structure as train_history. filename : str, optional Output filename. Default is "training_validation_metrics.png". logger : logging.Logger, optional Logger instance for logging messages. If None, no logging is performed. Notes ----- - Metrics are plotted on a logarithmic scale for better visualization of convergence - Each metric gets its own panel arranged in a grid (3 columns) - Blue lines: training metrics - Orange lines: validation metrics - Grid is enabled for all panels Examples -------- >>> train_history = {"nmae": [0.15, 0.12, 0.09, 0.07], ... "nmse": [0.08, 0.06, 0.04, 0.03]} >>> valid_history = {"nmae": [0.16, 0.13, 0.10, 0.08], ... "nmse": [0.09, 0.07, 0.05, 0.04]} >>> plot_metric_histories( ... train_history=train_history, ... valid_history=valid_history, ... filename="metrics.png", ... logger=logger ... ) """ num_metrics = len(train_history) if num_metrics == 0: if logger: logger.warning("No metrics to plot") return cols = 3 rows = math.ceil(num_metrics / cols) fig = Figure(figsize=(5 * cols, 4 * rows)) fig.set_tight_layout(True) canvas = FigureCanvasAgg(fig) gs = gridspec.GridSpec(rows, cols) for idx, key in enumerate(train_history): row, col = divmod(idx, cols) ax = fig.add_subplot(gs[row, col]) linestyles = mpltex.linestyle_generator(markers=[]) ax.plot(train_history[key], label="train", **next(linestyles)) ax.plot(valid_history[key], label="valid", **next(linestyles)) ax.set_yscale("log") # ax.set_title(key.replace('_', ' ').upper()) ax.set_xlabel("Epoch") ax.set_ylabel(key.replace("_", " ").upper()) ax.legend() ax.grid(True) canvas.print_figure(filename, bbox_inches="tight") if logger: logger.info(f"Saved metric history plot to {filename}") else: print(f"Saved metric history plot to {filename}")
[docs] def plot_loss_histories( train_loss, valid_loss, filename="training_validation_loss.png", logger=None ): """ Plot training and validation loss over epochs. Creates a single-panel figure showing the loss evolution during training. Parameters ---------- train_loss : list or array Training loss values over epochs. valid_loss : list or array Validation loss values over epochs. filename : str, optional Output filename. Default is "training_validation_loss.png". logger : logging.Logger, optional Logger instance for logging messages. If None, no logging is performed. Notes ----- - Uses logarithmic scale for y-axis to visualize exponential decay - Blue line: training loss - Orange line: validation loss - Includes grid for better readability - Both lines share the same x-axis (epoch number) Examples -------- >>> train_losses = [0.5, 0.3, 0.2, 0.15, 0.12] >>> valid_losses = [0.55, 0.35, 0.25, 0.18, 0.15] >>> plot_loss_histories( ... train_loss=train_losses, ... valid_loss=valid_losses, ... filename="loss_curves.png", ... logger=logger ... ) """ fig = Figure(figsize=(8, 5)) canvas = FigureCanvasAgg(fig) ax = fig.add_subplot(111) linestyles = mpltex.linestyle_generator(markers=[]) ax.plot(train_loss, label="train", **next(linestyles)) ax.plot(valid_loss, label="valid", **next(linestyles)) ax.set_yscale("log") ax.set_title("LOSS") ax.set_xlabel("Epoch") ax.set_ylabel("Loss Value") ax.legend() ax.grid(True) canvas.print_figure(filename, bbox_inches="tight") if logger: logger.info(f"Saved loss history plot to {filename}") else: print(f"Saved loss history plot to {filename}")
[docs] def plot_spatial_temporal_density( sindex_tracker, tindex_tracker, mode="train", save_dir="./tests_plots", filename="density_scatter", figsize=(10, 10), logger=None, ): """ Plot a density scatter plot of spatial index vs temporal index with marginal histograms. This function creates a 2D density scatter plot (hexbin) showing the distribution of spatial indices (processor ranks) across temporal indices, with: - Right plot: Histogram of temporal index distribution (horizontal bars) - Top plot: Histogram of spatial index distribution (vertical bars) Parameters ---------- sindex_tracker : list or array-like List of spatial indices (processor ranks) for each data sample. tindex_tracker : list or array-like List of temporal indices for each data sample. mode : str, optional Dataset mode identifier ("train", "validation", "test"). Used in filename. Default is "train". save_dir : str, optional Directory path where the plot will be saved. Default is "./tests_plots". filename : str, optional Base name for the output file. Default is "density_scatter". figsize : tuple, optional Figure size as (width, height) in inches. Default is (10, 10). logger : logging.Logger, optional Logger instance for logging messages. If None, no logging is performed. Returns ------- str or None Path to the saved plot file, or None if no data to plot. Notes ----- The plot layout uses GridSpec with the following structure: - Top-left: Histogram of spatial index distribution (vertical bars) - Top-right: Empty (no plot) - Bottom-left: Main hexbin density scatter plot - Bottom-right: Histogram of temporal index distribution (horizontal bars) A colorbar is placed below the main plot with log-scaled counts. This visualization is useful for: - Verifying uniform sampling across spatial processors - Checking temporal coverage of the dataset - Detecting sampling biases in training/validation splits Examples -------- >>> # After training, check sampling distribution >>> plot_spatial_temporal_density( ... sindex_tracker=dataset.sindex_tracker, ... tindex_tracker=dataset.tindex_tracker, ... mode="train", ... save_dir="./analysis", ... logger=logger ... ) """ if len(sindex_tracker) == 0 or len(tindex_tracker) == 0: print(f"No data to plot for {mode} mode") return None # Convert to numpy arrays sindex_tracker = np.array(sindex_tracker) tindex_tracker = np.array(tindex_tracker) # Get limits min_sindex = int(sindex_tracker.min()) max_sindex = int(sindex_tracker.max()) min_time = int(tindex_tracker.min()) max_time = int(tindex_tracker.max()) # Create figure with GridSpec for custom layout fig = Figure(figsize=figsize) canvas = FigureCanvasAgg(fig) # Define grid: # - Top histogram takes 20% height # - Main hexbin takes 80% height # - Right histogram takes 20% width # - Left main takes 80% width # Add space between panels gs = gridspec.GridSpec( 2, 2, figure=fig, height_ratios=[0.2, 0.8], width_ratios=[0.8, 0.2], hspace=0.2, wspace=0.2, ) # Main hexbin plot (bottom-left) ax_main = fig.add_subplot(gs[1, 0]) # Right histogram (bottom-right) - temporal index distribution ax_right = fig.add_subplot(gs[1, 1]) # Top histogram (top-left) - spatial index distribution ax_top = fig.add_subplot(gs[0, 0]) # Top-right corner is empty ax_empty = fig.add_subplot(gs[0, 1]) ax_empty.axis("off") # Create density scatter plot (hexbin) in main axis hb = ax_main.hexbin( sindex_tracker, tindex_tracker, gridsize=100, extent=[min_sindex - 0.5, max_sindex + 0.5, min_time, max_time], cmap="jet", bins="log", mincnt=1, edgecolors="none", ) ax_main.set_xlabel("Spatial Index (Processor Rank)") ax_main.set_ylabel("Temporal Index") ax_main.set_xlim(min_sindex - 0.5, max_sindex + 0.5) ax_main.set_ylim(min_time, max_time) ax_main.grid(True, alpha=0.3, linestyle="--") # Right plot: Histogram of temporal index distribution (horizontal bars) unique_tindices = np.arange(min_time, max_time + 1) temporal_counts = [np.sum(tindex_tracker == t) for t in unique_tindices] ax_right.barh( unique_tindices, temporal_counts, height=0.8, color="coral", alpha=0.7, edgecolor="black", ) ax_right.set_xlabel("Frequency") ax_right.set_ylim(min_time, max_time) ax_right.grid(True, alpha=0.3, linestyle="--", axis="x") ax_right.tick_params(axis="both") # Top plot: Histogram of spatial index distribution (vertical bars) unique_sindices = np.arange(min_sindex, max_sindex + 1) spatial_counts = [np.sum(sindex_tracker == s) for s in unique_sindices] ax_top.bar( unique_sindices, spatial_counts, width=0.8, color="steelblue", alpha=0.7, edgecolor="black", ) ax_top.set_ylabel("Frequency") ax_top.set_xlim(min_sindex - 0.5, max_sindex + 0.5) ax_top.grid(True, alpha=0.3, linestyle="--", axis="y") ax_top.tick_params(axis="both") # Add colorbar at the bottom, spanning the width of the main plot # Get the position of the main plot main_pos = ax_main.get_position() # Create colorbar axes below the main plot, with same width cbar_ax = fig.add_axes([main_pos.x0, main_pos.y0 - 0.1, main_pos.width, 0.02]) cbar = fig.colorbar(hb, cax=cbar_ax, orientation="horizontal") cbar.set_label(r"$\log_{10}[\mathrm{Count}]$") # Save figure os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, f"{filename}_{mode}.png") canvas.print_figure(save_path, bbox_inches="tight") if logger: logger.info( f"Saved density scatter plot with marginal histograms to: {save_path}" ) else: print(f"Saved density scatter plot with marginal histograms to: {save_path}")