Skip to content

Commit

Permalink
Add tests for SessionContext::create_physical_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Mar 5, 2025
1 parent ff67314 commit b59da14
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ use datafusion_functions_nested::expr_ext::{IndexAccessor, SliceAccessor};
use sqlparser::ast::NullTreatment;
/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan
use std::sync::{Arc, LazyLock};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_optimizer::simplify_expressions::ExprSimplifier;

mod parse_sql_expr;
mod simplification;
Expand Down Expand Up @@ -304,6 +307,38 @@ async fn test_aggregate_ext_null_treatment() {
.await;
}

#[tokio::test]
async fn test_create_physical_expr() {
// create_physical_expr does not simplify the expression
// 1 + 1
create_expr_test(lit(1i32) + lit(2i32), "1 + 2");
// However, you can run the simplifier before creating the physical
// expression. This mimics what delta.rs and other non-sql libraries do to
// create predicates
//
// 1 + 1
create_simplified_expr_test(lit(1i32) + lit(2i32), "3");
}

#[tokio::test]
async fn test_create_physical_expr_coercion() {
// create_physical_expr does apply type coercion and unwrapping in cast
//
// expect the cast on the literals
// compare string function to int `id = 1`
create_expr_test(col("id").eq(lit(1i32)), "id@0 = CAST(1 AS Utf8)");
create_expr_test(lit(1i32).eq(col("id")), "CAST(1 AS Utf8) = id@0");
// compare int col to string literal `i = '202410'`
// Note this casts the column (not the field)
create_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410");
create_expr_test(lit("202410").eq(col("i")), "202410 = CAST(i@1 AS Utf8)");
// however, when simplified the casts on i should removed
create_simplified_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410");
create_simplified_expr_test(lit("202410").eq(col("i")), "CAST(i@1 AS Utf8) = 202410");
}



/// Evaluates the specified expr as an aggregate and compares the result to the
/// expected result.
async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) {
Expand Down Expand Up @@ -350,6 +385,38 @@ fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) {
);
}

/// Creates the physical expression from Expr and compares the Debug expression
/// to the expected result.
fn create_expr_test(expr: Expr, expected_expr: &str) {
let batch = &TEST_BATCH;
let df_schema = DFSchema::try_from(batch.schema()).unwrap();
let physical_expr = SessionContext::new()
.create_physical_expr(expr, &df_schema)
.unwrap();

assert_eq!(physical_expr.to_string(), expected_expr);
}

/// Creates the physical expression from Expr and runs the expr simplifier
fn create_simplified_expr_test(expr: Expr, expected_expr: &str) {
let batch = &TEST_BATCH;
let df_schema = DFSchema::try_from(batch.schema()).unwrap();

// Simplify the expression first
let props = ExecutionProps::new();
let simplify_context =
SimplifyContext::new(&props).with_schema(df_schema.clone().into());
let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10);
let simplified = simplifier.simplify(expr).unwrap();
create_expr_test(simplified, expected_expr);
}

/// Returns a Batch with 3 rows and 4 columns:
///
/// id: Utf8
/// i: Int64
/// props: Struct
/// list: List<String>
static TEST_BATCH: LazyLock<RecordBatch> = LazyLock::new(|| {
let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"]));
let int_array: ArrayRef =
Expand Down

0 comments on commit b59da14

Please sign in to comment.