From 2e108c747fbbe36777dd6c0ee512e9f91895a002 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 22 Feb 2024 20:18:43 -0700 Subject: [PATCH] Fixes --- xarray/core/groupby.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 8c6ffd5a797..d656cf58a6a 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -29,7 +29,13 @@ safe_cast_to_index, ) from xarray.core.options import _get_keep_attrs -from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray +from xarray.core.types import ( + Dims, + QuantileMethods, + T_DataArray, + T_DataWithCoords, + T_Xarray, +) from xarray.core.utils import ( FrozenMappingWarningOnValuesAccess, either_dict_or_kwargs, @@ -276,9 +282,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_Xarray) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, - T_Xarray, + T_DataWithCoords, Hashable | None, list[Hashable], ]: @@ -341,10 +347,10 @@ def _apply_loffset( @dataclass -class ResolvedGrouper: +class ResolvedGrouper(Generic[T_DataWithCoords]): grouper: Grouper group: T_Group - obj: T_Xarray + obj: T_DataWithCoords # Defined by factorize: codes: DataArray = field(init=False) @@ -354,7 +360,7 @@ class ResolvedGrouper: # _ensure_1d: group1d: T_Group = field(init=False) - stacked_obj: T_Xarray = field(init=False) + stacked_obj: T_DataWithCoords = field(init=False) stacked_dim: Hashable | None = field(init=False) inserted_dims: list[Hashable] = field(init=False) @@ -660,7 +666,7 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None: ) -def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: +def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group: from xarray.core.dataarray import DataArray error_msg = ( @@ -698,12 +704,12 @@ def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - group = obj[group] - if group.name not in obj._indexes and group.name in obj.dims: + group_da: DataArray = obj[group] + if group_da.name not in obj._indexes and group_da.name in obj.dims: # DummyGroups should not appear on groupby results - newgroup = _DummyGroup(obj, group.name, group.coords) + newgroup = _DummyGroup(obj, group_da.name, group_da.coords) else: - newgroup = group + newgroup = group_da if newgroup.size == 0: raise ValueError(f"{newgroup.name} must not be empty")