Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Incorrect broadcasting on list-of-string set ops #18918

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions crates/polars-ops/src/chunked_array/list/sets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ fn binary(
let offset = if broadcast_rhs {
// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount
let a_iter = a.into_iter().skip(start_a).take(end_a - start_a);
let b_iter = b.into_iter();
let b_iter = b
.into_iter()
.skip(first_b as usize)
.take(second_b as usize - first_b as usize);
set_operation(
&mut set,
&mut set2,
Expand All @@ -314,7 +317,10 @@ fn binary(
true,
)
} else if broadcast_lhs {
let a_iter = a.into_iter();
let a_iter = a
.into_iter()
.skip(first_a as usize)
.take(second_a as usize - first_a as usize);
let b_iter = b.into_iter().skip(start_b).take(end_b - start_b);
set_operation(
&mut set,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ def test_list_set_operations_binary() -> None:
]


def test_list_set_operations_broadcast_binary() -> None:
df = pl.DataFrame(
{
"a": [["2", "3", "3"], ["3", "1"], ["1", "2", "3"]],
"b": [["1", "2"], ["4"], ["5"]],
}
)

assert df.select(pl.col("a").list.set_intersection(pl.col.b.first())).to_dict(
as_series=False
) == {"a": [["2"], ["1"], ["1", "2"]]}
assert df.select(pl.col("a").list.set_union(pl.col.b.first())).to_dict(
as_series=False
) == {"a": [["2", "3", "1"], ["3", "1", "2"], ["1", "2", "3"]]}
assert df.select(pl.col("a").list.set_difference(pl.col.b.first())).to_dict(
as_series=False
) == {"a": [["3"], ["3"], ["3"]]}
assert df.select(pl.col.b.first().list.set_difference("a")).to_dict(
as_series=False
) == {"b": [["1"], ["2"], []]}


def test_set_operations_14290() -> None:
df = pl.DataFrame(
{
Expand Down
Loading