Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rank-Zero Printing and Improve Wandb Initialization #16

Closed
wants to merge 14 commits into from
Closed
Prev Previous commit
Next Next commit
remove double wandb initialization
  • Loading branch information
sadamov committed Jun 6, 2024
commit f2a818093d1b5a05db7363d91a839b2baf94d8b0
4 changes: 4 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -597,3 +597,7 @@ def on_load_checkpoint(self, checkpoint):
if not self.restore_opt:
opt = self.configure_optimizers()
checkpoint["optimizer_states"] = [opt.state_dict()]

def on_run_end(self):
if self.trainer.is_global_zero:
wandb.save("neural_lam/data_config.yaml")
18 changes: 3 additions & 15 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Standard library
import os
import shutil
import random
import shutil
import time

# Third-party
import numpy as np
import pytorch_lightning as pl
import torch
import wandb # pylint: disable=wrong-import-order
from pytorch_lightning.utilities import rank_zero_only
from torch import nn
from tueplots import bundles, figsizes
@@ -134,7 +133,8 @@ def loads_file(fn):
hierarchical = n_levels > 1 # Nor just single level mesh graph

# Load static edge features
m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f)
# List of (M_m2m[l], d_edge_f)
m2m_features = loads_file("m2m_features.pt")
g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f)
m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f)

@@ -288,24 +288,12 @@ def init_wandb(args):
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H_%M_%S')}-{random_int}"
)
wandb.init(
name=run_name,
project=args.wandb_project,
config=args,
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
name=run_name,
config=args,
)
wandb.save("neural_lam/data_config.yaml")
else:
wandb.init(
project=args.wandb_project,
config=args,
id=args.resume_run,
resume="must",
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
id=args.resume_run,
Loading