diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 927c10df3c9..174c5ea3794 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -25,6 +25,7 @@ from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( PandasIndex, + PandasMultiIndex, filter_indexes_from_coords, ) from xarray.core.options import OPTIONS, _get_keep_attrs @@ -518,19 +519,6 @@ def __init__( self.encoded = grouper.encoded else: self.encoded = ComposedGrouper(groupers).factorize() - grouper_dims = set( - itertools.chain(*tuple(grouper.group.dims for grouper in groupers)) - ) - grouper_names = set(grouper.group.name for grouper in groupers) - # Drop any non-dim coords that we are not grouping by; - # even if they share dims with the group variables. - # This mirrors our behaviour for the single grouper case. - to_drop = tuple( - name - for name in (set(obj.coords) - set(obj.xindexes)) - if name not in grouper_names and set(obj[name].dims) & grouper_dims - ) - obj = obj.drop_vars(to_drop) # specification for the groupby operation # TODO: handle obj having variables that are not present on any of the groupers @@ -756,6 +744,8 @@ def _maybe_restore_empty_groups(self, combined): def _maybe_unstack(self, obj): """This gets called if we are applying on an array with a multidimensional group.""" + from xarray.groupers import UniqueGrouper + stacked_dim = self._stacked_dim if stacked_dim is not None and stacked_dim in obj.dims: inserted_dims = self._inserted_dims @@ -765,9 +755,19 @@ def _maybe_unstack(self, obj): del obj.coords[dim] obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) elif len(self.groupers) > 1: + # TODO: we could clean this up by setting the appropriate `stacked_dim` + # and `inserted_dims` # if multiple groupers all share the same single dimension, then # we don't stack/unstack. Do that manually now. obj = obj.unstack(*self.encoded.unique_coord.dims) + to_drop = [ + grouper.name + for grouper in self.groupers + if isinstance(grouper.group, _DummyGroup) + and isinstance(grouper.grouper, UniqueGrouper) + ] + obj = obj.drop_vars(to_drop) + return obj def _flox_reduce( @@ -784,6 +784,12 @@ def _flox_reduce( from xarray.groupers import BinGrouper obj = self._original_obj + variables = ( + {k: v.variable for k, v in obj.data_vars.items()} + if isinstance(obj, Dataset) + else obj._coords + ) + any_isbin = any( isinstance(grouper.grouper, BinGrouper) for grouper in self.groupers ) @@ -797,12 +803,27 @@ def _flox_reduce( # flox >=0.9 will choose this on its own. kwargs.setdefault("method", "cohorts") - numeric_only = kwargs.pop("numeric_only", None) - if numeric_only: + midx_grouping_vars: tuple[Hashable, ...] = () + for grouper in self.groupers: + name = grouper.name + maybe_midx = obj._indexes.get(name, None) + if isinstance(maybe_midx, PandasMultiIndex): + midx_grouping_vars += tuple(maybe_midx.index.names) + (name,) + + # For datasets, running a numeric-only reduction on non-numeric + # variable will just drop it. + non_numeric: dict[Hashable, Variable] + if kwargs.pop("numeric_only", None): non_numeric = { name: var - for name, var in obj.data_vars.items() - if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + for name, var in variables.items() + if ( + not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + # this avoids dropping any levels of a MultiIndex, which raises + # a warning + and name not in midx_grouping_vars + and name not in obj.dims + ) } else: non_numeric = {} @@ -878,8 +899,8 @@ def _flox_reduce( group_dims: set[str | Hashable] = set( itertools.chain(*(grouper.group.dims for grouper in self.groupers)) ) + new_coords = {} if group_dims.issubset(set(parsed_dim)): - new_coords = {} new_indexes = {} for grouper in self.groupers: output_index = grouper.full_index @@ -890,12 +911,28 @@ def _flox_reduce( Coordinates(new_coords, new_indexes) ).drop_vars(unindexed_dims) - # broadcast and restore non-numeric data variables (backcompat) - for name, var in non_numeric.items(): - if all(d not in var.dims for d in parsed_dim): - result[name] = var.variable.set_dims( - (name,) + var.dims, (result.sizes[name],) + var.shape + # broadcast any non-dim coord variables that don't + # share all dimensions with the grouper + result_variables = ( + result._variables if isinstance(result, Dataset) else result._coords + ) + to_broadcast: dict[Hashable, Variable] = {} + for name, var in variables.items(): + dims_set = set(var.dims) + if ( + dims_set <= set(parsed_dim) + and (dims_set & set(result.dims)) + and name not in result_variables + ): + to_broadcast[name] = var + for name, var in to_broadcast.items(): + if new_dims := tuple(d for d in parsed_dim if d not in var.dims): + new_sizes = tuple( + result.sizes.get(dim, obj.sizes.get(dim)) for dim in new_dims ) + result[name] = var.set_dims( + new_dims + var.dims, new_sizes + var.shape + ).transpose(..., *result.dims) if not isinstance(result, Dataset): # only restore dimension order for arrays diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 576c48da2c2..4c12709b511 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -11,6 +11,7 @@ import xarray as xr from xarray import DataArray, Dataset, Variable +from xarray.core.alignment import broadcast from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions from xarray.groupers import BinGrouper, EncodedGroups, Grouper, UniqueGrouper @@ -807,6 +808,7 @@ def test_groupby_dataset_errors() -> None: data.groupby(data.coords["dim1"].to_index()) # type: ignore[arg-type] +@pytest.mark.parametrize("use_flox", [True, False]) @pytest.mark.parametrize( "by_func", [ @@ -814,7 +816,10 @@ def test_groupby_dataset_errors() -> None: pytest.param(lambda x: {x: UniqueGrouper()}, id="group-by-unique-grouper"), ], ) -def test_groupby_dataset_reduce_ellipsis(by_func) -> None: +@pytest.mark.parametrize("letters_as_coord", [True, False]) +def test_groupby_dataset_reduce_ellipsis( + by_func, use_flox: bool, letters_as_coord: bool +) -> None: data = Dataset( { "xy": (["x", "y"], np.random.randn(3, 4)), @@ -824,13 +829,18 @@ def test_groupby_dataset_reduce_ellipsis(by_func) -> None: } ) + if letters_as_coord: + data = data.set_coords("letters") + expected = data.mean("y") expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) gb = data.groupby(by_func("x")) - actual = gb.mean(...) + with xr.set_options(use_flox=use_flox): + actual = gb.mean(...) assert_allclose(expected, actual) - actual = gb.mean("y") + with xr.set_options(use_flox=use_flox): + actual = gb.mean("y") assert_allclose(expected, actual) letters = data["letters"] @@ -842,7 +852,8 @@ def test_groupby_dataset_reduce_ellipsis(by_func) -> None: } ) gb = data.groupby(by_func("letters")) - actual = gb.mean(...) + with xr.set_options(use_flox=use_flox): + actual = gb.mean(...) assert_allclose(expected, actual) @@ -2606,10 +2617,11 @@ def test_multiple_groupers(use_flox) -> None: ) assert_identical(actual, expected) + expected = square.astype(np.float64) + expected["a"], expected["b"] = broadcast(square.a, square.b) with xr.set_options(use_flox=use_flox): assert_identical( - square.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean(), - square.astype(np.float64).drop_vars(("a", "b")), + square.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean(), expected ) b = xr.DataArray( @@ -2617,11 +2629,10 @@ def test_multiple_groupers(use_flox) -> None: coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])}, dims=["x", "y", "z"], ) + gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + repr(gb) with xr.set_options(use_flox=use_flox): - assert_identical( - b.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean("z"), - b.mean("z"), - ) + assert_identical(gb.mean("z"), b.mean("z")) gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper()) repr(gb) @@ -2651,3 +2662,24 @@ def test_multiple_groupers(use_flox) -> None: # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array # 3. group_over == groupby-reduce along other dimensions + + +def test_weather_data_resample(): + # from the docs + times = pd.date_range("2000-01-01", "2001-12-31", name="time") + annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) + + base = 10 + 15 * annual_cycle.reshape(-1, 1) + tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) + tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) + + ds = xr.Dataset( + { + "tmin": (("time", "location"), tmin_values), + "tmax": (("time", "location"), tmax_values), + }, + {"time": times, "location": ["IA", "IN", "IL"]}, + ) + + actual = ds.resample(time="1MS").mean() + assert "location" in actual._indexes