m09526 commented on code in PR #14775: URL: https://github.com/apache/datafusion/pull/14775#discussion_r1975307949
########## datafusion/ffi/src/arrow_wrappers.rs: ########## @@ -31,30 +32,37 @@ use log::error; #[derive(Debug, StableAbi)] pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] Review Comment: Is returning an empty schema with a warning the best thing to do here? It looks like it could create a silent failure which could be ignored in client code. Would it be safer to outright fail at this point? ########## datafusion/ffi/src/udaf/accumulator.rs: ########## @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { Review Comment: Could use a comment explaining purpose of struct and links to any relevant documentation. ########## datafusion/ffi/tests/ffi_integration.rs: ########## @@ -179,4 +181,103 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_avg_func = Review Comment: ffi_avg_func suggests average is being used, but code creates sum UDAF! ########## datafusion/ffi/src/udaf/mod.rs: ########## @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RStr, RString, RVec}, + StableAbi, +}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::data_types_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +}; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; +use prost::{DecodeError, Message}; + +mod accumulator; +mod accumulator_args; +mod groups_accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub name: RString, + + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub aliases: RVec<RString>, + + /// FFI equivalent to the `name` of a [`AggregateUDF`] Review Comment: Comment needs updating. ########## datafusion/ffi/src/udaf/groups_accumulator.rs: ########## @@ -0,0 +1,542 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::to_ffi, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, +}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec<WrappedArray>, + group_indices: RVec<usize>, + opt_filter: ROption<WrappedArray>, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult<WrappedArray, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult<RVec<WrappedArray>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec<WrappedArray>, + group_indices: RVec<usize>, + opt_filter: ROption<WrappedArray>, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec<WrappedArray>, + opt_filter: ROption<WrappedArray>, + ) + -> RResult<RVec<WrappedArray>, RString>, + + pub supports_convert_to_state: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box<dyn GroupsAccumulator>, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec<WrappedArray>, + group_indices: RVec<usize>, + opt_filter: ROption<WrappedArray>, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; Review Comment: Refactor private data access into separate function? ########## datafusion/ffi/src/arrow_wrappers.rs: ########## @@ -31,30 +32,37 @@ use log::error; #[derive(Debug, StableAbi)] pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_ffi_schema_error(e: ArrowError) -> FFI_ArrowSchema { + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + FFI_ArrowSchema::empty() +} + impl From<SchemaRef> for WrappedSchema { fn from(value: SchemaRef) -> Self { - let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); - FFI_ArrowSchema::empty() - } - }; - - WrappedSchema(ffi_schema) + WrappedSchema( + FFI_ArrowSchema::try_from(value.as_ref()) + .unwrap_or_else(catch_ffi_schema_error), + ) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); + Schema::empty() +} Review Comment: As above. ########## datafusion/ffi/src/record_batch_stream.rs: ########## @@ -196,3 +196,48 @@ impl Stream for FFI_RecordBatchStream { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + common::record_batch, error::Result, execution::SendableRecordBatchStream, + test_util::bounded_stream, + }; + + use super::FFI_RecordBatchStream; + use futures::StreamExt; + + #[tokio::test] + async fn test_round_trip_record_batch_stream() -> Result<()> { + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 3]), + ("b", Float64, vec![Some(4.0), None, Some(5.0)]) + )?; + let original_rbs = bounded_stream(record_batch.clone(), 1); + + let ffi_rbs: FFI_RecordBatchStream = original_rbs.into(); + let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs); + + let schema = ffi_rbs.schema(); + assert_eq!( + schema, + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true) + ])) + ); + + let batch = ffi_rbs.next().await; + assert!(batch.is_some()); + assert!(batch.unwrap().is_ok()); + Review Comment: Assert batch contains correct contents here? ########## datafusion/ffi/src/udaf/mod.rs: ########## @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RStr, RString, RVec}, + StableAbi, +}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::data_types_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +}; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; +use prost::{DecodeError, Message}; + +mod accumulator; +mod accumulator_args; +mod groups_accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub name: RString, + + /// FFI equivalent to the `name` of a [`AggregateUDF`] Review Comment: Comment doesn't reflect field. Should be "aliases"? ########## datafusion/ffi/src/udaf/groups_accumulator.rs: ########## @@ -0,0 +1,542 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::to_ffi, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, +}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec<WrappedArray>, + group_indices: RVec<usize>, + opt_filter: ROption<WrappedArray>, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult<WrappedArray, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult<RVec<WrappedArray>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec<WrappedArray>, + group_indices: RVec<usize>, + opt_filter: ROption<WrappedArray>, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec<WrappedArray>, + opt_filter: ROption<WrappedArray>, + ) + -> RResult<RVec<WrappedArray>, RString>, + + pub supports_convert_to_state: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box<dyn GroupsAccumulator>, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec<WrappedArray>, + group_indices: RVec<usize>, + opt_filter: ROption<WrappedArray>, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::<Result<Vec<ArrayRef>>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec<usize> = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(BooleanArray::from); + + rresult!(accum_data.accumulator.update_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult<WrappedArray, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let result = rresult_return!(accum_data.accumulator.evaluate(emit_to.into())); + + rresult!(WrappedArray::try_from(&result)) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + accum_data.accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult<RVec<WrappedArray>, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let state = rresult_return!(accum_data.accumulator.state(emit_to.into())); + rresult!(state + .into_iter() + .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) + .collect::<Result<RVec<_>>>()) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec<WrappedArray>, + group_indices: RVec<usize>, + opt_filter: ROption<WrappedArray>, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::<Result<Vec<ArrayRef>>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec<usize> = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(BooleanArray::from); + + rresult!(accum_data.accumulator.merge_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn convert_to_state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + values: RVec<WrappedArray>, + opt_filter: ROption<WrappedArray>, +) -> RResult<RVec<WrappedArray>, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values = rresult_return!(values + .into_iter() + .map(|v| ArrayRef::try_from(v).map_err(DataFusionError::from)) + .collect::<Result<Vec<_>>>()); + + let opt_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()).map(BooleanArray::from); + + let state = rresult_return!(accum_data + .accumulator + .convert_to_state(&values, opt_filter.as_ref())); + + rresult!(state + .iter() + .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from)) + .collect::<Result<RVec<_>>>()) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); +} + +// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> FFI_GroupsAccumulator { Review Comment: Remove dead code. ########## datafusion/ffi/src/udaf/accumulator.rs: ########## @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec<WrappedArray>, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn(accumulator: &Self) -> RResult<RVec<u8>, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: + unsafe extern "C" fn(accumulator: &Self) -> RResult<RVec<RVec<u8>>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + states: RVec<WrappedArray>, + ) -> RResult<(), RString>, + + pub retract_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec<WrappedArray>, + ) -> RResult<(), RString>, + + pub supports_retract_batch: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_Accumulator {} +unsafe impl Sync for FFI_Accumulator {} + +pub struct AccumulatorPrivateData { + pub accumulator: Box<dyn Accumulator>, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec<WrappedArray>, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::<Result<Vec<ArrayRef>>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accum_data.accumulator.update_batch(&values_arrays)) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_Accumulator, +) -> RResult<RVec<u8>, RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; Review Comment: Refactor these two repeated lines into a private inline(able) function? ########## datafusion/ffi/src/udaf/mod.rs: ########## @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RStr, RString, RVec}, + StableAbi, +}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::data_types_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +}; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; +use prost::{DecodeError, Message}; + +mod accumulator; +mod accumulator_args; +mod groups_accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub name: RString, + + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub aliases: RVec<RString>, + + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub volatility: FFI_Volatility, + + pub return_type: unsafe extern "C" fn( + udaf: &Self, + arg_types: RVec<WrappedSchema>, + ) -> RResult<WrappedSchema, RString>, + + pub is_nullable: bool, + + pub groups_accumulator_supported: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + + pub accumulator: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult<FFI_Accumulator, RString>, + + pub create_sliding_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult<FFI_Accumulator, RString>, + + #[allow(clippy::type_complexity)] + pub state_fields: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec<WrappedSchema>, + return_type: WrappedSchema, + ordering_fields: RVec<RVec<u8>>, + is_distinct: bool, + ) -> RResult<RVec<RVec<u8>>, RString>, + + pub create_groups_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult<FFI_GroupsAccumulator, RString>, + + pub with_beneficial_ordering: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, + ) -> RResult<ROption<FFI_AggregateUDF>, RString>, + + pub order_sensitivity: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`AggregateUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec<WrappedSchema>, + ) -> RResult<RVec<WrappedSchema>, RString>, + + /// Used to create a clone on the provider of the udaf. This should + /// only need to be called by the receiver of the udaf. + pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udaf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udaf. + /// A [`ForeignAggregateUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_AggregateUDF {} +unsafe impl Sync for FFI_AggregateUDF {} + +pub struct AggregateUDFPrivateData { + pub udaf: Arc<AggregateUDF>, +} + +impl FFI_AggregateUDF { + unsafe fn inner(&self) -> &Arc<AggregateUDF> { + let private_data = self.private_data as *const AggregateUDFPrivateData; + &(*private_data).udaf + } +} + +unsafe extern "C" fn return_type_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec<WrappedSchema>, +) -> RResult<WrappedSchema, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udaf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult<FFI_Accumulator, RString> { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_sliding_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult<FFI_Accumulator, RString> { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_sliding_accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult<FFI_GroupsAccumulator, RString> { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_groups_accumulator(accumulator_args.into()) + .map(FFI_GroupsAccumulator::from)) +} + +unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> bool { + let udaf = udaf.inner(); + + ForeignAccumulatorArgs::try_from(args) + .map(|a| udaf.groups_accumulator_supported((&a).into())) + .unwrap_or_else(|e| { + log::warn!("Unable to parse accumulator args. {}", e); + false + }) +} + +unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, +) -> RResult<ROption<FFI_AggregateUDF>, RString> { + let udaf = udaf.inner().as_ref().clone(); + + let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); + let result = rresult_return!(result + .map(|func| func.with_beneficial_ordering(beneficial_ordering)) + .transpose()) + .flatten() + .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + + RResult::ROk(result.into()) +} + +unsafe extern "C" fn state_fields_fn_wrapper( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec<WrappedSchema>, + return_type: WrappedSchema, + ordering_fields: RVec<RVec<u8>>, + is_distinct: bool, +) -> RResult<RVec<RVec<u8>>, RString> { + let udaf = udaf.inner(); + + let input_types = &rresult_return!(rvec_wrapped_to_vec_datatype(&input_types)); + let return_type = &rresult_return!(DataType::try_from(&return_type.0)); + + let ordering_fields = &rresult_return!(ordering_fields + .into_iter() + .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) + .collect::<std::result::Result<Vec<_>, DecodeError>>()); + + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)); + + let args = StateFieldsArgs { + name: name.as_str(), + input_types, + return_type, + ordering_fields, + is_distinct, + }; + + let state_fields = rresult_return!(udaf.state_fields(args)); + let state_fields = rresult_return!(state_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::<Result<Vec<_>>>()) + .into_iter() + .map(|field| field.encode_to_vec().into()) + .collect(); + + RResult::ROk(state_fields) +} + +unsafe extern "C" fn order_sensitivity_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> FFI_AggregateOrderSensitivity { + udaf.inner().order_sensitivity().into() +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec<WrappedSchema>, +) -> RResult<RVec<WrappedSchema>, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_types = rresult_return!(data_types_with_aggregate_udf(&arg_types, udaf)); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { + let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { + Arc::clone(udaf.inner()).into() +} + +impl Clone for FFI_AggregateUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From<Arc<AggregateUDF>> for FFI_AggregateUDF { + fn from(udaf: Arc<AggregateUDF>) -> Self { + let name = udaf.name().into(); + let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let is_nullable = udaf.is_nullable(); + let volatility = udaf.signature().volatility.into(); + + let private_data = Box::new(AggregateUDFPrivateData { udaf }); + + Self { + name, + is_nullable, + volatility, + aliases, + return_type: return_type_fn_wrapper, + accumulator: accumulator_fn_wrapper, + create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, + create_groups_accumulator: create_groups_accumulator_fn_wrapper, + groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, + with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, + state_fields: state_fields_fn_wrapper, + order_sensitivity: order_sensitivity_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_AggregateUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_AggregateUDF. +#[derive(Debug)] +pub struct ForeignAggregateUDF { + signature: Signature, + aliases: Vec<String>, + udaf: FFI_AggregateUDF, +} + +unsafe impl Send for ForeignAggregateUDF {} +unsafe impl Sync for ForeignAggregateUDF {} + +impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { + type Error = DataFusionError; + + fn try_from(udaf: &FFI_AggregateUDF) -> Result<Self, Self::Error> { + let signature = Signature::user_defined((&udaf.volatility).into()); + let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) + } +} + +impl AggregateUDFImpl for ForeignAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + self.udaf.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn is_nullable(&self) -> bool { + self.udaf.is_nullable + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { + let args = acc_args.try_into()?; + unsafe { + df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { + Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator> + }) + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { + unsafe { + let name = RStr::from_str(args.name); + let input_types = vec_datatype_to_rvec_wrapped(args.input_types)?; + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let ordering_fields = args + .ordering_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::<Result<Vec<_>>>()? + .into_iter() + .map(|proto_field| proto_field.encode_to_vec().into()) + .collect(); + + let fields = df_result!((self.udaf.state_fields)( + &self.udaf, + &name, + input_types, + return_type, + ordering_fields, + args.is_distinct + ))?; + let fields = fields + .into_iter() + .map(|field_bytes| { + datafusion_proto_common::Field::decode(field_bytes.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + }) + .collect::<Result<Vec<_>>>()?; + + parse_proto_fields_to_fields(fields.iter()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + } + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + let args = match FFI_AccumulatorArgs::try_from(args) { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {}", e); + return false; + } + }; + + unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result<Box<dyn GroupsAccumulator>> { + let args = FFI_AccumulatorArgs::try_from(args)?; + + unsafe { + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( + |accum| { + Box::new(ForeignGroupsAccumulator::from(accum)) + as Box<dyn GroupsAccumulator> + }, + ) + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result<Box<dyn Accumulator>> { + let args = args.try_into()?; + unsafe { + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( + |accum| Box::new(ForeignAccumulator::from(accum)) as Box<dyn Accumulator>, + ) + } + } + + fn with_beneficial_ordering( + self: Arc<Self>, + beneficial_ordering: bool, + ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> { + unsafe { + let result = df_result!((self.udaf.with_beneficial_ordering)( + &self.udaf, + beneficial_ordering + ))? + .into_option(); + + let result = result + .map(|func| ForeignAggregateUDF::try_from(&func)) + .transpose()?; + + Ok(result.map(|func| Arc::new(func) as Arc<dyn AggregateUDFImpl>)) + } + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() } + } + + fn simplify(&self) -> Option<AggregateFunctionSimplification> { + None + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = + df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_AggregateOrderSensitivity { + Insensitive, + HardRequirement, + Beneficial, +} + +impl From<FFI_AggregateOrderSensitivity> for AggregateOrderSensitivity { + fn from(value: FFI_AggregateOrderSensitivity) -> Self { + match value { + FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive, + FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity { + fn from(value: AggregateOrderSensitivity) -> Self { + match value { + AggregateOrderSensitivity::Insensitive => Self::Insensitive, + AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::{ + common::create_array, + functions_aggregate::sum::Sum, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, + scalar::ScalarValue, + }; + + use super::*; + + fn create_test_foreign_udaf( + original_udaf: impl AggregateUDFImpl + 'static, + ) -> Result<AggregateUDF> { + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + Ok(foreign_udaf.into()) + } + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = Sum::new(); + let original_name = original_udaf.name().to_owned(); + + let foreign_udaf = create_test_foreign_udaf(original_udaf)?; + // let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); Review Comment: Remove dead code. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org