Skip to content

Commit

Permalink
addressing reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Jan 16, 2025
1 parent 190afc9 commit 9b3f7c8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion config/test_cesm_ensemble.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ model:
pad_lat: 48 # number of grids to pad on -90 and 90 deg lat

post_conf:
activate: True
activate: False
#this scaling maps your variables to the ERA5 units
#make sure to adjust if your timestep is not 6 hours
requires_scaling: True
Expand Down
7 changes: 4 additions & 3 deletions credit/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ def __init__(self, conf, predict_mode=False):
# DO NOT apply these weights during metrics computations, only on the loss during
self.w_var = None

self.ensemble_size = conf["trainer"]["ensemble_size"]
self.ensemble_size = conf["trainer"].get("ensemble_size", 1) # default value of 1 if not set

def __call__(self, pred, y, clim=None, transform=None, forecast_datetime=0):
if transform is not None:
pred = transform(pred)
y = transform(y)

# calculate ensemble mean, if ensemble_size=1, does nothing
pred = pred.view(y.shape[0], self.ensemble_size, *y.shape[1:]) #b, ensemble, c, t, lat, lon
pred = pred.mean(dim=1)
if self.ensemble_size > 1:
pred = pred.view(y.shape[0], self.ensemble_size, *y.shape[1:]) #b, ensemble, c, t, lat, lon
pred = pred.mean(dim=1)

# Get latitude and variable weights
w_lat = (
Expand Down
6 changes: 4 additions & 2 deletions credit/trainers/trainerERA5_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def train_one_epoch(
# copies each sample in the batch ensemble_size number of times.
# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...)
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0)
if conf["trainer"].get("ensemble_size", 1) > 1: # gets value if set, otherwise is 1
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0)

# single step predict
y_pred = self.model(x)
Expand Down Expand Up @@ -476,7 +477,8 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics):
# copies each sample in the batch ensemble_size number of times.
# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...)
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0)
if conf["trainer"].get("ensemble_size", 1) > 1: # gets value if set, otherwise is 1
x = torch.repeat_interleave(x, conf["trainer"]["ensemble_size"], 0)

y_pred = self.model(x)

Expand Down

0 comments on commit 9b3f7c8

Please sign in to comment.