diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 35cf0c548e2..151ed9da105 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix bug with broadcasting when wrapping array API-compliant classes. (:issue:`8665`, :pull:`8669`) + By `Tom Nicholas `_. - Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant classes. (:issue:`8666`, :pull:`8668`) By `Tom Nicholas `_. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e4add9f838e..119495a486a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1476,7 +1476,8 @@ def set_dims(self, dims, shape=None): tmp_shape = tuple(dims_map[d] for d in expanded_dims) expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape) else: - expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)] + indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) + expanded_data = self.data[indexer] expanded_var = Variable( expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index fea36d9aca4..03dcfd9b20f 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -77,6 +77,22 @@ def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(a, e) +def test_broadcast_during_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x") + xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x") + + expected = np_arr * np_arr2 + actual = xp_arr * xp_arr2 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + expected = np_arr2 * np_arr + actual = xp_arr2 * xp_arr + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = xr.concat((np_arr, np_arr), dim="x")