Skip to content

Commit

Permalink
Restore mpi in rollout_to_netcdf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kanz76 committed Nov 24, 2024
1 parent c6787b6 commit 554c15f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# credit
from credit.models import load_model
from credit.seed import seed_everything
from credit.distributed import get_rank_info

from credit.data import (
concat_and_reshape,
Expand Down Expand Up @@ -951,11 +952,11 @@ def predict(rank, world_size, conf, p):
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

local_rank, world_rank, world_size = get_rank_info(conf["trainer"]["mode"])

with mp.Pool(num_cpus) as p:
if conf["predict"]["mode"] in ["fsdp", "ddp"]: # multi-gpu inference
_ = predict(
int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]), conf, p=p
)
_ = predict(world_rank, world_size, conf, p=p)
else: # single device inference
_ = predict(0, 1, conf, p=p)

Expand Down

0 comments on commit 554c15f

Please sign in to comment.