diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ba8856e178b..f9d308171a9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -63,6 +63,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Support non-string hashable dimensions in :py:class:`xarray.DataArray` (:issue:`8546`, :pull:`8559`). + By `Michael Niklas `_. - Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`). By `Kai Mühlbauer `_. - Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`). diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e8d6f82136b..d2ea0f8a1a4 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -11,6 +11,8 @@ Generic, Literal, NoReturn, + TypeVar, + Union, overload, ) @@ -61,6 +63,7 @@ ReprObject, _default, either_dict_or_kwargs, + hashable, ) from xarray.core.variable import ( IndexVariable, @@ -73,23 +76,11 @@ from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims if TYPE_CHECKING: - from typing import TypeVar, Union - + from dask.dataframe import DataFrame as DaskDataFrame + from dask.delayed import Delayed + from iris.cube import Cube as iris_Cube from numpy.typing import ArrayLike - try: - from dask.dataframe import DataFrame as DaskDataFrame - except ImportError: - DaskDataFrame = None - try: - from dask.delayed import Delayed - except ImportError: - Delayed = None # type: ignore[misc,assignment] - try: - from iris.cube import Cube as iris_Cube - except ImportError: - iris_Cube = None - from xarray.backends import ZarrStore from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes from xarray.core.groupby import DataArrayGroupBy @@ -140,7 +131,9 @@ def _check_coords_dims(shape, coords, dim): def _infer_coords_and_dims( - shape, coords, dims + shape: tuple[int, ...], + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + dims: str | Iterable[Hashable] | None, ) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]: """All the logic for creating a new DataArray""" @@ -157,8 +150,7 @@ def _infer_coords_and_dims( if isinstance(dims, str): dims = (dims,) - - if dims is None: + elif dims is None: dims = [f"dim_{n}" for n in range(len(shape))] if coords is not None and len(coords) == len(shape): # try to infer dimensions from coords @@ -168,16 +160,15 @@ def _infer_coords_and_dims( for n, (dim, coord) in enumerate(zip(dims, coords)): coord = as_variable(coord, name=dims[n]).to_index_variable() dims[n] = coord.name - dims = tuple(dims) - elif len(dims) != len(shape): + dims_tuple = tuple(dims) + if len(dims_tuple) != len(shape): raise ValueError( "different number of dimensions on data " - f"and dims: {len(shape)} vs {len(dims)}" + f"and dims: {len(shape)} vs {len(dims_tuple)}" ) - else: - for d in dims: - if not isinstance(d, str): - raise TypeError(f"dimension {d} is not a string") + for d in dims_tuple: + if not hashable(d): + raise TypeError(f"Dimension {d} is not hashable") new_coords: Mapping[Hashable, Any] @@ -189,17 +180,21 @@ def _infer_coords_and_dims( for k, v in coords.items(): new_coords[k] = as_variable(v, name=k) elif coords is not None: - for dim, coord in zip(dims, coords): + for dim, coord in zip(dims_tuple, coords): var = as_variable(coord, name=dim) var.dims = (dim,) new_coords[dim] = var.to_index_variable() - _check_coords_dims(shape, new_coords, dims) + _check_coords_dims(shape, new_coords, dims_tuple) - return new_coords, dims + return new_coords, dims_tuple -def _check_data_shape(data, coords, dims): +def _check_data_shape( + data: Any, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + dims: str | Iterable[Hashable] | None, +) -> Any: if data is dtypes.NA: data = np.nan if coords is not None and utils.is_scalar(data, include_0d=False): @@ -405,10 +400,8 @@ class DataArray( def __init__( self, data: Any = dtypes.NA, - coords: Sequence[Sequence[Any] | pd.Index | DataArray] - | Mapping[Any, Any] - | None = None, - dims: Hashable | Sequence[Hashable] | None = None, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + dims: str | Iterable[Hashable] | None = None, name: Hashable | None = None, attrs: Mapping | None = None, # internal parameters diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c19083915f4..c83a56bb373 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -130,6 +130,8 @@ from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: + from dask.dataframe import DataFrame as DaskDataFrame + from dask.delayed import Delayed from numpy.typing import ArrayLike from xarray.backends import AbstractDataStore, ZarrStore @@ -164,15 +166,6 @@ ) from xarray.core.weighted import DatasetWeighted - try: - from dask.delayed import Delayed - except ImportError: - Delayed = None # type: ignore[misc,assignment] - try: - from dask.dataframe import DataFrame as DaskDataFrame - except ImportError: - DaskDataFrame = None - # list of attributes of pd.DatetimeIndex that are ndarrays of time info _DATETIMEINDEX_COMPONENTS = [ diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 9d65c3ef021..207caba48f0 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -59,7 +59,11 @@ def _importorskip( raise ImportError("Minimum version not satisfied") except ImportError: has = False - func = pytest.mark.skipif(not has, reason=f"requires {modname}") + + reason = f"requires {modname}" + if minversion is not None: + reason += f">={minversion}" + func = pytest.mark.skipif(not has, reason=reason) return has, func @@ -122,10 +126,7 @@ def _importorskip( not has_pandas_version_two, reason="requires pandas 2.0.0" ) has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0") -has_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0") -requires_h5netcdf_ros3 = pytest.mark.skipif( - not has_h5netcdf_ros3[0], reason="requires h5netcdf 1.3.0" -) +has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0") has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip( "netCDF4", "1.6.2" diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 6dc85fc5691..ab1fc316f77 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -401,8 +401,8 @@ def test_constructor_invalid(self) -> None: with pytest.raises(ValueError, match=r"not a subset of the .* dim"): DataArray(data, {"x": [0, 1, 2]}) - with pytest.raises(TypeError, match=r"is not a string"): - DataArray(data, dims=["x", None]) + with pytest.raises(TypeError, match=r"is not hashable"): + DataArray(data, dims=["x", []]) # type: ignore[list-item] with pytest.raises(ValueError, match=r"conflicting sizes for dim"): DataArray([1, 2, 3], coords=[("x", [0, 1])]) diff --git a/xarray/tests/test_hashable.py b/xarray/tests/test_hashable.py new file mode 100644 index 00000000000..9f92c604dc3 --- /dev/null +++ b/xarray/tests/test_hashable.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Union + +import pytest + +from xarray import DataArray, Dataset, Variable + +if TYPE_CHECKING: + from xarray.core.types import TypeAlias + + DimT: TypeAlias = Union[int, tuple, "DEnum", "CustomHashable"] + + +class DEnum(Enum): + dim = "dim" + + +class CustomHashable: + def __init__(self, a: int) -> None: + self.a = a + + def __hash__(self) -> int: + return self.a + + +parametrize_dim = pytest.mark.parametrize( + "dim", + [ + pytest.param(5, id="int"), + pytest.param(("a", "b"), id="tuple"), + pytest.param(DEnum.dim, id="enum"), + pytest.param(CustomHashable(3), id="HashableObject"), + ], +) + + +@parametrize_dim +def test_hashable_dims(dim: DimT) -> None: + v = Variable([dim], [1, 2, 3]) + da = DataArray([1, 2, 3], dims=[dim]) + Dataset({"a": ([dim], [1, 2, 3])}) + + # alternative constructors + DataArray(v) + Dataset({"a": v}) + Dataset({"a": da}) + + +@parametrize_dim +def test_dataset_variable_hashable_names(dim: DimT) -> None: + Dataset({dim: ("x", [1, 2, 3])})