Skip to content

Commit

Permalink
Port StringToArray to function-arrays subcrate (#9543)
Browse files Browse the repository at this point in the history
* Issue-9497 - Port StringToArray to function-arrays

* Issue-9497 - Fix formatting issues

* Issue-9497 - Format expressions.md documentation
  • Loading branch information
erenavsarogullari authored Mar 11, 2024
1 parent 88187d4 commit 4cd3c43
Show file tree
Hide file tree
Showing 15 changed files with 224 additions and 188 deletions.
18 changes: 0 additions & 18 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ pub enum BuiltinScalarFunction {
SHA512,
/// split_part
SplitPart,
/// string_to_array
StringToArray,
/// starts_with
StartsWith,
/// strpos
Expand Down Expand Up @@ -383,7 +381,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::SHA512 => Volatility::Immutable,
BuiltinScalarFunction::Digest => Volatility::Immutable,
BuiltinScalarFunction::SplitPart => Volatility::Immutable,
BuiltinScalarFunction::StringToArray => Volatility::Immutable,
BuiltinScalarFunction::StartsWith => Volatility::Immutable,
BuiltinScalarFunction::Strpos => Volatility::Immutable,
BuiltinScalarFunction::Substr => Volatility::Immutable,
Expand Down Expand Up @@ -556,11 +553,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
BuiltinScalarFunction::StringToArray => Ok(List(Arc::new(Field::new(
"item",
input_expr_types[0].clone(),
true,
)))),
BuiltinScalarFunction::StartsWith => Ok(Boolean),
BuiltinScalarFunction::EndsWith => Ok(Boolean),
BuiltinScalarFunction::Strpos => {
Expand Down Expand Up @@ -833,13 +825,6 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::StringToArray => Signature::one_of(
vec![
TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]),
TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]),
],
self.volatility(),
),

