diff --git a/xbatcher/generators.py b/xbatcher/generators.py index ec3b2b6..52ff542 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -2,7 +2,7 @@ import itertools from collections import OrderedDict -from typing import Any, Dict, Hashable, Iterator +from typing import Any, Callable, Dict, Hashable, Iterator, Optional import xarray as xr @@ -111,6 +111,8 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, + cache: Optional[Dict[str, Any]] = None, + cache_preprocess: Optional[Callable] = None, ): self.ds = ds @@ -120,6 +122,9 @@ def __init__( self.batch_dims = OrderedDict(batch_dims) self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch + self.cache = cache + self.cache_preprocess = cache_preprocess + self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches def __iter__(self) -> Iterator[xr.Dataset]: @@ -139,6 +144,9 @@ def __getitem__(self, idx: int) -> xr.Dataset: if idx < 0: idx = list(self._batches)[idx] + if self.cache and self._batch_in_cache(idx): + return self._get_cached_batch(idx) + if idx in self._batches: if self.concat_input_dims: @@ -153,15 +161,34 @@ def __getitem__(self, idx: int) -> xr.Dataset: ] dsc = xr.concat(all_dsets, dim="input_batch") new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] - return _maybe_stack_batch_dims(dsc, new_input_dims) + batch = _maybe_stack_batch_dims(dsc, new_input_dims) else: - return _maybe_stack_batch_dims( + batch = _maybe_stack_batch_dims( self.ds.isel(**self._batches[idx]), list(self.input_dims) ) else: raise IndexError("list index out of range") + if self.cache is not None and self.cache_preprocess is not None: + batch = self.cache_preprocess(batch) + if self.cache is not None: + self._cache_batch(idx, batch) + + return batch + + def _batch_in_cache(self, idx: int) -> bool: + return self.cache is not None and f"{idx}/.zgroup" in self.cache + + def _cache_batch(self, idx: int, batch: xr.Dataset) -> None: + batch.to_zarr(self.cache, group=str(idx), mode="a") + + def _get_cached_batch(self, idx: int) -> xr.Dataset: + ds = xr.open_zarr(self.cache, group=str(idx)) + if self.preload_batch: + ds = ds.load() + return ds + def _gen_batches(self) -> dict: # in the future, we will want to do the batch generation lazily # going the eager route for now is allowing me to fill out the loader api diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index c51700b..cc115b8 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import numpy as np import pytest import xarray as xr @@ -222,3 +224,55 @@ def test_batch_exceptions(sample_ds_1d): with pytest.raises(ValueError) as e: BatchGenerator(sample_ds_1d, input_dims={"x": 10}, input_overlap={"x": 20}) assert len(e) == 1 + + +@pytest.mark.parametrize("preload", [True, False]) +def test_batcher_cached_getitem(sample_ds_1d, preload) -> None: + pytest.importorskip("zarr") + cache: Dict[str, Any] = {} + + def preproc(ds): + processed = ds.load().chunk(-1) + processed.attrs["foo"] = "bar" + return processed + + bg = BatchGenerator( + sample_ds_1d, + input_dims={"x": 10}, + cache=cache, + cache_preprocess=preproc, + preload_batch=preload, + ) + + # first batch + assert bg[0].dims["x"] == 10 + ds_no_cache = bg[1] + # last batch + assert bg[-1].dims["x"] == 10 + + assert "0/.zgroup" in cache + + # now from cache + # first batch + assert bg[0].dims["x"] == 10 + # last batch + assert bg[-1].dims["x"] == 10 + ds_cache = bg[1] + + assert ds_no_cache.attrs["foo"] == "bar" + assert ds_cache.attrs["foo"] == "bar" + + xr.testing.assert_equal(ds_no_cache, ds_cache) + xr.testing.assert_identical(ds_no_cache, ds_cache) + + # without preprocess func + bg = BatchGenerator( + sample_ds_1d, input_dims={"x": 10}, cache=cache, preload_batch=preload + ) + assert bg.cache_preprocess is None + assert bg[0].dims["x"] == 10 + ds_no_cache = bg[1] + assert "1/.zgroup" in cache + ds_cache = bg[1] + xr.testing.assert_equal(ds_no_cache, ds_cache) + xr.testing.assert_identical(ds_no_cache, ds_cache)