From 0cdc04b82503d23ad2e0a32b55e08b09856ab183 Mon Sep 17 00:00:00 2001 From: Norman Rzepka Date: Fri, 18 Oct 2024 14:52:22 +0200 Subject: [PATCH] fix reading partial shards --- src/zarr/codecs/sharding.py | 4 +-- tests/v3/test_codecs/test_sharding.py | 36 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 2181e9eb76..d01e116f9a 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -129,7 +129,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None: if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64): return None else: - return (int(chunk_start), int(chunk_start) + int(chunk_len)) + return (int(chunk_start), int(chunk_len)) def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None: localized_chunk = self._localize_chunk(chunk_coords) @@ -203,7 +203,7 @@ def create_empty( def __getitem__(self, chunk_coords: ChunkCoords) -> Buffer: chunk_byte_slice = self.index.get_chunk_slice(chunk_coords) if chunk_byte_slice: - return self.buf[chunk_byte_slice[0] : chunk_byte_slice[1]] + return self.buf[chunk_byte_slice[0] : (chunk_byte_slice[0] + chunk_byte_slice[1])] raise KeyError def __len__(self) -> int: diff --git a/tests/v3/test_codecs/test_sharding.py b/tests/v3/test_codecs/test_sharding.py index c0dcfbf350..f827a0720e 100644 --- a/tests/v3/test_codecs/test_sharding.py +++ b/tests/v3/test_codecs/test_sharding.py @@ -118,6 +118,42 @@ def test_sharding_partial( assert np.array_equal(data, read_data) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +@pytest.mark.parametrize( + "array_fixture", + [ + ArrayRequest(shape=(128,) * 3, dtype="uint16", order="F"), + ], + indirect=["array_fixture"], +) +def test_sharding_partial_readwrite( + store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation +) -> None: + data = array_fixture + spath = StorePath(store) + a = Array.create( + spath, + shape=data.shape, + chunk_shape=data.shape, + dtype=data.dtype, + fill_value=0, + codecs=[ + ShardingCodec( + chunk_shape=(1, data.shape[1], data.shape[2]), + codecs=[BytesCodec()], + index_location=index_location, + ) + ], + ) + + a[:] = data + + for x in range(data.shape[0]): + read_data = a[x, :, :] + assert np.array_equal(data[x], read_data) + + @pytest.mark.parametrize( "array_fixture", [