Skip to content

Commit

Permalink
only print on rank 0
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Kamuk Christiansen committed Jan 22, 2025
1 parent ded70b6 commit 2e469aa
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 10 deletions.
9 changes: 5 additions & 4 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy import ndarray

# Local
from ..utils import rank_zero_print
from .base import BaseRegularGridDatastore, CartesianGridShape


Expand Down Expand Up @@ -72,11 +73,11 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points

print("The loaded datastore contains the following features:")
rank_zero_print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
if len(self.get_vars_names(category)) > 0:
var_names = self.get_vars_names(category)
print(f" {category:<8s}: {' '.join(var_names)}")
rank_zero_print(f" {category:<8s}: {' '.join(var_names)}")

# check that all three train/val/test splits are available
required_splits = ["train", "val", "test"]
Expand All @@ -87,12 +88,12 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
f"splits: {available_splits}"
)

print("With the following splits (over time):")
rank_zero_print("With the following splits (over time):")
for split in required_splits:
da_split = self._ds.splits.sel(split_name=split)
da_split_start = da_split.sel(split_part="start").load().item()
da_split_end = da_split.sel(split_part="end").load().item()
print(f" {split:<8s}: {da_split_start} to {da_split_end}")
rank_zero_print(f" {split:<8s}: {da_split_start} to {da_split_end}")

# find out the dimension order for the stacking to grid-index
dim_order = None
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):

# Specify dimensions of data
self.num_mesh_nodes, _ = self.get_num_mesh()
print(
utils.rank_zero_print(
f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} "
f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)"
)
Expand Down
10 changes: 6 additions & 4 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
] # Needs as python list for later

# Print some useful info
print("Loaded hierarchical graph with structure:")
utils.rank_zero_print("Loaded hierarchical graph with structure:")
for level_index, level_mesh_size in enumerate(self.level_mesh_sizes):
same_level_edges = self.m2m_features[level_index].shape[0]
print(
utils.rank_zero_print(
f"level {level_index} - {level_mesh_size} nodes, "
f"{same_level_edges} same-level edges"
)

if level_index < (self.num_levels - 1):
up_edges = self.mesh_up_features[level_index].shape[0]
down_edges = self.mesh_down_features[level_index].shape[0]
print(f" {level_index}<->{level_index + 1}")
print(f" - {up_edges} up edges, {down_edges} down edges")
utils.rank_zero_print(f" {level_index}<->{level_index + 1}")
utils.rank_zero_print(
f" - {up_edges} up edges, {down_edges} down edges"
)
# Embedders
# Assume all levels have same static feature dimensionality
mesh_dim = self.mesh_static_features[0].shape[1]
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/graph_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
# grid_dim from data + static + batch_static
mesh_dim = self.mesh_static_features.shape[1]
m2m_edges, m2m_dim = self.m2m_features.shape
print(
utils.rank_zero_print(
f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, "
f"m2g={self.m2g_edges}"
)
Expand Down
7 changes: 7 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Third-party
import torch
from pytorch_lightning.utilities import rank_zero_only
from torch import nn
from tueplots import bundles, figsizes

Expand Down Expand Up @@ -233,6 +234,12 @@ def fractional_plot_bundle(fraction):
return bundle


@rank_zero_only
def rank_zero_print(*args, **kwargs):
"""Print only from rank 0 process"""
print(*args, **kwargs)


def init_wandb_metrics(wandb_logger, val_steps):
"""
Set up wandb metrics to track
Expand Down

0 comments on commit 2e469aa

Please sign in to comment.