This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b466f32077e3 [SPARK-49544][CONNECT] Replace coarse-locking in 
SparkConnectExecutionManager with ConcurrentMap
b466f32077e3 is described below

commit b466f32077e3f241cb8dfcd926098e4594635ace
Author: Changgyoo Park <[email protected]>
AuthorDate: Thu Sep 12 08:38:50 2024 +0900

    [SPARK-49544][CONNECT] Replace coarse-locking in 
SparkConnectExecutionManager with ConcurrentMap
    
    ### What changes were proposed in this pull request?
    
    Replace the coarse-locking mechanism implemented in 
SparkConnectExecutionManager with ConcurrentMap in order to ameliorate lock 
contention.
    
    ### Why are the changes needed?
    
    When there are too many threads, e.g., ~10K threads on a 4-core node, OS 
scheduling may cause priority inversion that leads to a serious performance 
problems, e.g., a 1000s delay when reattaching to an execute holder.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48034 from changgyoopark-db/SPARK-49544.
    
    Authored-by: Changgyoo Park <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/connect/service/ExecuteHolder.scala  |  28 +---
 .../service/SparkConnectExecutionManager.scala     | 185 +++++++++++++--------
 .../service/ExecuteEventsManagerSuite.scala        |   3 +-
 3 files changed, 123 insertions(+), 93 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index ec7ebbe92d72..dc349c3e3325 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -17,12 +17,10 @@
 
 package org.apache.spark.sql.connect.service
 
-import java.util.UUID
-
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
-import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.SparkEnv
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Observation
@@ -35,30 +33,19 @@ import org.apache.spark.util.SystemClock
  * Object used to hold the Spark Connect execution state.
  */
 private[connect] class ExecuteHolder(
+    val executeKey: ExecuteKey,
     val request: proto.ExecutePlanRequest,
     val sessionHolder: SessionHolder)
     extends Logging {
 
   val session = sessionHolder.session
 
-  val operationId = if (request.hasOperationId) {
-    try {
-      UUID.fromString(request.getOperationId).toString
-    } catch {
-      case _: IllegalArgumentException =>
-        throw new SparkSQLException(
-          errorClass = "INVALID_HANDLE.FORMAT",
-          messageParameters = Map("handle" -> request.getOperationId))
-    }
-  } else {
-    UUID.randomUUID().toString
-  }
-
   /**
    * Tag that is set for this execution on SparkContext, via 
SparkContext.addJobTag. Used
    * (internally) for cancellation of the Spark Jobs ran by this execution.
    */
-  val jobTag = ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, 
operationId)
+  val jobTag =
+    ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, 
executeKey.operationId)
 
   /**
    * Tags set by Spark Connect client users via SparkSession.addTag. Used to 
identify and group
@@ -278,7 +265,7 @@ private[connect] class ExecuteHolder(
       request = request,
       userId = sessionHolder.userId,
       sessionId = sessionHolder.sessionId,
-      operationId = operationId,
+      operationId = executeKey.operationId,
       jobTag = jobTag,
       sparkSessionTags = sparkSessionTags,
       reattachable = reattachable,
@@ -289,7 +276,10 @@ private[connect] class ExecuteHolder(
   }
 
   /** Get key used by SparkConnectExecutionManager global tracker. */
-  def key: ExecuteKey = ExecuteKey(sessionHolder.userId, 
sessionHolder.sessionId, operationId)
+  def key: ExecuteKey = executeKey
+
+  /** Get the operation ID. */
+  def operationId: String = key.operationId
 }
 
 /** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index 6681a5f509c6..61b41f932199 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.connect.service
 
-import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
+import java.util.UUID
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, 
ScheduledExecutorService, TimeUnit}
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
@@ -36,6 +37,24 @@ import org.apache.spark.util.ThreadUtils
 // Unique key identifying execution by combination of user, session and 
operation id
 case class ExecuteKey(userId: String, sessionId: String, operationId: String)
 
+object ExecuteKey {
+  def apply(request: proto.ExecutePlanRequest, sessionHolder: SessionHolder): 
ExecuteKey = {
+    val operationId = if (request.hasOperationId) {
+      try {
+        UUID.fromString(request.getOperationId).toString
+      } catch {
+        case _: IllegalArgumentException =>
+          throw new SparkSQLException(
+            errorClass = "INVALID_HANDLE.FORMAT",
+            messageParameters = Map("handle" -> request.getOperationId))
+      }
+    } else {
+      UUID.randomUUID().toString
+    }
+    ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId)
+  }
+}
+
 /**
  * Global tracker of all ExecuteHolder executions.
  *
@@ -44,10 +63,9 @@ case class ExecuteKey(userId: String, sessionId: String, 
operationId: String)
  */
 private[connect] class SparkConnectExecutionManager() extends Logging {
 
-  /** Hash table containing all current executions. Guarded by executionsLock. 
*/
-  @GuardedBy("executionsLock")
-  private val executions: mutable.HashMap[ExecuteKey, ExecuteHolder] =
-    new mutable.HashMap[ExecuteKey, ExecuteHolder]()
+  /** Concurrent hash table containing all the current executions. */
+  private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] =
+    new ConcurrentHashMap[ExecuteKey, ExecuteHolder]()
   private val executionsLock = new Object
 
   /** Graveyard of tombstones of executions that were abandoned and removed. */
@@ -61,6 +79,7 @@ private[connect] class SparkConnectExecutionManager() extends 
Logging {
   private var lastExecutionTimeMs: Option[Long] = 
Some(System.currentTimeMillis())
 
   /** Executor for the periodic maintenance */
+  @GuardedBy("executionsLock")
   private var scheduledExecutor: Option[ScheduledExecutorService] = None
 
   /**
@@ -76,27 +95,35 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
         request.getUserContext.getUserId,
         request.getSessionId,
         previousSessionId)
-    val executeHolder = new ExecuteHolder(request, sessionHolder)
+    val executeKey = ExecuteKey(request, sessionHolder)
+    val executeHolder = executions.compute(
+      executeKey,
+      (executeKey, oldExecuteHolder) => {
+        // Check if the operation already exists, either in the active 
execution map, or in the
+        // graveyard of tombstones of executions that have been abandoned. The 
latter is to prevent
+        // double executions when the client retries, thinking it never 
reached the server, but in
+        // fact it did, and already got removed as abandoned.
+        if (oldExecuteHolder != null) {
+          throw new SparkSQLException(
+            errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
+            messageParameters = Map("handle" -> executeKey.operationId))
+        }
+        if (getAbandonedTombstone(executeKey).isDefined) {
+          throw new SparkSQLException(
+            errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
+            messageParameters = Map("handle" -> executeKey.operationId))
+        }
+        new ExecuteHolder(executeKey, request, sessionHolder)
+      })
+
+    sessionHolder.addExecuteHolder(executeHolder)
+
     executionsLock.synchronized {
-      // Check if the operation already exists, both in active executions, and 
in the graveyard
-      // of tombstones of executions that have been abandoned.
-      // The latter is to prevent double execution when a client retries 
execution, thinking it
-      // never reached the server, but in fact it did, and already got removed 
as abandoned.
-      if (executions.get(executeHolder.key).isDefined) {
-        throw new SparkSQLException(
-          errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
-          messageParameters = Map("handle" -> executeHolder.operationId))
-      }
-      if (getAbandonedTombstone(executeHolder.key).isDefined) {
-        throw new SparkSQLException(
-          errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
-          messageParameters = Map("handle" -> executeHolder.operationId))
+      if (!executions.isEmpty()) {
+        lastExecutionTimeMs = None
       }
-      sessionHolder.addExecuteHolder(executeHolder)
-      executions.put(executeHolder.key, executeHolder)
-      lastExecutionTimeMs = None
-      logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} 
is created.")
     }
+    logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} 
is created.")
 
     schedulePeriodicChecks() // Starts the maintenance thread if it hasn't 
started.
 
@@ -108,43 +135,50 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
    * execution if still running, free all resources.
    */
   private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean 
= false): Unit = {
-    var executeHolder: Option[ExecuteHolder] = None
+    val executeHolder = executions.get(key)
+
+    if (executeHolder == null) {
+      return
+    }
+
+    // Put into abandonedTombstones before removing it from executions, so 
that the client ends up
+    // getting an INVALID_HANDLE.OPERATION_ABANDONED error on a retry.
+    if (abandoned) {
+      abandonedTombstones.put(key, executeHolder.getExecuteInfo)
+    }
+
+    // Remove the execution from the map *after* putting it in 
abandonedTombstones.
+    executions.remove(key)
+    executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId)
+
     executionsLock.synchronized {
-      executeHolder = executions.remove(key)
-      executeHolder.foreach { e =>
-        // Put into abandonedTombstones under lock, so that if it's accessed 
it will end up
-        // with INVALID_HANDLE.OPERATION_ABANDONED error.
-        if (abandoned) {
-          abandonedTombstones.put(key, e.getExecuteInfo)
-        }
-        e.sessionHolder.removeExecuteHolder(e.operationId)
-      }
       if (executions.isEmpty) {
         lastExecutionTimeMs = Some(System.currentTimeMillis())
       }
-      logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")
     }
-    // close the execution outside the lock
-    executeHolder.foreach { e =>
-      e.close()
-      if (abandoned) {
-        // Update in abandonedTombstones: above it wasn't yet updated with 
closedTime etc.
-        abandonedTombstones.put(key, e.getExecuteInfo)
-      }
+
+    logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")
+
+    executeHolder.close()
+    if (abandoned) {
+      // Update in abandonedTombstones: above it wasn't yet updated with 
closedTime etc.
+      abandonedTombstones.put(key, executeHolder.getExecuteInfo)
     }
   }
 
   private[connect] def getExecuteHolder(key: ExecuteKey): 
Option[ExecuteHolder] = {
-    executionsLock.synchronized {
-      executions.get(key)
-    }
+    Option(executions.get(key))
   }
 
   private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
-    val sessionExecutionHolders = executionsLock.synchronized {
-      executions.filter(_._2.sessionHolder.key == key)
-    }
-    sessionExecutionHolders.foreach { case (_, executeHolder) =>
+    var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]()
+    executions.forEach((_, executeHolder) => {
+      if (executeHolder.sessionHolder.key == key) {
+        sessionExecutionHolders += executeHolder
+      }
+    })
+
+    sessionExecutionHolders.foreach { executeHolder =>
       val info = executeHolder.getExecuteInfo
       logInfo(
         log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in 
removeSessionExecutions.")
@@ -161,11 +195,11 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
    * If there are no executions, return Left with System.currentTimeMillis of 
last active
    * execution. Otherwise return Right with list of ExecuteInfo of all 
executions.
    */
-  def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = 
executionsLock.synchronized {
+  def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = {
     if (executions.isEmpty) {
       Left(lastExecutionTimeMs.get)
     } else {
-      Right(executions.values.map(_.getExecuteInfo).toBuffer.toSeq)
+      Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq)
     }
   }
 
@@ -177,16 +211,22 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
     abandonedTombstones.asMap.asScala.values.toSeq
   }
 
-  private[connect] def shutdown(): Unit = executionsLock.synchronized {
-    scheduledExecutor.foreach { executor =>
-      ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
+  private[connect] def shutdown(): Unit = {
+    executionsLock.synchronized {
+      scheduledExecutor.foreach { executor =>
+        ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
+      }
+      scheduledExecutor = None
     }
-    scheduledExecutor = None
+
     // note: this does not cleanly shut down the executions, but the server is 
shutting down.
     executions.clear()
     abandonedTombstones.invalidateAll()
-    if (lastExecutionTimeMs.isEmpty) {
-      lastExecutionTimeMs = Some(System.currentTimeMillis())
+
+    executionsLock.synchronized {
+      if (lastExecutionTimeMs.isEmpty) {
+        lastExecutionTimeMs = Some(System.currentTimeMillis())
+      }
     }
   }
 
@@ -225,19 +265,18 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
 
     // Find any detached executions that expired and should be removed.
     val toRemove = new mutable.ArrayBuffer[ExecuteHolder]()
-    executionsLock.synchronized {
-      val nowMs = System.currentTimeMillis()
+    val nowMs = System.currentTimeMillis()
 
-      executions.values.foreach { executeHolder =>
-        executeHolder.lastAttachedRpcTimeMs match {
-          case Some(detached) =>
-            if (detached + timeout <= nowMs) {
-              toRemove += executeHolder
-            }
-          case _ => // execution is active
-        }
+    executions.forEach((_, executeHolder) => {
+      executeHolder.lastAttachedRpcTimeMs match {
+        case Some(detached) =>
+          if (detached + timeout <= nowMs) {
+            toRemove += executeHolder
+          }
+        case _ => // execution is active
       }
-    }
+    })
+
     // .. and remove them.
     toRemove.foreach { executeHolder =>
       val info = executeHolder.getExecuteInfo
@@ -250,16 +289,16 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
   }
 
   // For testing.
-  private[connect] def setAllRPCsDeadline(deadlineMs: Long) = 
executionsLock.synchronized {
-    executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
+  private[connect] def setAllRPCsDeadline(deadlineMs: Long) = {
+    
executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
   }
 
   // For testing.
-  private[connect] def interruptAllRPCs() = executionsLock.synchronized {
-    executions.values.foreach(_.interruptGrpcResponseSenders())
+  private[connect] def interruptAllRPCs() = {
+    executions.values().asScala.foreach(_.interruptGrpcResponseSenders())
   }
 
-  private[connect] def listExecuteHolders: Seq[ExecuteHolder] = 
executionsLock.synchronized {
-    executions.values.toSeq
+  private[connect] def listExecuteHolders: Seq[ExecuteHolder] = {
+    executions.values().asScala.toSeq
   }
 }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
index dbe8420eab03..a9843e261fff 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
@@ -374,7 +374,8 @@ class ExecuteEventsManagerSuite
       .setClientType(DEFAULT_CLIENT_TYPE)
       .build()
 
-    val executeHolder = new ExecuteHolder(executePlanRequest, sessionHolder)
+    val executeKey = ExecuteKey(executePlanRequest, sessionHolder)
+    val executeHolder = new ExecuteHolder(executeKey, executePlanRequest, 
sessionHolder)
 
     val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK)
     eventsManager.status_(executeStatus)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to