This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b34891e5b9c [SPARK-49684][CONNECT] Remove global locks from session 
and execution managers
3b34891e5b9c is described below

commit 3b34891e5b9c2694b7ffdc265290e25847dc3437
Author: Changgyoo Park <[email protected]>
AuthorDate: Thu Sep 19 09:10:51 2024 +0900

    [SPARK-49684][CONNECT] Remove global locks from session and execution 
managers
    
    ### What changes were proposed in this pull request?
    
    Eliminate the use of global locks in the session and execution managers. 
Those locks residing in the streaming query manager cannot be easily removed 
because the tag and query maps seemingly need to be synchronised.
    
    ### Why are the changes needed?
    
    In order to achieve true scalability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48131 from changgyoopark-db/SPARK-49684.
    
    Authored-by: Changgyoo Park <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../service/SparkConnectExecutionManager.scala     | 59 +++++++++------------
 .../service/SparkConnectSessionManager.scala       | 60 +++++++++-------------
 .../service/SparkConnectStreamingQueryCache.scala  | 22 ++++----
 3 files changed, 61 insertions(+), 80 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index 61b41f932199..d66964b8d34b 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service
 
 import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, 
ScheduledExecutorService, TimeUnit}
-import javax.annotation.concurrent.GuardedBy
+import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
 
 import scala.collection.mutable
 import scala.concurrent.duration.FiniteDuration
@@ -66,7 +66,6 @@ private[connect] class SparkConnectExecutionManager() extends 
Logging {
   /** Concurrent hash table containing all the current executions. */
   private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] =
     new ConcurrentHashMap[ExecuteKey, ExecuteHolder]()
-  private val executionsLock = new Object
 
   /** Graveyard of tombstones of executions that were abandoned and removed. */
   private val abandonedTombstones = CacheBuilder
@@ -74,13 +73,12 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
     
.maximumSize(SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE))
     .build[ExecuteKey, ExecuteInfo]()
 
-  /** None if there are no executions. Otherwise, the time when the last 
execution was removed. */
-  @GuardedBy("executionsLock")
-  private var lastExecutionTimeMs: Option[Long] = 
Some(System.currentTimeMillis())
+  /** The time when the last execution was removed. */
+  private var lastExecutionTimeMs: AtomicLong = new 
AtomicLong(System.currentTimeMillis())
 
   /** Executor for the periodic maintenance */
-  @GuardedBy("executionsLock")
-  private var scheduledExecutor: Option[ScheduledExecutorService] = None
+  private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+    new AtomicReference[ScheduledExecutorService]()
 
   /**
    * Create a new ExecuteHolder and register it with this global manager and 
with its session.
@@ -118,11 +116,6 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
 
     sessionHolder.addExecuteHolder(executeHolder)
 
-    executionsLock.synchronized {
-      if (!executions.isEmpty()) {
-        lastExecutionTimeMs = None
-      }
-    }
     logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} 
is created.")
 
     schedulePeriodicChecks() // Starts the maintenance thread if it hasn't 
started.
@@ -151,11 +144,7 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
     executions.remove(key)
     executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId)
 
-    executionsLock.synchronized {
-      if (executions.isEmpty) {
-        lastExecutionTimeMs = Some(System.currentTimeMillis())
-      }
-    }
+    updateLastExecutionTime()
 
     logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")
 
@@ -197,7 +186,7 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
    */
   def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = {
     if (executions.isEmpty) {
-      Left(lastExecutionTimeMs.get)
+      Left(lastExecutionTimeMs.getAcquire())
     } else {
       Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq)
     }
