Skip to content

Commit

Permalink
use array namespace again
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Feb 6, 2024
1 parent 1467c4c commit d9931ef
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
def where(condition, x, y):
"""Three argument where() with better dtype promotion rules."""
xp = get_array_namespace(condition)
return xp.where(condition, *as_shared_dtype([x, y]))
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))


def where_method(data, cond, other=dtypes.NA):
Expand All @@ -354,14 +354,14 @@ def concatenate(arrays, axis=0):
arrays[0], np.ndarray
):
xp = get_array_namespace(arrays[0])
return xp.concat(as_shared_dtype(arrays), axis=axis)
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
return _concatenate(as_shared_dtype(arrays), axis=axis)


def stack(arrays, axis=0):
"""stack() with better dtype promotion rules."""
xp = get_array_namespace(arrays[0])
return xp.stack(as_shared_dtype(arrays), axis=axis)
return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)


def reshape(array, shape):
Expand Down

0 comments on commit d9931ef

Please sign in to comment.