From f1508b55122185c9193e5055547c64a996713f4a Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Thu, 26 Sep 2024 23:25:24 +1000 Subject: [PATCH 1/5] fix: Fix `Expr.over` with `order_by` did not take effect if group keys were sorted --- .../polars-core/src/frame/group_by/proxy.rs | 2 +- crates/polars-expr/src/expressions/window.rs | 22 ++++++------------- .../tests/unit/operations/test_window.py | 18 +++++++++++++++ 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index d9aedc261faf..adddf4e8dac7 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -395,7 +395,7 @@ impl GroupsProxy { } } - pub(crate) fn is_sorted_flag(&self) -> bool { + pub fn is_sorted_flag(&self) -> bool { match self { GroupsProxy::Idx(groups) => groups.is_sorted_flag(), GroupsProxy::Slice { .. } => true, diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index 5a455cf5932b..2d1801753ec5 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -315,7 +315,6 @@ impl WindowExpr { fn determine_map_strategy( &self, agg_state: &AggState, - sorted_keys: bool, gb: &GroupBy, ) -> PolarsResult { match (self.mapping, agg_state) { @@ -333,19 +332,12 @@ impl WindowExpr { (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join), // no explicit aggregations, map over the groups //`(col("x").sum() * col("y")).over("groups")` - (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => { - if sorted_keys { - if let GroupsProxy::Idx(g) = gb.get_groups() { - debug_assert!(g.is_sorted_flag()) - } - // GroupsProxy::Slice is always sorted - - // Note that group columns must be sorted for this to make sense!!! - Ok(MapStrategy::Explode) - } else { - Ok(MapStrategy::Map) - } + (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) + if gb.get_groups().is_sorted_flag() => + { + Ok(MapStrategy::Explode) }, + (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => Ok(MapStrategy::Map), // no aggregations, just return column // or an aggregation that has been flattened // we have to check which one @@ -502,7 +494,7 @@ impl PhysicalExpr for WindowExpr { // to make sure that the caches align we sort // the groups, so that the cached groups and join keys // are consistent among all windows - if sort_groups || state.cache_window() { + if self.order_by.is_none() && (sort_groups || state.cache_window()) { groups.sort() } let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns)); @@ -516,7 +508,7 @@ impl PhysicalExpr for WindowExpr { let mut ac = self.run_aggregation(df, state, &gb)?; use MapStrategy::*; - match self.determine_map_strategy(ac.agg_state(), sorted_keys, &gb)? { + match self.determine_map_strategy(ac.agg_state(), &gb)? { Nothing => { let mut out = ac.flat_naive().into_owned(); diff --git a/py-polars/tests/unit/operations/test_window.py b/py-polars/tests/unit/operations/test_window.py index 8171fd5f9b03..69b9cd8f0e55 100644 --- a/py-polars/tests/unit/operations/test_window.py +++ b/py-polars/tests/unit/operations/test_window.py @@ -518,3 +518,21 @@ def test_lit_window_broadcast() -> None: assert pl.DataFrame({"a": [1, 1, 2]}).select(pl.lit(0).over("a").alias("a"))[ "a" ].to_list() == [0, 0, 0] + + +def test_order_by_sorted_keys_18943() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 1], + "t": [4, 3, 2, 1], + "x": [10, 20, 30, 40], + } + ) + + expect = pl.DataFrame({"x": [100, 90, 70, 40]}) + + out = df.select(pl.col("x").cum_sum().over("g", order_by="t")) + assert_frame_equal(out, expect) + + out = df.set_sorted("g").select(pl.col("x").cum_sum().over("g", order_by="t")) + assert_frame_equal(out, expect) From e4af44f779434945e4bf15604c5d486164ed19d3 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Fri, 27 Sep 2024 15:48:28 +1000 Subject: [PATCH 2/5] c --- crates/polars-expr/src/expressions/window.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index 2d1801753ec5..97ff973f9108 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -335,7 +335,11 @@ impl WindowExpr { (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) if gb.get_groups().is_sorted_flag() => { - Ok(MapStrategy::Explode) + if let GroupsProxy::Idx(_) = gb.get_groups() { + Ok(MapStrategy::Map) + } else { + Ok(MapStrategy::Explode) + } }, (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => Ok(MapStrategy::Map), // no aggregations, just return column From bfe67a3422789a27fbcd6c3d157000b816431913 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Fri, 27 Sep 2024 15:50:11 +1000 Subject: [PATCH 3/5] c --- crates/polars-expr/src/expressions/window.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index 97ff973f9108..ddff0f0975f6 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -498,7 +498,7 @@ impl PhysicalExpr for WindowExpr { // to make sure that the caches align we sort // the groups, so that the cached groups and join keys // are consistent among all windows - if self.order_by.is_none() && (sort_groups || state.cache_window()) { + if sort_groups || state.cache_window() { groups.sort() } let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns)); From 140bf152a06ad98103239f7773d99446be358cfe Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Fri, 27 Sep 2024 16:16:25 +1000 Subject: [PATCH 4/5] c --- crates/polars/tests/it/lazy/expressions/window.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/polars/tests/it/lazy/expressions/window.rs b/crates/polars/tests/it/lazy/expressions/window.rs index 21d8a3d26bf7..fb52ac3810bb 100644 --- a/crates/polars/tests/it/lazy/expressions/window.rs +++ b/crates/polars/tests/it/lazy/expressions/window.rs @@ -150,9 +150,7 @@ fn test_sort_by_in_groups() -> PolarsResult<()> { col("cars"), col("A") .sort_by([col("B")], SortMultipleOptions::default()) - .implode() .over([col("cars")]) - .explode() .alias("sorted_A_by_B"), ]) .collect()?; From ccc9580f05f067823d367c6aa12e9995991502ae Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Fri, 27 Sep 2024 16:28:32 +1000 Subject: [PATCH 5/5] c --- crates/polars-core/src/frame/group_by/proxy.rs | 2 +- crates/polars-expr/src/expressions/window.rs | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index adddf4e8dac7..d9aedc261faf 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -395,7 +395,7 @@ impl GroupsProxy { } } - pub fn is_sorted_flag(&self) -> bool { + pub(crate) fn is_sorted_flag(&self) -> bool { match self { GroupsProxy::Idx(groups) => groups.is_sorted_flag(), GroupsProxy::Slice { .. } => true, diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index ddff0f0975f6..b47d1744f662 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -332,16 +332,14 @@ impl WindowExpr { (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join), // no explicit aggregations, map over the groups //`(col("x").sum() * col("y")).over("groups")` - (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) - if gb.get_groups().is_sorted_flag() => - { - if let GroupsProxy::Idx(_) = gb.get_groups() { - Ok(MapStrategy::Map) - } else { + (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => { + if let GroupsProxy::Slice { .. } = gb.get_groups() { + // Result can be directly exploded if the input was sorted. Ok(MapStrategy::Explode) + } else { + Ok(MapStrategy::Map) } }, - (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => Ok(MapStrategy::Map), // no aggregations, just return column // or an aggregation that has been flattened // we have to check which one