diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index 94121d7e..83b1ee94 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -367,9 +367,9 @@ def __array_namespace__(self, /, *, api_version=None): "2023.12", ): raise ValueError(f"Unrecognized array API version: {api_version!r}") - import cubed.array_api as array_api + import cubed - return array_api + return cubed def __bool__(self, /): if self.ndim != 0: diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 928a2f41..4f451580 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -18,9 +18,9 @@ # https://github.com/data-apis/array-api/issues/621 -def nanmean(x, /, *, axis=None, keepdims=False, split_every=None): +def nanmean(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" - dtype = x.dtype + dtype = dtype or x.dtype intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] return reduction( x,