Skip to content

Commit

Permalink
ruff reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Dec 8, 2024
1 parent 829525f commit 07c744b
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 122 deletions.
13 changes: 7 additions & 6 deletions applications/rollout_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"


def predict(rank, world_size, conf, p):
# setup rank and world size for GPU-based rollout
if conf["predict"]["mode"] in ["fsdp", "ddp"]:
Expand Down Expand Up @@ -127,7 +128,7 @@ def predict(rank, world_size, conf, p):
+ len(conf["data"]["forcing_variables"])
+ len(conf["data"]["static_variables"])
)

# ------------------------------------------------------- #
# clamp to remove outliers
if conf["data"]["data_clamp"] is None:
Expand All @@ -136,7 +137,7 @@ def predict(rank, world_size, conf, p):
flag_clamp = True
clamp_min = float(conf["data"]["data_clamp"][0])
clamp_max = float(conf["data"]["data_clamp"][1])

# ====================================================== #
# postblock opts outside of model
post_conf = conf["model"]["post_conf"]
Expand Down Expand Up @@ -188,7 +189,7 @@ def predict(rank, world_size, conf, p):
rollout_p=0.0,
which_forecast=None,
)

# setup the dataloder
data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down Expand Up @@ -282,12 +283,12 @@ def predict(rank, world_size, conf, p):

# -------------------------------------------------------------------------------------- #
# start prediction

# --------------------------------------------- #
# clamp
if flag_clamp:
x = torch.clamp(x, min=clamp_min, max=clamp_max)

y_pred = model(x)

# ============================================= #
Expand Down Expand Up @@ -332,7 +333,7 @@ def predict(rank, world_size, conf, p):
.unsqueeze(2)
.cpu()
)

# Compute metrics
metrics_dict = metrics(
y_pred.float(), y.float(), forecast_datetime=forecast_hour
Expand Down
9 changes: 5 additions & 4 deletions applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"


def predict(rank, world_size, conf, p):
# setup rank and world size for GPU-based rollout
if conf["predict"]["mode"] in ["fsdp", "ddp"]:
Expand Down Expand Up @@ -127,7 +128,7 @@ def predict(rank, world_size, conf, p):
+ len(conf["data"]["forcing_variables"])
+ len(conf["data"]["static_variables"])
)

# ------------------------------------------------------- #
# clamp to remove outliers
if conf["data"]["data_clamp"] is None:
Expand All @@ -136,7 +137,7 @@ def predict(rank, world_size, conf, p):
flag_clamp = True
clamp_min = float(conf["data"]["data_clamp"][0])
clamp_max = float(conf["data"]["data_clamp"][1])

# ====================================================== #
# postblock opts outside of model
post_conf = conf["model"]["post_conf"]
Expand Down Expand Up @@ -291,8 +292,8 @@ def predict(rank, world_size, conf, p):
# clamp
if flag_clamp:
x = torch.clamp(x, min=clamp_min, max=clamp_max)
#y = torch.clamp(y, min=clamp_min, max=clamp_max)
# y = torch.clamp(y, min=clamp_min, max=clamp_max)

y_pred = model(x)

# ============================================= #
Expand Down
2 changes: 1 addition & 1 deletion applications/train_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def train(self, trial, conf):
# track hyperparameters and run metadata
config=conf,
)

seed = conf["seed"]
seed_everything(seed)

Expand Down
12 changes: 6 additions & 6 deletions credit/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,11 @@ def variable_weights(conf, channels, frames):

# surface + diag channels
N_channels_single = len(varname_surface) + len(varname_diagnostics)

weights_upper_air = torch.tensor(
[conf["loss"]["variable_weights"][var] for var in varname_upper_air]
).view(1, channels * frames, 1, 1)

weights_single = torch.tensor(
[
conf["loss"]["variable_weights"][var]
Expand Down Expand Up @@ -460,19 +460,19 @@ def __init__(self, conf, validation=False):
self.var_weights = None
if conf["loss"]["use_variable_weights"]:
logger.info("Using variable weights in loss calculations")

var_weights = [
value if isinstance(value, list) else [value]
for value in conf["loss"]["variable_weights"].values()
]

var_weights = np.array(
[item for sublist in var_weights for item in sublist]
)

self.var_weights = torch.from_numpy(var_weights)
# ------------------------------------------------------------- #

self.use_spectral_loss = conf["loss"]["use_spectral_loss"]
if self.use_spectral_loss:
self.spectral_lambda_reg = conf["loss"]["spectral_lambda_reg"]
Expand Down
52 changes: 28 additions & 24 deletions credit/models/fuxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,16 @@ class UTransformer(nn.Module):
"""

def __init__(
self, embed_dim,
num_groups,
input_resolution,
num_heads,
window_size,
self,
embed_dim,
num_groups,
input_resolution,
num_heads,
window_size,
depth,
proj_drop,
attn_drop,
drop_path
drop_path,
):
super().__init__()
num_groups = to_2tuple(num_groups)
Expand All @@ -256,15 +257,15 @@ def __init__(

# SwinT block
self.layer = SwinTransformerV2Stage(
embed_dim,
embed_dim,
input_resolution,
depth,
num_heads,
embed_dim,
embed_dim,
input_resolution,
depth,
num_heads,
window_size[0],
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
drop_path=drop_path,
) # <--- window_size[0] get window_size[int] from tuple

# up-sampling block
Expand Down Expand Up @@ -342,15 +343,15 @@ def __init__(

self.use_interp = interp
self.use_spectral_norm = use_spectral_norm

if padding_conf is None:
padding_conf = {"activate": False}

self.use_padding = padding_conf["activate"]

if post_conf is None:
post_conf = {"activate": False}

self.use_post_block = post_conf["activate"]

# input tensor size (time, lat, lon)
Expand Down Expand Up @@ -385,17 +386,20 @@ def __init__(
self.cube_embedding = CubeEmbedding(img_size, patch_size, in_chans, dim)

# Downsampling --> SwinTransformerV2 stacks --> Upsampling
logger.info(f"Define UTransforme with proj_drop={proj_drop}, attn_drop={attn_drop}, drop_path={drop_path}")

logger.info(
f"Define UTransforme with proj_drop={proj_drop}, attn_drop={attn_drop}, drop_path={drop_path}"
)

self.u_transformer = UTransformer(
dim, num_groups,
input_resolution,
num_heads,
window_size,
dim,
num_groups,
input_resolution,
num_heads,
window_size,
depth=depth,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
drop_path=drop_path,
)

# dense layer applied on channel dmension
Expand All @@ -418,7 +422,7 @@ def __init__(
# Move the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(device)

if self.use_spectral_norm:
logger.info("Adding spectral norm to all conv and linear layers")
apply_spectral_norm(self)
Expand Down
Loading

0 comments on commit 07c744b

Please sign in to comment.