diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 44f5a2340c09..7670bdf98bb4 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! UnwrapCastInComparison attempts to remove casts from +//! Unwrap casts in binary comparisons +//! +//! The functions in this module attempt to remove casts from //! comparisons to literals ([`ScalarValue`]s) by applying the casts //! to the literals if possible. It is inspired by the optimizer rule //! `UnwrapCastInBinaryComparison` of Spark. @@ -43,13 +45,13 @@ //! If the DataType of c1 is INT32. Given the filter //! //! ```text -//! Filter: cast(c1 as INT64) > INT64(10)` +//! cast(c1 as INT64) > INT64(10)` //! ``` //! //! This rule will remove the cast and rewrite the expression to: //! //! ```text -//! Filter: c1 > INT32(10) +//! c1 > INT32(10) //! ``` //! @@ -437,3 +439,911 @@ fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option Ordering::Equal => Some(value), } } + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + use std::sync::Arc; + + use crate::simplify_expressions::ExprSimplifier; + use arrow::compute::{cast_with_options, CastOptions}; + use arrow::datatypes::Field; + use datafusion_common::{DFSchema, DFSchemaRef}; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::simplify::SimplifyContext; + use datafusion_expr::{cast, col, in_list, try_cast}; + + #[test] + fn test_not_unwrap_cast_comparison() { + let schema = expr_test_schema(); + // cast(INT32(c1), INT64) > INT64(c2) + let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2")); + assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); + + // INT32(c1) < INT32(16), the type is same + let expr_lt = col("c1").lt(lit(16i32)); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type + let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + } + + #[test] + fn test_unwrap_cast_comparison() { + let schema = expr_test_schema(); + // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16)) + // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) + let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expected = col("c1").lt(lit(16i32)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expected = col("c1").lt(lit(16i32)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) + let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32)); + let expected = col("c2").eq(lit(16i64)); + assert_eq!(optimize_test(c2_eq_lit, &schema), expected); + + // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL) + let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64()); + let expected = col("c1").lt(null_i32()); + assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); + + // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL) + let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); + let expected = null_bool(); + assert_eq!(optimize_test(lit_lt_lit, &schema), expected); + } + + #[test] + fn test_unwrap_cast_comparison_unsigned() { + // "cast(c6, UINT64) = 0u64 => c6 = 0u32 + let schema = expr_test_schema(); + let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); + let expected = col("c6").eq(lit(0u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + } + + #[test] + fn test_unwrap_cast_comparison_string() { + let schema = expr_test_schema(); + let dict = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::from("value")), + ); + + // cast(str1 as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = Utf8('value1') + let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone())); + let expected = col("str1").eq(lit("value")); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') + let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); + let expected = col("tag").eq(lit(dict.clone())); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // Verify reversed argument order + // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 + let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); + let expected = col("str1").eq(lit("value")); + assert_eq!(optimize_test(expr_input, &schema), expected); + } + + #[test] + fn test_unwrap_cast_comparison_large_string() { + let schema = expr_test_schema(); + // cast(largestr as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = LargeUtf8('value1') + let dict = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), + ); + let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict)); + let expected = + col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); + assert_eq!(optimize_test(expr_input, &schema), expected); + } + + #[test] + fn test_not_unwrap_cast_with_decimal_comparison() { + let schema = expr_test_schema(); + // integer to decimal: value is out of the bounds of the decimal + // cast(c3, INT64) = INT64(100000000000000000) + let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64)); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + + // cast(c4, INT64) = INT64(1000) will overflow the i128 + let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64)); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + + // decimal to decimal: value will lose the scale when convert to the target data type + // c3 = DECIMAL(12340,20,4) + let expr_eq = + cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4)); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + + // decimal to integer + // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type + let expr_eq = + cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1)); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + + // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type + let expr_eq = + cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2)); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + } + + #[test] + fn test_unwrap_cast_with_decimal_lit_comparison() { + let schema = expr_test_schema(); + // integer to decimal + // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); + let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64)); + let expected = col("c3").lt(lit_decimal(1600, 18, 2)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // c3 < INT64(NULL) + let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64()); + let expected = col("c3").lt(null_decimal(18, 2)); + assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); + + // decimal to decimal + // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) + let expr_lt = + cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0)); + let expected = col("c3").lt(lit_decimal(12300, 18, 2)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) + let expr_lt = + cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3)); + let expected = col("c3").lt(lit_decimal(123, 18, 2)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // decimal to integer + // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) + let expr_lt = + cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2)); + let expected = col("c1").lt(lit(123i32)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + + #[test] + fn test_not_unwrap_list_cast_lit_comparison() { + let schema = expr_test_schema(); + // internal left type is not supported + // FLOAT32(C5) in ... + let expr_lt = + cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) + let expr_lt = cast(col("c1"), DataType::Float32) + .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // INT32(C1) in (INT64(99999999999), INT64(12)) + let expr_lt = cast(col("c1"), DataType::Int64) + .in_list(vec![lit(12i32), lit(99999999999i64)], false); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) + let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list( + vec![ + lit_decimal(12, 12, 3), + lit_decimal(12, 12, 3), + lit_decimal(128, 12, 3), + ], + false, + ); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + } + + #[test] + fn test_unwrap_list_cast_comparison() { + let schema = expr_test_schema(); + // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) -> + // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78)) + let expr_lt = cast(col("c1"), DataType::Int64).in_list( + vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)], + false, + ); + let expected = col("c1").in_list( + vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)], + false, + ); + assert_eq!(optimize_test(expr_lt, &schema), expected); + // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) -> + // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78)) + let expr_lt = cast(col("c2"), DataType::Int32).in_list( + vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)], + false, + ); + let expected = col("c2").in_list( + vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)], + false, + ); + + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // decimal test case + // c3 is decimal(18,2) + let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list( + vec![ + lit_decimal(12000, 19, 3), + lit_decimal(24000, 19, 3), + lit_decimal(1280, 19, 3), + lit_decimal(1240, 19, 3), + ], + false, + ); + let expected = col("c3").in_list( + vec![ + lit_decimal(1200, 18, 2), + lit_decimal(2400, 18, 2), + lit_decimal(128, 18, 2), + lit_decimal(124, 18, 2), + ], + false, + ); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // cast(INT32(12), INT64) IN (.....) => + // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16)) + // => true + let expr_lt = cast(lit(12i32), DataType::Int64).in_list( + vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)], + false, + ); + let expected = lit(true); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + + #[test] + fn aliased() { + let schema = expr_test_schema(); + // c1 < INT64(16) -> c1 < cast(INT32(16)) + // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) + let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x"); + let expected = col("c1").lt(lit(16i32)).alias("x"); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + + #[test] + fn nested() { + let schema = expr_test_schema(); + // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) + // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 + let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast( + col("c1"), + DataType::Int64, + ) + .gt(lit(32i64))); + let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + + #[test] + fn test_not_support_data_type() { + // "c6 > 0" will be cast to `cast(c6 as float) > 0 + // but the type of c6 is uint32 + // the rewriter will not throw error and just return the original expr + let schema = expr_test_schema(); + let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64)); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); + + // inlist for unsupported data type + let expr_input = in_list( + cast(col("c6"), DataType::Float64), + // need more literals to avoid rewriting to binary expr + vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)], + false, + ); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); + } + + #[test] + /// Basic integration test for unwrapping casts with different timezones + fn test_unwrap_cast_with_timestamp_nanos() { + let schema = expr_test_schema(); + // cast(ts_nano as Timestamp(Nanosecond, UTC)) < 1666612093000000000::Timestamp(Nanosecond, Utc)) + let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type()) + .lt(lit_timestamp_nano_utc(1666612093000000000)); + let expected = + col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000)); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + let props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&props).with_schema(Arc::clone(schema)), + ); + + simplifier.simplify(expr).unwrap() + } + + fn expr_test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), + Field::new("c3", DataType::Decimal128(18, 2), false), + Field::new("c4", DataType::Decimal128(38, 37), false), + Field::new("c5", DataType::Float32, false), + Field::new("c6", DataType::UInt32, false), + Field::new("ts_nano_none", timestamp_nano_none_type(), false), + Field::new("ts_nano_utf", timestamp_nano_utc_type(), false), + Field::new("str1", DataType::Utf8, false), + Field::new("largestr", DataType::LargeUtf8, false), + Field::new("tag", dictionary_tag_type(), false), + ] + .into(), + HashMap::new(), + ) + .unwrap(), + ) + } + + fn null_bool() -> Expr { + lit(ScalarValue::Boolean(None)) + } + + fn null_i8() -> Expr { + lit(ScalarValue::Int8(None)) + } + + fn null_i32() -> Expr { + lit(ScalarValue::Int32(None)) + } + + fn null_i64() -> Expr { + lit(ScalarValue::Int64(None)) + } + + fn lit_decimal(value: i128, precision: u8, scale: i8) -> Expr { + lit(ScalarValue::Decimal128(Some(value), precision, scale)) + } + + fn lit_timestamp_nano_none(ts: i64) -> Expr { + lit(ScalarValue::TimestampNanosecond(Some(ts), None)) + } + + fn lit_timestamp_nano_utc(ts: i64) -> Expr { + let utc = Some("+0:00".into()); + lit(ScalarValue::TimestampNanosecond(Some(ts), utc)) + } + + fn null_decimal(precision: u8, scale: i8) -> Expr { + lit(ScalarValue::Decimal128(None, precision, scale)) + } + + fn timestamp_nano_none_type() -> DataType { + DataType::Timestamp(TimeUnit::Nanosecond, None) + } + + // this is the type that now() returns + fn timestamp_nano_utc_type() -> DataType { + let utc = Some("+0:00".into()); + DataType::Timestamp(TimeUnit::Nanosecond, utc) + } + + // a dictionary type for storing string tags + fn dictionary_tag_type() -> DataType { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + } + + #[test] + fn test_try_cast_to_type_nulls() { + // test that nulls can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Decimal128(None, 3, 0), + ScalarValue::Decimal128(None, 8, 2), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + } + + #[test] + fn test_try_cast_to_type_int_in_range() { + // test values that can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(Some(123)), + ScalarValue::Int16(Some(123)), + ScalarValue::Int32(Some(123)), + ScalarValue::Int64(Some(123)), + ScalarValue::UInt8(Some(123)), + ScalarValue::UInt16(Some(123)), + ScalarValue::UInt32(Some(123)), + ScalarValue::UInt64(Some(123)), + ScalarValue::Decimal128(Some(123), 3, 0), + ScalarValue::Decimal128(Some(12300), 8, 2), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + + let max_i32 = ScalarValue::Int32(Some(i32::MAX)); + expect_cast( + max_i32, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))), + ); + + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + expect_cast( + min_i32, + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))), + ); + + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + expect_cast( + max_i64, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))), + ); + } + + #[test] + fn test_try_cast_to_type_int_out_of_range() { + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + let min_i64 = ScalarValue::Int64(Some(i64::MIN)); + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); + + expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); + + expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); + + expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue); + + expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); + + expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue); + + expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue); + + // decimal out of range + expect_cast( + ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), + DataType::Int64, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1), + DataType::Int64, + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_decimal_cast_in_range() { + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 5), + ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)), + ); + } + + #[test] + fn test_try_decimal_cast_out_of_range() { + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::NoValue, + ); + + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(2, 0), + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_cast_to_type_timestamps() { + for time_unit in [ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + let utc = Some("+00:00".into()); + // No timezone, utc timezone + let (lit_tz_none, lit_tz_utc) = match time_unit { + TimeUnit::Second => ( + ScalarValue::TimestampSecond(Some(12345), None), + ScalarValue::TimestampSecond(Some(12345), utc), + ), + + TimeUnit::Millisecond => ( + ScalarValue::TimestampMillisecond(Some(12345), None), + ScalarValue::TimestampMillisecond(Some(12345), utc), + ), + + TimeUnit::Microsecond => ( + ScalarValue::TimestampMicrosecond(Some(12345), None), + ScalarValue::TimestampMicrosecond(Some(12345), utc), + ), + + TimeUnit::Nanosecond => ( + ScalarValue::TimestampNanosecond(Some(12345), None), + ScalarValue::TimestampNanosecond(Some(12345), utc), + ), + }; + + // DataFusion ignores timezones for comparisons of ScalarValue + // so double check it here + assert_eq!(lit_tz_none, lit_tz_utc); + + // e.g. DataType::Timestamp(_, None) + let dt_tz_none = lit_tz_none.data_type(); + + // e.g. DataType::Timestamp(_, Some(utc)) + let dt_tz_utc = lit_tz_utc.data_type(); + + // None <--> None + expect_cast( + lit_tz_none.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // None <--> Utc + expect_cast( + lit_tz_none.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // Utc <--> None + expect_cast( + lit_tz_utc.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // Utc <--> Utc + expect_cast( + lit_tz_utc.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to int64 + expect_cast( + lit_tz_utc.clone(), + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(12345))), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to string (not supported yet) + expect_cast( + lit_tz_utc.clone(), + DataType::LargeUtf8, + ExpectedCast::NoValue, + ); + } + } + + #[test] + fn test_try_cast_to_type_unsupported() { + // int64 to list + expect_cast( + ScalarValue::Int64(Some(12345)), + DataType::List(Arc::new(Field::new("f", DataType::Int32, true))), + ExpectedCast::NoValue, + ); + } + + #[derive(Debug, Clone)] + enum ExpectedCast { + /// test successfully cast value and it is as specified + Value(ScalarValue), + /// test returned OK, but could not cast the value + NoValue, + } + + /// Runs try_cast_literal_to_type with the specified inputs and + /// ensure it computes the expected output, and ensures the + /// casting is consistent with the Arrow kernels + fn expect_cast( + literal: ScalarValue, + target_type: DataType, + expected_result: ExpectedCast, + ) { + let actual_value = try_cast_literal_to_type(&literal, &target_type); + + println!("expect_cast: "); + println!(" {literal:?} --> {target_type:?}"); + println!(" expected_result: {expected_result:?}"); + println!(" actual_result: {actual_value:?}"); + + match expected_result { + ExpectedCast::Value(expected_value) => { + let actual_value = + actual_value.expect("Expected cast value but got None"); + + assert_eq!(actual_value, expected_value); + + // Verify that calling the arrow + // cast kernel yields the same results + // input array + let literal_array = literal + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let expected_array = expected_value + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let cast_array = cast_with_options( + &literal_array, + &target_type, + &CastOptions::default(), + ) + .expect("Expected to be cast array with arrow cast kernel"); + + assert_eq!( + &expected_array, &cast_array, + "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" + ); + + // Verify that for timestamp types the timezones are the same + // (ScalarValue::cmp doesn't account for timezones); + if let ( + DataType::Timestamp(left_unit, left_tz), + DataType::Timestamp(right_unit, right_tz), + ) = (actual_value.data_type(), expected_value.data_type()) + { + assert_eq!(left_unit, right_unit); + assert_eq!(left_tz, right_tz); + } + } + ExpectedCast::NoValue => { + assert!( + actual_value.is_none(), + "Expected no cast value, but got {actual_value:?}" + ); + } + } + } + + #[test] + fn test_try_cast_literal_to_timestamp() { + // same timestamp + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123456), None) + ); + + // TimestampNanosecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123), None) + ); + + // TimestampNanosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampNanosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); + + // TimestampMicrosecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000), None) + ); + + // TimestampMicrosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampMicrosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); + + // TimestampMillisecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000), None) + ); + + // TimestampMillisecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000), None) + ); + // TimestampMillisecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); + + // TimestampSecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000000), None) + ); + + // TimestampSecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000000), None) + ); + + // TimestampSecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMillisecond(Some(123000), None) + ); + + // overflow + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(i64::MAX), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); + } + + #[test] + fn test_try_cast_to_string_type() { + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + } + #[test] + fn test_try_cast_to_dictionary_type() { + fn dictionary_type(t: DataType) -> DataType { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) + } + fn dictionary_value(value: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) + } + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + for s in &scalars { + expect_cast( + s.clone(), + dictionary_type(s.data_type()), + ExpectedCast::Value(dictionary_value(s.clone())), + ); + expect_cast( + dictionary_value(s.clone()), + s.data_type(), + ExpectedCast::Value(s.clone()), + ) + } + } +}