# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
"""
Diagnostics and visualization utilities for RTnn radiative transfer model.
This module provides comprehensive visualization tools for analyzing and
evaluating neural network emulators of atmospheric radiative transfer. It
includes functions for creating diagnostic plots of model predictions,
training histories, and data statistics.
The module is designed to support:
- Comparison of model predictions against ground truth targets
- Visualization of radiative fluxes (direct/diffuse, upwelling/downwelling)
- Absorption rate analysis across vertical levels and spectral bands
- Per-plant functional type (PFT) and per-spectral band diagnostics
- Training and validation metric tracking over epochs
- Spatial and temporal sampling distribution analysis
- Statistical characterization of input variables for normalization
Key visualization types
-----------------------
- **Hexbin plots**: Density scatter plots with color mapping for high-volume data
- **Line plots**: Time series or vertical profile comparisons
- **Histograms**: Distribution analysis with optional log scaling
- **Marginal histograms**: 2D density plots with side distributions
- **Multi-panel layouts**: Systematic comparison across flux channels
The module follows the matplotlib architecture using Figure and FigureCanvasAgg
for non-interactive, file-based rendering suitable for batch processing and
headless environments.
Notes
-----
All plotting functions are designed to work with PyTorch tensors or numpy arrays
and save figures directly to disk. The module uses custom matplotlib parameters
optimized for scientific publication quality.
Default figure parameters:
- Font family: DejaVu Sans
- Font sizes: axes labels (15), titles (15), ticks (12), legends (15)
- Line widths: 2
- Tick direction: outward
- Legend: frame off, best location
Examples
--------
Basic usage for generating diagnostic plots:
>>> import torch
>>> from rtnn.diagnostics import plot_flux_and_abs, plot_metric_histories
>>>
>>> # Assuming you have model predictions and targets
>>> predicts = torch.randn(32, 4, 10) # (batch, channels, levels)
>>> targets = torch.randn(32, 4, 10)
>>>
>>> # Create hexbin plot
>>> plot_flux_and_abs(
... predicts=predicts,
... targets=targets,
... filename="diagnostics.png",
... logger=logger
... )
>>>
>>> # Plot training history
>>> train_history = {"nmae": [0.1, 0.08, 0.06], "r2": [0.85, 0.88, 0.91]}
>>> valid_history = {"nmae": [0.12, 0.10, 0.08], "r2": [0.82, 0.85, 0.89]}
>>> plot_metric_histories(
... train_history=train_history,
... valid_history=valid_history,
... filename="metrics.png"
... )
See Also
--------
rtnn.dataset.DataPreprocessor : Data loading and preprocessing
rtnn.evaluater : Metrics and loss functions for evaluation
rtnn.models : Neural network architectures for radiative transfer
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib.gridspec as gridspec
import numpy as np
import mpltex
import math
from matplotlib import rcParams as mpl
from sklearn.metrics import r2_score
import matplotlib.ticker as ticker
import random
import xarray as xr
import os
import collections
params = {
"font.family": "DejaVu Sans",
# 'figure.dpi': 300,
# 'savefig.dpi': 300,
"lines.linewidth": 2,
"lines.dashed_pattern": [4, 2],
"lines.dashdot_pattern": [6, 3, 2, 3],
"lines.dotted_pattern": [2, 3],
"mathtext.rm": "arial",
"axes.labelsize": 15,
"axes.titlesize": 15,
"xtick.labelsize": 12,
"ytick.labelsize": 12,
"xtick.major.size": 6,
"ytick.major.size": 6,
"legend.fontsize": 15,
"legend.loc": "best",
"legend.frameon": False,
"xtick.direction": "out",
"ytick.direction": "out",
}
mpl.update(params)
[docs]
def stats(file_list, logger, output_dir, norm_mapping=None, plots=False):
"""
Compute statistics and generate histograms for variables in NetCDF files.
Reads a collection of NetCDF files, computes descriptive statistics for each
variable, and generates histogram plots saved to disk. In addition to raw
statistics, transformed statistics using logarithmic (log1p) and square-root
transformations are also computed.
Parameters
----------
file_list : list of str
Paths to the NetCDF files to process.
logger : logging.Logger
Logger used to report progress and informational messages.
output_dir : str
Directory where histogram plots will be saved.
norm_mapping : dict, optional
Dictionary to update with computed statistics. If None, a new dictionary
is created. Default is None.
plots : bool, optional
If True, generate and save histogram plots for each variable.
Default is False.
Returns
-------
dict
Dictionary mapping variable names to their computed statistics. Each
variable contains the following entries:
Raw statistics:
- vmin : float
Minimum value
- vmax : float
Maximum value
- vmean : float
Mean value
- vstd : float
Standard deviation
Robust statistics:
- q1 : float
First quartile (25th percentile)
- q3 : float
Third quartile (75th percentile)
- iqr : float
Interquartile range (q3 - q1)
- median : float
Median value
Log-transformed statistics (log1p):
- log_min : float
- log_max : float
- log_mean : float
- log_std : float
- log_q1 : float
- log_q3 : float
- log_iqr : float
- log_median : float
Square-root-transformed statistics:
- sqrt_min : float
- sqrt_max : float
- sqrt_mean : float
- sqrt_std : float
- sqrt_q1 : float
- sqrt_q3 : float
- sqrt_iqr : float
- sqrt_median : float
Examples
--------
>>> norm_mapping = stats(
... file_list=["data_1995.nc", "data_1996.nc"],
... logger=logger,
... output_dir="./stats",
... plots=True
... )
>>> norm_mapping["coszang"]["vmean"]
0.5
>>> norm_mapping["lai"]["log_median"]
1.23
Notes
-----
- Log transformation uses np.log1p (log(1+x)) to handle zero values
- Square root transformation uses np.sqrt with clipping to non-negative
- Histograms use 200 bins and log-scaled y-axis
"""
variable_data = collections.defaultdict(list)
if norm_mapping is None:
norm_mapping = {}
logger.info("Starting statistics computation for normalization")
for fpath in file_list:
try:
ds = xr.open_dataset(fpath)
logger.info(f"Processing file: {fpath}")
for var_name in ds.data_vars:
data = ds[var_name].values
variable_data[var_name].append(data.flatten())
ds.close()
except Exception as e:
logger.warning(f"Skipping file {fpath} due to error: {e}")
continue
for var_name, arrays in variable_data.items():
full_data = np.concatenate(arrays)
if full_data.size == 0:
logger.warning(f"{var_name} is empty after filtering, skipping.")
continue
vmin = float(np.min(full_data))
vmax = float(np.max(full_data))
vmean = float(np.mean(full_data))
vstd = float(np.std(full_data))
q1 = float(np.percentile(full_data, 25))
q3 = float(np.percentile(full_data, 75))
iqr = q3 - q1 if q3 != q1 else 1.0
median = float(np.median(full_data))
log_data = np.log1p(np.clip(full_data, a_min=0, a_max=None))
log_min = float(log_data.min())
log_max = float(log_data.max())
log_mean = float(log_data.mean())
log_std = float(log_data.std())
log_q1 = float(np.percentile(log_data, 25))
log_q3 = float(np.percentile(log_data, 75))
log_iqr = log_q3 - log_q1 if log_q3 != log_q1 else 1.0
log_median = float(np.median(log_data))
sqrt_data = np.sqrt(np.clip(full_data, a_min=0, a_max=None))
sqrt_min = float(sqrt_data.min())
sqrt_max = float(sqrt_data.max())
sqrt_mean = float(sqrt_data.mean())
sqrt_std = float(sqrt_data.std())
sqrt_q1 = float(np.percentile(sqrt_data, 25))
sqrt_q3 = float(np.percentile(sqrt_data, 75))
sqrt_iqr = sqrt_q3 - sqrt_q1 if sqrt_q3 != sqrt_q1 else 1.0
sqrt_median = float(np.median(sqrt_data))
norm_mapping[var_name] = {
"vmin": vmin,
"vmax": vmax,
"vmean": vmean,
"vstd": vstd,
"q1": q1,
"q3": q3,
"iqr": iqr,
"median": median,
"log_min": log_min,
"log_max": log_max,
"log_mean": log_mean,
"log_std": log_std,
"log_q1": log_q1,
"log_q3": log_q3,
"log_iqr": log_iqr,
"log_median": log_median,
"sqrt_min": sqrt_min,
"sqrt_max": sqrt_max,
"sqrt_mean": sqrt_mean,
"sqrt_std": sqrt_std,
"sqrt_q1": sqrt_q1,
"sqrt_q3": sqrt_q3,
"sqrt_iqr": sqrt_iqr,
"sqrt_median": sqrt_median,
}
if plots:
norm_label = ""
file_suffix = "_histogram.png"
fig = Figure(figsize=(8, 5))
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(111)
ax.hist(full_data, bins=200)
ax.set_yscale("log")
ax.set_title(f"Histogram of {var_name}{norm_label}")
ax.set_xlabel(var_name + norm_label)
ax.set_ylabel("Log Count")
ax.grid(True)
out_path = os.path.join(output_dir, f"{var_name}{file_suffix}")
canvas.print_figure(out_path, bbox_inches="tight")
return norm_mapping
[docs]
def subplots(nrows, ncols, figsize):
"""
Create a figure and grid of subplots without pyplot.
This is a replacement for matplotlib.pyplot.subplots() that uses the
Figure and FigureCanvasAgg API directly, suitable for headless environments.
Parameters
----------
nrows : int
Number of rows in the subplot grid.
ncols : int
Number of columns in the subplot grid.
figsize : tuple
Figure size in inches as (width, height).
Returns
-------
fig : matplotlib.figure.Figure
The created figure object.
axes : numpy.ndarray
Array of axes objects with shape (nrows, ncols).
"""
fig = Figure(figsize=figsize)
FigureCanvasAgg(fig)
axes = []
for i in range(nrows):
row_axes = []
for j in range(ncols):
ax = fig.add_subplot(nrows, ncols, i * ncols + j + 1)
row_axes.append(ax)
axes.append(row_axes)
return fig, np.array(axes)
[docs]
def plot_flux_and_abs_lines(
predicts,
targets,
abs12_predict=None,
abs12_target=None,
abs34_predict=None,
abs34_target=None,
filename="output_lines.png",
logger=None,
):
"""
Create line plots for fluxes and absorption rates.
Generates a multi-panel figure with line plots for four flux channels and
optionally two absorption panels. Each panel shows predictions vs targets
across vertical levels.
Parameters
----------
predicts : torch.Tensor or np.ndarray
Model predictions for fluxes of shape (batch_size, 4, seq_length).
targets : torch.Tensor or np.ndarray
Ground truth fluxes of shape (batch_size, 4, seq_length).
abs12_predict : torch.Tensor or np.ndarray, optional
Predicted absorption for channels 1-2 (direct beam).
abs12_target : torch.Tensor or np.ndarray, optional
True absorption for channels 1-2 (direct beam).
abs34_predict : torch.Tensor or np.ndarray, optional
Predicted absorption for channels 3-4 (diffuse).
abs34_target : torch.Tensor or np.ndarray, optional
True absorption for channels 3-4 (diffuse).
filename : str, optional
Output filename. Default is "output_lines.png".
logger : logging.Logger, optional
Logger instance for logging messages. If None, no logging is performed.
Notes
-----
Figure layout:
- 2x2 grid for fluxes:
- (0,0): Direct upwelling (Flux_direct^u)
- (0,1): Direct downwelling (Flux_direct^d)
- (1,0): Diffuse upwelling (Flux_diffusion^u)
- (1,1): Diffuse downwelling (Flux_diffusion^d)
- Optional 1x2 grid for absorption rates (if provided):
- (2,0): Direct absorption (Abs_direct)
- (2,1): Diffuse absorption (Abs_diffusion)
The function randomly selects 10 samples from the batch for plotting.
Each sample is shown with a different line style.
Examples
--------
>>> plot_flux_and_abs_lines(
... predicts=model_outputs,
... targets=ground_truth,
... filename="line_comparison.png",
... logger=logger
... )
"""
include_abs12 = abs12_predict is not None and abs12_target is not None
include_abs34 = abs34_predict is not None and abs34_target is not None
include_abs = include_abs12 and include_abs34
if include_abs:
fig, axes = subplots(3, 2, figsize=(10, 15))
else:
fig, axes = subplots(2, 2, figsize=(10, 10))
fig.subplots_adjust(
hspace=0.3, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
canvas = FigureCanvasAgg(fig)
num_samples = predicts.shape[0]
sample_indices = random.sample(range(num_samples), 10)
index_map = {(0, 0): 0, (0, 1): 1, (1, 0): 2, (1, 1): 3}
name_dict = {
(0, 0): {"plotname": r"$\mathrm{Flux_{direct}^{u}}$"},
(0, 1): {"plotname": r"$\mathrm{Flux_{direct}^{d}}$"},
(1, 0): {"plotname": r"$\mathrm{Flux_{diffusion}^{u}}$"},
(1, 1): {"plotname": r"$\mathrm{Flux_{diffusion}^{d}}$"},
}
legend_lines = []
legend_labels = []
for (r, c), props in name_dict.items():
flux_idx = index_map[(r, c)]
ax = axes[r, c]
linestyles = mpltex.linestyle_generator()
for sample_index in sample_indices:
(pred_line,) = ax.plot(
predicts[sample_index, flux_idx, :], label="Predict", **next(linestyles)
)
(true_line,) = ax.plot(
targets[sample_index, flux_idx, :], label="True", **next(linestyles)
)
if (r, c) == (0, 0):
legend_lines.extend([pred_line, true_line])
legend_labels.extend([pred_line.get_label(), true_line.get_label()])
ax.set_xlabel(r"$\mathrm{Vertical\ Level}$")
ax.set_ylabel(props["plotname"])
ax.set_xlim(0, predicts.shape[-1] - 1)
ax.set_ylim(0, 1)
ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.25))
if include_abs:
ax = axes[2, 0]
linestyles = mpltex.linestyle_generator()
for sample_index in sample_indices:
(pred_line,) = ax.plot(
abs12_predict[sample_index, 0, :],
label="Predict)",
**next(linestyles),
)
(true_line,) = ax.plot(
abs12_target[sample_index, 0, :],
label="True",
**next(linestyles),
)
ax.set_ylabel(r"$\mathrm{Abs_{direct}}$")
ax.set_xlabel(r"$\mathrm{Vertical\ Level}$")
ax.set_xlim(0, abs12_predict.shape[-1] - 1)
ax.set_ylim(0, 1)
ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.25))
ax = axes[2, 1]
linestyles = mpltex.linestyle_generator()
for sample_index in sample_indices:
(pred_line,) = ax.plot(
abs34_predict[sample_index, 0, :],
label="Predict",
**next(linestyles),
)
(true_line,) = ax.plot(
abs34_target[sample_index, 0, :],
label="True",
**next(linestyles),
)
ax.set_ylabel(r"$\mathrm{Abs_{diffusion}}$")
ax.set_xlabel(r"$\mathrm{Vertical\ Level}$")
ax.set_xlim(0, abs34_predict.shape[-1] - 1)
ax.set_ylim(0, 1)
ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.25))
fig.legend(
handles=legend_lines,
labels=legend_labels,
loc="center right",
bbox_to_anchor=(1.1, 0.5),
borderaxespad=0.5,
frameon=False,
ncol=1,
)
canvas.print_figure(filename, bbox_inches="tight")
if logger:
logger.info(f"Saved line plot to {filename}")
else:
print(f"Saved line plot to {filename}")
[docs]
def plot_flux_and_abs(
predicts,
targets,
abs12_predict=None,
abs12_target=None,
abs34_predict=None,
abs34_target=None,
filename="output.png",
logger=None,
):
"""
Create hexbin plots for fluxes and absorption rates.
Generates a multi-panel figure with hexbin density plots showing the
relationship between predicted and true values. Useful for assessing
prediction accuracy across the entire dataset.
Parameters
----------
predicts : torch.Tensor or np.ndarray
Model predictions for fluxes of shape (batch_size, 4, seq_length).
targets : torch.Tensor or np.ndarray
Ground truth fluxes of shape (batch_size, 4, seq_length).
abs12_predict : torch.Tensor or np.ndarray, optional
Predicted absorption for channels 1-2 (direct beam).
abs12_target : torch.Tensor or np.ndarray, optional
True absorption for channels 1-2 (direct beam).
abs34_predict : torch.Tensor or np.ndarray, optional
Predicted absorption for channels 3-4 (diffuse).
abs34_target : torch.Tensor or np.ndarray, optional
True absorption for channels 3-4 (diffuse).
filename : str, optional
Output filename. Default is "output.png".
logger : logging.Logger, optional
Logger instance for logging messages. If None, no logging is performed.
Notes
-----
- Hexbin plots use logarithmic color scale (bins='log')
- Includes diagonal reference line (y=x) in red dotted style
- Displays R² score in the top-left corner of each panel
- Shared colorbar on the right with log10(count) label
- Grid size of 100 bins for density calculation
The hexbin representation is particularly effective for large datasets
where scatter plots would suffer from overplotting.
Examples
--------
>>> plot_flux_and_abs(
... predicts=predictions,
... targets=targets,
... filename="hexbin_diagnostics.png",
... logger=logger
... )
"""
include_abs12 = abs12_predict is not None and abs12_target is not None
include_abs34 = abs34_predict is not None and abs34_target is not None
include_abs = include_abs12 and include_abs34
if include_abs:
fig, axes = subplots(3, 2, figsize=(10, 15))
else:
fig, axes = subplots(2, 2, figsize=(10, 10))
fig.subplots_adjust(
hspace=0.3, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
canvas = FigureCanvasAgg(fig)
index_map = {(0, 0): 0, (0, 1): 1, (1, 0): 2, (1, 1): 3}
name_dict = {
(0, 0): {"name": "Flux1u", "plotname": r"$\mathrm{Flux_{direct}^{u}}$"},
(0, 1): {"name": "Flux1d", "plotname": r"$\mathrm{Flux_{direct}^{d}}$"},
(1, 0): {"name": "Flux2u", "plotname": r"$\mathrm{Flux_{diffusion}^{u}}$"},
(1, 1): {"name": "Flux2d", "plotname": r"$\mathrm{Flux_{diffusion}^{d}}$"},
}
for (r, c), props in name_dict.items():
flux_idx = index_map[(r, c)]
y_pred = predicts[:, flux_idx, :].reshape(-1)
y_true = targets[:, flux_idx, :].reshape(-1)
ax = axes[r, c]
hb = ax.hexbin(
y_true, y_pred, gridsize=100, cmap="jet", bins="log", vmin=1, vmax=1e6
)
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
ax.plot(
[y_true.min(), y_true.max()],
[y_true.min(), y_true.max()],
"r:",
linewidth=0.5,
)
r2 = r2_score(y_true, y_pred)
ax.text(0.05, 0.9, f"$R^2$: {r2:.5f}", transform=ax.transAxes)
flux_name = props["plotname"]
ax.set_title(rf"{flux_name}")
ax.set_xlabel("Observed")
ax.set_ylabel("Predicted")
if include_abs:
ax = axes[2, 0]
y_pred = abs12_predict.reshape(-1)
y_true = abs12_target.reshape(-1)
hb = ax.hexbin(
y_true, y_pred, gridsize=100, cmap="jet", bins="log", vmin=1, vmax=1e6
)
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
ax.plot(
[y_true.min(), y_true.max()],
[y_true.min(), y_true.max()],
"r:",
linewidth=0.5,
)
r2 = r2_score(y_true, y_pred)
ax.text(0.05, 0.9, f"$R^2$: {r2:.5f}", transform=ax.transAxes)
ax.set_title(r"$\mathrm{Abs_{direct}}$")
ax.set_xlabel("Observed")
ax.set_ylabel("Predicted")
ax = axes[2, 1]
y_pred = abs34_predict.reshape(-1)
y_true = abs34_target.reshape(-1)
hb = ax.hexbin(
y_true, y_pred, gridsize=100, cmap="jet", bins="log", vmin=1, vmax=1e6
)
ax.plot(
[y_true.min(), y_true.max()],
[y_true.min(), y_true.max()],
"r:",
linewidth=0.5,
)
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
r2 = r2_score(y_true, y_pred)
ax.text(0.05, 0.9, f"$R^2$: {r2:.5f}", transform=ax.transAxes)
ax.set_title(r"$\mathrm{Abs_{diffusion}}$")
ax.set_xlabel("Observed")
ax.set_ylabel("Predicted")
# Shared colorbar on the right
cbar_ax = fig.add_axes([0.92, 0.1, 0.015, 0.8])
fig.colorbar(hb, cax=cbar_ax, label=r"$\mathrm{\log_{10}[Count]}$")
canvas.print_figure(filename, bbox_inches="tight")
if logger:
logger.info(f"Saved hexbin plot to {filename}")
else:
print(f"Saved hexbin plot to {filename}")
[docs]
def plot_all_diagnostics(
predicts,
targets,
abs12_predict=None,
abs12_target=None,
abs34_predict=None,
abs34_target=None,
n_pft=15,
n_bands=2,
n_chans=4,
output_dir="./results",
prefix="diagnostics",
logger=None,
):
"""
Generate all diagnostic plots: aggregated, per PFT, per band.
This function creates a comprehensive set of diagnostic plots including
one aggregated plot (all PFTs, all bands combined), per-band plots for
each selected PFT (VIS and NIR bands), and line plots and hexbin plots
for each combination.
Parameters
----------
predicts : torch.Tensor or np.ndarray
Model predictions for fluxes. Shape: (batch, 4, n_pft, n_bands, seq_length)
targets : torch.Tensor or np.ndarray
Ground truth targets. Same shape as predicts.
abs12_predict : torch.Tensor or np.ndarray, optional
Predicted absorption for channels 1-2 (direct beam).
abs12_target : torch.Tensor or np.ndarray, optional
True absorption for channels 1-2.
abs34_predict : torch.Tensor or np.ndarray, optional
Predicted absorption for channels 3-4 (diffuse).
abs34_target : torch.Tensor or np.ndarray, optional
True absorption for channels 3-4.
n_pft : int, optional
Number of Plant Functional Types. Default is 15.
n_bands : int, optional
Number of spectral bands. Default is 2 (VIS and NIR).
n_chans : int, optional
Number of output channels. Default is 4.
output_dir : str, optional
Directory to save diagnostic plots. Default is "./results".
prefix : str, optional
Prefix for output filenames. Default is "diagnostics".
logger : logging.Logger, optional
Logger instance for logging messages. If None, no logging is performed.
Notes
-----
The function generates 1 aggregated hexbin plot, 1 aggregated line plot,
and for each selected PFT (up to 8 randomly selected): hexbin and line
plots per band (VIS and NIR). Total plots = 1 + (selected_pfts * n_bands * 2)
where selected_pfts = min(8, n_pft).
Examples
--------
>>> plot_all_diagnostics(
... predicts=model_outputs,
... targets=targets,
... n_pft=15,
... n_bands=2,
... output_dir="./diagnostics",
... prefix="experiment_1"
... )
"""
predicts_np = predicts.detach().numpy()
targets_np = targets.detach().numpy()
if abs12_predict is not None:
abs12_predict_np = abs12_predict.detach().numpy()
abs12_target_np = abs12_target.detach().numpy()
abs34_predict_np = abs34_predict.detach().numpy()
abs34_target_np = abs34_target.detach().numpy()
else:
abs12_predict_np = None
abs12_target_np = None
abs34_predict_np = None
abs34_target_np = None
# 1. Aggregated plot (all PFTs, all bands)
plot_flux_and_abs(
predicts_np,
targets_np,
abs12_predict_np,
abs12_target_np,
abs34_predict_np,
abs34_target_np,
filename=os.path.join(output_dir, f"{prefix}_aggregated.png"),
logger=logger,
)
# 2. Randomly select 8 PFTs (or fewer if n_pft < 8)
num_pft_to_plot = min(8, n_pft)
selected_pfts = random.sample(range(n_pft), num_pft_to_plot)
if logger:
logger.info(f"Selected PFTs for detailed plots: {selected_pfts}")
# 3. For each selected PFT, create plots for each band
for pft_idx in selected_pfts:
for band_idx in range(n_bands):
# Extract data for specific PFT and band - shape (batch, 4, seq)
predicts_pft_band = predicts_np[:, :, pft_idx, band_idx, :]
targets_pft_band = targets_np[:, :, pft_idx, band_idx, :]
if abs12_predict_np is not None:
abs12_pft_band = abs12_predict_np[:, :, pft_idx, band_idx, :]
abs12_target_pft_band = abs12_target_np[:, :, pft_idx, band_idx, :]
abs34_pft_band = abs34_predict_np[:, :, pft_idx, band_idx, :]
abs34_target_pft_band = abs34_target_np[:, :, pft_idx, band_idx, :]
else:
abs12_pft_band = None
abs12_target_pft_band = None
abs34_pft_band = None
abs34_target_pft_band = None
band_name = "VIS" if band_idx == 0 else "NIR"
base_filename = os.path.join(
output_dir, f"{prefix}_pft{pft_idx:02d}_{band_name}"
)
# Hexbin plot
plot_flux_and_abs(
predicts_pft_band,
targets_pft_band,
abs12_pft_band,
abs12_target_pft_band,
abs34_pft_band,
abs34_target_pft_band,
filename=f"{base_filename}_hexbin.png",
logger=logger,
)
# Line plot
plot_flux_and_abs_lines(
predicts_pft_band,
targets_pft_band,
abs12_pft_band,
abs12_target_pft_band,
abs34_pft_band,
abs34_target_pft_band,
filename=f"{base_filename}_lines.png",
logger=logger,
)
if logger:
logger.info(
f"Generated {1 + num_pft_to_plot * n_bands} diagnostic plots in {output_dir}"
)
else:
print(
f"Generated {1 + num_pft_to_plot * n_bands} diagnostic plots in {output_dir}"
)
[docs]
def plot_metric_histories(
train_history,
valid_history,
filename="training_validation_metrics.png",
logger=None,
):
"""
Plot training and validation metrics over epochs.
Creates a multi-panel figure showing the evolution of various metrics
(e.g., NMAE, NMSE, R2, bias) over training epochs.
Parameters
----------
train_history : dict
Dictionary with metric names as keys and lists of training values.
Example: {"nmae": [0.1, 0.08, 0.06], "r2": [0.85, 0.88, 0.91]}
valid_history : dict
Dictionary with metric names as keys and lists of validation values.
Same structure as train_history.
filename : str, optional
Output filename. Default is "training_validation_metrics.png".
logger : logging.Logger, optional
Logger instance for logging messages. If None, no logging is performed.
Notes
-----
- Metrics are plotted on a logarithmic scale for better visualization of convergence
- Each metric gets its own panel arranged in a grid (3 columns)
- Blue lines: training metrics
- Orange lines: validation metrics
- Grid is enabled for all panels
Examples
--------
>>> train_history = {"nmae": [0.15, 0.12, 0.09, 0.07],
... "nmse": [0.08, 0.06, 0.04, 0.03]}
>>> valid_history = {"nmae": [0.16, 0.13, 0.10, 0.08],
... "nmse": [0.09, 0.07, 0.05, 0.04]}
>>> plot_metric_histories(
... train_history=train_history,
... valid_history=valid_history,
... filename="metrics.png",
... logger=logger
... )
"""
num_metrics = len(train_history)
if num_metrics == 0:
if logger:
logger.warning("No metrics to plot")
return
cols = 3
rows = math.ceil(num_metrics / cols)
fig = Figure(figsize=(5 * cols, 4 * rows))
fig.set_tight_layout(True)
canvas = FigureCanvasAgg(fig)
gs = gridspec.GridSpec(rows, cols)
for idx, key in enumerate(train_history):
row, col = divmod(idx, cols)
ax = fig.add_subplot(gs[row, col])
linestyles = mpltex.linestyle_generator(markers=[])
ax.plot(train_history[key], label="train", **next(linestyles))
ax.plot(valid_history[key], label="valid", **next(linestyles))
ax.set_yscale("log")
# ax.set_title(key.replace('_', ' ').upper())
ax.set_xlabel("Epoch")
ax.set_ylabel(key.replace("_", " ").upper())
ax.legend()
ax.grid(True)
canvas.print_figure(filename, bbox_inches="tight")
if logger:
logger.info(f"Saved metric history plot to {filename}")
else:
print(f"Saved metric history plot to {filename}")
[docs]
def plot_loss_histories(
train_loss, valid_loss, filename="training_validation_loss.png", logger=None
):
"""
Plot training and validation loss over epochs.
Creates a single-panel figure showing the loss evolution during training.
Parameters
----------
train_loss : list or array
Training loss values over epochs.
valid_loss : list or array
Validation loss values over epochs.
filename : str, optional
Output filename. Default is "training_validation_loss.png".
logger : logging.Logger, optional
Logger instance for logging messages. If None, no logging is performed.
Notes
-----
- Uses logarithmic scale for y-axis to visualize exponential decay
- Blue line: training loss
- Orange line: validation loss
- Includes grid for better readability
- Both lines share the same x-axis (epoch number)
Examples
--------
>>> train_losses = [0.5, 0.3, 0.2, 0.15, 0.12]
>>> valid_losses = [0.55, 0.35, 0.25, 0.18, 0.15]
>>> plot_loss_histories(
... train_loss=train_losses,
... valid_loss=valid_losses,
... filename="loss_curves.png",
... logger=logger
... )
"""
fig = Figure(figsize=(8, 5))
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(111)
linestyles = mpltex.linestyle_generator(markers=[])
ax.plot(train_loss, label="train", **next(linestyles))
ax.plot(valid_loss, label="valid", **next(linestyles))
ax.set_yscale("log")
ax.set_title("LOSS")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss Value")
ax.legend()
ax.grid(True)
canvas.print_figure(filename, bbox_inches="tight")
if logger:
logger.info(f"Saved loss history plot to {filename}")
else:
print(f"Saved loss history plot to {filename}")
[docs]
def plot_spatial_temporal_density(
sindex_tracker,
tindex_tracker,
mode="train",
save_dir="./tests_plots",
filename="density_scatter",
figsize=(10, 10),
logger=None,
):
"""
Plot a density scatter plot of spatial index vs temporal index with marginal histograms.
This function creates a 2D density scatter plot (hexbin) showing the
distribution of spatial indices (processor ranks) across temporal indices,
with:
- Right plot: Histogram of temporal index distribution (horizontal bars)
- Top plot: Histogram of spatial index distribution (vertical bars)
Parameters
----------
sindex_tracker : list or array-like
List of spatial indices (processor ranks) for each data sample.
tindex_tracker : list or array-like
List of temporal indices for each data sample.
mode : str, optional
Dataset mode identifier ("train", "validation", "test"). Used in filename.
Default is "train".
save_dir : str, optional
Directory path where the plot will be saved. Default is "./tests_plots".
filename : str, optional
Base name for the output file. Default is "density_scatter".
figsize : tuple, optional
Figure size as (width, height) in inches. Default is (10, 10).
logger : logging.Logger, optional
Logger instance for logging messages. If None, no logging is performed.
Returns
-------
str or None
Path to the saved plot file, or None if no data to plot.
Notes
-----
The plot layout uses GridSpec with the following structure:
- Top-left: Histogram of spatial index distribution (vertical bars)
- Top-right: Empty (no plot)
- Bottom-left: Main hexbin density scatter plot
- Bottom-right: Histogram of temporal index distribution (horizontal bars)
A colorbar is placed below the main plot with log-scaled counts.
This visualization is useful for:
- Verifying uniform sampling across spatial processors
- Checking temporal coverage of the dataset
- Detecting sampling biases in training/validation splits
Examples
--------
>>> # After training, check sampling distribution
>>> plot_spatial_temporal_density(
... sindex_tracker=dataset.sindex_tracker,
... tindex_tracker=dataset.tindex_tracker,
... mode="train",
... save_dir="./analysis",
... logger=logger
... )
"""
if len(sindex_tracker) == 0 or len(tindex_tracker) == 0:
print(f"No data to plot for {mode} mode")
return None
# Convert to numpy arrays
sindex_tracker = np.array(sindex_tracker)
tindex_tracker = np.array(tindex_tracker)
# Get limits
min_sindex = int(sindex_tracker.min())
max_sindex = int(sindex_tracker.max())
min_time = int(tindex_tracker.min())
max_time = int(tindex_tracker.max())
# Create figure with GridSpec for custom layout
fig = Figure(figsize=figsize)
canvas = FigureCanvasAgg(fig)
# Define grid:
# - Top histogram takes 20% height
# - Main hexbin takes 80% height
# - Right histogram takes 20% width
# - Left main takes 80% width
# Add space between panels
gs = gridspec.GridSpec(
2,
2,
figure=fig,
height_ratios=[0.2, 0.8],
width_ratios=[0.8, 0.2],
hspace=0.2,
wspace=0.2,
)
# Main hexbin plot (bottom-left)
ax_main = fig.add_subplot(gs[1, 0])
# Right histogram (bottom-right) - temporal index distribution
ax_right = fig.add_subplot(gs[1, 1])
# Top histogram (top-left) - spatial index distribution
ax_top = fig.add_subplot(gs[0, 0])
# Top-right corner is empty
ax_empty = fig.add_subplot(gs[0, 1])
ax_empty.axis("off")
# Create density scatter plot (hexbin) in main axis
hb = ax_main.hexbin(
sindex_tracker,
tindex_tracker,
gridsize=100,
extent=[min_sindex - 0.5, max_sindex + 0.5, min_time, max_time],
cmap="jet",
bins="log",
mincnt=1,
edgecolors="none",
)
ax_main.set_xlabel("Spatial Index (Processor Rank)")
ax_main.set_ylabel("Temporal Index")
ax_main.set_xlim(min_sindex - 0.5, max_sindex + 0.5)
ax_main.set_ylim(min_time, max_time)
ax_main.grid(True, alpha=0.3, linestyle="--")
# Right plot: Histogram of temporal index distribution (horizontal bars)
unique_tindices = np.arange(min_time, max_time + 1)
temporal_counts = [np.sum(tindex_tracker == t) for t in unique_tindices]
ax_right.barh(
unique_tindices,
temporal_counts,
height=0.8,
color="coral",
alpha=0.7,
edgecolor="black",
)
ax_right.set_xlabel("Frequency")
ax_right.set_ylim(min_time, max_time)
ax_right.grid(True, alpha=0.3, linestyle="--", axis="x")
ax_right.tick_params(axis="both")
# Top plot: Histogram of spatial index distribution (vertical bars)
unique_sindices = np.arange(min_sindex, max_sindex + 1)
spatial_counts = [np.sum(sindex_tracker == s) for s in unique_sindices]
ax_top.bar(
unique_sindices,
spatial_counts,
width=0.8,
color="steelblue",
alpha=0.7,
edgecolor="black",
)
ax_top.set_ylabel("Frequency")
ax_top.set_xlim(min_sindex - 0.5, max_sindex + 0.5)
ax_top.grid(True, alpha=0.3, linestyle="--", axis="y")
ax_top.tick_params(axis="both")
# Add colorbar at the bottom, spanning the width of the main plot
# Get the position of the main plot
main_pos = ax_main.get_position()
# Create colorbar axes below the main plot, with same width
cbar_ax = fig.add_axes([main_pos.x0, main_pos.y0 - 0.1, main_pos.width, 0.02])
cbar = fig.colorbar(hb, cax=cbar_ax, orientation="horizontal")
cbar.set_label(r"$\log_{10}[\mathrm{Count}]$")
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{filename}_{mode}.png")
canvas.print_figure(save_path, bbox_inches="tight")
if logger:
logger.info(
f"Saved density scatter plot with marginal histograms to: {save_path}"
)
else:
print(f"Saved density scatter plot with marginal histograms to: {save_path}")