GitHub Repository

Climate-PINN Documentation

Comprehensive guide to using the Physics-Informed Neural Network for Climate Modeling

Contents

Installation

Setting up the Climate-PINN environment requires Python 3.12 and several dependencies. Follow these steps to get started:

# Clone the repository
git clone https://github.com/enzolvd/PINN_Climate.git
cd PINN_Climate

# Create and activate a conda environment
conda create -n climate_pinn python=3.12
conda activate climate_pinn

# Install dependencies
pip install -r requirements.txt
Note: This project requires PyTorch with CUDA support for GPU acceleration. If you need to install a specific version of PyTorch compatible with your CUDA version, visit the PyTorch installation page.

Data Preparation

The model expects ERA5 climate data in NetCDF format. You can download ERA5 data from the Copernicus Climate Data Store or more easily from WeatherBench.

Data should be organized in the following structure:

data/
└── era_5_data/
    ├── constants/
    │   └── constants.nc      # Contains orography, land-sea mask, and soil type
    ├── geopotential_500/
    ├── 2m_temperature/
    ├── temperature_850/
    ├── 10m_u_component_of_wind/
    └── 10m_v_component_of_wind/

Each variable directory should contain yearly NetCDF files (e.g., geopotential_500_1979_data.nc).

Dataset Module

The dataset module provides utilities for loading and preprocessing ERA5 data.

ERADataset Class

The ERADataset class handles loading and preprocessing ERA5 climate data:

from dataset import ERADataset

# Create a dataset for specific years
dataset = ERADataset(
    root_dir="./data/era_5_data",
    years=[1979, 1980, 1981],
    normalize=True
)

# Access data from the dataset
batch = dataset[0]
inputs = batch['input']        # Shape: [channels, height, width]
targets = batch['target']      # Shape: [channels, height, width]
coords = batch['coords']       # List of [lon, lat, time]
masks = batch['constant_masks'] # Constant masks (orography, land-sea, soil)

The dataset automatically handles normalization and preparation of inputs and targets according to this mapping:

Type Variables
Inputs Geopotential at 500 hPa, Temperature at 850 hPa
Targets Temperature at 2m, 10m U wind component, 10m V wind component
Constants Orography, Land-sea mask, Soil type

load_dataset Function

The load_dataset function provides a convenient way to create training and validation splits:

from dataset import load_dataset
from torch.utils.data import DataLoader

# Load dataset with train/validation split
datasets = load_dataset(
    nb_file=10,               # Number of years to use
    train_val_split=0.8,      # 80% training, 20% validation
    year0=1979,               # Starting year
    root_dir="./data/era_5_data",
    normalize=True
)

# Create data loaders
train_loader = DataLoader(datasets['train'], batch_size=32, shuffle=True)
val_loader = DataLoader(datasets['val'], batch_size=32, shuffle=False)

Training

The project supports several training workflows, from single experiments to distributed training on SLURM clusters.

Single Experiment

To train a single model, use the train.py script with appropriate parameters:

python train.py \
    --model=model_2 \
    --experiment_name=climate_run_1 \
    --wandb_project=climate_pinn \
    --hidden_dim=64 \
    --initial_re=100.0 \
    --nb_years=10 \
    --train_val_split=0.8 \
    --batch_size=128 \
    --epochs=100 \
    --learning_rate=1e-3 \
    --physics_weight=0.5 \
    --data_weight=1.0 \
    --data_dir=./data/era_5_data

The training process will be tracked using Weights & Biases, logging:

  • Training and validation losses
  • Physics constraint residuals
  • Reynolds number evolution
  • Prediction visualizations
  • Model checkpoints

Visualization Tools

The repository includes comprehensive visualization tools for model predictions.

Video Generation

To generate visualizations of model predictions:

python video_gen.py

This will create visualizations for all runs, including:

  • Temperature field predictions vs. ground truth
  • Temperature prediction error maps
  • Wind field predictions vs. ground truth
  • Wind field prediction error maps

You can customize the visualization by modifying parameters at the bottom of video_gen.py:

