Skip to content

Commit

Permalink
Update tests for simplify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Mar 6, 2025
1 parent f0bbf06 commit 74d6269
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,9 @@ mod tests {
let expected = col("c1").lt(null_i32());
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);

// cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12)
// cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL)
let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
let expected = null_i8().lt(lit(12i8));
let expected = null_bool();
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
}

Expand Down Expand Up @@ -528,7 +528,7 @@ mod tests {
// Verify reversed argument order
// arrow_cast('value', 'Dictionary<Int32, Utf8>') = cast(str1 as Dictionary<Int32, Utf8>) => Utf8('value1') = str1
let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type()));
let expected = lit("value").eq(col("str1"));
let expected = col("str1").eq(lit("value"));
assert_eq!(optimize_test(expr_input, &schema), expected);
}

Expand Down Expand Up @@ -645,15 +645,27 @@ mod tests {
#[test]
fn test_unwrap_list_cast_comparison() {
let schema = expr_test_schema();
// INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24))
let expr_lt =
cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), lit(24i64)], false);
let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false);
// INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) ->
// INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78))
let expr_lt = cast(col("c1"), DataType::Int64).in_list(
vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)],
false,
);
let expected = col("c1").in_list(
vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)],
false,
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
// INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24))
let expr_lt =
cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), lit(14i32)], false);
let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false);
// INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) ->
// INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78))
let expr_lt = cast(col("c2"), DataType::Int32).in_list(
vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)],
false,
);
let expected = col("c2").in_list(
vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)],
false,
);

assert_eq!(optimize_test(expr_lt, &schema), expected);

Expand All @@ -679,10 +691,14 @@ mod tests {
);
assert_eq!(optimize_test(expr_lt, &schema), expected);

// cast(INT32(12), INT64) IN (.....)
let expr_lt = cast(lit(12i32), DataType::Int64)
.in_list(vec![lit(13i64), lit(12i64)], false);
let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false);
// cast(INT32(12), INT64) IN (.....) =>
// INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16))
// => true
let expr_lt = cast(lit(12i32), DataType::Int64).in_list(
vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)],
false,
);
let expected = lit(true);
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

Expand Down Expand Up @@ -720,8 +736,12 @@ mod tests {
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);

// inlist for unsupported data type
let expr_input =
in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false);
let expr_input = in_list(
cast(col("c6"), DataType::Float64),
// need more literals to avoid rewriting to binary expr
vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)],
false,
);
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}

Expand Down Expand Up @@ -769,6 +789,10 @@ mod tests {
)
}

fn null_bool() -> Expr {
lit(ScalarValue::Boolean(None))
}

fn null_i8() -> Expr {
lit(ScalarValue::Int8(None))
}
Expand Down

0 comments on commit 74d6269

Please sign in to comment.