From 4f276c5b92edb78844ba84bffbc0ba5812bd0f3b Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 4 Feb 2025 15:05:12 +0000 Subject: [PATCH] fix: Always collect `Iterator[IntoExpr]` in `utils.flatten` (#1934) --- narwhals/utils.py | 6 +- tests/expr_and_series/all_horizontal_test.py | 229 ++++++++++--------- 2 files changed, 126 insertions(+), 109 deletions(-) diff --git a/narwhals/utils.py b/narwhals/utils.py index 7f1229793..726180689 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -387,11 +387,7 @@ def remove_suffix(text: str, suffix: str) -> str: # pragma: no cover def flatten(args: Any) -> list[Any]: - if not args: - return [] - if len(args) == 1 and _is_iterable(args[0]): - return args[0] # type: ignore[no-any-return] - return args # type: ignore[no-any-return] + return list(args[0] if (len(args) == 1 and _is_iterable(args[0])) else args) def tupleify(arg: Any) -> Any: diff --git a/tests/expr_and_series/all_horizontal_test.py b/tests/expr_and_series/all_horizontal_test.py index 706c42baf..086b35e3a 100644 --- a/tests/expr_and_series/all_horizontal_test.py +++ b/tests/expr_and_series/all_horizontal_test.py @@ -1,104 +1,125 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -import narwhals.stable.v1 as nw -from tests.utils import POLARS_VERSION -from tests.utils import Constructor -from tests.utils import ConstructorEager -from tests.utils import assert_equal_data - - -@pytest.mark.parametrize("expr1", ["a", nw.col("a")]) -@pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_allh(constructor: Constructor, expr1: Any, expr2: Any) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(constructor(data)) - result = df.select(all=nw.all_horizontal(expr1, expr2)) - - expected = {"all": [False, False, True]} - assert_equal_data(result, expected) - - -def test_allh_series(constructor_eager: ConstructorEager) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(constructor_eager(data), eager_only=True) - result = df.select(all=nw.all_horizontal(df["a"], df["b"])) - - expected = {"all": [False, False, True]} - assert_equal_data(result, expected) - - -def test_allh_all(constructor: Constructor) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(constructor(data)) - result = df.select(all=nw.all_horizontal(nw.all())) - expected = {"all": [False, False, True]} - assert_equal_data(result, expected) - result = df.select(nw.all_horizontal(nw.all())) - expected = {"a": [False, False, True]} - assert_equal_data(result, expected) - - -def test_allh_nth( - constructor: Constructor, - request: pytest.FixtureRequest, -) -> None: - if "polars" in str(constructor) and POLARS_VERSION < (1, 0): - request.applymarker(pytest.mark.xfail) - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(constructor(data)) - result = df.select(nw.all_horizontal(nw.nth(0, 1))) - expected = {"a": [False, False, True]} - assert_equal_data(result, expected) - result = df.select(nw.all_horizontal(nw.col("a"), nw.nth(0))) - expected = {"a": [False, False, True]} - assert_equal_data(result, expected) - - -def test_horizontal_expressions_empty(constructor: Constructor) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(constructor(data)) - with pytest.raises( - ValueError, match=r"At least one expression must be passed.*all_horizontal" - ): - df.select(nw.all_horizontal()) - with pytest.raises( - ValueError, match=r"At least one expression must be passed.*any_horizontal" - ): - df.select(nw.any_horizontal()) - with pytest.raises( - ValueError, match=r"At least one expression must be passed.*mean_horizontal" - ): - df.select(nw.mean_horizontal()) - with pytest.raises( - ValueError, match=r"At least one expression must be passed.*sum_horizontal" - ): - df.select(nw.sum_horizontal()) - - with pytest.raises( - ValueError, match=r"At least one expression must be passed.*max_horizontal" - ): - df.select(nw.max_horizontal()) - - with pytest.raises( - ValueError, match=r"At least one expression must be passed.*min_horizontal" - ): - df.select(nw.min_horizontal()) +from __future__ import annotations + +from typing import Any + +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import POLARS_VERSION +from tests.utils import Constructor +from tests.utils import ConstructorEager +from tests.utils import assert_equal_data + + +@pytest.mark.parametrize("expr1", ["a", nw.col("a")]) +@pytest.mark.parametrize("expr2", ["b", nw.col("b")]) +def test_allh(constructor: Constructor, expr1: Any, expr2: Any) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(constructor(data)) + result = df.select(all=nw.all_horizontal(expr1, expr2)) + + expected = {"all": [False, False, True]} + assert_equal_data(result, expected) + + +def test_allh_series(constructor_eager: ConstructorEager) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df.select(all=nw.all_horizontal(df["a"], df["b"])) + + expected = {"all": [False, False, True]} + assert_equal_data(result, expected) + + +def test_allh_all(constructor: Constructor) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(constructor(data)) + result = df.select(all=nw.all_horizontal(nw.all())) + expected = {"all": [False, False, True]} + assert_equal_data(result, expected) + result = df.select(nw.all_horizontal(nw.all())) + expected = {"a": [False, False, True]} + assert_equal_data(result, expected) + + +def test_allh_nth( + constructor: Constructor, + request: pytest.FixtureRequest, +) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 0): + request.applymarker(pytest.mark.xfail) + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(constructor(data)) + result = df.select(nw.all_horizontal(nw.nth(0, 1))) + expected = {"a": [False, False, True]} + assert_equal_data(result, expected) + result = df.select(nw.all_horizontal(nw.col("a"), nw.nth(0))) + expected = {"a": [False, False, True]} + assert_equal_data(result, expected) + + +def test_allh_iterator(constructor: Constructor) -> None: + def iter_eq(items: Any, /) -> Any: + for column, value in items: + yield nw.col(column) == value + + data = {"a": [1, 2, 3, 3, 3], "b": ["b", "b", "a", "a", "b"]} + df = nw.from_native(constructor(data)) + expr_items = [("a", 3), ("b", "b")] + expected = {"a": [3], "b": ["b"]} + + eager = nw.all_horizontal(list(iter_eq(expr_items))) + assert_equal_data(df.filter(eager), expected) + unpacked = nw.all_horizontal(*iter_eq(expr_items)) + assert_equal_data(df.filter(unpacked), expected) + lazy = nw.all_horizontal(iter_eq(expr_items)) + + assert_equal_data(df.filter(lazy), expected) + assert_equal_data(df.filter(lazy), expected) + assert_equal_data(df.filter(lazy), expected) + + +def test_horizontal_expressions_empty(constructor: Constructor) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(constructor(data)) + with pytest.raises( + ValueError, match=r"At least one expression must be passed.*all_horizontal" + ): + df.select(nw.all_horizontal()) + with pytest.raises( + ValueError, match=r"At least one expression must be passed.*any_horizontal" + ): + df.select(nw.any_horizontal()) + with pytest.raises( + ValueError, match=r"At least one expression must be passed.*mean_horizontal" + ): + df.select(nw.mean_horizontal()) + with pytest.raises( + ValueError, match=r"At least one expression must be passed.*sum_horizontal" + ): + df.select(nw.sum_horizontal()) + + with pytest.raises( + ValueError, match=r"At least one expression must be passed.*max_horizontal" + ): + df.select(nw.max_horizontal()) + + with pytest.raises( + ValueError, match=r"At least one expression must be passed.*min_horizontal" + ): + df.select(nw.min_horizontal())