This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 37e4c2d1883d [SPARK-55646][SQL] Refactored
SQLExecution.withThreadLocalCaptured to separate thread-local capture from
execution
37e4c2d1883d is described below
commit 37e4c2d1883d8ce356cd5dbae555571443c8115a
Author: huanliwang-db <[email protected]>
AuthorDate: Tue Feb 24 12:32:37 2026 +0800
[SPARK-55646][SQL] Refactored SQLExecution.withThreadLocalCaptured to
separate thread-local capture from execution
### What changes were proposed in this pull request?
Previously, callers had to provide an ExecutorService upfront: thread-local
capture and task submission were fused into a single call that immediately
returned a CompletableFuture.
Now, captureThreadLocals(sparkSession) captures the current thread's SQL
context into a standalone SQLExecutionThreadLocalCaptured object. Callers can
then invoke `runWith { body }` on any thread, at any time, using any
concurrency primitive — not just ExecutorService.
`withThreadLocalCaptured` is preserved for backward compatibility and now
delegates to these two primitives.
### Why are the changes needed?
Refactoring to make withThreadLocalCaptured easier to use.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
Existing UTs
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #54434 from huanliwang-db/huanliwang-db/refactor-sqlthread.
Authored-by: huanliwang-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../apache/spark/sql/execution/SQLExecution.scala | 76 +++++++++++++++-------
1 file changed, 51 insertions(+), 25 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 96a0053f97b1..f25e908a9cdb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
-import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkContext,
SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper}
+import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, JobArtifactState,
SparkContext, SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper}
import org.apache.spark.SparkContext.{SPARK_JOB_DESCRIPTION,
SPARK_JOB_INTERRUPT_ON_CANCEL}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX,
SPARK_EXECUTOR_PREFIX}
@@ -38,6 +38,43 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.SQL_EVENT_TRUNCATE_LENGTH
import org.apache.spark.util.{Utils, UUIDv7Generator}
+/**
+ * Captures SQL-specific thread-local variables so they can be restored on a
different thread.
+ * Use [[SQLExecution.captureThreadLocals]] to create an instance on the
originating thread,
+ * then call [[runWith]] on the target thread to execute a block with these
thread locals applied.
+ */
+case class SQLExecutionThreadLocalCaptured(
+ sparkSession: SparkSession,
+ localProps: java.util.Properties,
+ artifactState: JobArtifactState) {
+
+ /**
+ * Run the given body with the captured thread-local variables applied on
the current thread.
+ * Original thread-local values are saved and restored after the body
completes.
+ */
+ def runWith[T](body: => T): T = {
+ val sc = sparkSession.sparkContext
+ JobArtifactSet.withActiveJobArtifactState(artifactState) {
+ val originalSession = SparkSession.getActiveSession
+ val originalLocalProps = sc.getLocalProperties
+ SparkSession.setActiveSession(sparkSession)
+ val res = SQLExecution.withSessionTagsApplied(sparkSession) {
+ sc.setLocalProperties(localProps)
+ val res = body
+ // reset active session and local props.
+ sc.setLocalProperties(originalLocalProps)
+ res
+ }
+ if (originalSession.nonEmpty) {
+ SparkSession.setActiveSession(originalSession.get)
+ } else {
+ SparkSession.clearActiveSession()
+ }
+ res
+ }
+ }
+}
+
object SQLExecution extends Logging {
val EXECUTION_ID_KEY = "spark.sql.execution.id"
@@ -343,36 +380,25 @@ object SQLExecution extends Logging {
}
}
+ def captureThreadLocals(sparkSession: SparkSession):
SQLExecutionThreadLocalCaptured = {
+ val sc = sparkSession.sparkContext
+ val localProps = Utils.cloneProperties(sc.getLocalProperties)
+ // `getCurrentJobArtifactState` will return a stat only in Spark Connect
mode. In non-Connect
+ // mode, we default back to the resources of the current Spark session.
+ val artifactState =
+
JobArtifactSet.getCurrentJobArtifactState.getOrElse(sparkSession.artifactManager.state)
+ SQLExecutionThreadLocalCaptured(sparkSession, localProps, artifactState)
+ }
+
/**
* Wrap passed function to ensure necessary thread-local variables like
* SparkContext local properties are forwarded to execution thread
*/
def withThreadLocalCaptured[T](
sparkSession: SparkSession, exec: ExecutorService) (body: => T):
CompletableFuture[T] = {
- val activeSession = sparkSession
- val sc = sparkSession.sparkContext
- val localProps = Utils.cloneProperties(sc.getLocalProperties)
- // `getCurrentJobArtifactState` will return a stat only in Spark Connect
mode. In non-Connect
- // mode, we default back to the resources of the current Spark session.
- val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
- activeSession.artifactManager.state)
- CompletableFuture.supplyAsync(() =>
JobArtifactSet.withActiveJobArtifactState(artifactState) {
- val originalSession = SparkSession.getActiveSession
- val originalLocalProps = sc.getLocalProperties
- SparkSession.setActiveSession(activeSession)
- val res = withSessionTagsApplied(activeSession) {
- sc.setLocalProperties(localProps)
- val res = body
- // reset active session and local props.
- sc.setLocalProperties(originalLocalProps)
- res
- }
- if (originalSession.nonEmpty) {
- SparkSession.setActiveSession(originalSession.get)
- } else {
- SparkSession.clearActiveSession()
- }
- res
+ val threadLocalCaptured = captureThreadLocals(sparkSession)
+ CompletableFuture.supplyAsync(() => {
+ threadLocalCaptured.runWith(body)
}, exec)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]