rtnn package
Subpackages
Submodules
rtnn.dataset module
Dataset module for RTnn radiative transfer neural network framework.
This module provides the DataPreprocessor class for loading, preprocessing, and normalizing NetCDF climate data for training and evaluating neural network emulators of atmospheric radiative transfer. The module handles complex data structures including multiple years, spatial and temporal batching, and various normalization techniques.
The DataPreprocessor is designed to work with Land Surface Model (LSM) data containing variables such as: - Cosine of solar zenith angle (cosz) - Leaf area index (lai) - Single scattering albedo (ssa) - Surface reflectance (rs) - Radiative transfer outputs (collimated albedo/transmittance, isotropic albedo/transmittance)
The module supports data augmentation through random spatial sampling during training and provides flexible normalization methods including min-max scaling, standardization, robust scaling, and transformations like log1p and sqrt.
Notes
The expected input data structure consists of NetCDF4 files with dimensions: - time: Temporal dimension - dim_1: Spatial points (e.g., grid cells) - dim_2: Vertical levels (sequence length, typically 10) - dim_3: Plant functional types (PFTs, typically 15) - dim_4: Spectral bands (typically 2)
Feature channels (121 total): - cosz: 1 channel - lai: 2 variables × 15 PFTs = 30 channels - ssa: 2 variables × 2 bands × 15 PFTs = 60 channels - rs: 1 variable × 2 bands × 15 PFTs = 30 channels
Output channels (120 total): - 4 variables × 2 bands × 15 PFTs = 120 channels
Examples
Basic usage for training:
>>> from rtnn.logger import Logger
>>> logger = Logger()
>>> train_files = ["data_1995.nc", "data_1996.nc", "data_1997.nc"]
>>> dataset = DataPreprocessor(
... logger=logger,
... dfs=train_files,
... stime=0,
... tbatch=24,
... training=True,
... sblock_perc=0.6,
... norm_mapping=norm_stats,
... normalization_type=norm_types
... )
>>> features, targets = dataset[0]
>>> features.shape
torch.Size([158, 121, 10])
>>> targets.shape
torch.Size([158, 120, 10])
For validation/testing:
>>> val_dataset = DataPreprocessor(
... logger=logger,
... dfs=val_files,
... stime=0,
... tbatch=24,
... training=False,
... norm_mapping=norm_stats,
... normalization_type=norm_types
... )
See also
rtnn.modelsContains the neural network architectures for radiative transfer
rtnn.evaluaterProvides metrics and loss functions for model evaluation
rtnn.diagnosticsVisualization tools for model predictions
- class rtnn.dataset.DataPreprocessor(*args: Any, **kwargs: Any)[source]
Bases:
DatasetDataset class for preprocessing LSM (Land Surface Model) data.
This class handles loading and preprocessing of NetCDF files containing climate data, with support for multiple years, spatial and temporal batching, and various normalization techniques.
- Parameters:
logger (object) – Logger instance for logging messages.
dfs (List[str]) – List of file paths to NetCDF files.
stime (int) – Start time index.
tstep (int) – Number of time steps per file.
tbatch (int) – Temporal batch size.
norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable. Default is empty dict.
normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable. Default is empty dict.
- norm_mapping
Normalization statistics.
- Type:
Dict
- normalization_type
Normalization types per variable.
- Type:
Dict
- time_blocks
Shuffled time blocks.
- Type:
np.ndarray
Examples
>>> from rtnn.logger import Logger >>> logger = Logger() >>> files = ["data_1995.nc", "data_1996.nc"] >>> dataset = DataPreprocessor( ... logger=logger, ... dfs=files, ... stime=0, ... tstep=100, ... tbatch=24, ... norm_mapping={}, ... normalization_type={} ... ) >>> len(dataset) 100 >>> features, targets = dataset[0] >>> features.shape torch.Size([schunk, feature_channels, seq_length]) >>> targets.shape torch.Size([schunk, output_channels, seq_length])
- __init__(logger: Any, dfs: List[str], stime: int, tbatch: int, training: bool = True, sblock_perc: float = 0.6, norm_mapping: Dict = {}, normalization_type: Dict = {}, debug: bool = False) None[source]
Initialize the DataPreprocessor.
- Parameters:
logger (Any) – Logger instance for logging messages.
dfs (List[str]) – List of file paths to NetCDF files.
stime (int) – Start time index.
tbatch (int) – Temporal batch size.
training (bool, optional) – If True, use 60% of spatial batches (data augmentation). If False, use 100% of spatial batches (full evaluation).
norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable.
normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable.
debug (bool, optional) – If True, print debug information.
- normalize(data: numpy.ndarray, var_name: str) numpy.ndarray[source]
Normalize data using the specified normalization method.
- Parameters:
data (np.ndarray) – Input data array to normalize.
var_name (str) – Name of the variable for which to retrieve normalization statistics.
- Returns:
Normalized data array.
- Return type:
np.ndarray
- Raises:
ValueError – If the normalization type is not supported.
Notes
Supported normalization types: - minmax: (x - min) / (max - min) - standard: (x - mean) / std - robust: (x - median) / IQR - log1p_minmax: log1p(x) normalized - log1p_standard: log1p(x) standardized - log1p_robust: log1p(x) robust normalized - sqrt_minmax: sqrt(x) normalized - sqrt_standard: sqrt(x) standardized - sqrt_robust: sqrt(x) robust normalized
rtnn.diagnostics module
Diagnostics and visualization utilities for RTnn radiative transfer model.
This module provides comprehensive visualization tools for analyzing and evaluating neural network emulators of atmospheric radiative transfer. It includes functions for creating diagnostic plots of model predictions, training histories, and data statistics.
The module is designed to support: - Comparison of model predictions against ground truth targets - Visualization of radiative fluxes (direct/diffuse, upwelling/downwelling) - Absorption rate analysis across vertical levels and spectral bands - Per-plant functional type (PFT) and per-spectral band diagnostics - Training and validation metric tracking over epochs - Spatial and temporal sampling distribution analysis - Statistical characterization of input variables for normalization
Key visualization types
Hexbin plots: Density scatter plots with color mapping for high-volume data
Line plots: Time series or vertical profile comparisons
Histograms: Distribution analysis with optional log scaling
Marginal histograms: 2D density plots with side distributions
Multi-panel layouts: Systematic comparison across flux channels
The module follows the matplotlib architecture using Figure and FigureCanvasAgg for non-interactive, file-based rendering suitable for batch processing and headless environments.
Notes
All plotting functions are designed to work with PyTorch tensors or numpy arrays and save figures directly to disk. The module uses custom matplotlib parameters optimized for scientific publication quality.
Default figure parameters: - Font family: DejaVu Sans - Font sizes: axes labels (15), titles (15), ticks (12), legends (15) - Line widths: 2 - Tick direction: outward - Legend: frame off, best location
Examples
Basic usage for generating diagnostic plots:
>>> import torch
>>> from rtnn.diagnostics import plot_flux_and_abs, plot_metric_histories
>>>
>>> # Assuming you have model predictions and targets
>>> predicts = torch.randn(32, 4, 10) # (batch, channels, levels)
>>> targets = torch.randn(32, 4, 10)
>>>
>>> # Create hexbin plot
>>> plot_flux_and_abs(
... predicts=predicts,
... targets=targets,
... filename="diagnostics.png",
... logger=logger
... )
>>>
>>> # Plot training history
>>> train_history = {"nmae": [0.1, 0.08, 0.06], "r2": [0.85, 0.88, 0.91]}
>>> valid_history = {"nmae": [0.12, 0.10, 0.08], "r2": [0.82, 0.85, 0.89]}
>>> plot_metric_histories(
... train_history=train_history,
... valid_history=valid_history,
... filename="metrics.png"
... )
See also
rtnn.dataset.DataPreprocessorData loading and preprocessing
rtnn.evaluaterMetrics and loss functions for evaluation
rtnn.modelsNeural network architectures for radiative transfer
- rtnn.diagnostics.stats(file_list, logger, output_dir, norm_mapping=None, plots=False)[source]
Compute statistics and generate histograms for variables in NetCDF files.
Reads a collection of NetCDF files, computes descriptive statistics for each variable, and generates histogram plots saved to disk. In addition to raw statistics, transformed statistics using logarithmic (log1p) and square-root transformations are also computed.
- Parameters:
file_list (list of str) – Paths to the NetCDF files to process.
logger (logging.Logger) – Logger used to report progress and informational messages.
output_dir (str) – Directory where histogram plots will be saved.
norm_mapping (dict, optional) – Dictionary to update with computed statistics. If None, a new dictionary is created. Default is None.
plots (bool, optional) – If True, generate and save histogram plots for each variable. Default is False.
- Returns:
Dictionary mapping variable names to their computed statistics. Each variable contains the following entries:
- Raw statistics:
- vminfloat
Minimum value
- vmaxfloat
Maximum value
- vmeanfloat
Mean value
- vstdfloat
Standard deviation
- Robust statistics:
- q1float
First quartile (25th percentile)
- q3float
Third quartile (75th percentile)
- iqrfloat
Interquartile range (q3 - q1)
- medianfloat
Median value
- Log-transformed statistics (log1p):
log_min : float
log_max : float
log_mean : float
log_std : float
log_q1 : float
log_q3 : float
log_iqr : float
log_median : float
- Square-root-transformed statistics:
sqrt_min : float
sqrt_max : float
sqrt_mean : float
sqrt_std : float
sqrt_q1 : float
sqrt_q3 : float
sqrt_iqr : float
sqrt_median : float
- Return type:
Examples
>>> norm_mapping = stats( ... file_list=["data_1995.nc", "data_1996.nc"], ... logger=logger, ... output_dir="./stats", ... plots=True ... ) >>> norm_mapping["coszang"]["vmean"] 0.5 >>> norm_mapping["lai"]["log_median"] 1.23
Notes
Log transformation uses np.log1p (log(1+x)) to handle zero values
Square root transformation uses np.sqrt with clipping to non-negative
Histograms use 200 bins and log-scaled y-axis
- rtnn.diagnostics.subplots(nrows, ncols, figsize)[source]
Create a figure and grid of subplots without pyplot.
This is a replacement for matplotlib.pyplot.subplots() that uses the Figure and FigureCanvasAgg API directly, suitable for headless environments.
- Parameters:
- Returns:
fig (matplotlib.figure.Figure) – The created figure object.
axes (numpy.ndarray) – Array of axes objects with shape (nrows, ncols).
- rtnn.diagnostics.plot_flux_and_abs_lines(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output_lines.png', logger=None)[source]
Create line plots for fluxes and absorption rates.
Generates a multi-panel figure with line plots for four flux channels and optionally two absorption panels. Each panel shows predictions vs targets across vertical levels.
- Parameters:
predicts (torch.Tensor or np.ndarray) – Model predictions for fluxes of shape (batch_size, 4, seq_length).
targets (torch.Tensor or np.ndarray) – Ground truth fluxes of shape (batch_size, 4, seq_length).
abs12_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 1-2 (direct beam).
abs12_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 1-2 (direct beam).
abs34_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 3-4 (diffuse).
abs34_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 3-4 (diffuse).
filename (str, optional) – Output filename. Default is “output_lines.png”.
logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.
Notes
- Figure layout:
- 2x2 grid for fluxes:
(0,0): Direct upwelling (Flux_direct^u)
(0,1): Direct downwelling (Flux_direct^d)
(1,0): Diffuse upwelling (Flux_diffusion^u)
(1,1): Diffuse downwelling (Flux_diffusion^d)
- Optional 1x2 grid for absorption rates (if provided):
(2,0): Direct absorption (Abs_direct)
(2,1): Diffuse absorption (Abs_diffusion)
The function randomly selects 10 samples from the batch for plotting. Each sample is shown with a different line style.
Examples
>>> plot_flux_and_abs_lines( ... predicts=model_outputs, ... targets=ground_truth, ... filename="line_comparison.png", ... logger=logger ... )
- rtnn.diagnostics.plot_flux_and_abs(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output.png', logger=None)[source]
Create hexbin plots for fluxes and absorption rates.
Generates a multi-panel figure with hexbin density plots showing the relationship between predicted and true values. Useful for assessing prediction accuracy across the entire dataset.
- Parameters:
predicts (torch.Tensor or np.ndarray) – Model predictions for fluxes of shape (batch_size, 4, seq_length).
targets (torch.Tensor or np.ndarray) – Ground truth fluxes of shape (batch_size, 4, seq_length).
abs12_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 1-2 (direct beam).
abs12_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 1-2 (direct beam).
abs34_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 3-4 (diffuse).
abs34_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 3-4 (diffuse).
filename (str, optional) – Output filename. Default is “output.png”.
logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.
Notes
Hexbin plots use logarithmic color scale (bins=’log’)
Includes diagonal reference line (y=x) in red dotted style
Displays R² score in the top-left corner of each panel
Shared colorbar on the right with log10(count) label
Grid size of 100 bins for density calculation
The hexbin representation is particularly effective for large datasets where scatter plots would suffer from overplotting.
Examples
>>> plot_flux_and_abs( ... predicts=predictions, ... targets=targets, ... filename="hexbin_diagnostics.png", ... logger=logger ... )
- rtnn.diagnostics.plot_all_diagnostics(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, n_pft=15, n_bands=2, n_chans=4, output_dir='./results', prefix='diagnostics', logger=None)[source]
Generate all diagnostic plots: aggregated, per PFT, per band.
This function creates a comprehensive set of diagnostic plots including one aggregated plot (all PFTs, all bands combined), per-band plots for each selected PFT (VIS and NIR bands), and line plots and hexbin plots for each combination.
- Parameters:
predicts (torch.Tensor or np.ndarray) – Model predictions for fluxes. Shape: (batch, 4, n_pft, n_bands, seq_length)
targets (torch.Tensor or np.ndarray) – Ground truth targets. Same shape as predicts.
abs12_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 1-2 (direct beam).
abs12_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 1-2.
abs34_predict (torch.Tensor or np.ndarray, optional) – Predicted absorption for channels 3-4 (diffuse).
abs34_target (torch.Tensor or np.ndarray, optional) – True absorption for channels 3-4.
n_pft (int, optional) – Number of Plant Functional Types. Default is 15.
n_bands (int, optional) – Number of spectral bands. Default is 2 (VIS and NIR).
n_chans (int, optional) – Number of output channels. Default is 4.
output_dir (str, optional) – Directory to save diagnostic plots. Default is “./results”.
prefix (str, optional) – Prefix for output filenames. Default is “diagnostics”.
logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.
Notes
The function generates 1 aggregated hexbin plot, 1 aggregated line plot, and for each selected PFT (up to 8 randomly selected): hexbin and line plots per band (VIS and NIR). Total plots = 1 + (selected_pfts * n_bands * 2) where selected_pfts = min(8, n_pft).
Examples
>>> plot_all_diagnostics( ... predicts=model_outputs, ... targets=targets, ... n_pft=15, ... n_bands=2, ... output_dir="./diagnostics", ... prefix="experiment_1" ... )
- rtnn.diagnostics.plot_metric_histories(train_history, valid_history, filename='training_validation_metrics.png', logger=None)[source]
Plot training and validation metrics over epochs.
Creates a multi-panel figure showing the evolution of various metrics (e.g., NMAE, NMSE, R2, bias) over training epochs.
- Parameters:
train_history (dict) – Dictionary with metric names as keys and lists of training values. Example: {“nmae”: [0.1, 0.08, 0.06], “r2”: [0.85, 0.88, 0.91]}
valid_history (dict) – Dictionary with metric names as keys and lists of validation values. Same structure as train_history.
filename (str, optional) – Output filename. Default is “training_validation_metrics.png”.
logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.
Notes
Metrics are plotted on a logarithmic scale for better visualization of convergence
Each metric gets its own panel arranged in a grid (3 columns)
Blue lines: training metrics
Orange lines: validation metrics
Grid is enabled for all panels
Examples
>>> train_history = {"nmae": [0.15, 0.12, 0.09, 0.07], ... "nmse": [0.08, 0.06, 0.04, 0.03]} >>> valid_history = {"nmae": [0.16, 0.13, 0.10, 0.08], ... "nmse": [0.09, 0.07, 0.05, 0.04]} >>> plot_metric_histories( ... train_history=train_history, ... valid_history=valid_history, ... filename="metrics.png", ... logger=logger ... )
- rtnn.diagnostics.plot_loss_histories(train_loss, valid_loss, filename='training_validation_loss.png', logger=None)[source]
Plot training and validation loss over epochs.
Creates a single-panel figure showing the loss evolution during training.
- Parameters:
train_loss (list or array) – Training loss values over epochs.
valid_loss (list or array) – Validation loss values over epochs.
filename (str, optional) – Output filename. Default is “training_validation_loss.png”.
logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.
Notes
Uses logarithmic scale for y-axis to visualize exponential decay
Blue line: training loss
Orange line: validation loss
Includes grid for better readability
Both lines share the same x-axis (epoch number)
Examples
>>> train_losses = [0.5, 0.3, 0.2, 0.15, 0.12] >>> valid_losses = [0.55, 0.35, 0.25, 0.18, 0.15] >>> plot_loss_histories( ... train_loss=train_losses, ... valid_loss=valid_losses, ... filename="loss_curves.png", ... logger=logger ... )
- rtnn.diagnostics.plot_spatial_temporal_density(sindex_tracker, tindex_tracker, mode='train', save_dir='./tests_plots', filename='density_scatter', figsize=(10, 10), logger=None)[source]
Plot a density scatter plot of spatial index vs temporal index with marginal histograms.
This function creates a 2D density scatter plot (hexbin) showing the distribution of spatial indices (processor ranks) across temporal indices, with: - Right plot: Histogram of temporal index distribution (horizontal bars) - Top plot: Histogram of spatial index distribution (vertical bars)
- Parameters:
sindex_tracker (list or array-like) – List of spatial indices (processor ranks) for each data sample.
tindex_tracker (list or array-like) – List of temporal indices for each data sample.
mode (str, optional) – Dataset mode identifier (“train”, “validation”, “test”). Used in filename. Default is “train”.
save_dir (str, optional) – Directory path where the plot will be saved. Default is “./tests_plots”.
filename (str, optional) – Base name for the output file. Default is “density_scatter”.
figsize (tuple, optional) – Figure size as (width, height) in inches. Default is (10, 10).
logger (logging.Logger, optional) – Logger instance for logging messages. If None, no logging is performed.
- Returns:
Path to the saved plot file, or None if no data to plot.
- Return type:
str or None
Notes
The plot layout uses GridSpec with the following structure: - Top-left: Histogram of spatial index distribution (vertical bars) - Top-right: Empty (no plot) - Bottom-left: Main hexbin density scatter plot - Bottom-right: Histogram of temporal index distribution (horizontal bars)
A colorbar is placed below the main plot with log-scaled counts.
This visualization is useful for: - Verifying uniform sampling across spatial processors - Checking temporal coverage of the dataset - Detecting sampling biases in training/validation splits
Examples
>>> # After training, check sampling distribution >>> plot_spatial_temporal_density( ... sindex_tracker=dataset.sindex_tracker, ... tindex_tracker=dataset.tindex_tracker, ... mode="train", ... save_dir="./analysis", ... logger=logger ... )
rtnn.evaluater module
Evaluation utilities for RTnn model assessment.
This module provides comprehensive evaluation tools for radiative transfer neural network models, including custom loss functions, metric computation, and visualization helpers. It is designed to support both training diagnostics and rigorous model validation against physical constraints.
The module implements several key capabilities:
- Custom Loss Functions
Normalized losses (NMAE, NMSE) for scale-invariant error measurement
Standard losses (MSE, MAE, Huber, Smooth L1) for baseline comparison
Weighted and physics-informed losses for multi-objective optimization
- Evaluation Metrics
NMAE: Normalized Mean Absolute Error
NMSE: Normalized Mean Squared Error
R²: Coefficient of determination
MBE: Mean Bias Error
MARE: Mean Absolute Relative Error
GMRAE: Geometric Mean Relative Absolute Error
- Physical Consistency
Absorption rate calculation from flux divergence
Energy conservation penalty (albedo + transmittance + absorptance = 1)
Heating rate computation from net flux profiles
- Data Handling
Normalization/de-normalization for all supported transformation types
Multi-dimensional tensor reshaping (batch, channels, PFTs, bands, levels)
Metric tracking with running statistics
The module follows a modular design where loss functions and metrics are implemented as separate callable classes/functions, allowing easy extension and composition.
Notes
Flux Variable Ordering
Channel 0: collimated albedo (direct upwelling)
Channel 1: collimated transmittance (direct downwelling)
Channel 2: isotropic albedo (diffuse upwelling)
Channel 3: isotropic transmittance (diffuse downwelling)
This ordering matches the ov list in rtnn.dataset.DataPreprocessor.
Absorption Calculation
For collimated: absorption = -d(net_flux)/dz
For isotropic: absorption = -d(net_flux)/dz
Supported Normalization Types
Linear: minmax, standard, robust
Log1p-based: log1p_minmax, log1p_standard, log1p_robust
Sqrt-based: sqrt_minmax, sqrt_standard, sqrt_robust
Examples
Basic usage for model evaluation:
>>> import torch
>>> from rtnn.evaluater import get_loss_function, run_validation
>>>
>>> # Create loss function
>>> args = argparse.Namespace(loss_type='huber', beta_delta=1.0)
>>> criterion = get_loss_function('huber', args)
>>>
>>> # Evaluate model
>>> valid_loss, metrics = run_validation(
... loader=val_loader,
... model=my_model,
... norm_mapping=norm_stats,
... normalization_type=norm_types,
... index_mapping=idxmap,
... device=device,
... args=args,
... epoch=10,
... logger=logger
... )
>>>
>>> print(f"Validation NMAE: {metrics['fluxes_NMAE']:.4f}")
>>> print(f"R² score: {metrics['fluxes_R2']:.4f}")
Using custom metric tracking:
>>> from rtnn.evaluater import MetricTracker, nmae_all
>>>
>>> tracker = MetricTracker()
>>> for batch in dataloader:
... pred, target = model(batch)
... count, value = nmae_all(pred, target)
... tracker.update(value.item(), count)
>>>
>>> mean_nmae = tracker.getmean()
See also
rtnn.dataset.DataPreprocessorData loading and normalization
rtnn.diagnosticsVisualization tools for model predictions
rtnn.modelsNeural network architectures
- class rtnn.evaluater.NMSELoss(*args: Any, **kwargs: Any)[source]
Bases:
ModuleNormalized Mean Squared Error Loss.
Computes MSE normalized by the mean square of the target values. Useful when the scale of the target variable varies across samples or when comparing models trained on different datasets.
- Parameters:
eps (float, optional) – Small constant for numerical stability. Default is 1e-8.
Notes
The loss is calculated as:
NMSE = MSE(pred, target) / (mean(target²) + eps)
This normalization makes the loss scale-invariant, with values typically in the range [0, 1].
Examples
>>> criterion = NMSELoss() >>> predictions = torch.tensor([[2.0, 3.0], [1.0, 2.0]]) >>> targets = torch.tensor([[2.0, 4.0], [1.0, 2.5]]) >>> loss = criterion(predictions, targets) >>> print(loss.item()) 0.0625 # approximates 0.0625 for this example
- class rtnn.evaluater.NMAELoss(*args: Any, **kwargs: Any)[source]
Bases:
ModuleNormalized Mean Absolute Error Loss.
Computes MAE normalized by the mean absolute value of the target. Provides a scale-invariant error metric that is more robust to outliers than NMSE.
- Parameters:
eps (float, optional) – Small constant for numerical stability. Default is 1e-8.
Notes
Values typically range from 0 to 1, with 0 representing perfect predictions and values >1 indicating predictions worse than the trivial zero predictor.
Examples
>>> criterion = NMAELoss() >>> predictions = torch.tensor([[2.0, 3.0], [1.0, 2.0]]) >>> targets = torch.tensor([[2.0, 4.0], [1.0, 2.5]]) >>> loss = criterion(predictions, targets) >>> print(loss.item()) 0.0833 # approximates 0.0833 for this example
- rtnn.evaluater.physics_loss(pred: torch.Tensor, target: torch.Tensor, conservation_penalty: torch.Tensor | None = None, lambda_phys: float = 0.1, delta: float = 1.0) torch.Tensor[source]
Combined Huber loss + energy conservation penalty (improvement E).
- The four output variables satisfy:
albedo + transmittance + absorptance = 1
For collimated: collim_alb + collim_tran + collim_abs = 1 For isotropic: isotrop_alb + isotrop_tran + isotrop_abs = 1
This function enforces the constraint as a soft penalty. You can pass
pred_absif your model also predicts absorptance; otherwise the penalty is computed implicitly as (1 - alb - tran).- Parameters:
pred (torch.Tensor shape (B, 4, L)) – Model predictions: [collim_alb, collim_tran, isotrop_alb, isotrop_tran] (channel order matches your
ovlist in DataPreprocessor).target (torch.Tensor shape (B, 4, L)) – Ground-truth targets.
pred_abs (torch.Tensor or None shape (B, 2, L)) – If provided, predicted absorptance for [collimated, isotropic]. If None, absorptance is inferred as (1 - alb - tran).
lambda_phys (float) – Weight of the energy conservation penalty relative to Huber loss.
delta (float) – Huber loss delta parameter.
- Return type:
torch.Tensor scalar loss value.
- class rtnn.evaluater.MetricTracker[source]
Bases:
objectA utility class for tracking and computing statistics of metric values.
This class maintains running sums of metric values and their squares, allowing incremental updates and computation of mean and standard deviation. It is particularly useful for aggregating metrics across multiple batches during evaluation.
Examples
>>> tracker = MetricTracker() >>> tracker.update(10.0, 5) # value=10.0, count=5 samples >>> tracker.update(20.0, 3) # value=20.0, count=3 samples >>> print(tracker.getmean()) # (10*5 + 20*3) / (5+3) = 110/8 = 13.75 13.75 >>> print(tracker.getstd()) 5.0 # computed from variance >>> print(tracker.getsqrtmean()) 3.7080992435478315
- getmean()[source]
Calculate the mean of all tracked values.
- Returns:
Weighted mean of all values: total_value / total_count
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- getstd()[source]
Calculate the standard deviation of all tracked values.
- Returns:
Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2)
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- getsqrtmean()[source]
Calculate the square root of the mean of all tracked values.
- Returns:
Square root of the weighted mean: sqrt(total_value / total_count)
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- rtnn.evaluater.get_loss_function(loss_type, args, logger=None)[source]
Factory function to instantiate the requested loss function.
- Parameters:
loss_type (str) –
Type of loss function. Options:
- Standard losses:
’mse’: Mean Squared Error
’mae’: Mean Absolute Error
- Normalized losses:
’nmae’: Normalized Mean Absolute Error
’nmse’: Normalized Mean Squared Error
- Robust losses:
’smoothl1’: Smooth L1 Loss (Huber with beta)
’huber’: Huber Loss with delta parameter
args (argparse.Namespace) – Arguments containing loss-specific parameters: - For ‘huber’/’smoothl1’: requires args.beta_delta - For composite losses: may require args.beta
logger (logging.Logger, optional) – Logger for informational messages. If None, no logging occurs.
- Returns:
Initialized loss function.
- Return type:
- Raises:
ValueError – If loss_type is not supported or required parameters are missing.
Examples
>>> import argparse >>> args = argparse.Namespace(beta_delta=1.0) >>> criterion = get_loss_function('huber', args) >>> loss = criterion(predictions, targets)
>>> args = argparse.Namespace() >>> criterion = get_loss_function('mse', args)
- rtnn.evaluater.mse_all(pred, true)[source]
Compute Mean Squared Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mse_value)
- Return type:
- rtnn.evaluater.mbe_all(pred, true)[source]
Compute Mean Bias Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mbe_value)
- Return type:
- rtnn.evaluater.mae_all(pred, true)[source]
Compute Mean Absolute Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mae_value)
- Return type:
- rtnn.evaluater.r2_all(pred, true)[source]
Calculate R2 (coefficient of determination) between predicted and true values.
Computes the R2 metric and returns both the number of elements and the R2 value.
- Parameters:
pred (torch.Tensor) – Predicted values from the model
true (torch.Tensor) – Ground truth values
- Returns:
(num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score
- Return type:
Notes
R2 is calculated as:
R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2)
This implementation is fully torch-based and works on CPU and GPU.
- rtnn.evaluater.nmae_all(pred, true)[source]
Compute Normalized Mean Absolute Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, nmae_value)
- Return type:
- rtnn.evaluater.nmse_all(pred, true)[source]
Compute Normalized Mean Squared Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, nmse_value)
- Return type:
- rtnn.evaluater.mare_all(pred, true)[source]
Compute Mean Absolute Relative Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mare_value)
- Return type:
- rtnn.evaluater.gmrae_all(pred, true)[source]
Compute Geometric Mean Relative Absolute Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, gmrae_value)
- Return type:
- rtnn.evaluater.unnorm_mpas(pred, targ, norm_mapping, normalization_type, idxmap)[source]
Reverse normalization for 5D tensors.
This function converts normalized predictions and targets back to physical units using stored normalization statistics. It supports all normalization types defined in DataPreprocessor.
- Parameters:
pred (torch.Tensor) – Normalized predictions. Shape: (batch, 4, n_pft, n_bands, seq_length)
targ (torch.Tensor) – Normalized targets. Same shape as pred.
norm_mapping (dict) – Dictionary containing normalization statistics for each variable.
normalization_type (dict) – Dictionary specifying normalization type for each variable.
idxmap (dict) – Mapping from channel indices (0-3) to variable names.
- Returns:
(upred, utarg) where: - upred (torch.Tensor): Unnormalized predictions - utarg (torch.Tensor): Unnormalized targets
- Return type:
- Raises:
ValueError – If normalization type is not supported.
Notes
- The function handles the following transformations:
Linear: x’ = (x - mean)/std or (x - min)/(max - min)
Log1p: x_norm = (log(1+x) - log_mean)/log_std
Sqrt: x_norm = (sqrt(x) - sqrt_mean)/sqrt_std
The reverse operation is applied to recover physical values.
- rtnn.evaluater.conservation_residual(alb, tran, abs_flux)[source]
Compute energy conservation residual for a given level.
- Parameters:
alb (torch.Tensor) – Albedo (upwelling flux). Shape (batch, 1, n_pft, n_bands, seq)
tran (torch.Tensor) – Transmittance (downwelling flux). Same shape as alb.
abs_flux (torch.Tensor) – Absorptance (absorbed flux). Shape (batch, 1, n_pft, n_bands, seq-1)
- Returns:
Squared conservation residual. Shape (batch, 1, n_pft, n_bands, seq-1)
- Return type:
Notes
- The function averages alb and tran to layer centers before computing:
residual = (alb_center + tran_center + abs_flux - 1)²
- This enforces the physical constraint that:
upwelling + downwelling + absorption = total incoming radiation = 1
- rtnn.evaluater.calc_abs(pred, targ, p=None)[source]
Calculate absorption rates from flux predictions.
Computes absorption rates for both collimated (direct) and isotropic (diffuse) components by calculating the divergence of net flux.
- Parameters:
pred (torch.Tensor) – Predicted fluxes. Shape (batch, 4, n_pft, n_bands, seq_length) Channel order: [collim_alb, collim_tran, isotrop_alb, isotrop_tran]
targ (torch.Tensor) – Target fluxes. Same shape as pred.
p (torch.Tensor, optional) – Pressure levels. If provided, computes heating rates. Shape can be (seq_length,) or (batch, seq_length). Default is None.
- Returns:
(abs12_pred, abs12_targ, abs34_pred, abs34_targ, conservation_penalty)
- Return type:
- rtnn.evaluater.heating_rate(up, down, p=None)[source]
Calculate heating rate from upwelling and downwelling fluxes.
- Parameters:
up (torch.Tensor) – Upwelling flux. Shape (batch, 1, n_pft, n_bands, seq_length)
down (torch.Tensor) – Downwelling flux. Same shape as up.
p (torch.Tensor, optional) – Pressure levels. If provided, computes heating rate in K/day. If None, computes negative flux divergence. Default is None.
- Returns:
torch.Tensor – Heating rate or flux divergence. Shape (batch, 1, n_pft, n_bands, seq_length-1)
Notes —–
The calculation involves –
net = up - down (net flux)
dnet = net - roll(net, 1) (vertical flux divergence)
If p is provided, convert to heating rate using: hr = -dnet/dp * (g * 8.64e4) / (cp * 100) where g = 9.8066 m/s², cp ≈ 1004 J/(kg·K) The factor 8.64e4 converts from W/m²/Pa to K/day
- rtnn.evaluater.run_validation(loader, model, norm_mapping, normalization_type, index_mapping, device, args, epoch, logger=None, base_dir='./results', n_pft=15, n_bands=2, n_chans=4)[source]
Evaluate model accuracy on LSM dataset.
Performs comprehensive evaluation including:
Loss computation for main fluxes and absorption rates
Metric calculation (NMAE, NMSE, R²) for fluxes and absorption
Optional plotting of predictions vs targets (every 10 epochs)
Energy conservation verification
- Parameters:
loader (torch.utils.data.DataLoader) – Data loader for evaluation dataset.
model (torch.nn.Module) – Trained model to evaluate.
norm_mapping (dict) – Normalization statistics for variables.
normalization_type (dict) – Normalization types per variable.
index_mapping (dict) – Mapping from channel indices (0-3) to variable names.
device (torch.device) – Device to run evaluation on (cuda or cpu).
args (argparse.Namespace) – Arguments containing loss_type, beta, beta_delta, and num_epochs.
epoch (int) – Current epoch number (for plotting schedule).
logger (logging.Logger, optional) – Logger for informational messages. If None, no logging occurs.
base_dir (str, optional) – Directory to save diagnostic plots. Default is “./results”.
n_pft (int, optional) – Number of Plant Functional Types. Default is 15.
n_bands (int, optional) – Number of spectral bands. Default is 2 (VIS, NIR).
n_chans (int, optional) – Number of output channels. Default is 4.
- Returns:
(valid_loss, valid_metrics)
- Return type:
Notes
The evaluation performs the following steps: 1. Iterates through validation loader 2. Computes predictions and reshapes to 5D tensors 3. De-normalizes predictions and targets to physical units 4. Calculates absorption rates and conservation penalties 5. Computes metrics for fluxes and absorption 6. Optionally generates diagnostic plots (epoch % 10 == 0 or final epoch)
- The combined loss includes both flux and absorption terms weighted by β:
total_loss = (1-β)*loss_fluxes + β*(loss_abs12 + loss_abs34)
Examples
>>> valid_loss, metrics = run_validation( ... loader=val_loader, ... model=my_model, ... norm_mapping=norm_stats, ... normalization_type=norm_types, ... index_mapping=idxmap, ... device=torch.device('cuda'), ... args=args, ... epoch=10, ... logger=logger ... ) >>> print(f"Validation NMAE: {metrics['fluxes_NMAE']:.4f}") >>> print(f"R² score: {metrics['fluxes_R2']:.4f}")
rtnn.logger module
Logging utilities for RTnn model training and evaluation.
This module provides a rich, customizable logging system for radiative transfer neural network workflows. It combines Python’s standard logging with the Rich library to deliver visually appealing, informative console output while maintaining file-based logging for archival purposes.
The Logger class offers multiple log levels and specialized formatting for different types of messages:
Info messages: General information (cyan)
Warning messages: Non-critical issues (yellow with warning emoji)
Success messages: Successful completions (green with checkmark)
Error messages: Critical failures (red panels)
Exception messages: Full traceback displays (multi-panel format)
Step messages: Pipeline stage transitions (magenta)
Task start messages: Formatted task initiation banners
Key features include:
Dual output: console (Rich-formatted) and file (plain text)
Progress bars for long-running operations
Structured task tracking with start/end logging
Metrics aggregation and display
Exception traceback visualization with code snippets
Configurable output (console/file/both)
Pretty printing support
- class rtnn.logger.Logger(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
Bases:
object- __init__(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
- start_task(task_name: str, description: str = '', **meta)[source]
Display a clearly formatted ‘task start’ message with good spacing.
rtnn.main module
RTnn (Radiative Transfer Neural Network) Training Pipeline
This module provides the main entry point for training neural network models for radiative transfer calculations in climate modeling. It supports various model architectures including LSTM, GRU, Transformer, and FCN.
The training pipeline includes:
Data loading and preprocessing from NetCDF files
Model initialization and configuration
Training loop with progress tracking
Validation and metric computation
Checkpoint saving and model persistence
Visualization and logging
Module Overview
This module implements a complete machine learning pipeline for radiative transfer modeling in climate science. The pipeline is designed to handle spatio-temporal data from NetCDF files, apply appropriate normalization, train various neural network architectures, and evaluate model performance using physics-informed metrics.
Key Features
Data Handling
Automatic reading of NetCDF files with year-based filtering
Support for multiple normalization schemes (log1p_standard, standard, minmax)
Temporal batching and sequence generation
Multi-year training and held-out year testing
Model Architectures
LSTM (Long Short-Term Memory)
GRU (Gated Recurrent Unit)
Transformer with attention mechanism
FCN (Fully Connected Networks)
MLP with residual connections
Custom RT-specific architectures
Training Features
Configurable loss functions (MSE, MAE, NMSE, NMAE, SmoothL1, Huber)
Physics-informed weighted loss combining flux and absorption terms
Learning rate scheduling with ReduceLROnPlateau
Multi-GPU support with DataParallel
Checkpoint saving and resumption
Evaluation Metrics
NMAE (Normalized Mean Absolute Error)
NMSE (Normalized Mean Squared Error)
R² (Coefficient of Determination)
Conservation penalty for physical consistency
Output and Visualization
TensorBoard integration for real-time monitoring
Training/validation loss plots
Metric history visualization
Spatial-temporal density scatter plots
Automatic checkpoint saving (best, epoch, final)
Data Flow
Parse command-line arguments (parse_args)
Setup directory structure and logging (setup_directories_and_logging)
Configure device and random seeds (setup_device_and_seed)
Load and preprocess data (get_data_files, create_datasets_and_loaders)
Compute normalization statistics (create_normalization_mapping)
Initialize model architecture (initialize_model)
Load checkpoint if resuming (load_checkpoint_if_requested)
Run inference or training loop
Generate plots and save results
- rtnn.main.parse_years(year_str)[source]
Parse a year string into a list of integers.
Supports hyphen-separated ranges (e.g., “1995-1999”) and comma-separated lists (e.g., “1995,1997,1999”). Returns a list of integers.
- Parameters:
year_str (str) – String containing years in range or comma-separated format.
- Returns:
List of parsed years.
- Return type:
Examples
>>> parse_years("1995-1999") [1995, 1996, 1997, 1998, 1999]
>>> parse_years("1995,1997,1999") [1995, 1997, 1999]
- rtnn.main.parse_args()[source]
Parse command-line arguments for RTnn model training.
Defines and parses command-line arguments required to configure and run the Radiative Transfer Neural Network (RTnn) training pipeline. This includes model architecture parameters, training hyperparameters, data configuration, and output settings.
- Returns:
Object containing parsed command-line arguments, grouped as follows:
- Model architecture
- typestr
Model type (e.g., “lstm”, “gru”, “fcn”, “fullyconnected”, “transformer”, “cnn”, “mlp”).
- hidden_sizeint
Size of hidden layers.
- num_layersint
Number of model layers.
- seq_lengthint
Length of input sequence.
- feature_channelint
Number of input feature channels.
- output_channelint
Number of output channels.
- embed_sizeint
Embedding dimension for transformer models.
- nheadint
Number of attention heads (transformer).
- forward_expansionint
Expansion factor for feed-forward layers.
- dropoutfloat
Dropout rate.
- Training hyperparameters
- batch_sizeint
Number of samples per batch.
- tbatchint
Temporal batch length.
- num_epochsint
Number of training epochs.
- learning_ratefloat
Initial learning rate.
- loss_typestr
Loss function (e.g., “mse”, “mae”, “nmae”, “nmse”, “wmse”, “logcosh”, “smoothl1”, “huber”).
- betafloat
Weighting factor for loss components.
- beta_deltafloat
Delta parameter for Huber or SmoothL1 loss.
- num_workersint
Number of data loader workers.
- Data configuration
- train_data_filesstr
Path or pattern for training data files.
- test_data_filesstr
Path or pattern for testing data files.
- train_yearsstr
Training years (comma-separated or range, e.g., “1995-1999”).
- test_yearstr
Test year or range.
- normstr
Normalization scheme (e.g., “log1p_standard”, “standard”, “minmax”, “none”).
- dataset_typestr
Dataset type (e.g., “LSM”, “RTM”).
- Output configuration
- root_dirstr
Root directory for all operations.
- main_folderstr
Main output folder name.
- sub_folderstr
Sub-folder name for the current run.
- prefixstr
Prefix for saved files.
- model_namestr
Custom model name (auto-generated if empty).
- save_modelbool
Whether to save model checkpoints.
- save_checkpoint_namestr
Base name for saved checkpoints.
- save_per_samplesint
Save checkpoint every N samples.
- load_modelbool
Whether to load an existing model.
- load_checkpoint_namestr
Checkpoint file to load.
- inferencebool
Run in inference-only mode.
- Return type:
Examples
>>> args = parse_args() >>> args.type 'lstm' >>> args.batch_size 16
Command line usage
$ rtnn –type lstm –hidden_size 128 –num_layers 3 –batch_size 32
- rtnn.main.setup_directories_and_logging(args)[source]
Set up directory structure and logging infrastructure for experiments.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
- Returns:
paths (EasyDict) – Dictionary containing paths to created directories.
logger (Logger) – Configured logger instance.
- rtnn.main.log_configuration(args, paths, logger)[source]
Log all configuration parameters to the provided logger.
- Parameters:
args (argparse.Namespace) – Configuration object containing all experiment parameters.
paths (EasyDict) – Dictionary containing paths to various experiment directories.
logger (Logger) – Logger instance for outputting configuration information.
- rtnn.main.setup_device_and_seed(args, logger)[source]
Set up device (GPU/CPU) and random seeds for reproducibility.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
logger (Logger) – Logger instance.
- Returns:
Device to use for computations.
- Return type:
- rtnn.main.get_data_files(args, logger)[source]
Get training and testing data files based on year specifications.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
logger (Logger) – Logger instance.
- Returns:
(train_files, test_files) lists of file paths.
- Return type:
- rtnn.main.create_normalization_mapping(train_files, paths, logger)[source]
Create normalization mapping from training data.
- rtnn.main.create_datasets_and_loaders(args, train_files, test_files, norm_mapping, logger)[source]
Create datasets and data loaders for training and validation.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
train_files (list) – Training file paths.
test_files (list) – Test file paths.
norm_mapping (dict) – Normalization statistics.
logger (Logger) – Logger instance.
- Returns:
(train_loader, test_loader, train_dataset, test_dataset)
- Return type:
- rtnn.main.initialize_model(args, device, logger)[source]
Initialize the model architecture.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
device (torch.device) – Device to place model on.
logger (Logger) – Logger instance.
- Returns:
Initialized model.
- Return type:
- rtnn.main.load_checkpoint_if_requested(args, model, optimizer, paths, device, logger)[source]
Load model checkpoint if requested using ModelUtils.load_training_checkpoint().
This function leverages ModelUtils.load_training_checkpoint() which handles: - DataParallel compatibility - Loading model and optimizer states - Extracting training state (epoch, samples, metrics, etc.)
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
model (torch.nn.Module) – Model to load weights into.
optimizer (torch.optim.Optimizer) – Optimizer to restore state.
paths (EasyDict) – Directory paths.
device (torch.device) – Device for loading.
logger (Logger) – Logger instance.
- Returns:
- (start_epoch, samples_processed, batches_processed, best_val_loss,
best_epoch, checkpoint, train_loss_history, valid_loss_history, valid_metrics_history)
- Return type:
- rtnn.main.train_epoch(model, train_loader, optimizer, loss_func, metric_funcs, metric_names, output_keys, train_metrics, train_loss_tracker, norm_mapping, normalization_type, index_mapping, device, args, epoch, writer, global_step, logger, n_pft=15, n_bands=2, n_chans=4)[source]
Train for one epoch.
- Returns:
(average_train_loss, updated_global_step)
- Return type:
rtnn.model_loader module
rtnn.model_loader - Model Factory for Radiative Transfer Neural Networks
This module provides a factory function load_model() that instantiates various neural network architectures for radiative transfer calculations in climate modeling. It serves as the central point for model creation across the RTnn framework.
Module Overview
The model loader module implements a factory pattern to abstract away the details of model instantiation. Based on a configuration object (typically from command-line arguments), it returns an initialized PyTorch model ready for training or inference.
Supported Model Architectures
Recurrent Neural Networks (RNN)
LSTM (Long Short-Term Memory): Bidirectional LSTM with Conv1d projection
GRU (Gated Recurrent Unit): Bidirectional GRU with Conv1d projection
Best for: Sequential data with temporal dependencies
Transformer Models
EncoderTorch: PyTorch-native transformer encoder
Features: Multi-head self-attention, positional embeddings
Best for: Long-range dependencies and parallel processing
Fully Connected Networks (FCN)
Standard FCN: Multi-layer perceptron with configurable depth
PINN: Physics-Informed Neural Network with two-stream architecture for vertical profiles
Best for: Non-sequential, feature-based inputs
Multi-Layer Perceptrons (MLP)
MLP: Standard MLP with batch/layer normalization options
MLPResidual: MLP with residual connections between layers
Features: Positional embeddings, dropout, multiple activation functions
Best for: Flexible architecture with skip connections
Architecture Details
All models are designed to handle the specific requirements of radiative transfer problems: - Input shape: (batch, seq_len, feature_channel) - Output shape: (batch, seq_len, output_channel) - Physical constraints: Conservation of energy, positive outputs
The models output four channels corresponding to: - collim_alb (collimated albedo) - collim_tran (collimated transmittance) - isotrop_alb (isotropic albedo) - isotrop_tran (isotropic transmittance)
Data Flow
Parse configuration (args)
Determine model type from args.type
Extract model-specific parameters
Instantiate appropriate model class
Return initialized model
- rtnn.model_loader.load_model(args)[source]
Factory function to instantiate a neural network model from configuration.
This function builds and returns a PyTorch model based on the value of
args.type. It supports multiple architectures including recurrent, convolutional, transformer-based, and fully connected models.Supported model families
LSTM / GRU: Bidirectional recurrent models with Conv1d projection head
Transformer: PyTorch Transformer encoder with positional embeddings
FCN / FullyConnected: Fully connected feedforward network for sequences
VRT / VerticalRT: Physics-inspired vertical column model
MLP: Flexible multilayer perceptron with optional embeddings and residuals
MLPResidual: Deep residual MLP with layer-wise skip connections
- returns:
Instantiated PyTorch model corresponding to the requested architecture.
- rtype:
torch.nn.Module
- raises ValueError:
If
args.typedoes not match any supported model.
rtnn.model_utils module
Model utility functions for PyTorch training workflows.
This module provides a collection of helper utilities for inspecting models, analyzing parameter distributions, and managing checkpoints during training. It is designed to standardize common operations such as saving/loading model states, logging architecture details, and maintaining reproducible training artifacts.
The utilities support both single-GPU and multi-GPU (DataParallel) setups and include safeguards for compatibility when loading checkpoints across different hardware configurations.
Features
Parameter counting (total and trainable)
Layer-wise parameter inspection
Model structure logging
Checkpoint saving and loading
Full training state persistence
Emergency checkpointing for crash recovery
DataParallel-aware state dictionary handling
Notes
Checkpoints include both model weights and optimizer states to enable seamless training resumption.
File naming conventions are automatically adapted based on checkpoint type (e.g., epoch, best, final, emergency).
When using torch.nn.DataParallel, the module automatically adjusts state dictionary keys to ensure compatibility between wrapped and unwrapped models.
Dependencies
torch
os
datetime
Examples
Basic usage:
>>> from model_utils import ModelUtils
>>> counts = ModelUtils.get_parameter_number(model)
>>> ModelUtils.print_model_layers(model)
Saving and loading checkpoints:
>>> state = {
... "state_dict": model.state_dict(),
... "optimizer": optimizer.state_dict(),
... "epoch": epoch,
... }
>>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar")
>>> checkpoint = torch.load("checkpoint.pth.tar")
>>> ModelUtils.load_checkpoint(checkpoint, model, optimizer)
Saving a full training checkpoint:
>>> ModelUtils.save_training_checkpoint(
... model, optimizer, epoch, samples_processed, batches_processed,
... train_loss_history, valid_loss_history, valid_metrics_history,
... best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss,
... args, paths, logger, checkpoint_type="best"
... )
- class rtnn.model_utils.ModelUtils[source]
Bases:
objectUtility class for model inspection, checkpointing, and memory profiling.
This class provides static methods for common model operations including parameter counting, memory usage analysis, checkpoint management, and model inspection.
Examples
>>> utils = ModelUtils() >>> param_counts = ModelUtils.get_parameter_number(model) >>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar", logger)
- static get_parameter_number(model, logger=None)[source]
Calculate the total and trainable number of parameters in a model.
- Parameters:
model (torch.nn.Module) – PyTorch model to inspect
logger (Logger, optional) – Logger instance for output, by default None
- Returns:
Dictionary containing: - ‘Total’: Total number of parameters - ‘Trainable’: Number of trainable parameters
- Return type:
Examples
>>> model = torch.nn.Linear(10, 5) >>> counts = ModelUtils.get_parameter_number(model, logger)
- static print_model_layers(model, logger=None)[source]
Print model parameter names along with their gradient requirements.
- Parameters:
model (torch.nn.Module) – PyTorch model to inspect
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> model = torch.nn.Sequential( ... torch.nn.Linear(10, 5), ... torch.nn.ReLU(), ... torch.nn.Linear(5, 1) ... ) >>> ModelUtils.print_model_layers(model, logger)
- static save_checkpoint(state, filename='checkpoint.pth.tar', logger=None)[source]
Save model and optimizer state to a file.
- Parameters:
state (dict) – Dictionary containing model state_dict and other training information. Typically includes: - ‘state_dict’: Model parameters - ‘optimizer’: Optimizer state - ‘epoch’: Current epoch - ‘loss’: Current loss value
filename (str, optional) – File path to save the checkpoint, by default “checkpoint.pth.tar”
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> state = { ... 'state_dict': model.state_dict(), ... 'optimizer': optimizer.state_dict(), ... 'epoch': epoch, ... 'loss': loss ... } >>> ModelUtils.save_checkpoint(state, 'model_checkpoint.pth.tar', logger)
- static load_checkpoint(checkpoint, model, optimizer=None, logger=None)[source]
Load model and optimizer state from a checkpoint file.
- Parameters:
checkpoint (dict) – Loaded checkpoint dictionary
model (torch.nn.Module) – Model to load weights into
optimizer (torch.optim.Optimizer, optional) – Optimizer to restore state, by default None
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> checkpoint = torch.load('model_checkpoint.pth.tar') >>> ModelUtils.load_checkpoint(checkpoint, model, optimizer, logger)
- static load_training_checkpoint(checkpoint_path, model, optimizer, device, logger=None)[source]
Load comprehensive training checkpoint.
- Parameters:
checkpoint_path (str) – Path to checkpoint file
model (torch.nn.Module) – Model to load weights into
optimizer (torch.optim.Optimizer) – Optimizer to restore state
device (torch.device) – Device to load checkpoint to
logger (Logger, optional) – Logger instance for output
- Returns:
(epoch, samples_processed, batches_processed, best_val_loss, best_epoch, checkpoint)
- Return type:
- static count_parameters_by_layer(model, logger=None)[source]
Count parameters for each layer in the model.
- Parameters:
model (torch.nn.Module) – PyTorch model to analyze
logger (Logger, optional) – Logger instance for output, by default None
- Returns:
Dictionary with layer names as keys and parameter counts as values
- Return type:
Examples
>>> layer_params = ModelUtils.count_parameters_by_layer(model, logger)
- static log_model_summary(model, input_shape=None, logger=None)[source]
Log comprehensive model summary including parameters and architecture.
- Parameters:
model (torch.nn.Module) – PyTorch model to summarize
input_shape (tuple, optional) – Input shape for memory analysis, by default None
logger (Logger, optional) – Logger instance for output, by default None
- static save_training_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, args, paths, logger, checkpoint_type='epoch', save_full_model=True)[source]
Save comprehensive training checkpoint with consistent formatting.
- Parameters:
model (torch.nn.Module) – Model to save
optimizer (torch.optim.Optimizer) – Optimizer to save
epoch (int) – Current epoch
samples_processed (int) – Number of samples processed so far
batches_processed (int) – Number of batches processed so far
train_loss_history (list) – History of training losses
valid_loss_history (list) – History of validation losses
valid_metrics_history (dict) – History of validation metrics
best_val_loss (float) – Best validation loss so far
best_epoch (int) – Epoch with best validation loss
avg_val_loss (float) – Current epoch validation loss
avg_epoch_loss (float) – Current epoch training loss
args (argparse.Namespace) – Command line arguments
paths (EasyDict) – Directory paths
logger (Logger) – Logger instance
checkpoint_type (str) – Type of checkpoint: “samples”, “epoch”, “best”, “final”
save_full_model (bool) – Whether to also save the full model separately
- Returns:
(checkpoint_filename, full_model_filename)
- Return type:
Examples
>>> checkpoint_file, full_model_file = ModelUtils.save_training_checkpoint( ... model, optimizer, epoch, samples_processed, batches_processed, ... train_loss_history, valid_loss_history, valid_metrics_history, ... best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, ... args, paths, logger, checkpoint_type="best" ... )
rtnn.utils module
File and lightweight configuration utilities.
This module provides helper classes for simplified file system operations and convenient dictionary handling with attribute-style access. It is designed to reduce boilerplate code when working with directories, files, and configuration-like data structures in Python projects.
The module includes:
EasyDict: A dictionary subclass enabling attribute-style access (e.g.,cfg.keyinstead ofcfg["key"]).FileUtils: Static utility methods for creating directories and files.
Features
Attribute-style access to dictionary keys
Minimal and dependency-free implementation
Safe directory creation (no error if directory already exists)
Simple file creation utility
Lightweight and suitable for configuration management
Notes
EasyDictraisesAttributeErrorwhen accessing missing keys, making it behave more like standard Python objects.FileUtils.makedirwill create nested directories if needed.FileUtils.makefilecreates an empty file if it does not exist, and does nothing if it already exists.
Dependencies
os
typing
Examples
Using EasyDict:
>>> cfg = EasyDict()
>>> cfg.learning_rate = 0.001
>>> cfg.batch_size = 32
>>> print(cfg.learning_rate)
0.001
>>> print(cfg["batch_size"])
32
Using FileUtils:
>>> FileUtils.makedir("outputs/logs")
>>> FileUtils.makefile("outputs/logs", "train.log")
Combined usage:
>>> paths = EasyDict()
>>> paths.output_dir = "outputs"
>>> FileUtils.makedir(paths.output_dir)
- class rtnn.utils.EasyDict[source]
Bases:
dictA dictionary subclass that allows for attribute-style access to its items. This class extends the built-in dict and overrides the __getattr__, __setattr__, and __delattr__ methods to enable accessing dictionary keys as attributes. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm
- class rtnn.utils.FileUtils[source]
Bases:
objectUtility class for file and directory operations.
- __init__()[source]
Initialize the FileUtils class. This class does not maintain any state, so the constructor is empty.
rtnn.version module
Version information for rtnn.