Skip to content

Commit

Permalink
Merge pull request #133 from NCAR/ksha
Browse files Browse the repository at this point in the history
Bugifx on random seed inconsistencies and keyword `grad_max_norm`
  • Loading branch information
djgagne authored Dec 4, 2024
2 parents 6814c8d + 6180dcb commit fc2f8d1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
10 changes: 4 additions & 6 deletions applications/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


def load_dataset_and_sampler(conf, files, world_size, rank, is_train, seed=42):
def load_dataset_and_sampler(conf, files, world_size, rank, is_train):
"""
Load the dataset and sampler for training or validation.
Expand All @@ -63,13 +63,13 @@ def load_dataset_and_sampler(conf, files, world_size, rank, is_train, seed=42):
world_size (int): Number of processes participating in the job.
rank (int): Rank of the current process.
is_train (bool): Flag indicating whether the dataset is for training or validation.
seed (int, optional): Seed for random number generation. Defaults to 42.
Returns:
tuple: A tuple containing the dataset and the distributed sampler.
"""

# convert $USER to the actual user name
seed = conf["seed"]
conf["save_loc"] = os.path.expandvars(conf["save_loc"])

# number of previous lead time inputs
Expand Down Expand Up @@ -149,7 +149,6 @@ def load_dataset_and_sampler_zscore_only(
world_size,
rank,
is_train,
seed=42,
):
"""
Load the Z-score only dataset and sampler for training or validation.
Expand All @@ -163,12 +162,11 @@ def load_dataset_and_sampler_zscore_only(
world_size (int): Number of processes participating in the job.
rank (int): Rank of the current process.
is_train (bool): Flag indicating whether the dataset is for training or validation.
seed (int, optional): Seed for random number generation. Defaults to 42.
Returns:
tuple: A tuple containing the dataset and the distributed sampler.
"""

seed = conf["seed"]
# --------------------------------------------------- #
# separate training set and validation set cases
if is_train:
Expand Down Expand Up @@ -460,7 +458,7 @@ def main(rank, world_size, conf, backend, trial=False):
torch.cuda.set_device(rank % torch.cuda.device_count())

# Config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed = conf["seed"]
seed_everything(seed)

train_batch_size = conf["trainer"]["train_batch_size"]
Expand Down
9 changes: 4 additions & 5 deletions applications/train_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def load_dataset_and_sampler(
world_size,
rank,
is_train=True,
seed=42,
):
"""
Load the dataset and sampler for training or validation.
Expand All @@ -72,11 +71,11 @@ def load_dataset_and_sampler(
world_size (int): Number of processes participating in the job.
rank (int): Rank of the current process.
is_train (bool): Flag indicating whether the dataset is for training or validation.
seed (int, optional): Seed for random number generation. Defaults to 42.
Returns:
tuple: A tuple containing the dataset and the distributed sampler.
"""
seed = conf["seed"]
# --------------------------------------------------- #
# separate training set and validation set cases
if is_train:
Expand Down Expand Up @@ -355,7 +354,7 @@ def main(rank, world_size, conf, backend, trial=False):
torch.cuda.set_device(rank % torch.cuda.device_count())

# Config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed = conf["seed"]
seed_everything(seed)

train_batch_size = conf["trainer"]["train_batch_size"]
Expand Down Expand Up @@ -722,8 +721,8 @@ def train(self, trial, conf):
# track hyperparameters and run metadata
config=conf,
)

seed = 1000 if "seed" not in conf else conf["seed"]
seed = conf["seed"]
seed_everything(seed)

local_rank, world_rank, world_size = get_rank_info(conf["trainer"]["mode"])
Expand Down
6 changes: 4 additions & 2 deletions credit/trainers/trainerERA5_multistep_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def train_one_epoch(
Returns:
dict: Dictionary containing training metrics and loss for the epoch.
"""

batches_per_epoch = conf["trainer"]["batches_per_epoch"]
grad_max_norm = conf["trainer"]["grad_max_norm"]
amp = conf["trainer"]["amp"]
distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False
forecast_length = conf["data"]["forecast_len"]
Expand Down Expand Up @@ -312,7 +313,8 @@ def train_one_epoch(

if distributed:
torch.distributed.barrier()


torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=grad_max_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Expand Down
5 changes: 4 additions & 1 deletion credit/trainers/trainerERA5_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def train_one_epoch(
# training hyperparameters
batches_per_epoch = conf["trainer"]["batches_per_epoch"]
grad_accum_every = conf["trainer"]["grad_accum_every"]
grad_max_norm = conf["trainer"]["grad_max_norm"]
forecast_len = conf["data"]["forecast_len"]
amp = conf["trainer"]["amp"]
distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False
Expand Down Expand Up @@ -121,6 +122,7 @@ def train_one_epoch(
results_dict = defaultdict(list)

for i, batch in batch_group_generator:

# training log
logs = {}
# loss
Expand Down Expand Up @@ -264,7 +266,8 @@ def train_one_epoch(

if distributed:
torch.distributed.barrier()


torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=grad_max_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Expand Down

0 comments on commit fc2f8d1

Please sign in to comment.