Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ORDER BY in AggregateUDF #9249

Closed
wants to merge 19 commits into from
10 changes: 8 additions & 2 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::Schema;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -85,7 +86,12 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: &[Expr],
_schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(GeometricMean::new()))
}

Expand Down Expand Up @@ -191,7 +197,7 @@ impl Accumulator for GeometricMean {

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::arrow::datatypes::Field;
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async fn main() -> Result<()> {
Arc::new(DataType::Float64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(GeometricMean::new()))),
Arc::new(|_, _, _| Ok(Box::new(GeometricMean::new()))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
Expand Down
21 changes: 12 additions & 9 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
distinct,
args,
filter,
order_by,
order_by: _,
null_treatment: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
}
AggregateFunctionDefinition::UDF(fun) => {
// TODO: Add support for filter and order by in AggregateUDF
// TODO: Add support for filter by in AggregateUDF
if filter.is_some() {
return exec_err!(
"aggregate expression with filter is not supported"
);
}
if order_by.is_some() {
return exec_err!(
"aggregate expression with order_by is not supported"
);
}

let names = args
.iter()
.map(|e| create_physical_name(e, false))
Expand Down Expand Up @@ -1682,6 +1678,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
)?),
None => None,
};

let sort_exprs = order_by.clone().unwrap_or(vec![]);
let order_by = match order_by {
Some(e) => Some(
e.iter()
Expand Down Expand Up @@ -1714,13 +1712,18 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
(agg_expr, filter, order_by)
}
AggregateFunctionDefinition::UDF(fun) => {
let ordering_reqs: Vec<PhysicalSortExpr> =
order_by.clone().unwrap_or(vec![]);

let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
);
(agg_expr?, filter, order_by)
)?;
(agg_expr, filter, order_by)
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!(
Expand Down
123 changes: 113 additions & 10 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use arrow_schema::{Schema, SortOptions};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand All @@ -45,9 +45,11 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
create_udaf, create_udaf_with_ordering, AggregateUDFImpl, Expr, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_physical_expr::expressions::{self, FirstValueAccumulator};
use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr};

/// Test to show the contents of the setup
#[tokio::test]
Expand Down Expand Up @@ -209,6 +211,102 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}

#[tokio::test]
async fn simple_udaf_order() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?;
ctx.register_table("t", Arc::new(provider))?;

fn create_accumulator(
data_type: &DataType,
order_by: &[Expr],
schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
let mut all_sort_orders = vec![];

// Construct PhysicalSortExpr objects from Expr objects:
let mut sort_exprs = vec![];
for expr in order_by {
if let Expr::Sort(sort) = expr {
if let Expr::Column(col) = sort.expr.as_ref() {
let name = &col.name;
let e = expressions::col(name, schema)?;
sort_exprs.push(PhysicalSortExpr {
expr: e,
options: SortOptions {
descending: !sort.asc,
nulls_first: sort.nulls_first,
},
});
}
}
}
if !sort_exprs.is_empty() {
all_sort_orders.extend(sort_exprs);
}

let ordering_req = all_sort_orders;

let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
.collect::<Result<Vec<_>>>()?;

let acc = FirstValueAccumulator::try_new(
data_type,
&ordering_dtypes,
ordering_req,
false,
Copy link
Contributor Author

@jayzhan211 jayzhan211 Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore_nulls or other arguments are given in the accumulator provided by the user.

)?;
Ok(Box::new(acc))
}

// define a udaf, using a DataFusion's accumulator
let my_first = create_udaf_with_ordering(
"my_first",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
Arc::new(create_accumulator),
Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]),
);

ctx.register_udaf(my_first);

// Should be the same as `SELECT FIRST_VALUE(a order by a) FROM t group by b order by b`
let result = ctx
.sql("SELECT MY_FIRST(a order by a desc) FROM t group by b order by b")
.await?
.collect()
.await?;

let expected = [
"+---------------+",
"| my_first(t.a) |",
"+---------------+",
"| 2 |",
"| 4 |",
"+---------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
Expand All @@ -234,7 +332,7 @@ async fn simple_udaf() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -262,7 +360,7 @@ async fn deregister_udaf() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -290,7 +388,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -333,7 +431,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
)
.with_aliases(vec!["dummy_alias"]);
Expand Down Expand Up @@ -497,7 +595,7 @@ impl TimeSum {

let captured_state = Arc::clone(&test_state);
let accumulator: AccumulatorFactoryFunction =
Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));
Arc::new(move |_, _, _| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));

let time_sum = AggregateUDF::from(SimpleAggregateUDF::new(
name,
Expand Down Expand Up @@ -596,7 +694,7 @@ impl FirstSelector {
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];

let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::new(Self::new())));
Arc::new(|_, _, _| Ok(Box::new(Self::new())));

let volatility = Volatility::Immutable;

Expand Down Expand Up @@ -717,7 +815,12 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
Ok(DataType::UInt64)
}

fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: &[Expr],
_schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
// should use groups accumulator
panic!("accumulator shouldn't invoke");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ async fn udaf_as_window_func() -> Result<()> {
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(MyAccumulator))),
Arc::new(|_, _, _| Ok(Box::new(MyAccumulator))),
Arc::new(vec![DataType::Int32]),
);

Expand Down
Loading
Loading