Skip to content

Commit

Permalink
test for stats
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Jul 23, 2024
1 parent 9b88160 commit a9fdad5
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 86 deletions.
6 changes: 3 additions & 3 deletions neural_lam/datastore/multizarr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def __getattr__(self, name):
keys = name.split(".")
value = self.values
for key in keys:
if key in value:
try:
value = value[key]
else:
return None
except KeyError:
raise AttributeError(f"Key '{key}' not found in {value}")
if isinstance(value, dict):
return Config(values=value)
return value
Expand Down
9 changes: 6 additions & 3 deletions neural_lam/datastore/multizarr/create_auxiliary_forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ def calculate_datetime_forcing(da_time: xr.DataArray):

def main():
"""Main function for creating the datetime forcing and boundary mask."""
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
description="Create the datetime forcing for neural LAM.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--data-config",
"data_config",
type=str,
default="tests/datastore_configs/multizarr.danra.yaml",
help="Path to data config file",
)
parser.add_argument(
"--zarr_path",
Expand Down
61 changes: 0 additions & 61 deletions neural_lam/datastore/multizarr/create_grid_features.py

This file was deleted.

21 changes: 11 additions & 10 deletions neural_lam/datastore/multizarr/create_normalization_stats.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Standard library
from argparse import ArgumentParser
import argparse

# Third-party
import xarray as xr

# First-party
from neural_lam.datastore.multizarr import MultiZarrDatastore

DEFAULT_PATH = "tests/datastore_configs/multizarr.danra.yaml"


def compute_stats(da):
mean = da.mean(dim=("time", "grid_index"))
Expand All @@ -17,17 +15,19 @@ def compute_stats(da):


def main():
parser = ArgumentParser(description="Training arguments")
parser = argparse.ArgumentParser(
description="Training arguments",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--data_config",
"data_config",
type=str,
default=DEFAULT_PATH,
help=f"Path to data config file (default: {DEFAULT_PATH})",
help="Path to data config file",
)
parser.add_argument(
"--zarr_path",
type=str,
default="data/normalization.zarr",
default="normalization.zarr",
help="Directory where data is stored",
)
args = parser.parse_args()
Expand All @@ -49,6 +49,7 @@ def main():
if combined_stats is not None:
for group in combined_stats:
vars_to_combine = group["vars"]

means = da_forcing_mean.sel(variable=vars_to_combine)
stds = da_forcing_std.sel(variable=vars_to_combine)

Expand Down Expand Up @@ -85,8 +86,8 @@ def main():
{
"state_mean": da_state_mean,
"state_std": da_state_std,
"diff_mean": diff_mean,
"diff_std": diff_std,
"state_diff_mean": diff_mean,
"state_diff_std": diff_std,
}
)
if da_forcing is not None:
Expand Down
18 changes: 15 additions & 3 deletions neural_lam/datastore/multizarr/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,19 @@ class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable"}

def __init__(self, config_path):
self.config_path = config_path
with open(config_path, encoding="utf-8", mode="r") as file:
self._config = yaml.safe_load(file)

def _normalize_path(self, path):
# try to parse path to see if it defines a protocol, e.g. s3://
if "://" in path or path.startswith("/"):
pass
else:
# assume path is relative to config file
path = os.path.join(os.path.dirname(self.config_path), path)
return path

def open_zarrs(self, category):
"""Open the zarr dataset for the given category.
Expand All @@ -33,7 +43,8 @@ def open_zarrs(self, category):

datasets = []
for config in zarr_configs:
dataset_path = config["path"]
dataset_path = self._normalize_path(config["path"])

try:
dataset = xr.open_zarr(dataset_path, consolidated=True)
except Exception as e:
Expand Down Expand Up @@ -359,7 +370,7 @@ def _load_and_merge_stats(self):
for i, zarr_config in enumerate(
self._config["utilities"]["normalization"]["zarrs"]
):
stats_path = zarr_config["path"]
stats_path = self._normalize_path(zarr_config["path"])
if not os.path.exists(stats_path):
raise FileNotFoundError(
f"Normalization statistics not found at path: {stats_path}"
Expand Down Expand Up @@ -612,9 +623,10 @@ def boundary_mask(self):
xr.DataArray
The boundary mask for the dataset, with dimensions `('grid_index',)`.
"""
ds_boundary_mask = xr.open_zarr(
boundary_mask_path = self._normalize_path(
self._config["boundary"]["mask"]["path"]
)
ds_boundary_mask = xr.open_zarr(boundary_mask_path)
return ds_boundary_mask.mask.stack(grid_index=("y", "x")).reset_index(
"grid_index"
)
Expand Down
2 changes: 2 additions & 0 deletions tests/datastore_configs/mllam/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.zarr/
graph/
2 changes: 1 addition & 1 deletion tests/datastore_configs/mllam/example.danra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ output:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
ops: [mean, std]
ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
validation:
start: 1990-09-06T00:00
Expand Down
6 changes: 3 additions & 3 deletions tests/datastore_configs/multizarr/data_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ forcing:
lat_lon_names:
lon: lon
lat: lat
- path: "tests/config_examples/multizarr/datetime_forcings.zarr"
- path: "datetime_forcings.zarr"
dims:
time: time
level: null
Expand Down Expand Up @@ -111,7 +111,7 @@ boundary:
lon: longitude
lat: latitude
mask:
path: "data/boundary_mask.zarr"
path: "boundary_mask.zarr"
dims:
x: x
y: y
Expand All @@ -126,7 +126,7 @@ boundary:
utilities:
normalization:
zarrs:
- path: "tests/datastore_configs/multizarr/normalization.zarr"
- path: "normalization.zarr"
stats_vars:
state_mean: state_mean
state_std: state_std
Expand Down
2 changes: 0 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# First-party
import neural_lam
import neural_lam.create_graph
import neural_lam.datastore.multizarr.create_grid_features
import neural_lam.train_model


Expand All @@ -10,5 +9,4 @@ def test_import():
now, eventually we should test their execution too."""
assert neural_lam is not None
assert neural_lam.create_graph is not None
assert neural_lam.datastore.multizarr.create_grid_features is not None
assert neural_lam.train_model is not None
28 changes: 28 additions & 0 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# Third-party
import cartopy.crs as ccrs
import pytest
import xarray as xr

# First-party
from neural_lam.datastore.mllam import MLLAMDatastore
Expand Down Expand Up @@ -105,3 +106,30 @@ def test_get_vars(datastore_name):
assert isinstance(units, list)
assert isinstance(names, list)
assert isinstance(num_vars, int)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_normalization_dataarray(datastore_name):
"""Check that the `datastore.get_normalization_dataarray` method is
implemented."""
datastore = _init_datastore(datastore_name)

for category in ["state", "forcing", "static"]:
ds_stats = datastore.get_normalization_dataarray(category=category)

# check that the returned object is an xarray DataArray
# and that it has the correct variables
assert isinstance(ds_stats, xr.Dataset)

if category == "state":
ops = ["mean", "std", "diff_mean", "diff_std"]
elif category == "forcing":
ops = ["mean", "std"]
elif category == "static":
ops = []
else:
raise NotImplementedError(category)

for op in ops:
var_name = f"{category}_{op}"
assert var_name in ds_stats.data_vars

0 comments on commit a9fdad5

Please sign in to comment.