diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 78606f05ae81..2bfd18c03a7f 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -411,7 +411,7 @@ impl CaseExpr { fn expr_or_expr(&self, batch: &RecordBatch) -> Result { let return_type = self.data_type(&batch.schema())?; - // evalute when condition on batch + // evaluate when condition on batch let when_value = self.when_then_expr[0].0.evaluate(batch)?; let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { @@ -421,6 +421,8 @@ impl CaseExpr { ) })?; + let current_value = new_null_array(&return_type, batch.num_rows()); + // Treat 'NULL' as false value let when_value = match when_value.null_count() { 0 => Cow::Borrowed(when_value), @@ -429,8 +431,20 @@ impl CaseExpr { let then_value = self.when_then_expr[0] .1 - .evaluate_selection(batch, &when_value)? - .into_array(batch.num_rows())?; + .evaluate_selection(batch, &when_value)?; + + let then_value = match then_value { + ColumnarValue::Scalar(ScalarValue::Null) => { + nullif(current_value.as_ref(), &when_value)? + } + ColumnarValue::Scalar(then_value) => { + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + } + ColumnarValue::Array(then_value) => { + zip(&when_value, &then_value, ¤t_value)? + } + }; + let then_value = zip(&when_value, &then_value, ¤t_value)?; // evaluate else expression on the values not covered by when_value let remainder = not(&when_value)?; @@ -574,7 +588,7 @@ pub fn case( mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use crate::expressions::{binary, cast, col, lit, BinaryExpr, IsNullExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; @@ -795,6 +809,21 @@ mod tests { Ok(batch) } + fn case_test_batch2() -> Result { + let schema = Schema::new(vec![Field::new( + "a", + DataType::List(Field::new("item", DataType::Int32, true).into()), + true, + )]); + let int_builder = Int32Builder::new(); + let mut list_builder = ListBuilder::with_capacity(int_builder, 2); + list_builder.append_value(vec![Some(1), Some(2), Some(3)]); + list_builder.append_null(); + let a = list_builder.finish(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + Ok(batch) + } + #[test] fn case_without_expr_else() -> Result<()> { let batch = case_test_batch()?; @@ -1311,6 +1340,35 @@ mod tests { Ok(()) } + /// Regression test for https://github.com/apache/datafusion/issues/14277 + #[test] + fn issue_14277() -> Result<()> { + let batch = case_test_batch2()?; + let schema = batch.schema(); + let when = Arc::new(IsNullExpr::new(col("a", &schema)?)); + let then = Arc::new(Literal::new(ScalarValue::Null)); + let else_expr = col("a", &schema)?; + let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; + assert!(matches!( + expr.eval_method, + EvalMethod::ExpressionOrExpression + )); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_list_array(&result); + assert_eq!(2, result.len()); + assert!(!result.is_null(0)); + assert!(result.is_null(1)); + let int32_array = result.value(0); + let int32_array = + as_int32_array(&int32_array).expect("failed to downcast to Int32Array"); + let expected = &Int32Array::from(vec![Some(1), Some(2), Some(3)]); + assert_eq!(expected, int32_array); + Ok(()) + } + fn make_col(name: &str, index: usize) -> Arc { Arc::new(Column::new(name, index)) }