Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression in CASE expression #14283

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ impl CaseExpr {
fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;

// evalute when condition on batch
// evaluate when condition on batch
let when_value = self.when_then_expr[0].0.evaluate(batch)?;
let when_value = when_value.into_array(batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|e| {
Expand All @@ -421,6 +421,8 @@ impl CaseExpr {
)
})?;

let current_value = new_null_array(&return_type, batch.num_rows());

// Treat 'NULL' as false value
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
Expand All @@ -429,8 +431,20 @@ impl CaseExpr {

let then_value = self.when_then_expr[0]
.1
.evaluate_selection(batch, &when_value)?
.into_array(batch.num_rows())?;
.evaluate_selection(batch, &when_value)?;

let then_value = match then_value {
ColumnarValue::Scalar(ScalarValue::Null) => {
nullif(current_value.as_ref(), &when_value)?
}
ColumnarValue::Scalar(then_value) => {
zip(&when_value, &then_value.to_scalar()?, &current_value)?
}
ColumnarValue::Array(then_value) => {
zip(&when_value, &then_value, &current_value)?
}
};
let then_value = zip(&when_value, &then_value, &current_value)?;

// evaluate else expression on the values not covered by when_value
let remainder = not(&when_value)?;
Expand Down Expand Up @@ -574,7 +588,7 @@ pub fn case(
mod tests {
use super::*;

use crate::expressions::{binary, cast, col, lit, BinaryExpr};
use crate::expressions::{binary, cast, col, lit, BinaryExpr, IsNullExpr};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType::Float64;
use arrow::datatypes::*;
Expand Down Expand Up @@ -795,6 +809,21 @@ mod tests {
Ok(batch)
}

fn case_test_batch2() -> Result<RecordBatch> {
let schema = Schema::new(vec![Field::new(
"a",
DataType::List(Field::new("item", DataType::Int32, true).into()),
true,
)]);
let int_builder = Int32Builder::new();
let mut list_builder = ListBuilder::with_capacity(int_builder, 2);
list_builder.append_value(vec![Some(1), Some(2), Some(3)]);
list_builder.append_null();
let a = list_builder.finish();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
Ok(batch)
}

#[test]
fn case_without_expr_else() -> Result<()> {
let batch = case_test_batch()?;
Expand Down Expand Up @@ -1311,6 +1340,35 @@ mod tests {
Ok(())
}

/// Regression test for https://github.com/apache/datafusion/issues/14277
#[test]
fn issue_14277() -> Result<()> {
let batch = case_test_batch2()?;
let schema = batch.schema();
let when = Arc::new(IsNullExpr::new(col("a", &schema)?));
let then = Arc::new(Literal::new(ScalarValue::Null));
let else_expr = col("a", &schema)?;
let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
assert!(matches!(
expr.eval_method,
EvalMethod::ExpressionOrExpression
));
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_list_array(&result);
assert_eq!(2, result.len());
assert!(!result.is_null(0));
assert!(result.is_null(1));
let int32_array = result.value(0);
let int32_array =
as_int32_array(&int32_array).expect("failed to downcast to Int32Array");
let expected = &Int32Array::from(vec![Some(1), Some(2), Some(3)]);
assert_eq!(expected, int32_array);
Ok(())
}

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}
Expand Down
Loading