Skip to content

Commit

Permalink
add y_diag to y when computing metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Dec 9, 2024
1 parent 8daf8ab commit 672bb82
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 7 additions & 0 deletions applications/rollout_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ def predict(rank, world_size, conf, p):
# no y_surf
y = reshape_only(batch["y"]).to(device).float()

# adding diagnostic vars to y
if "y_diag" in batch:
y_diag_batch = (
batch["y_diag"].to(device).permute(0, 2, 1, 3, 4)
)
y = torch.cat((y, y_diag_batch), dim=1).to(device).float()

# -------------------------------------------------------------------------------------- #
# start prediction

Expand Down
9 changes: 8 additions & 1 deletion applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,14 @@ def predict(rank, world_size, conf, p):
else:
# no y_surf
y = reshape_only(batch["y"]).to(device).float()


# adding diagnostic vars to y
if "y_diag" in batch:
y_diag_batch = (
batch["y_diag"].to(device).permute(0, 2, 1, 3, 4)
)
y = torch.cat((y, y_diag_batch), dim=1).to(device).float()

# -------------------------------------------------------------------------------------- #
# start prediction

Expand Down

0 comments on commit 672bb82

Please sign in to comment.