From ba53ae2a75aa1a17cb5827b711963bb2764e3dc2 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Thu, 25 Jan 2024 11:02:26 -0500 Subject: [PATCH 1/4] test automatic broadcasting --- xarray/tests/test_array_api.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index fddaa120970..d47995745e0 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -77,6 +77,21 @@ def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(a, e) +def test_broadcast_during_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x") + xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x") + + expected = np_arr * np_arr2 + actual = xp_arr * xp_arr2 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + expected = np_arr2 * np_arr + actual = xp_arr2 * xp_arr + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = xr.concat((np_arr, np_arr), dim="x") From 1d0171e00eaa95bd44b6cade39395f71ed80d2a7 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Thu, 25 Jan 2024 11:02:36 -0500 Subject: [PATCH 2/4] fix bug --- xarray/core/variable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3add7a1441e..b3389a83673 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1476,7 +1476,8 @@ def set_dims(self, dims, shape=None): tmp_shape = tuple(dims_map[d] for d in expanded_dims) expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape) else: - expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)] + indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) + expanded_data = self.data[indexer] expanded_var = Variable( expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True From 614f140f756e2ed45a99da50be7e57893153fcd7 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Thu, 25 Jan 2024 11:06:31 -0500 Subject: [PATCH 3/4] whatsnew --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ac0015c14c5..1aeff9f5b41 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix bug with broadcasting when wrapping array API-compliant classes. (:issue:`8665`, :pull:`8669`) + By `Tom Nicholas `_. Documentation ~~~~~~~~~~~~~ From ffa64cfde28d9fe345f4f55ad9747a160eec5e46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jan 2024 16:06:51 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/tests/test_array_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index d47995745e0..5af9f076df2 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -92,6 +92,7 @@ def test_broadcast_during_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) assert isinstance(actual.data, Array) assert_equal(actual, expected) + def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = xr.concat((np_arr, np_arr), dim="x")