Skip to content

Commit

Permalink
Merge branch 'main' into fix_rollout_netcdf
Browse files Browse the repository at this point in the history
  • Loading branch information
kanz76 committed Nov 24, 2024
2 parents 554c15f + 6fe8c7b commit 1c3d77d
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 44 deletions.
18 changes: 17 additions & 1 deletion applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,16 @@ def predict(rank, world_size, conf, p):
+ len(conf["data"]["forcing_variables"])
+ len(conf["data"]["static_variables"])
)


# ------------------------------------------------------- #
# clamp to remove outliers
if conf["data"]["data_clamp"] is None:
flag_clamp = False
else:
flag_clamp = True
clamp_min = float(conf["data"]["data_clamp"][0])
clamp_max = float(conf["data"]["data_clamp"][1])

# ====================================================== #
# postblock opts outside of model
post_conf = conf["model"]["post_conf"]
Expand Down Expand Up @@ -647,6 +656,13 @@ def predict(rank, world_size, conf, p):

# -------------------------------------------------------------------------------------- #
# start prediction

# --------------------------------------------- #
# clamp
if flag_clamp:
x = torch.clamp(x, min=clamp_min, max=clamp_max)
#y = torch.clamp(y, min=clamp_min, max=clamp_max)

y_pred = model(x)

# ============================================= #
Expand Down
6 changes: 2 additions & 4 deletions config/example_physics_single.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
# the FuXi architecture has been modified to reduce the overall model size
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
# example
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/$USER/CREDIT_runs/fuxi_conserve/'
seed: 1000
Expand Down Expand Up @@ -47,6 +44,7 @@ data:

# data workflow
scaler_type: 'std_new'
data_clamp: [-16, 16]

# number of input states
# FuXi has 2 input states
Expand Down
46 changes: 18 additions & 28 deletions credit/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def latitude_weights(conf):
return L


def variable_weights(conf, channels, surface_channels, frames):
def variable_weights(conf, channels, frames):
"""Create variable-specific weights for different atmospheric
and surface channels.
Expand All @@ -382,7 +382,6 @@ def variable_weights(conf, channels, surface_channels, frames):
conf (dict): Configuration dictionary containing the
variable weights.
channels (int): Number of channels for atmospheric variables.
surface_channels (int): Number of channels for surface variables.
frames (int): Number of time frames.
Returns:
Expand All @@ -393,41 +392,25 @@ def variable_weights(conf, channels, surface_channels, frames):
varname_upper_air = conf["data"]["variables"]
varname_surface = conf["data"]["surface_variables"]
varname_diagnostics = conf["data"]["diagnostic_variables"]
# N_levels = conf['data']['levels']

# weights_UVTQ = torch.tensor([
# conf["loss"]["variable_weights"]["U"],
# conf["loss"]["variable_weights"]["V"],
# conf["loss"]["variable_weights"]["T"],
# conf["loss"]["variable_weights"]["Q"]
# ]).view(1, channels * frames, 1, 1)

weights_UVTQ = torch.tensor(
# surface + diag channels
N_channels_single = len(varname_surface) + len(varname_diagnostics)

weights_upper_air = torch.tensor(
[conf["loss"]["variable_weights"][var] for var in varname_upper_air]
).view(1, channels * frames, 1, 1)

# Load weights for SP, t2m, V500, U500, T500, Z500, Q500
# weights_sfc = torch.tensor([
# conf["loss"]["variable_weights"]["SP"],
# conf["loss"]["variable_weights"]["t2m"],
# conf["loss"]["variable_weights"]["V500"],
# conf["loss"]["variable_weights"]["U500"],
# conf["loss"]["variable_weights"]["T500"],
# conf["loss"]["variable_weights"]["Z500"],
# conf["loss"]["variable_weights"]["Q500"]
# ]).view(1, surface_channels, 1, 1)

weights_sfc = torch.tensor(

weights_single = torch.tensor(
[
conf["loss"]["variable_weights"][var]
for var in (varname_surface + varname_diagnostics)
]
).view(1, surface_channels, 1, 1)
).view(1, N_channels_single, 1, 1)

# Combine all weights along the color channel
variable_weights = torch.cat([weights_UVTQ, weights_sfc], dim=1)
var_weights = torch.cat([weights_upper_air, weights_single], dim=1)

return variable_weights
return var_weights


class VariableTotalLoss2D(torch.nn.Module):
Expand Down Expand Up @@ -471,18 +454,25 @@ def __init__(self, conf, validation=False):
logger.info("Using latitude weights in loss calculations")
self.lat_weights = latitude_weights(conf)[:, 10].unsqueeze(0).unsqueeze(-1)

# ------------------------------------------------------------- #
# variable weights
# order: upper air --> surface --> diagnostics
self.var_weights = None
if conf["loss"]["use_variable_weights"]:
logger.info("Using variable weights in loss calculations")

var_weights = [
value if isinstance(value, list) else [value]
for value in conf["loss"]["variable_weights"].values()
]

var_weights = np.array(
[item for sublist in var_weights for item in sublist]
)

self.var_weights = torch.from_numpy(var_weights)

# ------------------------------------------------------------- #

self.use_spectral_loss = conf["loss"]["use_spectral_loss"]
if self.use_spectral_loss:
self.spectral_lambda_reg = conf["loss"]["spectral_lambda_reg"]
Expand Down
45 changes: 39 additions & 6 deletions credit/models/fuxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,15 @@ class UTransformer(nn.Module):
"""

