Skip to content

Commit

Permalink
add more statistical functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 19, 2024
1 parent 1855f7f commit cd90618
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions xarray/namedarray/_array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@
)


def max(
x: NamedArray[Any, _DType],
/,
*,
dims: _Dims | Default = _default,
keepdims: bool = False,
axis: _AxisLike | None = None,
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
axis_ = _dims_to_axis(x, dims, axis)
d = xp.max(x._data, axis=axis_, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out


def mean(
x: NamedArray[Any, _DType],
/,
Expand Down Expand Up @@ -82,3 +99,54 @@ def mean(
dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out


def min(
x: NamedArray[Any, _DType],
/,
*,
dims: _Dims | Default = _default,
keepdims: bool = False,
axis: _AxisLike | None = None,
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
axis_ = _dims_to_axis(x, dims, axis)
d = xp.min(x._data, axis=axis_, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out


def prod(
x: NamedArray[Any, _DType],
/,
*,
dims: _Dims | Default = _default,
keepdims: bool = False,
axis: _AxisLike | None = None,
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
axis_ = _dims_to_axis(x, dims, axis)
d = xp.prod(x._data, axis=axis_, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out


def sum(
x: NamedArray[Any, _DType],
/,
*,
dims: _Dims | Default = _default,
keepdims: bool = False,
axis: _AxisLike | None = None,
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
axis_ = _dims_to_axis(x, dims, axis)
d = xp.sum(x._data, axis=axis_, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out

0 comments on commit cd90618

Please sign in to comment.