Skip to content

Commit

Permalink
refactor: simplify expressions first
Browse files Browse the repository at this point in the history
Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com>
  • Loading branch information
ion-elgreco committed Mar 1, 2025
1 parent 49ed2c8 commit 7a51022
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
18 changes: 15 additions & 3 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion::datasource::{listing::PartitionedFile, MemTable, TableProvider,
use datafusion::execution::context::{SessionConfig, SessionContext, SessionState, TaskContext};
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::execution::FunctionRegistry;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_optimizer::pruning::PruningPredicate;
use datafusion_common::scalar::ScalarValue;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
Expand All @@ -56,6 +57,7 @@ use datafusion_common::{
};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::logical_plan::CreateExternalTable;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::utils::conjunction;
use datafusion_expr::{col, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility};
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
Expand Down Expand Up @@ -545,9 +547,19 @@ impl<'a> DeltaScanBuilder<'a> {

let context = SessionContext::new();
let df_schema = logical_schema.clone().to_dfschema()?;
let logical_filter = self
.filter
.map(|expr| context.create_physical_expr(expr, &df_schema).unwrap());

let logical_filter = self.filter.map(|expr| {
// Simplify the expression first
let props = ExecutionProps::new();
let simplify_context =
SimplifyContext::new(&props).with_schema(df_schema.clone().into());
let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10);
let simplified = simplifier.simplify(expr).unwrap();

context
.create_physical_expr(simplified, &df_schema)
.unwrap()
});

// Perform Pruning of files to scan
let (files, files_scanned, files_pruned) = match self.files {
Expand Down
47 changes: 30 additions & 17 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use datafusion::error::Result as DataFusionResult;
use datafusion::execution::context::SessionConfig;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::logical_expr::build_join_schema;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_plan::metrics::MetricBuilder;
use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
use datafusion::{
Expand All @@ -50,6 +51,8 @@ use datafusion::{
};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{Column, DFSchema, ExprSchema, ScalarValue, TableReference};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType};
use datafusion_expr::{
Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE,
Expand All @@ -60,6 +63,7 @@ use filter::try_construct_early_filter;
use futures::future::BoxFuture;
use parquet::file::properties::WriterProperties;
use serde::Serialize;
use tracing::field::debug;
use tracing::log::*;
use uuid::Uuid;

Expand Down Expand Up @@ -844,11 +848,32 @@ async fn execute(
streaming,
)
.await?
}
.map(|e| {
// simplify the expression so we have
let props = ExecutionProps::new();
let simplify_context = SimplifyContext::new(&props).with_schema(target.schema().clone());
let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10);
simplifier.simplify(e).unwrap()
});

// Predicate will be used for conflict detection
let commit_predicate = match target_subset_filter.clone() {
None => None, // No predicate means it's a full table merge
Some(some_filter) => {
let predict_expr = match &target_alias {
None => some_filter,
Some(alias) => remove_table_alias(some_filter, alias),
};
Some(fmt_expr_to_sql(&predict_expr)?)
}
};

debug!("Using target subset filter: {:?}", commit_predicate);

let file_column = Arc::new(scan_config.file_column_name.clone().unwrap());
// Need to manually push this filter into the scan... We want to PRUNE files not FILTER RECORDS
let target = match target_subset_filter.clone() {
let target = match target_subset_filter {
Some(filter) => {
let filter = match &target_alias {
Some(alias) => remove_table_alias(filter, alias),
Expand Down Expand Up @@ -1407,18 +1432,6 @@ async fn execute(
app_metadata.insert("operationMetrics".to_owned(), map);
}

// Predicate will be used for conflict detection
let commit_predicate = match target_subset_filter {
None => None, // No predicate means it's a full table merge
Some(some_filter) => {
let predict_expr = match &target_alias {
None => some_filter,
Some(alias) => remove_table_alias(some_filter, alias),
};
Some(fmt_expr_to_sql(&predict_expr)?)
}
};

// Do not make a commit when there are zero updates to the state
let operation = DeltaOperation::Merge {
predicate: commit_predicate,
Expand Down Expand Up @@ -2522,7 +2535,7 @@ mod tests {
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(
parameters["predicate"],
"id BETWEEN 'B' AND 'C' AND modified = '2021-02-02'"
"id >= 'B' AND id <= 'C' AND modified = '2021-02-02'"
);
assert_eq!(
parameters["mergePredicate"],
Expand Down Expand Up @@ -2773,7 +2786,7 @@ mod tests {
extra_info["operationMetrics"],
serde_json::to_value(&metrics).unwrap()
);
assert_eq!(parameters["predicate"], "id BETWEEN 'B' AND 'X'");
assert_eq!(parameters["predicate"], "id >= 'B' AND id <= 'X'");
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
Expand Down Expand Up @@ -3192,7 +3205,7 @@ mod tests {

assert_eq!(
parameters["predicate"],
json!("id BETWEEN 'B' AND 'X' AND modified = '2021-02-02'")
json!("id >= 'B' AND id <= 'X' AND modified = '2021-02-02'")
);

let expected = vec![
Expand Down Expand Up @@ -3276,7 +3289,7 @@ mod tests {

assert_eq!(
parameters["predicate"],
json!("id BETWEEN 'B' AND 'X' AND modified = '2021-02-02'")
json!("id >= 'B' AND id <= 'X' AND modified = '2021-02-02'")
);

let expected = vec![
Expand Down
8 changes: 4 additions & 4 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ def test_merge_date_partitioned_2344(tmp_path: pathlib.Path, streaming: bool):
if not streaming:
assert (
last_action["operationParameters"].get("predicate")
== "'2022-02-01'::date = date"
== "date = '2022-02-01'::date"
)
else:
# In streaming mode we don't use aggregated stats of the source in the predicate
Expand All @@ -1078,11 +1078,11 @@ def test_merge_date_partitioned_2344(tmp_path: pathlib.Path, streaming: bool):
[
(
None,
"arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, None)') = datetime",
"datetime = arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, None)')",
),
(
"UTC",
"arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))') = datetime",
"datetime = arrow_cast('2022-02-01T00:00:00.000000', 'Timestamp(Microsecond, Some(\"UTC\"))')",
),
],
)
Expand Down Expand Up @@ -1457,7 +1457,7 @@ def test_merge_on_decimal_3033(tmp_path):

assert (
string_predicate
== "timestamp BETWEEN arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND altitude BETWEEN '1505'::decimal(4, 1) AND '1505'::decimal(4, 1)"
== "timestamp >= arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND timestamp <= arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND altitude >= '1505'::decimal(4, 1) AND altitude <= '1505'::decimal(4, 1)"
)


Expand Down

0 comments on commit 7a51022

Please sign in to comment.