Skip to content

Commit

Permalink
bugfix variable weights in credit.loss
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Nov 16, 2024
1 parent b843a9d commit 41bb000
Showing 1 changed file with 18 additions and 28 deletions.
46 changes: 18 additions & 28 deletions credit/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def latitude_weights(conf):
return L


def variable_weights(conf, channels, surface_channels, frames):
def variable_weights(conf, channels, frames):
"""Create variable-specific weights for different atmospheric
and surface channels.
Expand All @@ -382,7 +382,6 @@ def variable_weights(conf, channels, surface_channels, frames):
conf (dict): Configuration dictionary containing the
variable weights.
channels (int): Number of channels for atmospheric variables.
surface_channels (int): Number of channels for surface variables.
frames (int): Number of time frames.
Returns:
Expand All @@ -393,41 +392,25 @@ def variable_weights(conf, channels, surface_channels, frames):
varname_upper_air = conf["data"]["variables"]
varname_surface = conf["data"]["surface_variables"]
varname_diagnostics = conf["data"]["diagnostic_variables"]
# N_levels = conf['data']['levels']

# weights_UVTQ = torch.tensor([
# conf["loss"]["variable_weights"]["U"],
# conf["loss"]["variable_weights"]["V"],
# conf["loss"]["variable_weights"]["T"],
# conf["loss"]["variable_weights"]["Q"]
# ]).view(1, channels * frames, 1, 1)

weights_UVTQ = torch.tensor(
# surface + diag channels
N_channels_single = len(varname_surface) + len(varname_diagnostics)

weights_upper_air = torch.tensor(
[conf["loss"]["variable_weights"][var] for var in varname_upper_air]
).view(1, channels * frames, 1, 1)

# Load weights for SP, t2m, V500, U500, T500, Z500, Q500
# weights_sfc = torch.tensor([
# conf["loss"]["variable_weights"]["SP"],
# conf["loss"]["variable_weights"]["t2m"],
# conf["loss"]["variable_weights"]["V500"],
# conf["loss"]["variable_weights"]["U500"],
# conf["loss"]["variable_weights"]["T500"],
# conf["loss"]["variable_weights"]["Z500"],
# conf["loss"]["variable_weights"]["Q500"]
# ]).view(1, surface_channels, 1, 1)

weights_sfc = torch.tensor(

weights_single = torch.tensor(
[
conf["loss"]["variable_weights"][var]
for var in (varname_surface + varname_diagnostics)
]
).view(1, surface_channels, 1, 1)
).view(1, N_channels_single, 1, 1)

# Combine all weights along the color channel
variable_weights = torch.cat([weights_UVTQ, weights_sfc], dim=1)
var_weights = torch.cat([weights_upper_air, weights_single], dim=1)

return variable_weights
return var_weights


class VariableTotalLoss2D(torch.nn.Module):
Expand Down Expand Up @@ -471,18 +454,25 @@ def __init__(self, conf, validation=False):
logger.info("Using latitude weights in loss calculations")
self.lat_weights = latitude_weights(conf)[:, 10].unsqueeze(0).unsqueeze(-1)

# ------------------------------------------------------------- #
# variable weights
# order: upper air --> surface --> diagnostics
self.var_weights = None
if conf["loss"]["use_variable_weights"]:
logger.info("Using variable weights in loss calculations")

var_weights = [
value if isinstance(value, list) else [value]
for value in conf["loss"]["variable_weights"].values()
]

var_weights = np.array(
[item for sublist in var_weights for item in sublist]
)

self.var_weights = torch.from_numpy(var_weights)

# ------------------------------------------------------------- #

self.use_spectral_loss = conf["loss"]["use_spectral_loss"]
if self.use_spectral_loss:
self.spectral_lambda_reg = conf["loss"]["spectral_lambda_reg"]
Expand Down

0 comments on commit 41bb000

Please sign in to comment.