Skip to content

Commit

Permalink
fix: Always collect Iterator[IntoExpr] in utils.flatten (#1934)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Feb 4, 2025
1 parent 4a2ca52 commit 4f276c5
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 109 deletions.
6 changes: 1 addition & 5 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,7 @@ def remove_suffix(text: str, suffix: str) -> str: # pragma: no cover


def flatten(args: Any) -> list[Any]:
if not args:
return []
if len(args) == 1 and _is_iterable(args[0]):
return args[0] # type: ignore[no-any-return]
return args # type: ignore[no-any-return]
return list(args[0] if (len(args) == 1 and _is_iterable(args[0])) else args)


def tupleify(arg: Any) -> Any:
Expand Down
229 changes: 125 additions & 104 deletions tests/expr_and_series/all_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,125 @@
from __future__ import annotations

from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data


@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_allh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
result = df.select(all=nw.all_horizontal(expr1, expr2))

expected = {"all": [False, False, True]}
assert_equal_data(result, expected)


def test_allh_series(constructor_eager: ConstructorEager) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df.select(all=nw.all_horizontal(df["a"], df["b"]))

expected = {"all": [False, False, True]}
assert_equal_data(result, expected)


def test_allh_all(constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
result = df.select(all=nw.all_horizontal(nw.all()))
expected = {"all": [False, False, True]}
assert_equal_data(result, expected)
result = df.select(nw.all_horizontal(nw.all()))
expected = {"a": [False, False, True]}
assert_equal_data(result, expected)


def test_allh_nth(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
result = df.select(nw.all_horizontal(nw.nth(0, 1)))
expected = {"a": [False, False, True]}
assert_equal_data(result, expected)
result = df.select(nw.all_horizontal(nw.col("a"), nw.nth(0)))
expected = {"a": [False, False, True]}
assert_equal_data(result, expected)


def test_horizontal_expressions_empty(constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*all_horizontal"
):
df.select(nw.all_horizontal())
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*any_horizontal"
):
df.select(nw.any_horizontal())
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*mean_horizontal"
):
df.select(nw.mean_horizontal())
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*sum_horizontal"
):
df.select(nw.sum_horizontal())

with pytest.raises(
ValueError, match=r"At least one expression must be passed.*max_horizontal"
):
df.select(nw.max_horizontal())

with pytest.raises(
ValueError, match=r"At least one expression must be passed.*min_horizontal"
):
df.select(nw.min_horizontal())
from __future__ import annotations

from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data


@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_allh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
result = df.select(all=nw.all_horizontal(expr1, expr2))

expected = {"all": [False, False, True]}
assert_equal_data(result, expected)


def test_allh_series(constructor_eager: ConstructorEager) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df.select(all=nw.all_horizontal(df["a"], df["b"]))

expected = {"all": [False, False, True]}
assert_equal_data(result, expected)


def test_allh_all(constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
result = df.select(all=nw.all_horizontal(nw.all()))
expected = {"all": [False, False, True]}
assert_equal_data(result, expected)
result = df.select(nw.all_horizontal(nw.all()))
expected = {"a": [False, False, True]}
assert_equal_data(result, expected)


def test_allh_nth(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
result = df.select(nw.all_horizontal(nw.nth(0, 1)))
expected = {"a": [False, False, True]}
assert_equal_data(result, expected)
result = df.select(nw.all_horizontal(nw.col("a"), nw.nth(0)))
expected = {"a": [False, False, True]}
assert_equal_data(result, expected)


def test_allh_iterator(constructor: Constructor) -> None:
def iter_eq(items: Any, /) -> Any:
for column, value in items:
yield nw.col(column) == value

data = {"a": [1, 2, 3, 3, 3], "b": ["b", "b", "a", "a", "b"]}
df = nw.from_native(constructor(data))
expr_items = [("a", 3), ("b", "b")]
expected = {"a": [3], "b": ["b"]}

eager = nw.all_horizontal(list(iter_eq(expr_items)))
assert_equal_data(df.filter(eager), expected)
unpacked = nw.all_horizontal(*iter_eq(expr_items))
assert_equal_data(df.filter(unpacked), expected)
lazy = nw.all_horizontal(iter_eq(expr_items))

assert_equal_data(df.filter(lazy), expected)
assert_equal_data(df.filter(lazy), expected)
assert_equal_data(df.filter(lazy), expected)


def test_horizontal_expressions_empty(constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(constructor(data))
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*all_horizontal"
):
df.select(nw.all_horizontal())
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*any_horizontal"
):
df.select(nw.any_horizontal())
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*mean_horizontal"
):
df.select(nw.mean_horizontal())
with pytest.raises(
ValueError, match=r"At least one expression must be passed.*sum_horizontal"
):
df.select(nw.sum_horizontal())

with pytest.raises(
ValueError, match=r"At least one expression must be passed.*max_horizontal"
):
df.select(nw.max_horizontal())

with pytest.raises(
ValueError, match=r"At least one expression must be passed.*min_horizontal"
):
df.select(nw.min_horizontal())

0 comments on commit 4f276c5

Please sign in to comment.