Skip to content

Commit

Permalink
revisit union interval logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ch-sc committed Feb 14, 2025
1 parent cb11459 commit b4bd851
Showing 1 changed file with 56 additions and 47 deletions.
103 changes: 56 additions & 47 deletions datafusion/expr-common/src/interval_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,10 @@ impl Interval {
self.lower.is_null() && self.upper.is_null()
}

pub fn is_not_null(&self) -> bool {
!(self.lower.is_null() || self.upper.is_null())
}

pub const CERTAINLY_FALSE: Self = Self {
lower: ScalarValue::Boolean(Some(false)),
upper: ScalarValue::Boolean(Some(false)),
Expand Down Expand Up @@ -620,57 +624,62 @@ impl Interval {
/// to an error.
pub fn union<T: Borrow<Self>>(&self, other: T) -> Result<Option<Self>> {
let rhs = other.borrow();

if self.data_type().ne(&rhs.data_type()) {
BinaryTypeCoercer::new(&self.data_type(), &Operator::Plus, &rhs.data_type()).get_result_type()
.map_err(|e|
DataFusionError::Internal(
format!(
"Cannot coerce data types for interval union, lhs:{}, rhs:{}. internal error: {}",
self.data_type(),
rhs.data_type(),
e
))
)?;
return internal_err!(
"Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}",
self.data_type(),
rhs.data_type()
);
};

// If the upper bound of one side is less than the lower bound of the
// other side or vice versa, then the resulting interval is expanded
// accordingly. Note that, this can only happen if one side has a null
// value.
//
// Examples:
// [1, 2] ∪ [3, NULL] = [1, 3]
// [3, NULL] ∪ [1, NULL] = [1, 3]
// [3, NULL] ∪ [NULL, 1] = [1, 3]
let (lower, upper) = if (!(self.lower.is_null() || rhs.upper.is_null())
&& self.lower > rhs.upper)
|| (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower)
{
(
min_of_bounds(
&min_of_bounds(&min_of_bounds(&self.lower, &rhs.lower), &self.upper),
&rhs.upper,
),
max_of_bounds(
&max_of_bounds(&max_of_bounds(&self.lower, &rhs.lower), &self.upper),
&rhs.upper,
),
)
} else {
(
min_of_bounds(&self.lower, &rhs.lower),
max_of_bounds(&self.upper, &rhs.upper),
)
let lower_bound = match (&self.lower.is_null(), &rhs.lower.is_null()) {
(false, false) => Some(min_of_bounds(&self.lower, &rhs.lower)),
(false, true) => Some(self.lower.clone()),
(true, false) => Some(rhs.lower.clone()),
(true, true) => None,
};
let upper_bound = match (&self.upper.is_null(), &rhs.upper.is_null()) {
(false, false) => Some(max_of_bounds(&self.upper, &rhs.upper)),
(false, true) => Some(self.upper.clone()),
(true, false) => Some(rhs.upper.clone()),
(true, true) => None,
};

// New lower and upper bounds must always construct a valid interval.
assert!(
lower.is_null() || upper.is_null() || (lower <= upper),
"The union of two intervals can not be an invalid interval"
);
// If the intervals overlap or touch, return a single merged interval
if (self.is_not_null() && rhs.is_not_null()) && self.upper >= rhs.lower
|| rhs.upper >= self.lower
{
return Ok(Some(Self::new(lower_bound.unwrap(), upper_bound.unwrap())));
}

Ok(Some(Self { lower, upper }))
// Handle non-overlapping intervals since interval sets are not supported
// TODO: with interval sets, we should return a set of disjoint intervals
let mut lower_value =
lower_bound.unwrap_or(ScalarValue::try_from(self.lower.data_type())?);
let mut upper_value =
upper_bound.unwrap_or(ScalarValue::try_from(self.lower.data_type())?);

// If both directions are unbounded, return unbounded set
// e.g. {10,None} ∪ {None,2} = {None, None}
if !(lower_value.is_null() || upper_value.is_null()) && lower_value > upper_value
{
return Ok(Some(Self::make_unbounded(&lower_value.data_type())?));
}

// If only one direction has a bound, the other direction is unbounded
// e.g. {5,NULL} ∪ {1,2} = {1,NULL}
if (self.upper.is_null() && !self.lower.is_null() && self.lower > rhs.upper)
|| (rhs.upper.is_null() && !rhs.lower.is_null() && rhs.lower > self.upper)
{
upper_value = ScalarValue::try_new_null(&self.upper.data_type())?;
}
if (self.lower.is_null() && !self.upper.is_null() && self.upper < rhs.lower)
|| (rhs.lower.is_null() && !rhs.upper.is_null() && rhs.upper < self.lower)
{
lower_value = ScalarValue::try_new_null(&self.lower.data_type())?;
}

Ok(Some(Interval::new(lower_value, upper_value)))
}

/// Compute the intersection of this interval with the given interval.
Expand Down Expand Up @@ -2850,12 +2859,12 @@ mod tests {
(
Interval::make(Some(1000_i64), None)?,
Interval::make(None, Some(10_i64))?,
Interval::make(Some(10_i64), Some(1000_i64))?,
Interval::make::<i64>(None, None)?,
),
(
Interval::make(Some(1000_i64), None)?,
Interval::make(Some(1_i64), Some(10_i64))?,
Interval::make(Some(1_i64), Some(1000_i64))?,
Interval::make(Some(1_i64), None)?,
),
(
Interval::make(None, Some(2000_u64))?,
Expand Down

0 comments on commit b4bd851

Please sign in to comment.