Skip to content

Commit

Permalink
Always skip reads when completely overwriting chunks (#2784)
Browse files Browse the repository at this point in the history
* Skip reads when completely overwriting boundary chunks

Uses `slice(..., None)` to indicate that a `chunk_selection`
ends at the boundary of the current chunk. Also does so for a last
chunk that is shorter than the chunk size.

`is_total_slice` now understands this convention, and correctly
detects boundary chunks as total slices.

Closes #757

* normalize in codec_pipeline

* Revert "normalize in codec_pipeline"

This reverts commit 234431cd6efb661c53e2a832a0e4ea4dca772c1b.

* Partially Revert "Skip reads when completely overwriting boundary chunks"

This reverts commit edbba37.

* Different approach

* fix bug

* add oindex property test

* more complex oindex test

* cleanup

* more oindex

* Add changelog entry

* [revert] note

* fix for numpy 1.25

---------

Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
  • Loading branch information
dcherian and d-v-b authored Feb 12, 2025
1 parent c66f32b commit feeb08f
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 83 deletions.
1 change: 1 addition & 0 deletions changes/2784.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid reading chunks during writes where possible. :issue:`757`
4 changes: 2 additions & 2 deletions src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ async def encode(
@abstractmethod
async def read(
self,
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
out: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -379,7 +379,7 @@ async def read(
@abstractmethod
async def write(
self,
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
value: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand Down
14 changes: 9 additions & 5 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ async def _decode_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
out,
)
Expand Down Expand Up @@ -486,7 +487,7 @@ async def _decode_partial_single(
)

indexed_chunks = list(indexer)
all_chunk_coords = {chunk_coords for chunk_coords, _, _ in indexed_chunks}
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}

# reading bytes of all requested chunks
shard_dict: ShardMapping = {}
Expand Down Expand Up @@ -524,8 +525,9 @@ async def _decode_partial_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
out,
)
Expand Down Expand Up @@ -558,8 +560,9 @@ async def _encode_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
shard_array,
)
Expand Down Expand Up @@ -601,8 +604,9 @@ async def _encode_partial_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
shard_array,
)
Expand Down
6 changes: 4 additions & 2 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,8 +1290,9 @@ async def _get_selection(
self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
chunk_selection,
out_selection,
is_complete_chunk,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
],
out_buffer,
drop_axes=indexer.drop_axes,
Expand Down Expand Up @@ -1417,8 +1418,9 @@ async def _set_selection(
self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
chunk_selection,
out_selection,
is_complete_chunk,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
],
value_buffer,
drop_axes=indexer.drop_axes,
Expand Down
60 changes: 34 additions & 26 deletions src/zarr/core/codec_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from zarr.core.common import ChunkCoords, concurrent_map
from zarr.core.config import config
from zarr.core.indexing import SelectorTuple, is_scalar, is_total_slice
from zarr.core.indexing import SelectorTuple, is_scalar
from zarr.core.metadata.v2 import _default_fill_value
from zarr.registry import register_pipeline

