From cd9061827805d2a245a1c2e790a000a9815e06d9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:22:49 +0200 Subject: [PATCH] add more statistical functions --- .../_array_api/statistical_functions.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/xarray/namedarray/_array_api/statistical_functions.py b/xarray/namedarray/_array_api/statistical_functions.py index 8aa1db92a7f..d99f0f812ef 100644 --- a/xarray/namedarray/_array_api/statistical_functions.py +++ b/xarray/namedarray/_array_api/statistical_functions.py @@ -17,6 +17,23 @@ ) +def max( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.max(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + def mean( x: NamedArray[Any, _DType], /, @@ -82,3 +99,54 @@ def mean( dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out + + +def min( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.min(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def prod( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.prod(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def sum( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.sum(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out