def __init__(
self, embed_dim, num_groups, input_resolution, num_heads, window_size, depth
self, embed_dim,
num_groups,
input_resolution,
num_heads,
window_size,
depth,
proj_drop,
attn_drop,
drop_path
):
super().__init__()
num_groups = to_2tuple(num_groups)
Expand All @@ -248,7 +256,15 @@ def __init__(

# SwinT block
self.layer = SwinTransformerV2Stage(
embed_dim, embed_dim, input_resolution, depth, num_heads, window_size[0]
embed_dim,
embed_dim,
input_resolution,
depth,
num_heads,
window_size[0],
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
) # <--- window_size[0] get window_size[int] from tuple

# up-sampling block
Expand Down Expand Up @@ -315,6 +331,9 @@ def __init__(
window_size=7,
use_spectral_norm=True,
interp=True,
proj_drop=0,
attn_drop=0,
drop_path=0,
padding_conf=None,
post_conf=None,
**kwargs,
Expand All @@ -323,11 +342,15 @@ def __init__(

self.use_interp = interp
self.use_spectral_norm = use_spectral_norm

if padding_conf is None:
padding_conf = {"activate": False}

self.use_padding = padding_conf["activate"]

if post_conf is None:
post_conf = {"activate": False}

self.use_post_block = post_conf["activate"]

# input tensor size (time, lat, lon)
Expand Down Expand Up @@ -362,8 +385,17 @@ def __init__(
self.cube_embedding = CubeEmbedding(img_size, patch_size, in_chans, dim)

# Downsampling --> SwinTransformerV2 stacks --> Upsampling
logger.info(f"Define UTransforme with proj_drop={proj_drop}, attn_drop={attn_drop}, drop_path={drop_path}")

self.u_transformer = UTransformer(
dim, num_groups, input_resolution, num_heads, window_size, depth=depth
dim, num_groups,
input_resolution,
num_heads,
window_size,
depth=depth,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
)

# dense layer applied on channel dmension
Expand All @@ -383,11 +415,12 @@ def __init__(
if self.use_padding:
self.padding_opt = TensorPadding(**padding_conf)

# Move the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(device)

if self.use_spectral_norm:
logger.info("Adding spectral norm to all conv and linear layers")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move the model to the device
self.to(device)
apply_spectral_norm(self)

if self.use_post_block:
Expand Down
47 changes: 43 additions & 4 deletions credit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def credit_main_parser(
)

## I/O data sizes

conf["data"].setdefault("data_clamp", None)

if parse_training:
assert (
"train_years" in conf["data"]
Expand All @@ -264,7 +267,7 @@ def credit_main_parser(
assert (
"forecast_len" in conf["data"]
), "Number of time frames for loss compute ('forecast_len') is missing from conf['data']"

if "valid_history_len" not in conf["data"]:
# use "history_len" for "valid_history_len"
conf["data"]["valid_history_len"] = conf["data"]["history_len"]
Expand Down Expand Up @@ -420,8 +423,8 @@ def credit_main_parser(
)

# # debug only
# conf['model']['post_conf']['varname_input'] = varname_input
# conf['model']['post_conf']['varname_output'] = varname_output
conf['model']['post_conf']['varname_input'] = varname_input
conf['model']['post_conf']['varname_output'] = varname_output
# --------------------------------------------------------------------- #

# SKEBS
Expand Down Expand Up @@ -478,6 +481,7 @@ def credit_main_parser(
conf["model"]["post_conf"]["global_mass_fixer"].setdefault("denorm", True)
conf["model"]["post_conf"]["global_mass_fixer"].setdefault("simple_demo", False)
conf["model"]["post_conf"]["global_mass_fixer"].setdefault("midpoint", False)
conf['model']['post_conf']['global_mass_fixer'].setdefault('grid_type', 'pressure')

assert (
"fix_level_num" in conf["model"]["post_conf"]["global_mass_fixer"]
Expand All @@ -487,7 +491,11 @@ def credit_main_parser(
assert (
"lon_lat_level_name" in conf["model"]["post_conf"]["global_mass_fixer"]
), "Must specifiy var names for lat/lon/level in physics reference file"


if conf['model']['post_conf']['global_mass_fixer']['grid_type'] == 'sigma':
assert 'surface_pressure_name' in conf['model']['post_conf']['global_mass_fixer'], (
'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates')

q_inds = [
i_var
for i_var, var in enumerate(varname_output)
Expand All @@ -498,6 +506,13 @@ def credit_main_parser(
]
conf["model"]["post_conf"]["global_mass_fixer"]["q_inds"] = q_inds

if conf['model']['post_conf']['global_mass_fixer']['grid_type'] == 'sigma':
sp_inds = [
i_var for i_var, var in enumerate(varname_output)
if var in conf['model']['post_conf']['global_mass_fixer']['surface_pressure_name']
]
conf['model']['post_conf']['global_mass_fixer']['sp_inds'] = sp_inds[0]

# --------------------------------------------------------------------- #
# global water fixer
flag_water = (
Expand All @@ -518,12 +533,17 @@ def credit_main_parser(
"simple_demo", False
)
conf["model"]["post_conf"]["global_water_fixer"].setdefault("midpoint", False)
conf['model']['post_conf']['global_water_fixer'].setdefault('grid_type', 'pressure')

if conf["model"]["post_conf"]["global_water_fixer"]["simple_demo"] is False:
assert (
"lon_lat_level_name" in conf["model"]["post_conf"]["global_water_fixer"]
), "Must specifiy var names for lat/lon/level in physics reference file"

if conf['model']['post_conf']['global_water_fixer']['grid_type'] == 'sigma':
assert 'surface_pressure_name' in conf['model']['post_conf']['global_water_fixer'], (
'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates')

q_inds = [
i_var
for i_var, var in enumerate(varname_output)
Expand Down Expand Up @@ -551,6 +571,13 @@ def credit_main_parser(
conf["model"]["post_conf"]["global_water_fixer"]["precip_ind"] = precip_inds[0]
conf["model"]["post_conf"]["global_water_fixer"]["evapor_ind"] = evapor_inds[0]

if conf['model']['post_conf']['global_water_fixer']['grid_type'] == 'sigma':
sp_inds = [
i_var for i_var, var in enumerate(varname_output)
if var in conf['model']['post_conf']['global_water_fixer']['surface_pressure_name']
]
conf['model']['post_conf']['global_water_fixer']['sp_inds'] = sp_inds[0]

# --------------------------------------------------------------------- #
# global energy fixer
flag_energy = (
Expand All @@ -571,13 +598,18 @@ def credit_main_parser(
"simple_demo", False
)
conf["model"]["post_conf"]["global_energy_fixer"].setdefault("midpoint", False)
conf['model']['post_conf']['global_energy_fixer'].setdefault('grid_type', 'pressure')

if conf["model"]["post_conf"]["global_energy_fixer"]["simple_demo"] is False:
assert (
"lon_lat_level_name"
in conf["model"]["post_conf"]["global_energy_fixer"]
), "Must specifiy var names for lat/lon/level in physics reference file"

if conf['model']['post_conf']['global_energy_fixer']['grid_type'] == 'sigma':
assert 'surface_pressure_name' in conf['model']['post_conf']['global_energy_fixer'], (
'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates')

T_inds = [
i_var
for i_var, var in enumerate(varname_output)
Expand Down Expand Up @@ -645,6 +677,13 @@ def credit_main_parser(
surf_flux_inds
)

if conf['model']['post_conf']['global_energy_fixer']['grid_type'] == 'sigma':
sp_inds = [
i_var for i_var, var in enumerate(varname_output)
if var in conf['model']['post_conf']['global_energy_fixer']['surface_pressure_name']
]
conf['model']['post_conf']['global_energy_fixer']['sp_inds'] = sp_inds[0]

# --------------------------------------------------------- #
# conf['trainer'] section

Expand Down
Loading

0 comments on commit 1c3d77d

Please sign in to comment.