Skip to content

Commit

Permalink
Better broadcasting of unreduced variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 20, 2024
1 parent 1e6acbe commit bcecb15
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 33 deletions.
83 changes: 60 additions & 23 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
PandasIndex,
PandasMultiIndex,
filter_indexes_from_coords,
)
from xarray.core.options import OPTIONS, _get_keep_attrs
Expand Down Expand Up @@ -518,19 +519,6 @@ def __init__(
self.encoded = grouper.encoded
else:
self.encoded = ComposedGrouper(groupers).factorize()
grouper_dims = set(
itertools.chain(*tuple(grouper.group.dims for grouper in groupers))
)
grouper_names = set(grouper.group.name for grouper in groupers)
# Drop any non-dim coords that we are not grouping by;
# even if they share dims with the group variables.
# This mirrors our behaviour for the single grouper case.
to_drop = tuple(
name
for name in (set(obj.coords) - set(obj.xindexes))
if name not in grouper_names and set(obj[name].dims) & grouper_dims
)
obj = obj.drop_vars(to_drop)

# specification for the groupby operation
# TODO: handle obj having variables that are not present on any of the groupers
Expand Down Expand Up @@ -756,6 +744,8 @@ def _maybe_restore_empty_groups(self, combined):
def _maybe_unstack(self, obj):
"""This gets called if we are applying on an array with a
multidimensional group."""
from xarray.groupers import UniqueGrouper

stacked_dim = self._stacked_dim
if stacked_dim is not None and stacked_dim in obj.dims:
inserted_dims = self._inserted_dims
Expand All @@ -765,9 +755,19 @@ def _maybe_unstack(self, obj):
del obj.coords[dim]
obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords))
elif len(self.groupers) > 1:
# TODO: we could clean this up by setting the appropriate `stacked_dim`
# and `inserted_dims`
# if multiple groupers all share the same single dimension, then
# we don't stack/unstack. Do that manually now.
obj = obj.unstack(*self.encoded.unique_coord.dims)
to_drop = [
grouper.name
for grouper in self.groupers
if isinstance(grouper.group, _DummyGroup)
and isinstance(grouper.grouper, UniqueGrouper)
]
obj = obj.drop_vars(to_drop)

return obj

def _flox_reduce(
Expand All @@ -784,6 +784,12 @@ def _flox_reduce(
from xarray.groupers import BinGrouper

obj = self._original_obj
variables = (
{k: v.variable for k, v in obj.data_vars.items()}
if isinstance(obj, Dataset)
else obj._coords
)

any_isbin = any(
isinstance(grouper.grouper, BinGrouper) for grouper in self.groupers
)
Expand All @@ -797,12 +803,27 @@ def _flox_reduce(
# flox >=0.9 will choose this on its own.
kwargs.setdefault("method", "cohorts")

numeric_only = kwargs.pop("numeric_only", None)
if numeric_only:
midx_grouping_vars: tuple[Hashable, ...] = ()
for grouper in self.groupers:
name = grouper.name
maybe_midx = obj._indexes.get(name, None)
if isinstance(maybe_midx, PandasMultiIndex):
midx_grouping_vars += tuple(maybe_midx.index.names) + (name,)

# For datasets, running a numeric-only reduction on non-numeric
# variable will just drop it.
non_numeric: dict[Hashable, Variable]
if kwargs.pop("numeric_only", None):
non_numeric = {
name: var
for name, var in obj.data_vars.items()
if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_))
for name, var in variables.items()
if (
not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_))
# this avoids dropping any levels of a MultiIndex, which raises
# a warning
and name not in midx_grouping_vars
and name not in obj.dims
)
}
else:
non_numeric = {}
Expand Down Expand Up @@ -878,8 +899,8 @@ def _flox_reduce(
group_dims: set[str | Hashable] = set(
itertools.chain(*(grouper.group.dims for grouper in self.groupers))
)
new_coords = {}
if group_dims.issubset(set(parsed_dim)):
new_coords = {}
new_indexes = {}
for grouper in self.groupers:
output_index = grouper.full_index
Expand All @@ -890,12 +911,28 @@ def _flox_reduce(
Coordinates(new_coords, new_indexes)
).drop_vars(unindexed_dims)

# broadcast and restore non-numeric data variables (backcompat)
for name, var in non_numeric.items():
if all(d not in var.dims for d in parsed_dim):
result[name] = var.variable.set_dims(
(name,) + var.dims, (result.sizes[name],) + var.shape
# broadcast any non-dim coord variables that don't
# share all dimensions with the grouper
result_variables = (
result._variables if isinstance(result, Dataset) else result._coords
)
to_broadcast: dict[Hashable, Variable] = {}
for name, var in variables.items():
dims_set = set(var.dims)
if (
dims_set <= set(parsed_dim)
and (dims_set & set(result.dims))
and name not in result_variables
):
to_broadcast[name] = var
for name, var in to_broadcast.items():
if new_dims := tuple(d for d in parsed_dim if d not in var.dims):
new_sizes = tuple(
result.sizes.get(dim, obj.sizes.get(dim)) for dim in new_dims
)
result[name] = var.set_dims(
new_dims + var.dims, new_sizes + var.shape
).transpose(..., *result.dims)

if not isinstance(result, Dataset):
# only restore dimension order for arrays
Expand Down
52 changes: 42 additions & 10 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import xarray as xr
from xarray import DataArray, Dataset, Variable
from xarray.core.alignment import broadcast
from xarray.core.groupby import _consolidate_slices
from xarray.core.types import InterpOptions
from xarray.groupers import BinGrouper, EncodedGroups, Grouper, UniqueGrouper
Expand Down Expand Up @@ -807,14 +808,18 @@ def test_groupby_dataset_errors() -> None:
data.groupby(data.coords["dim1"].to_index()) # type: ignore[arg-type]


@pytest.mark.parametrize("use_flox", [True, False])
@pytest.mark.parametrize(
"by_func",
[
pytest.param(lambda x: x, id="group-by-string"),
pytest.param(lambda x: {x: UniqueGrouper()}, id="group-by-unique-grouper"),
],
)
def test_groupby_dataset_reduce_ellipsis(by_func) -> None:
@pytest.mark.parametrize("letters_as_coord", [True, False])
def test_groupby_dataset_reduce_ellipsis(
by_func, use_flox: bool, letters_as_coord: bool
) -> None:
data = Dataset(
{
"xy": (["x", "y"], np.random.randn(3, 4)),
Expand All @@ -824,13 +829,18 @@ def test_groupby_dataset_reduce_ellipsis(by_func) -> None:
}
)

if letters_as_coord:
data = data.set_coords("letters")

expected = data.mean("y")
expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3})
gb = data.groupby(by_func("x"))
actual = gb.mean(...)
with xr.set_options(use_flox=use_flox):
actual = gb.mean(...)
assert_allclose(expected, actual)

actual = gb.mean("y")
with xr.set_options(use_flox=use_flox):
actual = gb.mean("y")
assert_allclose(expected, actual)

letters = data["letters"]
Expand All @@ -842,7 +852,8 @@ def test_groupby_dataset_reduce_ellipsis(by_func) -> None:
}
)
gb = data.groupby(by_func("letters"))
actual = gb.mean(...)
with xr.set_options(use_flox=use_flox):
actual = gb.mean(...)
assert_allclose(expected, actual)


Expand Down Expand Up @@ -2606,22 +2617,22 @@ def test_multiple_groupers(use_flox) -> None:
)
assert_identical(actual, expected)

expected = square.astype(np.float64)
expected["a"], expected["b"] = broadcast(square.a, square.b)
with xr.set_options(use_flox=use_flox):
assert_identical(
square.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean(),
square.astype(np.float64).drop_vars(("a", "b")),
square.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean(), expected
)

b = xr.DataArray(
np.random.RandomState(0).randn(2, 3, 4),
coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])},
dims=["x", "y", "z"],
)
gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper())
repr(gb)
with xr.set_options(use_flox=use_flox):
assert_identical(
b.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean("z"),
b.mean("z"),
)
assert_identical(gb.mean("z"), b.mean("z"))

gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper())
repr(gb)
Expand Down Expand Up @@ -2651,3 +2662,24 @@ def test_multiple_groupers(use_flox) -> None:
# 1. lambda x: x
# 2. grouped-reduce on unique coords is identical to array
# 3. group_over == groupby-reduce along other dimensions


def test_weather_data_resample():
# from the docs
times = pd.date_range("2000-01-01", "2001-12-31", name="time")
annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28))

base = 10 + 15 * annual_cycle.reshape(-1, 1)
tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3)
tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3)

ds = xr.Dataset(
{
"tmin": (("time", "location"), tmin_values),
"tmax": (("time", "location"), tmax_values),
},
{"time": times, "location": ["IA", "IN", "IL"]},
)

actual = ds.resample(time="1MS").mean()
assert "location" in actual._indexes

0 comments on commit bcecb15

Please sign in to comment.