rtnn package

Subpackages

Submodules

rtnn.dataset module

Dataset module for RTnn radiative transfer neural network framework.

This module provides the DataPreprocessor class for loading, preprocessing, and normalizing NetCDF climate data for training and evaluating neural network emulators of atmospheric radiative transfer. The module handles complex data structures including multiple years, spatial and temporal batching, and various normalization techniques.

The DataPreprocessor is designed to work with Land Surface Model (LSM) data containing variables such as: - Cosine of solar zenith angle (cosz) - Leaf area index (lai) - Single scattering albedo (ssa) - Surface reflectance (rs) - Radiative transfer outputs (collimated albedo/transmittance, isotropic albedo/transmittance)

The module supports data augmentation through random spatial sampling during training and provides flexible normalization methods including min-max scaling, standardization, robust scaling, and transformations like log1p and sqrt.

Notes

The expected input data structure consists of NetCDF4 files with dimensions: - time: Temporal dimension - dim_1: Spatial points (e.g., grid cells) - dim_2: Vertical levels (sequence length, typically 10) - dim_3: Plant functional types (PFTs, typically 15) - dim_4: Spectral bands (typically 2)

Feature channels (121 total): - cosz: 1 channel - lai: 2 variables × 15 PFTs = 30 channels - ssa: 2 variables × 2 bands × 15 PFTs = 60 channels - rs: 1 variable × 2 bands × 15 PFTs = 30 channels

Output channels (120 total): - 4 variables × 2 bands × 15 PFTs = 120 channels

Examples

Basic usage for training:

>>> from rtnn.logger import Logger
>>> logger = Logger()
>>> train_files = ["data_1995.nc", "data_1996.nc", "data_1997.nc"]
>>> dataset = DataPreprocessor(
...     logger=logger,
...     dfs=train_files,
...     stime=0,
...     tbatch=24,
...     training=True,
...     sblock_perc=0.6,
...     norm_mapping=norm_stats,
...     normalization_type=norm_types
... )
>>> features, targets = dataset[0]
>>> features.shape
torch.Size([158, 121, 10])
>>> targets.shape
torch.Size([158, 120, 10])

For validation/testing:

>>> val_dataset = DataPreprocessor(
...     logger=logger,
...     dfs=val_files,
...     stime=0,
...     tbatch=24,
...     training=False,
...     norm_mapping=norm_stats,
...     normalization_type=norm_types
... )

See also

rtnn.models

Contains the neural network architectures for radiative transfer

rtnn.evaluater

Provides metrics and loss functions for model evaluation

rtnn.diagnostics

Visualization tools for model predictions

