This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4f95a7f4dd2c [SPARK-50260][CONNECT] Refactor and optimize Spark
Connect execution and session management
4f95a7f4dd2c is described below
commit 4f95a7f4dd2c2d30be549d2a90e52e44046e0726
Author: changgyoopark-db <[email protected]>
AuthorDate: Tue Nov 12 14:00:05 2024 +0900
[SPARK-50260][CONNECT] Refactor and optimize Spark Connect execution and
session management
### What changes were proposed in this pull request?
Code refactoring.
- Replace int with a dedicated case class to represent the state of an
execution thread.
Minor optimization.
- Remove unnecessary steps before actually removing expired executions and
sessions.
### Why are the changes needed?
Improve code readability.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48792 from changgyoopark-db/SPARK-50260.
Authored-by: changgyoopark-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../connect/execution/ExecuteThreadRunner.scala | 22 +++++++++------
.../service/SparkConnectExecutionManager.scala | 32 +++++++---------------
.../service/SparkConnectSessionManager.scala | 16 +++--------
3 files changed, 27 insertions(+), 43 deletions(-)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 61be2bc4eb99..d27f390a23f9 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connect.execution
-import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.atomic.AtomicReference
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
@@ -41,7 +41,8 @@ import org.apache.spark.util.Utils
private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder)
extends Logging {
/** The thread state. */
- private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted)
+ private val state: AtomicReference[ThreadStateInfo] = new AtomicReference(
+ ThreadState.notStarted)
// The newly created thread will inherit all InheritableThreadLocals used by
Spark,
// e.g. SparkContext.localProperties. If considering implementing a
thread-pool,
@@ -349,17 +350,20 @@ private[connect] class ExecuteThreadRunner(executeHolder:
ExecuteHolder) extends
private object ThreadState {
/** The thread has not started: transition to interrupted or started. */
- val notStarted: Int = 0
+ val notStarted: ThreadStateInfo = ThreadStateInfo(0)
/** Execution was interrupted: terminal state. */
- val interrupted: Int = 1
+ val interrupted: ThreadStateInfo = ThreadStateInfo(1)
/** The thread has started: transition to startedInterrupted or completed. */
- val started: Int = 2
+ val started: ThreadStateInfo = ThreadStateInfo(2)
- /** The thread has started and execution was interrupted: transition to
completed. */
- val startedInterrupted: Int = 3
+ /** The thread was started and execution has been interrupted: transition to
completed. */
+ val startedInterrupted: ThreadStateInfo = ThreadStateInfo(3)
- /** Execution was completed: terminal state. */
- val completed: Int = 4
+ /** Execution has been completed: terminal state. */
+ val completed: ThreadStateInfo = ThreadStateInfo(4)
}
+
+/** Represents the state of an execution thread. */
+case class ThreadStateInfo(val transitionState: Int)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index d9eb5438c388..f750ca6db67a 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -21,7 +21,6 @@ import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors,
ScheduledExecutorService, TimeUnit}
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
-import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
@@ -160,19 +159,14 @@ private[connect] class SparkConnectExecutionManager()
extends Logging {
}
private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
- var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]()
executions.forEach((_, executeHolder) => {
if (executeHolder.sessionHolder.key == key) {
- sessionExecutionHolders += executeHolder
+ val info = executeHolder.getExecuteInfo
+ logInfo(
+ log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in
removeSessionExecutions.")
+ removeExecuteHolder(executeHolder.key, abandoned = true)
}
})
-
- sessionExecutionHolders.foreach { executeHolder =>
- val info = executeHolder.getExecuteInfo
- logInfo(
- log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in
removeSessionExecutions.")
- removeExecuteHolder(executeHolder.key, abandoned = true)
- }
}
/** Get info about abandoned execution, if there is one. */
@@ -252,30 +246,24 @@ private[connect] class SparkConnectExecutionManager()
extends Logging {
// Visible for testing.
private[connect] def periodicMaintenance(timeout: Long): Unit = {
+ // Find any detached executions that expired and should be removed.
logInfo("Started periodic run of SparkConnectExecutionManager
maintenance.")
- // Find any detached executions that expired and should be removed.
- val toRemove = new mutable.ArrayBuffer[ExecuteHolder]()
val nowMs = System.currentTimeMillis()
-
executions.forEach((_, executeHolder) => {
executeHolder.lastAttachedRpcTimeMs match {
case Some(detached) =>
if (detached + timeout <= nowMs) {
- toRemove += executeHolder
+ val info = executeHolder.getExecuteInfo
+ logInfo(
+ log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was
abandoned " +
+ log"and expired and will be removed.")
+ removeExecuteHolder(executeHolder.key, abandoned = true)
}
case _ => // execution is active
}
})
- // .. and remove them.
- toRemove.foreach { executeHolder =>
- val info = executeHolder.getExecuteInfo
- logInfo(
- log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was
abandoned " +
- log"and expired and will be removed.")
- removeExecuteHolder(executeHolder.key, abandoned = true)
- }
logInfo("Finished periodic run of SparkConnectExecutionManager
maintenance.")
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index 4ca3a80bfb98..a306856efa33 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -21,7 +21,6 @@ import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors,
ScheduledExecutorService, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
-import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
@@ -226,9 +225,8 @@ class SparkConnectSessionManager extends Logging {
private def periodicMaintenance(
defaultInactiveTimeoutMs: Long,
ignoreCustomTimeout: Boolean): Unit = {
- logInfo("Started periodic run of SparkConnectSessionManager maintenance.")
// Find any sessions that expired and should be removed.
- val toRemove = new mutable.ArrayBuffer[SessionHolder]()
+ logInfo("Started periodic run of SparkConnectSessionManager maintenance.")
def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = {
val timeoutMs = if (info.customInactiveTimeoutMs.isDefined &&
!ignoreCustomTimeout) {
@@ -242,15 +240,8 @@ class SparkConnectSessionManager extends Logging {
val nowMs = System.currentTimeMillis()
sessionStore.forEach((_, sessionHolder) => {
- if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
- toRemove += sessionHolder
- }
- })
-
- // .. and remove them.
- toRemove.foreach { sessionHolder =>
val info = sessionHolder.getSessionHolderInfo
- if (shouldExpire(info, System.currentTimeMillis())) {
+ if (shouldExpire(info, nowMs)) {
logInfo(
log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
log"and will be closed.")
@@ -261,7 +252,8 @@ class SparkConnectSessionManager extends Logging {
case NonFatal(ex) => logWarning("Unexpected exception closing
session", ex)
}
}
- }
+ })
+
logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]