Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
brokkoli71 committed Feb 12, 2025
1 parent a9c0eab commit b005620
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
15 changes: 10 additions & 5 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
NDBuffer,
default_buffer_prototype,
)
from zarr.core.buffer.core import NDArrayLike
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks
from zarr.core.chunk_key_encodings import (
ChunkKeyEncoding,
Expand Down Expand Up @@ -1400,7 +1401,7 @@ async def _set_selection(
value = value.astype(dtype=self.metadata.dtype, order="A")
else:
value = np.array(value, dtype=self.metadata.dtype, order="A")
value = cast(NDArrayOrScalarLike, value)
value = cast(NDArrayLike, value)
# We accept any ndarray like object from the user and convert it
# to a NDBuffer (or subclass). From this point onwards, we only pass
# Buffer and NDBuffer between components.
Expand Down Expand Up @@ -2260,7 +2261,7 @@ def _iter_chunk_regions(

def __array__(
self, dtype: npt.DTypeLike | None = None, copy: bool | None = None
) -> NDArrayOrScalarLike:
) -> NDArrayLike:
"""
This method is used by numpy when converting zarr.Array into a numpy array.
For more information, see https://numpy.org/devdocs/user/basics.interoperability.html#the-array-method
Expand All @@ -2269,9 +2270,13 @@ def __array__(
msg = "`copy=False` is not supported. This method always creates a copy."
raise ValueError(msg)

arr_np = self[...]
if self.ndim == 0:
arr_np = np.array(arr_np)
arr = self[...]
arr_np: NDArrayLike

if not hasattr(arr, "astype"):
arr_np = np.array(arr, dtype=dtype)
else:
arr_np = arr

if dtype is not None:
arr_np = arr_np.astype(dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/core/buffer/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Callable, Iterable
from typing import Self

from zarr.core.buffer.core import ArrayLike, NDArrayOrScalarLike
from zarr.core.buffer.core import ArrayLike, NDArrayLike
from zarr.core.common import BytesLike


Expand Down Expand Up @@ -142,7 +142,7 @@ class NDBuffer(core.NDBuffer):
ndarray-like object that is convertible to a regular Numpy array.
"""

def __init__(self, array: NDArrayOrScalarLike) -> None:
def __init__(self, array: NDArrayLike) -> None:
super().__init__(array)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/core/buffer/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy.typing as npt

from zarr.core.buffer import core
from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayOrScalarLike
from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -136,7 +136,7 @@ class NDBuffer(core.NDBuffer):
ndarray-like object that is convertible to a regular Numpy array.
"""

def __init__(self, array: NDArrayOrScalarLike) -> None:
def __init__(self, array: NDArrayLike) -> None:
if cp is None:
raise ImportError(
"Cannot use zarr.buffer.gpu.NDBuffer without cupy. Please install cupy."
Expand Down

0 comments on commit b005620

Please sign in to comment.