# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
"""
Multi-layer perceptron architectures for structured and sequence-based modeling.
This module provides flexible and extensible implementations of multi-layer
perceptrons (MLPs) tailored for tasks such as radiative transfer emulation
and other scientific machine learning applications involving structured inputs.
The module includes:
- MLPBlock: A configurable fully connected block with optional normalization,
activation, and dropout.
- MLP: A flexible MLP architecture supporting positional embeddings,
residual connections, and customizable depth.
- MLPResidual: A residual MLP with skip connections across all hidden layers
for improved gradient flow and training stability.
Features
--------
- Configurable hidden layer sizes and depth
- Support for multiple normalization strategies (batch norm, layer norm)
- Choice of activation functions (ReLU, GELU, SiLU)
- Optional dropout for regularization
- Residual connections for improved optimization
- Learnable positional embeddings for sequence-aware modeling
- Designed for flattened sequence inputs and structured data
Notes
-----
- Inputs are expected in the shape (batch_size, feature_channel, seq_length)
and are internally flattened before processing.
- Positional embeddings, when enabled, are concatenated to the input features
before passing through the network.
- Residual connections in ``MLP`` are applied globally, while ``MLPResidual``
applies residual connections at every hidden layer.
- Layer normalization is applied to outputs for improved numerical stability.
Dependencies
------------
- torch
- torch.nn
- typing
Examples
--------
Basic MLP usage:
>>> model = MLP(
... feature_channel=6,
... output_channel=4,
... seq_length=10,
... hidden_sizes=[512, 256, 128]
... )
>>> x = torch.randn(32, 6, 10)
>>> y = model(x)
Using MLP with positional embeddings and residuals:
>>> model = MLP(
... feature_channel=6,
... output_channel=4,
... seq_length=10,
... use_positional_embedding=True,
... use_residual=True
... )
Using MLPResidual:
>>> model = MLPResidual(
... feature_channel=6,
... output_channel=4,
... seq_length=10,
... hidden_size=256,
... num_layers=4
... )
>>> x = torch.randn(16, 6, 10)
>>> y = model(x)
"""
import torch
import torch.nn as nn
from typing import List
[docs]
class MLPBlock(nn.Module):
"""
A single MLP block with linear layer, normalization, activation, and dropout.
Parameters
----------
in_features : int
Number of input features.
out_features : int
Number of output features.
dropout : float, optional
Dropout rate. Default is 0.1.
use_batch_norm : bool, optional
Whether to use batch normalization. Default is True.
use_layer_norm : bool, optional
Whether to use layer normalization. Default is False.
activation : str, optional
Activation function ('relu', 'gelu', 'silu'). Default is 'relu'.
"""
[docs]
def __init__(
self,
in_features: int,
out_features: int,
dropout: float = 0.1,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
activation: str = "relu",
):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
self.use_batch_norm = use_batch_norm
self.use_layer_norm = use_layer_norm
if use_batch_norm:
self.bn = nn.BatchNorm1d(out_features)
if use_layer_norm:
self.ln = nn.LayerNorm(out_features)
if activation == "relu":
self.act = nn.ReLU()
elif activation == "gelu":
self.act = nn.GELU()
elif activation == "silu":
self.act = nn.SiLU()
else:
self.act = nn.ReLU()
self.dropout = nn.Dropout(dropout)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
x = self.linear(x)
if self.use_batch_norm:
x = self.bn(x)
if self.use_layer_norm:
x = self.ln(x)
x = self.act(x)
x = self.dropout(x)
return x
[docs]
class MLP(nn.Module):
"""
Multi-Layer Perceptron for radiative transfer emulation.
Parameters
----------
feature_channel : int
Number of input features per time step.
output_channel : int
Number of output channels.
seq_length : int
Length of the input sequence.
hidden_sizes : List[int], optional
List of hidden layer sizes. Default is [512, 256, 128].
dropout : float, optional
Dropout rate. Default is 0.1.
use_batch_norm : bool, optional
Whether to use batch normalization. Default is True.
use_layer_norm : bool, optional
Whether to use layer normalization. Default is False.
use_residual : bool, optional
Whether to use residual connections. Default is False.
activation : str, optional
Activation function ('relu', 'gelu', 'silu'). Default is 'relu'.
use_positional_embedding : bool, optional
Whether to add positional embeddings. Default is True.
positional_embed_dim : int, optional
Dimension of positional embeddings. Default is 16.
"""
[docs]
def __init__(
self,
feature_channel: int,
output_channel: int,
seq_length: int = 10,
hidden_sizes: List[int] = None,
dropout: float = 0.1,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
use_residual: bool = False,
activation: str = "relu",
use_positional_embedding: bool = True,
positional_embed_dim: int = 16,
):
super().__init__()
if hidden_sizes is None:
hidden_sizes = [512, 256, 128]
self.feature_channel = feature_channel
self.output_channel = output_channel
self.seq_length = seq_length
self.hidden_sizes = hidden_sizes
self.use_residual = use_residual
self.use_positional_embedding = use_positional_embedding
input_size = feature_channel * seq_length
if use_positional_embedding:
self.positional_embed = nn.Embedding(seq_length, positional_embed_dim)
input_size += positional_embed_dim * seq_length
else:
self.positional_embed = None
self.layers = nn.ModuleList()
prev_size = input_size
for hidden_size in hidden_sizes:
self.layers.append(
MLPBlock(
prev_size,
hidden_size,
dropout=dropout,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
activation=activation,
)
)
prev_size = hidden_size
output_size = output_channel * seq_length
self.output_layer = nn.Linear(prev_size, output_size)
self.output_dropout = nn.Dropout(dropout)
self.final_norm = nn.LayerNorm(output_size)
if use_residual and input_size != hidden_sizes[-1]:
self.residual_proj = nn.Linear(input_size, hidden_sizes[-1])
else:
self.residual_proj = None
def _add_positional_embedding(self, x: torch.Tensor) -> torch.Tensor:
"""Add positional embeddings to the flattened input."""
batch_size = x.shape[0]
x_reshaped = x.reshape(batch_size, self.seq_length, -1)
positions = torch.arange(self.seq_length, device=x.device)
pos_embed = self.positional_embed(positions)
pos_embed = pos_embed.unsqueeze(0).expand(batch_size, -1, -1)
x_combined = torch.cat([x_reshaped, pos_embed], dim=-1)
return x_combined.reshape(batch_size, -1)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
batch_size = x.shape[0]
x = x.reshape(batch_size, -1)
if self.use_positional_embedding and self.positional_embed is not None:
x = self._add_positional_embedding(x)
residual = x if self.use_residual else None
if self.residual_proj is not None:
residual = self.residual_proj(residual)
for layer in self.layers:
x = layer(x)
if self.use_residual and residual is not None:
x = x + residual
x = self.output_layer(x)
x = self.output_dropout(x)
x = self.final_norm(x)
x = x.reshape(batch_size, self.output_channel, self.seq_length)
return x
[docs]
class MLPResidual(nn.Module):
"""
MLP with residual connections between all layers.
Parameters
----------
feature_channel : int
Number of input features.
output_channel : int
Number of output channels.
seq_length : int
Sequence length.
hidden_size : int
Size of hidden layers.
num_layers : int
Number of hidden layers.
dropout : float, optional
Dropout rate. Default is 0.1.
"""
[docs]
def __init__(
self,
feature_channel: int,
output_channel: int,
seq_length: int = 10,
hidden_size: int = 256,
num_layers: int = 4,
dropout: float = 0.1,
):
super().__init__()
self.feature_channel = feature_channel
self.output_channel = output_channel
self.seq_length = seq_length
self.hidden_size = hidden_size
self.num_layers = num_layers
input_size = feature_channel * seq_length
output_size = output_channel * seq_length
self.input_proj = nn.Linear(input_size, hidden_size)
self.input_norm = nn.LayerNorm(hidden_size)
self.input_act = nn.ReLU()
self.input_dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList()
for _ in range(num_layers):
block = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(),
nn.Dropout(dropout),
)
self.blocks.append(block)
self.output_proj = nn.Linear(hidden_size, output_size)
self.output_norm = nn.LayerNorm(output_size)
self.output_dropout = nn.Dropout(dropout)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with residual connections."""
batch_size = x.shape[0]
x = x.reshape(batch_size, -1)
x = self.input_proj(x)
x = self.input_norm(x)
x = self.input_act(x)
x = self.input_dropout(x)
for block in self.blocks:
x = x + block(x)
x = self.output_proj(x)
x = self.output_norm(x)
x = self.output_dropout(x)
x = x.reshape(batch_size, self.output_channel, self.seq_length)
return x