Skip to content

Commit

Permalink
Allow grouping by multiple nD arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 19, 2024
1 parent 70e43a5 commit 1e6acbe
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
5 changes: 0 additions & 5 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,6 @@ def __init__(
(grouper,) = groupers
self.encoded = grouper.encoded
else:
for grouper in groupers:
if grouper.group.ndim > 1:
raise NotImplementedError(
"Only grouping by multiple 1D variables is supported at the moment."
)
self.encoded = ComposedGrouper(groupers).factorize()
grouper_dims = set(
itertools.chain(*tuple(grouper.group.dims for grouper in groupers))
Expand Down
29 changes: 29 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,35 @@ def test_multiple_groupers(use_flox) -> None:
square.astype(np.float64).drop_vars(("a", "b")),
)

b = xr.DataArray(
np.random.RandomState(0).randn(2, 3, 4),
coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])},
dims=["x", "y", "z"],
)
with xr.set_options(use_flox=use_flox):
assert_identical(
b.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean("z"),
b.mean("z"),
)

gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper())
repr(gb)
with xr.set_options(use_flox=use_flox):
actual = gb.mean()
expected = b.drop_vars("xy").rename({"y": "xy"}).copy(deep=True)
newval = b.isel(x=1, y=slice(1, None)).mean("y").data
expected.loc[dict(x=1, xy=1)] = expected.sel(x=1, xy=0).data
expected.loc[dict(x=1, xy=0)] = np.nan
expected.loc[dict(x=1, xy=2)] = newval
expected["xy"] = ("xy", ["a", "b", "c"])
# TODO: is order of dims correct?
assert_identical(actual, expected.transpose("z", "x", "xy"))

# assert_identical(
# b.groupby(['x', 'y']).apply(lambda x: x - x.mean()),
# b - b.mean("z"),
# )

# gb = square.groupby(x=UniqueGrouper(), y=UniqueGrouper())
# gb - gb.mean()

Expand Down

0 comments on commit 1e6acbe

Please sign in to comment.