Expand Down Expand Up @@ -243,18 +243,18 @@ async def encode_partial_batch(

async def read_batch(
self,
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
out: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
if self.supports_partial_decode:
chunk_array_batch = await self.decode_partial_batch(
[
(byte_getter, chunk_selection, chunk_spec)
for byte_getter, chunk_spec, chunk_selection, _ in batch_info
for byte_getter, chunk_spec, chunk_selection, *_ in batch_info
]
)
for chunk_array, (_, chunk_spec, _, out_selection) in zip(
for chunk_array, (_, chunk_spec, _, out_selection, _) in zip(
chunk_array_batch, batch_info, strict=False
):
if chunk_array is not None:
Expand All @@ -263,22 +263,19 @@ async def read_batch(
out[out_selection] = fill_value_or_default(chunk_spec)
else:
chunk_bytes_batch = await concurrent_map(
[
(byte_getter, array_spec.prototype)
for byte_getter, array_spec, _, _ in batch_info
],
[(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info],
lambda byte_getter, prototype: byte_getter.get(prototype),
config.get("async.concurrency"),
)
chunk_array_batch = await self.decode_batch(
[
(chunk_bytes, chunk_spec)
for chunk_bytes, (_, chunk_spec, _, _) in zip(
for chunk_bytes, (_, chunk_spec, *_) in zip(
chunk_bytes_batch, batch_info, strict=False
)
],
)
for chunk_array, (_, chunk_spec, chunk_selection, out_selection) in zip(
for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip(
chunk_array_batch, batch_info, strict=False
):
if chunk_array is not None:
Expand All @@ -296,9 +293,10 @@ def _merge_chunk_array(
out_selection: SelectorTuple,
chunk_spec: ArraySpec,
chunk_selection: SelectorTuple,
is_complete_chunk: bool,
drop_axes: tuple[int, ...],
) -> NDBuffer:
if is_total_slice(chunk_selection, chunk_spec.shape) and value.shape == chunk_spec.shape:
if is_complete_chunk and value.shape == chunk_spec.shape:
return value
if existing_chunk_array is None:
chunk_array = chunk_spec.prototype.nd_buffer.create(
Expand Down Expand Up @@ -327,7 +325,7 @@ def _merge_chunk_array(

async def write_batch(
self,
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
value: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -337,14 +335,14 @@ async def write_batch(
await self.encode_partial_batch(
[
(byte_setter, value, chunk_selection, chunk_spec)
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
],
)
else:
await self.encode_partial_batch(
[
(byte_setter, value[out_selection], chunk_selection, chunk_spec)
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
],
)

Expand All @@ -361,33 +359,43 @@ async def _read_key(
chunk_bytes_batch = await concurrent_map(
[
(
None if is_total_slice(chunk_selection, chunk_spec.shape) else byte_setter,
None if is_complete_chunk else byte_setter,
chunk_spec.prototype,
)
for byte_setter, chunk_spec, chunk_selection, _ in batch_info
for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info
],
_read_key,
config.get("async.concurrency"),
)
chunk_array_decoded = await self.decode_batch(
[
(chunk_bytes, chunk_spec)
for chunk_bytes, (_, chunk_spec, _, _) in zip(
for chunk_bytes, (_, chunk_spec, *_) in zip(
chunk_bytes_batch, batch_info, strict=False
)
],
)

chunk_array_merged = [
self._merge_chunk_array(
chunk_array, value, out_selection, chunk_spec, chunk_selection, drop_axes
)
for chunk_array, (_, chunk_spec, chunk_selection, out_selection) in zip(
chunk_array_decoded, batch_info, strict=False
chunk_array,
value,
out_selection,
chunk_spec,
chunk_selection,
is_complete_chunk,
drop_axes,
)
for chunk_array, (
_,
chunk_spec,
chunk_selection,
out_selection,
is_complete_chunk,
) in zip(chunk_array_decoded, batch_info, strict=False)
]
chunk_array_batch: list[NDBuffer | None] = []
for chunk_array, (_, chunk_spec, _, _) in zip(
for chunk_array, (_, chunk_spec, *_) in zip(
chunk_array_merged, batch_info, strict=False
):
if chunk_array is None:
Expand All @@ -403,7 +411,7 @@ async def _read_key(
chunk_bytes_batch = await self.encode_batch(
[
(chunk_array, chunk_spec)
for chunk_array, (_, chunk_spec, _, _) in zip(
for chunk_array, (_, chunk_spec, *_) in zip(
chunk_array_batch, batch_info, strict=False
)
],
Expand All @@ -418,7 +426,7 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
await concurrent_map(
[
(byte_setter, chunk_bytes)
for chunk_bytes, (byte_setter, _, _, _) in zip(
for chunk_bytes, (byte_setter, *_) in zip(
chunk_bytes_batch, batch_info, strict=False
)
],
Expand Down Expand Up @@ -446,7 +454,7 @@ async def encode(

async def read(
self,
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
out: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -461,7 +469,7 @@ async def read(

async def write(
self,
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
value: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand Down
Loading

0 comments on commit feeb08f

Please sign in to comment.