Source code for rtnn.dataset

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

"""
Dataset module for RTnn radiative transfer neural network framework.

This module provides the DataPreprocessor class for loading, preprocessing,
and normalizing NetCDF climate data for training and evaluating neural network
emulators of atmospheric radiative transfer. The module handles complex data
structures including multiple years, spatial and temporal batching, and various
normalization techniques.

The DataPreprocessor is designed to work with Land Surface Model (LSM) data
containing variables such as:
- Cosine of solar zenith angle (cosz)
- Leaf area index (lai)
- Single scattering albedo (ssa)
- Surface reflectance (rs)
- Radiative transfer outputs (collimated albedo/transmittance, isotropic albedo/transmittance)

The module supports data augmentation through random spatial sampling during
training and provides flexible normalization methods including min-max scaling,
standardization, robust scaling, and transformations like log1p and sqrt.

Notes
-----
The expected input data structure consists of NetCDF4 files with dimensions:
- time: Temporal dimension
- dim_1: Spatial points (e.g., grid cells)
- dim_2: Vertical levels (sequence length, typically 10)
- dim_3: Plant functional types (PFTs, typically 15)
- dim_4: Spectral bands (typically 2)

Feature channels (121 total):
- cosz: 1 channel
- lai: 2 variables × 15 PFTs = 30 channels
- ssa: 2 variables × 2 bands × 15 PFTs = 60 channels
- rs: 1 variable × 2 bands × 15 PFTs = 30 channels

Output channels (120 total):
- 4 variables × 2 bands × 15 PFTs = 120 channels

Examples
--------
Basic usage for training:

>>> from rtnn.logger import Logger
>>> logger = Logger()
>>> train_files = ["data_1995.nc", "data_1996.nc", "data_1997.nc"]
>>> dataset = DataPreprocessor(
...     logger=logger,
...     dfs=train_files,
...     stime=0,
...     tbatch=24,
...     training=True,
...     sblock_perc=0.6,
...     norm_mapping=norm_stats,
...     normalization_type=norm_types
... )
>>> features, targets = dataset[0]
>>> features.shape
torch.Size([158, 121, 10])
>>> targets.shape
torch.Size([158, 120, 10])

For validation/testing:

>>> val_dataset = DataPreprocessor(
...     logger=logger,
...     dfs=val_files,
...     stime=0,
...     tbatch=24,
...     training=False,
...     norm_mapping=norm_stats,
...     normalization_type=norm_types
... )

See Also
--------
rtnn.models : Contains the neural network architectures for radiative transfer
rtnn.evaluater : Provides metrics and loss functions for model evaluation
rtnn.diagnostics : Visualization tools for model predictions
"""

import torch
from torch.utils.data import Dataset
import numpy as np
import xarray as xr
from collections import defaultdict
import re
from typing import Dict, List, Tuple, Any
import random


