Skip to content

Commit

Permalink
Port ArrayDistinct to functions-array subcrate (#9549)
Browse files Browse the repository at this point in the history
* Issue-9545 - Port ArrayDistinct to function-arrays subcrate

* Issue-9545 - Add test coverage on roundtrip_logical_plan

* Issue-9545 - Address review comments
  • Loading branch information
erenavsarogullari authored Mar 12, 2024
1 parent 96669de commit 30c4fd7
Show file tree
Hide file tree
Showing 14 changed files with 174 additions and 60 deletions.
6 changes: 0 additions & 6 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ pub enum BuiltinScalarFunction {
ArrayPopFront,
/// array_pop_back
ArrayPopBack,
/// array_distinct
ArrayDistinct,
/// array_element
ArrayElement,
/// array_position
Expand Down Expand Up @@ -325,7 +323,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Tan => Volatility::Immutable,
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable,
Expand Down Expand Up @@ -416,7 +413,6 @@ impl BuiltinScalarFunction {
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match self {
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field)
| LargeList(field)
Expand Down Expand Up @@ -658,7 +654,6 @@ impl BuiltinScalarFunction {
Signature::array_and_index(self.volatility())
}
BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::array_and_element_and_optional_index(self.volatility())
}
Expand Down Expand Up @@ -1073,7 +1068,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::SHA256 => &["sha256"],
BuiltinScalarFunction::SHA384 => &["sha384"],
BuiltinScalarFunction::SHA512 => &["sha512"],
BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"],
BuiltinScalarFunction::ArrayElement => &[
"array_element",
"array_extract",
Expand Down
6 changes: 0 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,6 @@ scalar_expr!(
first_array second_array,
"Returns an array of the elements that appear in the first array but not in the second."
);
scalar_expr!(
ArrayDistinct,
array_distinct,
array,
"return distinct values from the array after removing duplicates."
);
scalar_expr!(
ArrayPosition,
array_position,
Expand Down
70 changes: 69 additions & 1 deletion datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ use arrow::compute;
use arrow::datatypes::Field;
use arrow::datatypes::UInt64Type;
use arrow::datatypes::{DataType, Date32Type, IntervalMonthDayNanoType};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use arrow_schema::FieldRef;
use arrow_schema::SortOptions;

use datafusion_common::cast::{
as_date32_array, as_generic_list_array, as_generic_string_array, as_int64_array,
as_interval_mdn_array, as_large_list_array, as_list_array, as_null_array,
as_string_array,
};
use datafusion_common::{exec_err, not_impl_datafusion_err, DataFusionError, Result};
use datafusion_common::{
exec_err, internal_err, not_impl_datafusion_err, DataFusionError, Result,
};
use itertools::Itertools;
use std::any::type_name;
use std::sync::Arc;

Expand Down Expand Up @@ -865,3 +871,65 @@ pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}
}

/// array_distinct SQL function
/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_distinct needs one argument");
}

// handle null
if args[0].data_type() == &DataType::Null {
return Ok(args[0].clone());
}

// handle for list & largelist
match args[0].data_type() {
DataType::List(field) => {
let array = as_list_array(&args[0])?;
general_array_distinct(array, field)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&args[0])?;
general_array_distinct(array, field)
}
array_type => exec_err!("array_distinct does not support type '{array_type:?}'"),
}
}

pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
array: &GenericListArray<OffsetSize>,
field: &FieldRef,
) -> Result<ArrayRef> {
let dt = array.value_type();
let mut offsets = Vec::with_capacity(array.len());
offsets.push(OffsetSize::usize_as(0));
let mut new_arrays = Vec::with_capacity(array.len());
let converter = RowConverter::new(vec![SortField::new(dt)])?;
// distinct for each list in ListArray
for arr in array.iter().flatten() {
let values = converter.convert_columns(&[arr])?;
// sort elements in list and remove duplicates
let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
let last_offset: OffsetSize = offsets.last().copied().unwrap();
offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
let arrays = converter.convert_rows(rows)?;
let array = match arrays.first() {
Some(array) => array.clone(),
None => {
return internal_err!("array_distinct: failed to get array from rows")
}
};
new_arrays.push(array);
}
let offsets = OffsetBuffer::new(offsets.into());
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = compute::concat(&new_arrays_ref)?;
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
field.clone(),
offsets,
values,
None,
)?))
}
2 changes: 2 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub mod expr_fn {
pub use super::concat::array_prepend;
pub use super::make_array::make_array;
pub use super::udf::array_dims;
pub use super::udf::array_distinct;
pub use super::udf::array_empty;
pub use super::udf::array_length;
pub use super::udf::array_ndims;
Expand Down Expand Up @@ -84,6 +85,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
udf::array_length_udf(),
udf::flatten_udf(),
udf::array_sort_udf(),
udf::array_distinct_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
64 changes: 64 additions & 0 deletions datafusion/functions-array/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,67 @@ impl ScalarUDFImpl for Flatten {
&self.aliases
}
}

