diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs index a36b161775ca..1c981c9eb3e2 100644 --- a/crates/polars-ops/src/chunked_array/mode.rs +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -1,4 +1,3 @@ -use arrow::legacy::utils::CustomIterTools; use polars_core::prelude::*; use polars_core::{with_match_physical_integer_polars_type, POOL}; @@ -33,29 +32,23 @@ fn mode_64(ca: &Float64Chunked) -> PolarsResult { fn mode_indices(groups: GroupsProxy) -> Vec { match groups { GroupsProxy::Idx(groups) => { - let mut groups = groups.into_iter().collect_trusted::>(); - groups.sort_unstable_by_key(|k| k.1.len()); - let last = &groups.last().unwrap(); - let max_occur = last.1.len(); + let Some(max_len) = groups.iter().map(|g| g.1.len()).max() else { + return Vec::new(); + }; groups - .iter() - .rev() - .take_while(|v| v.1.len() == max_occur) - .map(|v| v.0) + .into_iter() + .filter(|g| g.1.len() == max_len) + .map(|g| g.0) .collect() }, GroupsProxy::Slice { groups, .. } => { - let last = groups.last().unwrap(); - let max_occur = last[1]; - + let Some(max_len) = groups.iter().map(|g| g[1]).max() else { + return Vec::new(); + }; groups - .iter() - .rev() - .take_while(|v| { - let len = v[1]; - len == max_occur - }) - .map(|v| v[0]) + .into_iter() + .filter(|g| g[1] == max_len) + .map(|g| g[0]) .collect() }, } diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 2a4d6f1d5285..0f45cdfb6e21 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -980,6 +980,7 @@ def test_reinterpret() -> None: def test_mode() -> None: s = pl.Series("a", [1, 1, 2]) assert s.mode().to_list() == [1] + assert s.set_sorted().mode().to_list() == [1] df = pl.DataFrame([s]) assert df.select([pl.col("a").mode()])["a"].to_list() == [1] @@ -990,7 +991,7 @@ def test_mode() -> None: assert pl.Series([1.0, 2.0, 3.0, 2.0]).mode().item() == 2.0 # sorted data - assert pl.int_range(0, 3, eager=True).mode().to_list() == [2, 1, 0] + assert set(pl.int_range(0, 3, eager=True).mode().to_list()) == {0, 1, 2} def test_diff() -> None: