From 74d626982f0ff2803514e8dd5275f30b9e7b9835 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 5 Mar 2025 19:37:14 -0500 Subject: [PATCH] Update tests for simplify logic --- .../src/simplify_expressions/unwrap_cast.rs | 58 +++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 764c0745d241..7670bdf98bb4 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -492,9 +492,9 @@ mod tests { let expected = col("c1").lt(null_i32()); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); - // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) + // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL) let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); - let expected = null_i8().lt(lit(12i8)); + let expected = null_bool(); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } @@ -528,7 +528,7 @@ mod tests { // Verify reversed argument order // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); - let expected = lit("value").eq(col("str1")); + let expected = col("str1").eq(lit("value")); assert_eq!(optimize_test(expr_input, &schema), expected); } @@ -645,15 +645,27 @@ mod tests { #[test] fn test_unwrap_list_cast_comparison() { let schema = expr_test_schema(); - // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), lit(24i64)], false); - let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false); + // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) -> + // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78)) + let expr_lt = cast(col("c1"), DataType::Int64).in_list( + vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)], + false, + ); + let expected = col("c1").in_list( + vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)], + false, + ); assert_eq!(optimize_test(expr_lt, &schema), expected); - // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), lit(14i32)], false); - let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false); + // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) -> + // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78)) + let expr_lt = cast(col("c2"), DataType::Int32).in_list( + vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)], + false, + ); + let expected = col("c2").in_list( + vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)], + false, + ); assert_eq!(optimize_test(expr_lt, &schema), expected); @@ -679,10 +691,14 @@ mod tests { ); assert_eq!(optimize_test(expr_lt, &schema), expected); - // cast(INT32(12), INT64) IN (.....) - let expr_lt = cast(lit(12i32), DataType::Int64) - .in_list(vec![lit(13i64), lit(12i64)], false); - let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false); + // cast(INT32(12), INT64) IN (.....) => + // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16)) + // => true + let expr_lt = cast(lit(12i32), DataType::Int64).in_list( + vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)], + false, + ); + let expected = lit(true); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -720,8 +736,12 @@ mod tests { assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); // inlist for unsupported data type - let expr_input = - in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false); + let expr_input = in_list( + cast(col("c6"), DataType::Float64), + // need more literals to avoid rewriting to binary expr + vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)], + false, + ); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } @@ -769,6 +789,10 @@ mod tests { ) } + fn null_bool() -> Expr { + lit(ScalarValue::Boolean(None)) + } + fn null_i8() -> Expr { lit(ScalarValue::Int8(None)) }