This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6eee2717d7b5 [SPARK-50028][CONNECT] Replace global locks in Spark
Connect server listener with fine-grained locks
6eee2717d7b5 is described below
commit 6eee2717d7b51053c5d1f6e7570bc35001f7b288
Author: Changgyoo Park <[email protected]>
AuthorDate: Tue Nov 5 08:58:41 2024 -0800
[SPARK-50028][CONNECT] Replace global locks in Spark Connect server
listener with fine-grained locks
### What changes were proposed in this pull request?
Replace global locks in Spark Connect server listener with fine-grained
locks.
### Why are the changes needed?
Fix.
- onJobStart and onSQLExecutionStart were not properly synchronized, posing
a potential data race issue, e.g., if the map is resized while being read.
Get rid of global locks.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48544 from changgyoopark-db/SPARK-50028.
Authored-by: Changgyoo Park <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../connect/ui/SparkConnectServerListener.scala | 210 ++++++++++++---------
1 file changed, 116 insertions(+), 94 deletions(-)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
index 65db08be7f90..3a93bbae3f2b 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.connect.ui
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -37,8 +39,10 @@ private[connect] class SparkConnectServerListener(
extends SparkListener
with Logging {
- private val sessionList = new mutable.LinkedHashMap[String, LiveSessionData]
- private val executionList = new mutable.LinkedHashMap[String,
LiveExecutionData]
+ private val sessionList: ConcurrentMap[String, LiveSessionData] =
+ new ConcurrentHashMap[String, LiveSessionData]()
+ private val executionList: ConcurrentMap[String, LiveExecutionData] =
+ new ConcurrentHashMap[String, LiveExecutionData]
private val (retainedStatements: Int, retainedSessions: Int) = {
(
@@ -52,8 +56,8 @@ private[connect] class SparkConnectServerListener(
private val liveUpdatePeriodNs = if (live)
sparkConf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L
// Returns true if this listener has no live data. Exposed for tests only.
- private[connect] def noLiveData(): Boolean = synchronized {
- sessionList.isEmpty && executionList.isEmpty
+ private[connect] def noLiveData(): Boolean = {
+ sessionList.isEmpty() && executionList.isEmpty()
}
kvstore.addTrigger(classOf[SessionInfo], retainedSessions) { count =>
@@ -80,11 +84,10 @@ private[connect] class SparkConnectServerListener(
return
}
val executeJobTag = executeJobTagOpt.get
- val exec = executionList.get(executeJobTag)
+ val exec = Option(executionList.get(executeJobTag))
if (exec.nonEmpty) {
exec.foreach { exec =>
- exec.jobId += jobStart.jobId.toString
- updateLiveStore(exec)
+ updateLiveStore(exec) { exec => exec.jobId += jobStart.jobId.toString }
}
} else {
// It may possible that event reordering happens, such a way that
JobStart event come after
@@ -103,9 +106,10 @@ private[connect] class SparkConnectServerListener(
exec.userId,
exec.operationId,
exec.sparkSessionTags)
- liveExec.sqlExecId = exec.sqlExecId
- liveExec.jobId += jobStart.jobId.toString
- updateStoreWithTriggerEnabled(liveExec)
+ updateStoreWithTriggerEnabled(liveExec) { liveExec =>
+ liveExec.sqlExecId = exec.sqlExecId
+ liveExec.jobId += jobStart.jobId.toString
+ }
executionList.remove(liveExec.jobTag)
}
}
@@ -136,11 +140,10 @@ private[connect] class SparkConnectServerListener(
return
}
val executeJobTag = executeJobTagOpt.get
- val exec = executionList.get(executeJobTag)
+ val exec = Option(executionList.get(executeJobTag))
if (exec.nonEmpty) {
exec.foreach { exec =>
- exec.sqlExecId += e.executionId.toString
- updateLiveStore(exec)
+ updateLiveStore(exec) { exec => exec.sqlExecId +=
e.executionId.toString }
}
} else {
// This block guards against potential event re-ordering where a
SQLExecutionStart
@@ -158,15 +161,16 @@ private[connect] class SparkConnectServerListener(
exec.userId,
exec.operationId,
exec.sparkSessionTags)
- liveExec.jobId = exec.jobId
- liveExec.sqlExecId += e.executionId.toString
- updateStoreWithTriggerEnabled(liveExec)
+ updateStoreWithTriggerEnabled(liveExec) { liveExec =>
+ liveExec.jobId = exec.jobId
+ liveExec.sqlExecId += e.executionId.toString
+ }
executionList.remove(liveExec.jobTag)
}
}
}
- private def onOperationStarted(e: SparkListenerConnectOperationStarted) =
synchronized {
+ private def onOperationStarted(e: SparkListenerConnectOperationStarted) = {
val executionData = getOrCreateExecution(
e.jobTag,
e.statementText,
@@ -175,13 +179,12 @@ private[connect] class SparkConnectServerListener(
e.userId,
e.operationId,
e.sparkSessionTags)
- executionData.state = ExecutionState.STARTED
- executionList.put(e.jobTag, executionData)
- updateLiveStore(executionData)
- sessionList.get(e.sessionId) match {
+ updateLiveStore(executionData) { executionData =>
+ executionData.state = ExecutionState.STARTED
+ }
+ Option(sessionList.get(e.sessionId)) match {
case Some(sessionData) =>
- sessionData.totalExecution += 1
- updateLiveStore(sessionData)
+ updateLiveStore(sessionData) { sessionData =>
sessionData.totalExecution += 1 }
case None =>
logWarning(
log"onOperationStart called with unknown session id:
${MDC(SESSION_ID, e.sessionId)}." +
@@ -189,11 +192,12 @@ private[connect] class SparkConnectServerListener(
}
}
- private def onOperationAnalyzed(e: SparkListenerConnectOperationAnalyzed) =
synchronized {
- executionList.get(e.jobTag) match {
+ private def onOperationAnalyzed(e: SparkListenerConnectOperationAnalyzed) = {
+ Option(executionList.get(e.jobTag)) match {
case Some(executionData) =>
- executionData.state = ExecutionState.COMPILED
- updateLiveStore(executionData)
+ updateLiveStore(executionData) { executionData =>
+ executionData.state = ExecutionState.COMPILED
+ }
case None =>
logWarning(
log"onOperationAnalyzed called with " +
@@ -202,11 +206,12 @@ private[connect] class SparkConnectServerListener(
}
private def onOperationReadyForExecution(
- e: SparkListenerConnectOperationReadyForExecution): Unit = synchronized {
- executionList.get(e.jobTag) match {
+ e: SparkListenerConnectOperationReadyForExecution): Unit = {
+ Option(executionList.get(e.jobTag)) match {
case Some(executionData) =>
- executionData.state = ExecutionState.READY
- updateLiveStore(executionData)
+ updateLiveStore(executionData) { executionData =>
+ executionData.state = ExecutionState.READY
+ }
case None =>
logWarning(
log"onOperationReadyForExecution called with " +
@@ -214,97 +219,113 @@ private[connect] class SparkConnectServerListener(
}
}
- private def onOperationCanceled(e: SparkListenerConnectOperationCanceled) =
synchronized {
- executionList.get(e.jobTag) match {
+ private def onOperationCanceled(e: SparkListenerConnectOperationCanceled) = {
+ Option(executionList.get(e.jobTag)) match {
case Some(executionData) =>
- executionData.finishTimestamp = e.eventTime
- executionData.state = ExecutionState.CANCELED
- updateLiveStore(executionData)
+ updateLiveStore(executionData) { executionData =>
+ executionData.finishTimestamp = e.eventTime
+ executionData.state = ExecutionState.CANCELED
+ }
case None =>
logWarning(
log"onOperationCanceled called with " +
log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
}
}
- private def onOperationFailed(e: SparkListenerConnectOperationFailed) =
synchronized {
- executionList.get(e.jobTag) match {
+ private def onOperationFailed(e: SparkListenerConnectOperationFailed) = {
+ Option(executionList.get(e.jobTag)) match {
case Some(executionData) =>
- executionData.finishTimestamp = e.eventTime
- executionData.detail = e.errorMessage
- executionData.state = ExecutionState.FAILED
- updateLiveStore(executionData)
+ updateLiveStore(executionData) { executionData =>
+ executionData.finishTimestamp = e.eventTime
+ executionData.detail = e.errorMessage
+ executionData.state = ExecutionState.FAILED
+ }
case None =>
logWarning(
log"onOperationFailed called with " +
log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
}
}
- private def onOperationFinished(e: SparkListenerConnectOperationFinished) =
synchronized {
- executionList.get(e.jobTag) match {
+ private def onOperationFinished(e: SparkListenerConnectOperationFinished) = {
+ Option(executionList.get(e.jobTag)) match {
case Some(executionData) =>
- executionData.finishTimestamp = e.eventTime
- executionData.state = ExecutionState.FINISHED
- updateLiveStore(executionData)
+ updateLiveStore(executionData) { executionData =>
+ executionData.finishTimestamp = e.eventTime
+ executionData.state = ExecutionState.FINISHED
+ }
case None =>
logWarning(
log"onOperationFinished called with " +
log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
}
}
- private def onOperationClosed(e: SparkListenerConnectOperationClosed) =
synchronized {
- executionList.get(e.jobTag) match {
- case Some(executionData) =>
- executionData.closeTimestamp = e.eventTime
- executionData.state = ExecutionState.CLOSED
- updateStoreWithTriggerEnabled(executionData)
- executionList.remove(e.jobTag)
- case None =>
- logWarning(
- log"onOperationClosed called with " +
- log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
- }
+ private def onOperationClosed(e: SparkListenerConnectOperationClosed) = {
+ executionList.compute(
+ e.jobTag,
+ (_, executionData) => {
+ if (executionData != null) {
+ updateStoreWithTriggerEnabled(executionData) { executionData =>
+ executionData.closeTimestamp = e.eventTime
+ executionData.state = ExecutionState.CLOSED
+ }
+ } else {
+ logWarning(
+ log"onOperationClosed called with " +
+ log"unknown operation id: ${MDC(OP_ID, e.jobTag)}")
+ }
+ null
+ })
}
- private def onSessionStarted(e: SparkListenerConnectSessionStarted) =
synchronized {
+ private def onSessionStarted(e: SparkListenerConnectSessionStarted) = {
val session = getOrCreateSession(e.sessionId, e.userId, e.eventTime)
- sessionList.put(e.sessionId, session)
- updateLiveStore(session)
+ updateLiveStore(session) { _ => () }
}
- private def onSessionClosed(e: SparkListenerConnectSessionClosed) =
synchronized {
- sessionList.get(e.sessionId) match {
- case Some(sessionData) =>
- sessionData.finishTimestamp = e.eventTime
- updateStoreWithTriggerEnabled(sessionData)
- sessionList.remove(e.sessionId)
-
- case None =>
- logWarning(
- log"onSessionClosed called with " +
- log"unknown session id: ${MDC(SESSION_ID, e.sessionId)}")
- }
+ private def onSessionClosed(e: SparkListenerConnectSessionClosed) = {
+ sessionList.compute(
+ e.sessionId,
+ (_, sessionData) => {
+ if (sessionData != null) {
+ updateStoreWithTriggerEnabled(sessionData) { sessionData =>
+ sessionData.finishTimestamp = e.eventTime
+ }
+ } else {
+ logWarning(
+ log"onSessionClosed called with " +
+ log"unknown session id: ${MDC(SESSION_ID, e.sessionId)}")
+ }
+ null
+ })
}
// Update both live and history stores. Trigger is enabled by default, hence
// it will cleanup the entity which exceeds the threshold.
- def updateStoreWithTriggerEnabled(entity: LiveEntity): Unit = synchronized {
- entity.write(kvstore, System.nanoTime(), checkTriggers = true)
- }
+ def updateStoreWithTriggerEnabled[T <: LiveEntity](entity: T)(updater: T =>
Unit): Unit =
+ entity.synchronized {
+ updater(entity)
+ entity.write(kvstore, System.nanoTime(), checkTriggers = true)
+ }
// Update only live stores. If trigger is enabled, it will cleanup entity
// which exceeds the threshold.
- def updateLiveStore(entity: LiveEntity, trigger: Boolean = false): Unit =
synchronized {
- val now = System.nanoTime()
- if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime >
liveUpdatePeriodNs) {
- entity.write(kvstore, now, checkTriggers = trigger)
+ def updateLiveStore[T <: LiveEntity](entity: T, trigger: Boolean = false)(
+ updater: T => Unit): Unit =
+ entity.synchronized {
+ updater(entity)
+ val now = System.nanoTime()
+ if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime >
liveUpdatePeriodNs) {
+ entity.write(kvstore, now, checkTriggers = trigger)
+ }
}
- }
private def getOrCreateSession(
sessionId: String,
userName: String,
- startTime: Long): LiveSessionData = synchronized {
- sessionList.getOrElseUpdate(sessionId, new LiveSessionData(sessionId,
startTime, userName))
+ startTime: Long): LiveSessionData = {
+ sessionList.computeIfAbsent(
+ sessionId,
+ _ => new LiveSessionData(sessionId, startTime, userName))
}
private def getOrCreateExecution(
@@ -314,17 +335,18 @@ private[connect] class SparkConnectServerListener(
startTimestamp: Long,
userId: String,
operationId: String,
- sparkSessionTags: Set[String]): LiveExecutionData = synchronized {
- executionList.getOrElseUpdate(
+ sparkSessionTags: Set[String]): LiveExecutionData = {
+ executionList.computeIfAbsent(
jobTag,
- new LiveExecutionData(
- jobTag,
- statement,
- sessionId,
- startTimestamp,
- userId,
- operationId,
- sparkSessionTags))
+ _ =>
+ new LiveExecutionData(
+ jobTag,
+ statement,
+ sessionId,
+ startTimestamp,
+ userId,
+ operationId,
+ sparkSessionTags))
}
private def cleanupExecutions(count: Long): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]