Skip to content

Commit

Permalink
fix: graceful NULL and type error handling in array functions (#14737)
Browse files Browse the repository at this point in the history
* feat: arbitrary typed argument in array function

* fix: array_sort null handling

* fix: array_resize signature

* test: add array_sort sqllogictest for null and invalid types

* fix: don't match error message

* chore: use string instead of data type

* refactor: use new_null_array

* fix: pass null to array argument should return null

* fix: handle null argument for array in replace and resize

* fix: mismatched error message

* fix: incorrect number of rows returned

* test: update null tests

* fix: treat NULLs as lists directly to prevent extra handling

* fix: incorrect null pushing in array_sort
  • Loading branch information
alan910127 authored Mar 6, 2025
1 parent 9a4c9d5 commit 43ecd9b
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 25 deletions.
5 changes: 5 additions & 0 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ pub enum ArrayFunctionArgument {
/// An argument of type List/LargeList/FixedSizeList. All Array arguments must be coercible
/// to the same type.
Array,
// A Utf8 argument.
String,
}

impl Display for ArrayFunctionArgument {
Expand All @@ -372,6 +374,9 @@ impl Display for ArrayFunctionArgument {
ArrayFunctionArgument::Array => {
write!(f, "array")
}
ArrayFunctionArgument::String => {
write!(f, "string")
}
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::binary::{binary_numeric_coercion, comparison_coercion};
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
datatypes::{DataType, Field, TimeUnit},
};
use datafusion_common::types::LogicalType;
use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
Expand Down Expand Up @@ -387,7 +387,7 @@ fn get_valid_types(
new_base_type =
coerce_array_types(function_name, current_type, &new_base_type)?;
}
ArrayFunctionArgument::Index => {}
ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {}
}
}
let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
Expand All @@ -408,6 +408,7 @@ fn get_valid_types(
let valid_type = match argument_type {
ArrayFunctionArgument::Element => new_elem_type.clone(),
ArrayFunctionArgument::Index => DataType::Int64,
ArrayFunctionArgument::String => DataType::Utf8,
ArrayFunctionArgument::Array => {
let Some(current_type) = array(current_type) else {
return Ok(vec![vec![]]);
Expand Down Expand Up @@ -435,6 +436,10 @@ fn get_valid_types(
match array_type {
DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field(
DataType::Int64,
true,
)))),
_ => None,
}
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ impl ScalarUDFImpl for ArrayElement {
List(field)
| LargeList(field)
| FixedSizeList(field, _) => Ok(field.data_type().clone()),
DataType::Null => Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true)))),
_ => plan_err!(
"ArrayElement can only accept List, LargeList or FixedSizeList as the first argument"
),
Expand Down
7 changes: 5 additions & 2 deletions datafusion/functions-nested/src/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions.
use arrow::array::{
Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
NullBufferBuilder, OffsetSizeTrait,
new_null_array, Array, ArrayRef, AsArray, Capacities, GenericListArray,
MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
};
use arrow::datatypes::{DataType, Field};

Expand Down Expand Up @@ -429,6 +429,7 @@ pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type:?}'."),
}
}
Expand All @@ -447,6 +448,7 @@ pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_n does not support type '{array_type:?}'.")
}
Expand All @@ -467,6 +469,7 @@ pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_all does not support type '{array_type:?}'.")
}
Expand Down
32 changes: 28 additions & 4 deletions datafusion/functions-nested/src/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ use arrow::array::{
MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::ArrowNativeType;
use arrow::datatypes::DataType;
use arrow::datatypes::{ArrowNativeType, Field};
use arrow::datatypes::{
DataType::{FixedSizeList, LargeList, List},
FieldRef,
};
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
use datafusion_common::utils::ListCoercion;
use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
Expand Down Expand Up @@ -83,7 +85,26 @@ impl Default for ArrayResize {
impl ArrayResize {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Index,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
}),
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Index,
ArrayFunctionArgument::Element,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
}),
],
Volatility::Immutable,
),
aliases: vec!["list_resize".to_string()],
}
}
Expand All @@ -106,6 +127,9 @@ impl ScalarUDFImpl for ArrayResize {
match &arg_types[0] {
List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
LargeList(field) => Ok(LargeList(Arc::clone(field))),
DataType::Null => {
Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true))))
}
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
),
Expand Down Expand Up @@ -137,7 +161,7 @@ pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
let array = &arg[0];