class rtnn.dataset.DataPreprocessor(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Dataset class for preprocessing LSM (Land Surface Model) data.

This class handles loading and preprocessing of NetCDF files containing climate data, with support for multiple years, spatial and temporal batching, and various normalization techniques.

Parameters:
  • logger (object) – Logger instance for logging messages.

  • dfs (List[str]) – List of file paths to NetCDF files.

  • stime (int) – Start time index.

  • tstep (int) – Number of time steps per file.

  • tbatch (int) – Temporal batch size.

  • norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable. Default is empty dict.

  • normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable. Default is empty dict.

logger

Logger instance.

Type:

object

stime

Start time index.

Type:

int

tstep

Time steps per file.

Type:

int

tbatch

Temporal batch size.

Type:

int

norm_mapping

Normalization statistics.

Type:

Dict

normalization_type

Normalization types per variable.

Type:

Dict

sbatch

Number of spatial batches.

Type:

int

years

Sorted list of years in the dataset.

Type:

List[int]

etime

End time index.

Type:

int

dfs

List of (year, spatial_index, file_path) tuples.

Type:

List[Tuple[int, int, str]]

time_blocks

Shuffled time blocks.

Type:

np.ndarray

min_dims

Minimum dimensions across files.

Type:

Dict[str, int]

cosz

Cosine of solar zenith angle variable names.

Type:

List[str]

lai

Leaf area index variable names.

Type:

List[str]

ssa

Single scattering albedo variable names.

Type:

List[str]

rs

Surface reflectance variable names.

Type:

List[str]

ov

Output variable names.

Type:

List[str]

Examples

>>> from rtnn.logger import Logger
>>> logger = Logger()
>>> files = ["data_1995.nc", "data_1996.nc"]
>>> dataset = DataPreprocessor(
...     logger=logger,
...     dfs=files,
...     stime=0,
...     tstep=100,
...     tbatch=24,
...     norm_mapping={},
...     normalization_type={}
... )
>>> len(dataset)
100
>>> features, targets = dataset[0]
>>> features.shape
torch.Size([schunk, feature_channels, seq_length])
>>> targets.shape
torch.Size([schunk, output_channels, seq_length])
__init__(logger: Any, dfs: List[str], stime: int, tbatch: int, training: bool = True, sblock_perc: float = 0.6, norm_mapping: Dict = {}, normalization_type: Dict = {}, debug: bool = False) None[source]

Initialize the DataPreprocessor.

Parameters:
  • logger (Any) – Logger instance for logging messages.

  • dfs (List[str]) – List of file paths to NetCDF files.

  • stime (int) – Start time index.

  • tbatch (int) – Temporal batch size.

  • training (bool, optional) – If True, use 60% of spatial batches (data augmentation). If False, use 100% of spatial batches (full evaluation).

  • norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable.

  • normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable.

  • debug (bool, optional) – If True, print debug information.

normalize(data: numpy.ndarray, var_name: str) numpy.ndarray[source]

Normalize data using the specified normalization method.

Parameters:
  • data (np.ndarray) – Input data array to normalize.

  • var_name (str) – Name of the variable for which to retrieve normalization statistics.

Returns:

Normalized data array.

Return type:

np.ndarray

Raises:

ValueError – If the normalization type is not supported.

Notes

Supported normalization types: - minmax: (x - min) / (max - min) - standard: (x - mean) / std - robust: (x - median) / IQR - log1p_minmax: log1p(x) normalized - log1p_standard: log1p(x) standardized - log1p_robust: log1p(x) robust normalized - sqrt_minmax: sqrt(x) normalized - sqrt_standard: sqrt(x) standardized - sqrt_robust: sqrt(x) robust normalized

rtnn.diagnostics module

Diagnostics and visualization utilities for RTnn radiative transfer model.

This module provides comprehensive visualization tools for analyzing and evaluating neural network emulators of atmospheric radiative transfer. It includes functions for creating diagnostic plots of model predictions, training histories, and data statistics.

The module is designed to support: - Comparison of model predictions against ground truth targets - Visualization of radiative fluxes (direct/diffuse, upwelling/downwelling) - Absorption rate analysis across vertical levels and spectral bands - Per-plant functional type (PFT) and per-spectral band diagnostics - Training and validation metric tracking over epochs - Spatial and temporal sampling distribution analysis - Statistical characterization of input variables for normalization

Key visualization types

  • Hexbin plots: Density scatter plots with color mapping for high-volume data

  • Line plots: Time series or vertical profile comparisons

  • Histograms: Distribution analysis with optional log scaling

  • Marginal histograms: 2D density plots with side distributions

  • Multi-panel layouts: Systematic comparison across flux channels

The module follows the matplotlib architecture using Figure and FigureCanvasAgg for non-interactive, file-based rendering suitable for batch processing and headless environments.

Notes

All plotting functions are designed to work with PyTorch tensors or numpy arrays and save figures directly to disk. The module uses custom matplotlib parameters optimized for scientific publication quality.

Default figure parameters: - Font family: DejaVu Sans - Font sizes: axes labels (15), titles (15), ticks (12), legends (15) - Line widths: 2 - Tick direction: outward - Legend: frame off, best location

Examples

Basic usage for generating diagnostic plots:

>>> import torch
>>> from rtnn.diagnostics import plot_flux_and_abs, plot_metric_histories
>>>
>>> # Assuming you have model predictions and targets
>>> predicts = torch.randn(32, 4, 10)  # (batch, channels, levels)
>>> targets = torch.randn(32, 4, 10)
>>>
>>> # Create hexbin plot
>>> plot_flux_and_abs(
...     predicts=predicts,
...     targets=targets,
...     filename="diagnostics.png",
...     logger=logger
... )
>>>
>>> # Plot training history
>>> train_history = {"nmae": [0.1, 0.08, 0.06], "r2": [0.85, 0.88, 0.91]}
>>> valid_history = {"nmae": [0.12, 0.10, 0.08], "r2": [0.82, 0.85, 0.89]}
>>> plot_metric_histories(
...     train_history=train_history,
...     valid_history=valid_history,
...     filename="metrics.png"
... )

See also

rtnn.dataset.DataPreprocessor

Data loading and preprocessing

rtnn.evaluater

Metrics and loss functions for evaluation

rtnn.models

Neural network architectures for radiative transfer

rtnn.diagnostics.stats(file_list, logger, output_dir, norm_mapping=None, plots=False)[source]

Compute statistics and generate histograms for variables in NetCDF files.

Reads a collection of NetCDF files, computes descriptive statistics for each variable, and generates histogram plots saved to disk. In addition to raw statistics, transformed statistics using logarithmic (log1p) and square-root transformations are also computed.

Parameters:
  • file_list (list of str) – Paths to the NetCDF files to process.

  • logger (logging.Logger) – Logger used to report progress and informational messages.

  • output_dir (str) – Directory where histogram plots will be saved.

  • norm_mapping (dict, optional) – Dictionary to update with computed statistics. If None, a new dictionary is created. Default is None.

  • plots (bool, optional) – If True, generate and save histogram plots for each variable. Default is False.

Returns:

Dictionary mapping variable names to their computed statistics. Each variable contains the following entries:

Raw statistics:
  • vminfloat

    Minimum value

  • vmaxfloat

    Maximum value

  • vmeanfloat

    Mean value

  • vstdfloat

    Standard deviation

Robust statistics:
  • q1float

    First quartile (25th percentile)

  • q3float

    Third quartile (75th percentile)

  • iqrfloat

    Interquartile range (q3 - q1)

  • medianfloat

    Median value

Log-transformed statistics (log1p):
  • log_min : float

  • log_max : float

  • log_mean : float

  • log_std : float

  • log_q1 : float

  • log_q3 : float

  • log_iqr : float

  • log_median : float

Square-root-transformed statistics:
  • sqrt_min : float

  • sqrt_max : float

  • sqrt_mean : float

  • sqrt_std : float

  • sqrt_q1 : float

  • sqrt_q3 : float

  • sqrt_iqr : float

  • sqrt_median : float

Return type:

dict

Examples

>>> norm_mapping = stats(
...     file_list=["data_1995.nc", "data_1996.nc"],
...     logger=logger,
...     output_dir="./stats",
...     plots=True
... )
>>> norm_mapping["coszang"]["vmean"]
0.5
>>> norm_mapping["lai"]["log_median"]
1.23

Notes

  • Log transformation uses np.log1p (log(1+x)) to handle zero values

  • Square root transformation uses np.sqrt with clipping to non-negative

  • Histograms use 200 bins and log-scaled y-axis

rtnn.diagnostics.subplots(nrows, ncols, figsize)[source]

Create a figure and grid of subplots without pyplot.

This is a replacement for matplotlib.pyplot.subplots() that uses the Figure and FigureCanvasAgg API directly, suitable for headless environments.

Parameters:
  • nrows (int) – Number of rows in the subplot grid.

  • ncols (int) – Number of columns in the subplot grid.

  • figsize (tuple) – Figure size in inches as (width, height).

Returns:

  • fig (matplotlib.figure.Figure) – The created figure object.

  • axes (numpy.ndarray) – Array of axes objects with shape (nrows, ncols).

rtnn.diagnostics.plot_flux_and_abs_lines(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output_lines.png', logger=None)[source]

Create line plots for fluxes and absorption rates.

Generates a multi-panel figure with line plots for four flux channels and optionally two absorption panels. Each panel shows predictions vs targets across vertical levels.

Parameters:
  • predicts (torch.Tensor or np.ndarray) – Model predictions for fluxes of shape (batch_size, 4, seq_length).

  • targets (torch.Tensor or np.ndarray) – Ground truth fluxes of shape (batch_size, 4, seq_length).

  • abs12_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 1-2 (direct beam).

  • abs12_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 1-2 (direct beam).

  • abs34_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 3-4 (diffuse).

  • abs34_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 3-4 (diffuse).

  • filename (str, optional) – Output filename. Default is “output_lines.png”.

  • logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.

Notes

Figure layout:
  • 2x2 grid for fluxes:
    • (0,0): Direct upwelling (Flux_direct^u)

    • (0,1): Direct downwelling (Flux_direct^d)

    • (1,0): Diffuse upwelling (Flux_diffusion^u)

    • (1,1): Diffuse downwelling (Flux_diffusion^d)

  • Optional 1x2 grid for absorption rates (if provided):
    • (2,0): Direct absorption (Abs_direct)

    • (2,1): Diffuse absorption (Abs_diffusion)

The function randomly selects 10 samples from the batch for plotting. Each sample is shown with a different line style.

Examples

>>> plot_flux_and_abs_lines(
...     predicts=model_outputs,
...     targets=ground_truth,
...     filename="line_comparison.png",
...     logger=logger
... )
rtnn.diagnostics.plot_flux_and_abs(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output.png', logger=None)[source]

Create hexbin plots for fluxes and absorption rates.

Generates a multi-panel figure with hexbin density plots showing the relationship between predicted and true values. Useful for assessing prediction accuracy across the entire dataset.

Parameters:
  • predicts (torch.Tensor or np.ndarray) – Model predictions for fluxes of shape (batch_size, 4, seq_length).

  • targets (torch.Tensor or np.ndarray) – Ground truth fluxes of shape (batch_size, 4, seq_length).

  • abs12_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 1-2 (direct beam).

  • abs12_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 1-2 (direct beam).

  • abs34_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 3-4 (diffuse).

  • abs34_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 3-4 (diffuse).

  • filename (str, optional) – Output filename. Default is “output.png”.

  • logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.

Notes

  • Hexbin plots use logarithmic color scale (bins=’log’)

  • Includes diagonal reference line (y=x) in red dotted style

  • Displays R² score in the top-left corner of each panel

  • Shared colorbar on the right with log10(count) label

  • Grid size of 100 bins for density calculation

The hexbin representation is particularly effective for large datasets where scatter plots would suffer from overplotting.

Examples

>>> plot_flux_and_abs(
...     predicts=predictions,
...     targets=targets,
...     filename="hexbin_diagnostics.png",
...     logger=logger
... )
rtnn.diagnostics.plot_all_diagnostics(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, n_pft=15, n_bands=2, n_chans=4, output_dir='./results', prefix='diagnostics', logger=None)[source]

Generate all diagnostic plots: aggregated, per PFT, per band.

This function creates a comprehensive set of diagnostic plots including one aggregated plot (all PFTs, all bands combined), per-band plots for each selected PFT (VIS and NIR bands), and line plots and hexbin plots for each combination.

Parameters:
  • predicts (torch.Tensor or np.ndarray) – Model predictions for fluxes. Shape: (batch, 4, n_pft, n_bands, seq_length)

  • targets (torch.Tensor or np.ndarray) – Ground truth targets. Same shape as predicts.

  • abs12_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 1-2 (direct beam).

  • abs12_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 1-2.

  • abs34_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 3-4 (diffuse).

  • abs34_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 3-4.

  • n_pft (int, optional) – Number of Plant Functional Types. Default is 15.

  • n_bands (int, optional) – Number of spectral bands. Default is 2 (VIS and NIR).

  • n_chans (int, optional) – Number of output channels. Default is 4.

  • output_dir (str, optional) – Directory to save diagnostic plots. Default is “./results”.

  • prefix (str, optional) – Prefix for output filenames. Default is “diagnostics”.

  • logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.

Notes

The function generates 1 aggregated hexbin plot, 1 aggregated line plot, and for each selected PFT (up to 8 randomly selected): hexbin and line plots per band (VIS and NIR). Total plots = 1 + (selected_pfts * n_bands * 2) where selected_pfts = min(8, n_pft).

Examples

>>> plot_all_diagnostics(
...     predicts=model_outputs,
...     targets=targets,
...     n_pft=15,
...     n_bands=2,
...     output_dir="./diagnostics",
...     prefix="experiment_1"
... )
rtnn.diagnostics.plot_metric_histories(train_history, valid_history, filename='training_validation_metrics.png', logger=None)[source]

Plot training and validation metrics over epochs.

Creates a multi-panel figure showing the evolution of various metrics (e.g., NMAE, NMSE, R2, bias) over training epochs.

Parameters:
  • train_history (dict) – Dictionary with metric names as keys and lists of training values. Example: {“nmae”: [0.1, 0.08, 0.06], “r2”: [0.85, 0.88, 0.91]}

  • valid_history (dict) – Dictionary with metric names as keys and lists of validation values. Same structure as train_history.

  • filename (str, optional) – Output filename. Default is “training_validation_metrics.png”.

  • logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.

Notes

  • Metrics are plotted on a logarithmic scale for better visualization of convergence

  • Each metric gets its own panel arranged in a grid (3 columns)

  • Blue lines: training metrics

  • Orange lines: validation metrics

  • Grid is enabled for all panels

Examples

>>> train_history = {"nmae": [0.15, 0.12, 0.09, 0.07],
...                  "nmse": [0.08, 0.06, 0.04, 0.03]}
>>> valid_history = {"nmae": [0.16, 0.13, 0.10, 0.08],
...                  "nmse": [0.09, 0.07, 0.05, 0.04]}
>>> plot_metric_histories(
...     train_history=train_history,
...     valid_history=valid_history,
...     filename="metrics.png",
...     logger=logger
... )
rtnn.diagnostics.plot_loss_histories(train_loss, valid_loss, filename='training_validation_loss.png', logger=None)[source]

Plot training and validation loss over epochs.

Creates a single-panel figure showing the loss evolution during training.

Parameters:
  • train_loss (list or array) – Training loss values over epochs.

  • valid_loss (list or array) – Validation loss values over epochs.

  • filename (str, optional) – Output filename. Default is “training_validation_loss.png”.

  • logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.

Notes

  • Uses logarithmic scale for y-axis to visualize exponential decay

  • Blue line: training loss

  • Orange line: validation loss

  • Includes grid for better readability

  • Both lines share the same x-axis (epoch number)

Examples

>>> train_losses = [0.5, 0.3, 0.2, 0.15, 0.12]
>>> valid_losses = [0.55, 0.35, 0.25, 0.18, 0.15]
>>> plot_loss_histories(
...     train_loss=train_losses,
...     valid_loss=valid_losses,
...     filename="loss_curves.png",
...     logger=logger
... )
rtnn.diagnostics.plot_spatial_temporal_density(sindex_tracker, tindex_tracker, mode='train', save_dir='./tests_plots', filename='density_scatter', figsize=(10, 10), logger=None)[source]

Plot a density scatter plot of spatial index vs temporal index with marginal histograms.

This function creates a 2D density scatter plot (hexbin) showing the distribution of spatial indices (processor ranks) across temporal indices, with: - Right plot: Histogram of temporal index distribution (horizontal bars) - Top plot: Histogram of spatial index distribution (vertical bars)

Parameters:
  • sindex_tracker (list or array-like) – List of spatial indices (processor ranks) for each data sample.

  • tindex_tracker (list or array-like) – List of temporal indices for each data sample.

  • mode (str, optional) – Dataset mode identifier (“train”, “validation”, “test”). Used in filename. Default is “train”.

  • save_dir (str, optional) – Directory path where the plot will be saved. Default is “./tests_plots”.

  • filename (str, optional) – Base name for the output file. Default is “density_scatter”.

  • figsize (tuple, optional) – Figure size as (width, height) in inches. Default is (10, 10).

  • logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.

Returns:

Path to the saved plot file, or None if no data to plot.

Return type:

str or None

Notes

The plot layout uses GridSpec with the following structure: - Top-left: Histogram of spatial index distribution (vertical bars) - Top-right: Empty (no plot) - Bottom-left: Main hexbin density scatter plot - Bottom-right: Histogram of temporal index distribution (horizontal bars)

A colorbar is placed below the main plot with log-scaled counts.

This visualization is useful for: - Verifying uniform sampling across spatial processors - Checking temporal coverage of the dataset - Detecting sampling biases in training/validation splits

Examples

>>> # After training, check sampling distribution
>>> plot_spatial_temporal_density(
...     sindex_tracker=dataset.sindex_tracker,
...     tindex_tracker=dataset.tindex_tracker,
...     mode="train",
...     save_dir="./analysis",
...     logger=logger
... )

rtnn.evaluater module

Evaluation utilities for RTnn model assessment.

This module provides comprehensive evaluation tools for radiative transfer neural network models, including custom loss functions, metric computation, and visualization helpers. It is designed to support both training diagnostics and rigorous model validation against physical constraints.

The module implements several key capabilities:

Custom Loss Functions
  • Normalized losses (NMAE, NMSE) for scale-invariant error measurement

  • Standard losses (MSE, MAE, Huber, Smooth L1) for baseline comparison

  • Weighted and physics-informed losses for multi-objective optimization

Evaluation Metrics
  • NMAE: Normalized Mean Absolute Error

  • NMSE: Normalized Mean Squared Error

  • R²: Coefficient of determination

  • MBE: Mean Bias Error

  • MARE: Mean Absolute Relative Error

  • GMRAE: Geometric Mean Relative Absolute Error

Physical Consistency
  • Absorption rate calculation from flux divergence

  • Energy conservation penalty (albedo + transmittance + absorptance = 1)

  • Heating rate computation from net flux profiles

Data Handling
  • Normalization/de-normalization for all supported transformation types

  • Multi-dimensional tensor reshaping (batch, channels, PFTs, bands, levels)

  • Metric tracking with running statistics

The module follows a modular design where loss functions and metrics are implemented as separate callable classes/functions, allowing easy extension and composition.

Notes

Flux Variable Ordering

  • Channel 0: collimated albedo (direct upwelling)

  • Channel 1: collimated transmittance (direct downwelling)

  • Channel 2: isotropic albedo (diffuse upwelling)

  • Channel 3: isotropic transmittance (diffuse downwelling)

This ordering matches the ov list in rtnn.dataset.DataPreprocessor.

Absorption Calculation

  • For collimated: absorption = -d(net_flux)/dz

  • For isotropic: absorption = -d(net_flux)/dz

Supported Normalization Types

  • Linear: minmax, standard, robust

  • Log1p-based: log1p_minmax, log1p_standard, log1p_robust

  • Sqrt-based: sqrt_minmax, sqrt_standard, sqrt_robust

Examples

Basic usage for model evaluation:

>>> import torch
>>> from rtnn.evaluater import get_loss_function, run_validation
>>>
>>> # Create loss function
>>> args = argparse.Namespace(loss_type='huber', beta_delta=1.0)
>>> criterion = get_loss_function('huber', args)
>>>
>>> # Evaluate model
>>> valid_loss, metrics = run_validation(
...     loader=val_loader,
...     model=my_model,
...     norm_mapping=norm_stats,
...     normalization_type=norm_types,
...     index_mapping=idxmap,
...     device=device,
...     args=args,
...     epoch=10,
...     logger=logger
... )
>>>
>>> print(f"Validation NMAE: {metrics['fluxes_NMAE']:.4f}")
>>> print(f"R² score: {metrics['fluxes_R2']:.4f}")

Using custom metric tracking:

>>> from rtnn.evaluater import MetricTracker, nmae_all
>>>
>>> tracker = MetricTracker()
>>> for batch in dataloader:
...     pred, target = model(batch)
...     count, value = nmae_all(pred, target)
...     tracker.update(value.item(), count)
>>>
>>> mean_nmae = tracker.getmean()

See also

rtnn.dataset.DataPreprocessor

Data loading and normalization

rtnn.diagnostics

Visualization tools for model predictions

rtnn.models

Neural network architectures

class rtnn.evaluater.NMSELoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Normalized Mean Squared Error Loss.

Computes MSE normalized by the mean square of the target values. Useful when the scale of the target variable varies across samples or when comparing models trained on different datasets.

Parameters:

eps (float, optional) – Small constant for numerical stability. Default is 1e-8.

Notes

The loss is calculated as:

NMSE = MSE(pred, target) / (mean(target²) + eps)

This normalization makes the loss scale-invariant, with values typically in the range [0, 1].

Examples

>>> criterion = NMSELoss()
>>> predictions = torch.tensor([[2.0, 3.0], [1.0, 2.0]])
>>> targets = torch.tensor([[2.0, 4.0], [1.0, 2.5]])
>>> loss = criterion(predictions, targets)
>>> print(loss.item())
0.0625  # approximates 0.0625 for this example
__init__(eps=1e-08)[source]
forward(pred, target)[source]
class rtnn.evaluater.NMAELoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Normalized Mean Absolute Error Loss.

Computes MAE normalized by the mean absolute value of the target. Provides a scale-invariant error metric that is more robust to outliers than NMSE.

Parameters:

eps (float, optional) – Small constant for numerical stability. Default is 1e-8.

Notes

Values typically range from 0 to 1, with 0 representing perfect predictions and values >1 indicating predictions worse than the trivial zero predictor.

Examples

>>> criterion = NMAELoss()
>>> predictions = torch.tensor([[2.0, 3.0], [1.0, 2.0]])
>>> targets = torch.tensor([[2.0, 4.0], [1.0, 2.5]])
>>> loss = criterion(predictions, targets)
>>> print(loss.item())
0.0833  # approximates 0.0833 for this example
__init__(eps=1e-08)[source]
forward(pred, target)[source]
rtnn.evaluater.physics_loss(pred: torch.Tensor, target: torch.Tensor, conservation_penalty: torch.Tensor | None = None, lambda_phys: float = 0.1, delta: float = 1.0) torch.Tensor[source]

Combined Huber loss + energy conservation penalty (improvement E).

The four output variables satisfy:

albedo + transmittance + absorptance = 1

For collimated: collim_alb + collim_tran + collim_abs = 1 For isotropic: isotrop_alb + isotrop_tran + isotrop_abs = 1

This function enforces the constraint as a soft penalty. You can pass pred_abs if your model also predicts absorptance; otherwise the penalty is computed implicitly as (1 - alb - tran).

Parameters:
  • pred (torch.Tensor shape (B, 4, L)) – Model predictions: [collim_alb, collim_tran, isotrop_alb, isotrop_tran] (channel order matches your ov list in DataPreprocessor).

  • target (torch.Tensor shape (B, 4, L)) – Ground-truth targets.

  • pred_abs (torch.Tensor or None shape (B, 2, L)) – If provided, predicted absorptance for [collimated, isotropic]. If None, absorptance is inferred as (1 - alb - tran).

  • lambda_phys (float) – Weight of the energy conservation penalty relative to Huber loss.

  • delta (float) – Huber loss delta parameter.

Return type:

torch.Tensor scalar loss value.

class rtnn.evaluater.MetricTracker[source]

Bases: object

A utility class for tracking and computing statistics of metric values.

This class maintains running sums of metric values and their squares, allowing incremental updates and computation of mean and standard deviation. It is particularly useful for aggregating metrics across multiple batches during evaluation.

value

Cumulative weighted sum of metric values

Type:

float

count

Total number of samples processed

Type:

int

value_sq

Cumulative weighted sum of squared metric values

Type:

float

Examples

>>> tracker = MetricTracker()
>>> tracker.update(10.0, 5)  # value=10.0, count=5 samples
>>> tracker.update(20.0, 3)  # value=20.0, count=3 samples
>>> print(tracker.getmean())  # (10*5 + 20*3) / (5+3) = 110/8 = 13.75
13.75
>>> print(tracker.getstd())
5.0  # computed from variance
>>> print(tracker.getsqrtmean())
3.7080992435478315
__init__()[source]

Initialize MetricTracker with zero values.

reset()[source]

Reset all tracked values to zero.

Return type:

None

update(value, count)[source]

Update the tracker with new metric values.

Parameters:
  • value (float) – The metric value to add

  • count (int) – Number of samples this value represents (weight)

Return type:

None

getmean()[source]

Calculate the mean of all tracked values.

Returns:

Weighted mean of all values: total_value / total_count

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

getstd()[source]

Calculate the standard deviation of all tracked values.

Returns:

Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2)

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

getsqrtmean()[source]

Calculate the square root of the mean of all tracked values.

Returns:

Square root of the weighted mean: sqrt(total_value / total_count)

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

rtnn.evaluater.get_loss_function(loss_type, args, logger=None)[source]

Factory function to instantiate the requested loss function.

Parameters:
  • loss_type (str) –

    Type of loss function. Options:

    Standard losses:
    • ’mse’: Mean Squared Error

    • ’mae’: Mean Absolute Error

    Normalized losses:
    • ’nmae’: Normalized Mean Absolute Error

    • ’nmse’: Normalized Mean Squared Error

    Robust losses:
    • ’smoothl1’: Smooth L1 Loss (Huber with beta)

    • ’huber’: Huber Loss with delta parameter

  • args (argparse.Namespace) – Arguments containing loss-specific parameters: - For ‘huber’/’smoothl1’: requires args.beta_delta - For composite losses: may require args.beta

  • logger (logging.Logger, optional) – Logger for informational messages. If None, no logging occurs.

Returns:

Initialized loss function.

Return type:

torch.nn.Module

Raises:

ValueError – If loss_type is not supported or required parameters are missing.

Examples

>>> import argparse
>>> args = argparse.Namespace(beta_delta=1.0)
>>> criterion = get_loss_function('huber', args)
>>> loss = criterion(predictions, targets)
>>> args = argparse.Namespace()
>>> criterion = get_loss_function('mse', args)
rtnn.evaluater.mse_all(pred, true)[source]

Compute Mean Squared Error.

Parameters:
Returns:

(num_elements, mse_value)

Return type:

tuple

rtnn.evaluater.mbe_all(pred, true)[source]

Compute Mean Bias Error.

Parameters:
Returns:

(num_elements, mbe_value)

Return type:

tuple

rtnn.evaluater.mae_all(pred, true)[source]

Compute Mean Absolute Error.

Parameters:
Returns:

(num_elements, mae_value)

Return type:

tuple

rtnn.evaluater.r2_all(pred, true)[source]

Calculate R2 (coefficient of determination) between predicted and true values.

Computes the R2 metric and returns both the number of elements and the R2 value.

Parameters:
Returns:

(num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score

Return type:

tuple

Notes

R2 is calculated as:

R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2)

This implementation is fully torch-based and works on CPU and GPU.

rtnn.evaluater.nmae_all(pred, true)[source]

Compute Normalized Mean Absolute Error.

Parameters:
Returns:

(num_elements, nmae_value)

Return type:

tuple

rtnn.evaluater.nmse_all(pred, true)[source]

Compute Normalized Mean Squared Error.

Parameters:
Returns:

(num_elements, nmse_value)

Return type:

tuple

rtnn.evaluater.mare_all(pred, true)[source]

Compute Mean Absolute Relative Error.

Parameters:
Returns:

(num_elements, mare_value)

Return type:

tuple

rtnn.evaluater.gmrae_all(pred, true)[source]

Compute Geometric Mean Relative Absolute Error.

Parameters:
Returns:

(num_elements, gmrae_value)

Return type:

tuple

rtnn.evaluater.unnorm_mpas(pred, targ, norm_mapping, normalization_type, idxmap)[source]

Reverse normalization for 5D tensors.

This function converts normalized predictions and targets back to physical units using stored normalization statistics. It supports all normalization types defined in DataPreprocessor.

Parameters:
  • pred (torch.Tensor) – Normalized predictions. Shape: (batch, 4, n_pft, n_bands, seq_length)

  • targ (torch.Tensor) – Normalized targets. Same shape as pred.

  • norm_mapping (dict) – Dictionary containing normalization statistics for each variable.

  • normalization_type (dict) – Dictionary specifying normalization type for each variable.

  • idxmap (dict) – Mapping from channel indices (0-3) to variable names.

Returns:

(upred, utarg) where: - upred (torch.Tensor): Unnormalized predictions - utarg (torch.Tensor): Unnormalized targets

Return type:

tuple

Raises:

ValueError – If normalization type is not supported.

Notes

The function handles the following transformations:
  • Linear: x’ = (x - mean)/std or (x - min)/(max - min)

  • Log1p: x_norm = (log(1+x) - log_mean)/log_std

  • Sqrt: x_norm = (sqrt(x) - sqrt_mean)/sqrt_std

The reverse operation is applied to recover physical values.

rtnn.evaluater.conservation_residual(alb, tran, abs_flux)[source]

Compute energy conservation residual for a given level.

Parameters:
  • alb (torch.Tensor) – Albedo (upwelling flux). Shape (batch, 1, n_pft, n_bands, seq)

  • tran (torch.Tensor) – Transmittance (downwelling flux). Same shape as alb.

  • abs_flux (torch.Tensor) – Absorptance (absorbed flux). Shape (batch, 1, n_pft, n_bands, seq-1)

Returns:

Squared conservation residual. Shape (batch, 1, n_pft, n_bands, seq-1)

Return type:

torch.Tensor

Notes

The function averages alb and tran to layer centers before computing:

residual = (alb_center + tran_center + abs_flux - 1)²

This enforces the physical constraint that:

upwelling + downwelling + absorption = total incoming radiation = 1

rtnn.evaluater.calc_abs(pred, targ, p=None)[source]

Calculate absorption rates from flux predictions.

Computes absorption rates for both collimated (direct) and isotropic (diffuse) components by calculating the divergence of net flux.

Parameters:
  • pred (torch.Tensor) – Predicted fluxes. Shape (batch, 4, n_pft, n_bands, seq_length) Channel order: [collim_alb, collim_tran, isotrop_alb, isotrop_tran]

  • targ (torch.Tensor) – Target fluxes. Same shape as pred.

  • p (torch.Tensor, optional) – Pressure levels. If provided, computes heating rates. Shape can be (seq_length,) or (batch, seq_length). Default is None.

Returns:

(abs12_pred, abs12_targ, abs34_pred, abs34_targ, conservation_penalty)

Return type:

tuple

rtnn.evaluater.heating_rate(up, down, p=None)[source]

Calculate heating rate from upwelling and downwelling fluxes.

Parameters:
  • up (torch.Tensor) – Upwelling flux. Shape (batch, 1, n_pft, n_bands, seq_length)

  • down (torch.Tensor) – Downwelling flux. Same shape as up.

  • p (torch.Tensor, optional) – Pressure levels. If provided, computes heating rate in K/day. If None, computes negative flux divergence. Default is None.

Returns:

  • torch.Tensor – Heating rate or flux divergence. Shape (batch, 1, n_pft, n_bands, seq_length-1)

  • Notes —–

  • The calculation involves

    1. net = up - down (net flux)

    2. dnet = net - roll(net, 1) (vertical flux divergence)

    3. If p is provided, convert to heating rate using: hr = -dnet/dp * (g * 8.64e4) / (cp * 100) where g = 9.8066 m/s², cp ≈ 1004 J/(kg·K) The factor 8.64e4 converts from W/m²/Pa to K/day

rtnn.evaluater.run_validation(loader, model, norm_mapping, normalization_type, index_mapping, device, args, epoch, logger=None, base_dir='./results', n_pft=15, n_bands=2, n_chans=4)[source]

Evaluate model accuracy on LSM dataset.

Performs comprehensive evaluation including:

  • Loss computation for main fluxes and absorption rates

  • Metric calculation (NMAE, NMSE, R²) for fluxes and absorption

  • Optional plotting of predictions vs targets (every 10 epochs)

  • Energy conservation verification

Parameters:
  • loader (torch.utils.data.DataLoader) – Data loader for evaluation dataset.

  • model (torch.nn.Module) – Trained model to evaluate.

  • norm_mapping (dict) – Normalization statistics for variables.

  • normalization_type (dict) – Normalization types per variable.

  • index_mapping (dict) – Mapping from channel indices (0-3) to variable names.

  • device (torch.device) – Device to run evaluation on (cuda or cpu).

  • args (argparse.Namespace) – Arguments containing loss_type, beta, beta_delta, and num_epochs.

  • epoch (int) – Current epoch number (for plotting schedule).

  • logger (logging.Logger, optional) – Logger for informational messages. If None, no logging occurs.

  • base_dir (str, optional) – Directory to save diagnostic plots. Default is “./results”.

  • n_pft (int, optional) – Number of Plant Functional Types. Default is 15.

  • n_bands (int, optional) – Number of spectral bands. Default is 2 (VIS, NIR).

  • n_chans (int, optional) – Number of output channels. Default is 4.

Returns:

(valid_loss, valid_metrics)

Return type:

tuple

Notes

The evaluation performs the following steps: 1. Iterates through validation loader 2. Computes predictions and reshapes to 5D tensors 3. De-normalizes predictions and targets to physical units 4. Calculates absorption rates and conservation penalties 5. Computes metrics for fluxes and absorption 6. Optionally generates diagnostic plots (epoch % 10 == 0 or final epoch)

The combined loss includes both flux and absorption terms weighted by β:

total_loss = (1-β)*loss_fluxes + β*(loss_abs12 + loss_abs34)

Examples

>>> valid_loss, metrics = run_validation(
...     loader=val_loader,
...     model=my_model,
...     norm_mapping=norm_stats,
...     normalization_type=norm_types,
...     index_mapping=idxmap,
...     device=torch.device('cuda'),
...     args=args,
...     epoch=10,
...     logger=logger
... )
>>> print(f"Validation NMAE: {metrics['fluxes_NMAE']:.4f}")
>>> print(f"R² score: {metrics['fluxes_R2']:.4f}")

rtnn.logger module

Logging utilities for RTnn model training and evaluation.

This module provides a rich, customizable logging system for radiative transfer neural network workflows. It combines Python’s standard logging with the Rich library to deliver visually appealing, informative console output while maintaining file-based logging for archival purposes.

The Logger class offers multiple log levels and specialized formatting for different types of messages:

  • Info messages: General information (cyan)

  • Warning messages: Non-critical issues (yellow with warning emoji)

  • Success messages: Successful completions (green with checkmark)

  • Error messages: Critical failures (red panels)

  • Exception messages: Full traceback displays (multi-panel format)

  • Step messages: Pipeline stage transitions (magenta)

  • Task start messages: Formatted task initiation banners

Key features include:

  • Dual output: console (Rich-formatted) and file (plain text)

  • Progress bars for long-running operations

  • Structured task tracking with start/end logging

  • Metrics aggregation and display

  • Exception traceback visualization with code snippets

  • Configurable output (console/file/both)

  • Pretty printing support

class rtnn.logger.Logger(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]

Bases: object

__init__(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
clear_logs()[source]

Clear the stored Rich logs if record=True.

show_header(module_name)[source]

Display startup banner.

start_task(task_name: str, description: str = '', **meta)[source]

Display a clearly formatted ‘task start’ message with good spacing.

log_metrics()[source]

Log pipeline metrics

info(message)[source]

Formatted info message

warning(message)[source]

Formatted warning message

success(message)[source]

Custom success level (not default logging level)

step(step_name, message)[source]

Highlight pipeline step events

exception(message, exception=None)[source]

Display a formatted exception message with visual stack trace.

error(message, exception=None)[source]

Display a formatted error log, optionally including exception trace.

rtnn.main module

RTnn (Radiative Transfer Neural Network) Training Pipeline

This module provides the main entry point for training neural network models for radiative transfer calculations in climate modeling. It supports various model architectures including LSTM, GRU, Transformer, and FCN.

The training pipeline includes:

  • Data loading and preprocessing from NetCDF files

  • Model initialization and configuration

  • Training loop with progress tracking

  • Validation and metric computation

  • Checkpoint saving and model persistence

  • Visualization and logging

Module Overview

This module implements a complete machine learning pipeline for radiative transfer modeling in climate science. The pipeline is designed to handle spatio-temporal data from NetCDF files, apply appropriate normalization, train various neural network architectures, and evaluate model performance using physics-informed metrics.

Key Features

  1. Data Handling

    • Automatic reading of NetCDF files with year-based filtering

    • Support for multiple normalization schemes (log1p_standard, standard, minmax)

    • Temporal batching and sequence generation

    • Multi-year training and held-out year testing

  2. Model Architectures

    • LSTM (Long Short-Term Memory)

    • GRU (Gated Recurrent Unit)

    • Transformer with attention mechanism

    • FCN (Fully Connected Networks)

    • MLP with residual connections

    • Custom RT-specific architectures

  3. Training Features

    • Configurable loss functions (MSE, MAE, NMSE, NMAE, SmoothL1, Huber)

    • Physics-informed weighted loss combining flux and absorption terms

    • Learning rate scheduling with ReduceLROnPlateau

    • Multi-GPU support with DataParallel

    • Checkpoint saving and resumption

  4. Evaluation Metrics

    • NMAE (Normalized Mean Absolute Error)

    • NMSE (Normalized Mean Squared Error)

    • R² (Coefficient of Determination)

    • Conservation penalty for physical consistency

  5. Output and Visualization

    • TensorBoard integration for real-time monitoring

    • Training/validation loss plots

    • Metric history visualization

    • Spatial-temporal density scatter plots

    • Automatic checkpoint saving (best, epoch, final)

Data Flow

  1. Parse command-line arguments (parse_args)

  2. Setup directory structure and logging (setup_directories_and_logging)

  3. Configure device and random seeds (setup_device_and_seed)

  4. Load and preprocess data (get_data_files, create_datasets_and_loaders)

  5. Compute normalization statistics (create_normalization_mapping)

  6. Initialize model architecture (initialize_model)

  7. Load checkpoint if resuming (load_checkpoint_if_requested)

  8. Run inference or training loop

  9. Generate plots and save results

rtnn.main.print_version()[source]

Print detailed version information.

rtnn.main.parse_years(year_str)[source]

Parse a year string into a list of integers.

Supports hyphen-separated ranges (e.g., “1995-1999”) and comma-separated lists (e.g., “1995,1997,1999”). Returns a list of integers.

Parameters:

year_str (str) – String containing years in range or comma-separated format.

Returns:

List of parsed years.

Return type:

list of int

Examples

>>> parse_years("1995-1999")
[1995, 1996, 1997, 1998, 1999]
>>> parse_years("1995,1997,1999")
[1995, 1997, 1999]
rtnn.main.parse_args()[source]

Parse command-line arguments for RTnn model training.

Defines and parses command-line arguments required to configure and run the Radiative Transfer Neural Network (RTnn) training pipeline. This includes model architecture parameters, training hyperparameters, data configuration, and output settings.

Returns:

Object containing parsed command-line arguments, grouped as follows:

Model architecture
typestr

Model type (e.g., “lstm”, “gru”, “fcn”, “fullyconnected”, “transformer”, “cnn”, “mlp”).

hidden_sizeint

Size of hidden layers.

num_layersint

Number of model layers.

seq_lengthint

Length of input sequence.

feature_channelint

Number of input feature channels.

output_channelint

Number of output channels.

embed_sizeint

Embedding dimension for transformer models.

nheadint

Number of attention heads (transformer).

forward_expansionint

Expansion factor for feed-forward layers.

dropoutfloat

Dropout rate.

Training hyperparameters
batch_sizeint

Number of samples per batch.

tbatchint

Temporal batch length.

num_epochsint

Number of training epochs.

learning_ratefloat

Initial learning rate.

loss_typestr

Loss function (e.g., “mse”, “mae”, “nmae”, “nmse”, “wmse”, “logcosh”, “smoothl1”, “huber”).

betafloat

Weighting factor for loss components.

beta_deltafloat

Delta parameter for Huber or SmoothL1 loss.

num_workersint

Number of data loader workers.

Data configuration
train_data_filesstr

Path or pattern for training data files.

test_data_filesstr

Path or pattern for testing data files.

train_yearsstr

Training years (comma-separated or range, e.g., “1995-1999”).

test_yearstr

Test year or range.

normstr

Normalization scheme (e.g., “log1p_standard”, “standard”, “minmax”, “none”).

dataset_typestr

Dataset type (e.g., “LSM”, “RTM”).

Output configuration
root_dirstr

Root directory for all operations.

main_folderstr

Main output folder name.

sub_folderstr

Sub-folder name for the current run.

prefixstr

Prefix for saved files.

model_namestr

Custom model name (auto-generated if empty).

save_modelbool

Whether to save model checkpoints.

save_checkpoint_namestr

Base name for saved checkpoints.

save_per_samplesint

Save checkpoint every N samples.

load_modelbool

Whether to load an existing model.

load_checkpoint_namestr

Checkpoint file to load.

inferencebool

Run in inference-only mode.

Return type:

argparse.Namespace

Examples

>>> args = parse_args()
>>> args.type
'lstm'
>>> args.batch_size
16

Command line usage

$ rtnn –type lstm –hidden_size 128 –num_layers 3 –batch_size 32

rtnn.main.setup_directories_and_logging(args)[source]

Set up directory structure and logging infrastructure for experiments.

Parameters:

args (argparse.Namespace) – Parsed command-line arguments.

Returns:

  • paths (EasyDict) – Dictionary containing paths to created directories.

  • logger (Logger) – Configured logger instance.

rtnn.main.log_configuration(args, paths, logger)[source]

Log all configuration parameters to the provided logger.

Parameters:
  • args (argparse.Namespace) – Configuration object containing all experiment parameters.

  • paths (EasyDict) – Dictionary containing paths to various experiment directories.

  • logger (Logger) – Logger instance for outputting configuration information.

rtnn.main.setup_device_and_seed(args, logger)[source]

Set up device (GPU/CPU) and random seeds for reproducibility.

Parameters:
Returns:

Device to use for computations.

Return type:

torch.device

rtnn.main.get_data_files(args, logger)[source]

Get training and testing data files based on year specifications.

Parameters:
Returns:

(train_files, test_files) lists of file paths.

Return type:

tuple

rtnn.main.create_normalization_mapping(train_files, paths, logger)[source]

Create normalization mapping from training data.

Parameters:
  • train_files (list) – List of training file paths.

  • paths (EasyDict) – Directory paths.

  • logger (Logger) – Logger instance.

Returns:

Normalization statistics for each variable.

Return type:

dict

rtnn.main.create_datasets_and_loaders(args, train_files, test_files, norm_mapping, logger)[source]

Create datasets and data loaders for training and validation.

Parameters:
  • args (argparse.Namespace) – Parsed command-line arguments.

  • train_files (list) – Training file paths.

  • test_files (list) – Test file paths.

  • norm_mapping (dict) – Normalization statistics.

  • logger (Logger) – Logger instance.

Returns:

(train_loader, test_loader, train_dataset, test_dataset)

Return type:

tuple

rtnn.main.initialize_model(args, device, logger)[source]

Initialize the model architecture.

Parameters:
Returns:

Initialized model.

Return type:

torch.nn.Module

rtnn.main.load_checkpoint_if_requested(args, model, optimizer, paths, device, logger)[source]

Load model checkpoint if requested using ModelUtils.load_training_checkpoint().

This function leverages ModelUtils.load_training_checkpoint() which handles: - DataParallel compatibility - Loading model and optimizer states - Extracting training state (epoch, samples, metrics, etc.)

Parameters:
Returns:

(start_epoch, samples_processed, batches_processed, best_val_loss,

best_epoch, checkpoint, train_loss_history, valid_loss_history, valid_metrics_history)

Return type:

tuple

rtnn.main.train_epoch(model, train_loader, optimizer, loss_func, metric_funcs, metric_names, output_keys, train_metrics, train_loss_tracker, norm_mapping, normalization_type, index_mapping, device, args, epoch, writer, global_step, logger, n_pft=15, n_bands=2, n_chans=4)[source]

Train for one epoch.

Returns:

(average_train_loss, updated_global_step)

Return type:

tuple

rtnn.main.main()[source]

Main entry point for training the RTnn model.

rtnn.model_loader module

rtnn.model_loader - Model Factory for Radiative Transfer Neural Networks

This module provides a factory function load_model() that instantiates various neural network architectures for radiative transfer calculations in climate modeling. It serves as the central point for model creation across the RTnn framework.

Module Overview

The model loader module implements a factory pattern to abstract away the details of model instantiation. Based on a configuration object (typically from command-line arguments), it returns an initialized PyTorch model ready for training or inference.

Supported Model Architectures

  1. Recurrent Neural Networks (RNN)

    • LSTM (Long Short-Term Memory): Bidirectional LSTM with Conv1d projection

    • GRU (Gated Recurrent Unit): Bidirectional GRU with Conv1d projection

    • Best for: Sequential data with temporal dependencies

  2. Transformer Models

    • EncoderTorch: PyTorch-native transformer encoder

    • Features: Multi-head self-attention, positional embeddings

    • Best for: Long-range dependencies and parallel processing

  3. Fully Connected Networks (FCN)

    • Standard FCN: Multi-layer perceptron with configurable depth

    • PINN: Physics-Informed Neural Network with two-stream architecture for vertical profiles

    • Best for: Non-sequential, feature-based inputs

  4. Multi-Layer Perceptrons (MLP)

    • MLP: Standard MLP with batch/layer normalization options

    • MLPResidual: MLP with residual connections between layers

    • Features: Positional embeddings, dropout, multiple activation functions

    • Best for: Flexible architecture with skip connections

Architecture Details

All models are designed to handle the specific requirements of radiative transfer problems: - Input shape: (batch, seq_len, feature_channel) - Output shape: (batch, seq_len, output_channel) - Physical constraints: Conservation of energy, positive outputs

The models output four channels corresponding to: - collim_alb (collimated albedo) - collim_tran (collimated transmittance) - isotrop_alb (isotropic albedo) - isotrop_tran (isotropic transmittance)

Data Flow

  1. Parse configuration (args)

  2. Determine model type from args.type

  3. Extract model-specific parameters

  4. Instantiate appropriate model class

  5. Return initialized model

rtnn.model_loader.load_model(args)[source]

Factory function to instantiate a neural network model from configuration.

This function builds and returns a PyTorch model based on the value of args.type. It supports multiple architectures including recurrent, convolutional, transformer-based, and fully connected models.

Supported model families

  • LSTM / GRU: Bidirectional recurrent models with Conv1d projection head

  • Transformer: PyTorch Transformer encoder with positional embeddings

  • FCN / FullyConnected: Fully connected feedforward network for sequences

  • VRT / VerticalRT: Physics-inspired vertical column model

  • MLP: Flexible multilayer perceptron with optional embeddings and residuals

  • MLPResidual: Deep residual MLP with layer-wise skip connections

returns:

Instantiated PyTorch model corresponding to the requested architecture.

rtype:

torch.nn.Module

raises ValueError:

If args.type does not match any supported model.

rtnn.model_utils module

Model utility functions for PyTorch training workflows.

This module provides a collection of helper utilities for inspecting models, analyzing parameter distributions, and managing checkpoints during training. It is designed to standardize common operations such as saving/loading model states, logging architecture details, and maintaining reproducible training artifacts.

The utilities support both single-GPU and multi-GPU (DataParallel) setups and include safeguards for compatibility when loading checkpoints across different hardware configurations.

Features

  • Parameter counting (total and trainable)

  • Layer-wise parameter inspection

  • Model structure logging

  • Checkpoint saving and loading

  • Full training state persistence

  • Emergency checkpointing for crash recovery

  • DataParallel-aware state dictionary handling

Notes

  • Checkpoints include both model weights and optimizer states to enable seamless training resumption.

  • File naming conventions are automatically adapted based on checkpoint type (e.g., epoch, best, final, emergency).

  • When using torch.nn.DataParallel, the module automatically adjusts state dictionary keys to ensure compatibility between wrapped and unwrapped models.

Dependencies

  • torch

  • os

  • datetime

Examples

Basic usage:

>>> from model_utils import ModelUtils
>>> counts = ModelUtils.get_parameter_number(model)
>>> ModelUtils.print_model_layers(model)

Saving and loading checkpoints:

>>> state = {
...     "state_dict": model.state_dict(),
...     "optimizer": optimizer.state_dict(),
...     "epoch": epoch,
... }
>>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar")
>>> checkpoint = torch.load("checkpoint.pth.tar")
>>> ModelUtils.load_checkpoint(checkpoint, model, optimizer)

Saving a full training checkpoint:

>>> ModelUtils.save_training_checkpoint(
...     model, optimizer, epoch, samples_processed, batches_processed,
...     train_loss_history, valid_loss_history, valid_metrics_history,
...     best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss,
...     args, paths, logger, checkpoint_type="best"
... )
class rtnn.model_utils.ModelUtils[source]

Bases: object

Utility class for model inspection, checkpointing, and memory profiling.

This class provides static methods for common model operations including parameter counting, memory usage analysis, checkpoint management, and model inspection.

Examples

>>> utils = ModelUtils()
>>> param_counts = ModelUtils.get_parameter_number(model)
>>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar", logger)
__init__()[source]

Initialize ModelUtils instance.

static get_parameter_number(model, logger=None)[source]

Calculate the total and trainable number of parameters in a model.

Parameters:
  • model (torch.nn.Module) – PyTorch model to inspect

  • logger (Logger, optional) – Logger instance for output, by default None

Returns:

Dictionary containing: - ‘Total’: Total number of parameters - ‘Trainable’: Number of trainable parameters

Return type:

dict

Examples

>>> model = torch.nn.Linear(10, 5)
>>> counts = ModelUtils.get_parameter_number(model, logger)
static print_model_layers(model, logger=None)[source]

Print model parameter names along with their gradient requirements.

Parameters:
  • model (torch.nn.Module) – PyTorch model to inspect

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> model = torch.nn.Sequential(
...     torch.nn.Linear(10, 5),
...     torch.nn.ReLU(),
...     torch.nn.Linear(5, 1)
... )
>>> ModelUtils.print_model_layers(model, logger)
static save_checkpoint(state, filename='checkpoint.pth.tar', logger=None)[source]

Save model and optimizer state to a file.

Parameters:
  • state (dict) – Dictionary containing model state_dict and other training information. Typically includes: - ‘state_dict’: Model parameters - ‘optimizer’: Optimizer state - ‘epoch’: Current epoch - ‘loss’: Current loss value

  • filename (str, optional) – File path to save the checkpoint, by default “checkpoint.pth.tar”

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> state = {
...     'state_dict': model.state_dict(),
...     'optimizer': optimizer.state_dict(),
...     'epoch': epoch,
...     'loss': loss
... }
>>> ModelUtils.save_checkpoint(state, 'model_checkpoint.pth.tar', logger)
static load_checkpoint(checkpoint, model, optimizer=None, logger=None)[source]

Load model and optimizer state from a checkpoint file.

Parameters:
  • checkpoint (dict) – Loaded checkpoint dictionary

  • model (torch.nn.Module) – Model to load weights into

  • optimizer (torch.optim.Optimizer, optional) – Optimizer to restore state, by default None

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> checkpoint = torch.load('model_checkpoint.pth.tar')
>>> ModelUtils.load_checkpoint(checkpoint, model, optimizer, logger)
static load_training_checkpoint(checkpoint_path, model, optimizer, device, logger=None)[source]

Load comprehensive training checkpoint.

Parameters:
Returns:

(epoch, samples_processed, batches_processed, best_val_loss, best_epoch, checkpoint)

Return type:

tuple

static count_parameters_by_layer(model, logger=None)[source]

Count parameters for each layer in the model.

Parameters:
  • model (torch.nn.Module) – PyTorch model to analyze

  • logger (Logger, optional) – Logger instance for output, by default None

Returns:

Dictionary with layer names as keys and parameter counts as values

Return type:

dict

Examples

>>> layer_params = ModelUtils.count_parameters_by_layer(model, logger)
static log_model_summary(model, input_shape=None, logger=None)[source]

Log comprehensive model summary including parameters and architecture.

Parameters:
  • model (torch.nn.Module) – PyTorch model to summarize

  • input_shape (tuple, optional) – Input shape for memory analysis, by default None

  • logger (Logger, optional) – Logger instance for output, by default None

static save_training_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, args, paths, logger, checkpoint_type='epoch', save_full_model=True)[source]

Save comprehensive training checkpoint with consistent formatting.

Parameters:
  • model (torch.nn.Module) – Model to save

  • optimizer (torch.optim.Optimizer) – Optimizer to save

  • epoch (int) – Current epoch

  • samples_processed (int) – Number of samples processed so far

  • batches_processed (int) – Number of batches processed so far

  • train_loss_history (list) – History of training losses

  • valid_loss_history (list) – History of validation losses

  • valid_metrics_history (dict) – History of validation metrics

  • best_val_loss (float) – Best validation loss so far

  • best_epoch (int) – Epoch with best validation loss

  • avg_val_loss (float) – Current epoch validation loss

  • avg_epoch_loss (float) – Current epoch training loss

  • args (argparse.Namespace) – Command line arguments

  • paths (EasyDict) – Directory paths

  • logger (Logger) – Logger instance

  • checkpoint_type (str) – Type of checkpoint: “samples”, “epoch”, “best”, “final”

  • save_full_model (bool) – Whether to also save the full model separately

Returns:

(checkpoint_filename, full_model_filename)

Return type:

tuple

Examples

>>> checkpoint_file, full_model_file = ModelUtils.save_training_checkpoint(
...     model, optimizer, epoch, samples_processed, batches_processed,
...     train_loss_history, valid_loss_history, valid_metrics_history,
...     best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss,
...     args, paths, logger, checkpoint_type="best"
... )
static save_emergency_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, args, paths, logger, reason='emergency')[source]

Save emergency checkpoint for recovery.

Parameters:

reason (str) – Reason for emergency save (e.g., “crash”, “interrupt”, “error”)

Returns:

(checkpoint_filename, full_model_filename)

Return type:

tuple

rtnn.utils module

File and lightweight configuration utilities.

This module provides helper classes for simplified file system operations and convenient dictionary handling with attribute-style access. It is designed to reduce boilerplate code when working with directories, files, and configuration-like data structures in Python projects.

The module includes:

  • EasyDict: A dictionary subclass enabling attribute-style access (e.g., cfg.key instead of cfg["key"]).

  • FileUtils: Static utility methods for creating directories and files.

Features

  • Attribute-style access to dictionary keys

  • Minimal and dependency-free implementation

  • Safe directory creation (no error if directory already exists)

  • Simple file creation utility

  • Lightweight and suitable for configuration management

Notes

  • EasyDict raises AttributeError when accessing missing keys, making it behave more like standard Python objects.

  • FileUtils.makedir will create nested directories if needed.

  • FileUtils.makefile creates an empty file if it does not exist, and does nothing if it already exists.

Dependencies

  • os

  • typing

Examples

Using EasyDict:

>>> cfg = EasyDict()
>>> cfg.learning_rate = 0.001
>>> cfg.batch_size = 32
>>> print(cfg.learning_rate)
0.001
>>> print(cfg["batch_size"])
32

Using FileUtils:

>>> FileUtils.makedir("outputs/logs")
>>> FileUtils.makefile("outputs/logs", "train.log")

Combined usage:

>>> paths = EasyDict()
>>> paths.output_dir = "outputs"
>>> FileUtils.makedir(paths.output_dir)
class rtnn.utils.EasyDict[source]

Bases: dict

A dictionary subclass that allows for attribute-style access to its items. This class extends the built-in dict and overrides the __getattr__, __setattr__, and __delattr__ methods to enable accessing dictionary keys as attributes. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm

class rtnn.utils.FileUtils[source]

Bases: object

Utility class for file and directory operations.

__init__()[source]

Initialize the FileUtils class. This class does not maintain any state, so the constructor is empty.

static makedir(dirs)[source]

Create a directory if it does not exist.

Parameters:

dirs (str) – The path of the directory to be created.

static makefile(dirs, filename)[source]

Create an empty file in the specified directory. :param dirs: The path of the directory where the file will be created. :type dirs: str :param filename: The name of the file to be created. :type filename: str

rtnn.version module

Version information for rtnn.

rtnn.version.get_version()[source]

Return the version string.