pgwhalen commented on code in PR #46:
URL: https://github.com/apache/datafusion-java/pull/46#discussion_r3255090886
##########
core/src/main/java/org/apache/datafusion/SessionContext.java:
##########
@@ -204,6 +209,41 @@ public DataFrame readParquet(String path,
ParquetReadOptions options) {
return new DataFrame(dfHandle);
}
+ /**
+ * Register a Java-implemented scalar UDF. After registration, the function
can be invoked by SQL
+ * via its {@code name} or referenced in DataFusion plans deserialised with
{@link #fromProto}.
+ *
+ * <p>Argument and return types are declared at registration time. The UDF
is registered with an
+ * exact signature: the runtime will reject calls whose argument types do
not match {@code
+ * argTypes} exactly.
+ *
+ * @throws RuntimeException if registration fails (e.g., name already
registered with an
+ * incompatible signature, schema serialisation failure).
+ */
+ public void registerUdf(
+ String name,
+ ScalarUdf udf,
+ ArrowType returnType,
+ List<ArrowType> argTypes,
+ Volatility volatility) {
Review Comment:
The [equivalent rust
function](https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.register_udf)
takes a `ScalarUDF` struct that exposes many of these arguments, rather than
passing them here at the top level. Is it worth emulating that structure here?
This is obviously a much broader question, it could be asked about nearly
every public API we add. I waffled on it myself, and ended up deciding to
strictly follow the rust API as closely as possible (my ScalarUDF is
[here](https://github.com/pgwhalen/datafusion-java/blob/datafusion-ffi-java/datafusion-ffi-java/src/main/java/org/apache/arrow/datafusion/logical_expr/ScalarUDF.java)
- not saying we should emulate that; just for reference). My justification
was that it would be easier to stay in line with the upstream rust as it
evolved, and it would be clearer what Java supported.
Whatever we decide, I think we ought to document the strategy going forward,
especially for the benefit of coding agents. I attempted this in my own
bindings with Claude rules for defining the [Public
API](https://github.com/pgwhalen/datafusion-java/blob/datafusion-ffi-java/.claude/rules/public-api.md)
as well as for [documenting
it](https://github.com/pgwhalen/datafusion-java/blob/datafusion-ffi-java/.claude/rules/documentation.md),
going so far as to ensure that Javadocs provide a working link to the rust
equivalent. I'd be happy to do a version of this for these new bindings
(wouldn't make sense to use those I've linked exactly).
##########
docs/source/user-guide/scalar-udf.md:
##########
@@ -0,0 +1,97 @@
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# Scalar UDFs
+
+A scalar UDF is a Java-implemented SQL function that operates on one row at a
+time, expressed in vectorised form: each invocation receives a batch of input
+columns and returns a single output column of the same length.
+
+## Implement
+
+Implement the `ScalarUdf` interface:
+
+```java
+import java.util.List;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.datafusion.ScalarUdf;
+
+public final class AddOne implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List<FieldVector>
args) {
+ IntVector in = (IntVector) args.get(0);
+ IntVector out = new IntVector("add_one", allocator);
+ out.allocateNew(in.getValueCount());
+ for (int i = 0; i < in.getValueCount(); i++) {
+ if (in.isNull(i)) out.setNull(i);
+ else out.set(i, in.get(i) + 1);
+ }
+ out.setValueCount(in.getValueCount());
+ return out;
+ }
+}
+```
+
+Allocate any new vectors — including the result — from the supplied
+`BufferAllocator`. The input vectors are read-only views; do not close them.
+Ownership of the returned vector transfers to the framework on return.
Review Comment:
Great API, great docs! Similar to but better than what I did.
##########
native/src/udf.rs:
##########
@@ -0,0 +1,293 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Java-backed scalar UDF support.
+
+use std::any::Any;
+use std::fmt;
+use std::sync::Arc;
+
+use datafusion::arrow::array::{make_array, Array, ArrayRef, StructArray};
+use datafusion::arrow::datatypes::{DataType, Field, Fields};
+use datafusion::arrow::ffi::{from_ffi, to_ffi, FFI_ArrowArray,
FFI_ArrowSchema};
+use datafusion::error::DataFusionError;
+use datafusion::logical_expr::{
+ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
+};
+use jni::objects::{GlobalRef, JStaticMethodID, JThrowable};
+use jni::signature::{Primitive, ReturnType};
+use jni::sys::{jlong, jvalue};
+use jni::JNIEnv;
+
+pub(crate) struct JavaScalarUdf {
+ pub(crate) name: String,
+ pub(crate) signature: Signature,
+ pub(crate) return_type: DataType,
+ /// Global ref to the user's `org.apache.datafusion.ScalarUdf` instance.
+ pub(crate) udf_global_ref: GlobalRef,
+ /// Global ref to the `org.apache.datafusion.internal.JniBridge` class.
+ pub(crate) bridge_class: GlobalRef,
+ /// Method ID for `JniBridge.invokeScalarUdf`.
+ pub(crate) invoke_method: JStaticMethodID,
+}
+
+// SAFETY: JStaticMethodID is a JNI handle that's safe to share because the
+// class it points to is held alive by `bridge_class`. We never mutate
+// `invoke_method` after construction; DataFusion requires `Send + Sync` on
+// `ScalarUDFImpl`.
+unsafe impl Send for JavaScalarUdf {}
+unsafe impl Sync for JavaScalarUdf {}
+
+impl fmt::Debug for JavaScalarUdf {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("JavaScalarUdf")
+ .field("name", &self.name)
+ .field("return_type", &self.return_type)
+ .finish()
+ }
+}
+
+impl PartialEq for JavaScalarUdf {
+ fn eq(&self, other: &Self) -> bool {
+ // Two Java UDFs are equal iff they wrap the same registered name.
+ self.name == other.name
+ }
+}
+
+impl Eq for JavaScalarUdf {}
+
+impl std::hash::Hash for JavaScalarUdf {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.name.hash(state);
+ }
+}
+
+impl ScalarUDFImpl for JavaScalarUdf {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) ->
datafusion::error::Result<DataType> {
+ Ok(self.return_type.clone())
+ }
+
+ fn invoke_with_args(
+ &self,
+ args: ScalarFunctionArgs,
+ ) -> datafusion::error::Result<ColumnarValue> {
+ let number_rows = args.number_rows;
+
+ // 1. Materialise scalars to arrays so all columns are length-N.
+ let arrays: Vec<ArrayRef> = args
+ .args
+ .iter()
+ .map(|cv| cv.clone().into_array(number_rows))
+ .collect::<datafusion::error::Result<Vec<_>>>()?;
+
+ // 2. Build a single struct array carrying all arg columns. Field
names/types come
+ // from the signature's Exact type list (matches what the Java
caller declared).
+ let signature_fields: Vec<Arc<Field>> = match
&self.signature.type_signature {
+ TypeSignature::Exact(types) => types
+ .iter()
+ .enumerate()
+ .map(|(i, ty)| Arc::new(Field::new(format!("arg{}", i),
ty.clone(), true)))
+ .collect(),
+ _ => {
+ return Err(DataFusionError::Internal(
+ "JavaScalarUdf signature is not Exact; only
Signature::exact is supported"
+ .to_string(),
+ ))
+ }
+ };
+
+ let fields = Fields::from(
+ signature_fields
+ .iter()
+ .map(|f| f.as_ref().clone())
+ .collect::<Vec<Field>>(),
+ );
+ let struct_array = StructArray::try_new_with_length(fields, arrays,
None, number_rows)
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+ let args_data = struct_array.into_data();
+ let (args_ffi_array, args_ffi_schema) =
+ to_ffi(&args_data).map_err(|e|
DataFusionError::ArrowError(Box::new(e), None))?;
+
+ // 3. Pre-allocate empty FFI structs for the result.
+ let result_ffi_array = FFI_ArrowArray::empty();
+ let result_ffi_schema = FFI_ArrowSchema::empty();
+
+ // 4. Box for stable addresses across the JNI call.
+ let mut args_array_box = Box::new(args_ffi_array);
+ let mut args_schema_box = Box::new(args_ffi_schema);
+ let mut result_array_box = Box::new(result_ffi_array);
+ let mut result_schema_box = Box::new(result_ffi_schema);
+
+ let args_array_addr = args_array_box.as_mut() as *mut _ as jlong;
+ let args_schema_addr = args_schema_box.as_mut() as *mut _ as jlong;
+ let result_array_addr = result_array_box.as_mut() as *mut _ as jlong;
+ let result_schema_addr = result_schema_box.as_mut() as *mut _ as jlong;
+
+ // 5. Attach JNI to current thread.
+ let mut env = crate::jvm()
+ .attach_current_thread()
+ .map_err(|e| DataFusionError::Execution(format!("JNI attach
failed: {}", e)))?;
+
+ // 6. Call JniBridge.invokeScalarUdf(udf, args*, result*,
expectedRowCount).
+ //
+ // Build the jvalue argument array for call_static_method_unchecked.
+ // SAFETY: we build the args inline and pass them immediately; the
JObject
+ // pointed to by udf_global_ref is alive for the duration of this call.
+ let expected_rows = i32::try_from(number_rows).map_err(|_| {
+ DataFusionError::Execution(format!(
+ "batch row count {} exceeds i32::MAX; UDFs cannot handle
batches larger than 2^31 - 1 rows",
+ number_rows
+ ))
+ })?;
+
+ let udf_jobject = self.udf_global_ref.as_obj();
+ // SAFETY: udf_jobject is derived from a GlobalRef alive for the
duration of this
+ // function. The raw pointer is only read by the JNI call below, which
happens
+ // before any code that could drop udf_global_ref.
+ let call_args: [jvalue; 6] = [
+ // ScalarUdf instance
+ jvalue {
+ l: udf_jobject.as_raw(),
+ },
+ // argsArrayAddr
+ jvalue { j: args_array_addr },
+ // argsSchemaAddr
+ jvalue {
+ j: args_schema_addr,
+ },
+ // resultArrayAddr
+ jvalue {
+ j: result_array_addr,
+ },
+ // resultSchemaAddr
+ jvalue {
+ j: result_schema_addr,
+ },
+ // expectedRowCount
+ jvalue { i: expected_rows },
+ ];
+
+ let call_result = unsafe {
+ env.call_static_method_unchecked(
+ &self.bridge_class,
+ self.invoke_method,
+ ReturnType::Primitive(Primitive::Void),
+ &call_args,
+ )
+ };
+
+ // 7. If Java threw, translate to DataFusionError. Always check
exception_check first.
+ if env.exception_check().unwrap_or(false) {
+ let throwable = env.exception_occurred().map_err(|e| {
+ DataFusionError::Execution(format!("exception_occurred failed:
{}", e))
+ })?;
+ env.exception_clear().ok();
+ let message = jthrowable_to_string(&mut env, &throwable,
&self.name);
+ return Err(DataFusionError::Execution(message));
+ }
+ call_result.map_err(|e| DataFusionError::Execution(format!("JNI call
failed: {}", e)))?;
+
+ // 8. Import result. from_ffi consumes the FFI_ArrowArray.
+ let result_array = *result_array_box;
+ let result_schema = *result_schema_box;
+ // SAFETY: Java's `Data.exportVector` populated `result_array_box` and
+ // `result_schema_box` in place via the C Data Interface, and the
+ // exception check above guarantees the call succeeded without
+ // throwing — so the FFI structs are fully initialized.
+ let result_data = unsafe { from_ffi(result_array, &result_schema) }
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+
+ // 9. Validate type.
+ if result_data.data_type() != &self.return_type {
+ return Err(DataFusionError::Execution(format!(
+ "Java UDF '{}' returned vector of type {:?}; declared return
type was {:?}",
+ self.name,
+ result_data.data_type(),
+ self.return_type
+ )));
+ }
+
+ let array: ArrayRef = make_array(result_data);
+ Ok(ColumnarValue::Array(array))
+ }
+}
+
+pub(crate) fn volatility_from_byte(byte: u8) ->
datafusion::error::Result<Volatility> {
+ match byte {
+ 0 => Ok(Volatility::Immutable),
+ 1 => Ok(Volatility::Stable),
+ 2 => Ok(Volatility::Volatile),
+ other => Err(DataFusionError::Execution(format!(
+ "unknown volatility byte: {}",
+ other
+ ))),
+ }
+}
+
+/// Best-effort: extract class name and getMessage() from a Java throwable.
Review Comment:
I think it would be nice to propagate the stack trace (in full or in part),
at least optionally. In my bindings I made it a [session configurable
option](https://github.com/pgwhalen/datafusion-java/blob/1c20733aa8b008315af6092a912b58ef3df6a482/datafusion-ffi-java/src/main/java/org/apache/arrow/datafusion/config/ConfigOptions.java#L50-L54),
though I could imagine many other strategies. The implementation would look
quite different in the bridge.
Could make sense as follow on work for all upcalls, rather than in the scope
of this PR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]