From 870265a0474eeadef32ed015f33da28e104c95eb Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Thu, 13 Feb 2025 22:06:05 +1100 Subject: [PATCH] fix: sharding codec with fancy indexing (#2817) * fix: sharding codec with fancy indexing * changelog * add a better test * proper fix * fix: ArrayOfIntOrBool typing * Revert "fix: ArrayOfIntOrBool typing" This reverts commit 1a30563a2b6d67c74357e14b091c144ab6befe46. * ignore typing error in test --------- Co-authored-by: Deepak Cherian --- changes/2817.bugfix.rst | 1 + src/zarr/codecs/sharding.py | 6 +++++- tests/test_array.py | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 changes/2817.bugfix.rst diff --git a/changes/2817.bugfix.rst b/changes/2817.bugfix.rst new file mode 100644 index 0000000000..b1c0fa9220 --- /dev/null +++ b/changes/2817.bugfix.rst @@ -0,0 +1 @@ +Fix fancy indexing (e.g. arr[5, [0, 1]]) with the sharding codec \ No newline at end of file diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 459805d808..42b1313fac 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -531,7 +531,11 @@ async def _decode_partial_single( ], out, ) - return out + + if hasattr(indexer, "sel_shape"): + return out.reshape(indexer.sel_shape) + else: + return out async def _encode_single( self, diff --git a/tests/test_array.py b/tests/test_array.py index 6aaf1072ba..4838129561 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1429,3 +1429,18 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser results = pool.starmap(_index_array, [(arr, slice(len(data)))]) assert all(np.array_equal(r, data) for r in results) + + +async def test_sharding_coordinate_selection() -> None: + store = MemoryStore() + g = zarr.open_group(store, mode="w") + arr = g.create_array( + name="a", + shape=(2, 3, 4), + chunks=(1, 2, 2), + overwrite=True, + dtype=np.float32, + shards=(2, 4, 4), + ) + arr[:] = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() # type: ignore[index]