diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index a58a8cf51681..bc805383df36 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,14 +18,19 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions -use arrow::{array::AsArray, datatypes::Fields}; -use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use arrow::{array::AsArray, datatypes::Fields}; +use arrow_array::{ + types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray, +}; +use arrow_schema::Schema; + +use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; use datafusion::test_util::plan_and_collect; use datafusion::{ @@ -45,7 +50,8 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, + col, create_udaf, AggregateUDFImpl, GroupsAccumulator, LogicalPlanBuilder, + SimpleAggregateUDF, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -376,6 +382,56 @@ async fn test_groups_accumulator() -> Result<()> { Ok(()) } +#[ignore] +#[tokio::test] +async fn test_parameterized_aggregate_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + let udf1 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 1, + }); + let udf2 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 2, + }); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .aggregate( + [col("text")], + [ + udf1.call(vec![col("text")]).alias("a"), + udf2.call(vec![col("text")]).alias("b"), + ], + )? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+------+---+---+", + "| text | a | b |", + "+------+---+---+", + "| foo | 1 | 2 |", + "+------+---+---+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + /// Returns an context with a table "t" and the "first" and "time_sum" /// aggregate functions registered. /// @@ -733,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator { fn create_groups_accumulator(&self) -> Result> { Ok(Box::new(self.clone())) } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.result == other.result && self.signature == other.signature + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.signature.hash(hasher); + self.result.hash(hasher); + hasher.finish() + } } impl Accumulator for TestGroupsAccumulator { diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 86be887198ae..d0938aebd269 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,12 +15,23 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::iter; +use std::sync::Arc; + use arrow::compute::kernels::numeric::add; +use arrow_array::builder::BooleanBuilder; +use arrow_array::cast::AsArray; +use arrow_array::StringArray; use arrow_array::{ Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, }; use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; +use rand::{thread_rng, Rng}; +use regex::Regex; + use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -36,10 +47,6 @@ use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use rand::{thread_rng, Rng}; -use std::any::Any; -use std::iter; -use std::sync::Arc; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -961,6 +968,121 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { Ok(()) } +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + regex: Regex, +} + +impl MyRegexUdf { + fn new(pattern: &str) -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + regex: Regex::new(pattern).expect("regex"), + } + } + + fn matches(&self, value: Option<&str>) -> Option { + Some(self.regex.is_match(value?)) + } +} + +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regex_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Boolean) + } else { + plan_err!("regex_udf only accepts a Utf8 argument") + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(value))] => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + self.matches(value.as_deref()), + ))) + } + [ColumnarValue::Array(values)] => { + let mut builder = BooleanBuilder::with_capacity(values.len()); + for value in values.as_string::() { + builder.append_option(self.matches(value)) + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + _ => exec_err!("regex_udf only accepts a Utf8 arguments"), + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.regex.as_str() == other.regex.as_str() + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.regex.as_str().hash(hasher); + hasher.finish() + } +} + +#[tokio::test] +async fn test_parameterized_scalar_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}")); + let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar")); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .filter( + foo_udf + .call(vec![col("text")]) + .and(bar_udf.call(vec![col("text")])), + )? + .filter(col("text").is_not_null())? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+--------+", + "| text |", + "+--------+", + "| foobar |", + "| barfoo |", + "+--------+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index c46dd9cd3a6f..7ad65c64316f 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,16 +17,20 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; + +use arrow::datatypes::DataType; + +use datafusion_common::{not_impl_err, Result}; + use crate::groups_accumulator::GroupsAccumulator; use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; -use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, Result}; -use std::any::Any; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -66,16 +70,15 @@ pub struct AggregateUDF { impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for AggregateUDF {} -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -218,7 +221,7 @@ where /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature -/// }; +/// } /// /// impl GeoMeanUdf { /// fn new() -> Self { @@ -298,6 +301,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Return true if this aggregate UDF is equal to the other. + /// + /// Allows customizing the equality of aggregate UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this aggregate UDF. + /// + /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// AggregateUDF that adds an alias to the underlying function. It is better to @@ -348,6 +378,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 56266a05170b..ce6ff5a47b92 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,18 +17,20 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; + +use arrow::datatypes::DataType; + +use datafusion_common::{ExprSchema, Result}; + use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::{ ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; -use arrow::datatypes::DataType; -use datafusion_common::{ExprSchema, Result}; -use std::any::Any; -use std::fmt; -use std::fmt::Debug; -use std::fmt::Formatter; -use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// @@ -57,16 +59,15 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for ScalarUDF {} -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -224,7 +225,7 @@ where /// #[derive(Debug)] /// struct AddOne { /// signature: Signature -/// }; +/// } /// /// impl AddOne { /// fn new() -> Self { @@ -376,6 +377,33 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { ) -> Result { Ok(ExprSimplifyResult::Original(args)) } + + /// Return true if this scalar UDF is equal to the other. + /// + /// Allows customizing the equality of scalar UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this scalar UDF. + /// + /// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -393,7 +421,6 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } } } @@ -421,6 +448,21 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index d3925f2e1925..1fd1af86cf0a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -17,18 +17,22 @@ //! [`WindowUDF`]: User Defined Window Functions -use crate::{ - Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, - WindowFrame, -}; -use arrow::datatypes::DataType; -use datafusion_common::Result; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; +use arrow::datatypes::DataType; + +use datafusion_common::Result; + +use crate::{ + Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, + WindowFrame, +}; + /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// @@ -62,16 +66,15 @@ impl Display for WindowUDF { impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for WindowUDF {} -impl std::hash::Hash for WindowUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -205,7 +208,7 @@ where /// #[derive(Debug, Clone)] /// struct SmoothIt { /// signature: Signature -/// }; +/// } /// /// impl SmoothIt { /// fn new() -> Self { @@ -266,6 +269,33 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Return true if this window UDF is equal to the other. + /// + /// Allows customizing the equality of window UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this window UDF. + /// + /// Allows customizing the hash code of window UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// WindowUDF that adds an alias to the underlying function. It is better to @@ -312,6 +342,21 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c7907d0db16a..456c93b24a08 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -864,7 +864,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( } } Expr::Literal(_) => { - indexes.push(std::usize::MAX); + indexes.push(usize::MAX); } _ => {} }