From c4245852ea99fb69adea56e1b3c61a619b3e82ec Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Fri, 21 Oct 2022 17:03:40 -0700 Subject: [PATCH 1/4] basic pluggable cache implementation --- xbatcher/generators.py | 30 +++++++++++++++++++++++--- xbatcher/tests/test_generators.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index ec3b2b6..b569d6a 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -2,7 +2,7 @@ import itertools from collections import OrderedDict -from typing import Any, Dict, Hashable, Iterator +from typing import Any, Callable, Dict, Hashable, Iterator import xarray as xr @@ -111,6 +111,8 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, + cache: dict[str, Any] | None = None, + cache_preprocess: Callable | None = None, ): self.ds = ds @@ -120,6 +122,9 @@ def __init__( self.batch_dims = OrderedDict(batch_dims) self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch + self.cache = cache + self.cache_preprocess = cache_preprocess + self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches def __iter__(self) -> Iterator[xr.Dataset]: @@ -139,6 +144,9 @@ def __getitem__(self, idx: int) -> xr.Dataset: if idx < 0: idx = list(self._batches)[idx] + if self.cache and idx in self.cache: + return self._get_cached_batch(idx) + if idx in self._batches: if self.concat_input_dims: @@ -153,15 +161,31 @@ def __getitem__(self, idx: int) -> xr.Dataset: ] dsc = xr.concat(all_dsets, dim="input_batch") new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] - return _maybe_stack_batch_dims(dsc, new_input_dims) + batch = _maybe_stack_batch_dims(dsc, new_input_dims) else: - return _maybe_stack_batch_dims( + batch = _maybe_stack_batch_dims( self.ds.isel(**self._batches[idx]), list(self.input_dims) ) else: raise IndexError("list index out of range") + if self.cache is not None: + if self.cache_preprocess is not None: + batch = self.cache_preprocess(batch) + self._cache_batch(idx, batch) + + return batch + + def _cache_batch(self, idx: int, batch: xr.Dataset): + batch.to_zarr(self.cache, group=str(idx), mode="a") + + def _get_cached_batch(self, idx: int) -> xr.Dataset: + ds = xr.open_zarr(self.cache, group=str(idx)) + if self.preload_batch: + ds = ds.load() + return ds + def _gen_batches(self) -> dict: # in the future, we will want to do the batch generation lazily # going the eager route for now is allowing me to fill out the loader api diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index c51700b..dd4f33c 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy as np import pytest import xarray as xr @@ -222,3 +224,37 @@ def test_batch_exceptions(sample_ds_1d): with pytest.raises(ValueError) as e: BatchGenerator(sample_ds_1d, input_dims={"x": 10}, input_overlap={"x": 20}) assert len(e) == 1 + + +def test_batcher_cached_getitem(sample_ds_1d) -> None: + cache: dict[str, Any] = {} + + def preproc(ds): + processed = ds.load().chunk(-1) + processed.attrs["foo"] = "bar" + return processed + + bg = BatchGenerator( + sample_ds_1d, input_dims={"x": 10}, cache=cache, cache_preprocess=preproc + ) + + # first batch + assert bg[0].dims["x"] == 10 + ds_no_cache = bg[1] + # last batch + assert bg[-1].dims["x"] == 10 + + assert "0/.zgroup" in cache + + # now from cache + # first batch + assert bg[0].dims["x"] == 10 + # last batch + assert bg[-1].dims["x"] == 10 + ds_cache = bg[1] + + assert ds_no_cache.attrs["foo"] == "bar" + assert ds_cache.attrs["foo"] == "bar" + + xr.testing.assert_equal(ds_no_cache, ds_cache) + xr.testing.assert_identical(ds_no_cache, ds_cache) From 358bd36060862e006bcc2466a0a683a2fe64aa71 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Fri, 21 Oct 2022 20:18:10 -0700 Subject: [PATCH 2/4] typing for older pythons --- xbatcher/generators.py | 10 +++++++--- xbatcher/tests/test_generators.py | 5 +++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index b569d6a..1bf7342 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -111,7 +111,7 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, - cache: dict[str, Any] | None = None, + cache: Dict[str, Any] | None = None, cache_preprocess: Callable | None = None, ): @@ -144,7 +144,7 @@ def __getitem__(self, idx: int) -> xr.Dataset: if idx < 0: idx = list(self._batches)[idx] - if self.cache and idx in self.cache: + if self.cache and self._batch_in_cache(idx): return self._get_cached_batch(idx) if idx in self._batches: @@ -177,7 +177,11 @@ def __getitem__(self, idx: int) -> xr.Dataset: return batch - def _cache_batch(self, idx: int, batch: xr.Dataset): + def _batch_in_cache(self, idx: int) -> bool: + gkey = f"{idx}/.zgroup" + return gkey in self.cache + + def _cache_batch(self, idx: int, batch: xr.Dataset) -> None: batch.to_zarr(self.cache, group=str(idx), mode="a") def _get_cached_batch(self, idx: int) -> xr.Dataset: diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index dd4f33c..4111008 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict import numpy as np import pytest @@ -227,7 +227,8 @@ def test_batch_exceptions(sample_ds_1d): def test_batcher_cached_getitem(sample_ds_1d) -> None: - cache: dict[str, Any] = {} + pytest.importorskip("zarr") + cache: Dict[str, Any] = {} def preproc(ds): processed = ds.load().chunk(-1) From 23579ae98e7655c0660ee35b95ce2d280eb6abf9 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Fri, 21 Oct 2022 20:24:24 -0700 Subject: [PATCH 3/4] typing for older pythons --- xbatcher/generators.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 1bf7342..7d774e0 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -2,7 +2,7 @@ import itertools from collections import OrderedDict -from typing import Any, Callable, Dict, Hashable, Iterator +from typing import Any, Callable, Dict, Hashable, Iterator, Optional import xarray as xr @@ -111,8 +111,8 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, - cache: Dict[str, Any] | None = None, - cache_preprocess: Callable | None = None, + cache: Optional[Dict[str, Any]] = None, + cache_preprocess: Optional[Callable] = None, ): self.ds = ds @@ -178,8 +178,7 @@ def __getitem__(self, idx: int) -> xr.Dataset: return batch def _batch_in_cache(self, idx: int) -> bool: - gkey = f"{idx}/.zgroup" - return gkey in self.cache + return self.cache is not None and f"{idx}/.zgroup" in self.cache def _cache_batch(self, idx: int, batch: xr.Dataset) -> None: batch.to_zarr(self.cache, group=str(idx), mode="a") From ea7f1280a4d831ab7c3b483e822a8b62c562fa2c Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Sat, 22 Oct 2022 10:10:22 -0700 Subject: [PATCH 4/4] more test coverage --- xbatcher/generators.py | 4 ++-- xbatcher/tests/test_generators.py | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 7d774e0..52ff542 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -170,9 +170,9 @@ def __getitem__(self, idx: int) -> xr.Dataset: else: raise IndexError("list index out of range") + if self.cache is not None and self.cache_preprocess is not None: + batch = self.cache_preprocess(batch) if self.cache is not None: - if self.cache_preprocess is not None: - batch = self.cache_preprocess(batch) self._cache_batch(idx, batch) return batch diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 4111008..cc115b8 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -226,7 +226,8 @@ def test_batch_exceptions(sample_ds_1d): assert len(e) == 1 -def test_batcher_cached_getitem(sample_ds_1d) -> None: +@pytest.mark.parametrize("preload", [True, False]) +def test_batcher_cached_getitem(sample_ds_1d, preload) -> None: pytest.importorskip("zarr") cache: Dict[str, Any] = {} @@ -236,7 +237,11 @@ def preproc(ds): return processed bg = BatchGenerator( - sample_ds_1d, input_dims={"x": 10}, cache=cache, cache_preprocess=preproc + sample_ds_1d, + input_dims={"x": 10}, + cache=cache, + cache_preprocess=preproc, + preload_batch=preload, ) # first batch @@ -259,3 +264,15 @@ def preproc(ds): xr.testing.assert_equal(ds_no_cache, ds_cache) xr.testing.assert_identical(ds_no_cache, ds_cache) + + # without preprocess func + bg = BatchGenerator( + sample_ds_1d, input_dims={"x": 10}, cache=cache, preload_batch=preload + ) + assert bg.cache_preprocess is None + assert bg[0].dims["x"] == 10 + ds_no_cache = bg[1] + assert "1/.zgroup" in cache + ds_cache = bg[1] + xr.testing.assert_equal(ds_no_cache, ds_cache) + xr.testing.assert_identical(ds_no_cache, ds_cache)