diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 6b1029791ea..088af702a13 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -3280,7 +3280,7 @@ def var( dim: Dims = None, *, skipna: bool | None = None, - ddof: int = 0, + ddof: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, ) -> Self: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 7e7333fd8ea..bd6c2b505ff 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -13,6 +13,7 @@ from collections.abc import Callable from functools import partial from importlib import import_module +from typing import Any import numpy as np import pandas as pd @@ -29,6 +30,7 @@ from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils from xarray.core.array_api_compat import get_array_namespace from xarray.core.options import OPTIONS +from xarray.core.types import T_DuckArray from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import array_type, is_chunked_array @@ -43,7 +45,6 @@ normalize_axis_index, ) - dask_available = module_available("dask") @@ -474,6 +475,9 @@ def f(values, axis=None, skipna=None, **kwargs): if coerce_strings and dtypes.is_string(values.dtype): values = astype(values, object) + if name in ["std", "var"]: + kwargs = _handle_ddof_vs_correction_kwargs(kwargs, values) + func = None if skipna or ( skipna is None @@ -512,6 +516,42 @@ def f(values, axis=None, skipna=None, **kwargs): return f +def _handle_ddof_vs_correction_kwargs( + kwargs: dict[str, Any], values: T_DuckArray +) -> dict[str, Any]: + # handle ddof vs correction kwargs, see GH8566 + + ddof = kwargs.pop("ddof", None) + correction = kwargs.pop("correction", None) + + if ddof is None and correction is None: + degrees_of_freedom_correction_val = 0 # default to 0, see GH8566 + elif ddof is not None and correction is not None: + if ddof != correction: + raise ValueError( + "ddof and correction both refer to the same thing, and so cannot be meaningfully set to different values. " + f"Got ddof={ddof} but correction={correction}" + ) + else: + # both kwargs were passed, but they are the same value + warnings.warn( + "ddof and correction both refer to the same thing - you don't need to pass them both" + ) + degrees_of_freedom_correction_val = ddof + else: + # only one of the two kwargs passed + degrees_of_freedom_correction_val = ddof if ddof is not None else correction + + # assume that only array-API-compliant libraries will implement correction over ddof + # reasonable since array-API people made this name change, see https://github.com/data-apis/array-api/issues/695 + if hasattr(values, "__array_namespace__"): + kwargs["correction"] = degrees_of_freedom_correction_val + else: + kwargs["ddof"] = degrees_of_freedom_correction_val + + return kwargs + + # Attributes `numeric_only`, `available_min_count` is used for docs. # See ops.inject_reduce_methods argmax = _create_nan_agg_method("argmax", coerce_strings=True) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 4894cf02be2..437a70ed07d 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -156,7 +156,7 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0): return nputils.nanvar(a, axis=axis, dtype=dtype, ddof=ddof) -def nanstd(a, axis=None, dtype=None, out=None, ddof=0): +def nanstd(a, axis=None, dtype=None, out=None, ddof=None): return nputils.nanstd(a, axis=axis, dtype=dtype, ddof=ddof) diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py index 139cea83b5b..99c574a3b58 100644 --- a/xarray/namedarray/_aggregations.py +++ b/xarray/namedarray/_aggregations.py @@ -552,7 +552,7 @@ def std( dim: Dims = None, *, skipna: bool | None = None, - ddof: int = 0, + ddof: int | None = None, **kwargs: Any, ) -> Self: """ diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 3ebb67cf1f3..c76c9016e14 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -144,3 +144,45 @@ def test_where() -> None: actual = xr.where(xp_arr, 1, 0) assert isinstance(actual.data, Array) assert_equal(actual, expected) + + +def test_statistics(arrays) -> None: + with xr.set_options(use_bottleneck=False): + np_arr_nans, xp_arr_nans = arrays + + # TODO this is not really a test of the array API compatibility specifically, just of the ddof vs correction kwarg handling + expected = np_arr_nans.std(ddof=1) + actual = np_arr_nans.std(correction=1) + assert_equal(actual, expected) + + # test the default value + expected = np_arr_nans.std() + actual = np_arr_nans.std(correction=0) + assert_equal(actual, expected) + + # TODO have to remove the NaNs in the example data because array API standard can't handle them yet, see https://github.com/pydata/xarray/pull/7424#issuecomment-1373979208 + np_arr, xp_arr = np_arr_nans.fillna(0), xp_arr_nans.fillna(0) + + expected = np_arr.std() + actual = xp_arr.std(skipna=False) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + expected = np_arr.std(ddof=1) + actual = xp_arr.std(skipna=False, ddof=1) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + expected = np_arr.std(ddof=1) + actual = xp_arr.std(skipna=False, correction=1) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + # TODO check this warns + expected = np_arr.std(ddof=1) + actual = xp_arr.std(skipna=False, ddof=1, correction=1) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + with pytest.raises(ValueError): + xp_arr.std(skipna=False, ddof=1, correction=0)