Skip to content

Commit

Permalink
Merge pull request #146 from NCAR/multistep_bugs
Browse files Browse the repository at this point in the history
Bug fix on the original multi-step dataset
  • Loading branch information
jsschreck authored Jan 9, 2025
2 parents 58fbbd6 + b65b7ab commit 85d6810
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 346 deletions.
34 changes: 23 additions & 11 deletions config/example-v2025.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ data:
# files must have the listed variable names
# upper-air variables will be normalized by the dataloader, users do not need to normalize them
variables: ['U','V','T','Q']
save_loc: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/SixHourly_TOTAL_*'
save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL_*'

# surface variables, must be YEARLY zarr or nc with (time, latitude, longitude) dims
# the time dimension MUST be the same as upper-air variables
# files must have the listed variable names
# surface variables will be normalized by the dataloader, users do not need to normalize them
surface_variables: ['sp', 't2m']
save_loc_surface: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/SixHourly_TOTAL_*'
surface_variables: ['SP', 't2m', 'Z500', 'T500', 'U500', 'V500', 'Q500']
save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL_*'

# dynamic forcing variables, must be YEARLY zarr or nc with (time, latitude, longitude) dims
# the time dimension MUST be the same as upper-air variables
Expand All @@ -46,30 +46,32 @@ data:
# diagnostic variables, must be YEARLY zarr or nc with (time, latitude, longitude) dims
# the time dimension MUST be the same as upper-air variables
# files must have the listed variable names
# THESE ARE NOT LOADING CORRECTLY ON X
# diagnostic variables will be normalized by the dataloader, users do not need to normalize them
diagnostic_variables: ['Z500', 'T500']
save_loc_diagnostic: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/SixHourly_TOTAL_*'
# diagnostic_variables: ['Z500', 'T500', 'U500', 'V500', 'Q500']
# save_loc_diagnostic: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/SixHourly_y_TOTAL_*'

# periodic forcing variables, must be a single zarr or nc with (time, latitude, longitude) dims
# the time dimension should cover an entire LEAP YEAR
#
# e.g., periodic forcing variables can be provided on the year 2000,
# and its time coords should have 24*366 hours for a hourly model
#
# THESE ARE REDUNDANT?
# periodic forcing variables MUST be normalized BY USER
forcing_variables: ['TSI']
save_loc_forcing: '/glade/campaign/cisl/aiml/ksha/CREDIT/forcing_norm_6h.nc'
# forcing_variables: ['TSI']
# save_loc_forcing: '/glade/campaign/cisl/aiml/ksha/CREDIT/forcing_norm_6h.nc'

# static variables must be a single zarr or nc with (latitude, longitude) coords
# static variables must be normalized BY USER
static_variables: ['Z_GDS4_SFC', 'LSM']
save_loc_static: '/glade/campaign/cisl/aiml/ksha/CREDIT/static_norm_old.nc'
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'

# z-score files, they must be zarr or nc with (level,) coords
# they MUST include all the
# 'variables', 'surface_variables', 'dynamic_forcing_variables', 'diagnostic_variables' above
mean_path: '/glade/campaign/cisl/aiml/ksha/CREDIT/mean_6h_0.25deg.nc'
std_path: '/glade/campaign/cisl/aiml/ksha/CREDIT/std_6h_0.25deg.nc'
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'

# years to form the training / validation set [first_year, last_yeat (not covered)]
train_years: [1979, 2014] # 1979 - 2013
Expand Down Expand Up @@ -100,7 +102,7 @@ data:
# this option speed-up multi-step training, it ONLY works with trainer type: standard
# use one_shot --> True
# do not use one_shot --> False or null
one_shot: True
one_shot: False

# number of hours for each forecast step
# lead_time_periods = 6 for 6 hourly model and 6 hourly taining data
Expand Down Expand Up @@ -383,6 +385,16 @@ predict:
# A directory called metrics will be added to the save_forecast field below.
# To instead save the states to file (netcdf), run rollout_to_metrics.py. You will need to post-process
# all the data with this option. A directory called netcdf will be added to the save_forecast field


