Skip to content

Commit

Permalink
Implementation PR-review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Adamov committed May 21, 2024
1 parent 244284c commit 0ba441b
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 89 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Still, some restrictions are inevitable:
## A note on the limited area setting
Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
There are still some parts of the code that is quite specific for the MEPS area use case.
This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/data_config.yaml`).
This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ).
If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
We would be happy to support such enhancements.
See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
Expand Down
4 changes: 2 additions & 2 deletions create_grid_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

# First-party
from neural_lam import utils
from neural_lam import config


def main():
Expand All @@ -22,7 +22,7 @@ def main():
help="Path to data config file (default: neural_lam/data_config.yaml)",
)
args = parser.parse_args()
config_loader = utils.ConfigLoader(args.data_config)
config_loader = config.Config.from_file(args.data_config)

static_dir_path = os.path.join("data", config_loader.dataset.name, "static")

Expand Down
4 changes: 2 additions & 2 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch_geometric.utils.convert import from_networkx

# First-party
from neural_lam import utils
from neural_lam import config


def plot_graph(graph, title=None):
Expand Down Expand Up @@ -189,7 +189,7 @@ def main():
args = parser.parse_args()

# Load grid positions
config_loader = utils.ConfigLoader(args.data_config)
config_loader = config.Config.from_file(args.data_config)
static_dir_path = os.path.join("data", config_loader.dataset.name, "static")
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)
Expand Down
9 changes: 6 additions & 3 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm

# First-party
from neural_lam import utils
from neural_lam import config
from neural_lam.weather_dataset import WeatherDataset


Expand Down Expand Up @@ -43,7 +43,7 @@ def main():
)
args = parser.parse_args()

config_loader = utils.ConfigLoader(args.data_config)
config_loader = config.Config.from_file(args.data_config)
static_dir_path = os.path.join("data", config_loader.dataset.name, "static")

# Create parameter weights based on height
Expand All @@ -57,7 +57,10 @@ def main():
"500": 0.03,
}
w_list = np.array(
[w_dict[par.split("_")[-2]] for par in config_loader.dataset.var_names]
[
w_dict[par.split("_")[-2]]
for par in config_loader.dataset.var_longnames
]
)
print("Saving parameter weights...")
np.save(
Expand Down
59 changes: 59 additions & 0 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import functools
from pathlib import Path

import cartopy.crs as ccrs
import yaml


class Config:
"""
Class for loading configuration files.
This class loads a configuration file and provides a way to access its
values as attributes.
"""

def __init__(self, values):
self.values = values

@classmethod
def from_file(cls, filepath):
if filepath.endswith(".yaml"):
with open(filepath, encoding="utf-8", mode="r") as file:
return cls(values=yaml.safe_load(file))
else:
raise NotImplementedError(Path(filepath).suffix)

def __getattr__(self, name):
keys = name.split(".")
value = self.values
for key in keys:
if key in value:
value = value[key]
else:
return None
if isinstance(value, dict):
return Config(values=value)
return value

def __getitem__(self, key):
value = self.values[key]
if isinstance(value, dict):
return Config(values=value)
return value

def __contains__(self, key):
return key in self.values

def num_data_vars(self):
"""Return the number of data variables for a given key."""
return len(self.dataset.var_names)

@functools.cached_property
def coords_projection(self):
"""Return the projection."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)
8 changes: 4 additions & 4 deletions neural_lam/data_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dataset:
name: meps_example
vars:
var_names:
- pres_0g
- pres_0s
- nlwrs_0
Expand All @@ -18,7 +18,7 @@ dataset:
- wvint_0
- z_1000
- z_500
units:
var_units:
- Pa
- Pa
- r"$\mathrm{W}/\mathrm{m}^2$"
Expand All @@ -36,7 +36,7 @@ dataset:
- r"$\mathrm{kg}/\mathrm{m}^2$"
- r"$\mathrm{m}^2/\mathrm{s}^2$"
- r"$\mathrm{m}^2/\mathrm{s}^2$"
var_names:
var_longnames:
- pres_heightAboveGround_0_instant
- pres_heightAboveSea_0_instant
- nlwrs_heightAboveGround_0_accum
Expand All @@ -54,7 +54,7 @@ dataset:
- wvint_entireAtmosphere_0_instant
- z_isobaricInhPa_1000_instant
- z_isobaricInhPa_500_instant
forcing_dim: 16
num_forcing_features: 16
grid_shape_state: [268, 238]
projection:
class: LambertConformal
Expand Down
24 changes: 12 additions & 12 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import wandb

# First-party
from neural_lam import metrics, utils, vis
from neural_lam import config, metrics, utils, vis


class ARModel(pl.LightningModule):
Expand All @@ -25,7 +25,7 @@ def __init__(self, args):
super().__init__()
self.save_hyperparameters()
self.args = args
self.config_loader = utils.ConfigLoader(args.data_config)
self.config_loader = config.Config.from_file(args.data_config)

# Load static features for grid/data
static_data_dict = utils.load_static_data(
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, args):
self.grid_dim = (
2 * self.config_loader.num_data_vars()
+ grid_static_dim
+ self.config_loader.dataset.forcing_dim
+ self.config_loader.dataset.num_forcing_features
)

# Instantiate loss function
Expand Down Expand Up @@ -246,7 +246,7 @@ def validation_step(self, batch, batch_idx):
# Log loss per time step forward and mean
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_log
for step in self.args.val_steps_to_log
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_step(self, batch, batch_idx):
# Log loss per time step forward and mean
test_log_dict = {
f"test_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_log
for step in self.args.val_steps_to_log
}
test_log_dict["test_mean_loss"] = mean_loss

Expand Down Expand Up @@ -329,7 +329,7 @@ def test_step(self, batch, batch_idx):
prediction, target, pred_std, average_grid=False
) # (B, pred_steps, num_grid_nodes)
log_spatial_losses = spatial_loss[
:, [step - 1 for step in self.args.val_steps_log]
:, [step - 1 for step in self.args.val_steps_to_log]
]
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)
Expand Down Expand Up @@ -408,8 +408,8 @@ def plot_examples(self, batch, n_examples, prediction=None):
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
self.config_loader.dataset.vars,
self.config_loader.dataset.units,
self.config_loader.dataset.var_names,
self.config_loader.dataset.var_units,
var_vranges,
)
)
Expand All @@ -420,7 +420,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
self.config_loader.dataset.vars, var_figs
self.config_loader.dataset.var_names, var_figs
)
}
)
Expand Down Expand Up @@ -476,7 +476,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
# Check if metrics are watched, log exact values for specific vars
if full_log_name in self.args.metrics_watch:
for var_i, timesteps in self.args.var_leads_metrics_watch.items():
var = self.config_loader.dataset.vars[var_i]
var = self.config_loader.dataset.var_nums[var_i]
log_dict.update(
{
f"{full_log_name}_{var}_step_{step}": metric_tensor[
Expand Down Expand Up @@ -549,7 +549,7 @@ def on_test_epoch_end(self):
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
for t_i, loss_map in zip(
self.args.val_steps_log, mean_spatial_loss
self.args.val_steps_to_log, mean_spatial_loss
)
]

Expand All @@ -566,7 +566,7 @@ def on_test_epoch_end(self):
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs):
for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
torch.save(
Expand Down
56 changes: 0 additions & 56 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import os

# Third-party
import cartopy.crs as ccrs
import numpy as np
import torch
import yaml
from torch import nn
from tueplots import bundles, figsizes

Expand Down Expand Up @@ -270,57 +268,3 @@ def init_wandb_metrics(wandb_logger, val_steps):
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")


class ConfigLoader:
"""
Class for loading configuration files.
This class loads a YAML configuration file and provides a way to access
its values as attributes.
"""

def __init__(self, config_path, values=None):
self.config_path = config_path
if values is None:
self.values = self.load_config()
else:
self.values = values

def load_config(self):
"""Load configuration file."""
with open(self.config_path, encoding="utf-8", mode="r") as file:
return yaml.safe_load(file)

def __getattr__(self, name):
keys = name.split(".")
value = self.values
for key in keys:
if key in value:
value = value[key]
else:
return None
if isinstance(value, dict):
return ConfigLoader(None, values=value)
return value

def __getitem__(self, key):
value = self.values[key]
if isinstance(value, dict):
return ConfigLoader(None, values=value)
return value

def __contains__(self, key):
return key in self.values

def num_data_vars(self):
"""Return the number of data variables for a given key."""
return len(self.dataset.vars)

def projection(self):
"""Return the projection."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)
7 changes: 4 additions & 3 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def plot_error_map(errors, data_config, title=None, step_length=3):
y_ticklabels = [
f"{name} ({unit})"
for name, unit in zip(
data_config.dataset.vars, data_config.dataset.units
data_config.dataset.var_names, data_config.dataset.var_units
)
]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)
Expand Down Expand Up @@ -87,7 +87,7 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
subplot_kw={"projection": data_config.projection()},
subplot_kw={"projection": data_config.coords_projection()},
)

# Plot pred and target
Expand Down Expand Up @@ -135,7 +135,8 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
) # Faded border region

fig, ax = plt.subplots(
figsize=(5, 4.8), subplot_kw={"projection": data_config.projection()}
figsize=(5, 4.8),
subplot_kw={"projection": data_config.coords_projection()},
)

ax.coastlines() # Add coastline outlines
Expand Down
4 changes: 2 additions & 2 deletions plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch_geometric as pyg

# First-party
from neural_lam import utils
from neural_lam import config, utils

MESH_HEIGHT = 0.1
MESH_LEVEL_DIST = 0.2
Expand Down Expand Up @@ -44,7 +44,7 @@ def main():
)

args = parser.parse_args()
config_loader = utils.ConfigLoader(args.data_config)
config_loader = config.Config.from_file(args.data_config)

# Load graph data
hierarchical, graph_ldict = utils.load_graph(args.graph)
Expand Down
Loading

0 comments on commit 0ba441b

Please sign in to comment.