From 96669de2abd9056817f80a56628402bf6112b267 Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 12 Mar 2024 15:02:23 +0800 Subject: [PATCH] refactor: unify some plan optimization in CommonSubexprEliminate (#9556) --- .../optimizer/src/common_subexpr_eliminate.rs | 99 ++++--------------- 1 file changed, 19 insertions(+), 80 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 30c184a28e33..7b8eccad5133 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -33,9 +33,7 @@ use datafusion_common::{ DataFusionError, Result, }; use datafusion_expr::expr::Alias; -use datafusion_expr::logical_plan::{ - Aggregate, Filter, LogicalPlan, Projection, Sort, Window, -}; +use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; use datafusion_expr::{col, Expr, ExprSchemable}; /// A map from expression's identifier to tuple including @@ -44,13 +42,13 @@ use datafusion_expr::{col, Expr, ExprSchemable}; /// - DataType of this expression. type ExprSet = HashMap; -/// Identifier type. Current implementation use describe of a expression (type String) as +/// Identifier type. Current implementation use describe of an expression (type String) as /// Identifier. /// -/// A Identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no +/// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no /// collision (as low as possible)" /// -/// Since a identifier is likely to be copied many times, it is better that a identifier +/// Since an identifier is likely to be copied many times, it is better that an identifier /// is small or "copy". otherwise some kinds of reference count is needed. String description /// here is not such a good choose. type Identifier = String; @@ -108,61 +106,6 @@ impl CommonSubexprEliminate { Ok((rewrite_exprs, new_input)) } - fn try_optimize_projection( - &self, - projection: &Projection, - config: &dyn OptimizerConfig, - ) -> Result { - let Projection { expr, input, .. } = projection; - let input_schema = Arc::clone(input.schema()); - let mut expr_set = ExprSet::new(); - - // Visit expr list and build expr identifier to occuring count map (`expr_set`). - let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; - - let (mut new_expr, new_input) = - self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; - - // Since projection expr changes, schema changes also. Use try_new method. - Projection::try_new(pop_expr(&mut new_expr)?, Arc::new(new_input)) - .map(LogicalPlan::Projection) - } - - fn try_optimize_filter( - &self, - filter: &Filter, - config: &dyn OptimizerConfig, - ) -> Result { - let mut expr_set = ExprSet::new(); - let predicate = &filter.predicate; - let input_schema = Arc::clone(filter.input.schema()); - let mut id_array = vec![]; - expr_to_identifier( - predicate, - &mut expr_set, - &mut id_array, - input_schema, - ExprMask::Normal, - )?; - - let (mut new_expr, new_input) = self.rewrite_expr( - &[&[predicate.clone()]], - &[&[id_array]], - &filter.input, - &expr_set, - config, - )?; - - if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_input), - )?)) - } else { - internal_err!("Failed to pop predicate expr") - } - } - fn try_optimize_window( &self, window: &Window, @@ -354,25 +297,24 @@ impl CommonSubexprEliminate { } } - fn try_optimize_sort( + fn try_unary_plan( &self, - sort: &Sort, + plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result { - let Sort { expr, input, fetch } = sort; + let expr = plan.expressions(); + let inputs = plan.inputs(); + let input = inputs[0]; + let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); - let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; + // Visit expr list and build expr identifier to occuring count map (`expr_set`). + let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?; let (mut new_expr, new_input) = - self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; + self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?; - Ok(LogicalPlan::Sort(Sort { - expr: pop_expr(&mut new_expr)?, - input: Arc::new(new_input), - fetch: *fetch, - })) + plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) } } @@ -383,19 +325,15 @@ impl OptimizerRule for CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result> { let optimized_plan = match plan { - LogicalPlan::Projection(projection) => { - Some(self.try_optimize_projection(projection, config)?) - } - LogicalPlan::Filter(filter) => { - Some(self.try_optimize_filter(filter, config)?) - } + LogicalPlan::Projection(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?), LogicalPlan::Window(window) => { Some(self.try_optimize_window(window, config)?) } LogicalPlan::Aggregate(aggregate) => { Some(self.try_optimize_aggregate(aggregate, config)?) } - LogicalPlan::Sort(sort) => Some(self.try_optimize_sort(sort, config)?), LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) @@ -1321,7 +1259,8 @@ mod test { .build()?; let expected = "Projection: test.a, test.b, test.c\ - \n Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ + \n Filter: Int32(1) + test.atest.aInt32(1) - Int32(10) > Int32(1) + test.atest.aInt32(1)\ + \n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ \n TableScan: test"; assert_optimized_plan_eq(expected, &plan);