From 6e3afd2d040981c93cad7adb29334eabf34fdfe1 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Fri, 12 Jul 2024 09:27:34 +0300 Subject: [PATCH] Add tests for encoding and decoding UDAF --- .../physical-expr-common/src/aggregate/mod.rs | 5 + datafusion/proto/proto/datafusion.proto | 3 +- datafusion/proto/src/generated/pbjson.rs | 18 ++ datafusion/proto/src/generated/prost.rs | 4 +- datafusion/proto/src/logical_plan/to_proto.rs | 8 +- datafusion/proto/src/physical_plan/mod.rs | 7 +- .../proto/src/physical_plan/to_proto.rs | 122 ++++----- datafusion/proto/tests/cases/mod.rs | 99 +++++++ .../tests/cases/roundtrip_logical_plan.rs | 171 +++++------- .../tests/cases/roundtrip_physical_plan.rs | 251 +++++++++++------- 10 files changed, 416 insertions(+), 272 deletions(-) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index db4581a622ac..0e245fd0a66a 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -283,6 +283,11 @@ impl AggregateFunctionExpr { pub fn is_distinct(&self) -> bool { self.is_distinct } + + /// Return if the aggregation ignores nulls + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } } impl AggregateExpr for AggregateFunctionExpr { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index afe6583c920b..dc551778c5fb 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -858,7 +858,8 @@ message PhysicalAggregateExprNode { repeated PhysicalExprNode expr = 2; repeated PhysicalSortExprNode ordering_req = 5; bool distinct = 3; - optional bytes fun_definition = 6; + bool ignore_nulls = 6; + optional bytes fun_definition = 7; } message PhysicalWindowExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 31923e6dfb41..8f77c24bd911 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -12652,6 +12652,9 @@ impl serde::Serialize for PhysicalAggregateExprNode { if self.distinct { len += 1; } + if self.ignore_nulls { + len += 1; + } if self.fun_definition.is_some() { len += 1; } @@ -12668,6 +12671,9 @@ impl serde::Serialize for PhysicalAggregateExprNode { if self.distinct { struct_ser.serialize_field("distinct", &self.distinct)?; } + if self.ignore_nulls { + struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?; + } if let Some(v) = self.fun_definition.as_ref() { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; @@ -12698,6 +12704,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "ordering_req", "orderingReq", "distinct", + "ignore_nulls", + "ignoreNulls", "fun_definition", "funDefinition", "aggr_function", @@ -12711,6 +12719,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { Expr, OrderingReq, Distinct, + IgnoreNulls, FunDefinition, AggrFunction, UserDefinedAggrFunction, @@ -12738,6 +12747,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "expr" => Ok(GeneratedField::Expr), "orderingReq" | "ordering_req" => Ok(GeneratedField::OrderingReq), "distinct" => Ok(GeneratedField::Distinct), + "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), @@ -12763,6 +12773,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { let mut expr__ = None; let mut ordering_req__ = None; let mut distinct__ = None; + let mut ignore_nulls__ = None; let mut fun_definition__ = None; let mut aggregate_function__ = None; while let Some(k) = map_.next_key()? { @@ -12785,6 +12796,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { } distinct__ = Some(map_.next_value()?); } + GeneratedField::IgnoreNulls => { + if ignore_nulls__.is_some() { + return Err(serde::de::Error::duplicate_field("ignoreNulls")); + } + ignore_nulls__ = Some(map_.next_value()?); + } GeneratedField::FunDefinition => { if fun_definition__.is_some() { return Err(serde::de::Error::duplicate_field("funDefinition")); @@ -12811,6 +12828,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { expr: expr__.unwrap_or_default(), ordering_req: ordering_req__.unwrap_or_default(), distinct: distinct__.unwrap_or_default(), + ignore_nulls: ignore_nulls__.unwrap_or_default(), fun_definition: fun_definition__, aggregate_function: aggregate_function__, }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 743cc6ecccc8..605c56fa946a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1295,7 +1295,9 @@ pub struct PhysicalAggregateExprNode { pub ordering_req: ::prost::alloc::vec::Vec, #[prost(bool, tag = "3")] pub distinct: bool, - #[prost(bytes = "vec", optional, tag = "6")] + #[prost(bool, tag = "6")] + pub ignore_nulls: bool, + #[prost(bytes = "vec", optional, tag = "7")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")] pub aggregate_function: ::core::option::Option< diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7d30f13ad428..9607b918eb89 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -339,7 +339,7 @@ pub fn serialize_expr( protobuf::window_expr_node::WindowFunction::Udaf( aggr_udf.name().to_string(), ), - if buf.is_empty() { None } else { Some(buf) }, + (!buf.is_empty()).then_some(buf), ) } WindowFunctionDefinition::WindowUDF(window_udf) => { @@ -349,7 +349,7 @@ pub fn serialize_expr( protobuf::window_expr_node::WindowFunction::Udwf( window_udf.name().to_string(), ), - if buf.is_empty() { None } else { Some(buf) }, + (!buf.is_empty()).then_some(buf), ) } }; @@ -427,7 +427,7 @@ pub fn serialize_expr( Some(e) => serialize_exprs(e, codec)?, None => vec![], }, - fun_definition: if buf.is_empty() { None } else { Some(buf) }, + fun_definition: (!buf.is_empty()).then_some(buf), }, ))), } @@ -445,7 +445,7 @@ pub fn serialize_expr( protobuf::LogicalExprNode { expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name: func.name().to_string(), - fun_definition: if buf.is_empty() { None } else { Some(buf) }, + fun_definition: (!buf.is_empty()).then_some(buf), args: serialize_exprs(args, codec)?, })), } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index b686874098c9..0d519aefda14 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -491,7 +491,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { &ordering_req, &physical_schema, name.to_string(), - false, + agg_node.ignore_nulls, ) } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { @@ -506,8 +506,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; - let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct) } } }).transpose()?.ok_or_else(|| { @@ -2041,7 +2040,7 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!( - "LogicalExtensionCodec is not provided for aggregate function {name}" + "PhysicalExtensionCodec is not provided for aggregate function {name}" ) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index d85611eb7262..c9a7213fb5e0 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -40,6 +40,7 @@ use datafusion::{ physical_plan::expressions::LikeExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::WindowFrame; use crate::protobuf::{ self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode, @@ -66,8 +67,9 @@ pub fn serialize_physical_aggr_expr( aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, ordering_req, - distinct: false, - fun_definition: if buf.is_empty() { None } else { Some(buf) } + distinct: a.is_distinct(), + ignore_nulls: a.ignore_nulls(), + fun_definition: (!buf.is_empty()).then_some(buf) }, )), }); @@ -89,12 +91,55 @@ pub fn serialize_physical_aggr_expr( expr: expressions, ordering_req, distinct, + ignore_nulls: false, fun_definition: None, }, )), }) } +fn serialize_physical_window_aggr_expr( + aggr_expr: &dyn AggregateExpr, + window_frame: &WindowFrame, + codec: &dyn PhysicalExtensionCodec, +) -> Result<(physical_window_expr_node::WindowFunction, Option>)> { + if let Some(a) = aggr_expr.as_any().downcast_ref::() { + if a.is_distinct() || a.ignore_nulls() { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + let mut buf = Vec::new(); + codec.try_encode_udaf(a.fun(), &mut buf)?; + Ok(( + physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( + a.fun().name().to_string(), + ), + (!buf.is_empty()).then_some(buf), + )) + } else { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?; + if distinct { + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + if !window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!( + "Unbounded start bound in WindowFrame = {window_frame}" + ))); + } + + Ok(( + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32), + None, + )) + } +} + pub fn serialize_physical_window_expr( window_expr: Arc, codec: &dyn PhysicalExtensionCodec, @@ -171,70 +216,19 @@ pub fn serialize_physical_window_expr( } else if let Some(plain_aggr_window_expr) = expr.downcast_ref::() { - let aggr_expr = plain_aggr_window_expr.get_aggregate_expr(); - if let Some(a) = aggr_expr.as_any().downcast_ref::() { - let mut buf = Vec::new(); - codec.try_encode_udaf(a.fun(), &mut buf)?; - ( - physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( - a.fun().name().to_string(), - ), - if buf.is_empty() { None } else { Some(buf) }, - ) - } else { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( - plain_aggr_window_expr.get_aggregate_expr().as_ref(), - )?; - - if distinct { - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if !window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } - - ( - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32), - None, - ) - } + serialize_physical_window_aggr_expr( + plain_aggr_window_expr.get_aggregate_expr().as_ref(), + window_frame, + codec, + )? } else if let Some(sliding_aggr_window_expr) = expr.downcast_ref::() { - let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr(); - if let Some(a) = aggr_expr.as_any().downcast_ref::() { - let mut buf = Vec::new(); - codec.try_encode_udaf(a.fun(), &mut buf)?; - ( - physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( - a.fun().name().to_string(), - ), - if buf.is_empty() { None } else { Some(buf) }, - ) - } else { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( - sliding_aggr_window_expr.get_aggregate_expr().as_ref(), - )?; - - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } - - ( - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32), - None, - ) - } + serialize_physical_window_aggr_expr( + sliding_aggr_window_expr.get_aggregate_expr().as_ref(), + window_frame, + codec, + )? } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; @@ -492,7 +486,7 @@ pub fn serialize_physical_expr( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), args: serialize_physical_exprs(expr.args().to_vec(), codec)?, - fun_definition: if buf.is_empty() { None } else { Some(buf) }, + fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), }, )), diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index b17289205f3d..1f837b7f42e8 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,6 +15,105 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::plan_err; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility, +}; + mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; + +#[derive(Debug, PartialEq, Eq, Hash)] +struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, +} + +impl MyRegexUdf { + fn new(pattern: String) -> Self { + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + Self { signature, pattern } + } +} + +/// Implement the ScalarUDFImpl trait for MyRegexUdf +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Int64) + } else { + plan_err!("regex_udf only accepts Utf8 arguments") + } + } + fn invoke( + &self, + _args: &[ColumnarValue], + ) -> datafusion_common::Result { + unimplemented!() + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, +} + +#[derive(Debug, PartialEq, Eq, Hash)] +struct MyAggregateUDF { + signature: Signature, + result: String, +} + +impl MyAggregateUDF { + fn new(result: String) -> Self { + let signature = Signature::exact(vec![DataType::Int64], Volatility::Immutable); + Self { signature, result } + } +} + +impl AggregateUDFImpl for MyAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "aggregate_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + Ok(DataType::Utf8) + } + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> datafusion_common::Result> { + unimplemented!() + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MyAggregateUdfNode { + #[prost(string, tag = "1")] + pub result: String, +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d0209d811b7c..0117502f400d 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -28,15 +28,12 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use prost::Message; + use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; use datafusion::datasource::file_format::format_as_file_type; use datafusion::datasource::file_format::parquet::ParquetFormatFactory; -use datafusion_proto::logical_plan::file_formats::{ - ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, ParquetLogicalExtensionCodec, -}; -use prost::Message; - use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::session_state::SessionStateBuilder; @@ -62,9 +59,9 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, - LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, - TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue, + ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, + Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::average::avg_udaf; @@ -76,12 +73,17 @@ use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; +use datafusion_proto::logical_plan::file_formats::{ + ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, ParquetLogicalExtensionCodec, +}; use datafusion_proto::logical_plan::to_proto::serialize_expr; use datafusion_proto::logical_plan::{ from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use datafusion_proto::protobuf; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; + #[cfg(feature = "json")] fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { let string = serde_json::to_string(proto).unwrap(); @@ -744,7 +746,7 @@ pub mod proto { pub k: u64, #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, + pub expr: Option, } #[derive(Clone, PartialEq, Eq, ::prost::Message)] @@ -752,12 +754,6 @@ pub mod proto { #[prost(uint64, tag = "1")] pub k: u64, } - - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct MyRegexUdfNode { - #[prost(string, tag = "1")] - pub pattern: String, - } } #[derive(PartialEq, Eq, Hash)] @@ -890,51 +886,9 @@ impl LogicalExtensionCodec for TopKExtensionCodec { } #[derive(Debug)] -struct MyRegexUdf { - signature: Signature, - // regex as original string - pattern: String, -} - -impl MyRegexUdf { - fn new(pattern: String) -> Self { - Self { - signature: Signature::uniform( - 1, - vec![DataType::Int32], - Volatility::Immutable, - ), - pattern, - } - } -} - -/// Implement the ScalarUDFImpl trait for MyRegexUdf -impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "regex_udf" - } - fn signature(&self) -> &Signature { - &self.signature - } - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.first(), Some(&DataType::Utf8)) { - return plan_err!("regex_udf only accepts Utf8 arguments"); - } - Ok(DataType::Int32) - } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() - } -} - -#[derive(Debug)] -pub struct ScalarUDFExtensionCodec {} +pub struct UDFExtensionCodec; -impl LogicalExtensionCodec for ScalarUDFExtensionCodec { +impl LogicalExtensionCodec for UDFExtensionCodec { fn try_decode( &self, _buf: &[u8], @@ -969,13 +923,11 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { if name == "regex_udf" { - let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to decode regex_udf: {}", err)) + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) })?; - Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( - proto.pattern, - )))) + Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) } else { not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } @@ -984,11 +936,39 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); let udf = binding.as_any().downcast_ref::().unwrap(); - let proto = proto::MyRegexUdfNode { + let proto = MyRegexUdfNode { pattern: udf.pattern.clone(), }; - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) + })?; + Ok(()) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "aggregate_udf" { + let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode aggregate_udf: {err}" + )) + })?; + + Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( + proto.result, + )))) + } else { + not_impl_err!("unrecognized aggregate UDF implementation, cannot decode") + } + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + let udf = binding.as_any().downcast_ref::().unwrap(); + let proto = MyAggregateUdfNode { + result: udf.result.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) })?; Ok(()) } @@ -1563,8 +1543,7 @@ fn roundtrip_null_scalar_values() { for test_case in test_types.into_iter() { let proto_scalar: protobuf::ScalarValue = (&test_case).try_into().unwrap(); - let returned_scalar: datafusion::scalar::ScalarValue = - (&proto_scalar).try_into().unwrap(); + let returned_scalar: ScalarValue = (&proto_scalar).try_into().unwrap(); assert_eq!(format!("{:?}", &test_case), format!("{returned_scalar:?}")); } } @@ -1893,22 +1872,19 @@ fn roundtrip_aggregate_udf() { struct Dummy {} impl Accumulator for Dummy { - fn state(&mut self) -> datafusion::error::Result> { + fn state(&mut self) -> Result> { Ok(vec![]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64(None)) } @@ -1976,25 +1952,27 @@ fn roundtrip_scalar_udf() { #[test] fn roundtrip_scalar_udf_extension_codec() { - let pattern = ".*"; - let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); - let test_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()), vec![])); - + let udf = ScalarUDF::from(MyRegexUdf::new(".*".to_owned())); + let test_expr = udf.call(vec!["foo".lit()]); let ctx = SessionContext::new(); - ctx.register_udf(udf); - - let extension_codec = ScalarUDFExtensionCodec {}; - let proto: protobuf::LogicalExprNode = - match serialize_expr(&test_expr, &extension_codec) { - Ok(p) => p, - Err(e) => panic!("Error serializing expression: {:?}", e), - }; - let round_trip: Expr = - from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); + let round_trip = + from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); + roundtrip_json_test(&proto); +} + +#[test] +fn roundtrip_aggregate_udf_extension_codec() { + let udf = AggregateUDF::from(MyAggregateUDF::new("DataFusion".to_owned())); + let test_expr = udf.call(vec![42.lit()]); + let ctx = SessionContext::new(); + let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); + let round_trip = + from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); } @@ -2120,22 +2098,19 @@ fn roundtrip_window() { struct DummyAggr {} impl Accumulator for DummyAggr { - fn state(&mut self) -> datafusion::error::Result> { + fn state(&mut self) -> Result> { Ok(vec![]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64(None)) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 2fcc65008fd8..fba6dfe42599 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::RecordBatch; use std::any::Any; use std::fmt::Display; use std::hash::Hasher; @@ -23,8 +22,8 @@ use std::ops::Deref; use std::sync::Arc; use std::vec; +use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; -use datafusion::functions_aggregate::sum::sum_udaf; use prost::Message; use datafusion::arrow::array::ArrayRef; @@ -40,9 +39,10 @@ use datafusion::datasource::physical_plan::{ FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; +use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::aggregate::utils::down_cast_any_ref; -use datafusion::physical_expr::expressions::Max; +use datafusion::physical_expr::expressions::{Literal, Max}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -70,7 +70,7 @@ use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, + AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; @@ -79,10 +79,10 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -92,6 +92,8 @@ use datafusion_proto::physical_plan::{ }; use datafusion_proto::protobuf; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; + /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is /// lost during serde because the string representation of a plan often only shows a subset of state. @@ -312,7 +314,7 @@ fn roundtrip_window() -> Result<()> { ); let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; - let sum_expr = udaf::create_aggregate_expr( + let sum_expr = create_aggregate_expr( &sum_udaf(), &args, &[], @@ -367,7 +369,7 @@ fn rountrip_aggregate() -> Result<()> { false, )?], // NTH_VALUE - vec![udaf::create_aggregate_expr( + vec![create_aggregate_expr( &nth_value_udaf(), &[col("b", &schema)?, lit(1u64)], &[], @@ -379,7 +381,7 @@ fn rountrip_aggregate() -> Result<()> { false, )?], // STRING_AGG - vec![udaf::create_aggregate_expr( + vec![create_aggregate_expr( &AggregateUDF::new_from_impl(StringAgg::new()), &[ cast(col("b", &schema)?, &schema, DataType::Utf8)?, @@ -490,7 +492,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![udaf::create_aggregate_expr( + let aggregates: Vec> = vec![create_aggregate_expr( &udaf, &[col("b", &schema)?], &[], @@ -845,123 +847,161 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), &ctx) } -#[test] -fn roundtrip_scalar_udf_extension_codec() -> Result<()> { - #[derive(Debug)] - struct MyRegexUdf { - signature: Signature, - // regex as original string - pattern: String, +#[derive(Debug)] +struct UDFExtensionCodec; + +impl PhysicalExtensionCodec for UDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") } - impl MyRegexUdf { - fn new(pattern: String) -> Self { - Self { - signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), - pattern, - } - } + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") } - /// Implement the ScalarUDFImpl trait for MyRegexUdf - impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) + })?; - fn name(&self) -> &str { - "regex_udf" + Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } + } - fn signature(&self) -> &Signature { - &self.signature + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) + })?; } + Ok(()) + } - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.first(), Some(&DataType::Utf8)) { - return plan_err!("regex_udf only accepts Utf8 arguments"); - } - Ok(DataType::Int64) + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "aggregate_udf" { + let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode aggregate_udf: {err}" + )) + })?; + + Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( + proto.result, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } + } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyAggregateUdfNode { + result: udf.result.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err:?}")) + })?; } + Ok(()) } +} - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct MyRegexUdfNode { - #[prost(string, tag = "1")] - pub pattern: String, - } +#[test] +fn roundtrip_scalar_udf_extension_codec() -> Result<()> { + let field_text = Field::new("text", DataType::Utf8, true); + let field_published = Field::new("published", DataType::Boolean, false); + let field_author = Field::new("author", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_text, field_published, field_author])); + let input = Arc::new(EmptyExec::new(schema.clone())); - #[derive(Debug)] - pub struct ScalarUDFExtensionCodec {} + let udf_expr = Arc::new(ScalarFunctionExpr::new( + "regex_udf", + Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), + vec![col("text", &schema)?], + DataType::Int64, + )); - impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[Arc], - _registry: &dyn FunctionRegistry, - ) -> Result> { - not_impl_err!("No extension codec provided") - } + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("published", &schema)?, + Operator::And, + Arc::new(BinaryExpr::new(udf_expr.clone(), Operator::Gt, lit(0))), + )), + input, + )?); - fn try_encode( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> Result<()> { - not_impl_err!("No extension codec provided") - } + let window = Arc::new(WindowAggExec::try_new( + vec![Arc::new(PlainAggregateWindowExpr::new( + Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), + &[col("author", &schema)?], + &[], + Arc::new(WindowFrame::new(None)), + ))], + filter, + vec![col("author", &schema)?], + )?); - fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { - if name == "regex_udf" { - let proto = MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!( - "failed to decode regex_udf: {}", - err - )) - })?; - - Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( - proto.pattern, - )))) - } else { - not_impl_err!("unrecognized scalar UDF implementation, cannot decode") - } - } + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], + vec![None], + window, + schema.clone(), + )?); - fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { - let binding = node.inner(); - if let Some(udf) = binding.as_any().downcast_ref::() { - let proto = MyRegexUdfNode { - pattern: udf.pattern.clone(), - }; - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode udf: {e:?}")) - })?; - } - Ok(()) - } - } + let ctx = SessionContext::new(); + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + Ok(()) +} +#[test] +fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let field_text = Field::new("text", DataType::Utf8, true); let field_published = Field::new("published", DataType::Boolean, false); let field_author = Field::new("author", DataType::Utf8, false); let schema = Arc::new(Schema::new(vec![field_text, field_published, field_author])); let input = Arc::new(EmptyExec::new(schema.clone())); - let pattern = ".*"; - let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); let udf_expr = Arc::new(ScalarFunctionExpr::new( - udf.name(), - Arc::new(udf.clone()), + "regex_udf", + Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], DataType::Int64, )); + let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string())); + let aggr_args: [Arc; 1] = + [Arc::new(Literal::new(ScalarValue::from(42)))]; + let aggr_expr = create_aggregate_expr( + &udaf, + &aggr_args, + &[], + &[], + &[], + &schema, + "aggregate_udf", + false, + false, + )?; + let filter = Arc::new(FilterExec::try_new( Arc::new(BinaryExpr::new( col("published", &schema)?, @@ -973,7 +1013,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( - Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), + aggr_expr, &[col("author", &schema)?], &[], Arc::new(WindowFrame::new(None)), @@ -982,18 +1022,29 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { vec![col("author", &schema)?], )?); + let aggr_expr = create_aggregate_expr( + &udaf, + &aggr_args, + &[], + &[], + &[], + &schema, + "aggregate_udf", + true, + true, + )?; + let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], + vec![aggr_expr], vec![None], window, schema.clone(), )?); let ctx = SessionContext::new(); - let codec = ScalarUDFExtensionCodec {}; - roundtrip_test_and_return(aggregate, &ctx, &codec)?; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; Ok(()) }