From 67b1ae4f9dfe3d477f827f4997924070cc0340e4 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Fri, 24 Jan 2025 08:21:26 +0100 Subject: [PATCH] feat: allow any ddof value for duckdb `var` and `std` (#1858) --- narwhals/_duckdb/expr.py | 37 ++++++++++++++++++------------- tests/expr_and_series/std_test.py | 21 ++++++------------ tests/expr_and_series/var_test.py | 21 ++++++------------ tests/group_by_test.py | 8 +------ 4 files changed, 36 insertions(+), 51 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 3571aadbc..031e48e67 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -460,27 +460,32 @@ def len(self: Self) -> Self: ) def std(self: Self, ddof: int) -> Self: - if ddof == 1: - func = "stddev_samp" - elif ddof == 0: - func = "stddev_pop" - else: - msg = f"std with ddof {ddof} is not supported in DuckDB" - raise NotImplementedError(msg) + def _std(_input: duckdb.Expression, ddof: int) -> duckdb.Expression: + n_samples = FunctionExpression("count", _input) + + return ( + FunctionExpression("stddev_pop", _input) + * FunctionExpression("sqrt", n_samples) + / (FunctionExpression("sqrt", (n_samples - ddof))) + ) + return self._from_call( - lambda _input: FunctionExpression(func, _input), "std", returns_scalar=True + _std, + "std", + ddof=ddof, + returns_scalar=True, ) def var(self: Self, ddof: int) -> Self: - if ddof == 1: - func = "var_samp" - elif ddof == 0: - func = "var_pop" - else: - msg = f"var with ddof {ddof} is not supported in DuckDB" - raise NotImplementedError(msg) + def _var(_input: duckdb.Expression, ddof: int) -> duckdb.Expression: + n_samples = FunctionExpression("count", _input) + return FunctionExpression("var_pop", _input) * n_samples / (n_samples - ddof) + return self._from_call( - lambda _input: FunctionExpression(func, _input), "var", returns_scalar=True + _var, + "var", + ddof=ddof, + returns_scalar=True, ) def max(self: Self) -> Self: diff --git a/tests/expr_and_series/std_test.py b/tests/expr_and_series/std_test.py index 0cc9b6722..1f8735c19 100644 --- a/tests/expr_and_series/std_test.py +++ b/tests/expr_and_series/std_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - import pytest import narwhals.stable.v1 as nw @@ -34,19 +32,14 @@ def test_std(constructor: Constructor, input_data: dict[str, list[float | None]] "z_ddof_0": [0.816497], } assert_equal_data(result, expected_results) - context = ( - pytest.raises(NotImplementedError) - if "duckdb" in str(constructor) - else does_not_raise() + + result = df.select( + nw.col("b").std(ddof=2).alias("b_ddof_2"), ) - with context: - result = df.select( - nw.col("b").std(ddof=2).alias("b_ddof_2"), - ) - expected_results = { - "b_ddof_2": [1.632993], - } - assert_equal_data(result, expected_results) + expected_results = { + "b_ddof_2": [1.632993], + } + assert_equal_data(result, expected_results) @pytest.mark.parametrize("input_data", [data, data_with_nulls]) diff --git a/tests/expr_and_series/var_test.py b/tests/expr_and_series/var_test.py index 0edd8e305..109985400 100644 --- a/tests/expr_and_series/var_test.py +++ b/tests/expr_and_series/var_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - import pytest import narwhals.stable.v1 as nw @@ -34,19 +32,14 @@ def test_var(constructor: Constructor, input_data: dict[str, list[float | None]] "z_ddof_0": [0.6666666666666666], } assert_equal_data(result, expected_results) - context = ( - pytest.raises(NotImplementedError) - if "duckdb" in str(constructor) - else does_not_raise() + + result = df.select( + nw.col("b").var(ddof=2).alias("b_ddof_2"), ) - with context: - result = df.select( - nw.col("b").var(ddof=2).alias("b_ddof_2"), - ) - expected_results = { - "b_ddof_2": [2.666666666666667], - } - assert_equal_data(result, expected_results) + expected_results = { + "b_ddof_2": [2.666666666666667], + } + assert_equal_data(result, expected_results) @pytest.mark.parametrize("input_data", [data, data_with_nulls]) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 3ba35fabc..9929d36cd 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -137,11 +137,7 @@ def test_group_by_depth_1_agg( ("var", 2), ], ) -def test_group_by_depth_1_std_var( - constructor: Constructor, attr: str, ddof: int, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor) and ddof == 2: - request.applymarker(pytest.mark.xfail) +def test_group_by_depth_1_std_var(constructor: Constructor, attr: str, ddof: int) -> None: data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} _pow = 0.5 if attr == "std" else 1 expected = { @@ -398,8 +394,6 @@ def test_all_kind_of_aggs( # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4): # Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']] request.applymarker(pytest.mark.xfail)