Skip to content

Commit

Permalink
Merge branch 'main' into feat/mlflow
Browse files Browse the repository at this point in the history
  • Loading branch information
khintz committed Jan 23, 2025
2 parents d88c23c + d2e26a9 commit d54e3d8
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 11 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add support for MLFlow logging and metrics tracking. [\#77](https://github.com/mllam/neural-lam/pull/77)
@khintz

- Add support for multi-node training.
[\#103](https://github.com/mllam/neural-lam/pull/103) @simonkamuk @sadamov

### Fixed
- Only print on rank 0 to avoid duplicates of all print statements.
[\#103](https://github.com/mllam/neural-lam/pull/103) @simonkamuk @sadamov

## [v0.3.0](https://github.com/mllam/neural-lam/releases/tag/v0.3.0)

This release introduces Datastores to represent input data from different sources (including zarr and numpy) while keeping graph generation within neural-lam.
Expand Down
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,36 @@ python -m neural_lam.train_model --model hi_lam_parallel --graph hierarchical ..
Checkpoint files for our models trained on the MEPS data are available upon request.
### High Performance Computing
The training script can be run on a cluster with multiple GPU-nodes. Neural LAM is set up to use PyTorch Lightning's `DDP` backend for distributed training.
The code can be used on systems both with and without slurm. If the cluster has multiple nodes, set the `--num_nodes` argument accordingly.
Using SLURM, the job can be started with `sbatch slurm_job.sh` with a shell script like the following.
```
#!/bin/bash -l
#SBATCH --job-name=Neural-LAM
#SBATCH --time=24:00:00
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres:gpu=4
#SBATCH --partition=normal
#SBATCH --mem=444G
#SBATCH --no-requeue
#SBATCH --exclusive
#SBATCH --output=lightning_logs/neurallam_out_%j.log
#SBATCH --error=lightning_logs/neurallam_err_%j.log

# Load necessary modules or activate environment, for example:
conda activate neural-lam

srun -ul python -m neural_lam.train_model \
--config_path /path/to/config.yaml \
--num_nodes $SLURM_JOB_NUM_NODES
```
When using on a system without SLURM, where all GPU's are visible, it is possible to select a subset of GPU's to use for training with the `devices` cli argument, e.g. `--devices 0 1` to use the first 2 GPU's.
## Evaluate Models
Evaluation is also done using `python -m neural_lam.train_model --config_path <config-path>`, but using the `--eval` option.
Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data.
Expand Down
9 changes: 5 additions & 4 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy import ndarray

# Local
from ..utils import rank_zero_print
from .base import BaseRegularGridDatastore, CartesianGridShape


Expand Down Expand Up @@ -72,11 +73,11 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points

print("The loaded datastore contains the following features:")
rank_zero_print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
if len(self.get_vars_names(category)) > 0:
var_names = self.get_vars_names(category)
print(f" {category:<8s}: {' '.join(var_names)}")
rank_zero_print(f" {category:<8s}: {' '.join(var_names)}")

# check that all three train/val/test splits are available
required_splits = ["train", "val", "test"]
Expand All @@ -87,12 +88,12 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
f"splits: {available_splits}"
)

print("With the following splits (over time):")
rank_zero_print("With the following splits (over time):")
for split in required_splits:
da_split = self._ds.splits.sel(split_name=split)
da_split_start = da_split.sel(split_part="start").load().item()
da_split_end = da_split.sel(split_part="end").load().item()
print(f" {split:<8s}: {da_split_start} to {da_split_end}")
rank_zero_print(f" {split:<8s}: {da_split_start} to {da_split_end}")

# find out the dimension order for the stacking to grid-index
dim_order = None
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):

# Specify dimensions of data
self.num_mesh_nodes, _ = self.get_num_mesh()
print(
utils.rank_zero_print(
f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} "
f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)"
)
Expand Down
10 changes: 6 additions & 4 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
] # Needs as python list for later

# Print some useful info
print("Loaded hierarchical graph with structure:")
utils.rank_zero_print("Loaded hierarchical graph with structure:")
for level_index, level_mesh_size in enumerate(self.level_mesh_sizes):
same_level_edges = self.m2m_features[level_index].shape[0]
print(
utils.rank_zero_print(
f"level {level_index} - {level_mesh_size} nodes, "
f"{same_level_edges} same-level edges"
)

if level_index < (self.num_levels - 1):
up_edges = self.mesh_up_features[level_index].shape[0]
down_edges = self.mesh_down_features[level_index].shape[0]
print(f" {level_index}<->{level_index + 1}")
print(f" - {up_edges} up edges, {down_edges} down edges")
utils.rank_zero_print(f" {level_index}<->{level_index + 1}")
utils.rank_zero_print(
f" - {up_edges} up edges, {down_edges} down edges"
)
# Embedders
# Assume all levels have same static feature dimensionality
mesh_dim = self.mesh_static_features[0].shape[1]
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/graph_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
# grid_dim from data + static + batch_static
mesh_dim = self.mesh_static_features.shape[1]
m2m_edges, m2m_dim = self.m2m_features.shape
print(
utils.rank_zero_print(
f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, "
f"m2g={self.m2g_edges}"
)
Expand Down
27 changes: 27 additions & 0 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def main(input_args=None):
default=4,
help="Number of workers in data loader (default: 4)",
)
parser.add_argument(
"--num_nodes",
type=int,
default=1,
help="Number of nodes to use in DDP (default: 1)",
)
parser.add_argument(
"--devices",
nargs="+",
type=str,
default=["auto"],
help="Devices to use for training. Can be the string 'auto' or a list "
"of integer id's corresponding to the desired devices, e.g. "
"'--devices 0 1'. Note that this cannot be used with SLURM, instead "
"set 'ntasks-per-node' in the slurm setup (default: auto)",
)
parser.add_argument(
"--epochs",
type=int,
Expand Down Expand Up @@ -257,6 +273,15 @@ def main(input_args=None):
else:
device_name = "cpu"

# Set devices to use
if args.devices == ["auto"]:
devices = "auto"
else:
try:
devices = [int(i) for i in args.devices]
except ValueError:
raise ValueError("devices should be 'auto' or a list of integers")

# Load model parameters Use new args for model
ModelClass = MODELS[args.model]
model = ModelClass(args, config=config, datastore=datastore)
Expand Down Expand Up @@ -286,6 +311,8 @@ def main(input_args=None):
deterministic=True,
strategy="ddp",
accelerator=device_name,
num_nodes=args.num_nodes,
devices=devices,
logger=training_logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
Expand Down
8 changes: 7 additions & 1 deletion neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import MLFlowLogger, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from torch import nn
from tueplots import bundles, figsizes

Expand Down Expand Up @@ -238,6 +239,11 @@ def fractional_plot_bundle(fraction):
)
return bundle

@rank_zero_only
def rank_zero_print(*args, **kwargs):
"""Print only from rank 0 process"""
print(*args, **kwargs)


def init_training_logger_metrics(training_logger, val_steps):
"""
Expand All @@ -257,7 +263,7 @@ def init_training_logger_metrics(training_logger, val_steps):
)


@pl.utilities.rank_zero.rank_zero_only
@rank_zero_only
def setup_training_logger(datastore, args, run_name):
"""
Expand Down

0 comments on commit d54e3d8

Please sign in to comment.