diff --git a/README.md b/README.md index fc5675e8..ba0bb3fe 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Still, some restrictions are inevitable: ## A note on the limited area setting Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/data_config.yaml`). +This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. diff --git a/create_grid_features.py b/create_grid_features.py index e5b9c49a..c3714368 100644 --- a/create_grid_features.py +++ b/create_grid_features.py @@ -7,7 +7,7 @@ import torch # First-party -from neural_lam import utils +from neural_lam import config def main(): @@ -22,7 +22,7 @@ def main(): help="Path to data config file (default: neural_lam/data_config.yaml)", ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) static_dir_path = os.path.join("data", config_loader.dataset.name, "static") diff --git a/create_mesh.py b/create_mesh.py index 477ddf55..f04b4d4b 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -13,7 +13,7 @@ from torch_geometric.utils.convert import from_networkx # First-party -from neural_lam import utils +from neural_lam import config def plot_graph(graph, title=None): @@ -189,7 +189,7 @@ def main(): args = parser.parse_args() # Load grid positions - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) static_dir_path = os.path.join("data", config_loader.dataset.name, "static") graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index fd8c38cd..cae1ae3e 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -8,7 +8,7 @@ from tqdm import tqdm # First-party -from neural_lam import utils +from neural_lam import config from neural_lam.weather_dataset import WeatherDataset @@ -43,7 +43,7 @@ def main(): ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # Create parameter weights based on height @@ -57,7 +57,10 @@ def main(): "500": 0.03, } w_list = np.array( - [w_dict[par.split("_")[-2]] for par in config_loader.dataset.var_names] + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] ) print("Saving parameter weights...") np.save( diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..e758e09c --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,59 @@ +import functools +from pathlib import Path + +import cartopy.crs as ccrs +import yaml + + +class Config: + """ + Class for loading configuration files. + + This class loads a configuration file and provides a way to access its + values as attributes. + """ + + def __init__(self, values): + self.values = values + + @classmethod + def from_file(cls, filepath): + if filepath.endswith(".yaml"): + with open(filepath, encoding="utf-8", mode="r") as file: + return cls(values=yaml.safe_load(file)) + else: + raise NotImplementedError(Path(filepath).suffix) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + return None + if isinstance(value, dict): + return Config(values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return Config(values=value) + return value + + def __contains__(self, key): + return key in self.values + + def num_data_vars(self): + """Return the number of data variables for a given key.""" + return len(self.dataset.var_names) + + @functools.cached_property + def coords_projection(self): + """Return the projection.""" + proj_config = self.values["projection"] + proj_class_name = proj_config["class"] + proj_class = getattr(ccrs, proj_class_name) + proj_params = proj_config.get("kwargs", {}) + return proj_class(**proj_params) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index 213825ff..f16a4a30 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -1,6 +1,6 @@ dataset: name: meps_example - vars: + var_names: - pres_0g - pres_0s - nlwrs_0 @@ -18,7 +18,7 @@ dataset: - wvint_0 - z_1000 - z_500 - units: + var_units: - Pa - Pa - r"$\mathrm{W}/\mathrm{m}^2$" @@ -36,7 +36,7 @@ dataset: - r"$\mathrm{kg}/\mathrm{m}^2$" - r"$\mathrm{m}^2/\mathrm{s}^2$" - r"$\mathrm{m}^2/\mathrm{s}^2$" - var_names: + var_longnames: - pres_heightAboveGround_0_instant - pres_heightAboveSea_0_instant - nlwrs_heightAboveGround_0_accum @@ -54,7 +54,7 @@ dataset: - wvint_entireAtmosphere_0_instant - z_isobaricInhPa_1000_instant - z_isobaricInhPa_500_instant - forcing_dim: 16 + num_forcing_features: 16 grid_shape_state: [268, 238] projection: class: LambertConformal diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index da2654f0..9cda9fc2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -9,7 +9,7 @@ import wandb # First-party -from neural_lam import metrics, utils, vis +from neural_lam import config, metrics, utils, vis class ARModel(pl.LightningModule): @@ -25,7 +25,7 @@ def __init__(self, args): super().__init__() self.save_hyperparameters() self.args = args - self.config_loader = utils.ConfigLoader(args.data_config) + self.config_loader = config.Config.from_file(args.data_config) # Load static features for grid/data static_data_dict = utils.load_static_data( @@ -61,7 +61,7 @@ def __init__(self, args): self.grid_dim = ( 2 * self.config_loader.num_data_vars() + grid_static_dim - + self.config_loader.dataset.forcing_dim + + self.config_loader.dataset.num_forcing_features ) # Instantiate loss function @@ -246,7 +246,7 @@ def validation_step(self, batch, batch_idx): # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_log + for step in self.args.val_steps_to_log } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( @@ -294,7 +294,7 @@ def test_step(self, batch, batch_idx): # Log loss per time step forward and mean test_log_dict = { f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_log + for step in self.args.val_steps_to_log } test_log_dict["test_mean_loss"] = mean_loss @@ -329,7 +329,7 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, average_grid=False ) # (B, pred_steps, num_grid_nodes) log_spatial_losses = spatial_loss[ - :, [step - 1 for step in self.args.val_steps_log] + :, [step - 1 for step in self.args.val_steps_to_log] ] self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) @@ -408,8 +408,8 @@ def plot_examples(self, batch, n_examples, prediction=None): ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - self.config_loader.dataset.vars, - self.config_loader.dataset.units, + self.config_loader.dataset.var_names, + self.config_loader.dataset.var_units, var_vranges, ) ) @@ -420,7 +420,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - self.config_loader.dataset.vars, var_figs + self.config_loader.dataset.var_names, var_figs ) } ) @@ -476,7 +476,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.dataset.vars[var_i] + var = self.config_loader.dataset.var_nums[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ @@ -549,7 +549,7 @@ def on_test_epoch_end(self): title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) for t_i, loss_map in zip( - self.args.val_steps_log, mean_spatial_loss + self.args.val_steps_to_log, mean_spatial_loss ) ] @@ -566,7 +566,7 @@ def on_test_epoch_end(self): ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs): + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 528560e3..836b04ed 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -2,10 +2,8 @@ import os # Third-party -import cartopy.crs as ccrs import numpy as np import torch -import yaml from torch import nn from tueplots import bundles, figsizes @@ -270,57 +268,3 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") - - -class ConfigLoader: - """ - Class for loading configuration files. - - This class loads a YAML configuration file and provides a way to access - its values as attributes. - """ - - def __init__(self, config_path, values=None): - self.config_path = config_path - if values is None: - self.values = self.load_config() - else: - self.values = values - - def load_config(self): - """Load configuration file.""" - with open(self.config_path, encoding="utf-8", mode="r") as file: - return yaml.safe_load(file) - - def __getattr__(self, name): - keys = name.split(".") - value = self.values - for key in keys: - if key in value: - value = value[key] - else: - return None - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __getitem__(self, key): - value = self.values[key] - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __contains__(self, key): - return key in self.values - - def num_data_vars(self): - """Return the number of data variables for a given key.""" - return len(self.dataset.vars) - - def projection(self): - """Return the projection.""" - proj_config = self.values["projection"] - proj_class_name = proj_config["class"] - proj_class = getattr(ccrs, proj_class_name) - proj_params = proj_config.get("kwargs", {}) - return proj_class(**proj_params) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 7a4d3730..2b6abf15 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -51,7 +51,7 @@ def plot_error_map(errors, data_config, title=None, step_length=3): y_ticklabels = [ f"{name} ({unit})" for name, unit in zip( - data_config.dataset.vars, data_config.dataset.units + data_config.dataset.var_names, data_config.dataset.var_units ) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.projection()}, + subplot_kw={"projection": data_config.coords_projection()}, ) # Plot pred and target @@ -135,7 +135,8 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): ) # Faded border region fig, ax = plt.subplots( - figsize=(5, 4.8), subplot_kw={"projection": data_config.projection()} + figsize=(5, 4.8), + subplot_kw={"projection": data_config.coords_projection()}, ) ax.coastlines() # Add coastline outlines diff --git a/plot_graph.py b/plot_graph.py index 0670963f..40b2b41d 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -7,7 +7,7 @@ import torch_geometric as pyg # First-party -from neural_lam import utils +from neural_lam import config, utils MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.2 @@ -44,7 +44,7 @@ def main(): ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) diff --git a/train_model.py b/train_model.py index da109fdf..390da6d4 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,7 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import utils +from neural_lam import config, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel @@ -189,7 +189,7 @@ def main(): help="Wandb project name (default: neural_lam)", ) parser.add_argument( - "--val_steps_log", + "--val_steps_to_log", type=list, default=[1, 2, 3, 5, 10, 15, 19], help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])", @@ -208,7 +208,7 @@ def main(): ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" @@ -306,7 +306,7 @@ def main(): # Only init once, on rank 0 only if trainer.global_rank == 0: utils.init_wandb_metrics( - logger, args.val_steps_log + logger, args.val_steps_to_log ) # Do after wandb.init if args.eval: