This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3cb8d6e59999 [SPARK-49548][CONNECT] Replace coarse-locking in 
SparkConnectSessionManager with ConcurrentMap
3cb8d6e59999 is described below

commit 3cb8d6e59999e5525374c62f964c57657935311c
Author: Changgyoo Park <[email protected]>
AuthorDate: Thu Sep 12 08:38:10 2024 +0900

    [SPARK-49548][CONNECT] Replace coarse-locking in SparkConnectSessionManager 
with ConcurrentMap
    
    ### What changes were proposed in this pull request?
    
    Replace the coarse-locking in SparkConnectSessionManager with ConcurrentMap 
in order to minimise lock contention when there are many sessions.
    
    ### Why are the changes needed?
    
    It is a spin-off from https://github.com/apache/spark/pull/48034 where 
https://github.com/apache/spark/pull/48034 addresses many-execution cases 
whereas this addresses many-session situations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48036 from changgyoopark-db/SPARK-49548.
    
    Authored-by: Changgyoo Park <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../service/SparkConnectSessionManager.scala       | 99 +++++++++++-----------
 1 file changed, 49 insertions(+), 50 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index edaaa640bf12..fec01813de6e 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.connect.service
 
 import java.util.UUID
-import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, 
ScheduledExecutorService, TimeUnit}
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
@@ -42,8 +42,8 @@ class SparkConnectSessionManager extends Logging {
 
   private val sessionsLock = new Object
 
-  @GuardedBy("sessionsLock")
-  private val sessionStore = mutable.HashMap[SessionKey, SessionHolder]()
+  private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
+    new ConcurrentHashMap[SessionKey, SessionHolder]()
 
   private val closedSessionsCache =
     CacheBuilder
@@ -52,6 +52,7 @@ class SparkConnectSessionManager extends Logging {
       .build[SessionKey, SessionHolderInfo]()
 
   /** Executor for the periodic maintenance */
+  @GuardedBy("sessionsLock")
   private var scheduledExecutor: Option[ScheduledExecutorService] = None
 
   private def validateSessionId(
@@ -121,43 +122,39 @@ class SparkConnectSessionManager extends Logging {
   private def getSession(key: SessionKey, default: Option[() => 
SessionHolder]): SessionHolder = {
     schedulePeriodicChecks() // Starts the maintenance thread if it hasn't 
started yet.
 
-    sessionsLock.synchronized {
-      // try to get existing session from store
-      val sessionOpt = sessionStore.get(key)
-      // create using default if missing
-      val session = sessionOpt match {
-        case Some(s) => s
-        case None =>
-          default match {
-            case Some(callable) =>
-              val session = callable()
-              sessionStore.put(key, session)
-              session
-            case None =>
-              null
-          }
-      }
-      // record access time before returning
-      session match {
-        case null =>
-          null
-        case s: SessionHolder =>
-          s.updateAccessTime()
-          s
-      }
+    // Get the existing session from the store or create a new one.
+    val session = default match {
+      case Some(callable) =>
+        sessionStore.computeIfAbsent(key, _ => callable())
+      case None =>
+        sessionStore.get(key)
     }
+
+    // Record the access time before returning the session holder.
+    if (session != null) {
+      session.updateAccessTime()
+    }
+
+    session
   }
 
   // Removes session from sessionStore and returns it.
   private def removeSessionHolder(key: SessionKey): Option[SessionHolder] = {
     var sessionHolder: Option[SessionHolder] = None
-    sessionsLock.synchronized {
-      sessionHolder = sessionStore.remove(key)
-      sessionHolder.foreach { s =>
-        // Put into closedSessionsCache, so that it cannot get accidentally 
recreated
-        // by getOrCreateIsolatedSession.
-        closedSessionsCache.put(s.key, s.getSessionHolderInfo)
-      }
+
+    // The session holder should remain in the session store until it is added 
to the closed session
+    // cache, because of a subtle data race: a new session with the same key 
can be created if the
+    // closed session cache does not contain the key right after the key has 
been removed from the
+    // session store.
+    sessionHolder = Option(sessionStore.get(key))
+
+    sessionHolder.foreach { s =>
+      // Put into closedSessionsCache to prevent the same session from being 
recreated by
+      // getOrCreateIsolatedSession.
+      closedSessionsCache.put(s.key, s.getSessionHolderInfo)
+
+      // Then, remove the session holder from the session store.
+      sessionStore.remove(key)
     }
     sessionHolder
   }
@@ -176,21 +173,24 @@ class SparkConnectSessionManager extends Logging {
     sessionHolder.foreach(shutdownSessionHolder(_))
   }
 
-  private[connect] def shutdown(): Unit = sessionsLock.synchronized {
-    scheduledExecutor.foreach { executor =>
-      ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
+  private[connect] def shutdown(): Unit = {
+    sessionsLock.synchronized {
+      scheduledExecutor.foreach { executor =>
+        ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
+      }
+      scheduledExecutor = None
     }
-    scheduledExecutor = None
+
     // note: this does not cleanly shut down the sessions, but the server is 
shutting down.
     sessionStore.clear()
     closedSessionsCache.invalidateAll()
   }
 
-  def listActiveSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
-    sessionStore.values.map(_.getSessionHolderInfo).toSeq
+  def listActiveSessions: Seq[SessionHolderInfo] = {
+    sessionStore.values().asScala.map(_.getSessionHolderInfo).toSeq
   }
 
-  def listClosedSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
+  def listClosedSessions: Seq[SessionHolderInfo] = {
     closedSessionsCache.asMap.asScala.values.toSeq
   }
 
@@ -246,18 +246,17 @@ class SparkConnectSessionManager extends Logging {
       timeoutMs != -1 && info.lastAccessTimeMs + timeoutMs <= nowMs
     }
 
-    sessionsLock.synchronized {
-      val nowMs = System.currentTimeMillis()
-      sessionStore.values.foreach { sessionHolder =>
-        if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
-          toRemove += sessionHolder
-        }
+    val nowMs = System.currentTimeMillis()
+    sessionStore.forEach((_, sessionHolder) => {
+      if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
+        toRemove += sessionHolder
       }
-    }
+    })
+
     // .. and remove them.
     toRemove.foreach { sessionHolder =>
       // This doesn't use closeSession to be able to do the extra last chance 
check under lock.
-      val removedSession = sessionsLock.synchronized {
+      val removedSession = {
         // Last chance - check expiration time and remove under lock if 
expired.
         val info = sessionHolder.getSessionHolderInfo
         if (shouldExpire(info, System.currentTimeMillis())) {
@@ -309,7 +308,7 @@ class SparkConnectSessionManager extends Logging {
   /**
    * Used for testing
    */
-  private[connect] def invalidateAllSessions(): Unit = 
sessionsLock.synchronized {
+  private[connect] def invalidateAllSessions(): Unit = {
     periodicMaintenance(defaultInactiveTimeoutMs = 0L, ignoreCustomTimeout = 
true)
     assert(sessionStore.isEmpty)
     closedSessionsCache.invalidateAll()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to