Skip to content

Commit

Permalink
Fix typing in zarr.indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
dstansby committed Apr 14, 2024
1 parent 59d9d09 commit c27b0a8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ module = [
"zarr.v3.group",
"zarr.core",
"zarr.hierarchy",
"zarr.indexing",
"zarr.storage",
"tests.*",
]
Expand Down
17 changes: 14 additions & 3 deletions src/zarr/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numbers

import numpy as np
import numpy.typing as npt


from zarr.errors import (
Expand Down Expand Up @@ -330,6 +331,7 @@ def __init__(self, selection, array):
# setup per-dimension indexers
dim_indexers = []
for dim_sel, dim_len, dim_chunk_len in zip(selection, array._shape, array._chunks):
dim_indexer: IntDimIndexer | SliceDimIndexer
if is_integer(dim_sel):
dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len)

Expand Down Expand Up @@ -521,9 +523,12 @@ def __iter__(self):
else:
start = self.chunk_nitems_cumsum[dim_chunk_ix - 1]
stop = self.chunk_nitems_cumsum[dim_chunk_ix]

dim_out_sel: slice | np.ndarray
if self.order == Order.INCREASING:
dim_out_sel = slice(start, stop)
else:
assert self.dim_out_sel is not None
dim_out_sel = self.dim_out_sel[start:stop]

# find region in chunk
Expand Down Expand Up @@ -577,11 +582,10 @@ def oindex_set(a, selection, value):
selection = ix_(selection, a.shape)
if not np.isscalar(value) and drop_axes:
value = np.asanyarray(value)
value_selection = [slice(None)] * len(a.shape)
value_selection: list[slice | None] = [slice(None)] * len(a.shape)
for i in drop_axes:
value_selection[i] = np.newaxis
value_selection = tuple(value_selection)
value = value[value_selection]
value = value[tuple(value_selection)]
a[selection] = value


Expand All @@ -597,6 +601,7 @@ def __init__(self, selection, array):
# setup per-dimension indexers
dim_indexers = []
for dim_sel, dim_len, dim_chunk_len in zip(selection, array._shape, array._chunks):
dim_indexer: IntDimIndexer | SliceDimIndexer | IntArrayDimIndexer | BoolArrayDimIndexer
if is_integer(dim_sel):
dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len)

Expand All @@ -622,6 +627,8 @@ def __init__(self, selection, array):
self.dim_indexers = dim_indexers
self.shape = tuple(s.nitems for s in self.dim_indexers if not isinstance(s, IntDimIndexer))
self.is_advanced = not is_basic_selection(selection)

self.drop_axes: tuple[int, ...] | None
if self.is_advanced:
self.drop_axes = tuple(
i
Expand Down Expand Up @@ -797,6 +804,7 @@ def __init__(self, selection, array):
boundscheck_indices(dim_sel, dim_len)

# compute chunk index for each point in the selection
chunks_multi_index: npt.ArrayLike
chunks_multi_index = tuple(
dim_sel // dim_chunk_len for (dim_sel, dim_chunk_len) in zip(selection, array._chunks)
)
Expand Down Expand Up @@ -848,6 +856,8 @@ def __iter__(self):
else:
start = self.chunk_nitems_cumsum[chunk_rix - 1]
stop = self.chunk_nitems_cumsum[chunk_rix]

out_selection: slice | np.ndarray
if self.sel_sort is None:
out_selection = slice(start, stop)
else:
Expand Down Expand Up @@ -952,6 +962,7 @@ def check_no_multi_fields(fields):


def pop_fields(selection):
fields: str | None | list[str]
if isinstance(selection, str):
# single field selection
fields = selection
Expand Down

0 comments on commit c27b0a8

Please sign in to comment.