Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 12, 2024
1 parent 49e5b91 commit 929b70c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 16 deletions.
4 changes: 2 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
import sys
from collections.abc import Hashable, Iterator, Mapping, Sequence
from collections.abc import Hashable, Iterator, Mapping, Sequence, Collection
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -182,7 +182,7 @@ def copy(
DsCompatible = Union["Dataset", "DaCompatible"]
GroupByCompatible = Union["Dataset", "DataArray"]

Dims = Union[Hashable, Sequence[Hashable], None]
Dims = Union[Hashable, Collection[Hashable], None] # Note: Hashable include ellipsis

# FYI in some cases we don't allow `None`, which this doesn't take account of.
T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]
Expand Down
19 changes: 11 additions & 8 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def drop_missing_dims(

@overload
def parse_dims(
dim: str | Iterable[Hashable] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand All @@ -999,7 +999,7 @@ def parse_dims(

@overload
def parse_dims(
dim: str | Iterable[Hashable] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand Down Expand Up @@ -1042,11 +1042,14 @@ def parse_dims(
if replace_none:
return all_dims
return dim
if isinstance(dim, str):
dim = (dim,)
if isinstance(dim, Collection) and not isinstance(dim, str):
dim = tuple(dim)
else:
dim = (dim, )

if check_exists:
_check_dims(set(dim), set(all_dims))
return tuple(dim)
return dim


@overload
Expand Down Expand Up @@ -1124,9 +1127,9 @@ def parse_ordered_dims(
)


def _check_dims(dim: set[Hashable | ellipsis], all_dims: set[Hashable]) -> None:
wrong_dims = dim - all_dims
if wrong_dims and wrong_dims != {...}:
def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None:
wrong_dims = (dim - all_dims) - {...}
if wrong_dims:
wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims)
raise ValueError(
f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}"
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ def test_interpolate_chunk_1d(
if chunked:
dest[dim] = xr.DataArray(data=dest[dim], dims=[dim])
dest[dim] = dest[dim].chunk(2)
actual = da.interp(method=method, **dest, kwargs=kwargs) # type: ignore
expected = da.compute().interp(method=method, **dest, kwargs=kwargs) # type: ignore
actual = da.interp(method=method, **dest, kwargs=kwargs)
expected = da.compute().interp(method=method, **dest, kwargs=kwargs)

assert_identical(actual, expected)

Expand Down
10 changes: 6 additions & 4 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,21 @@ def test_infix_dims_errors(supplied, all_):
["dim", "expected"],
[
pytest.param("a", ("a",), id="str"),
pytest.param(1, (1,), id="non_str_hashable"),
pytest.param(["a", "b"], ("a", "b"), id="list_of_str"),
pytest.param(["a", 1], ("a", 1), id="list_mixed"),
pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"),
pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"),
pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"),
pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"),
pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"),
pytest.param((), (), id="empty_tuple"),
pytest.param(set(), (), id="empty_collection"),
pytest.param(None, None, id="None"),
pytest.param(..., ..., id="ellipsis"),
],
)
def test_parse_dims(
dim: str | Iterable[Hashable] | None,
expected: tuple[Hashable, ...],
) -> None:
def test_parse_dims(dim, expected):
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
actual = utils.parse_dims(dim, all_dims, replace_none=False)
assert actual == expected
Expand Down

0 comments on commit 929b70c

Please sign in to comment.