This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new c52efa4c [AURON-1272] Support HDFS CallerContext (#1260)
c52efa4c is described below

commit c52efa4cddb79c0c92eec900ece1d65ffbd66007
Author: cxzl25 <[email protected]>
AuthorDate: Tue Sep 9 15:19:36 2025 +0800

    [AURON-1272] Support HDFS CallerContext (#1260)
    
    * support HDFS CallerContext
    
    * initNativeThread
    
    * initNativeThread args
    
    * thread name
---
 native-engine/auron-jni-bridge/src/jni_bridge.rs   |  16 +--
 native-engine/auron/src/rt.rs                      |   5 +-
 .../java/org/apache/spark/sql/auron/JniBridge.java |  12 +-
 .../spark/sql/auron/util/TaskContextHelper.scala   | 135 +++++++++++++++++++++
 4 files changed, 152 insertions(+), 16 deletions(-)

diff --git a/native-engine/auron-jni-bridge/src/jni_bridge.rs 
b/native-engine/auron-jni-bridge/src/jni_bridge.rs
index 9c1d4e29..cc9ed819 100644
--- a/native-engine/auron-jni-bridge/src/jni_bridge.rs
+++ b/native-engine/auron-jni-bridge/src/jni_bridge.rs
@@ -553,8 +553,6 @@ pub struct JniBridge<'a> {
     pub method_getSparkEnvConfAsString_ret: ReturnType,
     pub method_getResource: JStaticMethodID,
     pub method_getResource_ret: ReturnType,
-    pub method_setTaskContext: JStaticMethodID,
-    pub method_setTaskContext_ret: ReturnType,
     pub method_getTaskContext: JStaticMethodID,
     pub method_getTaskContext_ret: ReturnType,
     pub method_getTaskOnHeapSpillManager: JStaticMethodID,
@@ -573,6 +571,8 @@ pub struct JniBridge<'a> {
     pub method_getTotalMemoryLimited_ret: ReturnType,
     pub method_getDirectWriteSpillToDiskFile: JStaticMethodID,
     pub method_getDirectWriteSpillToDiskFile_ret: ReturnType,
+    pub method_initNativeThread: JStaticMethodID,
+    pub method_initNativeThread_ret: ReturnType,
 }
 impl<'a> JniBridge<'a> {
     pub const SIG_TYPE: &'static str = "org/apache/spark/sql/auron/JniBridge";
@@ -611,12 +611,6 @@ impl<'a> JniBridge<'a> {
                 "()Lorg/apache/spark/TaskContext;",
             )?,
             method_getTaskContext_ret: ReturnType::Object,
-            method_setTaskContext: env.get_static_method_id(
-                class,
-                "setTaskContext",
-                "(Lorg/apache/spark/TaskContext;)V",
-            )?,
-            method_setTaskContext_ret: ReturnType::Primitive(Primitive::Void),
             method_getTaskOnHeapSpillManager: env.get_static_method_id(
                 class,
                 "getTaskOnHeapSpillManager",
@@ -657,6 +651,12 @@ impl<'a> JniBridge<'a> {
                 "()Ljava/lang/String;",
             )?,
             method_getDirectWriteSpillToDiskFile_ret: ReturnType::Object,
+            method_initNativeThread: env.get_static_method_id(
+                class,
+                "initNativeThread",
+                "(Ljava/lang/ClassLoader;Lorg/apache/spark/TaskContext;)V",
+            )?,
+            method_initNativeThread_ret: 
ReturnType::Primitive(Primitive::Void),
         })
     }
 }
