diff --git a/Cargo.lock b/Cargo.lock index 22f06f9932a6..06faa3f3f041 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2011,6 +2011,7 @@ name = "datafusion-expr-common" version = "45.0.0" dependencies = [ "arrow", + "arrow-buffer", "datafusion-common", "indexmap 2.7.1", "itertools 0.14.0", diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 349850df6148..fea11ec5d822 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -24,7 +24,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::stats::Precision; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{ColumnStatistics, DFSchema}; +use datafusion::common::{internal_datafusion_err, ColumnStatistics, DFSchema}; use datafusion::common::{ScalarValue, ToDFSchema}; use datafusion::error::Result; use datafusion::functions_aggregate::first_last::first_value_udaf; @@ -302,10 +302,17 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { distinct_count: Precision::Absent, }; + let field = schema.fields().first().ok_or_else(|| { + internal_datafusion_err!("schema does not have a field at index 0") + })?; + // We can then build our expression boundaries from the column statistics // allowing the analysis to be more precise. - let initial_boundaries = - vec![ExprBoundaries::try_from_column(&schema, &column_stats, 0)?]; + let initial_boundaries = vec![ExprBoundaries::try_from_column( + field.as_ref(), + &column_stats, + 0, + )?]; // With the above we can perform the boundary analysis similar to the previous // example. diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 9059ae07e648..02c038fe878f 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1583,6 +1583,17 @@ impl ScalarValue { } } + /// Returns negation for a boolean scalar value + pub fn boolean_negate(&self) -> Result { + match self { + ScalarValue::Boolean(None) => Ok(self.clone()), + ScalarValue::Boolean(Some(value)) => Ok(ScalarValue::Boolean(Some(!value))), + value => { + _internal_err!("Can not run boolean negative on scalar value {value:?}") + } + } + } + /// Wrapping addition of `ScalarValue` /// /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 5b841db53c5e..7ba754a12105 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -158,7 +158,7 @@ impl Precision { /// rows are selected. A selectivity of `0.5` means half the rows are /// selected. Will always return inexact statistics. pub fn with_estimated_selectivity(self, selectivity: f64) -> Self { - self.map(|v| ((v as f64 * selectivity).ceil()) as usize) + self.map(|v| (v as f64 * selectivity).ceil() as usize) .to_inexact() } } diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index 375af94acaf4..48e65d63ac74 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -32,7 +32,6 @@ use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskCon use datafusion_expr::Operator; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; -use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; use datafusion_physical_optimizer::join_selection::{ @@ -1071,21 +1070,21 @@ fn check_expr_supported() { Operator::Plus, Arc::new(Column::new("a", 0)), )) as Arc; - assert!(check_support(&supported_expr, &schema)); + assert!(&supported_expr.supports_bounds_evaluation(&schema)); let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; - assert!(check_support(&supported_expr_2, &schema)); + assert!(&supported_expr_2.supports_bounds_evaluation(&schema)); let unsupported_expr = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Or, Arc::new(Column::new("a", 0)), )) as Arc; - assert!(!check_support(&unsupported_expr, &schema)); + assert!(!&unsupported_expr.supports_bounds_evaluation(&schema)); let unsupported_expr_2 = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Or, Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), )) as Arc; - assert!(!check_support(&unsupported_expr_2, &schema)); + assert!(!&unsupported_expr_2.supports_bounds_evaluation(&schema)); } struct TestCase { diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 14717dd78135..f551f7798ccb 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -42,3 +42,6 @@ datafusion-common = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } paste = "^1.0" + +[dev-dependencies] +arrow-buffer = { workspace = true } diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 9d00b45962bc..a9bdbf1f00b7 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -31,7 +31,7 @@ use arrow::datatypes::{ MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, }; use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; macro_rules! get_extreme_value { ($extreme:ident, $value:expr) => { @@ -307,23 +307,15 @@ impl Interval { // Standardize floating-point endpoints: DataType::Float32 => handle_float_intervals!(Float32, f32, lower, upper), DataType::Float64 => handle_float_intervals!(Float64, f64, lower, upper), - // Unsigned null values for lower bounds are set to zero: - DataType::UInt8 if lower.is_null() => Self { - lower: ScalarValue::UInt8(Some(0)), - upper, - }, - DataType::UInt16 if lower.is_null() => Self { - lower: ScalarValue::UInt16(Some(0)), - upper, - }, - DataType::UInt32 if lower.is_null() => Self { - lower: ScalarValue::UInt32(Some(0)), - upper, - }, - DataType::UInt64 if lower.is_null() => Self { - lower: ScalarValue::UInt64(Some(0)), - upper, - }, + // Lower bounds of unsigned integer null values are set to zero: + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 + if lower.is_null() => + { + Self { + lower: ScalarValue::new_zero(&lower.data_type()).unwrap(), + upper, + } + } // Other data types do not require standardization: _ => Self { lower, upper }, } @@ -406,8 +398,8 @@ impl Interval { // There must be no way to create an interval whose endpoints have // different types. - debug_assert!( - lower_type == upper_type, + debug_assert_eq!( + lower_type, upper_type, "Interval bounds have different types: {lower_type} != {upper_type}" ); lower_type @@ -631,13 +623,19 @@ impl Interval { /// to an error. pub fn intersect>(&self, other: T) -> Result> { let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { - return internal_err!( - "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", - self.data_type(), - rhs.data_type() - ); - }; + BinaryTypeCoercer::new(&self.data_type(), &Operator::Plus, &rhs.data_type()).get_result_type() + .map_err(|e| + DataFusionError::Internal( + format!( + "Cannot coerce data types for interval intersection, lhs:{}, rhs:{}. internal error: {}", + self.data_type(), + rhs.data_type(), + e + )) + )?; + } // If it is evident that the result is an empty interval, short-circuit // and directly return `None`. @@ -652,7 +650,7 @@ impl Interval { // New lower and upper bounds must always construct a valid interval. debug_assert!( - (lower.is_null() || upper.is_null() || (lower <= upper)), + lower.is_null() || upper.is_null() || (lower <= upper), "The intersection of two intervals can not be an invalid interval" ); @@ -941,6 +939,19 @@ impl Interval { upper: self.lower.arithmetic_negate()?, }) } + + pub fn boolean_negate(self) -> Result { + if self.data_type() != DataType::Boolean { + return internal_err!( + "Boolean negation is only supported for boolean intervals" + ); + } + + Ok(Self { + lower: self.lower().clone().boolean_negate()?, + upper: self.upper().clone().boolean_negate()?, + }) + } } impl Display for Interval { @@ -963,6 +974,23 @@ pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result lhs.sub(rhs), Operator::Multiply => lhs.mul(rhs), Operator::Divide => lhs.div(rhs), + Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => { + NullableInterval::from(lhs) + .apply_operator(op, &rhs.into()) + .and_then(|nullable_interval| match nullable_interval { + NullableInterval::Null { .. } => { + let return_type = BinaryTypeCoercer::new( + &lhs.data_type(), + op, + &rhs.data_type(), + ) + .get_result_type()?; + Interval::make_unbounded(&return_type) + } + NullableInterval::MaybeNull { values } + | NullableInterval::NotNull { values } => Ok(values), + }) + } _ => internal_err!("Interval arithmetic does not support the operator {op}"), } } @@ -1119,14 +1147,14 @@ fn handle_overflow( } } -// This function should remain private since it may corrupt the an interval if +// This function should remain private since it may corrupt an interval if // used without caution. fn next_value(value: ScalarValue) -> ScalarValue { use ScalarValue::*; value_transition!(MAX, true, value) } -// This function should remain private since it may corrupt the an interval if +// This function should remain private since it may corrupt an interval if // used without caution. fn prev_value(value: ScalarValue) -> ScalarValue { use ScalarValue::*; @@ -1136,10 +1164,10 @@ fn prev_value(value: ScalarValue) -> ScalarValue { trait OneTrait: Sized + std::ops::Add + std::ops::Sub { fn one() -> Self; } -macro_rules! impl_OneTrait{ +macro_rules! impl_one_trait { ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} } -impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64, i128} +impl_one_trait! {u8, u16, u32, u64, i8, i16, i32, i64, i128} impl OneTrait for IntervalDayTime { fn one() -> Self { @@ -1298,18 +1326,18 @@ pub fn satisfy_greater( } if !left.upper.is_null() && left.upper <= right.lower { - if !strict && left.upper == right.lower { + return if !strict && left.upper == right.lower { // Singleton intervals: - return Ok(Some(( + Ok(Some(( Interval::new(left.upper.clone(), left.upper.clone()), Interval::new(left.upper.clone(), left.upper.clone()), - ))); + ))) } else { // Left-hand side: <--======----0------------> // Right-hand side: <------------0--======----> // No intersection, infeasible to propagate: - return Ok(None); - } + Ok(None) + }; } // Only the lower bound of left-hand side and the upper bound of the right-hand @@ -1690,6 +1718,24 @@ impl Display for NullableInterval { } } +impl From<&Interval> for NullableInterval { + fn from(value: &Interval) -> Self { + if value.is_unbounded() { + Self::Null { + datatype: value.data_type(), + } + } else if value.lower.is_null() || value.upper.is_null() { + Self::MaybeNull { + values: value.clone(), + } + } else { + Self::NotNull { + values: value.clone(), + } + } + } +} + impl From for NullableInterval { /// Create an interval that represents a single value. fn from(value: ScalarValue) -> Self { @@ -1929,8 +1975,14 @@ mod tests { }; use arrow::datatypes::DataType; + use arrow_buffer::IntervalDayTime as ArrowIntervalDayTime; use datafusion_common::rounding::{next_down, next_up}; + use datafusion_common::ScalarValue::{ + Date32, DurationMillisecond, DurationSecond, IntervalDayTime, IntervalYearMonth, + TimestampSecond, + }; use datafusion_common::{Result, ScalarValue}; + use ScalarValue::{Date64, Time32Millisecond}; #[test] fn test_next_prev_value() -> Result<()> { @@ -2147,6 +2199,24 @@ mod tests { prev_value(ScalarValue::Float32(Some(-1.0))), )?, ), + ( + Interval::new(Date64(Some(1)), Date64(Some(1))), + Interval::new(Date64(Some(-1)), Date64(Some(-1))), + ), + ( + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(10), None), + ), + Interval::new( + TimestampSecond(Some(-10), None), + TimestampSecond(Some(-1), None), + ), + ), + ( + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(10))), + Interval::new(DurationSecond(Some(-10)), DurationSecond(Some(-1))), + ), ]; for (first, second) in exactly_gt_cases { assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_TRUE); @@ -2184,6 +2254,24 @@ mod tests { ScalarValue::Float32(Some(-1.0_f32)), )?, ), + ( + Interval::new(Date64(Some(1)), Date64(Some(10))), + Interval::new(Date64(Some(1)), Date64(Some(1))), + ), + ( + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(10), None), + ), + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(1), None), + ), + ), + ( + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(10))), + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(1))), + ), ]; for (first, second) in possibly_gt_cases { assert_eq!(first.gt(second.clone())?, Interval::UNCERTAIN); @@ -2221,6 +2309,24 @@ mod tests { next_value(ScalarValue::Float32(Some(-1.0_f32))), )?, ), + ( + Interval::new(Date64(Some(1)), Date64(Some(10))), + Interval::new(Date64(Some(10)), Date64(Some(100))), + ), + ( + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(10), None), + ), + Interval::new( + TimestampSecond(Some(10), None), + TimestampSecond(None, None), + ), + ), + ( + Interval::new(DurationSecond(Some(-10)), DurationSecond(Some(-1))), + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(1))), + ), ]; for (first, second) in not_gt_cases { assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_FALSE); @@ -2267,6 +2373,30 @@ mod tests { ScalarValue::Float32(Some(-1.0)), )?, ), + ( + Interval::new( + ScalarValue::Time32Second(Some(0)), + ScalarValue::Time32Second(Some(10)), + ), + Interval::new( + ScalarValue::Time32Second(Some(-1)), + ScalarValue::Time32Second(Some(-1)), + ), + ), + ( + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(10), None), + ), + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(1), None), + ), + ), + ( + Interval::new(DurationSecond(Some(-10)), DurationSecond(Some(1))), + Interval::new(DurationSecond(Some(-10)), DurationSecond(Some(-10))), + ), ]; for (first, second) in exactly_gteq_cases { assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_TRUE); @@ -2304,6 +2434,30 @@ mod tests { next_value(ScalarValue::Float32(Some(-1.0_f32))), )?, ), + ( + Interval::new( + ScalarValue::Time32Second(Some(0)), + ScalarValue::Time32Second(Some(10)), + ), + Interval::new( + ScalarValue::Time32Second(Some(0)), + ScalarValue::Time32Second(None), + ), + ), + ( + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(10), None), + ), + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(10), None), + ), + ), + ( + Interval::new(DurationSecond(Some(-10)), DurationSecond(Some(1))), + Interval::new(DurationSecond(None), DurationSecond(Some(0))), + ), ]; for (first, second) in possibly_gteq_cases { assert_eq!(first.gt_eq(second.clone())?, Interval::UNCERTAIN); @@ -2337,6 +2491,30 @@ mod tests { next_value(ScalarValue::Float32(Some(-1.0))), )?, ), + ( + Interval::new( + ScalarValue::Time32Second(Some(-10)), + ScalarValue::Time32Second(Some(0)), + ), + Interval::new( + ScalarValue::Time32Second(Some(1)), + ScalarValue::Time32Second(Some(10)), + ), + ), + ( + Interval::new( + TimestampSecond(Some(5), None), + TimestampSecond(Some(9), None), + ), + Interval::new( + TimestampSecond(Some(10), None), + TimestampSecond(Some(100), None), + ), + ), + ( + Interval::new(DurationSecond(None), DurationSecond(Some(-1))), + Interval::new(DurationSecond(Some(0)), DurationSecond(Some(1))), + ), ]; for (first, second) in not_gteq_cases { assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_FALSE); @@ -2365,6 +2543,28 @@ mod tests { Interval::make(Some(f64::MIN), Some(f64::MIN))?, Interval::make(Some(f64::MIN), Some(f64::MIN))?, ), + ( + Interval::new(Date64(Some(1000)), Date64(Some(1000))), + Interval::new(Date64(Some(1000)), Date64(Some(1000))), + ), + ( + Interval::new( + Time32Millisecond(Some(1000)), + Time32Millisecond(Some(1000)), + ), + Interval::new( + Time32Millisecond(Some(1000)), + Time32Millisecond(Some(1000)), + ), + ), + ( + Interval::new(IntervalYearMonth(Some(10)), IntervalYearMonth(Some(10))), + Interval::new(IntervalYearMonth(Some(10)), IntervalYearMonth(Some(10))), + ), + ( + Interval::new(DurationSecond(Some(10)), DurationSecond(Some(10))), + Interval::new(DurationSecond(Some(10)), DurationSecond(Some(10))), + ), ]; for (first, second) in exactly_eq_cases { assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_TRUE); @@ -2550,6 +2750,17 @@ mod tests { Interval::make(Some(32.0_f64), Some(64.0_f64))?, Interval::make(Some(32.0_f64), Some(32.0_f64))?, ), + ( + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(10))), + Interval::new( + DurationMillisecond(Some(1001)), + DurationMillisecond(Some(1010)), + ), + Interval::new( + DurationMillisecond(Some(1001)), + DurationMillisecond(Some(1010)), + ), + ), ]; for (first, second, expected) in possible_cases { assert_eq!(first.intersect(second)?.unwrap(), expected) @@ -2594,6 +2805,107 @@ mod tests { Ok(()) } + #[test] + fn test_union() -> Result<()> { + let possible_cases: Vec<(Interval, Interval, Interval)> = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(2000_i64))?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make::(None, None)?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(500_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(10_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1_i64), Some(10_i64))?, + Interval::make(Some(1_i64), None)?, + ), + ( + Interval::make(None, Some(2000_u64))?, + Interval::make(Some(500_u64), None)?, + Interval::make_unbounded(&DataType::UInt64)?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), None)?, + Interval::make(Some(0_u64), None)?, + ), + ( + Interval::make(Some(1000.0_f32), None)?, + Interval::make(None, Some(1000.0_f32))?, + Interval::make_unbounded(&DataType::Float32)?, + ), + ( + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + ), + ( + Interval::make(Some(16.0_f64), Some(32.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(16.0_f64), Some(64.0_f64))?, + ), + ( + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(10))), + Interval::new(DurationSecond(Some(10)), DurationSecond(Some(100))), + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(100))), + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.union(second.clone())?, expected) + } + + Ok(()) + } + #[test] fn union_test() -> Result<()> { let possible_cases = vec![ @@ -2727,6 +3039,22 @@ mod tests { Interval::make(Some(1501_i64), Some(1999_i64))?, Interval::CERTAINLY_TRUE, ), + ( + Interval::new( + TimestampSecond(Some(1), None), + TimestampSecond(Some(10), None), + ), + Interval::new( + TimestampSecond(Some(2), None), + TimestampSecond(Some(5), None), + ), + Interval::CERTAINLY_TRUE, + ), + ( + Interval::new(DurationSecond(Some(0)), DurationSecond(Some(600))), + Interval::new(DurationSecond(Some(1)), DurationSecond(Some(599))), + Interval::CERTAINLY_TRUE, + ), ( Interval::make(Some(1000_i64), None)?, Interval::make::(None, None)?, @@ -2742,6 +3070,11 @@ mod tests { Interval::make(Some(32.0), Some(64.0))?, Interval::UNCERTAIN, ), + ( + Interval::make::(Some(3_i64), Some(5_i64))?, + Interval::make::(Some(0_i64), Some(9_i64))?, + Interval::UNCERTAIN, + ), ( Interval::make(Some(1000_i64), None)?, Interval::make(None, Some(0_i64))?, @@ -2768,6 +3101,17 @@ mod tests { Interval::make(Some(1.0_f32), Some(1.0_f32))?, Interval::CERTAINLY_FALSE, ), + ( + Interval::new( + ScalarValue::Time32Second(Some(0)), + ScalarValue::Time32Second(Some(60)), + ), + Interval::new( + ScalarValue::Time32Second(Some(61)), + ScalarValue::Time32Second(Some(120)), + ), + Interval::CERTAINLY_FALSE, + ), ]; for (first, second, expected) in possible_cases { assert_eq!(first.contains(second)?, expected) @@ -2881,6 +3225,41 @@ mod tests { Interval::make(None, Some(200_f64))?, Interval::make(None, Some(300_f64))?, ), + ( + Interval::new( + TimestampSecond(Some(100), None), + TimestampSecond(Some(200), None), + ), + Interval::new(DurationSecond(Some(100)), DurationSecond(Some(200))), + Interval::new( + TimestampSecond(Some(200), None), + TimestampSecond(Some(400), None), + ), + ), + ( + Interval::new(Date32(Some(100)), Date32(Some(100))), + Interval::new( + IntervalDayTime(Some(ArrowIntervalDayTime { + days: 1, + milliseconds: 0, + })), + IntervalDayTime(Some(ArrowIntervalDayTime { + days: 10, + milliseconds: 0, + })), + ), + Interval::new(Date32(Some(101)), Date32(Some(110))), + ), + ( + Interval::new(DurationSecond(Some(100)), DurationSecond(Some(100))), + Interval::new(DurationSecond(Some(100)), DurationSecond(Some(100))), + Interval::new(DurationSecond(Some(200)), DurationSecond(Some(200))), + ), + ( + Interval::new(IntervalYearMonth(Some(100)), IntervalYearMonth(Some(100))), + Interval::new(IntervalYearMonth(Some(100)), IntervalYearMonth(Some(100))), + Interval::new(IntervalYearMonth(Some(200)), IntervalYearMonth(Some(200))), + ), ]; for case in cases { let result = case.0.add(case.1)?; diff --git a/datafusion/expr-common/src/operator.rs b/datafusion/expr-common/src/operator.rs index 6ca0f04897ac..ac8a9d9e06dd 100644 --- a/datafusion/expr-common/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -164,6 +164,25 @@ impl Operator { ) } + /// Indicates whether this operator supports interval arithmetic + pub fn supports_interval_evaluation(&self) -> bool { + matches!( + &self, + &Operator::Plus + | &Operator::Minus + | &Operator::And + | &Operator::Gt + | &Operator::GtEq + | &Operator::Lt + | &Operator::LtEq + | &Operator::Eq + | &Operator::Multiply + | &Operator::Divide + | &Operator::IsDistinctFrom + | &Operator::IsNotDistinctFrom + ) + } + /// Return true if the comparison operator can be used in interval arithmetic and constraint /// propagation /// diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 64c26192ae0f..a4f7ed436443 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -1493,7 +1493,8 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { #[cfg(test)] mod tests { use super::*; - + use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth}; + use arrow::datatypes::TimeUnit::Nanosecond; use datafusion_common::assert_contains; #[test] @@ -1673,7 +1674,7 @@ mod tests { #[test] fn test_date_timestamp_arithmetic_error() -> Result<()> { let (lhs, rhs) = BinaryTypeCoercer::new( - &DataType::Timestamp(TimeUnit::Nanosecond, None), + &DataType::Timestamp(Nanosecond, None), &Operator::Minus, &DataType::Timestamp(TimeUnit::Millisecond, None), ) @@ -1768,33 +1769,33 @@ mod tests { ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Time64(TimeUnit::Nanosecond), + DataType::Time64(Nanosecond), Operator::Eq, - DataType::Time64(TimeUnit::Nanosecond) + DataType::Time64(Nanosecond) ); test_coercion_binary_rule!( DataType::Utf8, DataType::Timestamp(TimeUnit::Second, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, DataType::Timestamp(TimeUnit::Millisecond, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, DataType::Timestamp(TimeUnit::Microsecond, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(Nanosecond, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, @@ -1883,6 +1884,35 @@ mod tests { Ok(()) } + #[test] + fn test_type_coercion_temporal() -> Result<()> { + test_coercion_binary_rule!( + DataType::Duration(TimeUnit::Second), + DataType::Duration(TimeUnit::Second), + Operator::Plus, + DataType::Duration(TimeUnit::Second) + ); + test_coercion_binary_rule!( + DataType::Duration(TimeUnit::Second), + DataType::Duration(Nanosecond), + Operator::Plus, + DataType::Interval(MonthDayNano) + ); + test_coercion_binary_rule!( + DataType::Interval(YearMonth), + DataType::Interval(YearMonth), + Operator::Plus, + DataType::Interval(YearMonth) + ); + test_coercion_binary_rule!( + DataType::Interval(YearMonth), + DataType::Interval(MonthDayNano), + Operator::Plus, + DataType::Interval(MonthDayNano) + ); + Ok(()) + } + #[test] fn test_type_coercion_arithmetic() -> Result<()> { // integer diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 74c3c2775c1c..53f053db6b12 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; @@ -279,6 +279,11 @@ impl ScalarUDF { self.inner.evaluate_bounds(inputs) } + /// Indicates whether this ['ScalarUDF'] supports interval arithmetic. + pub fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + self.inner.supports_bounds_evaluation(schema) + } + /// Updates bounds for child expressions, given a known interval for this /// function. This is used to propagate constraints down through an expression /// tree. @@ -717,9 +722,18 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, /// then the output interval would be `[0, 3]`. - fn evaluate_bounds(&self, _input: &[&Interval]) -> Result { - // We cannot assume the input datatype is the same of output type. - Interval::make_unbounded(&DataType::Null) + fn evaluate_bounds(&self, input: &[&Interval]) -> Result { + let input_data_types = input + .iter() + .map(|i| i.data_type()) + .collect::>(); + let return_type = self.return_type(&input_data_types)?; + Interval::make_unbounded(&return_type) + } + + /// Indicates whether this ['ScalarUDFImpl'] supports interval arithmetic. + fn supports_bounds_evaluation(&self, _schema: &SchemaRef) -> bool { + false } /// Updates bounds for child expressions, given a known interval for this diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index cc2ff2f24790..505abb2920c9 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -24,7 +24,7 @@ use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; @@ -119,6 +119,11 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { not_impl_err!("Not implemented for {self}") } + /// Indicates whether interval arithmetic is supported for this expression. + fn supports_bounds_evaluation(&self, _schema: &SchemaRef) -> bool { + false + } + /// Updates bounds for child expressions, given a known interval for this /// expression. /// diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 114007bfa6af..abf488ca8240 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -26,6 +26,7 @@ use datafusion_expr_common::sort_properties::ExprProperties; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::tree_node::ExprContext; +use arrow::datatypes::DataType; /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. @@ -103,6 +104,31 @@ pub fn reverse_order_bys(order_bys: &LexOrdering) -> LexOrdering { .collect() } +/// Indicates whether interval arithmetic is supported for the given data type. +pub fn is_supported_datatype_for_bounds_eval(data_type: &DataType) -> bool { + matches!( + data_type, + &DataType::Int64 + | &DataType::Int32 + | &DataType::Int16 + | &DataType::Int8 + | &DataType::UInt64 + | &DataType::UInt32 + | &DataType::UInt16 + | &DataType::UInt8 + | &DataType::Float64 + | &DataType::Float32 + | &DataType::Float16 + | &DataType::Timestamp(_, _) + | &DataType::Date32 + | &DataType::Date64 + | &DataType::Time32(_) + | &DataType::Time64(_) + | &DataType::Interval(_) + | &DataType::Duration(_) + ) +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 5abd50f6d1b4..0dd8666d217c 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -25,7 +25,7 @@ use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult}; use crate::utils::collect_columns; use crate::PhysicalExpr; -use arrow::datatypes::Schema; +use arrow::datatypes::{Field, Schema}; use datafusion_common::stats::Precision; use datafusion_common::{ internal_datafusion_err, internal_err, ColumnStatistics, Result, ScalarValue, @@ -60,13 +60,17 @@ impl AnalysisContext { /// Create a new analysis context from column statistics. pub fn try_from_statistics( - input_schema: &Schema, + schema: &Schema, statistics: &[ColumnStatistics], ) -> Result { - statistics + schema + .fields() .iter() + .zip(statistics.iter()) .enumerate() - .map(|(idx, stats)| ExprBoundaries::try_from_column(input_schema, stats, idx)) + .map(|(idx, (field, stats))| { + ExprBoundaries::try_from_column(field.as_ref(), stats, idx) + }) .collect::>>() .map(Self::new) } @@ -94,17 +98,10 @@ pub struct ExprBoundaries { impl ExprBoundaries { /// Create a new `ExprBoundaries` object from column level statistics. pub fn try_from_column( - schema: &Schema, + field: &Field, col_stats: &ColumnStatistics, col_index: usize, ) -> Result { - let field = schema.fields().get(col_index).ok_or_else(|| { - internal_datafusion_err!( - "Could not create `ExprBoundaries`: in `try_from_column` `col_index` - has gone out of bounds with a value of {col_index}, the schema has {} columns.", - schema.fields.len() - ) - })?; let empty_field = ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null); let interval = Interval::try_new( diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 1f16c5471ed7..ba3ebf5f2149 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -381,6 +381,25 @@ impl PhysicalExpr for BinaryExpr { )) } + fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + // Interval data types must be compatible for the given operation + if let (Ok(lhs), Ok(rhs)) = ( + self.left.data_type(schema.as_ref()), + self.right.data_type(schema.as_ref()), + ) { + if BinaryTypeCoercer::new(&lhs, &self.op, &rhs) + .get_result_type() + .is_err() + { + return false; + } + } + + self.op().supports_interval_evaluation() + && self.left.supports_bounds_evaluation(schema) + && self.right.supports_bounds_evaluation(schema) + } + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { // Get children intervals: let left_interval = children[0]; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 8a093e0ae92e..88d465f4a6be 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, DataType::*, Schema}; +use arrow::datatypes::{DataType, DataType::*, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; @@ -164,6 +164,10 @@ impl PhysicalExpr for CastExpr { children[0].cast_to(&self.cast_type, &self.cast_options) } + fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + self.expr().supports_bounds_evaluation(schema) + } + fn propagate_constraints( &self, interval: &Interval, diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 0ec985887c3f..51217cab6620 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -29,6 +29,7 @@ use arrow::{ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::utils::is_supported_datatype_for_bounds_eval; /// Represents the column at a given index in a RecordBatch /// @@ -137,6 +138,14 @@ impl PhysicalExpr for Column { ) -> Result> { Ok(self) } + + fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + if let Ok(field) = schema.field_with_name(self.name()) { + is_supported_datatype_for_bounds_eval(field.data_type()) + } else { + false + } + } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index dfe9a905dfea..b67f1811f8fb 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -45,6 +45,7 @@ use datafusion_physical_expr_common::datum::compare_with_eq; use ahash::RandomState; use datafusion_common::HashMap; +use datafusion_expr::interval_arithmetic::Interval; use hashbrown::hash_map::RawEntryMut; /// InList @@ -398,6 +399,51 @@ impl PhysicalExpr for InListExpr { self.static_filter.clone(), ))) } + + /// The output interval is computed by checking if the list item intervals are + /// a subset of, overlap, or are disjoint with the input expression's interval. + /// + /// If [InListExpr::negated] is true, the output interval gets negated. + /// + /// # Example: + /// If the input expression's interval is a superset of the + /// conjunction of the list items intervals, the output + /// interval is [`Interval::CERTAINLY_TRUE`]. + /// + /// ```text + /// interval of expr: ....---------------------.... + /// Some list items: ..........|..|.....|.|....... + /// + /// output interval: [`true`, `true`] + /// ``` + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + let expr_bounds = children[0]; + + debug_assert!( + children.len() >= 2, + "InListExpr requires at least one list item" + ); + + // conjunction of list item intervals + let list_bounds = children + .iter() + .skip(2) + .try_fold(children[1].clone(), |acc, item| acc.union(*item))?; + + if self.negated { + expr_bounds.contains(list_bounds)?.boolean_negate() + } else { + expr_bounds.contains(list_bounds) + } + } + + fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + self.expr.supports_bounds_evaluation(schema) + && self + .list + .iter() + .all(|expr| expr.supports_bounds_evaluation(schema)) + } } impl PartialEq for InListExpr { @@ -1422,4 +1468,37 @@ mod tests { Ok(()) } + + #[test] + fn test_in_list_bounds_eval() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let col_a = col("a", &schema)?; + let list = vec![lit(0i64), lit(2i64), lit(9i64), lit(6i64)]; + + let expr = in_list(col_a, list, &false, &schema).unwrap(); + + let child_intervals: &[&Interval] = &[ + &Interval::make(Some(3_i64), Some(5_i64))?, + &Interval::make(Some(0_i64), Some(2_i64))?, + &Interval::make(Some(6_i64), Some(9_i64))?, + ]; + let result = expr.evaluate_bounds(child_intervals)?; + debug_assert_eq!(result, Interval::UNCERTAIN); + + let child_intervals: &[&Interval] = &[ + &Interval::make(Some(3_i64), Some(5_i64))?, + &Interval::make(Some(4_i64), Some(4_i64))?, + ]; + let result = expr.evaluate_bounds(child_intervals)?; + debug_assert_eq!(result, Interval::CERTAINLY_TRUE); + + let child_intervals: &[&Interval] = &[ + &Interval::make(Some(3_i64), Some(5_i64))?, + &Interval::make(Some(10_i64), Some(10_i64))?, + ]; + let result = expr.evaluate_bounds(child_intervals)?; + debug_assert_eq!(result, Interval::CERTAINLY_FALSE); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 47dc53d12555..bbd5d80908fc 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -27,6 +27,7 @@ use arrow::{ }; use datafusion_common::Result; use datafusion_common::ScalarValue; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; /// IS NOT NULL expression @@ -104,6 +105,17 @@ impl PhysicalExpr for IsNotNullExpr { ) -> Result> { Ok(Arc::new(IsNotNullExpr::new(Arc::clone(&children[0])))) } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + let inner = children[0]; + Ok(if inner.is_unbounded() { + Interval::CERTAINLY_FALSE + } else if inner.lower().is_null() || inner.upper().is_null() { + Interval::UNCERTAIN + } else { + Interval::CERTAINLY_TRUE + }) + } } /// Create an IS NOT NULL expression diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 232f9769b056..5a3da5f58d9b 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -23,15 +23,16 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::SchemaRef; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Expr; -use datafusion_expr_common::columnar_value::ColumnarValue; -use datafusion_expr_common::interval_arithmetic::Interval; -use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_physical_expr_common::utils::is_supported_datatype_for_bounds_eval; /// Represents a literal value #[derive(Debug, PartialEq, Eq, Hash)] @@ -93,6 +94,14 @@ impl PhysicalExpr for Literal { preserves_lex_ordering: true, }) } + + fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + if let Ok(dt) = self.data_type(schema) { + is_supported_datatype_for_bounds_eval(&dt) + } else { + false + } + } } /// Create a literal expression diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 8795545274a2..c19e14e87933 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; +use arrow::datatypes::SchemaRef; use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, @@ -121,6 +122,10 @@ impl PhysicalExpr for NegativeExpr { children[0].arithmetic_negate() } + fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + self.arg().supports_bounds_evaluation(schema) + } + /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that /// given the input interval is known to be `children`. fn propagate_constraints( diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index a53814c3ad2b..66c80b52dd0e 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -157,11 +157,11 @@ use datafusion_common::{internal_err, Result}; use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; +use datafusion_expr::type_coercion::{is_datetime, is_interval}; use petgraph::graph::NodeIndex; use petgraph::stable_graph::{DefaultIx, StableGraph}; use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; use petgraph::Outgoing; - /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. #[derive(Clone, Debug)] @@ -256,42 +256,27 @@ pub fn propagate_arithmetic( right_child: &Interval, ) -> Result> { let inverse_op = get_inverse_op(*op)?; - match (left_child.data_type(), right_child.data_type()) { - // If we have a child whose type is a time interval (i.e. DataType::Interval), - // we need special handling since timestamp differencing results in a - // Duration type. - (DataType::Timestamp(..), DataType::Interval(_)) => { - propagate_time_interval_at_right( - left_child, - right_child, - parent, - op, - &inverse_op, - ) - } - (DataType::Interval(_), DataType::Timestamp(..)) => { - propagate_time_interval_at_left( - left_child, - right_child, - parent, - op, - &inverse_op, - ) - } - _ => { - // First, propagate to the left: - match apply_operator(&inverse_op, parent, right_child)? - .intersect(left_child)? - { - // Left is feasible: - Some(value) => Ok( - // Propagate to the right using the new left. - propagate_right(&value, parent, right_child, op, &inverse_op)? - .map(|right| (value, right)), - ), - // If the left child is infeasible, short-circuit. - None => Ok(None), - } + + // If we have a child whose data type is datetime (i.e. timestamp), + // we need special handling since timestamp differencing results in + // a Duration type. + if is_datetime(&left_child.data_type()) && is_interval(&right_child.data_type()) { + propagate_time_interval_at_right(left_child, right_child, parent, op, &inverse_op) + } else if is_interval(&left_child.data_type()) + && is_datetime(&right_child.data_type()) + { + propagate_time_interval_at_left(left_child, right_child, parent, op, &inverse_op) + } else { + // First, propagate to the left: + match apply_operator(&inverse_op, parent, right_child)?.intersect(left_child)? { + // Left is feasible: + Some(value) => Ok( + // Propagate to the right using the new left. + propagate_right(&value, parent, right_child, op, &inverse_op)? + .map(|right| (value, right)), + ), + // If the left child is infeasible, short-circuit. + None => Ok(None), } } } @@ -347,9 +332,14 @@ pub fn propagate_comparison( ) -> Result> { if parent == &Interval::CERTAINLY_TRUE { match op { - Operator::Eq => left_child.intersect(right_child).map(|result| { - result.map(|intersection| (intersection.clone(), intersection)) - }), + Operator::Eq | Operator::IsNotDistinctFrom => { + left_child.intersect(right_child).map(|result| { + result.map(|intersection| (intersection.clone(), intersection)) + }) + } + Operator::NotEq | Operator::IsDistinctFrom => left_child + .union(right_child) + .map(|union| Some((union.clone(), union.clone()))), Operator::Gt => satisfy_greater(left_child, right_child, true), Operator::GtEq => satisfy_greater(left_child, right_child, false), Operator::Lt => satisfy_greater(right_child, left_child, true) @@ -362,7 +352,10 @@ pub fn propagate_comparison( } } else if parent == &Interval::CERTAINLY_FALSE { match op { - Operator::Eq => { + Operator::Eq + | Operator::IsNotDistinctFrom + | Operator::NotEq + | Operator::IsDistinctFrom => { // TODO: Propagation is not possible until we support interval sets. Ok(None) } diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 910631ef4a43..c6d1f2519e4c 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -17,51 +17,11 @@ //! Utility functions for the interval arithmetic library -use std::sync::Arc; - -use crate::{ - expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, - PhysicalExpr, -}; - use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; -use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; -/// Indicates whether interval arithmetic is supported for the given expression. -/// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. -/// We do not support every type of [`Operator`]s either. Over time, this check -/// will relax as more types of `PhysicalExpr`s and `Operator`s are supported. -/// Currently, [`CastExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. -pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { - let expr_any = expr.as_any(); - if let Some(binary_expr) = expr_any.downcast_ref::() { - is_operator_supported(binary_expr.op()) - && check_support(binary_expr.left(), schema) - && check_support(binary_expr.right(), schema) - } else if let Some(column) = expr_any.downcast_ref::() { - if let Ok(field) = schema.field_with_name(column.name()) { - is_datatype_supported(field.data_type()) - } else { - return false; - } - } else if let Some(literal) = expr_any.downcast_ref::() { - if let Ok(dt) = literal.data_type(schema) { - is_datatype_supported(&dt) - } else { - return false; - } - } else if let Some(cast) = expr_any.downcast_ref::() { - check_support(cast.expr(), schema) - } else if let Some(negative) = expr_any.downcast_ref::() { - check_support(negative.arg(), schema) - } else { - false - } -} - // This function returns the inverse operator of the given operator. pub fn get_inverse_op(op: Operator) -> Result { match op { @@ -73,40 +33,6 @@ pub fn get_inverse_op(op: Operator) -> Result { } } -/// Indicates whether interval arithmetic is supported for the given operator. -pub fn is_operator_supported(op: &Operator) -> bool { - matches!( - op, - &Operator::Plus - | &Operator::Minus - | &Operator::And - | &Operator::Gt - | &Operator::GtEq - | &Operator::Lt - | &Operator::LtEq - | &Operator::Eq - | &Operator::Multiply - | &Operator::Divide - ) -} - -/// Indicates whether interval arithmetic is supported for the given data type. -pub fn is_datatype_supported(data_type: &DataType) -> bool { - matches!( - data_type, - &DataType::Int64 - | &DataType::Int32 - | &DataType::Int16 - | &DataType::Int8 - | &DataType::UInt64 - | &DataType::UInt32 - | &DataType::UInt16 - | &DataType::UInt8 - | &DataType::Float64 - | &DataType::Float32 - ) -} - /// Converts an [`Interval`] of time intervals to one of `Duration`s, if applicable. Otherwise, returns [`None`]. pub fn convert_interval_type_to_duration(interval: &Interval) -> Option { if let (Some(lower), Some(upper)) = ( diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index bd38fb22ccbc..790bc7224101 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -38,7 +38,7 @@ use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; @@ -237,6 +237,10 @@ impl PhysicalExpr for ScalarFunctionExpr { self.fun.evaluate_bounds(children) } + fn supports_bounds_evaluation(&self, schema: &SchemaRef) -> bool { + self.fun.supports_bounds_evaluation(schema) + } + fn propagate_constraints( &self, interval: &Interval, diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index 8edbb0f09114..cbdf1773d8f3 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -29,13 +29,13 @@ use datafusion_physical_plan::ExecutionPlan; use datafusion_common::config::{ConfigOptions, OptimizerOptions}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; use crate::PhysicalOptimizerRule; use datafusion_physical_expr_common::sort_expr::format_physical_sort_requirement_list; +use datafusion_physical_expr_common::utils::is_supported_datatype_for_bounds_eval; use itertools::izip; /// The SanityCheckPlan rule rejects the following query plans: @@ -113,12 +113,14 @@ pub fn check_finiteness_requirements( /// [`Operator`]: datafusion_expr::Operator fn is_prunable(join: &SymmetricHashJoinExec) -> bool { join.filter().is_some_and(|filter| { - check_support(filter.expression(), &join.schema()) + filter + .expression() + .supports_bounds_evaluation(&join.schema()) && filter .schema() .fields() .iter() - .all(|f| is_datatype_supported(f.data_type())) + .all(|f| is_supported_datatype_for_bounds_eval(f.data_type())) }) } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index a66873bc6576..b3e8b97765ba 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -47,7 +47,6 @@ use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::BinaryExpr; -use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ analyze, split_conjunction, AcrossPartitions, AnalysisContext, ConstExpr, @@ -175,7 +174,7 @@ impl FilterExec { ) -> Result { let input_stats = input.statistics()?; let schema = input.schema(); - if !check_support(predicate, &schema) { + if !predicate.supports_bounds_evaluation(&schema) { let selectivity = default_selectivity as f64 / 100.0; let mut stats = input_stats.to_inexact(); stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity);