From beb7a350cacb4cec2f1eb34a575250373aeeeea2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:12:36 +0200 Subject: [PATCH] Add manip functions --- xarray/namedarray/_array_api/__init__.py | 28 ++-- .../_array_api/_manipulation_functions.py | 126 ++++++++++++++++-- xarray/namedarray/_array_api/_utils.py | 9 ++ 3 files changed, 136 insertions(+), 27 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 8167d6adc8f..c5afda1b329 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -259,30 +259,30 @@ ] from xarray.namedarray._array_api._manipulation_functions import ( - # broadcast_arrays, - # broadcast_to, - # concat, + broadcast_arrays, + broadcast_to, + concat, expand_dims, - # flip, - # moveaxis, + flip, + moveaxis, permute_dims, reshape, - # roll, - # squeeze, + roll, + squeeze, stack, ) __all__ += [ - # "broadcast_arrays", - # "broadcast_to", - # "concat", + "broadcast_arrays", + "broadcast_to", + "concat", "expand_dims", - # "flip", - # "moveaxis", + "flip", + "moveaxis", "permute_dims", "reshape", - # "roll", - # "squeeze", + "roll", + "squeeze", "stack", ] diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 662911cbbf6..977a45f41da 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -3,7 +3,12 @@ from typing import Any from xarray.namedarray._array_api._creation_functions import asarray -from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._data_type_functions import result_type +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _infer_dims, + _insert_dim, +) from xarray.namedarray._typing import ( Default, _arrayapi, @@ -11,12 +16,41 @@ _Axis, _default, _Dim, + _Dims, _DType, _Shape, ) -from xarray.namedarray.core import ( - NamedArray, -) +from xarray.namedarray.core import NamedArray + + +def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: + x = arrays[0] + xp = _get_data_namespace(x) + _arrays = tuple(a._data for a in arrays) + _datas = xp.broadcast_arrays(_arrays) + out = [] + for _data in _datas: + _dims = _infer_dims(_data) # TODO: Fix dims + out.append(x._new(_dims, _data)) + return out + + +def broadcast_to(x: NamedArray, /, shape: _Shape) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.broadcast_to(x._data, shape=shape) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def concat( + arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis | None = 0 +) -> NamedArray: + xp = _get_data_namespace(arrays[0]) + dtype = result_type(*arrays) + arrays = tuple(a._data for a in arrays) + _data = xp.concat(arrays, axis=axis, dtype=dtype) + _dims = _infer_dims(_data) + return NamedArray(_dims, _data) def expand_dims( @@ -57,13 +91,23 @@ def expand_dims( [3., 4.]]]) """ xp = _get_data_namespace(x) - dims = x.dims - if dim is _default: - dim = f"dim_{len(dims)}" - d = list(dims) - d.insert(axis, dim) - out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) - return out + _data = xp.expand_dims(x._data, axis=axis) + _dims = _insert_dim(x.dims, dim, axis) + return x._new(_dims, _data) + + +def flip(x: NamedArray, /, *, axis: _Axes | None = None) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.flip(x._data, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def moveaxis(x: NamedArray, source: _Axes, destination: _Axes, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.moveaxis(x._data, source=source, destination=destination) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: @@ -95,6 +139,19 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DT return out +def repeat( + x: NamedArray, + repeats: int | NamedArray, + /, + *, + axis: _Axis | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.repeat(x._data, repeats, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + def reshape(x, /, shape: _Shape, *, copy: bool | None = None): xp = _get_data_namespace(x) _data = xp.reshape(x._data, shape) @@ -105,5 +162,48 @@ def reshape(x, /, shape: _Shape, *, copy: bool | None = None): return out -def stack(arrays, /, *, axis=0): - raise NotImplementedError("TODO:") +def roll( + x: NamedArray, + /, + shift: int | tuple[int, ...], + *, + axis: _Axes | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.roll(x._data, shift=shift, axis=axis) + return x._new(_data) + + +def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.squeeze(x._data, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def stack( + arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis = 0 +) -> NamedArray: + x = arrays[0] + xp = _get_data_namespace(x) + arrays = tuple(a._data for a in arrays) + _data = xp.stack(arrays, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def tile(x: NamedArray, repetitions: tuple[int, ...], /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.tile(x._data, repetitions) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def unstack(x: NamedArray, /, *, axis: _Axis = 0) -> tuple[NamedArray, ...]: + xp = _get_data_namespace(x) + _datas = xp.unstack(x._data, axis=axis) + out = () + for _data in _datas: + _dims = _infer_dims(_data) # TODO: Fix dims + out += (x._new(_dims, _data),) + return out diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index eba610ec5c1..77afdeb07cd 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -14,6 +14,7 @@ _DimsLike, _DType, _dtype, + _Axis, _Shape, duckarray, ) @@ -162,3 +163,11 @@ def _get_remaining_dims( dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes) return dims, data + + +def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: + if dim is _default: + dim = f"dim_{len(dims)}" + d = list(dims) + d.insert(axis, dim) + return tuple(d)