@@ -212,22 +201,23 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
   }
 
   private[connect] def shutdown(): Unit = {
-    executionsLock.synchronized {
-      scheduledExecutor.foreach { executor =>
-        ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
-      }
-      scheduledExecutor = None
+    val executor = scheduledExecutor.getAndSet(null)
+    if (executor != null) {
+      ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
     }
 
     // note: this does not cleanly shut down the executions, but the server is 
shutting down.
     executions.clear()
     abandonedTombstones.invalidateAll()
 
-    executionsLock.synchronized {
-      if (lastExecutionTimeMs.isEmpty) {
-        lastExecutionTimeMs = Some(System.currentTimeMillis())
-      }
-    }
+    updateLastExecutionTime()
+  }
+
+  /**
+   * Updates the last execution time after the last execution has been removed.
+   */
+  private def updateLastExecutionTime(): Unit = {
+    lastExecutionTimeMs.getAndUpdate(prev => 
prev.max(System.currentTimeMillis()))
   }
 
   /**
@@ -235,16 +225,16 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
    * for executions that have not been closed, but are left with no RPC 
attached to them, and
    * removes them after a timeout.
    */
-  private def schedulePeriodicChecks(): Unit = executionsLock.synchronized {
-    scheduledExecutor match {
-      case Some(_) => // Already running.
-      case None =>
+  private def schedulePeriodicChecks(): Unit = {
+    var executor = scheduledExecutor.getAcquire()
+    if (executor == null) {
+      executor = Executors.newSingleThreadScheduledExecutor()
+      if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) 
{
         val interval = 
SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL)
         logInfo(
           log"Starting thread for cleanup of abandoned executions every " +
             log"${MDC(LogKeys.INTERVAL, interval)} ms")
-        scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
-        scheduledExecutor.get.scheduleAtFixedRate(
+        executor.scheduleAtFixedRate(
           () => {
             try {
               val timeout = 
SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT)
@@ -256,6 +246,7 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
           interval,
           interval,
           TimeUnit.MILLISECONDS)
+      }
     }
   }
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index fec01813de6e..4ca3a80bfb98 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service
 
 import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, 
ScheduledExecutorService, TimeUnit}
-import javax.annotation.concurrent.GuardedBy
+import java.util.concurrent.atomic.AtomicReference
 
 import scala.collection.mutable
 import scala.concurrent.duration.FiniteDuration
@@ -40,8 +40,6 @@ import org.apache.spark.util.ThreadUtils
  */
 class SparkConnectSessionManager extends Logging {
 
-  private val sessionsLock = new Object
-
   private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
     new ConcurrentHashMap[SessionKey, SessionHolder]()
 
@@ -52,8 +50,8 @@ class SparkConnectSessionManager extends Logging {
       .build[SessionKey, SessionHolderInfo]()
 
   /** Executor for the periodic maintenance */
-  @GuardedBy("sessionsLock")
-  private var scheduledExecutor: Option[ScheduledExecutorService] = None
+  private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+    new AtomicReference[ScheduledExecutorService]()
 
   private def validateSessionId(
       key: SessionKey,
@@ -75,8 +73,6 @@ class SparkConnectSessionManager extends Logging {
     val holder = getSession(
       key,
       Some(() => {
-        // Executed under sessionsState lock in getSession,  to guard against 
concurrent removal
-        // and insertion into closedSessionsCache.
         validateSessionCreate(key)
         val holder = SessionHolder(key.userId, key.sessionId, 
newIsolatedSession())
         holder.initializeSession()
@@ -168,17 +164,14 @@ class SparkConnectSessionManager extends Logging {
 
   def closeSession(key: SessionKey): Unit = {
     val sessionHolder = removeSessionHolder(key)
-    // Rest of the cleanup outside sessionLock - the session cannot be 
accessed anymore by
-    // getOrCreateIsolatedSession.
+    // Rest of the cleanup: the session cannot be accessed anymore by 
getOrCreateIsolatedSession.
     sessionHolder.foreach(shutdownSessionHolder(_))
   }
 
   private[connect] def shutdown(): Unit = {
-    sessionsLock.synchronized {
-      scheduledExecutor.foreach { executor =>
-        ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
-      }
-      scheduledExecutor = None
+    val executor = scheduledExecutor.getAndSet(null)
+    if (executor != null) {
+      ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
     }
 
     // note: this does not cleanly shut down the sessions, but the server is 
shutting down.
@@ -199,16 +192,16 @@ class SparkConnectSessionManager extends Logging {
    *
    * The checks are looking to remove sessions that expired.
    */
-  private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized {
-    scheduledExecutor match {
-      case Some(_) => // Already running.
-      case None =>
+  private def schedulePeriodicChecks(): Unit = {
+    var executor = scheduledExecutor.getAcquire()
+    if (executor == null) {
+      executor = Executors.newSingleThreadScheduledExecutor()
+      if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) 
{
         val interval = 
SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL)
         logInfo(
           log"Starting thread for cleanup of expired sessions every " +
             log"${MDC(INTERVAL, interval)} ms")
-        scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
-        scheduledExecutor.get.scheduleAtFixedRate(
+        executor.scheduleAtFixedRate(
           () => {
             try {
               val defaultInactiveTimeoutMs =
@@ -221,6 +214,7 @@ class SparkConnectSessionManager extends Logging {
           interval,
           interval,
           TimeUnit.MILLISECONDS)
+      }
     }
   }
 
@@ -255,24 +249,18 @@ class SparkConnectSessionManager extends Logging {
 
     // .. and remove them.
     toRemove.foreach { sessionHolder =>
-      // This doesn't use closeSession to be able to do the extra last chance 
check under lock.
-      val removedSession = {
-        // Last chance - check expiration time and remove under lock if 
expired.
-        val info = sessionHolder.getSessionHolderInfo
-        if (shouldExpire(info, System.currentTimeMillis())) {
-          logInfo(
-            log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
-              log"and will be closed.")
-          removeSessionHolder(info.key)
-        } else {
-          None
+      val info = sessionHolder.getSessionHolderInfo
+      if (shouldExpire(info, System.currentTimeMillis())) {
+        logInfo(
+          log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
+            log"and will be closed.")
+        removeSessionHolder(info.key)
+        try {
+          shutdownSessionHolder(sessionHolder)
+        } catch {
+          case NonFatal(ex) => logWarning("Unexpected exception closing 
session", ex)
         }
       }
-      // do shutdown and cleanup outside of lock.
-      try removedSession.foreach(shutdownSessionHolder(_))
-      catch {
-        case NonFatal(ex) => logWarning("Unexpected exception closing 
session", ex)
-      }
     }
     logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
   }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
index 03719ddd8741..8241672d5107 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
 import java.util.concurrent.Executors
 import java.util.concurrent.ScheduledExecutorService
 import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicReference
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
@@ -185,10 +186,10 @@ private[connect] class SparkConnectStreamingQueryCache(
 
   // Visible for testing.
   private[service] def shutdown(): Unit = queryCacheLock.synchronized {
-    scheduledExecutor.foreach { executor =>
+    val executor = scheduledExecutor.getAndSet(null)
+    if (executor != null) {
       ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
     }
-    scheduledExecutor = None
   }
 
   @GuardedBy("queryCacheLock")
@@ -199,19 +200,19 @@ private[connect] class SparkConnectStreamingQueryCache(
   private val taggedQueries = new mutable.HashMap[String, 
mutable.ArrayBuffer[QueryCacheKey]]
   private val taggedQueriesLock = new Object
 
-  @GuardedBy("queryCacheLock")
-  private var scheduledExecutor: Option[ScheduledExecutorService] = None
+  private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+    new AtomicReference[ScheduledExecutorService]()
 
   /** Schedules periodic checks if it is not already scheduled */
-  private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized {
-    scheduledExecutor match {
-      case Some(_) => // Already running.
-      case None =>
+  private def schedulePeriodicChecks(): Unit = {
+    var executor = scheduledExecutor.getAcquire()
+    if (executor == null) {
+      executor = Executors.newSingleThreadScheduledExecutor()
+      if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) 
{
         logInfo(
           log"Starting thread for polling streaming sessions " +
             log"every ${MDC(DURATION, sessionPollingPeriod.toMillis)}")
-        scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
-        scheduledExecutor.get.scheduleAtFixedRate(
+        executor.scheduleAtFixedRate(
           () => {
             try periodicMaintenance()
             catch {
@@ -221,6 +222,7 @@ private[connect] class SparkConnectStreamingQueryCache(
           sessionPollingPeriod.toMillis,
           sessionPollingPeriod.toMillis,
           TimeUnit.MILLISECONDS)
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to