BuiltinScalarFunction::EndsWith
| BuiltinScalarFunction::Strpos
Expand Down Expand Up @@ -1087,9 +1072,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Rpad => &["rpad"],
BuiltinScalarFunction::Rtrim => &["rtrim"],
BuiltinScalarFunction::SplitPart => &["split_part"],
BuiltinScalarFunction::StringToArray => {
&["string_to_array", "string_to_list"]
}
BuiltinScalarFunction::StartsWith => &["starts_with"],
BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"],
BuiltinScalarFunction::Substr => &["substr"],
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,6 @@ scalar_expr!(SHA256, sha256, string, "SHA-256 hash");
scalar_expr!(SHA384, sha384, string, "SHA-384 hash");
scalar_expr!(SHA512, sha512, string, "SHA-512 hash");
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`");
scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
Expand Down Expand Up @@ -1275,7 +1274,6 @@ mod test {
test_scalar_expr!(SHA384, sha384, string);
test_scalar_expr!(SHA512, sha512, string);
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value);
test_scalar_expr!(StartsWith, starts_with, string, characters);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
Expand Down
106 changes: 99 additions & 7 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@
use arrow::array::{
Array, ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array,
GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray,
OffsetSizeTrait, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeListArray,
LargeStringArray, ListArray, ListBuilder, OffsetSizeTrait, StringArray,
StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::array::{LargeListArray, ListArray};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::Field;
use arrow::datatypes::UInt64Type;
use arrow::datatypes::{DataType, Date32Type, IntervalMonthDayNanoType};
use datafusion_common::cast::{
as_date32_array, as_generic_list_array, as_int64_array, as_interval_mdn_array,
as_large_list_array, as_list_array, as_null_array, as_string_array,
as_date32_array, as_generic_list_array, as_generic_string_array, as_int64_array,
as_interval_mdn_array, as_large_list_array, as_list_array, as_null_array,
as_string_array,
};
use datafusion_common::DataFusionError;
use datafusion_common::{exec_err, not_impl_datafusion_err, Result};
use datafusion_common::{exec_err, not_impl_datafusion_err, DataFusionError, Result};
use std::any::type_name;
use std::sync::Arc;

Expand Down Expand Up @@ -261,6 +261,98 @@ pub(super) fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(string_arr))
}

/// Splits string at occurrences of delimiter and returns an array of parts
/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]'
pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("string_to_array expects two or three arguments");
}
let string_array = as_generic_string_array::<T>(&args[0])?;
let delimiter_array = as_generic_string_array::<T>(&args[1])?;

let mut list_builder = ListBuilder::new(StringBuilder::with_capacity(
string_array.len(),
string_array.get_buffer_memory_size(),
));

match args.len() {
2 => {
string_array.iter().zip(delimiter_array.iter()).for_each(
|(string, delimiter)| {
match (string, delimiter) {
(Some(string), Some("")) => {
list_builder.values().append_value(string);
list_builder.append(true);
}
(Some(string), Some(delimiter)) => {
string.split(delimiter).for_each(|s| {
list_builder.values().append_value(s);
});
list_builder.append(true);
}
(Some(string), None) => {
string.chars().map(|c| c.to_string()).for_each(|c| {
list_builder.values().append_value(c);
});
list_builder.append(true);
}
_ => list_builder.append(false), // null value
}
},
);
}

3 => {
let null_value_array = as_generic_string_array::<T>(&args[2])?;
string_array
.iter()
.zip(delimiter_array.iter())
.zip(null_value_array.iter())
.for_each(|((string, delimiter), null_value)| {
match (string, delimiter) {
(Some(string), Some("")) => {
if Some(string) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(string);
}
list_builder.append(true);
}
(Some(string), Some(delimiter)) => {
string.split(delimiter).for_each(|s| {
if Some(s) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(s);
}
});
list_builder.append(true);
}
(Some(string), None) => {
string.chars().map(|c| c.to_string()).for_each(|c| {
if Some(c.as_str()) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(c);
}
});
list_builder.append(true);
}
_ => list_builder.append(false), // null value
}
});
}
_ => {
return exec_err!(
"Expect string_to_array function to take two or three parameters"
)
}
}

let list_array = list_builder.finish();
Ok(Arc::new(list_array) as ArrayRef)
}

/// Generates an array of integers from start to stop with a given step.
///
/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values.
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ pub mod expr_fn {
pub use super::udf::flatten;
pub use super::udf::gen_series;
pub use super::udf::range;
pub use super::udf::string_to_array;
}

/// Registers all enabled packages with a [`FunctionRegistry`]
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let functions: Vec<Arc<ScalarUDF>> = vec![
udf::array_to_string_udf(),
udf::string_to_array_udf(),
udf::range_udf(),
udf::gen_series_udf(),
udf::array_dims_udf(),
Expand Down
76 changes: 76 additions & 0 deletions datafusion/functions-array/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! [`ScalarUDFImpl`] definitions for array functions.
use arrow::array::{NullArray, StringArray};
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use arrow::datatypes::IntervalUnit::MonthDayNano;
Expand Down Expand Up @@ -89,6 +90,81 @@ impl ScalarUDFImpl for ArrayToString {
}
}

make_udf_function!(StringToArray,
string_to_array,
string delimiter null_string, // arg name
"splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc
string_to_array_udf // internal function name
);
#[derive(Debug)]
pub(super) struct StringToArray {
signature: Signature,
aliases: Vec<String>,
}

impl StringToArray {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: vec![
String::from("string_to_array"),
String::from("string_to_list"),
],
}
}
}

impl ScalarUDFImpl for StringToArray {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"string_to_array"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
use DataType::*;
Ok(match arg_types[0] {
Utf8 | LargeUtf8 => {
List(Arc::new(Field::new("item", arg_types[0].clone(), true)))
}
_ => {
return plan_err!(
"The string_to_array function can only accept Utf8 or LargeUtf8."
);
}
})
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
let mut args = ColumnarValue::values_to_arrays(args)?;
// Case: delimiter is NULL, needs to be handled as well.
if args[1].as_any().is::<NullArray>() {
args[1] = Arc::new(StringArray::new_null(args[1].len()));
};

match args[0].data_type() {
arrow::datatypes::DataType::Utf8 => {
crate::kernels::string_to_array::<i32>(&args).map(ColumnarValue::Array)
}
arrow::datatypes::DataType::LargeUtf8 => {
crate::kernels::string_to_array::<i64>(&args).map(ColumnarValue::Array)
}
other => {
exec_err!("unsupported type for string_to_array function as {other}")
}
}
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

make_udf_function!(
Range,
range,
Expand Down
Loading

0 comments on commit 4cd3c43

Please sign in to comment.