Source code for rtnn.models.rnn

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

"""
Bidirectional recurrent neural network models for sequence modeling.

This module provides implementations of bidirectional recurrent neural
networks (RNNs) using Long Short-Term Memory (LSTM) and Gated Recurrent
Unit (GRU) cells. These models are designed for sequence-based data,
such as time series or vertically structured physical profiles.

The module includes:

- BaseRNN: A flexible base class supporting both LSTM and GRU
  architectures with bidirectional processing.
- RNN_LSTM: A specialized LSTM-based model built on BaseRNN.
- RNN_GRU: A specialized GRU-based model built on BaseRNN.

Features
--------

- Bidirectional sequence processing for improved context awareness
- Support for stacked recurrent layers
- Unified interface for LSTM and GRU architectures
- Automatic hidden state initialization
- Final 1D convolution layer for channel-wise output projection
- Compatible with batched inputs and GPU acceleration

Notes
-----

- Inputs are expected in the shape (batch_size, feature_channel, seq_length).
- Internally, inputs are permuted to (batch_size, seq_length, feature_channel)
  to match PyTorch RNN requirements.
- Bidirectional RNNs double the hidden state size, which is handled
  automatically in the final projection layer.
- Hidden states are initialized to zeros at each forward pass.
- The final Conv1d layer maps hidden representations to the desired
  output channels while preserving sequence length.

Dependencies
------------

- torch
- torch.nn
- typing

Examples
--------

Using LSTM-based model::

    >>> model = RNN_LSTM(
    ...     feature_channel=6,
    ...     output_channel=4,
    ...     hidden_size=128,
    ...     num_layers=3
    ... )
    >>> x = torch.randn(16, 6, 10)
    >>> y = model(x)

Using GRU-based model::

    >>> model = RNN_GRU(
    ...     feature_channel=6,
    ...     output_channel=4,
    ...     hidden_size=128,
    ...     num_layers=3
    ... )
    >>> x = torch.randn(16, 6, 10)
    >>> y = model(x)

Using BaseRNN directly::

    >>> model = BaseRNN(
    ...     feature_channel=6,
    ...     output_channel=4,
    ...     hidden_size=64,
    ...     num_layers=2,
    ...     rnn_type="lstm"
    ... )
"""

import torch
import torch.nn as nn
from typing import Tuple, Union


[docs] class BaseRNN(nn.Module): """ Base class for bidirectional RNN modules (LSTM/GRU). This class provides a common interface for both LSTM and GRU models with bidirectional processing and a final 1D convolutional layer to map the hidden states to the desired output channels. Parameters ---------- feature_channel : int Number of input features per time step. output_channel : int Number of output channels (target variables). hidden_size : int Number of hidden units in the RNN layers. num_layers : int Number of stacked RNN layers. rnn_type : str Type of RNN cell, either 'lstm' or 'gru'. Attributes ---------- rnn : nn.LSTM or nn.GRU The bidirectional RNN layer. final : nn.Conv1d Final 1D convolution to project hidden states to output channels. hidden_size : int Number of hidden units. num_layers : int Number of stacked layers. output_channel : int Number of output channels. Examples -------- >>> model = BaseRNN( ... feature_channel=6, ... output_channel=4, ... hidden_size=64, ... num_layers=2, ... rnn_type='lstm' ... ) >>> x = torch.randn(32, 6, 10) # (batch, features, sequence) >>> y = model(x) >>> y.shape torch.Size([32, 4, 10]) """
[docs] def __init__( self, feature_channel: int, output_channel: int, hidden_size: int, num_layers: int, rnn_type: str, ) -> None: """ Initialize the BaseRNN module. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of RNN layers. rnn_type : str Type of RNN ('lstm' or 'gru'). Raises ------ ValueError If rnn_type is not 'lstm' or 'gru'. """ super().__init__() # Validate input if rnn_type not in ["lstm", "gru"]: raise ValueError(f"rnn_type must be 'lstm' or 'gru', got {rnn_type}") self.hidden_size = hidden_size self.num_layers = num_layers self.output_channel = output_channel self.rnn_type = rnn_type # Select RNN class rnn_class = nn.LSTM if rnn_type == "lstm" else nn.GRU # Create bidirectional RNN self.rnn = rnn_class( input_size=feature_channel, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True, ) # Final projection layer self.final = nn.Conv1d( in_channels=2 * hidden_size, # bidirectional doubles hidden size out_channels=output_channel, kernel_size=1, padding=0, bias=True, )
[docs] def init_hidden( self, batch_size: int, device: torch.device ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Initialize the hidden state for the RNN. Parameters ---------- batch_size : int Batch size for the input. device : torch.device Device to create the hidden state on. Returns ------- torch.Tensor or tuple of torch.Tensor For GRU: returns hidden state tensor of shape (2 * num_layers, batch_size, hidden_size) For LSTM: returns tuple (hidden, cell) both of same shape. """ hidden = torch.zeros( 2 * self.num_layers, batch_size, self.hidden_size, device=device, requires_grad=False, ) if self.rnn_type == "lstm": return (hidden, hidden.clone()) return hidden
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the bidirectional RNN. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, feature_channel, seq_length). Returns ------- torch.Tensor Output tensor of shape (batch_size, output_channel, seq_length). Notes ----- The input is permuted to (batch_size, seq_length, feature_channel) for the RNN, then the output is permuted back for the convolution. """ # Permute to (batch, seq, features) for RNN x = x.permute(0, 2, 1) # Initialize hidden state hidden = self.init_hidden(x.size(0), x.device) # Forward through RNN out, _ = self.rnn(x, hidden) # Permute back to (batch, features, seq) for convolution out = out.permute(0, 2, 1) # Final projection return self.final(out)
[docs] class RNN_LSTM(BaseRNN): """ LSTM-based bidirectional RNN model. This class inherits from BaseRNN and configures it to use LSTM cells. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of LSTM layers. Examples -------- >>> model = RNN_LSTM( ... feature_channel=6, ... output_channel=4, ... hidden_size=128, ... num_layers=3 ... ) >>> x = torch.randn(16, 6, 10) >>> y = model(x) >>> print(y.shape) torch.Size([16, 4, 10]) """
[docs] def __init__( self, feature_channel: int, output_channel: int, hidden_size: int, num_layers: int, ) -> None: """ Initialize the LSTM model. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of LSTM layers. """ super().__init__( feature_channel, output_channel, hidden_size, num_layers, "lstm" )
[docs] class RNN_GRU(BaseRNN): """ GRU-based bidirectional RNN model. This class inherits from BaseRNN and configures it to use GRU cells. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of GRU layers. Examples -------- >>> model = RNN_GRU( ... feature_channel=6, ... output_channel=4, ... hidden_size=128, ... num_layers=3 ... ) >>> x = torch.randn(16, 6, 10) >>> y = model(x) >>> print(y.shape) torch.Size([16, 4, 10]) """
[docs] def __init__( self, feature_channel: int, output_channel: int, hidden_size: int, num_layers: int, ) -> None: """ Initialize the GRU model. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of GRU layers. """ super().__init__( feature_channel, output_channel, hidden_size, num_layers, "gru" )