From 4364d780c5ee7fa9b5e3b7e8e1a7d88e86216932 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Wed, 15 Jan 2025 14:43:00 -0800 Subject: [PATCH] Move private function to the bottom of the file. --- cubed/array_api/statistical_functions.py | 54 ++++++++++++------------ 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index b26279bc..07b7f231 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -113,33 +113,6 @@ def min(x, /, *, axis=None, keepdims=False, split_every=None): ) -def _validate_and_define_numeric_or_bool_dtype(x, dtype=None, *, fname=None, device=None): - """Validate the type of the numeric function. If it's None, provide a good default dtype.""" - dtypes = __array_namespace_info__().default_dtypes(device=device) - - # Validate. - # boolean is allowed by numpy - if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: - errmsg = "Only numeric or boolean dtypes are allowed" - if fname: - errmsg += f" in {fname}" - raise TypeError(errmsg) - - # Choose a good default dtype, when None - if dtype is None: - if x.dtype in _boolean_dtypes: - dtype = dtypes['integral'] - elif x.dtype in _signed_integer_dtypes: - dtype = dtypes['integral'] - elif x.dtype in _unsigned_integer_dtypes: - #TODO(#658): I don't think "indexing" --> uint64; is this correct? - dtype = dtypes['indexing'] - else: - dtype = x.dtype - - return dtype - - def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None): dtype = _validate_and_define_numeric_or_bool_dtype(x, dtype, fname="prod", device=device) extra_func_kwargs = dict(dtype=dtype) @@ -249,3 +222,30 @@ def _var_combine(a, axis=None, correction=None, **kwargs): def _var_aggregate(a, correction=None, **kwargs): return nxp.divide(a["M2"], a["n"] - correction) + + +def _validate_and_define_numeric_or_bool_dtype(x, dtype=None, *, fname=None, device=None): + """Validate the type of the numeric function. If it's None, provide a good default dtype.""" + dtypes = __array_namespace_info__().default_dtypes(device=device) + + # Validate. + # boolean is allowed by numpy + if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: + errmsg = "Only numeric or boolean dtypes are allowed" + if fname: + errmsg += f" in {fname}" + raise TypeError(errmsg) + + # Choose a good default dtype, when None + if dtype is None: + if x.dtype in _boolean_dtypes: + dtype = dtypes['integral'] + elif x.dtype in _signed_integer_dtypes: + dtype = dtypes['integral'] + elif x.dtype in _unsigned_integer_dtypes: + #TODO(#658): I don't think "indexing" --> uint64; is this correct? + dtype = dtypes['indexing'] + else: + dtype = x.dtype + + return dtype