make_udf_function!(
ArrayDistinct,
array_distinct,
array,
"return distinct values from the array after removing duplicates.",
array_distinct_udf
);

#[derive(Debug)]
pub(super) struct ArrayDistinct {
signature: Signature,
aliases: Vec<String>,
}

impl crate::udf::ArrayDistinct {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
aliases: vec!["array_distinct".to_string(), "list_distinct".to_string()],
}
}
}

impl ScalarUDFImpl for crate::udf::ArrayDistinct {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_distinct"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
match &arg_types[0] {
List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new(
"item",
field.data_type().clone(),
true,
)))),
LargeList(field) => Ok(LargeList(Arc::new(Field::new(
"item",
field.data_type().clone(),
true,
)))),
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
crate::kernels::array_distinct(&args).map(ColumnarValue::Array)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}
26 changes: 0 additions & 26 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1539,32 +1539,6 @@ pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
)?))
}

/// array_distinct SQL function
/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_distinct needs one argument");
}

// handle null
if args[0].data_type() == &DataType::Null {
return Ok(args[0].clone());
}

// handle for list & largelist
match args[0].data_type() {
DataType::List(field) => {
let array = as_list_array(&args[0])?;
general_array_distinct(array, field)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&args[0])?;
general_array_distinct(array, field)
}
array_type => exec_err!("array_distinct does not support type '{array_type:?}'"),
}
}

/// array_resize SQL function
pub fn array_resize(arg: &[ArrayRef]) -> Result<ArrayRef> {
if arg.len() < 2 || arg.len() > 3 {
Expand Down
3 changes: 0 additions & 3 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ pub fn create_physical_fun(
}

// array functions
BuiltinScalarFunction::ArrayDistinct => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_distinct)(args)
}),
BuiltinScalarFunction::ArrayElement => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_element)(args)
}),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ enum ScalarFunction {
SubstrIndex = 126;
FindInSet = 127;
/// 128 was ArraySort
ArrayDistinct = 129;
/// 129 was ArrayDistinct
ArrayResize = 130;
EndsWith = 131;
/// 132 was InStr
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 6 additions & 10 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ use datafusion_common::{
use datafusion_expr::expr::Unnest;
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
use datafusion_expr::{
acosh, array_distinct, array_element, array_except, array_intersect, array_pop_back,
array_pop_front, array_position, array_positions, array_remove, array_remove_all,
array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n,
array_resize, array_slice, array_union, ascii, asinh, atan, atan2, atanh, bit_length,
btrim, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos,
cosh, cot, current_date, current_time, degrees, digest, ends_with, exp,
acosh, array_element, array_except, array_intersect, array_pop_back, array_pop_front,
array_position, array_positions, array_remove, array_remove_all, array_remove_n,
array_repeat, array_replace, array_replace_all, array_replace_n, array_resize,
array_slice, array_union, ascii, asinh, atan, atan2, atanh, bit_length, btrim, cbrt,
ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot,
current_date, current_time, degrees, digest, ends_with, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, find_in_set, floor, from_unixtime, gcd, initcap, iszero, lcm, left,
levenshtein, ln, log, log10, log2,
Expand Down Expand Up @@ -475,7 +475,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Ltrim => Self::Ltrim,
ScalarFunction::Rtrim => Self::Rtrim,
ScalarFunction::ArrayExcept => Self::ArrayExcept,
ScalarFunction::ArrayDistinct => Self::ArrayDistinct,
ScalarFunction::ArrayElement => Self::ArrayElement,
ScalarFunction::ArrayPopFront => Self::ArrayPopFront,
ScalarFunction::ArrayPopBack => Self::ArrayPopBack,
Expand Down Expand Up @@ -1463,9 +1462,6 @@ pub fn parse_expr(
parse_expr(&args[2], registry, codec)?,
parse_expr(&args[3], registry, codec)?,
)),
ScalarFunction::ArrayDistinct => {
Ok(array_distinct(parse_expr(&args[0], registry, codec)?))
}
ScalarFunction::ArrayElement => Ok(array_element(
parse_expr(&args[0], registry, codec)?,
parse_expr(&args[1], registry, codec)?,
Expand Down
1 change: 0 additions & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Rtrim => Self::Rtrim,
BuiltinScalarFunction::ToChar => Self::ToChar,
BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept,
BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct,
BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
BuiltinScalarFunction::ArrayPopFront => Self::ArrayPopFront,
BuiltinScalarFunction::ArrayPopBack => Self::ArrayPopBack,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ async fn roundtrip_expr_api() -> Result<()> {
lit("desc"),
lit("NULLS LAST"),
),
array_distinct(make_array(vec![lit(1), lit(3), lit(3), lit(2), lit(2)])),
];

// ensure expressions created with the expr api can be round tripped
Expand Down
Loading

0 comments on commit 30c4fd7

Please sign in to comment.