This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6eee2717d7b5 [SPARK-50028][CONNECT] Replace global locks in Spark 
Connect server listener with fine-grained locks
6eee2717d7b5 is described below

commit 6eee2717d7b51053c5d1f6e7570bc35001f7b288
Author: Changgyoo Park <[email protected]>
AuthorDate: Tue Nov 5 08:58:41 2024 -0800

    [SPARK-50028][CONNECT] Replace global locks in Spark Connect server 
listener with fine-grained locks
    
    ### What changes were proposed in this pull request?
    
    Replace global locks in Spark Connect server listener with fine-grained 
locks.
    
    ### Why are the changes needed?
    
    Fix.
    - onJobStart and onSQLExecutionStart were not properly synchronized, posing 
a potential data race issue, e.g., if the map is resized while being read.
    
    Get rid of global locks.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48544 from changgyoopark-db/SPARK-50028.
    
    Authored-by: Changgyoo Park <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../connect/ui/SparkConnectServerListener.scala    | 210 ++++++++++++---------
 1 file changed, 116 insertions(+), 94 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
index 65db08be7f90..3a93bbae3f2b 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.connect.ui
 
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
+
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
@@ -37,8 +39,10 @@ private[connect] class SparkConnectServerListener(
     extends SparkListener
     with Logging {
 
-  private val sessionList = new mutable.LinkedHashMap[String, LiveSessionData]
-  private val executionList = new mutable.LinkedHashMap[String, 
LiveExecutionData]
+  private val sessionList: ConcurrentMap[String, LiveSessionData] =
+    new ConcurrentHashMap[String, LiveSessionData]()
+  private val executionList: ConcurrentMap[String, LiveExecutionData] =
+    new ConcurrentHashMap[String, LiveExecutionData]
 
   private val (retainedStatements: Int, retainedSessions: Int) = {
     (
@@ -52,8 +56,8 @@ private[connect] class SparkConnectServerListener(
   private val liveUpdatePeriodNs = if (live) 
sparkConf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L
 
   // Returns true if this listener has no live data. Exposed for tests only.
-  private[connect] def noLiveData(): Boolean = synchronized {
-    sessionList.isEmpty && executionList.isEmpty
+  private[connect] def noLiveData(): Boolean = {
+    sessionList.isEmpty() && executionList.isEmpty()
   }
 
   kvstore.addTrigger(classOf[SessionInfo], retainedSessions) { count =>
@@ -80,11 +84,10 @@ private[connect] class SparkConnectServerListener(
       return
     }
     val executeJobTag = executeJobTagOpt.get
-    val exec = executionList.get(executeJobTag)
+    val exec = Option(executionList.get(executeJobTag))
     if (exec.nonEmpty) {
       exec.foreach { exec =>
-        exec.jobId += jobStart.jobId.toString
-        updateLiveStore(exec)
+        updateLiveStore(exec) { exec => exec.jobId += jobStart.jobId.toString }
       }
     } else {
       // It may possible that event reordering happens, such a way that 
JobStart event come after
@@ -103,9 +106,10 @@ private[connect] class SparkConnectServerListener(
           exec.userId,
           exec.operationId,
           exec.sparkSessionTags)
-        liveExec.sqlExecId = exec.sqlExecId
-        liveExec.jobId += jobStart.jobId.toString
-        updateStoreWithTriggerEnabled(liveExec)
+        updateStoreWithTriggerEnabled(liveExec) { liveExec =>
+          liveExec.sqlExecId = exec.sqlExecId
+          liveExec.jobId += jobStart.jobId.toString
+        }
         executionList.remove(liveExec.jobTag)
       }
     }
@@ -136,11 +140,10 @@ private[connect] class SparkConnectServerListener(
       return
     }
     val executeJobTag = executeJobTagOpt.get
-    val exec = executionList.get(executeJobTag)
+    val exec = Option(executionList.get(executeJobTag))
     if (exec.nonEmpty) {
       exec.foreach { exec =>
-        exec.sqlExecId += e.executionId.toString
-        updateLiveStore(exec)
+        updateLiveStore(exec) { exec => exec.sqlExecId += 
e.executionId.toString }
       }
     } else {
       // This block guards against potential event re-ordering where a 
SQLExecutionStart
@@ -158,15 +161,16 @@ private[connect] class SparkConnectServerListener(
           exec.userId,
           exec.operationId,
           exec.sparkSessionTags)
-        liveExec.jobId = exec.jobId
-        liveExec.sqlExecId += e.executionId.toString
-        updateStoreWithTriggerEnabled(liveExec)
+        updateStoreWithTriggerEnabled(liveExec) { liveExec =>
+          liveExec.jobId = exec.jobId
+          liveExec.sqlExecId += e.executionId.toString
+        }
         executionList.remove(liveExec.jobTag)
       }
     }
   }
 
-  private def onOperationStarted(e: SparkListenerConnectOperationStarted) = 
synchronized {
+  private def onOperationStarted(e: SparkListenerConnectOperationStarted) = {
     val executionData = getOrCreateExecution(
       e.jobTag,
       e.statementText,
@@ -175,13 +179,12 @@ private[connect] class SparkConnectServerListener(
       e.userId,
       e.operationId,
       e.sparkSessionTags)
-    executionData.state = ExecutionState.STARTED
-    executionList.put(e.jobTag, executionData)
-    updateLiveStore(executionData)
-    sessionList.get(e.sessionId) match {
+    updateLiveStore(executionData) { executionData =>
+      executionData.state = ExecutionState.STARTED
+    }
+    Option(sessionList.get(e.sessionId)) match {
       case Some(sessionData) =>
-        sessionData.totalExecution += 1
-        updateLiveStore(sessionData)
+        updateLiveStore(sessionData) { sessionData => 
sessionData.totalExecution += 1 }
       case None =>
         logWarning(
           log"onOperationStart called with unknown session id: 
${MDC(SESSION_ID, e.sessionId)}." +
@@ -189,11 +192,12 @@ private[connect] class SparkConnectServerListener(
     }
   }
 
-  private def onOperationAnalyzed(e: SparkListenerConnectOperationAnalyzed) = 
synchronized {
-    executionList.get(e.jobTag) match {
+  private def onOperationAnalyzed(e: SparkListenerConnectOperationAnalyzed) = {
+    Option(executionList.get(e.jobTag)) match {
       case Some(executionData) =>
-        executionData.state = ExecutionState.COMPILED
-        updateLiveStore(executionData)
+        updateLiveStore(executionData) { executionData =>
+          executionData.state = ExecutionState.COMPILED
+        }
       case None =>
         logWarning(
           log"onOperationAnalyzed called with " +
@@ -202,11 +206,12 @@ private[connect] class SparkConnectServerListener(
   }
 
   private def onOperationReadyForExecution(
-      e: SparkListenerConnectOperationReadyForExecution): Unit = synchronized {
-    executionList.get(e.jobTag) match {
+      e: SparkListenerConnectOperationReadyForExecution): Unit = {
+    Option(executionList.get(e.jobTag)) match {
       case Some(executionData) =>
-        executionData.state = ExecutionState.READY
-        updateLiveStore(executionData)
+        updateLiveStore(executionData) { executionData =>
+          executionData.state = ExecutionState.READY
+        }
       case None =>
         logWarning(
           log"onOperationReadyForExecution called with " +
@@ -214,97 +219,113 @@ private[connect] class SparkConnectServerListener(
     }
   }
 
-  private def onOperationCanceled(e: SparkListenerConnectOperationCanceled) = 
synchronized {
-    executionList.get(e.jobTag) match {
+  private def onOperationCanceled(e: SparkListenerConnectOperationCanceled) = {
+    Option(executionList.get(e.jobTag)) match {
       case Some(executionData) =>
-        executionData.finishTimestamp = e.eventTime
-        executionData.state = ExecutionState.CANCELED
-        updateLiveStore(executionData)
+        updateLiveStore(executionData) { executionData =>
+          executionData.finishTimestamp = e.eventTime
+          executionData.state = ExecutionState.CANCELED
+        }
       case None =>
         logWarning(
           log"onOperationCanceled called with " +
             log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
     }
   }
-  private def onOperationFailed(e: SparkListenerConnectOperationFailed) = 
synchronized {
-    executionList.get(e.jobTag) match {
+  private def onOperationFailed(e: SparkListenerConnectOperationFailed) = {
+    Option(executionList.get(e.jobTag)) match {
       case Some(executionData) =>
-        executionData.finishTimestamp = e.eventTime
-        executionData.detail = e.errorMessage
-        executionData.state = ExecutionState.FAILED
-        updateLiveStore(executionData)
+        updateLiveStore(executionData) { executionData =>
+          executionData.finishTimestamp = e.eventTime
+          executionData.detail = e.errorMessage
+          executionData.state = ExecutionState.FAILED
+        }
       case None =>
         logWarning(
           log"onOperationFailed called with " +
             log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
     }
   }
-  private def onOperationFinished(e: SparkListenerConnectOperationFinished) = 
synchronized {
-    executionList.get(e.jobTag) match {
+  private def onOperationFinished(e: SparkListenerConnectOperationFinished) = {
+    Option(executionList.get(e.jobTag)) match {
       case Some(executionData) =>
-        executionData.finishTimestamp = e.eventTime
-        executionData.state = ExecutionState.FINISHED
-        updateLiveStore(executionData)
+        updateLiveStore(executionData) { executionData =>
+          executionData.finishTimestamp = e.eventTime
+          executionData.state = ExecutionState.FINISHED
+        }
       case None =>
         logWarning(
           log"onOperationFinished called with " +
             log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
     }
   }
-  private def onOperationClosed(e: SparkListenerConnectOperationClosed) = 
synchronized {
-    executionList.get(e.jobTag) match {
-      case Some(executionData) =>
-        executionData.closeTimestamp = e.eventTime
-        executionData.state = ExecutionState.CLOSED
-        updateStoreWithTriggerEnabled(executionData)
-        executionList.remove(e.jobTag)
-      case None =>
-        logWarning(
-          log"onOperationClosed called with " +
-            log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
-    }
+  private def onOperationClosed(e: SparkListenerConnectOperationClosed) = {
+    executionList.compute(
+      e.jobTag,
+      (_, executionData) => {
+        if (executionData != null) {
+          updateStoreWithTriggerEnabled(executionData) { executionData =>
+            executionData.closeTimestamp = e.eventTime
+            executionData.state = ExecutionState.CLOSED
+          }
+        } else {
+          logWarning(
+            log"onOperationClosed called with " +
+              log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
+        }
+        null
+      })
   }
 
-  private def onSessionStarted(e: SparkListenerConnectSessionStarted) = 
synchronized {
+  private def onSessionStarted(e: SparkListenerConnectSessionStarted) = {
     val session = getOrCreateSession(e.sessionId, e.userId, e.eventTime)
-    sessionList.put(e.sessionId, session)
-    updateLiveStore(session)
+    updateLiveStore(session) { _ => () }
   }
 
-  private def onSessionClosed(e: SparkListenerConnectSessionClosed) = 
synchronized {
-    sessionList.get(e.sessionId) match {
-      case Some(sessionData) =>
-        sessionData.finishTimestamp = e.eventTime
-        updateStoreWithTriggerEnabled(sessionData)
-        sessionList.remove(e.sessionId)
-
-      case None =>
-        logWarning(
-          log"onSessionClosed called with " +
-            log"unknown session id: ${MDC(SESSION_ID, e.sessionId)}")
-    }
+  private def onSessionClosed(e: SparkListenerConnectSessionClosed) = {
+    sessionList.compute(
+      e.sessionId,
+      (_, sessionData) => {
+        if (sessionData != null) {
+          updateStoreWithTriggerEnabled(sessionData) { sessionData =>
+            sessionData.finishTimestamp = e.eventTime
+          }
+        } else {
+          logWarning(
+            log"onSessionClosed called with " +
+              log"unknown session id: ${MDC(SESSION_ID, e.sessionId)}")
+        }
+        null
+      })
   }
 
   // Update both live and history stores. Trigger is enabled by default, hence
   // it will cleanup the entity which exceeds the threshold.
-  def updateStoreWithTriggerEnabled(entity: LiveEntity): Unit = synchronized {
-    entity.write(kvstore, System.nanoTime(), checkTriggers = true)
-  }
+  def updateStoreWithTriggerEnabled[T <: LiveEntity](entity: T)(updater: T => 
Unit): Unit =
+    entity.synchronized {
+      updater(entity)
+      entity.write(kvstore, System.nanoTime(), checkTriggers = true)
+    }
 
   // Update only live stores. If trigger is enabled, it will cleanup entity
   // which exceeds the threshold.
-  def updateLiveStore(entity: LiveEntity, trigger: Boolean = false): Unit = 
synchronized {
-    val now = System.nanoTime()
-    if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > 
liveUpdatePeriodNs) {
-      entity.write(kvstore, now, checkTriggers = trigger)
+  def updateLiveStore[T <: LiveEntity](entity: T, trigger: Boolean = false)(
+      updater: T => Unit): Unit =
+    entity.synchronized {
+      updater(entity)
+      val now = System.nanoTime()
+      if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > 
liveUpdatePeriodNs) {
+        entity.write(kvstore, now, checkTriggers = trigger)
+      }
     }
-  }
 
   private def getOrCreateSession(
       sessionId: String,
       userName: String,
-      startTime: Long): LiveSessionData = synchronized {
-    sessionList.getOrElseUpdate(sessionId, new LiveSessionData(sessionId, 
startTime, userName))
+      startTime: Long): LiveSessionData = {
+    sessionList.computeIfAbsent(
+      sessionId,
+      _ => new LiveSessionData(sessionId, startTime, userName))
   }
 
   private def getOrCreateExecution(
@@ -314,17 +335,18 @@ private[connect] class SparkConnectServerListener(
       startTimestamp: Long,
       userId: String,
       operationId: String,
-      sparkSessionTags: Set[String]): LiveExecutionData = synchronized {
-    executionList.getOrElseUpdate(
+      sparkSessionTags: Set[String]): LiveExecutionData = {
+    executionList.computeIfAbsent(
       jobTag,
-      new LiveExecutionData(
-        jobTag,
-        statement,
-        sessionId,
-        startTimestamp,
-        userId,
-        operationId,
-        sparkSessionTags))
+      _ =>
+        new LiveExecutionData(
+          jobTag,
+          statement,
+          sessionId,
+          startTimestamp,
+          userId,
+          operationId,
+          sparkSessionTags))
   }
 
   private def cleanupExecutions(count: Long): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to