[docs] class DataPreprocessor(Dataset): """ Dataset class for preprocessing LSM (Land Surface Model) data. This class handles loading and preprocessing of NetCDF files containing climate data, with support for multiple years, spatial and temporal batching, and various normalization techniques. Parameters ---------- logger : object Logger instance for logging messages. dfs : List[str] List of file paths to NetCDF files. stime : int Start time index. tstep : int Number of time steps per file. tbatch : int Temporal batch size. norm_mapping : Dict, optional Dictionary containing normalization statistics for each variable. Default is empty dict. normalization_type : Dict, optional Dictionary specifying normalization type for each variable. Default is empty dict. Attributes ---------- logger : object Logger instance. stime : int Start time index. tstep : int Time steps per file. tbatch : int Temporal batch size. norm_mapping : Dict Normalization statistics. normalization_type : Dict Normalization types per variable. sbatch : int Number of spatial batches. years : List[int] Sorted list of years in the dataset. etime : int End time index. dfs : List[Tuple[int, int, str]] List of (year, spatial_index, file_path) tuples. time_blocks : np.ndarray Shuffled time blocks. min_dims : Dict[str, int] Minimum dimensions across files. cosz : List[str] Cosine of solar zenith angle variable names. lai : List[str] Leaf area index variable names. ssa : List[str] Single scattering albedo variable names. rs : List[str] Surface reflectance variable names. ov : List[str] Output variable names. Examples -------- >>> from rtnn.logger import Logger >>> logger = Logger() >>> files = ["data_1995.nc", "data_1996.nc"] >>> dataset = DataPreprocessor( ... logger=logger, ... dfs=files, ... stime=0, ... tstep=100, ... tbatch=24, ... norm_mapping={}, ... normalization_type={} ... ) >>> len(dataset) 100 >>> features, targets = dataset[0] >>> features.shape torch.Size([schunk, feature_channels, seq_length]) >>> targets.shape torch.Size([schunk, output_channels, seq_length]) """
[docs] def __init__( self, logger: Any, dfs: List[str], stime: int, tbatch: int, training: bool = True, sblock_perc: float = 0.6, norm_mapping: Dict = {}, normalization_type: Dict = {}, debug: bool = False, ) -> None: """ Initialize the DataPreprocessor. Parameters ---------- logger : Any Logger instance for logging messages. dfs : List[str] List of file paths to NetCDF files. stime : int Start time index. tbatch : int Temporal batch size. training : bool, optional If True, use 60% of spatial batches (data augmentation). If False, use 100% of spatial batches (full evaluation). norm_mapping : Dict, optional Dictionary containing normalization statistics for each variable. normalization_type : Dict, optional Dictionary specifying normalization type for each variable. debug : bool, optional If True, print debug information. """ self.logger = logger self.stime = stime self.tbatch = tbatch self.training = training self.norm_mapping = norm_mapping self.normalization_type = normalization_type self.debug = debug self.sblock_perc = sblock_perc # Group files by year self.train_sbatch_files_by_year = defaultdict(list) for f in dfs: match = re.search(r"_(\d{4})\.nc$", f) if match: year = int(match.group(1)) self.train_sbatch_files_by_year[year].append(f) # Determine number of spatial batches first_key = list(self.train_sbatch_files_by_year.keys())[0] self.total_sbatch = len(self.train_sbatch_files_by_year[first_key]) # Set spatial batch size based on training mode if self.training: # Training: use 60% of spatial batches self.sbatch = max(1, int(self.total_sbatch * self.sblock_perc)) # Initialize tracking for random spatial mapping self.last_tindex = -1 self.current_spatial_mapping = None else: # Validation/Testing: use 100% of spatial batches self.sbatch = self.total_sbatch self.years = sorted(self.train_sbatch_files_by_year.keys()) self.year_to_index = {y: i for i, y in enumerate(self.years)} # Create list of (year, spatial_index, path) for all files self.dfs = [ (year, sindex, path) for year in self.years for sindex, path in enumerate(sorted(self.train_sbatch_files_by_year[year])) ] # Find minimum dimensions across all files self.min_dims = { "time": np.inf, "dim_1": np.inf, "dim_2": np.inf, "dim_3": np.inf, "dim_4": np.inf, } for _, _, file_path in self.dfs: ds = xr.open_dataset(file_path) for dim in self.min_dims: if dim in ds.sizes: self.min_dims[dim] = min(self.min_dims[dim], ds.sizes[dim]) ds.close() for dim, size in self.min_dims.items(): self.logger.info(f"Minimum {dim} across files: {size}") self.tstep = self.min_dims["time"] self.etime = self.tstep * len(self.years) # Create and shuffle time blocks self.time_blocks = np.arange((self.etime - self.stime) // self.tbatch) # Define variable groups self.cosz = ["coszang"] # Cosine of solar zenith angle self.lai = ["laieff_collim", "laieff_isotrop"] # Leaf area index self.ssa = ["leaf_ssa", "leaf_psd"] # Single scattering albedo self.rs = ["rs_surface_emu"] # Surface reflectance self.ov = [ "collim_alb", "collim_tran", "isotrop_alb", "isotrop_tran", ] # Output variables self.sindex_tracker = [] # Will store spatial indices self.tindex_tracker = [] # Will store temporal indices self.logger.info(f"Time range: {self.stime} ... {self.etime}") self.logger.info(f"Time steps per file: {self.tstep}") self.logger.info(f"Temporal batch size: {self.tbatch}") self.logger.info(f"Spatial batche size: {self.sbatch}") self.logger.info(f"Time blocks: {self.time_blocks}") self.logger.info(f"Years: {self.years}") self.logger.info(f"Year to index: {self.year_to_index}") self.logger.info( f"Variable groups: {self.cosz}, {self.lai}, {self.ssa}, {self.rs}, {self.ov}" ) self.logger.info( "The list of file info:\n" + "\n".join(f"{year}, {sindex}, {path}" for year, sindex, path in self.dfs) ) random.seed(42) # Set a fixed seed for reproducibility
def _get_random_spatial_mapping(self) -> List[int]: """ Generate a random spatial mapping for training. Returns ------- List[int] List of randomly selected processor ranks (size = self.sbatch). """ return random.sample(range(self.total_sbatch), self.sbatch)
[docs] def normalize(self, data: np.ndarray, var_name: str) -> np.ndarray: """ Normalize data using the specified normalization method. Parameters ---------- data : np.ndarray Input data array to normalize. var_name : str Name of the variable for which to retrieve normalization statistics. Returns ------- np.ndarray Normalized data array. Raises ------ ValueError If the normalization type is not supported. Notes ----- Supported normalization types: - minmax: (x - min) / (max - min) - standard: (x - mean) / std - robust: (x - median) / IQR - log1p_minmax: log1p(x) normalized - log1p_standard: log1p(x) standardized - log1p_robust: log1p(x) robust normalized - sqrt_minmax: sqrt(x) normalized - sqrt_standard: sqrt(x) standardized - sqrt_robust: sqrt(x) robust normalized """ norm_type = self.normalization_type.get(var_name, "log1p_minmax") stats = self.norm_mapping[var_name] if self.debug: self.logger.info( f"Normalizing variable '{var_name}' using method '{norm_type}' with stats: {stats}" ) if norm_type == "minmax": vmin = stats["vmin"] vmax = stats["vmax"] return (data - vmin) / (vmax - vmin) elif norm_type == "standard": mean = stats["vmean"] std = stats["vstd"] return (data - mean) / std elif norm_type == "robust": median = stats["median"] iqr = stats["iqr"] return (data - median) / iqr elif norm_type == "log1p_minmax": data = np.log1p(data) log_min = stats["log_min"] log_max = stats["log_max"] return (data - log_min) / (log_max - log_min) elif norm_type == "log1p_standard": data = np.log1p(data) mean = stats["log_mean"] std = stats["log_std"] return (data - mean) / std elif norm_type == "log1p_robust": data = np.log1p(data) median = stats["log_median"] iqr = stats["log_iqr"] return (data - median) / iqr elif norm_type == "sqrt_minmax": data = np.sqrt(np.clip(data, a_min=0, a_max=None)) sqrt_min = stats["sqrt_min"] sqrt_max = stats["sqrt_max"] return (data - sqrt_min) / (sqrt_max - sqrt_min) elif norm_type == "sqrt_standard": data = np.sqrt(np.clip(data, a_min=0, a_max=None)) mean = stats["sqrt_mean"] std = stats["sqrt_std"] return (data - mean) / std elif norm_type == "sqrt_robust": data = np.sqrt(np.clip(data, a_min=0, a_max=None)) median = stats["sqrt_median"] iqr = stats["sqrt_iqr"] return (data - median) / iqr else: raise ValueError( f"Unsupported normalization type '{norm_type}' for variable '{var_name}'" )
def __len__(self) -> int: """ Get the total number of samples in the dataset. Returns ------- int Total number of samples (time blocks * spatial batches). """ return (self.etime - self.stime) // self.tbatch * self.sbatch def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Get a sample from the dataset. Parameters ---------- index : int Index of the sample to retrieve. Returns ------- Tuple[torch.Tensor, torch.Tensor] A tuple containing: - **features** : torch.Tensor Input features tensor of shape (schunk, feature_channels, seq_length) where schunk is the spatial batch size (dim_1 dimension) - **targets** : torch.Tensor Target variables tensor of shape (schunk, output_channels, seq_length) Raises ------ IndexError If the index is out of range or if year_index calculation fails. Notes ----- The method loads data from the appropriate file based on the index, applies normalization, and returns the processed features and targets. **Data loading strategy:** 1. Calculate spatial and temporal indices from the sample index 2. Determine which year and file to load based on the indices 3. Load the appropriate NetCDF file and extract the required time step 4. Process and normalize each variable group (cosz, lai, ssa, rs, outputs) 5. Reshape and combine into feature and target tensors **Feature processing details:** - COSZ: Single value tiled across all vertical levels - LAI: (dim_1, n_pft, seq_len) - varies with vertical level - SSA: (dim_1, n_bands × n_pft, seq_len) - constant across vertical levels - RS: (dim_1, n_bands × n_pft, seq_len) - constant across vertical levels - Outputs: (dim_1, 4 × n_bands × n_pft, seq_len) Examples -------- >>> features, targets = dataset[42] >>> features.shape torch.Size([158, 121, 10]) >>> targets.shape torch.Size([158, 120, 10]) >>> features.dtype torch.float32 """ if index >= len(self): raise IndexError(f"Index {index} out of range [0, {len(self)})") # Calculate spatial and temporal indices index_spatial_mapping = index % self.sbatch tblock = index // self.sbatch # Calculate which year this block belongs to blocks_per_year = self.tstep // self.tbatch if blocks_per_year <= 0: raise ValueError( f"Invalid blocks_per_year: {blocks_per_year}. " f"tstep={self.tstep}, tbatch={self.tbatch}" ) year_index = tblock // blocks_per_year # Validate year_index if year_index >= len(self.years): raise IndexError( f"Year index {year_index} out of range [0, {len(self.years)})" ) local_tblock = tblock % blocks_per_year # Calculate time index (with random offset for training) tindex = local_tblock * self.tbatch + self.stime # For training: regenerate spatial mapping when time index changes if self.training: if self.last_tindex != tindex: self.current_spatial_mapping = self._get_random_spatial_mapping() self.last_tindex = tindex if self.debug: self.logger.info( f"New spatial mapping for tindex {tindex}: {self.current_spatial_mapping}" ) # Map the spatial index to an actual processor rank sindex = self.current_spatial_mapping[index_spatial_mapping] else: # For validation/testing: use direct mapping (sindex = index_spatial_mapping) sindex = index_spatial_mapping if self.training: tindex += np.random.randint(self.tbatch) self.tindex_tracker.append(tblock) self.sindex_tracker.append(sindex) # Get the file path dfs_index = year_index * self.sbatch + sindex _, _, path = self.dfs[dfs_index] if self.debug: self.logger.info("------------------- GET ITEM INFO -------------------") self.logger.info( f"\nTorch batch index: {index}\n" f"Spatial index before mapping: {index_spatial_mapping}, and Spatial index after mapping: {sindex}\n" f"Temporal block index: {tblock}\n" f"Year index: {year_index}\n" f"Local time block: {local_tblock}\n" f"Time index: {tindex}\n" f"Loading file: {path}" ) # Open the dataset self.df = xr.open_dataset(path) # Get dimensions seq_len = self.min_dims["dim_2"] # 10 (vertical levels) dim_1 = self.min_dims["dim_1"] # 263 (spatial points) n_pft = self.min_dims["dim_3"] # 15 n_bands = self.min_dims["dim_4"] # 2 if self.debug: self.logger.info( f"Dimensions for processing:\n" f" |- sequence_length_dim: {seq_len}\n" f" |- dim_1: {dim_1}\n" f" |- n_pft: {n_pft}\n" f" |- n_bands: {n_bands}" ) # ================================================================ # FEATURES # ================================================================ # Feature channels: # - cosz: 1 # - lai: 2 vars × n_pft = 30 # - ssa: 2 vars × n_bands × n_pft = 60 # - rs: 1 var × n_bands × n_pft = 30 # Total: 121 n_lai_features = 2 * n_pft # 30 n_ssa_features = 2 * n_bands * n_pft # 60 n_rs_features = 1 * n_bands * n_pft # 30 n_features = 1 + n_lai_features + n_ssa_features + n_rs_features # 121 features = np.zeros([dim_1, n_features, seq_len], dtype=np.float32) f_idx = 0 # 1. COSZ - shape: (time, dim_1) -> (dim_1, 1, seq_len) for var_name in self.cosz: da = self.df[var_name] temp = da.isel(time=tindex, dim_1=slice(0, dim_1)).values # (dim_1,) temp = self.normalize(temp, var_name) # Tile to (dim_1, 1, seq_len) - same value for all vertical levels temp = temp[:, np.newaxis, np.newaxis] # (dim_1, 1, 1) temp = np.tile(temp, (1, 1, seq_len)) # (dim_1, 1, seq_len) features[:, f_idx : f_idx + 1, :] = temp f_idx += 1 # 2. LAI - shape: (time, dim_3, dim_2, dim_1) -> (dim_1, n_pft, seq_len) # dim_2 is vertical level, we need ALL levels, so we loop over dim_2 for var_name in self.lai: da = self.df[var_name] # Get all data for this time step temp = da.isel( time=tindex, dim_1=slice(0, dim_1) ).values # (dim_3, dim_2, dim_1) temp = self.normalize(temp, var_name) # For each vertical level (seq_len), we have a (dim_3, dim_1) matrix # We want: (dim_1, dim_3, seq_len) # Transpose to (dim_1, dim_3, dim_2) temp = temp.transpose(2, 0, 1) # (dim_1, dim_3, dim_2) # Now temp has shape (dim_1, n_pft, seq_len) - perfect! features[:, f_idx : f_idx + n_pft, :] = temp f_idx += n_pft # 3. SSA (leaf_ssa, leaf_psd) - shape: (time, dim_4, dim_3) -> (dim_1, n_bands, n_pft, seq_len) # Note: SSA does NOT depend on dim_2 (vertical level), so same for all levels for var_name in self.ssa: da = self.df[var_name] temp = da.isel(time=tindex).values # (dim_4, dim_3) temp = self.normalize(temp, var_name) # Expand to (dim_1, dim_4, dim_3, seq_len) by tiling temp = temp[np.newaxis, :, :, np.newaxis] # (1, dim_4, dim_3, 1) temp = np.tile( temp, (dim_1, 1, 1, seq_len) ) # (dim_1, dim_4, dim_3, seq_len) # Reshape to (dim_1, dim_4 * dim_3, seq_len) temp = temp.reshape(dim_1, n_bands * n_pft, seq_len) features[:, f_idx : f_idx + n_bands * n_pft, :] = temp f_idx += n_bands * n_pft # 4. RS (rs_surface_emu) - shape: (time, dim_4, dim_3, dim_1) -> (dim_1, n_bands, n_pft, seq_len) # Note: RS has dim_3 (PFT) dimension, we need all PFTs for var_name in self.rs: da = self.df[var_name] temp = da.isel( time=tindex, dim_1=slice(0, dim_1) ).values # (dim_4, dim_3, dim_1) temp = self.normalize(temp, var_name) # Transpose to (dim_1, dim_4, dim_3) temp = temp.transpose(2, 0, 1) # (dim_1, dim_4, dim_3) # Expand to (dim_1, dim_4, dim_3, seq_len) by tiling (same for all vertical levels) temp = temp[:, :, :, np.newaxis] # (dim_1, dim_4, dim_3, 1) temp = np.tile(temp, (1, 1, 1, seq_len)) # (dim_1, dim_4, dim_3, seq_len) # Reshape to (dim_1, dim_4 * dim_3, seq_len) temp = temp.reshape(dim_1, n_bands * n_pft, seq_len) features[:, f_idx : f_idx + n_bands * n_pft, :] = temp f_idx += n_bands * n_pft assert f_idx == n_features, f"Feature mismatch: {f_idx} vs {n_features}" # ================================================================ # OUTPUTS - shape: (dim_1, n_outputs, seq_len) # n_outputs = 4 vars × n_bands × n_pft = 120 # ================================================================ n_outputs = len(self.ov) * n_bands * n_pft # 120 outputs = np.zeros([dim_1, n_outputs, seq_len], dtype=np.float32) o_idx = 0 for var_name in self.ov: da = self.df[var_name] # Shape: (time, dim_4, dim_2, dim_3, dim_1) temp = da.isel( time=tindex, dim_1=slice(0, dim_1) ).values # (dim_4, dim_2, dim_3, dim_1) temp = self.normalize(temp, var_name) # We want: (dim_1, dim_4, dim_3, dim_2) - output per vertical level # Transpose to (dim_1, dim_4, dim_3, dim_2) temp = temp.transpose(3, 0, 2, 1) # (dim_1, dim_4, dim_3, dim_2) # Now temp has shape (dim_1, n_bands, n_pft, seq_len) - perfect! # Reshape to (dim_1, n_bands * n_pft, seq_len) temp = temp.reshape(dim_1, n_bands * n_pft, seq_len) outputs[:, o_idx : o_idx + n_bands * n_pft, :] = temp o_idx += n_bands * n_pft assert o_idx == n_outputs, f"Output mismatch: {o_idx} vs {n_outputs}" # Convert to torch tensors features_tensor = torch.tensor(features, dtype=torch.float32) outputs_tensor = torch.tensor(outputs, dtype=torch.float32) return features_tensor, outputs_tensor