Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed API change for AggregateUDF to support ORDER BY #1

Closed
wants to merge 12 commits into from
10 changes: 8 additions & 2 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::Schema;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -85,7 +86,12 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: Vec<Expr>,
_schema: Option<&Schema>,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(GeometricMean::new()))
}

Expand Down Expand Up @@ -191,7 +197,7 @@ impl Accumulator for GeometricMean {

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::arrow::datatypes::Field;
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async fn main() -> Result<()> {
Arc::new(DataType::Float64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(GeometricMean::new()))),
Arc::new(|_, _, _| Ok(Box::new(GeometricMean::new()))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
Expand Down
17 changes: 8 additions & 9 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,23 +244,19 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
distinct,
args,
filter,
order_by,
order_by: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
}
AggregateFunctionDefinition::UDF(fun) => {
// TODO: Add support for filter and order by in AggregateUDF
// TODO: Add support for filter by in AggregateUDF
if filter.is_some() {
return exec_err!(
"aggregate expression with filter is not supported"
);
}
if order_by.is_some() {
return exec_err!(
"aggregate expression with order_by is not supported"
);
}

let names = args
.iter()
.map(|e| create_physical_name(e, false))
Expand Down Expand Up @@ -1689,13 +1685,16 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
(agg_expr, filter, order_by)
}
AggregateFunctionDefinition::UDF(fun) => {
let ordering_reqs: Vec<PhysicalSortExpr> =
order_by.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
&ordering_reqs,
physical_input_schema,
name,
);
(agg_expr?, filter, order_by)
)?;
(agg_expr, filter, order_by)
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!(
Expand Down
129 changes: 120 additions & 9 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use arrow_schema::{Schema, SortOptions};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand All @@ -43,12 +43,14 @@ use datafusion::{
scalar::ScalarValue,
};
use datafusion_common::{
assert_contains, cast::as_primitive_array, exec_err, DataFusionError,
assert_contains, cast::as_primitive_array, exec_err, Column, DataFusionError,
};
use datafusion_expr::{
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr,
GroupsAccumulator, SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_physical_expr::expressions::{self, FirstValueAccumulator};
use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr};

/// Test to show the contents of the setup
#[tokio::test]
Expand Down Expand Up @@ -210,6 +212,110 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}

#[tokio::test]
async fn simple_udaf_order() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
],
)?;

let ctx = SessionContext::new();

let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?;
ctx.register_table("t", Arc::new(provider))?;

fn create_accumulator(
data_type: &DataType,
order_by: Vec<Expr>,
schema: Option<&Schema>,
) -> Result<Box<dyn Accumulator>> {
// test with ordering so schema is required
let schema = schema.unwrap();

let mut all_sort_orders = vec![];

// Construct PhysicalSortExpr objects from Expr objects:
let mut sort_exprs = vec![];
for expr in order_by {
if let Expr::Sort(sort) = expr {
if let Expr::Column(col) = sort.expr.as_ref() {
let name = &col.name;
let e = expressions::col(name, schema)?;
sort_exprs.push(PhysicalSortExpr {
expr: e,
options: SortOptions {
descending: !sort.asc,
nulls_first: sort.nulls_first,
},
});
}
}
}
if !sort_exprs.is_empty() {
all_sort_orders.extend(sort_exprs);
}

let ordering_req = all_sort_orders;

let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
.collect::<Result<Vec<_>>>()?;

let acc = FirstValueAccumulator::try_new(
data_type,
ordering_types.as_slice(),
ordering_req,
)?;
Ok(Box::new(acc))
}

// define a udaf, using a DataFusion's accumulator
let my_first = create_udaf_with_ordering(
"my_first",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
Arc::new(create_accumulator),
Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]),
vec![Expr::Sort(Sort {
expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))),
asc: false,
nulls_first: false,
})],
Some(&schema),
);

ctx.register_udaf(my_first);

// Should be the same as `SELECT FIRST_VALUE(a order by a) FROM t group by b order by b`
let result = ctx
.sql("SELECT MY_FIRST(a order by a desc) FROM t group by b order by b")
.await?
.collect()
.await?;

let expected = [
"+---------------+",
"| my_first(t.a) |",
"+---------------+",
"| 2 |",
"| 4 |",
"+---------------+",
];
assert_batches_eq!(expected, &result);

Ok(())
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
Expand All @@ -235,7 +341,7 @@ async fn simple_udaf() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -268,7 +374,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -439,7 +545,7 @@ impl TimeSum {

let captured_state = Arc::clone(&test_state);
let accumulator: AccumulatorFactoryFunction =
Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));
Arc::new(move |_, _, _| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));

let time_sum = AggregateUDF::from(SimpleAggregateUDF::new(
name,
Expand Down Expand Up @@ -538,7 +644,7 @@ impl FirstSelector {
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];

let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::new(Self::new())));
Arc::new(|_, _, _| Ok(Box::new(Self::new())));

let volatility = Volatility::Immutable;

Expand Down Expand Up @@ -659,7 +765,12 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
Ok(DataType::UInt64)
}

fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: Vec<Expr>,
_schema: Option<&Schema>,
) -> Result<Box<dyn Accumulator>> {
// should use groups accumulator
panic!("accumulator shouldn't invoke");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ async fn udaf_as_window_func() -> Result<()> {
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(MyAccumulator))),
Arc::new(|_, _, _| Ok(Box::new(MyAccumulator))),
Arc::new(vec![DataType::Int32]),
);

Expand Down
11 changes: 8 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
ScalarUDF, Signature, Volatility,
};
use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl};
use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{Column, Result};
use std::any::Any;
use std::fmt::Debug;
Expand Down Expand Up @@ -1150,8 +1150,13 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
Ok(self.return_type.clone())
}

fn accumulator(&self, arg: &DataType) -> Result<Box<dyn crate::Accumulator>> {
(self.accumulator)(arg)
fn accumulator(
&self,
arg: &DataType,
sort_exprs: Vec<Expr>,
schema: &Schema,
) -> Result<Box<dyn crate::Accumulator>> {
(self.accumulator)(arg, sort_exprs, schema)
}

fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
Expand Down
11 changes: 6 additions & 5 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

//! Function module contains typing and signature for built-in and user defined functions.

use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
use crate::{Accumulator, BuiltinScalarFunction, Expr, PartitionEvaluator, Signature};
use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, Schema};
use datafusion_common::utils::datafusion_strsim;
use datafusion_common::Result;
use std::sync::Arc;
Expand All @@ -41,9 +41,10 @@ pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;

/// Factory that returns an accumulator for the given aggregate, given
/// its return datatype.
pub type AccumulatorFactoryFunction =
Arc<dyn Fn(&DataType) -> Result<Box<dyn Accumulator>> + Send + Sync>;
/// its return datatype, the sorting expressions and the schema for ordering.
pub type AccumulatorFactoryFunction = Arc<
dyn Fn(&DataType, Vec<Expr>, &Schema) -> Result<Box<dyn Accumulator>> + Send + Sync,
>;

/// Factory that creates a PartitionEvaluator for the given window
/// function
Expand Down
Loading