Skip to content

Commit

Permalink
vendor dask.array.normalize_chunks and its dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Aug 12, 2024
1 parent 8db961e commit 955c56e
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 27 deletions.
38 changes: 11 additions & 27 deletions xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from packaging.version import Version

from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike
from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike, _NormalizedChunks

if TYPE_CHECKING:
if sys.version_info >= (3, 10):
Expand All @@ -27,7 +27,14 @@
DaskArray = NDArray # type: ignore
DaskCollection: Any = NDArray # type: ignore

from xarray.namedarray._typing import T_Chunks, _Dim, _Dims, _Shape, duckarray
from xarray.namedarray._typing import (
T_Chunks,
_Dim,
_Dims,
_DType,
_Shape,
duckarray,
)


K = TypeVar("K")
Expand Down Expand Up @@ -224,9 +231,6 @@ def __dask_tokenize__(self) -> object:
return normalize_token((type(self), self._value))


from xarray.namedarray._typing import _DType, _NormalizedChunks


def normalize_chunks_to_tuples(
chunks: T_Chunks,
dims: _Dims,
Expand Down Expand Up @@ -263,6 +267,7 @@ def normalize_chunks_to_tuples(
}

# (everything below here is vendored from dask)
from xarray.vendor.dask.array.utils import validate_axis

# validate that chunk lengths are valid choices
ndim = len(dims)
Expand All @@ -282,7 +287,7 @@ def normalize_chunks_to_tuples(
)

# TODO vendor the normalize_chunks function and remove it from the ChunkManager
from dask.array.core import normalize_chunks
from xarray.vendor.dask.array.core import normalize_chunks

# supports the 'auto' option, using previous_chunks as a fallback
return cast(
Expand All @@ -291,24 +296,3 @@ def normalize_chunks_to_tuples(
chunks, shape, dtype=dtype, previous_chunks=_previous_chunks
),
)


import numbers

if Version(np.__version__).release >= (2, 0):
from numpy.exceptions import AxisError
else:
from numpy import AxisError


def validate_axis(axis: int, ndim: int) -> int:
"""Validate an input to axis= keywords"""
if isinstance(axis, (tuple, list)):
return tuple(validate_axis(ax, ndim) for ax in axis)
if not isinstance(axis, numbers.Integral):
raise TypeError(f"Axis value must be an integer, got {axis}")
if axis < -ndim or axis >= ndim:
raise AxisError(f"Axis {axis} is out of bounds for array of dimension {ndim}")
if axis < 0:
axis += ndim
return axis
Loading

0 comments on commit 955c56e

Please sign in to comment.