Skip to content

Commit

Permalink
fix: Incorrect mode for sorted input (#18945)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Sep 26, 2024
1 parent 503582e commit 71a8b05
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
31 changes: 12 additions & 19 deletions crates/polars-ops/src/chunked_array/mode.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use arrow::legacy::utils::CustomIterTools;
use polars_core::prelude::*;
use polars_core::{with_match_physical_integer_polars_type, POOL};

Expand Down Expand Up @@ -33,29 +32,23 @@ fn mode_64(ca: &Float64Chunked) -> PolarsResult<Float64Chunked> {
fn mode_indices(groups: GroupsProxy) -> Vec<IdxSize> {
match groups {
GroupsProxy::Idx(groups) => {
let mut groups = groups.into_iter().collect_trusted::<Vec<_>>();
groups.sort_unstable_by_key(|k| k.1.len());
let last = &groups.last().unwrap();
let max_occur = last.1.len();
let Some(max_len) = groups.iter().map(|g| g.1.len()).max() else {
return Vec::new();
};
groups
.iter()
.rev()
.take_while(|v| v.1.len() == max_occur)
.map(|v| v.0)
.into_iter()
.filter(|g| g.1.len() == max_len)
.map(|g| g.0)
.collect()
},
GroupsProxy::Slice { groups, .. } => {
let last = groups.last().unwrap();
let max_occur = last[1];

let Some(max_len) = groups.iter().map(|g| g[1]).max() else {
return Vec::new();
};
groups
.iter()
.rev()
.take_while(|v| {
let len = v[1];
len == max_occur
})
.map(|v| v[0])
.into_iter()
.filter(|g| g[1] == max_len)
.map(|g| g[0])
.collect()
},
}
Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ def test_reinterpret() -> None:
def test_mode() -> None:
s = pl.Series("a", [1, 1, 2])
assert s.mode().to_list() == [1]
assert s.set_sorted().mode().to_list() == [1]

df = pl.DataFrame([s])
assert df.select([pl.col("a").mode()])["a"].to_list() == [1]
Expand All @@ -990,7 +991,7 @@ def test_mode() -> None:
assert pl.Series([1.0, 2.0, 3.0, 2.0]).mode().item() == 2.0

# sorted data
assert pl.int_range(0, 3, eager=True).mode().to_list() == [2, 1, 0]
assert set(pl.int_range(0, 3, eager=True).mode().to_list()) == {0, 1, 2}


def test_diff() -> None:
Expand Down

0 comments on commit 71a8b05

Please sign in to comment.