# the keyword that controls GPU usage
# fsdp: fully sharded data parallel
# ddp: distributed data parallel
# none: single-GPU training
mode: none

# Set the batch_size for the prediction inference mode (default is 1)
batch_size: 1

forecasts:
type: "custom" # keep it as "custom". See also credit.forecast
Expand Down
173 changes: 136 additions & 37 deletions credit/datasets/era5_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def worker(

# for multi-input cases, use time=-1 ocean SKT for all times
if history_len > 1:
input_skt[: history_len - 1] = input_skt[
: history_len - 1
].where(~ocean_mask_bool, input_skt.isel(time=-1))
input_skt[: history_len - 1] = input_skt[: history_len - 1].where(
~ocean_mask_bool, input_skt.isel(time=-1)
)

# for target skt, replace ocean values using time=-1 input SKT
target_skt = target_skt.where(~ocean_mask_bool, input_skt.isel(time=-1))
Expand Down Expand Up @@ -250,6 +250,85 @@ def worker(
return sample


class RepeatingIndexSampler(torch.utils.data.Sampler):
def __init__(
self,
dataset,
forecast_len,
skip_periods=1,
shuffle=True,
seed=42,
rank=0,
num_replicas=1,
):
"""
Sampler that yields each starting index repeated (forecast_len + 1) times,
ensuring indices don't exceed the valid range for a full forecast sequence.
Supports distributed sampling.
Args:
- dataset (Dataset): The dataset to sample from.
- forecast_len (int): Length of each forecast sequence minus one.
- skip_periods (int): Number of periods to skip between sequences.
- shuffle (bool): Whether to shuffle the starting indices.
- seed (int): Random seed for reproducibility.
- rank (int): Rank of the current process (for distributed training).
- world_size (int): Total number of processes (for distributed training).
"""
self.dataset = dataset
self.forecast_len = forecast_len + 1 # Total steps in the forecast sequence
self.skip_periods = skip_periods
self.shuffle = shuffle
self.seed = seed
self.rank = rank
self.num_replicas = num_replicas

# Compute valid starting indices ensuring full sequences fit
all_start_indices = list(range(0, len(self.dataset), skip_periods))

num_indices = len(
all_start_indices
) # Trim the number of indices to ensure it's divisible by world_size
num_indices_per_rank = num_indices // self.num_replicas
all_start_indices = all_start_indices[
: num_indices_per_rank * self.num_replicas
]
self.all_start_indices = all_start_indices
self.num_indices_per_rank = num_indices_per_rank

if self.shuffle:
self.rng = np.random.default_rng(seed)
# rng.shuffle(self.start_indices)

def __len__(self):
"""Returns the total number of indices for this rank."""
# return len(self.start_indices) * self.forecast_len
return self.num_indices_per_rank * self.forecast_len

def __iter__(self):
"""
Yields each start index repeated (forecast_len + 1) times.
"""
all_indices = self.all_start_indices
if self.shuffle:
all_indices = self.rng.permutation(all_indices)

self.start_indices = all_indices[self.rank :: self.num_replicas]
assert len(self.start_indices) == self.num_indices_per_rank
for idx in self.start_indices:
for _ in range(self.forecast_len):
yield idx

def batches_per_epoch(self):
"""
Computes the number of batches per epoch for a given batch size.
Returns:
- int: Number of batches per epoch.
"""
return self.num_indices_per_rank


class ERA5_and_Forcing_MultiStep(torch.utils.data.Dataset):
"""
A Pytorch Dataset class that works on:
Expand Down Expand Up @@ -284,7 +363,7 @@ def __init__(
skip_periods=None,
one_shot=None,
max_forecast_len=None,
sst_forcing=None
sst_forcing=None,
):
"""
Initialize the ERA5_and_Forcing_Dataset
Expand Down Expand Up @@ -493,7 +572,7 @@ def __init__(
forecast_len=self.forecast_len,
skip_periods=self.skip_periods,
transform=self.transform,
sst_forcing=self.sst_forcing
sst_forcing=self.sst_forcing,
)

self.total_length = len(self.ERA5_indices)
Expand Down Expand Up @@ -549,17 +628,16 @@ def __getitem__(self, index):


if __name__ == "__main__":

import torch
import yaml
from torch.utils.data import DataLoader
from credit.transforms import load_transforms
from credit.parser import credit_main_parser, training_data_check
from credit.datasets import setup_data_loading, set_globals

with open(
"/glade/derecho/scratch/schreck/repos/miles-credit/production/multistep/wxformer_6h/model.yml"
) as cf:
# filename = "/glade/derecho/scratch/schreck/finetune/arnold/model_xform.yml"
filename = "../../config/example-v2025.2.0.yml"
with open(filename) as cf:
conf = yaml.load(cf, Loader=yaml.FullLoader)

conf = credit_main_parser(
Expand All @@ -569,44 +647,65 @@ def __getitem__(self, index):

data_config = setup_data_loading(conf)

data_config["forecast_len"] = 6
batch_size = 2
data_config["forecast_len"] = 5
batch_size = 1
training_type = "train"

set_globals(data_config, namespace=globals())

dataset_multi = ERA5_and_Forcing_MultiStep(
varname_upper_air=data_config['varname_upper_air'],
varname_surface=data_config['varname_surface'],
varname_dyn_forcing=data_config['varname_dyn_forcing'],
varname_forcing=data_config['varname_forcing'],
varname_static=data_config['varname_static'],
varname_diagnostic=data_config['varname_diagnostic'],
filenames=data_config['all_ERA_files'],
filename_surface=data_config['surface_files'],
filename_dyn_forcing=data_config['dyn_forcing_files'],
filename_forcing=data_config['forcing_files'],
filename_static=data_config['static_files'],
filename_diagnostic=data_config['diagnostic_files'],
history_len=data_config['history_len'],
forecast_len=data_config['forecast_len'],
skip_periods=data_config['skip_periods'],
varname_upper_air=data_config["varname_upper_air"],
varname_surface=data_config["varname_surface"],
varname_dyn_forcing=data_config["varname_dyn_forcing"],
varname_forcing=data_config["varname_forcing"],
varname_static=data_config["varname_static"],
varname_diagnostic=data_config["varname_diagnostic"],
filenames=data_config["all_ERA_files"],
filename_surface=data_config["surface_files"],
filename_dyn_forcing=data_config["dyn_forcing_files"],
filename_forcing=data_config["forcing_files"],
filename_static=data_config["static_files"],
filename_diagnostic=data_config["diagnostic_files"],
history_len=data_config["history_len"],
forecast_len=data_config["forecast_len"],
skip_periods=data_config["skip_periods"],
one_shot=False,
max_forecast_len=data_config['max_forecast_len'],
sst_forcing=data_config['sst_forcing'],
transform=load_transforms(conf)
max_forecast_len=data_config["max_forecast_len"],
sst_forcing=data_config["sst_forcing"],
transform=load_transforms(conf),
)

sampler = RepeatingIndexSampler(
dataset_multi,
forecast_len=data_config["forecast_len"],
num_replicas=1,
rank=0,
seed=1000,
shuffle=True,
)

dataloader = DataLoader(
dataset_multi,
batch_size=1, # Adjust the batch size as needed
shuffle=True, # Shuffle the dataset if needed
num_workers=1, # Number of subprocesses to use for data loading (adjust as needed)
drop_last=True, # Drop the last incomplete batch if not divisible by batch_size,
prefetch_factor=4
batch_size=1,
shuffle=False,
sampler=sampler,
pin_memory=True,
persistent_workers=False,
num_workers=1, # set to one so prefetch is working
prefetch_factor=4,
)

dataloader.dataset.set_epoch(0)
for (k, sample) in enumerate(dataloader):
print(k, sample['index'], sample['datetime'], sample['forecast_step'], sample['stop_forecast'], sample["x"].shape, sample["x_surf"].shape)
if k == 20:
for k, sample in enumerate(dataloader):
print(
k,
sample["index"],
sample["datetime"],
sample["forecast_step"],
sample["stop_forecast"],
sample["x"].shape,
sample["x_surf"].shape,
# sample["x_forcing_static"].shape,
)
if k == 500:
break
Loading

0 comments on commit 85d6810

Please sign in to comment.