Skip to content

Commit

Permalink
Add tests for encoding and decoding UDAF
Browse files Browse the repository at this point in the history
  • Loading branch information
joroKr21 committed Jul 16, 2024
1 parent 65302ca commit 6e3afd2
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 272 deletions.
5 changes: 5 additions & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ impl AggregateFunctionExpr {
pub fn is_distinct(&self) -> bool {
self.is_distinct
}

/// Return if the aggregation ignores nulls
pub fn ignore_nulls(&self) -> bool {
self.ignore_nulls
}
}

impl AggregateExpr for AggregateFunctionExpr {
Expand Down
3 changes: 2 additions & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,8 @@ message PhysicalAggregateExprNode {
repeated PhysicalExprNode expr = 2;
repeated PhysicalSortExprNode ordering_req = 5;
bool distinct = 3;
optional bytes fun_definition = 6;
bool ignore_nulls = 6;
optional bytes fun_definition = 7;
}

message PhysicalWindowExprNode {
Expand Down
18 changes: 18 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ pub fn serialize_expr(
protobuf::window_expr_node::WindowFunction::Udaf(
aggr_udf.name().to_string(),
),
if buf.is_empty() { None } else { Some(buf) },
(!buf.is_empty()).then_some(buf),
)
}
WindowFunctionDefinition::WindowUDF(window_udf) => {
Expand All @@ -349,7 +349,7 @@ pub fn serialize_expr(
protobuf::window_expr_node::WindowFunction::Udwf(
window_udf.name().to_string(),
),
if buf.is_empty() { None } else { Some(buf) },
(!buf.is_empty()).then_some(buf),
)
}
};
Expand Down Expand Up @@ -427,7 +427,7 @@ pub fn serialize_expr(
Some(e) => serialize_exprs(e, codec)?,
None => vec![],
},
fun_definition: if buf.is_empty() { None } else { Some(buf) },
fun_definition: (!buf.is_empty()).then_some(buf),
},
))),
}
Expand All @@ -445,7 +445,7 @@ pub fn serialize_expr(
protobuf::LogicalExprNode {
expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
fun_name: func.name().to_string(),
fun_definition: if buf.is_empty() { None } else { Some(buf) },
fun_definition: (!buf.is_empty()).then_some(buf),
args: serialize_exprs(args, codec)?,
})),
}
Expand Down
7 changes: 3 additions & 4 deletions datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
&ordering_req,
&physical_schema,
name.to_string(),
false,
agg_node.ignore_nulls,
)
}
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
Expand All @@ -506,8 +506,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
// TODO: `order by` is not supported for UDAF yet
let sort_exprs = &[];
let ordering_req = &[];
let ignore_nulls = false;
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false)
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct)
}
}
}).transpose()?.ok_or_else(|| {
Expand Down Expand Up @@ -2041,7 +2040,7 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {

fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result<Arc<AggregateUDF>> {
not_impl_err!(
"LogicalExtensionCodec is not provided for aggregate function {name}"
"PhysicalExtensionCodec is not provided for aggregate function {name}"
)
}

Expand Down
122 changes: 58 additions & 64 deletions datafusion/proto/src/physical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use datafusion::{
physical_plan::expressions::LikeExpr,
};
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
use datafusion_expr::WindowFrame;

use crate::protobuf::{
self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode,
Expand All @@ -66,8 +67,9 @@ pub fn serialize_physical_aggr_expr(
aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)),
expr: expressions,
ordering_req,
distinct: false,
fun_definition: if buf.is_empty() { None } else { Some(buf) }
distinct: a.is_distinct(),
ignore_nulls: a.ignore_nulls(),
fun_definition: (!buf.is_empty()).then_some(buf)
},
)),
});
Expand All @@ -89,12 +91,55 @@ pub fn serialize_physical_aggr_expr(
expr: expressions,
ordering_req,
distinct,
ignore_nulls: false,
fun_definition: None,
},
)),
})
}

fn serialize_physical_window_aggr_expr(
aggr_expr: &dyn AggregateExpr,
window_frame: &WindowFrame,
codec: &dyn PhysicalExtensionCodec,
) -> Result<(physical_window_expr_node::WindowFunction, Option<Vec<u8>>)> {
if let Some(a) = aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if a.is_distinct() || a.ignore_nulls() {
// TODO
return not_impl_err!(
"Distinct aggregate functions not supported in window expressions"
);
}

let mut buf = Vec::new();
codec.try_encode_udaf(a.fun(), &mut buf)?;
Ok((
physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
a.fun().name().to_string(),
),
(!buf.is_empty()).then_some(buf),
))
} else {
let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?;
if distinct {
return not_impl_err!(
"Distinct aggregate functions not supported in window expressions"
);
}

if !window_frame.start_bound.is_unbounded() {
return Err(DataFusionError::Internal(format!(
"Unbounded start bound in WindowFrame = {window_frame}"
)));
}

Ok((
physical_window_expr_node::WindowFunction::AggrFunction(inner as i32),
None,
))
}
}

pub fn serialize_physical_window_expr(
window_expr: Arc<dyn WindowExpr>,
codec: &dyn PhysicalExtensionCodec,
Expand Down Expand Up @@ -171,70 +216,19 @@ pub fn serialize_physical_window_expr(
} else if let Some(plain_aggr_window_expr) =
expr.downcast_ref::<PlainAggregateWindowExpr>()
{
let aggr_expr = plain_aggr_window_expr.get_aggregate_expr();
if let Some(a) = aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
let mut buf = Vec::new();
codec.try_encode_udaf(a.fun(), &mut buf)?;
(
physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
a.fun().name().to_string(),
),
if buf.is_empty() { None } else { Some(buf) },
)
} else {
let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
plain_aggr_window_expr.get_aggregate_expr().as_ref(),
)?;

if distinct {
return not_impl_err!(
"Distinct aggregate functions not supported in window expressions"
);
}

if !window_frame.start_bound.is_unbounded() {
return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}")));
}

(
physical_window_expr_node::WindowFunction::AggrFunction(inner as i32),
None,
)
}
serialize_physical_window_aggr_expr(
plain_aggr_window_expr.get_aggregate_expr().as_ref(),
window_frame,
codec,
)?
} else if let Some(sliding_aggr_window_expr) =
expr.downcast_ref::<SlidingAggregateWindowExpr>()
{
let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr();
if let Some(a) = aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
let mut buf = Vec::new();
codec.try_encode_udaf(a.fun(), &mut buf)?;
(
physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
a.fun().name().to_string(),
),
if buf.is_empty() { None } else { Some(buf) },
)
} else {
let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
)?;

if distinct {
// TODO
return not_impl_err!(
"Distinct aggregate functions not supported in window expressions"
);
}

if window_frame.start_bound.is_unbounded() {
return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}")));
}

(
physical_window_expr_node::WindowFunction::AggrFunction(inner as i32),
None,
)
}
serialize_physical_window_aggr_expr(
sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
window_frame,
codec,
)?
} else {
return not_impl_err!("WindowExpr not supported: {window_expr:?}");
};
Expand Down Expand Up @@ -492,7 +486,7 @@ pub fn serialize_physical_expr(
protobuf::PhysicalScalarUdfNode {
name: expr.name().to_string(),
args: serialize_physical_exprs(expr.args().to_vec(), codec)?,
fun_definition: if buf.is_empty() { None } else { Some(buf) },
fun_definition: (!buf.is_empty()).then_some(buf),
return_type: Some(expr.return_type().try_into()?),
},
)),
Expand Down
Loading

0 comments on commit 6e3afd2

Please sign in to comment.