diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index ba6fadbf7235..063417a254be 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -358,6 +358,8 @@ pub enum ArrayFunctionArgument { /// An argument of type List/LargeList/FixedSizeList. All Array arguments must be coercible /// to the same type. Array, + // A Utf8 argument. + String, } impl Display for ArrayFunctionArgument { @@ -372,6 +374,9 @@ impl Display for ArrayFunctionArgument { ArrayFunctionArgument::Array => { write!(f, "array") } + ArrayFunctionArgument::String => { + write!(f, "string") + } } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b471feca043f..0ec017bdc27f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -19,7 +19,7 @@ use super::binary::{binary_numeric_coercion, comparison_coercion}; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; use arrow::{ compute::can_cast_types, - datatypes::{DataType, TimeUnit}, + datatypes::{DataType, Field, TimeUnit}, }; use datafusion_common::types::LogicalType; use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; @@ -387,7 +387,7 @@ fn get_valid_types( new_base_type = coerce_array_types(function_name, current_type, &new_base_type)?; } - ArrayFunctionArgument::Index => {} + ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {} } } let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( @@ -408,6 +408,7 @@ fn get_valid_types( let valid_type = match argument_type { ArrayFunctionArgument::Element => new_elem_type.clone(), ArrayFunctionArgument::Index => DataType::Int64, + ArrayFunctionArgument::String => DataType::Utf8, ArrayFunctionArgument::Array => { let Some(current_type) = array(current_type) else { return Ok(vec![vec![]]); @@ -435,6 +436,10 @@ fn get_valid_types( match array_type { DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), + DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field( + DataType::Int64, + true, + )))), _ => None, } } diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 422b1b612850..0f50f62dd8d2 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -166,6 +166,7 @@ impl ScalarUDFImpl for ArrayElement { List(field) | LargeList(field) | FixedSizeList(field, _) => Ok(field.data_type().clone()), + DataType::Null => Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true)))), _ => plan_err!( "ArrayElement can only accept List, LargeList or FixedSizeList as the first argument" ), diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 71bfedb72d1c..3dbe672c5b02 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -18,8 +18,8 @@ //! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions. use arrow::array::{ - Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, - NullBufferBuilder, OffsetSizeTrait, + new_null_array, Array, ArrayRef, AsArray, Capacities, GenericListArray, + MutableArrayData, NullBufferBuilder, OffsetSizeTrait, }; use arrow::datatypes::{DataType, Field}; @@ -429,6 +429,7 @@ pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { let list_array = array.as_list::(); general_replace::(list_array, from, to, arr_n) } + DataType::Null => Ok(new_null_array(array.data_type(), 1)), array_type => exec_err!("array_replace does not support type '{array_type:?}'."), } } @@ -447,6 +448,7 @@ pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { let list_array = array.as_list::(); general_replace::(list_array, from, to, arr_n) } + DataType::Null => Ok(new_null_array(array.data_type(), 1)), array_type => { exec_err!("array_replace_n does not support type '{array_type:?}'.") } @@ -467,6 +469,7 @@ pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result { let list_array = array.as_list::(); general_replace::(list_array, from, to, arr_n) } + DataType::Null => Ok(new_null_array(array.data_type(), 1)), array_type => { exec_err!("array_replace_all does not support type '{array_type:?}'.") } diff --git a/datafusion/functions-nested/src/resize.rs b/datafusion/functions-nested/src/resize.rs index 6c0b91a678e7..145d7e80043b 100644 --- a/datafusion/functions-nested/src/resize.rs +++ b/datafusion/functions-nested/src/resize.rs @@ -23,16 +23,18 @@ use arrow::array::{ MutableArrayData, NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::ArrowNativeType; use arrow::datatypes::DataType; +use arrow::datatypes::{ArrowNativeType, Field}; use arrow::datatypes::{ DataType::{FixedSizeList, LargeList, List}, FieldRef, }; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::utils::ListCoercion; use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -83,7 +85,26 @@ impl Default for ArrayResize { impl ArrayResize { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }), + ], + Volatility::Immutable, + ), aliases: vec!["list_resize".to_string()], } } @@ -106,6 +127,9 @@ impl ScalarUDFImpl for ArrayResize { match &arg_types[0] { List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))), LargeList(field) => Ok(LargeList(Arc::clone(field))), + DataType::Null => { + Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true)))) + } _ => exec_err!( "Not reachable, data_type should be List, LargeList or FixedSizeList" ), @@ -137,7 +161,7 @@ pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { let array = &arg[0]; // Checks if entire array is null - if array.null_count() == array.len() { + if array.logical_null_count() == array.len() { let return_type = match array.data_type() { List(field) => List(Arc::clone(field)), LargeList(field) => LargeList(Arc::clone(field)), diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index 7dbf9f2b211e..1db245fe52fe 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -18,7 +18,7 @@ //! [`ScalarUDFImpl`] definitions for array_sort function. use crate::utils::make_scalar_function; -use arrow::array::{Array, ArrayRef, ListArray, NullBufferBuilder}; +use arrow::array::{new_null_array, Array, ArrayRef, ListArray, NullBufferBuilder}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType::{FixedSizeList, LargeList, List}; use arrow::datatypes::{DataType, Field}; @@ -26,7 +26,8 @@ use arrow::{compute, compute::SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -87,7 +88,30 @@ impl Default for ArraySort { impl ArraySort { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::String, + ], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::String, + ArrayFunctionArgument::String, + ], + array_coercion: None, + }), + ], + Volatility::Immutable, + ), aliases: vec!["list_sort".to_string()], } } @@ -115,6 +139,7 @@ impl ScalarUDFImpl for ArraySort { field.data_type().clone(), true, )))), + DataType::Null => Ok(DataType::Null), _ => exec_err!( "Not reachable, data_type should be List, LargeList or FixedSizeList" ), @@ -143,6 +168,10 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { return exec_err!("array_sort expects one to three arguments"); } + if args[1..].iter().any(|array| array.is_null(0)) { + return Ok(new_null_array(args[0].data_type(), args[0].len())); + } + let sort_option = match args.len() { 1 => None, 2 => { @@ -196,12 +225,16 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { .map(|a| a.as_ref()) .collect::>(); - let list_arr = ListArray::new( - Arc::new(Field::new_list_field(data_type, true)), - OffsetBuffer::from_lengths(array_lengths), - Arc::new(compute::concat(elements.as_slice())?), - buffer, - ); + let list_arr = if elements.is_empty() { + ListArray::new_null(Arc::new(Field::new_list_field(data_type, true)), row_count) + } else { + ListArray::new( + Arc::new(Field::new_list_field(data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + buffer, + ) + }; Ok(Arc::new(list_arr)) } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6b5b246aee51..c8f6a985bb22 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1204,8 +1204,10 @@ select array_element([1, 2], NULL); ---- NULL -query error +query I select array_element(NULL, 2); +---- +NULL # array_element scalar function #1 (with positive index) query IT @@ -2265,6 +2267,52 @@ select array_sort([]); ---- [] +# test with null arguments +query ? +select array_sort(NULL); +---- +NULL + +query ? +select array_sort(column1, NULL) from arrays_values; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query ?? +select array_sort(column1, 'DESC', NULL), array_sort(column1, 'ASC', NULL) from arrays_values; +---- +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL + +query ?? +select array_sort(column1, NULL, 'NULLS FIRST'), array_sort(column1, NULL, 'NULLS LAST') from arrays_values; +---- +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL +NULL NULL + +## test with argument of incorrect types +query error DataFusion error: Execution error: the second parameter of array_sort expects DESC or ASC +select array_sort([1, 3, null, 5, NULL, -5], 1), array_sort([1, 3, null, 5, NULL, -5], 'DESC', 1), array_sort([1, 3, null, 5, NULL, -5], 1, 1); + # test with empty row, the row that does not match the condition has row count 0 statement ok create table t1(a int, b int) as values (100, 1), (101, 2), (102, 3), (101, 2); @@ -2290,8 +2338,10 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3 # array_append with NULLs -query error +query ? select array_append(null, 1); +---- +[1] query error select array_append(null, [2, 3]); @@ -2539,8 +2589,10 @@ select array_append(column1, arrow_cast(make_array(1, 11, 111), 'FixedSizeList(3 # DuckDB: [4] # ClickHouse: Null # Since they dont have the same result, we just follow Postgres, return error -query error +query ? select array_prepend(4, NULL); +---- +[4] query ? select array_prepend(4, []); @@ -2575,11 +2627,10 @@ select array_prepend(null, [[1,2,3]]); query error select array_prepend([], []); -# DuckDB: [null] -# ClickHouse: [null] -# TODO: We may also return [null] -query error +query ? select array_prepend(null, null); +---- +[NULL] query ? select array_append([], null); @@ -5264,9 +5315,11 @@ NULL [3] [5] # array_ndims scalar function #1 #follow PostgreSQL -query error +query I select array_ndims(null); +---- +NULL query I select