Skip to content

Commit

Permalink
Add union_extract scalar function (#12116)
Browse files Browse the repository at this point in the history
* feat: add union_extract scalar function

* fix:  docs fmt, add clippy atr, sql error msg

* use arrow-rs implementation

* docs: add union functions section

* docs: simplify union_extract docs

* test: simplify union_extract sqllogictests

* refactor(union_extract): new udf api, docs macro, use any signature

* fix: remove user_doc include attribute

* fix: generate docs

* fix: manually trim sqllogictest generated errors

* fix: fmt

* docs: add union functions section description

* docs: update functions docs

* docs: clarify union_extract description

Co-authored-by: Bruce Ritchie <bruce.ritchie@veeva.com>

* fix: use return_type_from_args, tests, docs

---------

Co-authored-by: Bruce Ritchie <bruce.ritchie@veeva.com>
  • Loading branch information
gstvg and Omega359 authored Feb 14, 2025
1 parent c1338b7 commit 7873e5c
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 3 deletions.
8 changes: 8 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ pub mod scalar_doc_sections {
DOC_SECTION_STRUCT,
DOC_SECTION_MAP,
DOC_SECTION_HASHING,
DOC_SECTION_UNION,
DOC_SECTION_OTHER,
]
}
Expand All @@ -996,6 +997,7 @@ pub mod scalar_doc_sections {
DOC_SECTION_STRUCT,
DOC_SECTION_MAP,
DOC_SECTION_HASHING,
DOC_SECTION_UNION,
DOC_SECTION_OTHER,
]
}
Expand Down Expand Up @@ -1070,4 +1072,10 @@ The following regular expression functions are supported:"#,
label: "Other Functions",
description: None,
};

pub const DOC_SECTION_UNION: DocSection = DocSection {
include: true,
label: "Union Functions",
description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"),
};
}
8 changes: 8 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub mod nvl;
pub mod nvl2;
pub mod planner;
pub mod r#struct;
pub mod union_extract;
pub mod version;

// create UDFs
Expand All @@ -48,6 +49,7 @@ make_udf_function!(getfield::GetFieldFunc, get_field);
make_udf_function!(coalesce::CoalesceFunc, coalesce);
make_udf_function!(greatest::GreatestFunc, greatest);
make_udf_function!(least::LeastFunc, least);
make_udf_function!(union_extract::UnionExtractFun, union_extract);
make_udf_function!(version::VersionFunc, version);

pub mod expr_fn {
Expand Down Expand Up @@ -99,6 +101,11 @@ pub mod expr_fn {
pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr {
super::get_field().call(vec![arg1, arg2.lit()])
}

#[doc = "Returns the value of the field with the given name from the union when it's selected, or NULL otherwise"]
pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr {
super::union_extract().call(vec![arg1, arg2.lit()])
}
}

/// Returns all DataFusion functions defined in this package
Expand All @@ -121,6 +128,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
coalesce(),
greatest(),
least(),
union_extract(),
version(),
r#struct(),
]
Expand Down
255 changes: 255 additions & 0 deletions datafusion/functions/src/core/union_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::Array;
use arrow::datatypes::{DataType, FieldRef, UnionFields};
use datafusion_common::cast::as_union_array;
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, Result, ScalarValue,
};
use datafusion_doc::Documentation;
use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;

#[user_doc(
doc_section(label = "Union Functions"),
description = "Returns the value of the given field in the union when selected, or NULL otherwise.",
syntax_example = "union_extract(union, field_name)",
sql_example = r#"```sql
❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
+--------------+----------------------------------+----------------------------------+
| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
+--------------+----------------------------------+----------------------------------+
| {a=1} | 1 | |
| {b=3.0} | | 3.0 |
| {a=4} | 4 | |
| {b=} | | |
| {a=} | | |
+--------------+----------------------------------+----------------------------------+
```"#,
standard_argument(name = "union", prefix = "Union"),
argument(
name = "field_name",
description = "String expression to operate on. Must be a constant."
)
)]
#[derive(Debug)]
pub struct UnionExtractFun {
signature: Signature,
}

impl Default for UnionExtractFun {
fn default() -> Self {
Self::new()
}
}

impl UnionExtractFun {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for UnionExtractFun {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"union_extract"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
// should be using return_type_from_exprs and not calling the default implementation
internal_err!("union_extract should return type from exprs")
}

fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
if args.arg_types.len() != 2 {
return exec_err!(
"union_extract expects 2 arguments, got {} instead",
args.arg_types.len()
);
}

let DataType::Union(fields, _) = &args.arg_types[0] else {
return exec_err!(
"union_extract first argument must be a union, got {} instead",
args.arg_types[0]
);
};

let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else {
return exec_err!(
"union_extract second argument must be a non-null string literal, got {} instead",
args.arg_types[1]
);
};

let field = find_field(fields, field_name)?.1;

Ok(ReturnInfo::new_nullable(field.data_type().clone()))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;

if args.len() != 2 {
return exec_err!(
"union_extract expects 2 arguments, got {} instead",
args.len()
);
}

let target_name = match &args[1] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name),
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"),
_ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()),
};

match &args[0] {
ColumnarValue::Array(array) => {
let union_array = as_union_array(&array).map_err(|_| {
exec_datafusion_err!(
"union_extract first argument must be a union, got {} instead",
array.data_type()
)
})?;

Ok(ColumnarValue::Array(
arrow::compute::kernels::union_extract::union_extract(
union_array,
target_name?,
)?,
))
}
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
let target_name = target_name?;
let (target_type_id, target) = find_field(fields, target_name)?;

let result = match value {
Some((type_id, value)) if target_type_id == *type_id => {
*value.clone()
}
_ => ScalarValue::try_from(target.data_type())?,
};

Ok(ColumnarValue::Scalar(result))
}
other => exec_err!(
"union_extract first argument must be a union, got {} instead",
other.data_type()
),
}
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> {
fields
.iter()
.find(|field| field.1.name() == name)
.ok_or_else(|| exec_datafusion_err!("field {name} not found on union"))
}

#[cfg(test)]
mod tests {

use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};

use super::UnionExtractFun;

// when it becomes possible to construct union scalars in SQL, this should go to sqllogictests
#[test]
fn test_scalar_value() -> Result<()> {
let fun = UnionExtractFun::new();

let fields = UnionFields::new(
vec![1, 3],
vec![
Field::new("str", DataType::Utf8, false),
Field::new("int", DataType::Int32, false),
],
);

let result = fun.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Union(
None,
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
],
number_rows: 1,
return_type: &DataType::Utf8,
})?;

assert_scalar(result, ScalarValue::Utf8(None));

let result = fun.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Union(
Some((3, Box::new(ScalarValue::Int32(Some(42))))),
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
],
number_rows: 1,
return_type: &DataType::Utf8,
})?;

assert_scalar(result, ScalarValue::Utf8(None));

let result = fun.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Union(
Some((1, Box::new(ScalarValue::new_utf8("42")))),
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
],
number_rows: 1,
return_type: &DataType::Utf8,
})?;

assert_scalar(result, ScalarValue::new_utf8("42"));

Ok(())
}

fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
match value {
ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
}
}
}
32 changes: 29 additions & 3 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ use std::path::Path;
use std::sync::Arc;

use arrow::array::{
ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray,
StringArray, TimestampNanosecondArray,
Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray,
LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray,
};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields};
use arrow::record_batch::RecordBatch;
use datafusion::catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session,
Expand Down Expand Up @@ -113,6 +114,10 @@ impl TestContext {
info!("Registering metadata table tables");
register_metadata_tables(test_ctx.session_ctx()).await;
}
"union_function.slt" => {
info!("Registering table with union column");
register_union_table(test_ctx.session_ctx())
}
_ => {
info!("Using default SessionContext");
}
Expand Down Expand Up @@ -402,3 +407,24 @@ fn create_example_udf() -> ScalarUDF {
adder,
)
}

fn register_union_table(ctx: &SessionContext) {
let union = UnionArray::try_new(
UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]),
ScalarBuffer::from(vec![3, 3]),
None,
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)
.unwrap();

let schema = Schema::new(vec![Field::new(
"union_column",
union.data_type().clone(),
false,
)]);

let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap();

ctx.register_batch("union_table", batch).unwrap();
}
Loading

0 comments on commit 7873e5c

Please sign in to comment.