From 92cdbcf12f24d7ce837be519ef27995f994c862e Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Wed, 14 Dec 2022 17:53:44 -0500 Subject: [PATCH 1/6] Add tests for duplicate dims in batch_dims and input_dims --- xbatcher/testing.py | 71 +++++++++++++++++++++++-------- xbatcher/tests/test_generators.py | 32 ++++++++++++++ 2 files changed, 86 insertions(+), 17 deletions(-) diff --git a/xbatcher/testing.py b/xbatcher/testing.py index b034901..c8645ea 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -24,9 +24,10 @@ def _get_non_specified_dims(generator: BatchGenerator) -> Dict[Hashable, int]: in the input_dims or batch_dims attributes of the batch generator. """ return { - k: v - for k, v in generator.ds.sizes.items() - if (generator.input_dims.get(k) is None and generator.batch_dims.get(k) is None) + dim: length + for dim, length in generator.ds.sizes.items() + if generator.input_dims.get(dim) is None + and generator.batch_dims.get(dim) is None } @@ -46,9 +47,31 @@ def _get_non_input_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]: not also in input_dims """ return { - k: v - for k, v in generator.batch_dims.items() - if (generator.input_dims.get(k) is None) + dim: length + for dim, length in generator.batch_dims.items() + if generator.input_dims.get(dim) is None + } + + +def _get_duplicate_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]: + """ + Return all dimensions that are in batch_dims as well as input_dims. + + Parameters + ---------- + generator : xbatcher.BatchGenerator + The batch generator object. + + Returns + ------- + d : dict + Dict containing all dimensions in specified in batch_dims that are + not also in input_dims + """ + return { + dim: length + for dim, length in generator.batch_dims.items() + if generator.input_dims.get(dim) is not None } @@ -188,17 +211,18 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: """ nbatches_from_input_dims = np.product( [ - generator.ds.sizes[k] // generator.input_dims[k] - for k in generator.input_dims.keys() - if generator.input_overlap.get(k) is None + generator.ds.sizes[dim] // length + for dim, length in generator.input_dims.items() + if generator.input_overlap.get(dim) is None + and generator.batch_dims.get(dim) is None ] ) if generator.input_overlap: nbatches_from_input_overlap = np.product( [ - (generator.ds.sizes[k] - generator.input_overlap[k]) - // (generator.input_dims[k] - generator.input_overlap[k]) - for k in generator.input_overlap + (generator.ds.sizes[dim] - overlap) + // (generator.input_dims[dim] - overlap) + for dim, overlap in generator.input_overlap.items() ] ) return int(nbatches_from_input_overlap * nbatches_from_input_dims) @@ -217,17 +241,30 @@ def validate_generator_length(generator: BatchGenerator) -> None: The batch generator object. """ non_input_batch_dims = _get_non_input_batch_dims(generator) - nbatches_from_batch_dims = np.product( + duplicate_batch_dims = _get_duplicate_batch_dims(generator) + nbatches_from_unique_batch_dims = np.product( [ - generator.ds.sizes[k] // non_input_batch_dims[k] - for k in non_input_batch_dims.keys() + generator.ds.sizes[dim] // length + for dim, length in non_input_batch_dims.items() + ] + ) + nbatches_from_duplicate_batch_dims = np.product( + [ + generator.ds.sizes[dim] // length + for dim, length in duplicate_batch_dims.items() ] ) if generator.concat_input_dims: - expected_length = int(nbatches_from_batch_dims) + expected_length = int( + nbatches_from_unique_batch_dims * nbatches_from_duplicate_batch_dims + ) else: nbatches_from_input_dims = _get_nbatches_from_input_dims(generator) - expected_length = int(nbatches_from_batch_dims * nbatches_from_input_dims) + expected_length = int( + nbatches_from_unique_batch_dims + * nbatches_from_duplicate_batch_dims + * nbatches_from_input_dims + ) TestCase().assertEqual( expected_length, len(generator), diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index bf2b796..4ea54b3 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -115,6 +115,21 @@ def test_batch_1d_concat(sample_ds_1d, input_size): assert "x" in ds_batch.coords +def test_batch_1d_concat_duplicate_dim(sample_ds_1d): + """ + Test batch generation for a 1D dataset using ``concat_input_dims`` when + the same dimension occurs in ``input_dims`` and `batch_dims`` + """ + bg = BatchGenerator( + sample_ds_1d, input_dims={"x": 5}, batch_dims={"x": 10}, concat_input_dims=True + ) + validate_generator_length(bg) + expected_dims = get_batch_dimensions(bg) + for n, ds_batch in enumerate(bg): + assert isinstance(ds_batch, xr.Dataset) + validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) + + @pytest.mark.parametrize("input_size", [5, 10]) def test_batch_1d_no_coordinate(sample_ds_1d, input_size): """ @@ -218,6 +233,23 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): + """ + Test batch generation for a 3D dataset using ``concat_input_dims`` when + the same dimension occurs in ``input_dims`` and batch_dims``. + """ + bg = BatchGenerator( + sample_ds_3d, + input_dims={"x": 5, "y": 10}, + batch_dims={"x": 10, "y": 20}, + concat_input_dims=True, + ) + validate_generator_length(bg) + expected_dims = get_batch_dimensions(bg) + for ds_batch in bg: + validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) + + @pytest.mark.parametrize("input_size", [5, 10]) def test_batch_3d_2d_input(sample_ds_3d, input_size): """ From 090d924902ecef500044b31099299e4f3261113d Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Wed, 14 Dec 2022 20:35:08 -0500 Subject: [PATCH 2/6] Update docstring --- xbatcher/testing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xbatcher/testing.py b/xbatcher/testing.py index c8645ea..4cb224d 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -55,7 +55,7 @@ def _get_non_input_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]: def _get_duplicate_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]: """ - Return all dimensions that are in batch_dims as well as input_dims. + Return all dimensions that are in both batch_dims and input_dims. Parameters ---------- @@ -65,8 +65,7 @@ def _get_duplicate_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]: Returns ------- d : dict - Dict containing all dimensions in specified in batch_dims that are - not also in input_dims + Dict containing all dimensions duplicated between batch_dims and input_dims. """ return { dim: length From d7fde92ef95464600887d2a51db65de9678dda25 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Fri, 16 Dec 2022 10:31:29 -0500 Subject: [PATCH 3/6] Update xbatcher/tests/test_generators.py Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> --- xbatcher/tests/test_generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 4ea54b3..8c5b48c 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -125,7 +125,7 @@ def test_batch_1d_concat_duplicate_dim(sample_ds_1d): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) From 9f3f67998b119123d43dde96259fd67adfab701e Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Fri, 16 Dec 2022 10:35:11 -0500 Subject: [PATCH 4/6] Remove more unnecessary enumerate() --- xbatcher/tests/test_generators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 8c5b48c..0c252e3 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -109,7 +109,7 @@ def test_batch_1d_concat(sample_ds_1d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) assert "x" in ds_batch.coords @@ -163,7 +163,7 @@ def test_batch_1d_concat_no_coordinate(sample_ds_1d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) assert "x" not in ds_batch.coords @@ -289,7 +289,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) @@ -300,7 +300,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) From 7d413f98e2480cf34d497dda2fc358c6c9a6e3f1 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Fri, 16 Dec 2022 14:29:29 -0500 Subject: [PATCH 5/6] Update xbatcher/tests/test_generators.py Co-authored-by: Raphael Hagen --- xbatcher/tests/test_generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 0c252e3..2ab665d 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -114,7 +114,7 @@ def test_batch_1d_concat(sample_ds_1d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) assert "x" in ds_batch.coords - +@pytest.mark.xfail def test_batch_1d_concat_duplicate_dim(sample_ds_1d): """ Test batch generation for a 1D dataset using ``concat_input_dims`` when From ecec5d9dc1821129e2f4ff90d099f519b4e6004e Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Fri, 16 Dec 2022 14:36:10 -0500 Subject: [PATCH 6/6] Add more xfail markers --- xbatcher/tests/test_generators.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 2ab665d..bba84dd 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -114,7 +114,10 @@ def test_batch_1d_concat(sample_ds_1d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) assert "x" in ds_batch.coords -@pytest.mark.xfail + +@pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +) def test_batch_1d_concat_duplicate_dim(sample_ds_1d): """ Test batch generation for a 1D dataset using ``concat_input_dims`` when @@ -216,6 +219,9 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +@pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +) @pytest.mark.parametrize("concat", [True, False]) def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): """ @@ -233,6 +239,9 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +@pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +) def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): """ Test batch generation for a 3D dataset using ``concat_input_dims`` when