From 7a510220ced1fb72a0914efd9e11dd669348568c Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 1 Mar 2025 12:44:16 +0100 Subject: [PATCH] refactor: simplify expressions first Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/delta_datafusion/mod.rs | 18 ++++++++-- crates/core/src/operations/merge/mod.rs | 47 ++++++++++++++++--------- python/tests/test_merge.py | 8 ++--- 3 files changed, 49 insertions(+), 24 deletions(-) diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index e2fdcbdada..ab475b55ea 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -47,6 +47,7 @@ use datafusion::datasource::{listing::PartitionedFile, MemTable, TableProvider, use datafusion::execution::context::{SessionConfig, SessionContext, SessionState, TaskContext}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; +use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; @@ -56,6 +57,7 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::CreateExternalTable; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::conjunction; use datafusion_expr::{col, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility}; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; @@ -545,9 +547,19 @@ impl<'a> DeltaScanBuilder<'a> { let context = SessionContext::new(); let df_schema = logical_schema.clone().to_dfschema()?; - let logical_filter = self - .filter - .map(|expr| context.create_physical_expr(expr, &df_schema).unwrap()); + + let logical_filter = self.filter.map(|expr| { + // Simplify the expression first + let props = ExecutionProps::new(); + let simplify_context = + SimplifyContext::new(&props).with_schema(df_schema.clone().into()); + let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10); + let simplified = simplifier.simplify(expr).unwrap(); + + context + .create_physical_expr(simplified, &df_schema) + .unwrap() + }); // Perform Pruning of files to scan let (files, files_scanned, files_pruned) = match self.files { diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 26829c569a..bf073b69fa 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -41,6 +41,7 @@ use datafusion::error::Result as DataFusionResult; use datafusion::execution::context::SessionConfig; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::build_join_schema; +use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_plan::metrics::MetricBuilder; use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner}; use datafusion::{ @@ -50,6 +51,8 @@ use datafusion::{ }; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ExprSchema, ScalarValue, TableReference}; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, @@ -60,6 +63,7 @@ use filter::try_construct_early_filter; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; use serde::Serialize; +use tracing::field::debug; use tracing::log::*; use uuid::Uuid; @@ -844,11 +848,32 @@ async fn execute( streaming, ) .await? + } + .map(|e| { + // simplify the expression so we have + let props = ExecutionProps::new(); + let simplify_context = SimplifyContext::new(&props).with_schema(target.schema().clone()); + let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10); + simplifier.simplify(e).unwrap() + }); + + // Predicate will be used for conflict detection + let commit_predicate = match target_subset_filter.clone() { + None => None, // No predicate means it's a full table merge + Some(some_filter) => { + let predict_expr = match &target_alias { + None => some_filter, + Some(alias) => remove_table_alias(some_filter, alias), + }; + Some(fmt_expr_to_sql(&predict_expr)?) + } }; + debug!("Using target subset filter: {:?}", commit_predicate); + let file_column = Arc::new(scan_config.file_column_name.clone().unwrap()); // Need to manually push this filter into the scan... We want to PRUNE files not FILTER RECORDS - let target = match target_subset_filter.clone() { + let target = match target_subset_filter { Some(filter) => { let filter = match &target_alias { Some(alias) => remove_table_alias(filter, alias), @@ -1407,18 +1432,6 @@ async fn execute( app_metadata.insert("operationMetrics".to_owned(), map); } - // Predicate will be used for conflict detection - let commit_predicate = match target_subset_filter { - None => None, // No predicate means it's a full table merge - Some(some_filter) => { - let predict_expr = match &target_alias { - None => some_filter, - Some(alias) => remove_table_alias(some_filter, alias), - }; - Some(fmt_expr_to_sql(&predict_expr)?) - } - }; - // Do not make a commit when there are zero updates to the state let operation = DeltaOperation::Merge { predicate: commit_predicate, @@ -2522,7 +2535,7 @@ mod tests { let parameters = last_commit.operation_parameters.clone().unwrap(); assert_eq!( parameters["predicate"], - "id BETWEEN 'B' AND 'C' AND modified = '2021-02-02'" + "id >= 'B' AND id <= 'C' AND modified = '2021-02-02'" ); assert_eq!( parameters["mergePredicate"], @@ -2773,7 +2786,7 @@ mod tests { extra_info["operationMetrics"], serde_json::to_value(&metrics).unwrap() ); - assert_eq!(parameters["predicate"], "id BETWEEN 'B' AND 'X'"); + assert_eq!(parameters["predicate"], "id >= 'B' AND id <= 'X'"); assert_eq!(parameters["mergePredicate"], json!("target.id = source.id")); assert_eq!( parameters["matchedPredicates"], @@ -3192,7 +3205,7 @@ mod tests { assert_eq!( parameters["predicate"], - json!("id BETWEEN 'B' AND 'X' AND modified = '2021-02-02'") + json!("id >= 'B' AND id <= 'X' AND modified = '2021-02-02'") ); let expected = vec![ @@ -3276,7 +3289,7 @@ mod tests { assert_eq!( parameters["predicate"], - json!("id BETWEEN 'B' AND 'X' AND modified = '2021-02-02'") + json!("id >= 'B' AND id <= 'X' AND modified = '2021-02-02'") ); let expected = vec![ diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 5fa9b73406..b66019187e 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -1066,7 +1066,7 @@ def test_merge_date_partitioned_2344(tmp_path: pathlib.Path, streaming: bool): if not streaming: assert ( last_action["operationParameters"].get("predicate") - == "'2022-02-01'::date = date" + == "date = '2022-02-01'::date" ) else: # In streaming mode we don't use aggregated stats of the source in the predicate @@ -1078,11 +1078,11 @@ def test_merge_date_partitioned_2344(tmp_path: pathlib.Path, streaming: bool): [ ( None, - "arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, None)') = datetime", + "datetime = arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, None)')", ), ( "UTC", - "arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))') = datetime", + "datetime = arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))')", ), ], ) @@ -1457,7 +1457,7 @@ def test_merge_on_decimal_3033(tmp_path): assert ( string_predicate - == "timestamp BETWEEN arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND altitude BETWEEN '1505'::decimal(4, 1) AND '1505'::decimal(4, 1)" + == "timestamp >= arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND timestamp <= arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND altitude >= '1505'::decimal(4, 1) AND altitude <= '1505'::decimal(4, 1)" )