This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 966d8366e feat: Add num_rows and TaskContext to
CometUDFBridge.evaluate (#4306)
966d8366e is described below
commit 966d8366e0702d990e25ced4ee7257faa46caaa0
Author: Matt Butrovich <[email protected]>
AuthorDate: Tue May 12 20:12:46 2026 -0400
feat: Add num_rows and TaskContext to CometUDFBridge.evaluate (#4306)
---
.../java/org/apache/comet/udf/CometUdfBridge.java | 55 ++++++++++++++++++----
.../main/scala/org/apache/comet/udf/CometUDF.scala | 9 +++-
.../apache/spark/comet/CometTaskContextShim.scala | 41 ++++++++++++++++
native/core/src/execution/jni_api.rs | 21 ++++++++-
native/core/src/execution/planner.rs | 26 +++++++---
native/jni-bridge/src/comet_udf_bridge.rs | 2 +-
native/spark-expr/src/jvm_udf/mod.rs | 24 +++++++++-
.../scala/org/apache/comet/CometExecIterator.scala | 5 +-
spark/src/main/scala/org/apache/comet/Native.scala | 5 +-
9 files changed, 162 insertions(+), 26 deletions(-)
diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
index aed53c57d..5e7681981 100644
--- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
+++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
@@ -27,6 +27,8 @@ import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
+import org.apache.spark.TaskContext;
+import org.apache.spark.comet.CometTaskContextShim;
/**
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches
the static-method
@@ -48,13 +50,52 @@ public class CometUdfBridge {
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs
(one per input)
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the
result
+ * @param numRows row count of the current batch. Mirrors DataFusion's {@code
+ * ScalarFunctionArgs.number_rows}; the only batch-size signal a
zero-input UDF (e.g. a
+ * zero-arg non-deterministic ScalaUDF) ever sees.
+ * @param taskContext propagated Spark {@link TaskContext} from the driving
Spark task thread, or
+ * {@code null} outside a Spark task. Treated as ground truth for the
call: installed as the
+ * thread-local on entry, with the prior value (if any) saved and
restored in {@code finally}.
+ * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
+ * MonotonicallyIncreasingID}) work from Tokio workers and avoids
reusing a stale TaskContext
+ * left on a worker by a previous task.
*/
public static void evaluate(
String udfClassName,
long[] inputArrayPtrs,
long[] inputSchemaPtrs,
long outArrayPtr,
- long outSchemaPtr) {
+ long outSchemaPtr,
+ int numRows,
+ TaskContext taskContext) {
+ // Save-and-restore rather than only-install-if-null: the propagated
context is the ground
+ // truth for this call. Any value already on the thread is either (a) the
same object on a
+ // Spark task thread, or (b) stale from a prior task on a reused Tokio
worker.
+ TaskContext prior = TaskContext.get();
+ if (taskContext != null) {
+ CometTaskContextShim.set(taskContext);
+ }
+ try {
+ evaluateInternal(
+ udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr,
outSchemaPtr, numRows);
+ } finally {
+ if (taskContext != null) {
+ if (prior != null) {
+ CometTaskContextShim.set(prior);
+ } else {
+ CometTaskContextShim.unset();
+ }
+ }
+ }
+ }
+
+ private static void evaluateInternal(
+ String udfClassName,
+ long[] inputArrayPtrs,
+ long[] inputSchemaPtrs,
+ long outArrayPtr,
+ long outSchemaPtr,
+ int numRows) {
CometUDF udf =
INSTANCES.computeIfAbsent(
udfClassName,
@@ -84,23 +125,17 @@ public class CometUdfBridge {
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
}
- result = udf.evaluate(inputs);
+ result = udf.evaluate(inputs, numRows);
if (!(result instanceof FieldVector)) {
throw new RuntimeException(
"CometUDF.evaluate() must return a FieldVector, got: " +
result.getClass().getName());
}
- // Result length must match the longest input. Scalar (length-1) inputs
- // are allowed to be shorter, but a vector input bounds the output.
- int expectedLen = 0;
- for (ValueVector v : inputs) {
- expectedLen = Math.max(expectedLen, v.getValueCount());
- }
- if (result.getValueCount() != expectedLen) {
+ if (result.getValueCount() != numRows) {
throw new RuntimeException(
"CometUDF.evaluate() returned "
+ result.getValueCount()
+ " rows, expected "
- + expectedLen);
+ + numRows);
}
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);
diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
index 29186f0a2..5b6652d90 100644
--- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
+++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
@@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector
*
* - Vector arguments arrive at the row count of the current batch.
* - Scalar (literal-folded) arguments arrive as length-1 vectors and must
be read at index 0.
- * - The returned vector's length must match the longest input.
+ * - The returned vector's length must match `numRows`.
+ *
+ * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the
batch row count.
+ * UDFs that always have at least one batch-length input can derive length
from the inputs and
+ * ignore `numRows`; UDFs that may be called with zero data columns (e.g. a
zero-arg ScalaUDF)
+ * need `numRows` to know how many rows to produce.
*
* Implementations must have a public no-arg constructor and must be
stateless: a single instance
* per class is cached and shared across native worker threads for the
lifetime of the JVM.
*/
trait CometUDF {
- def evaluate(inputs: Array[ValueVector]): ValueVector
+ def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
}
diff --git
a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
new file mode 100644
index 000000000..9218fc5e7
--- /dev/null
+++ b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet
+
+import org.apache.spark.TaskContext
+
+/**
+ * Package-private access shim for `TaskContext.setTaskContext` /
`TaskContext.unset`.
+ *
+ * Both methods are declared `protected[spark]` on Spark's `TaskContext`
companion, so they are
+ * reachable from code inside the `org.apache.spark` package tree but not from
`org.apache.comet`.
+ * The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its
caller thread (a
+ * Tokio worker thread with no `TaskContext`) so the user's UDF body and any
partition-sensitive
+ * built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the
driving Spark task's
+ * `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call
through to the
+ * protected methods, and exposes plain public forwarders the bridge (which
lives in
+ * `org.apache.comet.udf`) can use.
+ */
+object CometTaskContextShim {
+
+ def set(taskContext: TaskContext): Unit =
TaskContext.setTaskContext(taskContext)
+
+ def unset(): Unit = TaskContext.unset()
+}
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index 5d3dbb826..f5b04cc51 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -306,6 +306,13 @@ struct ExecutionContext {
pub tracing_memory_metric_name: String,
/// Pre-computed tracing event name for executePlan calls
pub tracing_event_name: String,
+ /// Spark `TaskContext` captured on the driving Spark task thread at
`createPlan` time.
+ /// Threaded into every JVM scalar UDF the planner builds so the JNI
bridge can install it
+ /// as the thread-local `TaskContext` for the Tokio worker running the
UDF. `None` when no
+ /// driving Spark task is present (unit tests, direct native driver runs).
The `Arc` is
+ /// cheap to clone; the underlying `Global<JObject>` releases its JNI
global ref on drop
+ /// via `jni`'s `Drop` impl.
+ pub task_context: Option<Arc<Global<JObject<'static>>>>,
}
/// Accept serialized query plan and return the address of the native query
plan.
@@ -332,6 +339,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
task_attempt_id: jlong,
task_cpus: jlong,
key_unwrapper_obj: JObject,
+ task_context_obj: JObject,
) -> jlong {
try_unwrap_or_throw(&e, |env| {
// Deserialize Spark configs
@@ -453,6 +461,15 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
String::new()
};
+ // Capture the driving Spark task's TaskContext as a JNI global
reference when
+ // non-null. The `Arc<Global<JObject>>` releases its global ref on
drop, so cleanup
+ // is automatic when the ExecutionContext drops.
+ let task_context = if !task_context_obj.is_null() {
+ Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?))
+ } else {
+ None
+ };
+
let exec_context = Box::new(ExecutionContext {
id,
task_attempt_id,
@@ -479,6 +496,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
"thread_{rust_thread_id}_comet_memory_reserved"
),
tracing_event_name,
+ task_context,
});
Ok(Box::into_raw(exec_context) as i64)
@@ -703,7 +721,8 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
let start = Instant::now();
let planner =
PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
- .with_exec_id(exec_context_id);
+ .with_exec_id(exec_context_id)
+ .with_task_context(exec_context.task_context.clone());
let (scans, shuffle_scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
&mut exec_context.input_sources.clone(),
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 478c7a8d9..b00f14002 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -183,6 +183,9 @@ pub struct PhysicalPlanner {
partition: i32,
session_ctx: Arc<SessionContext>,
query_context_registry: Arc<datafusion_comet_spark_expr::QueryContextMap>,
+ /// Captured at `createPlan` time on `ExecutionContext`; see that struct
for the
+ /// propagation rationale. `None` when no driving Spark task is available.
+ task_context: Option<Arc<Global<JObject<'static>>>>,
}
impl Default for PhysicalPlanner {
@@ -198,16 +201,24 @@ impl PhysicalPlanner {
session_ctx,
partition,
query_context_registry:
datafusion_comet_spark_expr::create_query_context_map(),
+ task_context: None,
}
}
- pub fn with_exec_id(self, exec_context_id: i64) -> Self {
- Self {
- exec_context_id,
- partition: self.partition,
- session_ctx: Arc::clone(&self.session_ctx),
- query_context_registry: Arc::clone(&self.query_context_registry),
- }
+ pub fn with_exec_id(mut self, exec_context_id: i64) -> Self {
+ self.exec_context_id = exec_context_id;
+ self
+ }
+
+ /// Attach the Spark `TaskContext` global reference captured at
`createPlan` time. Cloned
+ /// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can
install it as
+ /// the thread-local on the Tokio worker driving the UDF.
+ pub fn with_task_context(
+ mut self,
+ task_context: Option<Arc<Global<JObject<'static>>>>,
+ ) -> Self {
+ self.task_context = task_context;
+ self
}
/// Return session context of this planner.
@@ -735,6 +746,7 @@ impl PhysicalPlanner {
args,
return_type,
udf.return_nullable,
+ self.task_context.clone(),
)))
}
expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),
diff --git a/native/jni-bridge/src/comet_udf_bridge.rs
b/native/jni-bridge/src/comet_udf_bridge.rs
index 89cd8ee51..e531d20cb 100644
--- a/native/jni-bridge/src/comet_udf_bridge.rs
+++ b/native/jni-bridge/src/comet_udf_bridge.rs
@@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> {
method_evaluate: env.get_static_method_id(
JNIString::new(Self::JVM_CLASS),
jni::jni_str!("evaluate"),
- jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"),
+
jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"),
)?,
method_evaluate_ret: ReturnType::Primitive(Primitive::Void),
class,
diff --git a/native/spark-expr/src/jvm_udf/mod.rs
b/native/spark-expr/src/jvm_udf/mod.rs
index 668a2b672..4ed25de6e 100644
--- a/native/spark-expr/src/jvm_udf/mod.rs
+++ b/native/spark-expr/src/jvm_udf/mod.rs
@@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr;
use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError};
use datafusion_comet_jni_bridge::JVMClasses;
-use jni::objects::{JObject, JValue};
+use jni::objects::{Global, JObject, JValue};
/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via
JNI.
/// The JVM class named by `class_name` must implement
`org.apache.comet.udf.CometUDF`.
@@ -41,6 +41,14 @@ pub struct JvmScalarUdfExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
return_nullable: bool,
+ /// Captured at `createPlan` time and threaded here by the planner. Passed
through the
+ /// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio
worker's
+ /// thread-local `TaskContext`. Without this, partition-sensitive
built-ins inside a UDF
+ /// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
+ /// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None`
when no driving
+ /// Spark task is available; the bridge then leaves whatever
`TaskContext.get()` already
+ /// returns in place.
+ task_context: Option<Arc<Global<JObject<'static>>>>,
}
impl JvmScalarUdfExpr {
@@ -49,12 +57,14 @@ impl JvmScalarUdfExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
return_nullable: bool,
+ task_context: Option<Arc<Global<JObject<'static>>>>,
) -> Self {
Self {
class_name,
args,
return_type,
return_nullable,
+ task_context,
}
}
}
@@ -186,7 +196,14 @@ impl PhysicalExpr for JvmScalarUdfExpr {
.set_region(env, 0, &in_sch_ptrs)
.map_err(|e| CometError::JNI { source: e })?;
- // Call CometUdfBridge.evaluate(String, long[], long[], long, long)
+ // Pass a null jobject when no TaskContext was propagated so the
bridge's null-guard
+ // leaves the worker thread's current TaskContext.get() in place.
The borrow must
+ // outlive `call_static_method_unchecked`.
+ let null_task_context = JObject::null();
+ let task_context_ref: &JObject = match &self.task_context {
+ Some(gref) => gref.as_obj(),
+ None => &null_task_context,
+ };
let ret = unsafe {
env.call_static_method_unchecked(
&bridge.class,
@@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr {
JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(),
JValue::Long(out_arr_ptr).as_jni(),
JValue::Long(out_sch_ptr).as_jni(),
+ JValue::Int(batch.num_rows() as i32).as_jni(),
+ JValue::Object(task_context_ref).as_jni(),
],
)
};
@@ -234,6 +253,7 @@ impl PhysicalExpr for JvmScalarUdfExpr {
children,
self.return_type.clone(),
self.return_nullable,
+ self.task_context.clone(),
)))
}
}
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index a93564811..6140eca55 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -127,7 +127,10 @@ class CometExecIterator(
memoryConfig.memoryLimitPerTask,
taskAttemptId,
taskCPUs,
- keyUnwrapper)
+ keyUnwrapper,
+ // Propagated to Tokio workers running JVM UDFs so they see this Spark
task's
+ // TaskContext. See CometUdfBridge.evaluate.
+ TaskContext.get())
}
private var nextBatch: Option[ColumnarBatch] = None
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala
b/spark/src/main/scala/org/apache/comet/Native.scala
index c003bcd13..3cfa51b6e 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -21,7 +21,7 @@ package org.apache.comet
import java.nio.ByteBuffer
-import org.apache.spark.CometTaskMemoryManager
+import org.apache.spark.{CometTaskMemoryManager, TaskContext}
import org.apache.spark.sql.comet.CometMetricNode
import org.apache.comet.parquet.CometFileKeyUnwrapper
@@ -69,7 +69,8 @@ class Native extends NativeBase {
memoryLimitPerTask: Long,
taskAttemptId: Long,
taskCPUs: Long,
- keyUnwrapper: CometFileKeyUnwrapper): Long
+ keyUnwrapper: CometFileKeyUnwrapper,
+ taskContext: TaskContext): Long
// scalastyle:on
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]