Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Feb 23, 2024
1 parent 12a44fe commit 2e108c7
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
safe_cast_to_index,
)
from xarray.core.options import _get_keep_attrs
from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray
from xarray.core.types import (
Dims,
QuantileMethods,
T_DataArray,
T_DataWithCoords,
T_Xarray,
)
from xarray.core.utils import (
FrozenMappingWarningOnValuesAccess,
either_dict_or_kwargs,
Expand Down Expand Up @@ -276,9 +282,9 @@ def to_array(self) -> DataArray:
T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup]


def _ensure_1d(group: T_Group, obj: T_Xarray) -> tuple[
def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
T_Group,
T_Xarray,
T_DataWithCoords,
Hashable | None,
list[Hashable],
]:
Expand Down Expand Up @@ -341,10 +347,10 @@ def _apply_loffset(


@dataclass
class ResolvedGrouper:
class ResolvedGrouper(Generic[T_DataWithCoords]):
grouper: Grouper
group: T_Group
obj: T_Xarray
obj: T_DataWithCoords

# Defined by factorize:
codes: DataArray = field(init=False)
Expand All @@ -354,7 +360,7 @@ class ResolvedGrouper:

# _ensure_1d:
group1d: T_Group = field(init=False)
stacked_obj: T_Xarray = field(init=False)
stacked_obj: T_DataWithCoords = field(init=False)
stacked_dim: Hashable | None = field(init=False)
inserted_dims: list[Hashable] = field(init=False)

Expand Down Expand Up @@ -660,7 +666,7 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None:
)


def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group:
def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group:
from xarray.core.dataarray import DataArray

error_msg = (
Expand Down Expand Up @@ -698,12 +704,12 @@ def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group:
"name of an xarray variable or dimension. "
f"Received {group!r} instead."
)
group = obj[group]
if group.name not in obj._indexes and group.name in obj.dims:
group_da: DataArray = obj[group]
if group_da.name not in obj._indexes and group_da.name in obj.dims:
# DummyGroups should not appear on groupby results
newgroup = _DummyGroup(obj, group.name, group.coords)
newgroup = _DummyGroup(obj, group_da.name, group_da.coords)
else:
newgroup = group
newgroup = group_da

if newgroup.size == 0:
raise ValueError(f"{newgroup.name} must not be empty")
Expand Down

0 comments on commit 2e108c7

Please sign in to comment.