Skip to content

Commit

Permalink
Fix a bug when setting complete chunks (#2851)
Browse files Browse the repository at this point in the history
* Fix a bug when setting complete chunks

Closes #2849

* much simpler test

* add release note

* Update strategy priorities:

1. Emphasize arrays of side > 1,
2. Emphasize indexing the last chunk for both setitem & getitem

* Use short node names

* bug fix

* Add scalar tests

* [revert]

* Add unit test

* one more test

* Add xfails

* switch to skip, XPASS is not allwoed

* Fix test

* cleaniup

---------

Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
  • Loading branch information
dcherian and d-v-b authored Feb 23, 2025
1 parent 96c9677 commit 8b59a38
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 36 deletions.
1 change: 1 addition & 0 deletions changes/2851.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug when setting values of a smaller last chunk.
25 changes: 14 additions & 11 deletions src/zarr/core/codec_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,6 @@ def _merge_chunk_array(
is_complete_chunk: bool,
drop_axes: tuple[int, ...],
) -> NDBuffer:
if is_complete_chunk and value.shape == chunk_spec.shape:
return value
if existing_chunk_array is None:
chunk_array = chunk_spec.prototype.nd_buffer.create(
shape=chunk_spec.shape,
dtype=chunk_spec.dtype,
order=chunk_spec.order,
fill_value=fill_value_or_default(chunk_spec),
)
else:
chunk_array = existing_chunk_array.copy() # make a writable copy
if chunk_selection == () or is_scalar(value.as_ndarray_like(), chunk_spec.dtype):
chunk_value = value
else:
Expand All @@ -320,6 +309,20 @@ def _merge_chunk_array(
for idx in range(chunk_spec.ndim)
)
chunk_value = chunk_value[item]
if is_complete_chunk and chunk_value.shape == chunk_spec.shape:
# TODO: For the last chunk, we could have is_complete_chunk=True
# that is smaller than the chunk_spec.shape but this throws
# an error in the _decode_single
return chunk_value
if existing_chunk_array is None:
chunk_array = chunk_spec.prototype.nd_buffer.create(
shape=chunk_spec.shape,
dtype=chunk_spec.dtype,
order=chunk_spec.order,
fill_value=fill_value_or_default(chunk_spec),
)
else:
chunk_array = existing_chunk_array.copy() # make a writable copy
chunk_array[chunk_selection] = chunk_value
return chunk_array

Expand Down
2 changes: 1 addition & 1 deletion src/zarr/storage/_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def from_url(
try:
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper

fs = AsyncFileSystemWrapper(fs)
fs = AsyncFileSystemWrapper(fs, asynchronous=True)
except ImportError as e:
raise ImportError(
f"The filesystem for URL '{url}' is synchronous, and the required "
Expand Down
8 changes: 4 additions & 4 deletions src/zarr/testing/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(self, store: Store) -> None:
def init_store(self) -> None:
self.store.clear()

@rule(key=zarr_keys, data=st.binary(min_size=0, max_size=MAX_BINARY_SIZE))
@rule(key=zarr_keys(), data=st.binary(min_size=0, max_size=MAX_BINARY_SIZE))
def set(self, key: str, data: DataObject) -> None:
note(f"(set) Setting {key!r} with {data}")
assert not self.store.read_only
Expand All @@ -334,7 +334,7 @@ def set(self, key: str, data: DataObject) -> None:
self.model[key] = data_buf

@precondition(lambda self: len(self.model.keys()) > 0)
@rule(key=zarr_keys, data=st.data())
@rule(key=zarr_keys(), data=st.data())
def get(self, key: str, data: DataObject) -> None:
key = data.draw(
st.sampled_from(sorted(self.model.keys()))
Expand All @@ -344,7 +344,7 @@ def get(self, key: str, data: DataObject) -> None:
# to bytes here necessary because data_buf set to model in set()
assert self.model[key] == store_value

@rule(key=zarr_keys, data=st.data())
@rule(key=zarr_keys(), data=st.data())
def get_invalid_zarr_keys(self, key: str, data: DataObject) -> None:
note("(get_invalid)")
assume(key not in self.model)
Expand Down Expand Up @@ -408,7 +408,7 @@ def is_empty(self) -> None:
# make sure they either both are or both aren't empty (same state)
assert self.store.is_empty("") == (not self.model)

@rule(key=zarr_keys)
@rule(key=zarr_keys())
def exists(self, key: str) -> None:
note("(exists)")

Expand Down
86 changes: 71 additions & 15 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import sys
from typing import Any, Literal

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np
from hypothesis import given, settings # noqa: F401
from hypothesis import event, given, settings # noqa: F401
from hypothesis.strategies import SearchStrategy

import zarr
Expand All @@ -28,6 +29,16 @@
)


@st.composite # type: ignore[misc]
def keys(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> Any:
return draw(st.lists(node_names, min_size=1, max_size=max_num_nodes).map("/".join))


@st.composite # type: ignore[misc]
def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> Any:
return draw(st.just("/") | keys(max_num_nodes=max_num_nodes))


def v3_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.boolean_dtypes()
Expand Down Expand Up @@ -87,17 +98,19 @@ def clear_store(x: Store) -> Store:
node_names = st.text(zarr_key_chars, min_size=1).filter(
lambda t: t not in (".", "..") and not t.startswith("__")
)
short_node_names = st.text(zarr_key_chars, max_size=3, min_size=1).filter(
lambda t: t not in (".", "..") and not t.startswith("__")
)
array_names = node_names
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
keys = st.lists(node_names, min_size=1).map("/".join)
paths = st.just("/") | keys
# st.builds will only call a new store constructor for different keyword arguments
# i.e. stores.examples() will always return the same object per Store class.
# So we map a clear to reset the store.
stores = st.builds(MemoryStore, st.just({})).map(clear_store)
compressors = st.sampled_from([None, "default"])
zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([3, 2])
array_shapes = npst.array_shapes(max_dims=4, min_side=0)
# We de-prioritize arrays having dim sizes 0, 1, 2
array_shapes = npst.array_shapes(max_dims=4, min_side=3) | npst.array_shapes(max_dims=4, min_side=0)


@st.composite # type: ignore[misc]
Expand Down Expand Up @@ -152,13 +165,15 @@ def numpy_arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
dtype: np.dtype[Any] | None = None,
zarr_formats: st.SearchStrategy[ZarrFormat] | None = zarr_formats,
) -> Any:
"""
Generate numpy arrays that can be saved in the provided Zarr format.
"""
zarr_format = draw(zarr_formats)
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
if dtype is None:
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
if np.issubdtype(dtype, np.str_):
safe_unicode_strings = safe_unicode_for_dtype(dtype)
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))
Expand All @@ -174,11 +189,19 @@ def chunk_shapes(draw: st.DrawFn, *, shape: tuple[int, ...]) -> tuple[int, ...]:
st.tuples(*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in shape])
)
# 2. and now generate the chunks tuple
return tuple(
chunks = tuple(
size // nchunks if nchunks > 0 else 0
for size, nchunks in zip(shape, numchunks, strict=True)
)

for c in chunks:
event("chunk size", c)

if any((c != 0 and s % c != 0) for s, c in zip(shape, chunks, strict=True)):
event("smaller last chunk")

return chunks


@st.composite # type: ignore[misc]
def shard_shapes(
Expand Down Expand Up @@ -211,7 +234,7 @@ def arrays(
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
compressors: st.SearchStrategy = compressors,
stores: st.SearchStrategy[StoreLike] = stores,
paths: st.SearchStrategy[str | None] = paths,
paths: st.SearchStrategy[str | None] = paths(), # noqa: B008
array_names: st.SearchStrategy = array_names,
arrays: st.SearchStrategy | None = None,
attrs: st.SearchStrategy = attrs,
Expand Down Expand Up @@ -267,23 +290,56 @@ def arrays(
return a


@st.composite # type: ignore[misc]
def simple_arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
) -> Any:
return draw(
arrays(
shapes=shapes,
paths=paths(max_num_nodes=2),
array_names=short_node_names,
attrs=st.none(),
compressors=st.sampled_from([None, "default"]),
)
)


def is_negative_slice(idx: Any) -> bool:
return isinstance(idx, slice) and idx.step is not None and idx.step < 0


@st.composite # type: ignore[misc]
def end_slices(draw: st.DrawFn, *, shape: tuple[int]) -> Any:
"""
A strategy that slices ranges that include the last chunk.
This is intended to stress-test handling of a possibly smaller last chunk.
"""
slicers = []
for size in shape:
start = draw(st.integers(min_value=size // 2, max_value=size - 1))
length = draw(st.integers(min_value=0, max_value=size - start))
slicers.append(slice(start, start + length))
event("drawing end slice")
return tuple(slicers)


@st.composite # type: ignore[misc]
def basic_indices(draw: st.DrawFn, *, shape: tuple[int], **kwargs: Any) -> Any:
"""Basic indices without unsupported negative slices."""
return draw(
npst.basic_indices(shape=shape, **kwargs).filter(
lambda idxr: (
not (
is_negative_slice(idxr)
or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr))
)
strategy = npst.basic_indices(shape=shape, **kwargs).filter(
lambda idxr: (
not (
is_negative_slice(idxr)
or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr))
)
)
)
if math.prod(shape) >= 3:
strategy = end_slices(shape=shape) | strategy
return draw(strategy)


@st.composite # type: ignore[misc]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,18 @@ def test_orthogonal_indexing_fallback_on_getitem_2d(
np.testing.assert_array_equal(z[index], expected_result)


@pytest.mark.skip(reason="fails on ubuntu, windows; numpy=2.2; in CI")
def test_setitem_repeated_index():
array = zarr.array(data=np.zeros((4,)), chunks=(1,))
indexer = np.array([-1, -1, 0, 0])
array.oindex[(indexer,)] = [0, 1, 2, 3]
np.testing.assert_array_equal(array[:], np.array([3, 0, 0, 1]))

indexer = np.array([-1, 0, 0, -1])
array.oindex[(indexer,)] = [0, 1, 2, 3]
np.testing.assert_array_equal(array[:], np.array([2, 0, 0, 3]))


Index = list[int] | tuple[slice | int | list[int], ...]


Expand Down Expand Up @@ -815,6 +827,25 @@ def test_set_orthogonal_selection_1d(store: StorePath) -> None:
_test_set_orthogonal_selection(v, a, z, selection)


def test_set_item_1d_last_two_chunks():
# regression test for GH2849
g = zarr.open_group("foo.zarr", zarr_format=3, mode="w")
a = g.create_array("bar", shape=(10,), chunks=(3,), dtype=int)
data = np.array([7, 8, 9])
a[slice(7, 10)] = data
np.testing.assert_array_equal(a[slice(7, 10)], data)

z = zarr.open_group("foo.zarr", mode="w")
z.create_array("zoo", dtype=float, shape=())
z["zoo"][...] = np.array(1) # why doesn't [:] work?
np.testing.assert_equal(z["zoo"][()], np.array(1))

z = zarr.open_group("foo.zarr", mode="w")
z.create_array("zoo", dtype=float, shape=())
z["zoo"][...] = 1 # why doesn't [:] work?
np.testing.assert_equal(z["zoo"][()], np.array(1))


def _test_set_orthogonal_selection_2d(
v: npt.NDArray[np.int_],
a: npt.NDArray[np.int_],
Expand Down
16 changes: 11 additions & 5 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal

Expand All @@ -18,6 +19,7 @@
basic_indices,
numpy_arrays,
orthogonal_indices,
simple_arrays,
stores,
zarr_formats,
)
Expand Down Expand Up @@ -50,13 +52,13 @@ def test_array_creates_implicit_groups(array):

@given(data=st.data())
def test_basic_indexing(data: st.DataObject) -> None:
zarray = data.draw(arrays())
zarray = data.draw(simple_arrays())
nparray = zarray[:]
indexer = data.draw(basic_indices(shape=nparray.shape))
actual = zarray[indexer]
assert_array_equal(nparray[indexer], actual)

new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
zarray[indexer] = new_data
nparray[indexer] = new_data
assert_array_equal(nparray, zarray[:])
Expand All @@ -65,15 +67,19 @@ def test_basic_indexing(data: st.DataObject) -> None:
@given(data=st.data())
def test_oindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
nparray = zarray[:]

zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
actual = zarray.oindex[zindexer]
assert_array_equal(nparray[npindexer], actual)

assume(zarray.shards is None) # GH2834
new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
for idxr in npindexer:
if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size:
# behaviour of setitem with repeated indices is not guaranteed in practice
assume(False)
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
nparray[npindexer] = new_data
zarray.oindex[zindexer] = new_data
assert_array_equal(nparray, zarray[:])
Expand All @@ -82,7 +88,7 @@ def test_oindex(data: st.DataObject) -> None:
@given(data=st.data())
def test_vindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
nparray = zarray[:]

indexer = data.draw(
Expand Down

0 comments on commit 8b59a38

Please sign in to comment.