diff --git a/zarr/core.py b/zarr/core.py index 43ccdbaf7d..06d848558b 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -2067,13 +2067,15 @@ def _process_chunk( fields, out_selection, partial_read_decode=False, + coords=None, ): """Take binary data from storage and fill output array""" + shape = tuple(c if isinstance(c, int) else c[coords[i]] for i, c in enumerate(self._chunks)) if ( out_is_ndarray and not fields and is_contiguous_selection(out_selection) - and is_total_slice(chunk_selection, self._chunks) + and is_total_slice(chunk_selection, shape) and not self._filters and self._dtype != object ): @@ -2100,7 +2102,7 @@ def _process_chunk( if isinstance(cdata, UncompressedPartialReadBufferV3): cdata = cdata.read_full() chunk = ensure_ndarray_like(cdata).view(self._dtype) - chunk = chunk.reshape(self._chunks, order=self._order) + chunk = chunk.reshape(shape, order=self._order) np.copyto(dest, chunk) return @@ -2109,7 +2111,7 @@ def _process_chunk( if partial_read_decode: cdata.prepare_chunk() # size of chunk - tmp = np.empty_like(self._meta_array, shape=self._chunks, dtype=self.dtype) + tmp = np.empty_like(self._meta_array, shape=shape, dtype=self.dtype) index_selection = PartialChunkIterator(chunk_selection, self.chunks) for start, nitems, partial_out_selection in index_selection: expected_shape = [ @@ -2138,7 +2140,7 @@ def _process_chunk( return except ArrayIndexError: cdata = cdata.read_full() - chunk = self._decode_chunk(cdata) + chunk = self._decode_chunk(cdata, expected_shape=shape) # select data from chunk if fields: @@ -2223,7 +2225,9 @@ def _chunk_getitems( contexts = ConstantMap(ckeys, constant=Context(meta_array=self._meta_array)) cdatas = self.chunk_store.getitems(ckeys, contexts=contexts) - for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection): + for ckey, coords, chunk_select, out_select in zip( + ckeys, lchunk_coords, lchunk_selection, lout_selection + ): if ckey in cdatas: self._process_chunk( out, @@ -2234,6 +2238,7 @@ def _chunk_getitems( fields, out_select, partial_read_decode=partial_read_decode, + coords=coords, ) else: # check exception type @@ -2305,7 +2310,9 @@ def _chunk_setitem(self, chunk_coords, chunk_selection, value, fields=None): def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value, fields=None): ckey = self._chunk_key(chunk_coords) - cdata = self._process_for_setitem(ckey, chunk_selection, value, fields=fields) + cdata = self._process_for_setitem( + ckey, chunk_selection, value, coords=chunk_coords, fields=fields + ) # attempt to delete chunk if it only contains the fill value if (not self.write_empty_chunks) and all_equal(self.fill_value, cdata): @@ -2313,8 +2320,9 @@ def _chunk_setitem_nosync(self, chunk_coords, chunk_selection, value, fields=Non else: self.chunk_store[ckey] = self._encode_chunk(cdata) - def _process_for_setitem(self, ckey, chunk_selection, value, fields=None): - if is_total_slice(chunk_selection, self._chunks) and not fields: + def _process_for_setitem(self, ckey, chunk_selection, value, coords=None, fields=None): + shape = tuple(c if isinstance(c, int) else c[coords[i]] for i, c in enumerate(self._chunks)) + if is_total_slice(chunk_selection, shape) and not fields: # totally replace chunk # optimization: we are completely replacing the chunk, so no need @@ -2324,7 +2332,7 @@ def _process_for_setitem(self, ckey, chunk_selection, value, fields=None): # setup array filled with value chunk = np.empty_like( - self._meta_array, shape=self._chunks, dtype=self._dtype, order=self._order + self._meta_array, shape=shape, dtype=self._dtype, order=self._order ) chunk.fill(value) @@ -2346,22 +2354,22 @@ def _process_for_setitem(self, ckey, chunk_selection, value, fields=None): # chunk not initialized if self._fill_value is not None: chunk = np.empty_like( - self._meta_array, shape=self._chunks, dtype=self._dtype, order=self._order + self._meta_array, shape=shape, dtype=self._dtype, order=self._order ) chunk.fill(self._fill_value) elif self._dtype == object: - chunk = np.empty(self._chunks, dtype=self._dtype, order=self._order) + chunk = np.empty(shape, dtype=self._dtype, order=self._order) else: # N.B., use zeros here so any region beyond the array has consistent # and compressible data chunk = np.zeros_like( - self._meta_array, shape=self._chunks, dtype=self._dtype, order=self._order + self._meta_array, shape=shape, dtype=self._dtype, order=self._order ) else: # decode chunk - chunk = self._decode_chunk(cdata) + chunk = self._decode_chunk(cdata, expected_shape=shape) if not chunk.flags.writeable: chunk = chunk.copy(order="K") diff --git a/zarr/indexing.py b/zarr/indexing.py index 487cc8b9d9..26ef6d900b 100644 --- a/zarr/indexing.py +++ b/zarr/indexing.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import collections import itertools import math import numbers +from typing import List import numpy as np @@ -163,6 +166,95 @@ def __iter__(self): yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) +class VarSliceDimIndexer: + def __init__(self, dim_sel: slice, dim_len: int, chunk_lengths: List[int]): + # normalize + start = dim_sel.start + if start is None: + start = 0 + elif start < 0: + start = dim_len + start + + stop = dim_sel.stop + if stop is None: + stop = dim_len + elif stop < 0: + stop = dim_len + stop + else: + stop = min(stop, dim_len) + + step = dim_sel.step or 1 + if step < 0: + raise NotImplementedError + + # store attributes + self.offsets = np.cumsum([0] + chunk_lengths) + self.dim_len = dim_len + self.dim_sel = dim_sel + self.projections = [] + + first_chunk, last_chunk = np.searchsorted(self.offsets[1:], [start, stop]) + + rem = 0 + nfilled = 0 + next_slice_start = start - self.offsets[first_chunk] + + for i in range(first_chunk, last_chunk + 1): + # Setup slice for this chunk + slice_start = next_slice_start + slice_end = min(chunk_lengths[i], stop - self.offsets[i]) + slice_len = ceildiv((slice_end - slice_start), step) + + # Prepare for next iteration. We do this ahead of the next iteration + # so that we don't need to special case the first chunk + rem = (self.offsets[i + 1] - start) % step + if rem != 0: + next_slice_start = step - rem + else: + next_slice_start = 0 + + # Skip iteration if slice is empty + if slice_end <= slice_start: + continue + + # Add projection for this chunk + update nfilled + cur_nfilled = nfilled + slice_len + self.projections.append( + ChunkDimProjection( + i, + slice(slice_start, slice_end, step), + slice(nfilled, cur_nfilled), + ) + ) + nfilled = cur_nfilled + + self.nitems = nfilled + + def __iter__(self): + yield from self.projections + + +class VarIntDimIndexer: + def __init__(self, dim_sel: int, dim_len: int, chunk_lengths: List[int]): + + self.offsets = np.cumsum([0] + chunk_lengths) + self.dim_len = dim_len + + # normalize + dim_sel = normalize_integer_selection(dim_sel, self.dim_len) + + # store attributes + self.dim_sel = dim_sel + self.nitems = 1 + + def __iter__(self): + for ix, off in enumerate(self.offsets): + if off > self.dim_sel: + break + ix -= 1 + yield ChunkDimProjection(ix, self.dim_sel - self.offsets[ix], None) + + def ceildiv(a, b): return math.ceil(a / b) @@ -339,10 +431,16 @@ def __init__(self, selection, array): for dim_sel, dim_len, dim_chunk_len in zip(selection, array._shape, array._chunks): if is_integer(dim_sel): - dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len) + if isinstance(dim_chunk_len, int): + dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len) + else: + dim_indexer = VarIntDimIndexer(dim_sel, dim_len, dim_chunk_len) elif is_slice(dim_sel): - dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim_chunk_len) + if isinstance(dim_chunk_len, int): + dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim_chunk_len) + else: + dim_indexer = VarSliceDimIndexer(dim_sel, dim_len, dim_chunk_len) else: raise IndexError( @@ -427,6 +525,56 @@ def __iter__(self): yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) +class VarBoolArrayDimIndexer: + def __init__(self, dim_sel, dim_len, dim_chunk_lens: List[int]): + + # check number of dimensions + if not is_bool_array(dim_sel, 1): + raise IndexError( + "Boolean arrays in an orthogonal selection must " "be 1-dimensional only" + ) + + # check shape + if dim_sel.shape[0] != dim_len: + raise IndexError( + "Boolean array has the wrong length for dimension; " + "expected {}, got {}".format(dim_len, dim_sel.shape[0]) + ) + + # store attributes + self.dim_sel = dim_sel + self.dim_len = dim_len + self.dim_chunk_lens = dim_chunk_lens + self.nchunks = len(dim_chunk_lens) + self.offsets = np.cumsum([0] + dim_chunk_lens) + + # precompute number of selected items for each chunk + self.chunk_nitems = np.zeros(self.nchunks, dtype="i8") + np.add.reduceat(dim_sel, self.offsets[:-1], out=self.chunk_nitems) + self.chunk_nitems_cumsum = np.cumsum(self.chunk_nitems) + self.nitems = self.chunk_nitems_cumsum[-1] + self.dim_chunk_ixs = np.nonzero(self.chunk_nitems)[0] + + def __iter__(self): + + # iterate over chunks with at least one item + for dim_chunk_ix in self.dim_chunk_ixs: + + # find region in chunk + dim_offset = self.offsets[dim_chunk_ix] + dim_chunk_sel = self.dim_sel[dim_offset : self.offsets[dim_chunk_ix + 1]] + + # find region in output + if dim_chunk_ix == 0: + start = 0 + else: + start = self.chunk_nitems_cumsum[dim_chunk_ix - 1] + stop = self.chunk_nitems_cumsum[dim_chunk_ix] + dim_out_sel = slice(start, stop) + + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) + + class Order: UNKNOWN = 0 INCREASING = 1 @@ -460,6 +608,94 @@ def boundscheck_indices(x, dim_len): raise BoundsCheckError(dim_len) +class VarIntArrayDimIndexer: + """Integer array selection against a single dimension.""" + + def __init__( + self, + dim_sel, + dim_len, + dim_chunk_lens: List[int], + wraparound=True, + boundscheck=True, + order=Order.UNKNOWN, + ): + + # ensure 1d array + dim_sel = np.asanyarray(dim_sel) + if not is_integer_array(dim_sel, 1): + raise IndexError( + "integer arrays in an orthogonal selection must be " "1-dimensional only" + ) + + # handle wraparound + if wraparound: + wraparound_indices(dim_sel, dim_len) + + # handle out of bounds + if boundscheck: + boundscheck_indices(dim_sel, dim_len) + + # store attributes + self.dim_len = dim_len + self.dim_chunk_lens = dim_chunk_lens + self.nchunks = len(dim_chunk_lens) + self.nitems = len(dim_sel) + self.offsets = np.cumsum([0] + dim_chunk_lens) + + # determine which chunk is needed for each selection item + # note: for dense integer selections, the division operation here is the + # bottleneck + dim_sel_chunk = np.digitize(dim_sel, self.offsets[1:]) + + # determine order of indices + if order == Order.UNKNOWN: + order = Order.check(dim_sel) + self.order = order + + if self.order == Order.INCREASING: + self.dim_sel = dim_sel + self.dim_out_sel = None + elif self.order == Order.DECREASING: + self.dim_sel = dim_sel[::-1] + # TODO should be possible to do this without creating an arange + self.dim_out_sel = np.arange(self.nitems - 1, -1, -1) + else: + # sort indices to group by chunk + self.dim_out_sel = np.argsort(dim_sel_chunk) + self.dim_sel = np.take(dim_sel, self.dim_out_sel) + + # precompute number of selected items for each chunk + self.chunk_nitems = np.bincount(dim_sel_chunk, minlength=self.nchunks) + + # find chunks that we need to visit + self.dim_chunk_ixs = np.nonzero(self.chunk_nitems)[0] + + # compute offsets into the output array + self.chunk_nitems_cumsum = np.cumsum(self.chunk_nitems) + + def __iter__(self): + + for dim_chunk_ix in self.dim_chunk_ixs: + + # find region in output + if dim_chunk_ix == 0: + start = 0 + else: + start = self.chunk_nitems_cumsum[dim_chunk_ix - 1] + stop = self.chunk_nitems_cumsum[dim_chunk_ix] + if self.order == Order.INCREASING: + dim_out_sel = slice(start, stop) + else: + dim_out_sel = self.dim_out_sel[start:stop] + + # find region in chunk + dim_offset = self.offsets[dim_chunk_ix] + dim_chunk_sel = self.dim_sel[start:stop] - dim_offset + + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) + + class IntArrayDimIndexer: """Integer array selection against a single dimension.""" @@ -614,16 +850,28 @@ def __init__(self, selection, array): for dim_sel, dim_len, dim_chunk_len in zip(selection, array._shape, array._chunks): if is_integer(dim_sel): - dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len) + if isinstance(dim_chunk_len, int): + dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len) + else: + dim_indexer = VarIntDimIndexer(dim_sel, dim_len, dim_chunk_len) elif isinstance(dim_sel, slice): - dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim_chunk_len) + if isinstance(dim_chunk_len, int): + dim_indexer = SliceDimIndexer(dim_sel, dim_len, dim_chunk_len) + else: + dim_indexer = VarSliceDimIndexer(dim_sel, dim_len, dim_chunk_len) elif is_integer_array(dim_sel): - dim_indexer = IntArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) + if isinstance(dim_chunk_len, int): + dim_indexer = IntArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) + else: + dim_indexer = VarIntArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) elif is_bool_array(dim_sel): - dim_indexer = BoolArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) + if isinstance(dim_chunk_len, int): + dim_indexer = BoolArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) + else: + dim_indexer = VarBoolArrayDimIndexer(dim_sel, dim_len, dim_chunk_len) else: raise IndexError( diff --git a/zarr/tests/test_indexing.py b/zarr/tests/test_indexing.py index 8a34c1e715..3d43780610 100644 --- a/zarr/tests/test_indexing.py +++ b/zarr/tests/test_indexing.py @@ -1790,3 +1790,73 @@ def test_accessed_chunks(shape, chunks, ops): ) == 1 # Check that no other chunks were accessed assert len(delta_counts) == 0 + + +@pytest.mark.parametrize( + "chunking, index", + [ + # Single chunk + ([[10]], (slice(5, 7),)), + # Sample from single chunk + ([[1, 9]], (slice(5, 7),)), + # Slice across multiple chunks + ([[1, 5, 4]], (slice(5, 7),)), + # Slice across multiple chunks, negative + ([[1, 5, 4]], (slice(-6, -2),)), + # Full slice across multiple chunks + ([[1, 2, 2, 5]], (slice(None),)), + # slice across multiple chunks + ([[1, 2, 2, 5]], (slice(2, 8),)), + # slice across multiple chunks with step + ([[1, 2, 2, 5]], (slice(None, None, 2),)), + # slice across multiple chunks with start and step + ([[1, 2, 2, 5]], (slice(1, None, 2),)), + # Skip many chunks + ([[3, 3, 1, 1, 1, 1, 3, 1, 3]], (slice(1, None, 4),)), + # Slice 2d case + ([[1, 2, 2, 5], 5], (slice(1, None, 2), slice(None))), + # Fancy indexing + ([[1, 2, 2, 5], 1], (np.array([0, 3, 7]),)), + # out of order fancy indexing + ([[1, 3, 1, 1, 5], 1], (np.array([6, 6, 6, 0, 10]),)), + # boolean indexing + ([[2, 1, 2], 2], ([True, False, False, True, True], slice(None))), + # single int indexing + ([[2, 1, 2]], (1,)), + ], +) +def test_varchunk_indexing(chunking, index): + from functools import reduce + from operator import mul + + shape = tuple(x * 2 if isinstance(x, int) else sum(x) for x in chunking) + np_array = np.arange(reduce(mul, shape)).reshape(shape) + var_chunked = zarr.array(np_array, chunks=chunking) + regular_chunked = zarr.array(np_array) + + expected = regular_chunked[index] + actual = var_chunked[index] + np.testing.assert_equal(actual, expected) + + +# single int indexing +@pytest.mark.parametrize( + "chunking, index, err", + [ + # negative step + ([[2, 1, 2]], (slice(None, None, -1)), NotImplementedError), + # boolean, wrong length + ([[2, 1, 2], 2], ([True, False, False, True], slice(None)), IndexError), + # boolean orthogonal + ([[2, 1, 2], 2], ([True, False, False, True, True], [True, False]), IndexError), + ], +) +def test_varchunk_errors(chunking, index, err): + from functools import reduce + from operator import mul + + shape = tuple(x * 2 if isinstance(x, int) else sum(x) for x in chunking) + np_array = np.arange(reduce(mul, shape)).reshape(shape) + var_chunked = zarr.array(np_array, chunks=chunking) + with pytest.raises(err): + var_chunked[index] diff --git a/zarr/util.py b/zarr/util.py index ea0dd9fcec..ad7f29a009 100644 --- a/zarr/util.py +++ b/zarr/util.py @@ -175,7 +175,7 @@ def normalize_chunks(chunks: Any, shape: Tuple[int, ...], typesize: int) -> Tupl if -1 in chunks or None in chunks: chunks = tuple(s if c == -1 or c is None else int(c) for s, c in zip(shape, chunks)) - chunks = tuple(int(c) for c in chunks) + chunks = tuple(list(int(_) for _ in c) if isinstance(c, list) else int(c) for c in chunks) return chunks