Skip to content

Commit

Permalink
treat null as unbounded
Browse files Browse the repository at this point in the history
  • Loading branch information
ch-sc committed Feb 18, 2025
1 parent aaf116b commit 5d65ba1
Showing 1 changed file with 32 additions and 59 deletions.
91 changes: 32 additions & 59 deletions datafusion/expr-common/src/interval_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -622,65 +622,38 @@ impl Interval {
/// NOTE: This function only works with intervals of the same data type.
/// Attempting to compare intervals of different data types will lead
/// to an error.
pub fn union<T: Borrow<Self>>(&self, other: T) -> Result<Option<Self>> {
pub fn union<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
let rhs = other.borrow();
if self.data_type().ne(&rhs.data_type()) {
return internal_err!(
"Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}",
self.data_type(),
rhs.data_type()
);
};

let lower_bound = match (&self.lower.is_null(), &rhs.lower.is_null()) {
(false, false) => Some(min_of_bounds(&self.lower, &rhs.lower)),
(false, true) => Some(self.lower.clone()),
(true, false) => Some(rhs.lower.clone()),
(true, true) => None,
};
let upper_bound = match (&self.upper.is_null(), &rhs.upper.is_null()) {
(false, false) => Some(max_of_bounds(&self.upper, &rhs.upper)),
(false, true) => Some(self.upper.clone()),
(true, false) => Some(rhs.upper.clone()),
(true, true) => None,
"Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}",
self.data_type(),
rhs.data_type()
);
};

// If the intervals overlap or touch, return a single merged interval
if self.is_not_null()
&& rhs.is_not_null()
&& (self.upper >= rhs.lower || rhs.upper >= self.lower)
let lower = if self.lower.is_null()
|| (!rhs.lower.is_null() && self.lower <= rhs.lower)
{
return Ok(Some(Self::new(lower_bound.unwrap(), upper_bound.unwrap())));
}

// Handle non-overlapping intervals since interval sets are not supported
// TODO: with interval sets, we should return a set of disjoint intervals
let mut lower_value =
lower_bound.unwrap_or(ScalarValue::try_from(self.lower.data_type())?);
let mut upper_value =
upper_bound.unwrap_or(ScalarValue::try_from(self.lower.data_type())?);

// If both directions are unbounded, return unbounded set
// e.g. {10,None} ∪ {None,2} = {None, None}
if !(lower_value.is_null() || upper_value.is_null()) && lower_value > upper_value
self.lower.clone()
} else {
rhs.lower.clone()
};
let upper = if self.upper.is_null()
|| (!rhs.upper.is_null() && self.upper >= rhs.upper)
{
return Ok(Some(Self::make_unbounded(&lower_value.data_type())?));
}
self.upper.clone()
} else {
rhs.upper.clone()
};

// If only one direction has a bound, the other direction is unbounded
// e.g. {5,NULL} ∪ {1,2} = {1,NULL}
if (self.upper.is_null() && !self.lower.is_null() && self.lower > rhs.upper)
|| (rhs.upper.is_null() && !rhs.lower.is_null() && rhs.lower > self.upper)
{
upper_value = ScalarValue::try_new_null(&self.upper.data_type())?;
}
if (self.lower.is_null() && !self.upper.is_null() && self.upper < rhs.lower)
|| (rhs.lower.is_null() && !rhs.upper.is_null() && rhs.upper < self.lower)
{
lower_value = ScalarValue::try_new_null(&self.lower.data_type())?;
}
// New lower and upper bounds must always construct a valid interval.
debug_assert!(
(lower.is_null() || upper.is_null() || (lower <= upper)),
"The union of two intervals can not be an invalid interval"
);

Ok(Some(Interval::new(lower_value, upper_value)))
Ok(Self { lower, upper })
}

/// Compute the intersection of this interval with the given interval.
Expand Down Expand Up @@ -2815,7 +2788,7 @@ mod tests {
(
Interval::make(Some(1000_i64), None)?,
Interval::make::<i64>(None, None)?,
Interval::make(Some(1000_i64), None)?,
Interval::make_unbounded(&DataType::Int64)?,
),
(
Interval::make(Some(1000_i64), None)?,
Expand All @@ -2825,17 +2798,17 @@ mod tests {
(
Interval::make(Some(1000_i64), None)?,
Interval::make(None, Some(2000_i64))?,
Interval::make(Some(1000_i64), Some(2000_i64))?,
Interval::make_unbounded(&DataType::Int64)?,
),
(
Interval::make::<i64>(None, None)?,
Interval::make(Some(1000_i64), Some(2000_i64))?,
Interval::make(Some(1000_i64), Some(2000_i64))?,
Interval::make_unbounded(&DataType::Int64)?,
),
(
Interval::make(Some(1000_i64), Some(2000_i64))?,
Interval::make::<i64>(None, None)?,
Interval::make(Some(1000_i64), Some(2000_i64))?,
Interval::make_unbounded(&DataType::Int64)?,
),
(
Interval::make::<i64>(None, None)?,
Expand All @@ -2845,7 +2818,7 @@ mod tests {
(
Interval::make(Some(1000_i64), Some(2000_i64))?,
Interval::make(Some(1000_i64), None)?,
Interval::make(Some(1000_i64), Some(2000_i64))?,
Interval::make(Some(1000_i64), None)?,
),
(
Interval::make(Some(1000_i64), Some(2000_i64))?,
Expand All @@ -2870,17 +2843,17 @@ mod tests {
(
Interval::make(None, Some(2000_u64))?,
Interval::make(Some(500_u64), None)?,
Interval::make(Some(0_u64), Some(2000_u64))?,
Interval::make_unbounded(&DataType::UInt64)?,
),
(
Interval::make(Some(0_u64), Some(0_u64))?,
Interval::make(Some(0_u64), None)?,
Interval::make(Some(0_u64), Some(0_u64))?,
Interval::make(Some(0_u64), None)?,
),
(
Interval::make(Some(1000.0_f32), None)?,
Interval::make(None, Some(1000.0_f32))?,
Interval::make(Some(1000.0_f32), Some(1000.0_f32))?,
Interval::make_unbounded(&DataType::Float32)?,
),
(
Interval::make(Some(1000.0_f32), Some(1500.0_f32))?,
Expand All @@ -2904,7 +2877,7 @@ mod tests {
),
];
for (first, second, expected) in possible_cases {
assert_eq!(first.union(second.clone())?.unwrap(), expected)
assert_eq!(first.union(second.clone())?, expected)
}

Ok(())
Expand Down

0 comments on commit 5d65ba1

Please sign in to comment.