Skip to content

Commit

Permalink
Remove batch-static tensor from dataset class and models (#13)
Browse files Browse the repository at this point in the history
* Bake the batch-static features into the normal forcing in the MEPS Dataset class.
* Change the Dataset class to only return 3 tensors per sample (init, target, forcing).
* Remove the batch-static tensor from being extracted from the batch and passed around in the graph-based models. This while making sure that input dimensions line up so older checkpoints can still be loaded correctly.
  • Loading branch information
joeloskarsson authored Mar 18, 2024
1 parent 0669ff4 commit b0050b9
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 34 deletions.
7 changes: 4 additions & 3 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main():
squares = []
flux_means = []
flux_squares = []
for init_batch, target_batch, _, forcing_batch in tqdm(loader):
for init_batch, target_batch, forcing_batch in tqdm(loader):
batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t, N_grid, d_features)
Expand All @@ -91,7 +91,8 @@ def main():
torch.mean(batch**2, dim=(1, 2))
) # (N_batch, d_features,)

flux_batch = forcing_batch[:, :, :, 0] # Flux is first index
# Flux at 1st windowed position is index 1 in forcing
flux_batch = forcing_batch[:, :, :, 1]
flux_means.append(torch.mean(flux_batch)) # (,)
flux_squares.append(torch.mean(flux_batch**2)) # (,)

Expand Down Expand Up @@ -125,7 +126,7 @@ def main():

diff_means = []
diff_squares = []
for init_batch, target_batch, _, _ in tqdm(loader_standard):
for init_batch, target_batch, _ in tqdm(loader_standard):
batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t', N_grid, d_features)
Expand Down
3 changes: 1 addition & 2 deletions neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,5 @@
)

# Data dimensions
BATCH_STATIC_FEATURE_DIM = 1 # Only open water
GRID_FORCING_DIM = 5 * 3 # 5 features for 3 time-step window
GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static
GRID_STATE_DIM = 17
23 changes: 6 additions & 17 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, args):
persistent=False,
)

# grid_dim from data + static + batch_static
# grid_dim from data + static
(
self.num_grid_nodes,
grid_static_dim,
Expand All @@ -62,7 +62,6 @@ def __init__(self, args):
2 * constants.GRID_STATE_DIM
+ grid_static_dim
+ constants.GRID_FORCING_DIM
+ constants.BATCH_STATIC_FEATURE_DIM
)

# Instantiate loss function
Expand Down Expand Up @@ -117,25 +116,19 @@ def expand_to_batch(x, batch_size):
"""
return x.unsqueeze(0).expand(batch_size, -1, -1)

def predict_step(
self, prev_state, prev_prev_state, batch_static_features, forcing
):
def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
batch_static_features: (B, num_grid_nodes, batch_static_feature_dim)
forcing: (B, num_grid_nodes, forcing_dim)
"""
raise NotImplementedError("No prediction step implemented")

def unroll_prediction(
self, init_states, batch_static_features, forcing_features, true_states
):
def unroll_prediction(self, init_states, forcing_features, true_states):
"""
Roll out prediction taking multiple autoregressive steps with model
init_states: (B, 2, num_grid_nodes, d_f)
batch_static_features: (B, num_grid_nodes, d_static_f)
forcing_features: (B, pred_steps, num_grid_nodes, d_static_f)
true_states: (B, pred_steps, num_grid_nodes, d_f)
"""
Expand All @@ -150,7 +143,7 @@ def unroll_prediction(
border_state = true_states[:, i]

pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, batch_static_features, forcing
prev_state, prev_prev_state, forcing
)
# state: (B, num_grid_nodes, d_f)
# pred_std: (B, num_grid_nodes, d_f) or None
Expand Down Expand Up @@ -184,24 +177,20 @@ def unroll_prediction(
def common_step(self, batch):
"""
Predict on single batch
batch = time_series, batch_static_features, forcing_features
batch consists of:
init_states: (B, 2, num_grid_nodes, d_features)
target_states: (B, pred_steps, num_grid_nodes, d_features)
batch_static_features: (B, num_grid_nodes, d_static_f),
for example open water
forcing_features: (B, pred_steps, num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
"""
(
init_states,
target_states,
batch_static_features,
forcing_features,
) = batch

prediction, pred_std = self.unroll_prediction(
init_states, batch_static_features, forcing_features, target_states
init_states, forcing_features, target_states
) # (B, pred_steps, num_grid_nodes, d_f)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
Expand Down
8 changes: 2 additions & 6 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, args):
f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)"
)

# grid_dim from data + static + batch_static
# grid_dim from data + static
self.g2m_edges, g2m_dim = self.g2m_features.shape
self.m2g_edges, m2g_dim = self.m2g_features.shape

Expand Down Expand Up @@ -98,14 +98,11 @@ def process_step(self, mesh_rep):
"""
raise NotImplementedError("process_step not implemented")

def predict_step(
self, prev_state, prev_prev_state, batch_static_features, forcing
):
def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
batch_static_features: (B, num_grid_nodes, batch_static_feature_dim)
forcing: (B, num_grid_nodes, forcing_dim)
"""
batch_size = prev_state.shape[0]
Expand All @@ -115,7 +112,6 @@ def predict_step(
(
prev_state,
prev_prev_state,
batch_static_features,
forcing,
self.expand_to_batch(self.grid_static_features, batch_size),
),
Expand Down
22 changes: 16 additions & 6 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,27 @@ def __getitem__(self, idx):
init_states = sample[:2] # (2, N_grid, d_features)
target_states = sample[2:] # (sample_length-2, N_grid, d_features)

# === Static batch features ===
# === Forcing features ===
# Now batch-static features are just part of forcing,
# repeated over temporal dimension
# Load water coverage
sample_datetime = sample_name[:10]
water_path = os.path.join(
self.sample_dir_path, f"wtr_{sample_datetime}.npy"
)
static_features = torch.tensor(
water_cover_features = torch.tensor(
np.load(water_path), dtype=torch.float32
).unsqueeze(
-1
) # (dim_x, dim_y, 1)
# Flatten
static_features = static_features.flatten(0, 1) # (N_grid, 1)
water_cover_features = water_cover_features.flatten(0, 1) # (N_grid, 1)
# Expand over temporal dimension
water_cover_expanded = water_cover_features.unsqueeze(0).expand(
self.sample_length - 2, -1, -1 # -2 as added on after windowing
) # (sample_len, N_grid, 1)

# === Forcing features ===
# Forcing features
# TOA flux
flux_path = os.path.join(
self.sample_dir_path,
f"nwp_toa_downwelling_shortwave_flux_{sample_datetime}.npy",
Expand Down Expand Up @@ -247,4 +252,9 @@ def __getitem__(self, idx):
) # (sample_len-2, N_grid, 3*d_forcing)
# Now index 0 of ^ corresponds to forcing at index 0-2 of sample

return init_states, target_states, static_features, forcing_windowed
# batch-static water cover is added after windowing,
# as it is static over time
forcing = torch.cat((water_cover_expanded, forcing_windowed), dim=2)
# (sample_len-2, N_grid, forcing_dim)

return init_states, target_states, forcing

0 comments on commit b0050b9

Please sign in to comment.