This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f95a7f4dd2c [SPARK-50260][CONNECT] Refactor and optimize Spark 
Connect execution and session management
4f95a7f4dd2c is described below

commit 4f95a7f4dd2c2d30be549d2a90e52e44046e0726
Author: changgyoopark-db <[email protected]>
AuthorDate: Tue Nov 12 14:00:05 2024 +0900

    [SPARK-50260][CONNECT] Refactor and optimize Spark Connect execution and 
session management
    
    ### What changes were proposed in this pull request?
    
    Code refactoring.
    - Replace int with a dedicated case class to represent the state of an 
execution thread.
    
    Minor optimization.
    - Remove unnecessary steps before actually removing expired executions and 
sessions.
    
    ### Why are the changes needed?
    
    Improve code readability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48792 from changgyoopark-db/SPARK-50260.
    
    Authored-by: changgyoopark-db <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../connect/execution/ExecuteThreadRunner.scala    | 22 +++++++++------
 .../service/SparkConnectExecutionManager.scala     | 32 +++++++---------------
 .../service/SparkConnectSessionManager.scala       | 16 +++--------
 3 files changed, 27 insertions(+), 43 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 61be2bc4eb99..d27f390a23f9 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.connect.execution
 
-import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.atomic.AtomicReference
 
 import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
@@ -41,7 +41,8 @@ import org.apache.spark.util.Utils
 private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) 
extends Logging {
 
   /** The thread state. */
-  private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted)
+  private val state: AtomicReference[ThreadStateInfo] = new AtomicReference(
+    ThreadState.notStarted)
 
   // The newly created thread will inherit all InheritableThreadLocals used by 
Spark,
   // e.g. SparkContext.localProperties. If considering implementing a 
thread-pool,
@@ -349,17 +350,20 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
 private object ThreadState {
 
   /** The thread has not started: transition to interrupted or started. */
-  val notStarted: Int = 0
+  val notStarted: ThreadStateInfo = ThreadStateInfo(0)
 
   /** Execution was interrupted: terminal state. */
-  val interrupted: Int = 1
+  val interrupted: ThreadStateInfo = ThreadStateInfo(1)
 
   /** The thread has started: transition to startedInterrupted or completed. */
-  val started: Int = 2
+  val started: ThreadStateInfo = ThreadStateInfo(2)
 
-  /** The thread has started and execution was interrupted: transition to 
completed. */
-  val startedInterrupted: Int = 3
+  /** The thread was started and execution has been interrupted: transition to 
completed. */
+  val startedInterrupted: ThreadStateInfo = ThreadStateInfo(3)
 
-  /** Execution was completed: terminal state. */
-  val completed: Int = 4
+  /** Execution has been completed: terminal state. */
+  val completed: ThreadStateInfo = ThreadStateInfo(4)
 }
+
+/** Represents the state of an execution thread. */
+case class ThreadStateInfo(val transitionState: Int)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index d9eb5438c388..f750ca6db67a 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -21,7 +21,6 @@ import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, 
ScheduledExecutorService, TimeUnit}
 import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
 
-import scala.collection.mutable
 import scala.concurrent.duration.FiniteDuration
 import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
@@ -160,19 +159,14 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
   }
 
   private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
-    var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]()
     executions.forEach((_, executeHolder) => {
       if (executeHolder.sessionHolder.key == key) {
-        sessionExecutionHolders += executeHolder
+        val info = executeHolder.getExecuteInfo
+        logInfo(
+          log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in 
removeSessionExecutions.")
+        removeExecuteHolder(executeHolder.key, abandoned = true)
       }
     })
-
-    sessionExecutionHolders.foreach { executeHolder =>
-      val info = executeHolder.getExecuteInfo
-      logInfo(
-        log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in 
removeSessionExecutions.")
-      removeExecuteHolder(executeHolder.key, abandoned = true)
-    }
   }
 
   /** Get info about abandoned execution, if there is one. */
@@ -252,30 +246,24 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
 
   // Visible for testing.
   private[connect] def periodicMaintenance(timeout: Long): Unit = {
+    // Find any detached executions that expired and should be removed.
     logInfo("Started periodic run of SparkConnectExecutionManager 
maintenance.")
 
-    // Find any detached executions that expired and should be removed.
-    val toRemove = new mutable.ArrayBuffer[ExecuteHolder]()
     val nowMs = System.currentTimeMillis()
-
     executions.forEach((_, executeHolder) => {
       executeHolder.lastAttachedRpcTimeMs match {
         case Some(detached) =>
           if (detached + timeout <= nowMs) {
-            toRemove += executeHolder
+            val info = executeHolder.getExecuteInfo
+            logInfo(
+              log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was 
abandoned " +
+                log"and expired and will be removed.")
+            removeExecuteHolder(executeHolder.key, abandoned = true)
           }
         case _ => // execution is active
       }
     })
 
-    // .. and remove them.
-    toRemove.foreach { executeHolder =>
-      val info = executeHolder.getExecuteInfo
-      logInfo(
-        log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was 
abandoned " +
-          log"and expired and will be removed.")
-      removeExecuteHolder(executeHolder.key, abandoned = true)
-    }
     logInfo("Finished periodic run of SparkConnectExecutionManager 
maintenance.")
   }
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index 4ca3a80bfb98..a306856efa33 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -21,7 +21,6 @@ import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, 
ScheduledExecutorService, TimeUnit}
 import java.util.concurrent.atomic.AtomicReference
 
-import scala.collection.mutable
 import scala.concurrent.duration.FiniteDuration
 import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
@@ -226,9 +225,8 @@ class SparkConnectSessionManager extends Logging {
   private def periodicMaintenance(
       defaultInactiveTimeoutMs: Long,
       ignoreCustomTimeout: Boolean): Unit = {
-    logInfo("Started periodic run of SparkConnectSessionManager maintenance.")
     // Find any sessions that expired and should be removed.
-    val toRemove = new mutable.ArrayBuffer[SessionHolder]()
+    logInfo("Started periodic run of SparkConnectSessionManager maintenance.")
 
     def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = {
       val timeoutMs = if (info.customInactiveTimeoutMs.isDefined && 
!ignoreCustomTimeout) {
@@ -242,15 +240,8 @@ class SparkConnectSessionManager extends Logging {
 
     val nowMs = System.currentTimeMillis()
     sessionStore.forEach((_, sessionHolder) => {
-      if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
-        toRemove += sessionHolder
-      }
-    })
-
-    // .. and remove them.
-    toRemove.foreach { sessionHolder =>
       val info = sessionHolder.getSessionHolderInfo
-      if (shouldExpire(info, System.currentTimeMillis())) {
+      if (shouldExpire(info, nowMs)) {
         logInfo(
           log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
             log"and will be closed.")
@@ -261,7 +252,8 @@ class SparkConnectSessionManager extends Logging {
           case NonFatal(ex) => logWarning("Unexpected exception closing 
session", ex)
         }
       }
-    }
+    })
+
     logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to