Skip to content

Commit

Permalink
Generalize Coordinates assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 16, 2024
1 parent b13497f commit 7f23eea
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ class GroupBy(Generic[T_Xarray]):
_inserted_dims: list[Hashable]

_unique_coord: Variable | _DummyGroup
coords: Coordinates | None
coords: Coordinates

_len: int

Expand Down Expand Up @@ -531,15 +531,27 @@ def __init__(
(grouper,) = self.groupers
self._group_indices = grouper.group_indices
self._unique_coord = grouper.unique_coord
self.coords = None

# TODO: clean this up by moving this logic to the Grouper Classes
if not isinstance(grouper.unique_coord, _DummyGroup):
new_index, index_vars = create_default_index_implicit(
grouper.unique_coord
)
indexes = {k: new_index for k in index_vars}
new_vars = new_index.create_variables()
self.coords = Coordinates(new_vars, indexes)
else:
self.coords = Coordinates()

else:
midx = pd.MultiIndex.from_product(
(grouper.unique_coord.data for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
self.coords = Coordinates.from_pandas_multiindex(midx, dim=self._group_dim)
dim_name = "stacked_" + "_".join(grouper.name for grouper in groupers)
self.coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name)
self._group_indices = _codes_to_group_indices(_flatcodes.ravel(), self._len)
self._unique_coord = Variable(dims=(self._group_dim,), data=midx.values)
self._unique_coord = Variable(dims=(dim_name,), data=midx.values)

# cached attributes
self._groups = None
Expand Down Expand Up @@ -649,7 +661,6 @@ def _infer_concat_args(self, applied_example):
):
# When binning we actually do set the index
coord = None
coord = getattr(coord, "variable", coord)
return coord, dim, positions

def _binary_op(self, other, f, reflexive=False):
Expand Down Expand Up @@ -774,7 +785,7 @@ def _maybe_unstack(self, obj):
del obj.coords[dim]
obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords))
if len(self.groupers) > 1:
obj = obj.unstack(self._group_dim)
obj = obj.unstack(*self._unique_coord.dims)
return obj

def _flox_reduce(
Expand Down Expand Up @@ -1290,12 +1301,7 @@ def _combine(self, applied, shortcut=False):
combined = self._restore_dim_order(combined)
# assign coord and index when the applied function does not return that coord
if coord is not None and dim not in applied_example.dims:
if self.coords is not None:
combined = combined.assign_coords(self.coords)
else:
index, index_vars = create_default_index_implicit(coord)
indexes = {k: index for k in index_vars}
combined = combined._overwrite_indexes(indexes, index_vars)
combined = combined.assign_coords(self.coords)
combined = self._maybe_unstack(combined)
combined = self._maybe_restore_empty_groups(combined)
return combined
Expand Down

0 comments on commit 7f23eea

Please sign in to comment.