From 1d9a64b828d908110d67c590a2a7711a6554a9cc Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Sun, 7 Jul 2024 06:37:34 -0400 Subject: [PATCH] AggregateExec: Take grouping sets into account for InputOrderMode (#11301) * AggregateExec: Take grouping sets into account for InputOrderMode * pr comments --- .../physical-plan/src/aggregates/mod.rs | 121 ++++++++++++++++-- 1 file changed, 112 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e263876b07d5..09f00e80f6dd 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -346,14 +346,26 @@ impl AggregateExec { new_requirement.extend(req); new_requirement = collapse_lex_req(new_requirement); - let input_order_mode = - if indices.len() == groupby_exprs.len() && !indices.is_empty() { - InputOrderMode::Sorted - } else if !indices.is_empty() { - InputOrderMode::PartiallySorted(indices) - } else { - InputOrderMode::Linear - }; + // If our aggregation has grouping sets then our base grouping exprs will + // be expanded based on the flags in `group_by.groups` where for each + // group we swap the grouping expr for `null` if the flag is `true` + // That means that each index in `indices` is valid if and only if + // it is not null in every group + let indices: Vec = indices + .into_iter() + .filter(|idx| group_by.groups.iter().all(|group| !group[*idx])) + .collect(); + + let input_order_mode = if indices.len() == groupby_exprs.len() + && !indices.is_empty() + && group_by.groups.len() == 1 + { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = @@ -1227,6 +1239,7 @@ mod tests { use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; + use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, Result, ScalarValue, @@ -1235,13 +1248,15 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, + lit, ApproxDistinct, Count, FirstValue, LastValue, Literal, Median, + OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, }; + use crate::common::collect; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -2233,4 +2248,92 @@ mod tests { assert_eq!(new_agg.schema(), aggregate_exec.schema()); Ok(()) } + + #[tokio::test] + async fn test_agg_exec_group_by_const() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + Field::new("const", DataType::Int32, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + + let groups = PhysicalGroupBy::new( + vec![ + (col_a, "a".to_string()), + (col_b, "b".to_string()), + (const_expr, "const".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "b".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Int32(None))), + "const".to_string(), + ), + ], + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + ); + + let aggregates: Vec> = vec![create_aggregate_expr( + &AggregateFunction::Count, + false, + &[lit(1)], + &[], + schema.as_ref(), + "1", + false, + )?]; + + let input_batches = (0..4) + .map(|_| { + let a = Arc::new(Float32Array::from(vec![0.; 8192])); + let b = Arc::new(Float32Array::from(vec![0.; 8192])); + let c = Arc::new(Int32Array::from(vec![1; 8192])); + + RecordBatch::try_new(schema.clone(), vec![a, b, c]).unwrap() + }) + .collect(); + + let input = + Arc::new(MemoryExec::try_new(&[input_batches], schema.clone(), None)?); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None], + input, + schema, + )?); + + let output = + collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; + + let expected = [ + "+-----+-----+-------+----------+", + "| a | b | const | 1[count] |", + "+-----+-----+-------+----------+", + "| | 0.0 | | 32768 |", + "| 0.0 | | | 32768 |", + "| | | 1 | 32768 |", + "+-----+-----+-------+----------+", + ]; + assert_batches_sorted_eq!(expected, &output); + + Ok(()) + } }