if __name__ == "__main__":
    runs = ['run_1', 'run_3', 'run_4', 'run_8', 'run_9']
    fps = 48
    year = 2000
    duration = 20

    for run in tqdm(runs):
        visualize_predictions(run, year, fps=fps, duration=duration)

Static Image Generation

For static comparison images at specific timesteps, use image_gen.py:

python image_gen.py

This script generates PDF files showing side-by-side comparisons of temperature and wind predictions vs. ground truth at specified timesteps.

Uncertainty Visualization with MC Dropout

For models with dropout (e.g., model_2), you can visualize prediction uncertainty using Monte Carlo dropout:

python uncertainty.py run_8 2000 30

Arguments:

  1. Run name (e.g., run_8)
  2. Year to analyze (e.g., 2000)
  3. Number of MC samples (e.g., 30)

This generates uncertainty visualizations including:

  • Static uncertainty maps at specific timesteps
  • Animated visualizations of mean predictions with confidence intervals
  • Spatially-resolved uncertainty maps

Physics Constraints

The PINN is constrained by fluid dynamics principles from the Navier-Stokes equations.

Reynolds Number

The Reynolds number is a dimensionless quantity that represents the ratio of inertial forces to viscous forces within a fluid. In the PINN models, it is treated as a learnable parameter.

Different model variants handle the Reynolds number differently:

  • model_0: Basic implementation with unconstrained Reynolds number
  • model_0_Re: Introduces clamping and momentum-based smoothing
  • model_3: Uses a neural network to predict spatially-varying Reynolds numbers

Example implementation from model_2:

def get_reynolds_number(self):
    clamped_log_re = torch.clamp(self.log_re, 
                                min=torch.log(torch.tensor(50.0, device=self.device)), 
                                max=torch.log(torch.tensor(1e5, device=self.device)))
    current_re = torch.exp(clamped_log_re)
    
    if self.previous_re is None:
        self.previous_re = current_re
    
    smoothed_re = self.re_momentum * self.previous_re + (1 - self.re_momentum) * current_re
    self.previous_re = smoothed_re.detach()
    
    return smoothed_re

Hyperparameters

Key hyperparameters for model training:

Parameter Description Default
model Model variant to use Required
hidden_dim Dimension of hidden layers 64
initial_re Initial Reynolds number 100.0
physics_weight Weight of physics loss terms 0.5
data_weight Weight of data loss terms 1.0
batch_size Training batch size 128
learning_rate Initial learning rate 1e-3
nb_years Number of years to use for training 10
train_val_split Ratio for training/validation split 0.8

Recommended parameter combinations for different model variants:

Model hidden_dim physics_weight learning_rate
model_0 64 1.0 1e-3
model_0_Re 64 1.0 1e-3
model_2 64 0.5 1e-3
model_3 64 1.0 1e-3

Frequently Asked Questions

How do I get the pre-trained model weights?

Pre-trained model weights are available on Hugging Face. You can download them using:

# Using git lfs
git lfs install
git clone https://huggingface.co/enzolouv/PINN_Climate

# Or download directly from the web interface
# https://huggingface.co/enzolouv/PINN_Climate

What's the difference between model variants?

The repository includes several model variants with different architectures:

  • model_0: Original baseline model
  • model_0_Re: Enhanced with clipped gradient and momentum on the Reynolds number
  • model_1: Model with modified dropout placement (suboptimal)
  • model_2: Model with correctly placed dropout and improved Reynolds number handling
  • model_3: Advanced model with a neural network for Reynolds number estimation

For most applications, model_2 or model_3 are recommended as they provide the best performance.

How do I generate custom visualizations?

You can customize the visualization scripts to generate visualizations for specific models, years, or parameters:

# For video generation
python video_gen.py

# For static images
python image_gen.py

# For uncertainty visualization
python uncertainty.py run_name year num_samples

Edit the parameters at the bottom of each script to customize the output.

How do physics constraints work in the model?

Physics constraints are incorporated into the loss function by computing residuals of fluid dynamics equations. The model is trained to minimize both data-driven loss (MSE between predictions and true values) and physics-based loss (residuals of Navier-Stokes equations). This ensures that predictions respect physical laws.

Acknowledgments

This project builds upon several key resources: