diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index c5add8782c51..a18e6df59bf1 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -19,8 +19,9 @@ use std::sync::Arc; use abi_stable::StableAbi; use arrow::{ + array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{FFI_ArrowArray, FFI_ArrowSchema}, + ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -68,3 +69,13 @@ pub struct WrappedArray { pub schema: WrappedSchema, } + +impl TryFrom for ArrayRef { + type Error = arrow::error::ArrowError; + + fn try_from(value: WrappedArray) -> Result { + let data = unsafe { from_ffi(value.array, &value.schema.0)? }; + + Ok(make_array(data)) + } +} diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index 6c5db1218563..8087acfa33c8 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -30,7 +30,8 @@ use datafusion::{ use tokio::runtime::Handle; use crate::{ - plan_properties::FFI_PlanProperties, record_batch_stream::FFI_RecordBatchStream, + df_result, plan_properties::FFI_PlanProperties, + record_batch_stream::FFI_RecordBatchStream, rresult, }; /// A stable struct for sharing a [`ExecutionPlan`] across FFI boundaries. @@ -112,13 +113,11 @@ unsafe extern "C" fn execute_fn_wrapper( let ctx = &(*private_data).context; let runtime = (*private_data).runtime.clone(); - match plan.execute(partition, Arc::clone(ctx)) { - Ok(rbs) => RResult::ROk(FFI_RecordBatchStream::new(rbs, runtime)), - Err(e) => RResult::RErr( - format!("Error occurred during FFI_ExecutionPlan execute: {}", e).into(), - ), - } + rresult!(plan + .execute(partition, Arc::clone(ctx)) + .map(|rbs| FFI_RecordBatchStream::new(rbs, runtime))) } + unsafe extern "C" fn name_fn_wrapper(plan: &FFI_ExecutionPlan) -> RString { let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; @@ -274,16 +273,8 @@ impl ExecutionPlan for ForeignExecutionPlan { _context: Arc, ) -> Result { unsafe { - match (self.plan.execute)(&self.plan, partition) { - RResult::ROk(stream) => { - let stream = Pin::new(Box::new(stream)) as SendableRecordBatchStream; - Ok(stream) - } - RResult::RErr(e) => Err(DataFusionError::Execution(format!( - "Error occurred during FFI call to FFI_ExecutionPlan execute. {}", - e - ))), - } + df_result!((self.plan.execute)(&self.plan, partition)) + .map(|stream| Pin::new(Box::new(stream)) as SendableRecordBatchStream) } } } diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index b25528234773..bbcdd85ff80a 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -26,6 +26,9 @@ pub mod record_batch_stream; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udf; +pub mod util; +pub mod volatility; #[cfg(feature = "integration-tests")] pub mod tests; diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 3c7bc886aede..3592c16b8fab 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -19,8 +19,8 @@ use std::{ffi::c_void, sync::Arc}; use abi_stable::{ std_types::{ - RResult::{self, RErr, ROk}, - RStr, RVec, + RResult::{self, ROk}, + RString, RVec, }, StableAbi, }; @@ -44,7 +44,7 @@ use datafusion_proto::{ }; use prost::Message; -use crate::arrow_wrappers::WrappedSchema; +use crate::{arrow_wrappers::WrappedSchema, df_result, rresult_return}; /// A stable struct for sharing [`PlanProperties`] across FFI boundaries. #[repr(C)] @@ -54,7 +54,7 @@ pub struct FFI_PlanProperties { /// The output partitioning is a [`Partitioning`] protobuf message serialized /// into bytes to pass across the FFI boundary. pub output_partitioning: - unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + unsafe extern "C" fn(plan: &Self) -> RResult, RString>, /// Return the emission type of the plan. pub emission_type: unsafe extern "C" fn(plan: &Self) -> FFI_EmissionType, @@ -64,8 +64,7 @@ pub struct FFI_PlanProperties { /// The output ordering is a [`PhysicalSortExprNodeCollection`] protobuf message /// serialized into bytes to pass across the FFI boundary. - pub output_ordering: - unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + pub output_ordering: unsafe extern "C" fn(plan: &Self) -> RResult, RString>, /// Return the schema of the plan. pub schema: unsafe extern "C" fn(plan: &Self) -> WrappedSchema, @@ -84,21 +83,13 @@ struct PlanPropertiesPrivateData { unsafe extern "C" fn output_partitioning_fn_wrapper( properties: &FFI_PlanProperties, -) -> RResult, RStr<'static>> { +) -> RResult, RString> { let private_data = properties.private_data as *const PlanPropertiesPrivateData; let props = &(*private_data).props; let codec = DefaultPhysicalExtensionCodec {}; let partitioning_data = - match serialize_partitioning(props.output_partitioning(), &codec) { - Ok(p) => p, - Err(_) => { - return RErr( - "unable to serialize output_partitioning in FFI_PlanProperties" - .into(), - ) - } - }; + rresult_return!(serialize_partitioning(props.output_partitioning(), &codec)); let output_partitioning = partitioning_data.encode_to_vec(); ROk(output_partitioning.into()) @@ -122,31 +113,24 @@ unsafe extern "C" fn boundedness_fn_wrapper( unsafe extern "C" fn output_ordering_fn_wrapper( properties: &FFI_PlanProperties, -) -> RResult, RStr<'static>> { +) -> RResult, RString> { let private_data = properties.private_data as *const PlanPropertiesPrivateData; let props = &(*private_data).props; let codec = DefaultPhysicalExtensionCodec {}; - let output_ordering = - match props.output_ordering() { - Some(ordering) => { - let physical_sort_expr_nodes = - match serialize_physical_sort_exprs(ordering.to_owned(), &codec) { - Ok(v) => v, - Err(_) => return RErr( - "unable to serialize output_ordering in FFI_PlanProperties" - .into(), - ), - }; - - let ordering_data = PhysicalSortExprNodeCollection { - physical_sort_expr_nodes, - }; - - ordering_data.encode_to_vec() - } - None => Vec::default(), - }; + let output_ordering = match props.output_ordering() { + Some(ordering) => { + let physical_sort_expr_nodes = rresult_return!( + serialize_physical_sort_exprs(ordering.to_owned(), &codec) + ); + let ordering_data = PhysicalSortExprNodeCollection { + physical_sort_expr_nodes, + }; + + ordering_data.encode_to_vec() + } + None => Vec::default(), + }; ROk(output_ordering.into()) } @@ -200,40 +184,32 @@ impl TryFrom for PlanProperties { let codex = DefaultPhysicalExtensionCodec {}; let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; - let orderings = match ffi_orderings { - ROk(ordering_vec) => { - let proto_output_ordering = - PhysicalSortExprNodeCollection::decode(ordering_vec.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - Some(parse_physical_sort_exprs( - &proto_output_ordering.physical_sort_expr_nodes, - &default_ctx, - &schema, - &codex, - )?) - } - RErr(e) => return Err(DataFusionError::Plan(e.to_string())), - }; - let ffi_partitioning = unsafe { (ffi_props.output_partitioning)(&ffi_props) }; - let partitioning = match ffi_partitioning { - ROk(partitioning_vec) => { - let proto_output_partitioning = - Partitioning::decode(partitioning_vec.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - parse_protobuf_partitioning( - Some(&proto_output_partitioning), - &default_ctx, - &schema, - &codex, - )? - .ok_or(DataFusionError::Plan( - "Unable to deserialize partitioning protobuf in FFI_PlanProperties" - .to_string(), - )) - } - RErr(e) => Err(DataFusionError::Plan(e.to_string())), - }?; + let proto_output_ordering = + PhysicalSortExprNodeCollection::decode(df_result!(ffi_orderings)?.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let orderings = Some(parse_physical_sort_exprs( + &proto_output_ordering.physical_sort_expr_nodes, + &default_ctx, + &schema, + &codex, + )?); + + let partitioning_vec = + unsafe { df_result!((ffi_props.output_partitioning)(&ffi_props))? }; + let proto_output_partitioning = + Partitioning::decode(partitioning_vec.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let partitioning = parse_protobuf_partitioning( + Some(&proto_output_partitioning), + &default_ctx, + &schema, + &codex, + )? + .ok_or(DataFusionError::Plan( + "Unable to deserialize partitioning protobuf in FFI_PlanProperties" + .to_string(), + ))?; let eq_properties = match orderings { Some(ordering) => { diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 466ce247678a..939c4050028c 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -35,7 +35,10 @@ use datafusion::{ use futures::{Stream, TryStreamExt}; use tokio::runtime::Handle; -use crate::arrow_wrappers::{WrappedArray, WrappedSchema}; +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + rresult, +}; /// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries. /// We use the async-ffi crate for handling async calls across libraries. @@ -97,13 +100,12 @@ fn record_batch_to_wrapped_array( record_batch: RecordBatch, ) -> RResult { let struct_array = StructArray::from(record_batch); - match to_ffi(&struct_array.to_data()) { - Ok((array, schema)) => RResult::ROk(WrappedArray { + rresult!( + to_ffi(&struct_array.to_data()).map(|(array, schema)| WrappedArray { array, - schema: WrappedSchema(schema), - }), - Err(e) => RResult::RErr(e.to_string().into()), - } + schema: WrappedSchema(schema) + }) + ) } // probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result { diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 978ac10206bd..0b4080abcb55 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -44,6 +44,7 @@ use tokio::runtime::Handle; use crate::{ arrow_wrappers::WrappedSchema, + df_result, rresult_return, session_config::ForeignSessionConfig, table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, }; @@ -233,10 +234,7 @@ unsafe extern "C" fn scan_fn_wrapper( let runtime = &(*private_data).runtime; async move { - let config = match ForeignSessionConfig::try_from(&session_config) { - Ok(c) => c, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); let session = SessionStateBuilder::new() .with_default_features() .with_config(config.0) @@ -250,15 +248,13 @@ unsafe extern "C" fn scan_fn_wrapper( let codec = DefaultLogicalExtensionCodec {}; let proto_filters = - match LogicalExprList::decode(filters_serialized.as_ref()) { - Ok(f) => f, - Err(e) => return RResult::RErr(e.to_string().into()), - }; - - match parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec) { - Ok(f) => f, - Err(e) => return RResult::RErr(e.to_string().into()), - } + rresult_return!(LogicalExprList::decode(filters_serialized.as_ref())); + + rresult_return!(parse_exprs( + proto_filters.expr.iter(), + &default_ctx, + &codec + )) } }; @@ -268,13 +264,11 @@ unsafe extern "C" fn scan_fn_wrapper( false => Some(&projections), }; - let plan = match internal_provider - .scan(&ctx.state(), maybe_projections, &filters, limit.into()) - .await - { - Ok(p) => p, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let plan = rresult_return!( + internal_provider + .scan(&ctx.state(), maybe_projections, &filters, limit.into()) + .await + ); RResult::ROk(FFI_ExecutionPlan::new( plan, @@ -298,30 +292,22 @@ unsafe extern "C" fn insert_into_fn_wrapper( let runtime = &(*private_data).runtime; async move { - let config = match ForeignSessionConfig::try_from(&session_config) { - Ok(c) => c, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); let session = SessionStateBuilder::new() .with_default_features() .with_config(config.0) .build(); let ctx = SessionContext::new_with_state(session); - let input = match ForeignExecutionPlan::try_from(&input) { - Ok(input) => Arc::new(input), - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let input = rresult_return!(ForeignExecutionPlan::try_from(&input).map(Arc::new)); let insert_op = InsertOp::from(insert_op); - let plan = match internal_provider - .insert_into(&ctx.state(), input, insert_op) - .await - { - Ok(p) => p, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let plan = rresult_return!( + internal_provider + .insert_into(&ctx.state(), input, insert_op) + .await + ); RResult::ROk(FFI_ExecutionPlan::new( plan, @@ -456,14 +442,7 @@ impl TableProvider for ForeignTableProvider { ) .await; - match maybe_plan { - RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?, - RResult::RErr(_) => { - return Err(DataFusionError::Internal( - "Unable to perform scan via FFI".to_string(), - )) - } - } + ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? }; Ok(Arc::new(plan)) @@ -493,12 +472,9 @@ impl TableProvider for ForeignTableProvider { }; let serialized_filters = expr_list.encode_to_vec(); - let pushdowns = pushdown_fn(&self.0, serialized_filters.into()); + let pushdowns = df_result!(pushdown_fn(&self.0, serialized_filters.into()))?; - match pushdowns { - RResult::ROk(p) => Ok(p.iter().map(|v| v.into()).collect()), - RResult::RErr(e) => Err(DataFusionError::Plan(e.to_string())), - } + Ok(pushdowns.iter().map(|v| v.into()).collect()) } } @@ -519,15 +495,7 @@ impl TableProvider for ForeignTableProvider { let maybe_plan = (self.0.insert_into)(&self.0, &session_config, &input, insert_op).await; - match maybe_plan { - RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?, - RResult::RErr(e) => { - return Err(DataFusionError::Internal(format!( - "Unable to perform insert_into via FFI: {}", - e - ))) - } - } + ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? }; Ok(Arc::new(plan)) diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index a5fc74b840d1..5a471cb8fe43 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -26,7 +26,7 @@ use abi_stable::{ StableAbi, }; -use super::table_provider::FFI_TableProvider; +use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; use datafusion::{ @@ -34,27 +34,30 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; +use udf_udaf_udwf::create_ffi_abs_func; mod async_provider; mod sync_provider; +mod udf_udaf_udwf; #[repr(C)] #[derive(StableAbi)] -#[sabi(kind(Prefix(prefix_ref = TableProviderModuleRef)))] +#[sabi(kind(Prefix(prefix_ref = ForeignLibraryModuleRef)))] /// This struct defines the module interfaces. It is to be shared by /// both the module loading program and library that implements the -/// module. It is possible to move this definition into the loading -/// program and reference it in the modules, but this example shows -/// how a user may wish to separate these concerns. -pub struct TableProviderModule { +/// module. +pub struct ForeignLibraryModule { /// Constructs the table provider pub create_table: extern "C" fn(synchronous: bool) -> FFI_TableProvider, + /// Create a scalar UDF + pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, + pub version: extern "C" fn() -> u64, } -impl RootModule for TableProviderModuleRef { - declare_root_module_statics! {TableProviderModuleRef} +impl RootModule for ForeignLibraryModuleRef { + declare_root_module_statics! {ForeignLibraryModuleRef} const BASE_NAME: &'static str = "datafusion_ffi"; const NAME: &'static str = "datafusion_ffi"; const VERSION_STRINGS: VersionStrings = package_version_strings!(); @@ -64,7 +67,7 @@ impl RootModule for TableProviderModuleRef { } } -fn create_test_schema() -> Arc { +pub fn create_test_schema() -> Arc { Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -90,9 +93,10 @@ extern "C" fn construct_table_provider(synchronous: bool) -> FFI_TableProvider { #[export_root_module] /// This defines the entry point for using the module. -pub fn get_simple_memory_table() -> TableProviderModuleRef { - TableProviderModule { +pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { + ForeignLibraryModule { create_table: construct_table_provider, + create_scalar_udf: create_ffi_abs_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs new file mode 100644 index 000000000000..e8a13aac1308 --- /dev/null +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::udf::FFI_ScalarUDF; +use datafusion::{functions::math::abs::AbsFunc, logical_expr::ScalarUDF}; + +use std::sync::Arc; + +pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { + let udf: Arc = Arc::new(AbsFunc::new().into()); + + udf.into() +} diff --git a/datafusion/ffi/src/udf.rs b/datafusion/ffi/src/udf.rs new file mode 100644 index 000000000000..bbc9cf936cee --- /dev/null +++ b/datafusion/ffi/src/udf.rs @@ -0,0 +1,351 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::DataType; +use arrow::{ + array::ArrayRef, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, +}; +use datafusion::{ + error::DataFusionError, + logical_expr::type_coercion::functions::data_types_with_scalar_udf, +}; +use datafusion::{ + error::Result, + logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + }, +}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; + +/// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_ScalarUDF { + /// FFI equivalent to the `name` of a [`ScalarUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`ScalarUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`ScalarUDF`] + pub volatility: FFI_Volatility, + + /// Determines the return type of the underlying [`ScalarUDF`] based on the + /// argument types. + pub return_type: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, + + /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` + /// within an AbiStable wrapper. + pub invoke_with_args: unsafe extern "C" fn( + udf: &Self, + args: RVec, + num_rows: usize, + return_type: WrappedSchema, + ) -> RResult, + + /// See [`ScalarUDFImpl`] for details on short_circuits + pub short_circuits: bool, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`ScalarUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + /// Used to create a clone on the provider of the udf. This should + /// only need to be called by the receiver of the udf. + pub clone: unsafe extern "C" fn(udf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udf. + /// A [`ForeignScalarUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_ScalarUDF {} +unsafe impl Sync for FFI_ScalarUDF {} + +pub struct ScalarUDFPrivateData { + pub udf: Arc, +} + +unsafe extern "C" fn return_type_fn_wrapper( + udf: &FFI_ScalarUDF, + arg_types: RVec, +) -> RResult { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf = &(*private_data).udf; + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udf: &FFI_ScalarUDF, + arg_types: RVec, +) -> RResult, RString> { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf = &(*private_data).udf; + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf)); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn invoke_with_args_fn_wrapper( + udf: &FFI_ScalarUDF, + args: RVec, + number_rows: usize, + return_type: WrappedSchema, +) -> RResult { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf = &(*private_data).udf; + + let args = args + .into_iter() + .map(|arr| { + from_ffi(arr.array, &arr.schema.0) + .map(|v| ColumnarValue::Array(arrow::array::make_array(v))) + }) + .collect::>(); + + let args = rresult_return!(args); + let return_type = rresult_return!(DataType::try_from(&return_type.0)); + + let args = ScalarFunctionArgs { + args, + number_rows, + return_type: &return_type, + }; + + let result = rresult_return!(udf + .invoke_with_args(args) + .and_then(|r| r.to_array(number_rows))); + + let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data())); + + RResult::ROk(WrappedArray { + array: result_array, + schema: WrappedSchema(result_schema), + }) +} + +unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) { + let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf_data = &(*private_data); + + Arc::clone(&udf_data.udf).into() +} + +impl Clone for FFI_ScalarUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_ScalarUDF { + fn from(udf: Arc) -> Self { + let name = udf.name().into(); + let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let volatility = udf.signature().volatility.into(); + let short_circuits = udf.short_circuits(); + + let private_data = Box::new(ScalarUDFPrivateData { udf }); + + Self { + name, + aliases, + volatility, + short_circuits, + invoke_with_args: invoke_with_args_fn_wrapper, + return_type: return_type_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_ScalarUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignScalarUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_ScalarUDF. +#[derive(Debug)] +pub struct ForeignScalarUDF { + name: String, + aliases: Vec, + udf: FFI_ScalarUDF, + signature: Signature, +} + +unsafe impl Send for ForeignScalarUDF {} +unsafe impl Sync for ForeignScalarUDF {} + +impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF { + type Error = DataFusionError; + + fn try_from(udf: &FFI_ScalarUDF) -> Result { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + name, + udf: udf.clone(), + aliases, + signature, + }) + } +} + +impl ScalarUDFImpl for ForeignScalarUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args, + number_rows, + return_type, + } = invoke_args; + + let args = args + .into_iter() + .map(|v| v.to_array(number_rows)) + .collect::>>()? + .into_iter() + .map(|v| { + to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray { + array: ffi_array, + schema: WrappedSchema(ffi_schema), + }) + }) + .collect::, ArrowError>>()? + .into(); + + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); + + let result = unsafe { + (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type) + }; + + let result = df_result!(result)?; + let result_array: ArrayRef = result.try_into()?; + + Ok(ColumnarValue::Array(result_array)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn short_circuits(&self) -> bool { + self.udf.short_circuits + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_scalar_udf() -> Result<()> { + let original_udf = datafusion::functions::math::abs::AbsFunc::new(); + let original_udf = Arc::new(ScalarUDF::from(original_udf)); + + let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into(); + + let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; + + assert!(original_udf.name() == foreign_udf.name()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs new file mode 100644 index 000000000000..9d5f2aefe324 --- /dev/null +++ b/datafusion/ffi/src/util.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::std_types::RVec; +use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; + +use crate::arrow_wrappers::WrappedSchema; + +/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// DataFusion result. +#[macro_export] +macro_rules! df_result { + ( $x:expr ) => { + match $x { + abi_stable::std_types::RResult::ROk(v) => Ok(v), + abi_stable::std_types::RResult::RErr(e) => { + Err(datafusion::error::DataFusionError::Execution(e.to_string())) + } + } + }; +} + +/// This macro is a helpful conversion utility to conver from a DataFusion Result to an abi_stable::RResult +#[macro_export] +macro_rules! rresult { + ( $x:expr ) => { + match $x { + Ok(v) => abi_stable::std_types::RResult::ROk(v), + Err(e) => abi_stable::std_types::RResult::RErr( + abi_stable::std_types::RString::from(e.to_string()), + ), + } + }; +} + +/// This macro is a helpful conversion utility to conver from a DataFusion Result to an abi_stable::RResult +/// and to also call return when it is an error. Since you cannot use `?` on an RResult, this is designed +/// to mimic the pattern. +#[macro_export] +macro_rules! rresult_return { + ( $x:expr ) => { + match $x { + Ok(v) => v, + Err(e) => { + return abi_stable::std_types::RResult::RErr( + abi_stable::std_types::RString::from(e.to_string()), + ) + } + } + }; +} + +/// This is a utility function to convert a slice of [`DataType`] to its equivalent +/// FFI friendly counterpart, [`WrappedSchema`] +pub fn vec_datatype_to_rvec_wrapped( + data_types: &[DataType], +) -> Result, arrow::error::ArrowError> { + Ok(data_types + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, arrow::error::ArrowError>>()? + .into_iter() + .map(WrappedSchema) + .collect()) +} + +/// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] +/// to their equivalent [`DataType`]. +pub fn rvec_wrapped_to_vec_datatype( + data_types: &RVec, +) -> Result, arrow::error::ArrowError> { + data_types + .iter() + .map(|d| DataType::try_from(&d.0)) + .collect() +} + +#[cfg(test)] +mod tests { + use abi_stable::std_types::{RResult, RString}; + use datafusion::error::DataFusionError; + + fn wrap_result(result: Result) -> RResult { + RResult::ROk(rresult_return!(result)) + } + + #[test] + fn test_conversion() { + const VALID_VALUE: &str = "valid_value"; + const ERROR_VALUE: &str = "error_value"; + + let ok_r_result: RResult = + RResult::ROk(VALID_VALUE.to_string().into()); + let err_r_result: RResult = + RResult::RErr(ERROR_VALUE.to_string().into()); + + let returned_ok_result = df_result!(ok_r_result); + assert!(returned_ok_result.is_ok()); + assert!(returned_ok_result.unwrap().to_string() == VALID_VALUE); + + let returned_err_result = df_result!(err_r_result); + assert!(returned_err_result.is_err()); + assert!( + returned_err_result.unwrap_err().to_string() + == format!("Execution error: {}", ERROR_VALUE) + ); + + let ok_result: Result = Ok(VALID_VALUE.to_string()); + let err_result: Result = + Err(DataFusionError::Execution(ERROR_VALUE.to_string())); + + let returned_ok_r_result = wrap_result(ok_result); + assert!(returned_ok_r_result == RResult::ROk(VALID_VALUE.into())); + + let returned_err_r_result = wrap_result(err_result); + assert!( + returned_err_r_result + == RResult::RErr(format!("Execution error: {}", ERROR_VALUE).into()) + ); + } +} diff --git a/datafusion/ffi/src/volatility.rs b/datafusion/ffi/src/volatility.rs new file mode 100644 index 000000000000..8b565b91b76d --- /dev/null +++ b/datafusion/ffi/src/volatility.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::StableAbi; +use datafusion::logical_expr::Volatility; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_Volatility { + Immutable, + Stable, + Volatile, +} + +impl From for FFI_Volatility { + fn from(value: Volatility) -> Self { + match value { + Volatility::Immutable => Self::Immutable, + Volatility::Stable => Self::Stable, + Volatility::Volatile => Self::Volatile, + } + } +} + +impl From<&FFI_Volatility> for Volatility { + fn from(value: &FFI_Volatility) -> Self { + match value { + FFI_Volatility::Immutable => Self::Immutable, + FFI_Volatility::Stable => Self::Stable, + FFI_Volatility::Volatile => Self::Volatile, + } + } +} diff --git a/datafusion/ffi/tests/table_provider.rs b/datafusion/ffi/tests/ffi_integration.rs similarity index 68% rename from datafusion/ffi/tests/table_provider.rs rename to datafusion/ffi/tests/ffi_integration.rs index 9169c9f4221c..84e120df4299 100644 --- a/datafusion/ffi/tests/table_provider.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -21,10 +21,13 @@ mod tests { use abi_stable::library::RootModule; + use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; - use datafusion::prelude::SessionContext; + use datafusion::logical_expr::ScalarUDF; + use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::table_provider::ForeignTableProvider; - use datafusion_ffi::tests::TableProviderModuleRef; + use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; + use datafusion_ffi::udf::ForeignScalarUDF; use std::path::Path; use std::sync::Arc; @@ -61,11 +64,7 @@ mod tests { Ok(best_path) } - /// It is important that this test is in the `tests` directory and not in the - /// library directory so we can verify we are building a dynamic library and - /// testing it via a different executable. - #[cfg(feature = "integration-tests")] - async fn test_table_provider(synchronous: bool) -> Result<()> { + fn get_module() -> Result { let expected_version = datafusion_ffi::version(); let crate_root = Path::new(env!("CARGO_MANIFEST_DIR")); @@ -80,22 +79,30 @@ mod tests { // so you will need to change the approach here based on your use case. // let target: &std::path::Path = "../../../../target/".as_ref(); let library_path = - compute_library_path::(target_dir.as_path()) + compute_library_path::(target_dir.as_path()) .map_err(|e| DataFusionError::External(Box::new(e)))? .join("deps"); // Load the module - let table_provider_module = - TableProviderModuleRef::load_from_directory(&library_path) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + let module = ForeignLibraryModuleRef::load_from_directory(&library_path) + .map_err(|e| DataFusionError::External(Box::new(e)))?; assert_eq!( - table_provider_module + module .version() .expect("Unable to call version on FFI module")(), expected_version ); + Ok(module) + } + + /// It is important that this test is in the `tests` directory and not in the + /// library directory so we can verify we are building a dynamic library and + /// testing it via a different executable. + async fn test_table_provider(synchronous: bool) -> Result<()> { + let table_provider_module = get_module()?; + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = table_provider_module.create_table().ok_or( @@ -116,9 +123,9 @@ mod tests { let results = df.collect().await?; assert_eq!(results.len(), 3); - assert_eq!(results[0], datafusion_ffi::tests::create_record_batch(1, 5)); - assert_eq!(results[1], datafusion_ffi::tests::create_record_batch(6, 1)); - assert_eq!(results[2], datafusion_ffi::tests::create_record_batch(7, 5)); + assert_eq!(results[0], create_record_batch(1, 5)); + assert_eq!(results[1], create_record_batch(6, 1)); + assert_eq!(results[2], create_record_batch(7, 5)); Ok(()) } @@ -132,4 +139,44 @@ mod tests { async fn sync_test_table_provider() -> Result<()> { test_table_provider(true).await } + + /// This test validates that we can load an external module and use a scalar + /// udf defined in it via the foreign function interface. In this case we are + /// using the abs() function as our scalar UDF. + #[tokio::test] + async fn test_scalar_udf() -> Result<()> { + let module = get_module()?; + + let ffi_abs_func = + module + .create_scalar_udf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_scalar_udf" + .to_string(), + ))?(); + let foreign_abs_func: ForeignScalarUDF = (&ffi_abs_func).try_into()?; + + let udf: ScalarUDF = foreign_abs_func.into(); + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df + .with_column("abs_a", udf.call(vec![col("a")]))? + .with_column("abs_b", udf.call(vec![col("b")]))?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![-5, -4, -3, -2, -1]), + ("b", Float64, vec![-5., -4., -3., -2., -1.]), + ("abs_a", Int32, vec![5, 4, 3, 2, 1]), + ("abs_b", Float64, vec![5., 4., 3., 2., 1.]) + )?; + + assert!(result.len() == 1); + assert!(result[0] == expected); + + Ok(()) + } }