Source code for rtnn.model_loader

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

"""
rtnn.model_loader - Model Factory for Radiative Transfer Neural Networks

This module provides a factory function `load_model()` that instantiates
various neural network architectures for radiative transfer calculations
in climate modeling. It serves as the central point for model creation
across the RTnn framework.

Module Overview
---------------
The model loader module implements a factory pattern to abstract away the
details of model instantiation. Based on a configuration object (typically
from command-line arguments), it returns an initialized PyTorch model ready
for training or inference.

Supported Model Architectures
-----------------------------
1. **Recurrent Neural Networks (RNN)**

   - LSTM (Long Short-Term Memory): Bidirectional LSTM with Conv1d projection
   - GRU (Gated Recurrent Unit): Bidirectional GRU with Conv1d projection
   - Best for: Sequential data with temporal dependencies

2. **Transformer Models**

   - EncoderTorch: PyTorch-native transformer encoder
   - Features: Multi-head self-attention, positional embeddings
   - Best for: Long-range dependencies and parallel processing

3. **Fully Connected Networks (FCN)**

   - Standard FCN: Multi-layer perceptron with configurable depth
   - PINN: Physics-Informed Neural Network with two-stream architecture for vertical profiles
   - Best for: Non-sequential, feature-based inputs

4. **Multi-Layer Perceptrons (MLP)**

   - MLP: Standard MLP with batch/layer normalization options
   - MLPResidual: MLP with residual connections between layers
   - Features: Positional embeddings, dropout, multiple activation functions
   - Best for: Flexible architecture with skip connections

Architecture Details
--------------------
All models are designed to handle the specific requirements of radiative
transfer problems:
- Input shape: (batch, seq_len, feature_channel)
- Output shape: (batch, seq_len, output_channel)
- Physical constraints: Conservation of energy, positive outputs

The models output four channels corresponding to:
- collim_alb (collimated albedo)
- collim_tran (collimated transmittance)
- isotrop_alb (isotropic albedo)
- isotrop_tran (isotropic transmittance)

Data Flow
---------
1. Parse configuration (args)
2. Determine model type from args.type
3. Extract model-specific parameters
4. Instantiate appropriate model class
5. Return initialized model
"""

from rtnn.models.rnn import RNN_LSTM, RNN_GRU
from rtnn.models.transformer import EncoderTorch
from rtnn.models.fcn import FCN
from rtnn.models.pinn import PINN
from rtnn.models.mlp import MLP, MLPResidual


[docs] def load_model(args): """ Factory function to instantiate a neural network model from configuration. This function builds and returns a PyTorch model based on the value of ``args.type``. It supports multiple architectures including recurrent, convolutional, transformer-based, and fully connected models. Supported model families ------------------------ - LSTM / GRU: Bidirectional recurrent models with Conv1d projection head - Transformer: PyTorch Transformer encoder with positional embeddings - FCN / FullyConnected: Fully connected feedforward network for sequences - VRT / VerticalRT: Physics-inspired vertical column model - MLP: Flexible multilayer perceptron with optional embeddings and residuals - MLPResidual: Deep residual MLP with layer-wise skip connections Returns ------- torch.nn.Module Instantiated PyTorch model corresponding to the requested architecture. Raises ------ ValueError If ``args.type`` does not match any supported model. """ model_type = args.type.lower() if model_type in ["lstm", "gru"]: model_class = RNN_LSTM if model_type == "lstm" else RNN_GRU model = model_class( feature_channel=args.feature_channel, output_channel=args.output_channel, hidden_size=args.hidden_size, num_layers=args.num_layers, ) elif model_type in ["encodertorch", "transformer"]: # New EncoderTorch implementation (PyTorch native transformer) model = EncoderTorch( feature_channel=args.feature_channel, output_channel=args.output_channel, embed_size=args.embed_size, num_layers=args.num_layers, heads=args.nhead, forward_expansion=args.forward_expansion if args.forward_expansion is not None else 4, # Default expansion factor seq_length=args.seq_length, dropout=args.dropout, ) elif model_type in ["pinn"]: model = PINN( feature_channel=args.feature_channel, hidden=args.hidden_size, out_channel=args.output_channel, n_layers=args.seq_length, ) elif model_type in ["fcn", "fullyconnected"]: model = FCN( feature_channel=args.feature_channel, output_channel=args.output_channel, num_layers=args.num_layers, hidden_size=args.hidden_size, seq_length=args.seq_length, dim_expand=0, ) elif model_type in ["mlp"]: # Standard MLP with configurable hidden layers hidden_sizes = getattr(args, "hidden_sizes", [512, 256, 128]) if isinstance(hidden_sizes, str): hidden_sizes = [int(x) for x in hidden_sizes.split(",")] model = MLP( feature_channel=args.feature_channel, output_channel=args.output_channel, seq_length=args.seq_length, hidden_sizes=hidden_sizes, dropout=getattr(args, "dropout", 0.1), use_batch_norm=getattr(args, "use_batch_norm", True), use_layer_norm=getattr(args, "use_layer_norm", False), use_residual=getattr(args, "use_residual", False), activation=getattr(args, "activation", "relu"), use_positional_embedding=getattr(args, "use_positional_embedding", True), positional_embed_dim=getattr(args, "positional_embed_dim", 16), ) elif model_type in ["mlp_residual"]: # MLP with residual connections between layers model = MLPResidual( feature_channel=args.feature_channel, output_channel=args.output_channel, seq_length=args.seq_length, hidden_size=getattr(args, "hidden_size", 256), num_layers=getattr(args, "num_layers", 4), dropout=getattr(args, "dropout", 0.1), ) else: raise ValueError(f"Model type '{args.type}' is not implemented.") return model