Skip to content

Commit

Permalink
Add manip functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 25, 2024
1 parent 9618f2a commit beb7a35
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 27 deletions.
28 changes: 14 additions & 14 deletions xarray/namedarray/_array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,30 +259,30 @@
]

from xarray.namedarray._array_api._manipulation_functions import (
# broadcast_arrays,
# broadcast_to,
# concat,
broadcast_arrays,
broadcast_to,
concat,
expand_dims,
# flip,
# moveaxis,
flip,
moveaxis,
permute_dims,
reshape,
# roll,
# squeeze,
roll,
squeeze,
stack,
)

__all__ += [
# "broadcast_arrays",
# "broadcast_to",
# "concat",
"broadcast_arrays",
"broadcast_to",
"concat",
"expand_dims",
# "flip",
# "moveaxis",
"flip",
"moveaxis",
"permute_dims",
"reshape",
# "roll",
# "squeeze",
"roll",
"squeeze",
"stack",
]

Expand Down
126 changes: 113 additions & 13 deletions xarray/namedarray/_array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,54 @@
from typing import Any

from xarray.namedarray._array_api._creation_functions import asarray
from xarray.namedarray._array_api._utils import _get_data_namespace
from xarray.namedarray._array_api._data_type_functions import result_type
from xarray.namedarray._array_api._utils import (
_get_data_namespace,
_infer_dims,
_insert_dim,
)
from xarray.namedarray._typing import (
Default,
_arrayapi,
_Axes,
_Axis,
_default,
_Dim,
_Dims,
_DType,
_Shape,
)
from xarray.namedarray.core import (
NamedArray,
)
from xarray.namedarray.core import NamedArray


def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]:
x = arrays[0]
xp = _get_data_namespace(x)
_arrays = tuple(a._data for a in arrays)
_datas = xp.broadcast_arrays(_arrays)
out = []
for _data in _datas:
_dims = _infer_dims(_data) # TODO: Fix dims
out.append(x._new(_dims, _data))
return out


def broadcast_to(x: NamedArray, /, shape: _Shape) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.broadcast_to(x._data, shape=shape)
_dims = _infer_dims(_data) # TODO: Fix dims
return x._new(_dims, _data)


def concat(
arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis | None = 0
) -> NamedArray:
xp = _get_data_namespace(arrays[0])
dtype = result_type(*arrays)
arrays = tuple(a._data for a in arrays)
_data = xp.concat(arrays, axis=axis, dtype=dtype)
_dims = _infer_dims(_data)
return NamedArray(_dims, _data)


def expand_dims(
Expand Down Expand Up @@ -57,13 +91,23 @@ def expand_dims(
[3., 4.]]])
"""
xp = _get_data_namespace(x)
dims = x.dims
if dim is _default:
dim = f"dim_{len(dims)}"
d = list(dims)
d.insert(axis, dim)
out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
return out
_data = xp.expand_dims(x._data, axis=axis)
_dims = _insert_dim(x.dims, dim, axis)
return x._new(_dims, _data)


def flip(x: NamedArray, /, *, axis: _Axes | None = None) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.flip(x._data, axis=axis)
_dims = _infer_dims(_data) # TODO: Fix dims
return x._new(_dims, _data)


def moveaxis(x: NamedArray, source: _Axes, destination: _Axes, /) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.moveaxis(x._data, source=source, destination=destination)
_dims = _infer_dims(_data) # TODO: Fix dims
return x._new(_dims, _data)


def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]:
Expand Down Expand Up @@ -95,6 +139,19 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DT
return out


def repeat(
x: NamedArray,
repeats: int | NamedArray,
/,
*,
axis: _Axis | None = None,
) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.repeat(x._data, repeats, axis=axis)
_dims = _infer_dims(_data) # TODO: Fix dims
return x._new(_dims, _data)


def reshape(x, /, shape: _Shape, *, copy: bool | None = None):
xp = _get_data_namespace(x)
_data = xp.reshape(x._data, shape)
Expand All @@ -105,5 +162,48 @@ def reshape(x, /, shape: _Shape, *, copy: bool | None = None):
return out


def stack(arrays, /, *, axis=0):
raise NotImplementedError("TODO:")
def roll(
x: NamedArray,
/,
shift: int | tuple[int, ...],
*,
axis: _Axes | None = None,
) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.roll(x._data, shift=shift, axis=axis)
return x._new(_data)


def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.squeeze(x._data, axis=axis)
_dims = _infer_dims(_data) # TODO: Fix dims
return x._new(_dims, _data)


def stack(
arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis = 0
) -> NamedArray:
x = arrays[0]
xp = _get_data_namespace(x)
arrays = tuple(a._data for a in arrays)
_data = xp.stack(arrays, axis=axis)
_dims = _infer_dims(_data) # TODO: Fix dims
return x._new(_dims, _data)


def tile(x: NamedArray, repetitions: tuple[int, ...], /) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.tile(x._data, repetitions)
_dims = _infer_dims(_data) # TODO: Fix dims
return x._new(_dims, _data)


def unstack(x: NamedArray, /, *, axis: _Axis = 0) -> tuple[NamedArray, ...]:
xp = _get_data_namespace(x)
_datas = xp.unstack(x._data, axis=axis)
out = ()
for _data in _datas:
_dims = _infer_dims(_data) # TODO: Fix dims
out += (x._new(_dims, _data),)
return out
9 changes: 9 additions & 0 deletions xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_DimsLike,
_DType,
_dtype,
_Axis,
_Shape,
duckarray,
)
Expand Down Expand Up @@ -162,3 +163,11 @@ def _get_remaining_dims(
dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes)

return dims, data


def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims:
if dim is _default:
dim = f"dim_{len(dims)}"
d = list(dims)
d.insert(axis, dim)
return tuple(d)

0 comments on commit beb7a35

Please sign in to comment.