diff --git a/native-engine/auron/src/rt.rs b/native-engine/auron/src/rt.rs
index 55d10d84..4b33b6d8 100644
--- a/native-engine/auron/src/rt.rs
+++ b/native-engine/auron/src/rt.rs
@@ -115,10 +115,7 @@ impl NativeExecutionRuntime {
             .on_thread_start(move || {
                 let classloader = JavaClasses::get().classloader;
                 let _ = jni_call_static!(
-                    JniBridge.setContextClassLoader(classloader) -> ()
-                );
-                let _ = jni_call_static!(
-                    
JniBridge.setTaskContext(spark_task_context_global.as_obj()) -> ()
+                    
JniBridge.initNativeThread(classloader,spark_task_context_global.as_obj()) -> ()
                 );
                 THREAD_STAGE_ID.set(stage_id);
                 THREAD_PARTITION_ID.set(partition_id);
diff --git 
a/spark-extension/src/main/java/org/apache/spark/sql/auron/JniBridge.java 
b/spark-extension/src/main/java/org/apache/spark/sql/auron/JniBridge.java
index c36aeafd..3df98c5d 100644
--- a/spark-extension/src/main/java/org/apache/spark/sql/auron/JniBridge.java
+++ b/spark-extension/src/main/java/org/apache/spark/sql/auron/JniBridge.java
@@ -32,6 +32,7 @@ import org.apache.spark.auron.FSDataOutputWrapper;
 import org.apache.spark.auron.FSDataOutputWrapper$;
 import org.apache.spark.sql.auron.memory.OnHeapSpillManager;
 import org.apache.spark.sql.auron.memory.OnHeapSpillManager$;
+import org.apache.spark.sql.auron.util.TaskContextHelper$;
 
 @SuppressWarnings("unused")
 public class JniBridge {
@@ -65,10 +66,6 @@ public class JniBridge {
         return TaskContext$.MODULE$.get();
     }
 
-    public static void setTaskContext(TaskContext tc) {
-        TaskContext$.MODULE$.setTaskContext(tc);
-    }
-
     public static OnHeapSpillManager getTaskOnHeapSpillManager() {
         return OnHeapSpillManager$.MODULE$.current();
     }
@@ -117,4 +114,11 @@ public class JniBridge {
                 ._2
                 .getPath();
     }
+
+    public static void initNativeThread(ClassLoader cl, TaskContext tc) {
+        setContextClassLoader(cl);
+        TaskContext$.MODULE$.setTaskContext(tc);
+        TaskContextHelper$.MODULE$.setNativeThreadName();
+        TaskContextHelper$.MODULE$.setHDFSCallerContext();
+    }
 }
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/util/TaskContextHelper.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/util/TaskContextHelper.scala
new file mode 100644
index 00000000..8e5d7353
--- /dev/null
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/util/TaskContextHelper.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron.util
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.APP_CALLER_CONTEXT
+import org.apache.spark.util.Utils
+
+object TaskContextHelper extends Logging {
+
+  private val callerContextSupported: Boolean = {
+    SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", 
false) && {
+      try {
+        Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+        Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
+        true
+      } catch {
+        case _: ClassNotFoundException =>
+          false
+        case NonFatal(e) =>
+          logWarning("Fail to load the CallerContext class", e)
+          false
+      }
+    }
+  }
+
+  def setNativeThreadName(): Unit = {
+    val context: TaskContext = TaskContext.get()
+    val thread = Thread.currentThread()
+    val threadName = if (context != null) {
+      s"auron native task ${context.partitionId()}.${context.attemptNumber()} 
in stage ${context
+        .stageId()}.${context.stageAttemptNumber()} (TID 
${context.taskAttemptId()})"
+    } else {
+      "auron native task " + thread.getName
+    }
+    thread.setName(threadName)
+  }
+
+  def setHDFSCallerContext(): Unit = {
+    if (!callerContextSupported) {
+      return
+    }
+    val context: TaskContext = TaskContext.get()
+    if (context != null) {
+      val conf = SparkEnv.get.conf
+      val appId = conf.get("spark.app.id", "")
+      val appAttemptId = conf.get("spark.app.attempt.id", "")
+      // Spark executor cannot get the jobId from TaskContext, so we set a 
default value -1 here.
+      val jobId = -1
+      new CallerContextHelper(
+        "TASK",
+        conf.get(APP_CALLER_CONTEXT),
+        Option(appId),
+        if (appAttemptId == "") None else Option(appAttemptId),
+        Option(jobId),
+        Option(context.stageId()),
+        Option(context.stageAttemptNumber()),
+        Option(context.taskAttemptId()),
+        Option(context.attemptNumber())).setCurrentContext()
+    }
+  }
+
+  /**
+   * Copied from Apache Spark org.apache.spark.util.CallerContext
+   */
+  private class CallerContextHelper(
+      from: String,
+      upstreamCallerContext: Option[String] = None,
+      appId: Option[String] = None,
+      appAttemptId: Option[String] = None,
+      jobId: Option[Int] = None,
+      stageId: Option[Int] = None,
+      stageAttemptId: Option[Int] = None,
+      taskId: Option[Long] = None,
+      taskAttemptNumber: Option[Int] = None)
+      extends Logging {
+
+    private val context = prepareContext(
+      "SPARK_" +
+        from +
+        appId.map("_" + _).getOrElse("") +
+        appAttemptId.map("_" + _).getOrElse("") +
+        jobId.map("_JId_" + _).getOrElse("") +
+        stageId.map("_SId_" + _).getOrElse("") +
+        stageAttemptId.map("_" + _).getOrElse("") +
+        taskId.map("_TId_" + _).getOrElse("") +
+        taskAttemptNumber.map("_" + _).getOrElse("") +
+        upstreamCallerContext.map("_" + _).getOrElse(""))
+
+    private def prepareContext(context: String): String = {
+      lazy val len = 
SparkHadoopUtil.get.conf.getInt("hadoop.caller.context.max.size", 128)
+      if (context == null || context.length <= len) {
+        context
+      } else {
+        val finalContext = context.substring(0, len)
+        logWarning(s"Truncated Spark caller context from $context to 
$finalContext")
+        finalContext
+      }
+    }
+
+    def setCurrentContext(): Unit = {
+      if (callerContextSupported) {
+        try {
+          val callerContext = 
Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+          val builder: Class[AnyRef] =
+            Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
+          val builderInst = 
builder.getConstructor(classOf[String]).newInstance(context)
+          val hdfsContext = builder.getMethod("build").invoke(builderInst)
+          callerContext.getMethod("setCurrent", callerContext).invoke(null, 
hdfsContext)
+        } catch {
+          case NonFatal(e) =>
+            logWarning("Fail to set Spark caller context", e)
+        }
+      }
+    }
+  }
+}

Reply via email to