// Checks if entire array is null
if array.null_count() == array.len() {
if array.logical_null_count() == array.len() {
let return_type = match array.data_type() {
List(field) => List(Arc::clone(field)),
LargeList(field) => LargeList(Arc::clone(field)),
Expand Down
51 changes: 42 additions & 9 deletions datafusion/functions-nested/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
//! [`ScalarUDFImpl`] definitions for array_sort function.
use crate::utils::make_scalar_function;
use arrow::array::{Array, ArrayRef, ListArray, NullBufferBuilder};
use arrow::array::{new_null_array, Array, ArrayRef, ListArray, NullBufferBuilder};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::DataType::{FixedSizeList, LargeList, List};
use arrow::datatypes::{DataType, Field};
use arrow::{compute, compute::SortOptions};
use datafusion_common::cast::{as_list_array, as_string_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
Expand Down Expand Up @@ -87,7 +88,30 @@ impl Default for ArraySort {
impl ArraySort {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![ArrayFunctionArgument::Array],
array_coercion: None,
}),
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::String,
],
array_coercion: None,
}),
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::String,
ArrayFunctionArgument::String,
],
array_coercion: None,
}),
],
Volatility::Immutable,
),
aliases: vec!["list_sort".to_string()],
}
}
Expand Down Expand Up @@ -115,6 +139,7 @@ impl ScalarUDFImpl for ArraySort {
field.data_type().clone(),
true,
)))),
DataType::Null => Ok(DataType::Null),
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
),
Expand Down Expand Up @@ -143,6 +168,10 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
return exec_err!("array_sort expects one to three arguments");
}

if args[1..].iter().any(|array| array.is_null(0)) {
return Ok(new_null_array(args[0].data_type(), args[0].len()));
}

let sort_option = match args.len() {
1 => None,
2 => {
Expand Down Expand Up @@ -196,12 +225,16 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = ListArray::new(
Arc::new(Field::new_list_field(data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
buffer,
);
let list_arr = if elements.is_empty() {
ListArray::new_null(Arc::new(Field::new_list_field(data_type, true)), row_count)
} else {
ListArray::new(
Arc::new(Field::new_list_field(data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
buffer,
)
};
Ok(Arc::new(list_arr))
}

Expand Down
69 changes: 61 additions & 8 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1204,8 +1204,10 @@ select array_element([1, 2], NULL);
----
NULL

query error
query I
select array_element(NULL, 2);
----
NULL

# array_element scalar function #1 (with positive index)
query IT
Expand Down Expand Up @@ -2265,6 +2267,52 @@ select array_sort([]);
----
[]

# test with null arguments
query ?
select array_sort(NULL);
----
NULL

query ?
select array_sort(column1, NULL) from arrays_values;
----
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL

query ??
select array_sort(column1, 'DESC', NULL), array_sort(column1, 'ASC', NULL) from arrays_values;
----
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL

query ??
select array_sort(column1, NULL, 'NULLS FIRST'), array_sort(column1, NULL, 'NULLS LAST') from arrays_values;
----
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL
NULL NULL

## test with argument of incorrect types
query error DataFusion error: Execution error: the second parameter of array_sort expects DESC or ASC
select array_sort([1, 3, null, 5, NULL, -5], 1), array_sort([1, 3, null, 5, NULL, -5], 'DESC', 1), array_sort([1, 3, null, 5, NULL, -5], 1, 1);

# test with empty row, the row that does not match the condition has row count 0
statement ok
create table t1(a int, b int) as values (100, 1), (101, 2), (102, 3), (101, 2);
Expand All @@ -2290,8 +2338,10 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3

# array_append with NULLs

query error
query ?
select array_append(null, 1);
----
[1]

query error
select array_append(null, [2, 3]);
Expand Down Expand Up @@ -2539,8 +2589,10 @@ select array_append(column1, arrow_cast(make_array(1, 11, 111), 'FixedSizeList(3
# DuckDB: [4]
# ClickHouse: Null
# Since they dont have the same result, we just follow Postgres, return error
query error
query ?
select array_prepend(4, NULL);
----
[4]

query ?
select array_prepend(4, []);
Expand Down Expand Up @@ -2575,11 +2627,10 @@ select array_prepend(null, [[1,2,3]]);
query error
select array_prepend([], []);

# DuckDB: [null]
# ClickHouse: [null]
# TODO: We may also return [null]
query error
query ?
select array_prepend(null, null);
----
[NULL]

query ?
select array_append([], null);
Expand Down Expand Up @@ -5264,9 +5315,11 @@ NULL [3] [5]
# array_ndims scalar function #1

#follow PostgreSQL
query error
query I
select
array_ndims(null);
----
NULL

query I
select
Expand Down

0 comments on commit 43ecd9b

Please sign in to comment.