Skip to content

Commit

Permalink
add test to check if lr_scheduler is working as expected
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob Mathias Schreiner committed Jan 27, 2025
1 parent 14ab353 commit 9824d82
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

# Third-party
import pooch
import pytest
import yaml
from torch.utils.data import DataLoader

# First-party
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.datastore import DATASTORES, init_datastore
from neural_lam.datastore.npyfilesmeps import (
compute_standardization_stats as compute_standardization_stats_meps,
)
from neural_lam.weather_dataset import WeatherDataset

# Local
from .dummy_datastore import DummyDatastore
Expand Down Expand Up @@ -106,6 +110,10 @@ def init_datastore_example(datastore_kind):
return datastore


graph_name = "1level"


@pytest.fixture
def model_args():
class ModelArgs:
output_std = False
Expand All @@ -121,3 +129,25 @@ class ModelArgs:
num_future_forcing_steps = 1

return ModelArgs()


@pytest.fixture
def datastore():
datastore = init_datastore_example("dummydata")
graph_dir_path = Path(datastore.root_path) / "graph" / graph_name
if not graph_dir_path.exists():
create_graph_from_datastore(
datastore=datastore,
output_root_path=str(graph_dir_path),
n_max_levels=1,
)

return datastore


@pytest.fixture
def batch(datastore):
dataset = WeatherDataset(datastore=datastore)
data_loader = DataLoader(dataset, batch_size=1)
batch = next(iter(data_loader))
return batch
39 changes: 39 additions & 0 deletions tests/test_ar_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Third-party
import pytorch_lightning as pl
import torch

# First-party
from neural_lam import config as nlconfig
from neural_lam.models.ar_model import ARModel


class ARModelWithParams(ARModel):
def __init__(self, args, datastore, config):
super().__init__(args=args, datastore=datastore, config=config)
self.layer = torch.nn.Linear(1, 1)


def test_lr_scheduler_reduces_lr(model_args, datastore):
yaml_str = """
datastore:
kind: mdp
config_path: ""
training:
optimization:
lr: 1
lr_scheduler: ExponentialLR
lr_scheduler_kwargs:
gamma: 0.5
"""
config = nlconfig.NeuralLAMConfig.from_yaml(yaml_str)

model = ARModelWithParams(
args=model_args, datastore=datastore, config=config
)
[optimizer], [lr_scheduler] = model.configure_optimizers()

assert optimizer.param_groups[0]["lr"] == 1
lr_scheduler.step()
assert optimizer.param_groups[0]["lr"] == 0.5
lr_scheduler.step()
assert optimizer.param_groups[0]["lr"] == 0.25

0 comments on commit 9824d82

Please sign in to comment.