From 9824d82da04f0f66a1eb51652636508c60f4dbac Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Mon, 27 Jan 2025 15:30:40 +0100 Subject: [PATCH] add test to check if lr_scheduler is working as expected --- tests/conftest.py | 30 ++++++++++++++++++++++++++++++ tests/test_ar_model.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 tests/test_ar_model.py diff --git a/tests/conftest.py b/tests/conftest.py index 63efcce1..2b9e7e8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,17 @@ # Third-party import pooch +import pytest import yaml +from torch.utils.data import DataLoader # First-party +from neural_lam.create_graph import create_graph_from_datastore from neural_lam.datastore import DATASTORES, init_datastore from neural_lam.datastore.npyfilesmeps import ( compute_standardization_stats as compute_standardization_stats_meps, ) +from neural_lam.weather_dataset import WeatherDataset # Local from .dummy_datastore import DummyDatastore @@ -106,6 +110,10 @@ def init_datastore_example(datastore_kind): return datastore +graph_name = "1level" + + +@pytest.fixture def model_args(): class ModelArgs: output_std = False @@ -121,3 +129,25 @@ class ModelArgs: num_future_forcing_steps = 1 return ModelArgs() + + +@pytest.fixture +def datastore(): + datastore = init_datastore_example("dummydata") + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + return datastore + + +@pytest.fixture +def batch(datastore): + dataset = WeatherDataset(datastore=datastore) + data_loader = DataLoader(dataset, batch_size=1) + batch = next(iter(data_loader)) + return batch diff --git a/tests/test_ar_model.py b/tests/test_ar_model.py new file mode 100644 index 00000000..b6d02d88 --- /dev/null +++ b/tests/test_ar_model.py @@ -0,0 +1,39 @@ +# Third-party +import pytorch_lightning as pl +import torch + +# First-party +from neural_lam import config as nlconfig +from neural_lam.models.ar_model import ARModel + + +class ARModelWithParams(ARModel): + def __init__(self, args, datastore, config): + super().__init__(args=args, datastore=datastore, config=config) + self.layer = torch.nn.Linear(1, 1) + + +def test_lr_scheduler_reduces_lr(model_args, datastore): + yaml_str = """ + datastore: + kind: mdp + config_path: "" + training: + optimization: + lr: 1 + lr_scheduler: ExponentialLR + lr_scheduler_kwargs: + gamma: 0.5 + """ + config = nlconfig.NeuralLAMConfig.from_yaml(yaml_str) + + model = ARModelWithParams( + args=model_args, datastore=datastore, config=config + ) + [optimizer], [lr_scheduler] = model.configure_optimizers() + + assert optimizer.param_groups[0]["lr"] == 1 + lr_scheduler.step() + assert optimizer.param_groups[0]["lr"] == 0.5 + lr_scheduler.step() + assert optimizer.param_groups[0]["lr"] == 0.25