From 71f9d0c1d69c2404422f9f19bf78f0062c6b2c5e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 17 Feb 2025 21:18:07 +0800 Subject: [PATCH] Signature::Coercible with user defined implicit casting (#14440) * coerciblev2 Signed-off-by: Jay Zhan * repeat Signed-off-by: Jay Zhan * fix possible types * replace all coerciblev1 * cleanup * remove specialize logic * comment * err msg * ci escape * rm coerciblev1 Signed-off-by: Jay Zhan * fmt * rename * rename * refactor * make default_casted_type private * cleanup * fmt * integer * rm binary for ascii * rm unused * conflit * fmt * Rename get_example_types, make method on TypeSignatureClass * Move more logic into TypeSignatureClass * fix docs * 46 * enum * fmt * fmt * doc * upd doc --------- Signed-off-by: Jay Zhan Co-authored-by: Andrew Lamb --- Cargo.lock | 1 + datafusion/catalog/src/information_schema.rs | 6 +- datafusion/common/src/types/native.rs | 9 + datafusion/expr-common/Cargo.toml | 1 + datafusion/expr-common/src/signature.rs | 359 +++++++++++++++--- .../expr/src/type_coercion/functions.rs | 107 ++---- .../functions/src/datetime/date_part.rs | 28 +- datafusion/functions/src/string/ascii.rs | 11 +- datafusion/functions/src/string/repeat.rs | 13 +- datafusion/sqllogictest/test_files/expr.slt | 12 +- 10 files changed, 403 insertions(+), 144 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 29d88d80aac1..bc8b2943b246 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2009,6 +2009,7 @@ version = "45.0.0" dependencies = [ "arrow", "datafusion-common", + "indexmap 2.7.1", "itertools 0.14.0", "paste", ] diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index e68e636989f8..7948c0299d39 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -405,7 +405,7 @@ fn get_udf_args_and_return_types( udf: &Arc, ) -> Result, Option)>> { let signature = udf.signature(); - let arg_types = signature.type_signature.get_possible_types(); + let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { Ok(vec![(vec![], None)]) } else { @@ -428,7 +428,7 @@ fn get_udaf_args_and_return_types( udaf: &Arc, ) -> Result, Option)>> { let signature = udaf.signature(); - let arg_types = signature.type_signature.get_possible_types(); + let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { Ok(vec![(vec![], None)]) } else { @@ -452,7 +452,7 @@ fn get_udwf_args_and_return_types( udwf: &Arc, ) -> Result, Option)>> { let signature = udwf.signature(); - let arg_types = signature.type_signature.get_possible_types(); + let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { Ok(vec![(vec![], None)]) } else { diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index a4c4dfc7b106..39c79b4b9974 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -198,6 +198,11 @@ impl LogicalType for NativeType { TypeSignature::Native(self) } + /// Returns the default casted type for the given arrow type + /// + /// For types like String or Date, multiple arrow types mapped to the same logical type + /// If the given arrow type is one of them, we return the same type + /// Otherwise, we define the default casted type for the given arrow type fn default_cast_for(&self, origin: &DataType) -> Result { use DataType::*; @@ -226,6 +231,10 @@ impl LogicalType for NativeType { (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), + // If given type is Date, return the same type + (Self::Date, origin) if matches!(origin, Date32 | Date64) => { + origin.to_owned() + } (Self::Date, _) => Date32, (Self::Time(tu), _) => match tu { TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 109d8e0b89a6..abc78a9f084b 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -39,5 +39,6 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true } +indexmap = { workspace = true } itertools = { workspace = true } paste = "^1.0" diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 4ca4961d7b63..ba6fadbf7235 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,11 +19,14 @@ //! and return types of functions in DataFusion. use std::fmt::Display; +use std::hash::Hash; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::types::{LogicalTypeRef, NativeType}; +use datafusion_common::internal_err; +use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; use datafusion_common::utils::ListCoercion; +use indexmap::IndexSet; use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. @@ -127,12 +130,11 @@ pub enum TypeSignature { Exact(Vec), /// One or more arguments belonging to the [`TypeSignatureClass`], in order. /// - /// For example, `Coercible(vec![logical_float64()])` accepts - /// arguments like `vec![Int32]` or `vec![Float32]` - /// since i32 and f32 can be cast to f64 + /// [`Coercion`] contains not only the desired type but also the allowed casts. + /// For example, if you expect a function has string type, but you also allow it to be casted from binary type. /// /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. - Coercible(Vec), + Coercible(Vec), /// One or more arguments coercible to a single, comparable type. /// /// Each argument will be coerced to a single type using the @@ -209,14 +211,13 @@ impl TypeSignature { #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] pub enum TypeSignatureClass { Timestamp, - Date, Time, Interval, Duration, Native(LogicalTypeRef), // TODO: // Numeric - // Integer + Integer, } impl Display for TypeSignatureClass { @@ -225,6 +226,89 @@ impl Display for TypeSignatureClass { } } +impl TypeSignatureClass { + /// Get example acceptable types for this `TypeSignatureClass` + /// + /// This is used for `information_schema` and can be used to generate + /// documentation or error messages. + fn get_example_types(&self) -> Vec { + match self { + TypeSignatureClass::Native(l) => get_data_types(l.native()), + TypeSignatureClass::Timestamp => { + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp( + TimeUnit::Nanosecond, + Some(TIMEZONE_WILDCARD.into()), + ), + ] + } + TypeSignatureClass::Time => { + vec![DataType::Time64(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Interval => { + vec![DataType::Interval(IntervalUnit::DayTime)] + } + TypeSignatureClass::Duration => { + vec![DataType::Duration(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Integer => { + vec![DataType::Int64] + } + } + } + + /// Does the specified `NativeType` match this type signature class? + pub fn matches_native_type( + self: &TypeSignatureClass, + logical_type: &NativeType, + ) -> bool { + if logical_type == &NativeType::Null { + return true; + } + + match self { + TypeSignatureClass::Native(t) if t.native() == logical_type => true, + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => true, + TypeSignatureClass::Time if logical_type.is_time() => true, + TypeSignatureClass::Interval if logical_type.is_interval() => true, + TypeSignatureClass::Duration if logical_type.is_duration() => true, + TypeSignatureClass::Integer if logical_type.is_integer() => true, + _ => false, + } + } + + /// What type would `origin_type` be casted to when casting to the specified native type? + pub fn default_casted_type( + &self, + native_type: &NativeType, + origin_type: &DataType, + ) -> datafusion_common::Result { + match self { + TypeSignatureClass::Native(logical_type) => { + logical_type.native().default_cast_for(origin_type) + } + // If the given type is already a timestamp, we don't change the unit and timezone + TypeSignatureClass::Timestamp if native_type.is_timestamp() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Time if native_type.is_time() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Interval if native_type.is_interval() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Duration if native_type.is_duration() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Integer if native_type.is_integer() => { + Ok(origin_type.to_owned()) + } + _ => internal_err!("May miss the matching logic in `matches_native_type`"), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { /// A function takes at least one List/LargeList/FixedSizeList argument. @@ -316,8 +400,8 @@ impl TypeSignature { TypeSignature::Comparable(num) => { vec![format!("Comparable({num})")] } - TypeSignature::Coercible(types) => { - vec![Self::join_types(types, ", ")] + TypeSignature::Coercible(coercions) => { + vec![Self::join_types(coercions, ", ")] } TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] @@ -371,44 +455,45 @@ impl TypeSignature { } } - /// get all possible types for the given `TypeSignature` + #[deprecated(since = "46.0.0", note = "See get_example_types instead")] pub fn get_possible_types(&self) -> Vec> { + self.get_example_types() + } + + /// Return example acceptable types for this `TypeSignature`' + /// + /// Returns a `Vec` for each argument to the function + /// + /// This is used for `information_schema` and can be used to generate + /// documentation or error messages. + pub fn get_example_types(&self) -> Vec> { match self { TypeSignature::Exact(types) => vec![types.clone()], TypeSignature::OneOf(types) => types .iter() - .flat_map(|type_sig| type_sig.get_possible_types()) + .flat_map(|type_sig| type_sig.get_example_types()) .collect(), TypeSignature::Uniform(arg_count, types) => types .iter() .cloned() .map(|data_type| vec![data_type; *arg_count]) .collect(), - TypeSignature::Coercible(types) => types + TypeSignature::Coercible(coercions) => coercions .iter() - .map(|logical_type| match logical_type { - TypeSignatureClass::Native(l) => get_data_types(l.native()), - TypeSignatureClass::Timestamp => { - vec![ - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp( - TimeUnit::Nanosecond, - Some(TIMEZONE_WILDCARD.into()), - ), - ] - } - TypeSignatureClass::Date => { - vec![DataType::Date64] - } - TypeSignatureClass::Time => { - vec![DataType::Time64(TimeUnit::Nanosecond)] - } - TypeSignatureClass::Interval => { - vec![DataType::Interval(IntervalUnit::DayTime)] - } - TypeSignatureClass::Duration => { - vec![DataType::Duration(TimeUnit::Nanosecond)] + .map(|c| { + let mut all_types: IndexSet = + c.desired_type().get_example_types().into_iter().collect(); + + if let Some(implicit_coercion) = c.implicit_coercion() { + let allowed_casts: Vec = implicit_coercion + .allowed_source_types + .iter() + .flat_map(|t| t.get_example_types()) + .collect(); + all_types.extend(allowed_casts); } + + all_types.into_iter().collect::>() }) .multi_cartesian_product() .collect(), @@ -466,6 +551,186 @@ fn get_data_types(native_type: &NativeType) -> Vec { } } +/// Represents type coercion rules for function arguments, specifying both the desired type +/// and optional implicit coercion rules for source types. +/// +/// # Examples +/// +/// ``` +/// use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; +/// use datafusion_common::types::{NativeType, logical_binary, logical_string}; +/// +/// // Exact coercion that only accepts timestamp types +/// let exact = Coercion::new_exact(TypeSignatureClass::Timestamp); +/// +/// // Implicit coercion that accepts string types but can coerce from binary types +/// let implicit = Coercion::new_implicit( +/// TypeSignatureClass::Native(logical_string()), +/// vec![TypeSignatureClass::Native(logical_binary())], +/// NativeType::String +/// ); +/// ``` +/// +/// There are two variants: +/// +/// * `Exact` - Only accepts arguments that exactly match the desired type +/// * `Implicit` - Accepts the desired type and can coerce from specified source types +#[derive(Debug, Clone, Eq, PartialOrd)] +pub enum Coercion { + /// Coercion that only accepts arguments exactly matching the desired type. + Exact { + /// The required type for the argument + desired_type: TypeSignatureClass, + }, + + /// Coercion that accepts the desired type and can implicitly coerce from other types. + Implicit { + /// The primary desired type for the argument + desired_type: TypeSignatureClass, + /// Rules for implicit coercion from other types + implicit_coercion: ImplicitCoercion, + }, +} + +impl Coercion { + pub fn new_exact(desired_type: TypeSignatureClass) -> Self { + Self::Exact { desired_type } + } + + /// Create a new coercion with implicit coercion rules. + /// + /// `allowed_source_types` defines the possible types that can be coerced to `desired_type`. + /// `default_casted_type` is the default type to be used for coercion if we cast from other types via `allowed_source_types`. + pub fn new_implicit( + desired_type: TypeSignatureClass, + allowed_source_types: Vec, + default_casted_type: NativeType, + ) -> Self { + Self::Implicit { + desired_type, + implicit_coercion: ImplicitCoercion { + allowed_source_types, + default_casted_type, + }, + } + } + + pub fn allowed_source_types(&self) -> &[TypeSignatureClass] { + match self { + Coercion::Exact { .. } => &[], + Coercion::Implicit { + implicit_coercion, .. + } => implicit_coercion.allowed_source_types.as_slice(), + } + } + + pub fn default_casted_type(&self) -> Option<&NativeType> { + match self { + Coercion::Exact { .. } => None, + Coercion::Implicit { + implicit_coercion, .. + } => Some(&implicit_coercion.default_casted_type), + } + } + + pub fn desired_type(&self) -> &TypeSignatureClass { + match self { + Coercion::Exact { desired_type } => desired_type, + Coercion::Implicit { desired_type, .. } => desired_type, + } + } + + pub fn implicit_coercion(&self) -> Option<&ImplicitCoercion> { + match self { + Coercion::Exact { .. } => None, + Coercion::Implicit { + implicit_coercion, .. + } => Some(implicit_coercion), + } + } +} + +impl Display for Coercion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Coercion({}", self.desired_type())?; + if let Some(implicit_coercion) = self.implicit_coercion() { + write!(f, ", implicit_coercion={implicit_coercion}",) + } else { + write!(f, ")") + } + } +} + +impl PartialEq for Coercion { + fn eq(&self, other: &Self) -> bool { + self.desired_type() == other.desired_type() + && self.implicit_coercion() == other.implicit_coercion() + } +} + +impl Hash for Coercion { + fn hash(&self, state: &mut H) { + self.desired_type().hash(state); + self.implicit_coercion().hash(state); + } +} + +/// Defines rules for implicit type coercion, specifying which source types can be +/// coerced and the default type to use when coercing. +/// +/// This is used by functions to specify which types they can accept via implicit +/// coercion in addition to their primary desired type. +/// +/// # Examples +/// +/// ``` +/// use arrow::datatypes::TimeUnit; +/// +/// use datafusion_expr_common::signature::{Coercion, ImplicitCoercion, TypeSignatureClass}; +/// use datafusion_common::types::{NativeType, logical_binary}; +/// +/// // Allow coercing from binary types to timestamp, coerce to specific timestamp unit and timezone +/// let implicit = Coercion::new_implicit( +/// TypeSignatureClass::Timestamp, +/// vec![TypeSignatureClass::Native(logical_binary())], +/// NativeType::Timestamp(TimeUnit::Second, None), +/// ); +/// ``` +#[derive(Debug, Clone, Eq, PartialOrd)] +pub struct ImplicitCoercion { + /// The types that can be coerced from via implicit casting + allowed_source_types: Vec, + + /// The default type to use when coercing from allowed source types. + /// This is particularly important for types like Timestamp that have multiple + /// possible configurations (different time units and timezones). + default_casted_type: NativeType, +} + +impl Display for ImplicitCoercion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ImplicitCoercion({:?}, default_type={:?})", + self.allowed_source_types, self.default_casted_type + ) + } +} + +impl PartialEq for ImplicitCoercion { + fn eq(&self, other: &Self) -> bool { + self.allowed_source_types == other.allowed_source_types + && self.default_casted_type == other.default_casted_type + } +} + +impl Hash for ImplicitCoercion { + fn hash(&self, state: &mut H) { + self.allowed_source_types.hash(state); + self.default_casted_type.hash(state); + } +} + /// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. /// /// DataFusion will automatically coerce (cast) argument types to one of the supported @@ -542,11 +807,9 @@ impl Signature { volatility, } } + /// Target coerce types in order - pub fn coercible( - target_types: Vec, - volatility: Volatility, - ) -> Self { + pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), volatility, @@ -721,14 +984,14 @@ mod tests { #[test] fn test_get_possible_types() { let type_signature = TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!(possible_types, vec![vec![DataType::Int32, DataType::Int64]]); let type_signature = TypeSignature::OneOf(vec![ TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]), TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]), ]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -742,7 +1005,7 @@ mod tests { TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]), TypeSignature::Exact(vec![DataType::Utf8]), ]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -754,7 +1017,7 @@ mod tests { let type_signature = TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Int64]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -764,10 +1027,10 @@ mod tests { ); let type_signature = TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Native(logical_int64()), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_int64())), ]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -779,14 +1042,14 @@ mod tests { let type_signature = TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![vec![DataType::Int32], vec![DataType::Int64]] ); let type_signature = TypeSignature::Numeric(2); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -804,7 +1067,7 @@ mod tests { ); let type_signature = TypeSignature::String(2); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 7fda92862be9..b471feca043f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,19 +21,15 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; +use datafusion_common::types::LogicalType; use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, - types::{LogicalType, NativeType}, - utils::list_ndims, - Result, + exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, + utils::list_ndims, Result, }; use datafusion_expr_common::signature::ArrayFunctionArgument; use datafusion_expr_common::{ - signature::{ - ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD, - TIMEZONE_WILDCARD, - }, + signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, type_coercion::binary::comparison_coercion_numeric, type_coercion::binary::string_coercion, }; @@ -604,75 +600,36 @@ fn get_valid_types( vec![vec![target_type; *num]] } } - TypeSignature::Coercible(target_types) => { - function_length_check( - function_name, - current_types.len(), - target_types.len(), - )?; - - // Aim to keep this logic as SIMPLE as possible! - // Make sure the corresponding test is covered - // If this function becomes COMPLEX, create another new signature! - fn can_coerce_to( - function_name: &str, - current_type: &DataType, - target_type_class: &TypeSignatureClass, - ) -> Result { - let logical_type: NativeType = current_type.into(); - - match target_type_class { - TypeSignatureClass::Native(native_type) => { - let target_type = native_type.native(); - if &logical_type == target_type { - return target_type.default_cast_for(current_type); - } - - if logical_type == NativeType::Null { - return target_type.default_cast_for(current_type); - } - - if target_type.is_integer() && logical_type.is_integer() { - return target_type.default_cast_for(current_type); - } - - internal_err!( - "Function '{function_name}' expects {target_type_class} but received {current_type}" - ) - } - // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - TypeSignatureClass::Timestamp - if logical_type == NativeType::String => - { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Date if logical_type.is_date() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Time if logical_type.is_time() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Interval if logical_type.is_interval() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Duration if logical_type.is_duration() => { - Ok(current_type.to_owned()) - } - _ => { - not_impl_err!("Function '{function_name}' got logical_type: {logical_type} with target_type_class: {target_type_class}") - } - } - } + TypeSignature::Coercible(param_types) => { + function_length_check(function_name, current_types.len(), param_types.len())?; let mut new_types = Vec::with_capacity(current_types.len()); - for (current_type, target_type_class) in - current_types.iter().zip(target_types.iter()) - { - let target_type = can_coerce_to(function_name, current_type, target_type_class)?; - new_types.push(target_type); + for (current_type, param) in current_types.iter().zip(param_types.iter()) { + let current_native_type: NativeType = current_type.into(); + + if param.desired_type().matches_native_type(¤t_native_type) { + let casted_type = param.desired_type().default_casted_type( + ¤t_native_type, + current_type, + )?; + + new_types.push(casted_type); + } else if param + .allowed_source_types() + .iter() + .any(|t| t.matches_native_type(¤t_native_type)) { + // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap + let default_casted_type = param.default_casted_type().unwrap(); + let casted_type = default_casted_type.default_cast_for(current_type)?; + new_types.push(casted_type); + } else { + return internal_err!( + "Expect {} but received {}, DataType: {}", + param.desired_type(), + current_native_type, + current_type + ); + } } vec![new_types] diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 9df91da67f39..49b7a4ec462a 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -27,6 +27,7 @@ use arrow::datatypes::DataType::{ }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_common::types::{logical_date, NativeType}; use datafusion_common::{ cast::{ @@ -44,7 +45,7 @@ use datafusion_expr::{ ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; #[user_doc( @@ -95,24 +96,29 @@ impl DatePartFunc { signature: Signature::one_of( vec![ TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Timestamp, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(Nanosecond, None), + ), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Date, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_date())), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Time, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Time), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Interval, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Interval), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Duration, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Duration), ]), ], Volatility::Immutable, diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 858eddc7c8f8..3832ad2a341d 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -19,9 +19,11 @@ use crate::utils::make_scalar_function; use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; use arrow::error::ArrowError; +use datafusion_common::types::logical_string; use datafusion_common::{internal_err, Result}; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr_common::signature::Coercion; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -61,7 +63,12 @@ impl Default for AsciiFunc { impl AsciiFunc { pub fn new() -> Self { Self { - signature: Signature::string(1, Volatility::Immutable), + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Native( + logical_string(), + ))], + Volatility::Immutable, + ), } } } diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 8253754c2b83..8fdbc3dd296f 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -26,11 +26,11 @@ use arrow::array::{ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; -use datafusion_common::types::{logical_int64, logical_string}; +use datafusion_common::types::{logical_int64, logical_string, NativeType}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; #[user_doc( @@ -67,8 +67,13 @@ impl RepeatFunc { Self { signature: Signature::coercible( vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Native(logical_int64()), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + // Accept all integer types but cast them to i64 + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), ], Volatility::Immutable, ), diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index a0264c43622f..7980b180ae68 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -324,6 +324,16 @@ SELECT ascii('x') ---- 120 +query I +SELECT ascii('222') +---- +50 + +query I +SELECT ascii('0xa') +---- +48 + query I SELECT ascii(NULL) ---- @@ -571,7 +581,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error DataFusion error: Error during planning: Internal error: Function 'repeat' expects TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received Float64 +query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64 select repeat('-1.2', 3.2); query T