diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 54806988205..b718eed1b31 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -287,7 +287,7 @@ def codes(self) -> DataArray: return self.encoded.codes @property - def unique_coord(self) -> DataArray: + def unique_coord(self) -> Variable | _DummyGroup: return self.encoded.unique_coord def __post_init__(self) -> None: @@ -300,13 +300,13 @@ def __post_init__(self) -> None: self.group = _resolve_group(self.obj, self.group) - self.encoded = self.factorize() + self.encoded = self.grouper.factorize(self.group) @property def name(self) -> Hashable: """Name for the grouped coordinate after reduction.""" # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper - (name,) = self.unique_coord.dims + (name,) = self.encoded.unique_coord.dims return name @property @@ -318,28 +318,6 @@ def __len__(self) -> int: """Number of groups.""" return len(self.encoded.full_index) - @property - def dims(self): - return self.group1d.dims - - def factorize(self) -> EncodedGroups: - encoded = self.grouper.factorize(self.group) - - if encoded.group_indices is None: - encoded.group_indices = tuple( - g - for g in _codes_to_group_indices( - encoded.codes.data.ravel(), len(encoded.full_index) - ) - if g - ) - if encoded.unique_coord is None: - unique_values = encoded.full_index[np.unique(encoded.codes)] - encoded.unique_coord = Variable( - dims=encoded.codes.name, data=unique_values, attrs=self.group.attrs - ) - return encoded - def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: # While we don't generally check the type of every arg, passing @@ -446,7 +424,7 @@ def factorize(self) -> EncodedGroups: (grouper.full_index.values for grouper in groupers), names=tuple(grouper.name for grouper in groupers), ) - dim_name = "stacked_" + "_".join(grouper.name for grouper in groupers) + dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers) return EncodedGroups( codes=first_codes.copy(data=_flatcodes), diff --git a/xarray/groupers.py b/xarray/groupers.py index 45bca540c35..d30d0eeea74 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -45,7 +45,7 @@ def _coordinates_from_variable(variable: Variable) -> Coordinates: return Coordinates(new_vars, indexes) -@dataclass +@dataclass(init=False) class EncodedGroups: """ Dataclass for storing intermediate values for GroupBy operation. @@ -67,19 +67,45 @@ class EncodedGroups: codes: DataArray full_index: pd.Index - group_indices: GroupIndices | None = field(default=None) - unique_coord: Variable | _DummyGroup | None = field(default=None) - coords: Coordinates = field(default_factory=Coordinates) - - def __post_init__(self): - assert isinstance(self.codes, DataArray) - if self.codes.name is None: + group_indices: GroupIndices + unique_coord: Variable | _DummyGroup + coords: Coordinates + + def __init__( + self, + codes: DataArray, + full_index: pd.Index, + group_indices: GroupIndices | None = None, + unique_coord: Variable | _DummyGroup | None = None, + coords: Coordinates | None = None, + ): + from xarray.core.groupby import _codes_to_group_indices + + assert isinstance(codes, DataArray) + if codes.name is None: raise ValueError("Please set a name on the array you are grouping by.") - assert isinstance(self.full_index, pd.Index) - assert ( - isinstance(self.unique_coord, Variable | _DummyGroup) - or self.unique_coord is None - ) + self.codes = codes + assert isinstance(full_index, pd.Index) + self.full_index = full_index + + if group_indices is None: + self.group_indices = tuple( + g + for g in _codes_to_group_indices(codes.data.ravel(), len(full_index)) + if g + ) + else: + self.group_indices = group_indices + + if unique_coord is None: + unique_values = full_index[np.unique(codes)] + self.unique_coord = Variable( + dims=codes.name, data=unique_values, attrs=codes.attrs + ) + else: + self.unique_coord = unique_coord + + self.coords = coords or Coordinates() class Grouper(ABC):