diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 151ed9da105..8865eb98481 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,9 @@ v2024.02.0 (unreleased) New Features ~~~~~~~~~~~~ +- Xarray now defers to flox's `heuristics `_ + to set default `method` for groupby problems. This only applies to ``flox>=0.9``. + By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ebb488d42c9..3a9c3ee6470 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +from packaging.version import Version from xarray.core import dtypes, duck_array_ops, nputils, ops from xarray.core._aggregations import ( @@ -1003,6 +1004,7 @@ def _flox_reduce( **kwargs: Any, ): """Adaptor function that translates our groupby API to that of flox.""" + import flox from flox.xarray import xarray_reduce from xarray.core.dataset import Dataset @@ -1014,10 +1016,11 @@ def _flox_reduce( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - # preserve current strategy (approximately) for dask groupby. - # We want to control the default anyway to prevent surprises - # if flox decides to change its default - kwargs.setdefault("method", "cohorts") + if Version(flox.__version__) < Version("0.9"): + # preserve current strategy (approximately) for dask groupby + # on older flox versions to prevent surprises. + # flox >=0.9 will choose this on its own. + kwargs.setdefault("method", "cohorts") numeric_only = kwargs.pop("numeric_only", None) if numeric_only: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e45d8ed0bef..25fabd5e2b9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3,10 +3,12 @@ import datetime import operator import warnings +from unittest import mock import numpy as np import pandas as pd import pytest +from packaging.version import Version import xarray as xr from xarray import DataArray, Dataset, Variable @@ -2499,3 +2501,20 @@ def test_groupby_dim_no_dim_equal(use_flox): actual1 = da.drop_vars("lat").groupby("lat", squeeze=False).sum() actual2 = da.groupby("lat", squeeze=False).sum() assert_identical(actual1, actual2.drop_vars("lat")) + + +@requires_flox +def test_default_flox_method(): + import flox.xarray + + da = xr.DataArray([1, 2, 3], dims="x", coords={"label": ("x", [2, 2, 1])}) + + result = xr.DataArray([3, 3], dims="label", coords={"label": [1, 2]}) + with mock.patch("flox.xarray.xarray_reduce", return_value=result) as mocked_reduce: + da.groupby("label").sum() + + kwargs = mocked_reduce.call_args.kwargs + if Version(flox.__version__) < Version("0.9.0"): + assert kwargs["method"] == "cohorts" + else: + assert "method" not in kwargs