Skip to content

Commit

Permalink
Simplify ResolvedGrouper
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 16, 2024
1 parent d419fb8 commit 82e9514
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 39 deletions.
30 changes: 4 additions & 26 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def codes(self) -> DataArray:
return self.encoded.codes

@property
def unique_coord(self) -> DataArray:
def unique_coord(self) -> Variable | _DummyGroup:
return self.encoded.unique_coord

def __post_init__(self) -> None:
Expand All @@ -300,13 +300,13 @@ def __post_init__(self) -> None:

self.group = _resolve_group(self.obj, self.group)

self.encoded = self.factorize()
self.encoded = self.grouper.factorize(self.group)

@property
def name(self) -> Hashable:
"""Name for the grouped coordinate after reduction."""
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
(name,) = self.unique_coord.dims
(name,) = self.encoded.unique_coord.dims
return name

@property
Expand All @@ -318,28 +318,6 @@ def __len__(self) -> int:
"""Number of groups."""
return len(self.encoded.full_index)

@property
def dims(self):
return self.group1d.dims

def factorize(self) -> EncodedGroups:
encoded = self.grouper.factorize(self.group)

if encoded.group_indices is None:
encoded.group_indices = tuple(
g
for g in _codes_to_group_indices(
encoded.codes.data.ravel(), len(encoded.full_index)
)
if g
)
if encoded.unique_coord is None:
unique_values = encoded.full_index[np.unique(encoded.codes)]
encoded.unique_coord = Variable(
dims=encoded.codes.name, data=unique_values, attrs=self.group.attrs
)
return encoded


def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# While we don't generally check the type of every arg, passing
Expand Down Expand Up @@ -446,7 +424,7 @@ def factorize(self) -> EncodedGroups:
(grouper.full_index.values for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
dim_name = "stacked_" + "_".join(grouper.name for grouper in groupers)
dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers)

return EncodedGroups(
codes=first_codes.copy(data=_flatcodes),
Expand Down
52 changes: 39 additions & 13 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _coordinates_from_variable(variable: Variable) -> Coordinates:
return Coordinates(new_vars, indexes)


@dataclass
@dataclass(init=False)
class EncodedGroups:
"""
Dataclass for storing intermediate values for GroupBy operation.
Expand All @@ -67,19 +67,45 @@ class EncodedGroups:

codes: DataArray
full_index: pd.Index
group_indices: GroupIndices | None = field(default=None)
unique_coord: Variable | _DummyGroup | None = field(default=None)
coords: Coordinates = field(default_factory=Coordinates)

def __post_init__(self):
assert isinstance(self.codes, DataArray)
if self.codes.name is None:
group_indices: GroupIndices
unique_coord: Variable | _DummyGroup
coords: Coordinates

def __init__(
self,
codes: DataArray,
full_index: pd.Index,
group_indices: GroupIndices | None = None,
unique_coord: Variable | _DummyGroup | None = None,
coords: Coordinates | None = None,
):
from xarray.core.groupby import _codes_to_group_indices

assert isinstance(codes, DataArray)
if codes.name is None:
raise ValueError("Please set a name on the array you are grouping by.")
assert isinstance(self.full_index, pd.Index)
assert (
isinstance(self.unique_coord, Variable | _DummyGroup)
or self.unique_coord is None
)
self.codes = codes
assert isinstance(full_index, pd.Index)
self.full_index = full_index

if group_indices is None:
self.group_indices = tuple(
g
for g in _codes_to_group_indices(codes.data.ravel(), len(full_index))
if g
)
else:
self.group_indices = group_indices

if unique_coord is None:
unique_values = full_index[np.unique(codes)]
self.unique_coord = Variable(
dims=codes.name, data=unique_values, attrs=codes.attrs
)
else:
self.unique_coord = unique_coord

self.coords = coords or Coordinates()


class Grouper(ABC):
Expand Down

0 comments on commit 82e9514

Please sign in to comment.