From d9931ef78879352e35b8d8c7afa716678f1adae4 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 04:39:21 -0500 Subject: [PATCH] use array namespace again --- xarray/core/duck_array_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 91161bd3510..035255aa619 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -331,7 +331,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition) - return xp.where(condition, *as_shared_dtype([x, y])) + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): @@ -354,14 +354,14 @@ def concatenate(arrays, axis=0): arrays[0], np.ndarray ): xp = get_array_namespace(arrays[0]) - return xp.concat(as_shared_dtype(arrays), axis=axis) + return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) return _concatenate(as_shared_dtype(arrays), axis=axis) def stack(arrays, axis=0): """stack() with better dtype promotion rules.""" xp = get_array_namespace(arrays[0]) - return xp.stack(as_shared_dtype(arrays), axis=axis) + return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis) def reshape(array, shape):