diff --git a/xarray/namedarray/_array_api/statistical_functions.py b/xarray/namedarray/_array_api/statistical_functions.py index 8aa1db92a7f..d99f0f812ef 100644 --- a/xarray/namedarray/_array_api/statistical_functions.py +++ b/xarray/namedarray/_array_api/statistical_functions.py @@ -17,6 +17,23 @@ ) +def max( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.max(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + def mean( x: NamedArray[Any, _DType], /, @@ -82,3 +99,54 @@ def mean( dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out + + +def min( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.min(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def prod( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.prod(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def sum( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.sum(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out