This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 966d8366e feat: Add num_rows and TaskContext to 
CometUDFBridge.evaluate (#4306)
966d8366e is described below

commit 966d8366e0702d990e25ced4ee7257faa46caaa0
Author: Matt Butrovich <[email protected]>
AuthorDate: Tue May 12 20:12:46 2026 -0400

    feat: Add num_rows and TaskContext to CometUDFBridge.evaluate (#4306)
---
 .../java/org/apache/comet/udf/CometUdfBridge.java  | 55 ++++++++++++++++++----
 .../main/scala/org/apache/comet/udf/CometUDF.scala |  9 +++-
 .../apache/spark/comet/CometTaskContextShim.scala  | 41 ++++++++++++++++
 native/core/src/execution/jni_api.rs               | 21 ++++++++-
 native/core/src/execution/planner.rs               | 26 +++++++---
 native/jni-bridge/src/comet_udf_bridge.rs          |  2 +-
 native/spark-expr/src/jvm_udf/mod.rs               | 24 +++++++++-
 .../scala/org/apache/comet/CometExecIterator.scala |  5 +-
 spark/src/main/scala/org/apache/comet/Native.scala |  5 +-
 9 files changed, 162 insertions(+), 26 deletions(-)

diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java 
b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
index aed53c57d..5e7681981 100644
--- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
+++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
@@ -27,6 +27,8 @@ import org.apache.arrow.c.Data;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.ValueVector;
+import org.apache.spark.TaskContext;
+import org.apache.spark.comet.CometTaskContextShim;
 
 /**
  * JNI entry point for native execution to invoke a {@link CometUDF}. Matches 
the static-method
@@ -48,13 +50,52 @@ public class CometUdfBridge {
    * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs 
(one per input)
    * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
    * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the 
result
+   * @param numRows row count of the current batch. Mirrors DataFusion's {@code
+   *     ScalarFunctionArgs.number_rows}; the only batch-size signal a 
zero-input UDF (e.g. a
+   *     zero-arg non-deterministic ScalaUDF) ever sees.
+   * @param taskContext propagated Spark {@link TaskContext} from the driving 
Spark task thread, or
+   *     {@code null} outside a Spark task. Treated as ground truth for the 
call: installed as the
+   *     thread-local on entry, with the prior value (if any) saved and 
restored in {@code finally}.
+   *     Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
+   *     MonotonicallyIncreasingID}) work from Tokio workers and avoids 
reusing a stale TaskContext
+   *     left on a worker by a previous task.
    */
   public static void evaluate(
       String udfClassName,
       long[] inputArrayPtrs,
       long[] inputSchemaPtrs,
       long outArrayPtr,
-      long outSchemaPtr) {
+      long outSchemaPtr,
+      int numRows,
+      TaskContext taskContext) {
+    // Save-and-restore rather than only-install-if-null: the propagated 
context is the ground
+    // truth for this call. Any value already on the thread is either (a) the 
same object on a
+    // Spark task thread, or (b) stale from a prior task on a reused Tokio 
worker.
+    TaskContext prior = TaskContext.get();
+    if (taskContext != null) {
+      CometTaskContextShim.set(taskContext);
+    }
+    try {
+      evaluateInternal(
+          udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, 
outSchemaPtr, numRows);
+    } finally {
+      if (taskContext != null) {
+        if (prior != null) {
+          CometTaskContextShim.set(prior);
+        } else {
+          CometTaskContextShim.unset();
+        }
+      }
+    }
+  }
+
+  private static void evaluateInternal(
+      String udfClassName,
+      long[] inputArrayPtrs,
+      long[] inputSchemaPtrs,
+      long outArrayPtr,
+      long outSchemaPtr,
+      int numRows) {
     CometUDF udf =
         INSTANCES.computeIfAbsent(
             udfClassName,
@@ -84,23 +125,17 @@ public class CometUdfBridge {
         inputs[i] = Data.importVector(allocator, inArr, inSch, null);
       }
 
-      result = udf.evaluate(inputs);
+      result = udf.evaluate(inputs, numRows);
       if (!(result instanceof FieldVector)) {
         throw new RuntimeException(
             "CometUDF.evaluate() must return a FieldVector, got: " + 
result.getClass().getName());
       }
-      // Result length must match the longest input. Scalar (length-1) inputs
-      // are allowed to be shorter, but a vector input bounds the output.
-      int expectedLen = 0;
-      for (ValueVector v : inputs) {
-        expectedLen = Math.max(expectedLen, v.getValueCount());
-      }
-      if (result.getValueCount() != expectedLen) {
+      if (result.getValueCount() != numRows) {
         throw new RuntimeException(
             "CometUDF.evaluate() returned "
                 + result.getValueCount()
                 + " rows, expected "
-                + expectedLen);
+                + numRows);
       }
       ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
       ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);
diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala 
b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
index 29186f0a2..5b6652d90 100644
--- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
+++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
@@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector
  *
  *   - Vector arguments arrive at the row count of the current batch.
  *   - Scalar (literal-folded) arguments arrive as length-1 vectors and must 
be read at index 0.
- *   - The returned vector's length must match the longest input.
+ *   - The returned vector's length must match `numRows`.
+ *
+ * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the 
batch row count.
+ * UDFs that always have at least one batch-length input can derive length 
from the inputs and
+ * ignore `numRows`; UDFs that may be called with zero data columns (e.g. a 
zero-arg ScalaUDF)
+ * need `numRows` to know how many rows to produce.
  *
  * Implementations must have a public no-arg constructor and must be 
stateless: a single instance
  * per class is cached and shared across native worker threads for the 
lifetime of the JVM.
  */
 trait CometUDF {
-  def evaluate(inputs: Array[ValueVector]): ValueVector
+  def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
 }
diff --git 
a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala 
b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
new file mode 100644
index 000000000..9218fc5e7
--- /dev/null
+++ b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet
+
+import org.apache.spark.TaskContext
+
+/**
+ * Package-private access shim for `TaskContext.setTaskContext` / 
`TaskContext.unset`.
+ *
+ * Both methods are declared `protected[spark]` on Spark's `TaskContext` 
companion, so they are
+ * reachable from code inside the `org.apache.spark` package tree but not from 
`org.apache.comet`.
+ * The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its 
caller thread (a
+ * Tokio worker thread with no `TaskContext`) so the user's UDF body and any 
partition-sensitive
+ * built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the 
driving Spark task's
+ * `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call 
through to the
+ * protected methods, and exposes plain public forwarders the bridge (which 
lives in
+ * `org.apache.comet.udf`) can use.
+ */
+object CometTaskContextShim {
+
+  def set(taskContext: TaskContext): Unit = 
TaskContext.setTaskContext(taskContext)
+
+  def unset(): Unit = TaskContext.unset()
+}
diff --git a/native/core/src/execution/jni_api.rs 
b/native/core/src/execution/jni_api.rs
index 5d3dbb826..f5b04cc51 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -306,6 +306,13 @@ struct ExecutionContext {
     pub tracing_memory_metric_name: String,
     /// Pre-computed tracing event name for executePlan calls
     pub tracing_event_name: String,
+    /// Spark `TaskContext` captured on the driving Spark task thread at 
`createPlan` time.
+    /// Threaded into every JVM scalar UDF the planner builds so the JNI 
bridge can install it
+    /// as the thread-local `TaskContext` for the Tokio worker running the 
UDF. `None` when no
+    /// driving Spark task is present (unit tests, direct native driver runs). 
The `Arc` is
+    /// cheap to clone; the underlying `Global<JObject>` releases its JNI 
global ref on drop
+    /// via `jni`'s `Drop` impl.
+    pub task_context: Option<Arc<Global<JObject<'static>>>>,
 }
 
 /// Accept serialized query plan and return the address of the native query 
plan.
@@ -332,6 +339,7 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_createPlan(
     task_attempt_id: jlong,
     task_cpus: jlong,
     key_unwrapper_obj: JObject,
+    task_context_obj: JObject,
 ) -> jlong {
     try_unwrap_or_throw(&e, |env| {
         // Deserialize Spark configs
@@ -453,6 +461,15 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_createPlan(
                 String::new()
             };
 
+            // Capture the driving Spark task's TaskContext as a JNI global 
reference when
+            // non-null. The `Arc<Global<JObject>>` releases its global ref on 
drop, so cleanup
+            // is automatic when the ExecutionContext drops.
+            let task_context = if !task_context_obj.is_null() {
+                Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?))
+            } else {
+                None
+            };
+
             let exec_context = Box::new(ExecutionContext {
                 id,
                 task_attempt_id,
@@ -479,6 +496,7 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_createPlan(
                     "thread_{rust_thread_id}_comet_memory_reserved"
                 ),
                 tracing_event_name,
+                task_context,
             });
 
             Ok(Box::into_raw(exec_context) as i64)
@@ -703,7 +721,8 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_executePlan(
                 let start = Instant::now();
                 let planner =
                     
PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
-                        .with_exec_id(exec_context_id);
+                        .with_exec_id(exec_context_id)
+                        .with_task_context(exec_context.task_context.clone());
                 let (scans, shuffle_scans, root_op) = planner.create_plan(
                     &exec_context.spark_plan,
                     &mut exec_context.input_sources.clone(),
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 478c7a8d9..b00f14002 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -183,6 +183,9 @@ pub struct PhysicalPlanner {
     partition: i32,
     session_ctx: Arc<SessionContext>,
     query_context_registry: Arc<datafusion_comet_spark_expr::QueryContextMap>,
+    /// Captured at `createPlan` time on `ExecutionContext`; see that struct 
for the
+    /// propagation rationale. `None` when no driving Spark task is available.
+    task_context: Option<Arc<Global<JObject<'static>>>>,
 }
 
 impl Default for PhysicalPlanner {
@@ -198,16 +201,24 @@ impl PhysicalPlanner {
             session_ctx,
             partition,
             query_context_registry: 
datafusion_comet_spark_expr::create_query_context_map(),
+            task_context: None,
         }
     }
 
-    pub fn with_exec_id(self, exec_context_id: i64) -> Self {
-        Self {
-            exec_context_id,
-            partition: self.partition,
-            session_ctx: Arc::clone(&self.session_ctx),
-            query_context_registry: Arc::clone(&self.query_context_registry),
-        }
+    pub fn with_exec_id(mut self, exec_context_id: i64) -> Self {
+        self.exec_context_id = exec_context_id;
+        self
+    }
+
+    /// Attach the Spark `TaskContext` global reference captured at 
`createPlan` time. Cloned
+    /// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can 
install it as
+    /// the thread-local on the Tokio worker driving the UDF.
+    pub fn with_task_context(
+        mut self,
+        task_context: Option<Arc<Global<JObject<'static>>>>,
+    ) -> Self {
+        self.task_context = task_context;
+        self
     }
 
     /// Return session context of this planner.
@@ -735,6 +746,7 @@ impl PhysicalPlanner {
                     args,
                     return_type,
                     udf.return_nullable,
+                    self.task_context.clone(),
                 )))
             }
             expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),
diff --git a/native/jni-bridge/src/comet_udf_bridge.rs 
b/native/jni-bridge/src/comet_udf_bridge.rs
index 89cd8ee51..e531d20cb 100644
--- a/native/jni-bridge/src/comet_udf_bridge.rs
+++ b/native/jni-bridge/src/comet_udf_bridge.rs
@@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> {
             method_evaluate: env.get_static_method_id(
                 JNIString::new(Self::JVM_CLASS),
                 jni::jni_str!("evaluate"),
-                jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"),
+                
jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"),
             )?,
             method_evaluate_ret: ReturnType::Primitive(Primitive::Void),
             class,
diff --git a/native/spark-expr/src/jvm_udf/mod.rs 
b/native/spark-expr/src/jvm_udf/mod.rs
index 668a2b672..4ed25de6e 100644
--- a/native/spark-expr/src/jvm_udf/mod.rs
+++ b/native/spark-expr/src/jvm_udf/mod.rs
@@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr;
 
 use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError};
 use datafusion_comet_jni_bridge::JVMClasses;
-use jni::objects::{JObject, JValue};
+use jni::objects::{Global, JObject, JValue};
 
 /// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via 
JNI.
 /// The JVM class named by `class_name` must implement 
`org.apache.comet.udf.CometUDF`.
@@ -41,6 +41,14 @@ pub struct JvmScalarUdfExpr {
     args: Vec<Arc<dyn PhysicalExpr>>,
     return_type: DataType,
     return_nullable: bool,
+    /// Captured at `createPlan` time and threaded here by the planner. Passed 
through the
+    /// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio 
worker's
+    /// thread-local `TaskContext`. Without this, partition-sensitive 
built-ins inside a UDF
+    /// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
+    /// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` 
when no driving
+    /// Spark task is available; the bridge then leaves whatever 
`TaskContext.get()` already
+    /// returns in place.
+    task_context: Option<Arc<Global<JObject<'static>>>>,
 }
 
 impl JvmScalarUdfExpr {
@@ -49,12 +57,14 @@ impl JvmScalarUdfExpr {
         args: Vec<Arc<dyn PhysicalExpr>>,
         return_type: DataType,
         return_nullable: bool,
+        task_context: Option<Arc<Global<JObject<'static>>>>,
     ) -> Self {
         Self {
             class_name,
             args,
             return_type,
             return_nullable,
+            task_context,
         }
     }
 }
@@ -186,7 +196,14 @@ impl PhysicalExpr for JvmScalarUdfExpr {
                 .set_region(env, 0, &in_sch_ptrs)
                 .map_err(|e| CometError::JNI { source: e })?;
 
-            // Call CometUdfBridge.evaluate(String, long[], long[], long, long)
+            // Pass a null jobject when no TaskContext was propagated so the 
bridge's null-guard
+            // leaves the worker thread's current TaskContext.get() in place. 
The borrow must
+            // outlive `call_static_method_unchecked`.
+            let null_task_context = JObject::null();
+            let task_context_ref: &JObject = match &self.task_context {
+                Some(gref) => gref.as_obj(),
+                None => &null_task_context,
+            };
             let ret = unsafe {
                 env.call_static_method_unchecked(
                     &bridge.class,
@@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr {
                         
JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(),
                         JValue::Long(out_arr_ptr).as_jni(),
                         JValue::Long(out_sch_ptr).as_jni(),
+                        JValue::Int(batch.num_rows() as i32).as_jni(),
+                        JValue::Object(task_context_ref).as_jni(),
                     ],
                 )
             };
@@ -234,6 +253,7 @@ impl PhysicalExpr for JvmScalarUdfExpr {
             children,
             self.return_type.clone(),
             self.return_nullable,
+            self.task_context.clone(),
         )))
     }
 }
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala 
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index a93564811..6140eca55 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -127,7 +127,10 @@ class CometExecIterator(
       memoryConfig.memoryLimitPerTask,
       taskAttemptId,
       taskCPUs,
-      keyUnwrapper)
+      keyUnwrapper,
+      // Propagated to Tokio workers running JVM UDFs so they see this Spark 
task's
+      // TaskContext. See CometUdfBridge.evaluate.
+      TaskContext.get())
   }
 
   private var nextBatch: Option[ColumnarBatch] = None
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala 
b/spark/src/main/scala/org/apache/comet/Native.scala
index c003bcd13..3cfa51b6e 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -21,7 +21,7 @@ package org.apache.comet
 
 import java.nio.ByteBuffer
 
-import org.apache.spark.CometTaskMemoryManager
+import org.apache.spark.{CometTaskMemoryManager, TaskContext}
 import org.apache.spark.sql.comet.CometMetricNode
 
 import org.apache.comet.parquet.CometFileKeyUnwrapper
@@ -69,7 +69,8 @@ class Native extends NativeBase {
       memoryLimitPerTask: Long,
       taskAttemptId: Long,
       taskCPUs: Long,
-      keyUnwrapper: CometFileKeyUnwrapper): Long
+      keyUnwrapper: CometFileKeyUnwrapper,
+      taskContext: TaskContext): Long
   // scalastyle:on
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to