This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8bd7b6e1cf00 [SPARK-55606][CONNECT] Server-side implementation of
GetStatus API
8bd7b6e1cf00 is described below
commit 8bd7b6e1cf000d57d910dfd326de980f48bcbe85
Author: Anastasiia Terenteva <[email protected]>
AuthorDate: Wed Feb 25 10:25:06 2026 -0400
[SPARK-55606][CONNECT] Server-side implementation of GetStatus API
### What changes were proposed in this pull request?
Server-side implementation of the GetStatus API:
- Introduce a variable in ExecuteEventsManager to track execution
termination reason after it's closed.
- Track minimal execution termination information in SessionHolder's
inactiveOperations cache.
- Use SessionHolder's activeOperations and inactiveOperations lists for
determining execution status in GetStatus API handler.
- Add plugin interface for GetStatus operation for processing custom proto
extensions.
### Why are the changes needed?
GetStatus API allows to monitor status of executions in a session, which is
particularly useful in multithreaded clients.
### Does this PR introduce _any_ user-facing change?
Yes. It's a new Spark Connect API.
### How was this patch tested?
- New tests were added.
- E2E tests with checking for real (not mocked) execution lifecycles are
coming in client-side PR.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude 4.6 Opus High
Closes #54445 from terana/get-status-server.
Authored-by: Anastasiia Terenteva <[email protected]>
Signed-off-by: Herman van Hövell <[email protected]>
---
.../spark/sql/connect/plugin/GetStatusPlugin.java | 77 +++++
.../apache/spark/sql/connect/config/Connect.scala | 12 +
.../plugin/SparkConnectPluginRegistry.scala | 46 ++-
.../sql/connect/service/ExecuteEventsManager.scala | 79 ++++-
.../spark/sql/connect/service/ExecuteHolder.scala | 21 ++
.../spark/sql/connect/service/SessionHolder.scala | 60 +++-
.../service/SparkConnectExecutionManager.scala | 4 +-
.../service/SparkConnectGetStatusHandler.scala | 218 +++++++++++++
.../sql/connect/service/SparkConnectService.scala | 13 +
.../spark/sql/connect/SparkConnectTestUtils.scala | 28 +-
.../plugin/SparkConnectPluginRegistrySuite.scala | 54 ++++
.../service/ExecuteEventsManagerSuite.scala | 58 ++++
.../connect/service/GetStatusHandlerSuite.scala | 354 +++++++++++++++++++++
.../SparkConnectExecutionManagerSuite.scala | 118 +++++++
.../service/SparkConnectSessionHolderSuite.scala | 88 +++++
15 files changed, 1211 insertions(+), 19 deletions(-)
diff --git
a/sql/connect/server/src/main/java/org/apache/spark/sql/connect/plugin/GetStatusPlugin.java
b/sql/connect/server/src/main/java/org/apache/spark/sql/connect/plugin/GetStatusPlugin.java
new file mode 100644
index 000000000000..77a519099bf7
--- /dev/null
+++
b/sql/connect/server/src/main/java/org/apache/spark/sql/connect/plugin/GetStatusPlugin.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.plugin;
+
+import com.google.protobuf.Any;
+
+import java.util.List;
+import java.util.Optional;
+
+import org.apache.spark.sql.connect.service.SessionHolder;
+
+/**
+ * Plugin interface for extending GetStatus RPC behavior in Spark Connect.
+ *
+ * <p>Classes implementing this interface must be trivially constructable
(have a no-argument
+ * constructor) and should not rely on internal state. The plugin is invoked
during GetStatus
+ * request handling, allowing custom logic to be executed based on request
extensions.
+ *
+ * <p>The GetStatus RPC message has two extension points:
+ * <ul>
+ * <li>{@code GetStatusRequest.extensions} - request-level extensions</li>
+ * <li>{@code GetStatusRequest.OperationStatusRequest.extensions}
+ * - operation-level extensions</li>
+ * </ul>
+ *
+ * <p>And corresponding response extension points:
+ * <ul>
+ * <li>{@code GetStatusResponse.extensions} - response-level extensions</li>
+ * <li>{@code GetStatusResponse.OperationStatus.extensions} -
operation-level extensions</li>
+ * </ul>
+ */
+public interface GetStatusPlugin {
+
+ /**
+ * Process request-level extensions from a GetStatus request.
+ *
+ * <p>This method is called once per GetStatus request, before operation
statuses are processed.
+ * Plugins can use the request extensions to customize behavior and return
response extensions.
+ *
+ * @param sessionHolder the session holder for the current session
+ * @param requestExtensions the extensions from the GetStatus request
+ * @return optional list of response extensions to add to the
GetStatusResponse;
+ * return {@code Optional.empty()} if this plugin does not handle
the request extensions
+ */
+ Optional<List<Any>> processRequestExtensions(
+ SessionHolder sessionHolder, List<Any> requestExtensions);
+
+ /**
+ * Process operation-level extensions from an OperationStatusRequest.
+ *
+ * <p>This method is called once per operation whose status is requested.
+ * Plugins can use the operation-level extensions to customize
per-operation behavior.
+ *
+ * @param operationId the operation ID being queried
+ * @param sessionHolder the session holder for the current session
+ * @param operationExtensions the extensions from the
OperationStatusRequest
+ * @return optional list of response extensions to add to the
OperationStatus;
+ * return {@code Optional.empty()} if this plugin does not handle
the extensions
+ */
+ Optional<List<Any>> processOperationExtensions(
+ String operationId, SessionHolder sessionHolder, List<Any>
operationExtensions);
+}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 1df97d855678..e2d496239d29 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -218,6 +218,18 @@ object Connect {
.toSequence
.createWithDefault(Nil)
+ val CONNECT_EXTENSIONS_GET_STATUS_CLASSES =
+ buildStaticConf("spark.connect.extensions.getStatus.classes")
+ .doc("""
+ |Comma separated list of classes that implement the trait
+ |org.apache.spark.sql.connect.plugin.GetStatusPlugin to support
custom
+ |GetStatus extensions in proto.
+ |""".stripMargin)
+ .version("4.1.0")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
val CONNECT_ML_BACKEND_CLASSES =
buildConf("spark.connect.ml.backend.classes")
.doc("""
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistry.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistry.scala
index fc4e7aed7aed..b26b8fb43fbe 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistry.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistry.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
/**
- * This object provides a global list of configured relation, expression and
command plugins for
- * Spark Connect. The plugins are used to handle custom message types.
+ * This object provides a global list of configured relation, expression,
command, and getStatus
+ * plugins for Spark Connect. The plugins are used to handle custom message
types.
*/
object SparkConnectPluginRegistry {
@@ -46,15 +46,22 @@ object SparkConnectPluginRegistry {
// expression[DummyExpressionPlugin](classOf[DummyExpressionPlugin])
)
+ private lazy val getStatusPluginChain: Seq[getStatusPluginBuilder] = Seq(
+ // Adding a new plugin at compile time works like the example below:
+ // getStatus[DummyGetStatusPlugin](classOf[DummyGetStatusPlugin])
+ )
+
private var initialized = false
private var relationRegistryCache: Seq[RelationPlugin] = Seq.empty
private var expressionRegistryCache: Seq[ExpressionPlugin] = Seq.empty
private var commandRegistryCache: Seq[CommandPlugin] = Seq.empty
+ private var getStatusRegistryCache: Seq[GetStatusPlugin] = Seq.empty
// Type used to identify the closure responsible to instantiate a
ServerInterceptor.
type relationPluginBuilder = () => RelationPlugin
type expressionPluginBuilder = () => ExpressionPlugin
type commandPluginBuilder = () => CommandPlugin
+ type getStatusPluginBuilder = () => GetStatusPlugin
def relationRegistry: Seq[RelationPlugin] = withInitialize {
relationRegistryCache
@@ -65,6 +72,9 @@ object SparkConnectPluginRegistry {
def commandRegistry: Seq[CommandPlugin] = withInitialize {
commandRegistryCache
}
+ def getStatusRegistry: Seq[GetStatusPlugin] = withInitialize {
+ getStatusRegistryCache
+ }
def mlBackendRegistry(conf: SQLConf): Seq[MLBackendPlugin] =
loadMlBackendPlugins(conf)
private def withInitialize[T](f: => Seq[T]): Seq[T] = {
@@ -73,6 +83,7 @@ object SparkConnectPluginRegistry {
relationRegistryCache = loadRelationPlugins()
expressionRegistryCache = loadExpressionPlugins()
commandRegistryCache = loadCommandPlugins()
+ getStatusRegistryCache = loadGetStatusPlugins()
initialized = true
}
}
@@ -88,6 +99,23 @@ object SparkConnectPluginRegistry {
}
}
+ /**
+ * Only visible for testing. Allows injecting test GetStatus plugins
directly into the registry
+ * cache, bypassing the normal plugin chain loading. Forces initialization
of all other caches
+ * if not already initialized, then overrides the GetStatus cache.
+ */
+ private[connect] def setGetStatusPluginsForTesting(plugins:
Seq[GetStatusPlugin]): Unit = {
+ synchronized {
+ if (!initialized) {
+ relationRegistryCache = loadRelationPlugins()
+ expressionRegistryCache = loadExpressionPlugins()
+ commandRegistryCache = loadCommandPlugins()
+ initialized = true
+ }
+ getStatusRegistryCache = plugins
+ }
+ }
+
/**
* Only visible for testing
*/
@@ -109,6 +137,12 @@ object SparkConnectPluginRegistry {
SparkEnv.get.conf.get(Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES))
}
+ private[connect] def loadGetStatusPlugins(): Seq[GetStatusPlugin] = {
+ getStatusPluginChain.map(x => x()) ++
+ createConfiguredPlugins(
+ SparkEnv.get.conf.get(Connect.CONNECT_EXTENSIONS_GET_STATUS_CLASSES))
+ }
+
private[connect] def loadMlBackendPlugins(sqlConf: SQLConf):
Seq[MLBackendPlugin] = {
createConfiguredPlugins(sqlConf.getConf(Connect.CONNECT_ML_BACKEND_CLASSES))
}
@@ -182,4 +216,12 @@ object SparkConnectPluginRegistry {
*/
def command[T <: CommandPlugin](cls: Class[T]): commandPluginBuilder =
() => createInstance[CommandPlugin, T](cls)
+
+ /**
+ * Creates a callable expression that instantiates the configured GetStatus
plugin.
+ *
+ * Visible for testing only.
+ */
+ def getStatus[T <: GetStatusPlugin](cls: Class[T]): getStatusPluginBuilder =
+ () => createInstance[GetStatusPlugin, T](cls)
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
index 61cd95621d15..351be8875ba1 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
@@ -47,7 +47,55 @@ object ExecuteStatus {
}
/**
- * Post request Connect events to @link
org.apache.spark.scheduler.LiveListenerBus.
+ * Records why an operation terminated so that the reason remains available
after the operation
+ * transitions to the closed state.
+ */
+sealed abstract class TerminationReason(val value: Int)
+
+object TerminationReason {
+ case object Succeeded extends TerminationReason(0)
+ case object Failed extends TerminationReason(1)
+ case object Canceled extends TerminationReason(2)
+}
+
+/**
+ * Manage the lifecycle of an operation by tracking its status and termination
reason. Post
+ * request Connect events to @link org.apache.spark.scheduler.LiveListenerBus
and serve as an
+ * information source for GetStatus RPC.
+ *
+ * {{{
+ * +---------------------------------------------+
+ * | ExecuteEventsManager |
+ * | |
+ * | - Tracks operation lifecycle state |
+ * | - Posts events to event bus |
+ * | - Maintains termination info in memory |
+ * +---------------------------------------------+
+ * |
+ * +-----------+-----------+
+ * | |
+ * v v
+ * +------------------+ +------------------+
+ * | GetStatus RPC | | Query History |
+ * | (Direct API) | | (Audit Log) |
+ * +------------------+ +------------------+
+ *
+ * State Mapping Matrix:
+ *
+ * ExecuteStatus -> GetStatus API -> Query History
+ * + TerminationReason
+ * ----------------------------------------------------------------
+ * Pending -> RUNNING -> (no event posted)
+ * Started -> RUNNING -> STARTED
+ * Analyzed -> RUNNING -> COMPILED
+ * ReadyForExecution -> RUNNING -> READY
+ * Finished -> TERMINATING -> FINISHED
+ * Failed -> TERMINATING -> FAILED
+ * Canceled -> TERMINATING -> CANCELED
+ * Closed+Succeeded -> SUCCEEDED -> CLOSED
+ * Closed+Failed -> FAILED -> CLOSED
+ * Closed+Canceled -> CANCELLED -> CLOSED
+ * }}}
*
* @param executeHolder:
* Request for which the events are generated.
@@ -70,7 +118,7 @@ case class ExecuteEventsManager(executeHolder:
ExecuteHolder, clock: Clock) {
private def sessionStatus = sessionHolder.eventManager.status
- private var _status: ExecuteStatus = ExecuteStatus.Pending
+ @volatile private var _status: ExecuteStatus = ExecuteStatus.Pending
private var error = Option.empty[Boolean]
@@ -78,12 +126,36 @@ case class ExecuteEventsManager(executeHolder:
ExecuteHolder, clock: Clock) {
private var producedRowCount = Option.empty[Long]
+ @volatile private var _terminationReason: Option[TerminationReason] = None
+
/**
* @return
* Last event posted by the Connect request
*/
private[connect] def status: ExecuteStatus = _status
+ /**
+ * @return
+ * The reason for termination, set when the operation finishes, fails, or
is canceled. Since
+ * the closed state itself does not convey why the operation ended, this
value preserves that
+ * information for later use.
+ */
+ private[connect] def terminationReason: Option[TerminationReason] =
_terminationReason
+
+ /**
+ * Updates the termination reason only if the new reason has a higher value
than the current
+ * one. This established the ordering Canceled > Failed > Succeeded, which
is consistent with
+ * ExecuteStatus ordering. This handles the cases when execution is
interrupted or fails during
+ * cleanup.
+ */
+ private def updateTerminationReason(newReason: TerminationReason): Unit = {
+ _terminationReason match {
+ case Some(currentReason) if currentReason.value >= newReason.value =>
+ case _ =>
+ _terminationReason = Some(newReason)
+ }
+ }
+
/**
* @return
* True when the Connect request has posted @link
@@ -184,6 +256,7 @@ case class ExecuteEventsManager(executeHolder:
ExecuteHolder, clock: Clock) {
ExecuteStatus.Failed),
ExecuteStatus.Canceled)
canceled = Some(true)
+ updateTerminationReason(TerminationReason.Canceled)
listenerBus
.post(SparkListenerConnectOperationCanceled(jobTag, operationId,
clock.getTimeMillis()))
}
@@ -203,6 +276,7 @@ case class ExecuteEventsManager(executeHolder:
ExecuteHolder, clock: Clock) {
ExecuteStatus.Finished),
ExecuteStatus.Failed)
error = Some(true)
+ updateTerminationReason(TerminationReason.Failed)
listenerBus.post(
SparkListenerConnectOperationFailed(
jobTag,
@@ -224,6 +298,7 @@ case class ExecuteEventsManager(executeHolder:
ExecuteHolder, clock: Clock) {
List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution),
ExecuteStatus.Finished)
producedRowCount = producedRowsCountOpt
+ updateTerminationReason(TerminationReason.Succeeded)
listenerBus
.post(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 42574b1f8d43..7f5670f43960 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -308,6 +308,8 @@ private[connect] class ExecuteHolder(
responseObserver.removeAll()
// Post "closed" to UI.
eventsManager.postClosed()
+ // Update the termination info in the session holder after closure.
+ sessionHolder.closeOperation(this)
}
}
@@ -331,6 +333,7 @@ private[connect] class ExecuteHolder(
sparkSessionTags = sparkSessionTags,
reattachable = reattachable,
status = eventsManager.status,
+ terminationReason = eventsManager.terminationReason,
creationTimeNs = creationTimeNs,
lastAttachedRpcTimeNs = lastAttachedRpcTimeNs,
closedTimeNs = closedTimeNs)
@@ -341,6 +344,15 @@ private[connect] class ExecuteHolder(
/** Get the operation ID. */
def operationId: String = key.operationId
+
+ def getTerminationInfo: TerminationInfo = {
+ TerminationInfo(
+ userId = sessionHolder.userId,
+ sessionId = sessionHolder.sessionId,
+ operationId = executeKey.operationId,
+ status = eventsManager.status,
+ terminationReason = eventsManager.terminationReason)
+ }
}
private object ExecuteHolder {
@@ -406,9 +418,18 @@ case class ExecuteInfo(
sparkSessionTags: Set[String],
reattachable: Boolean,
status: ExecuteStatus,
+ terminationReason: Option[TerminationReason],
creationTimeNs: Long,
lastAttachedRpcTimeNs: Option[Long],
closedTimeNs: Option[Long]) {
def key: ExecuteKey = ExecuteKey(userId, sessionId, operationId)
}
+
+/** Minimal termination status information for inactive operations. */
+case class TerminationInfo(
+ userId: String,
+ sessionId: String,
+ operationId: String,
+ status: ExecuteStatus,
+ terminationReason: Option[TerminationReason])
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index d0d0f0ba750a..912543ac13dd 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -96,16 +96,15 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
// Set of active operation IDs for this session.
private val activeOperationIds: mutable.Set[String] = mutable.Set.empty
- // Cache of inactive operation IDs for this session, either completed,
interrupted or abandoned.
- // The Boolean is just a placeholder since Guava needs a <K, V> pair.
- private lazy val inactiveOperationIds: Cache[String, Boolean] =
+ // Cache of inactive operations for this session, either completed,
interrupted or abandoned.
+ private lazy val inactiveOperations: Cache[String, TerminationInfo] =
CacheBuilder
.newBuilder()
.ticker(Ticker.systemTicker())
.expireAfterAccess(
SparkEnv.get.conf.get(Connect.CONNECT_INACTIVE_OPERATIONS_CACHE_EXPIRATION_MINS),
TimeUnit.MINUTES)
- .build[String, Boolean]()
+ .build[String, TerminationInfo]()
// The cache that maps an error id to a throwable. The throwable in cache is
independent to
// each other.
@@ -197,7 +196,7 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
if (activeOperationIds.contains(operationId)) {
return Some(true)
}
- Option(inactiveOperationIds.getIfPresent(operationId)) match {
+ Option(inactiveOperations.getIfPresent(operationId)) match {
case Some(_) =>
return Some(false)
case None =>
@@ -206,13 +205,50 @@ case class SessionHolder(userId: String, sessionId:
String, session: SparkSessio
}
/**
- * Close an operation in this session by removing its operation ID.
+ * Returns the TerminationInfo for an inactive operation if it exists in the
cache. Cache
+ * expiration is configured with
CONNECT_INACTIVE_OPERATIONS_CACHE_EXPIRATION_MINS.
*
- * Called only by SparkConnectExecutionManager when an execution is ended.
+ * @param operationId
+ * @return
+ * Some(TerminationInfo) if the operation was closed recently, None if no
inactive operation
+ * with this id is found.
+ */
+ private[service] def getInactiveOperationInfo(operationId: String):
Option[TerminationInfo] = {
+ Option(inactiveOperations.getIfPresent(operationId))
+ }
+
+ /**
+ * Returns all inactive operations for this session. These are operations
that were closed and
+ * are still in the cache. Cache expiration is configured with
+ * CONNECT_INACTIVE_OPERATIONS_CACHE_EXPIRATION_MINS.
+ *
+ * @return
+ * Sequence of TerminationInfo for inactive operations.
+ */
+ private[service] def listInactiveOperations(): Seq[TerminationInfo] = {
+ inactiveOperations.asMap().values().asScala.toSeq
+ }
+
+ /**
+ * Returns all active operation IDs for this session.
+ *
+ * @return
+ * Sequence of operation IDs that are currently active.
+ */
+ private[service] def listActiveOperationIds(): Seq[String] = {
+ activeOperationIds.synchronized {
+ activeOperationIds.toSeq
+ }
+ }
+
+ /**
+ * Close an operation in this session by storing its TerminationInfo and
removing from active
+ * set.
*/
- private[service] def closeOperation(operationId: String): Unit = {
- inactiveOperationIds.put(operationId, true)
- activeOperationIds.remove(operationId)
+ private[service] def closeOperation(executeHolder: ExecuteHolder): Unit = {
+ val terminationInfo = executeHolder.getTerminationInfo
+ inactiveOperations.put(terminationInfo.operationId, terminationInfo)
+ activeOperationIds.remove(terminationInfo.operationId)
}
/**
@@ -249,7 +285,7 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
SparkConnectService.executionManager.getExecuteHolder(executeKey).foreach {
executeHolder =>
if (executeHolder.sparkSessionTags.contains(tag)) {
if (executeHolder.interrupt()) {
- closeOperation(operationId)
+ closeOperation(executeHolder)
interruptedIds += operationId
}
}
@@ -268,7 +304,7 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
val executeKey = ExecuteKey(userId, sessionId, operationId)
SparkConnectService.executionManager.getExecuteHolder(executeKey).foreach
{ executeHolder =>
if (executeHolder.interrupt()) {
- closeOperation(operationId)
+ closeOperation(executeHolder)
interruptedIds += operationId
}
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index 5f01676d9f89..768c6a858188 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -157,12 +157,12 @@ private[connect] class SparkConnectExecutionManager()
extends Logging {
// getting an INVALID_HANDLE.OPERATION_ABANDONED error on a retry.
if (abandoned) {
abandonedTombstones.put(key, executeHolder.getExecuteInfo)
- executeHolder.sessionHolder.closeOperation(executeHolder.operationId)
+ executeHolder.sessionHolder.closeOperation(executeHolder)
}
// Remove the execution from the map *after* putting it in
abandonedTombstones.
executions.remove(key)
- executeHolder.sessionHolder.closeOperation(executeHolder.operationId)
+ executeHolder.sessionHolder.closeOperation(executeHolder)
updateLastExecutionTime()
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectGetStatusHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectGetStatusHandler.scala
new file mode 100644
index 000000000000..c37061ff2eba
--- /dev/null
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectGetStatusHandler.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import scala.jdk.CollectionConverters._
+import scala.jdk.OptionConverters._
+import scala.util.control.NonFatal
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.{Logging, LogKeys}
+import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
+
+class SparkConnectGetStatusHandler(responseObserver:
StreamObserver[proto.GetStatusResponse])
+ extends Logging {
+
+ def handle(request: proto.GetStatusRequest): Unit = {
+ val previousSessionId = request.hasClientObservedServerSideSessionId match
{
+ case true => Some(request.getClientObservedServerSideSessionId)
+ case false => None
+ }
+ val sessionHolder = SparkConnectService.sessionManager.getIsolatedSession(
+ SessionKey(request.getUserContext.getUserId, request.getSessionId),
+ previousSessionId)
+
+ val responseBuilder = proto.GetStatusResponse
+ .newBuilder()
+ .setSessionId(request.getSessionId)
+ .setServerSideSessionId(sessionHolder.serverSessionId)
+
+ val responseExtensions =
+ processRequestExtensionsViaPlugins(sessionHolder,
request.getExtensionsList)
+ responseExtensions.foreach(responseBuilder.addExtensions)
+
+ if (request.hasOperationStatus) {
+ val operationStatusRequest = request.getOperationStatus
+ val requestedOperationIds =
+ operationStatusRequest.getOperationIdsList.asScala.distinct.toSeq
+ val operationExtensions = operationStatusRequest.getExtensionsList
+
+ val operationStatuses = if (requestedOperationIds.isEmpty) {
+ // If no specific operation IDs are requested,
+ // return status of all known operations in session
+ getAllOperationStatuses(sessionHolder, operationExtensions)
+ } else {
+ // Return status only for the requested operation IDs
+ requestedOperationIds.map { operationId =>
+ getOperationStatus(sessionHolder, operationId, operationExtensions)
+ }
+ }
+
+ operationStatuses.foreach(responseBuilder.addOperationStatuses)
+ }
+
+ responseObserver.onNext(responseBuilder.build())
+ responseObserver.onCompleted()
+ }
+
+ private def getOperationStatus(
+ sessionHolder: SessionHolder,
+ operationId: String,
+ operationExtensions: java.util.List[com.google.protobuf.Any])
+ : proto.GetStatusResponse.OperationStatus = {
+ val executeKey = ExecuteKey(sessionHolder.userId, sessionHolder.sessionId,
operationId)
+
+ // First look up operation in active list, then in inactive. This ordering
handles the case
+ // where a concurrent thread moves the operation to inactive, and we don't
find it neither in
+ // active list, nor in inactive.
+ val activeState:
Option[proto.GetStatusResponse.OperationStatus.OperationState] =
+ SparkConnectService.executionManager
+ .getExecuteHolder(executeKey)
+ .map { executeHolder =>
+ val info = executeHolder.getExecuteInfo
+ mapStatusToState(info.operationId, info.status,
info.terminationReason)
+ }
+
+ // Check inactiveOperations - this status prevails over activeState.
+ val state = sessionHolder
+ .getInactiveOperationInfo(operationId)
+ .map { info =>
+ mapStatusToState(info.operationId, info.status, info.terminationReason)
+ }
+ .orElse(activeState)
+
.getOrElse(proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_UNKNOWN)
+
+ val responseExtensions =
+ processOperationExtensionsViaPlugins(sessionHolder, operationExtensions,
operationId)
+
+ buildOperationStatus(operationId, state, responseExtensions)
+ }
+
+ private def getAllOperationStatuses(
+ sessionHolder: SessionHolder,
+ operationExtensions: java.util.List[com.google.protobuf.Any])
+ : Seq[proto.GetStatusResponse.OperationStatus] = {
+ val allOperationIds =
+ (sessionHolder.listActiveOperationIds() ++
+ sessionHolder.listInactiveOperations().map(_.operationId)).distinct
+
+ allOperationIds.map { operationId =>
+ getOperationStatus(sessionHolder, operationId, operationExtensions)
+ }
+ }
+
+ private def mapStatusToState(
+ operationId: String,
+ status: ExecuteStatus,
+ terminationReason: Option[TerminationReason])
+ : proto.GetStatusResponse.OperationStatus.OperationState = {
+ status match {
+ case ExecuteStatus.Pending | ExecuteStatus.Started |
ExecuteStatus.Analyzed |
+ ExecuteStatus.ReadyForExecution =>
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING
+
+ // Finished, Failed, Canceled are terminating states - resources haven't
been cleaned yet
+ case ExecuteStatus.Finished | ExecuteStatus.Failed |
ExecuteStatus.Canceled =>
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_TERMINATING
+
+ case ExecuteStatus.Closed =>
+ if (terminationReason.isEmpty) {
+ // This should not happen: ExecuteEventsManager processes state
transitions
+ // from a single thread at a time, so there are no concurrent
changes and
+ // terminationReason should always be set before reaching Closed.
+ logError(
+ log"Operation ${MDC(LogKeys.OPERATION_ID, operationId)} is Closed
but " +
+ log"terminationReason is not set. status=${MDC(LogKeys.STATUS,
status)}")
+ }
+ mapTerminationReasonToState(terminationReason)
+ }
+ }
+
+ private def mapTerminationReasonToState(terminationReason:
Option[TerminationReason])
+ : proto.GetStatusResponse.OperationStatus.OperationState = {
+ terminationReason match {
+ case Some(TerminationReason.Succeeded) =>
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED
+ case Some(TerminationReason.Failed) =>
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_FAILED
+ case Some(TerminationReason.Canceled) =>
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_CANCELLED
+ case None =>
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_UNKNOWN
+ }
+ }
+
+ private def buildOperationStatus(
+ operationId: String,
+ state: proto.GetStatusResponse.OperationStatus.OperationState,
+ extensions: Seq[com.google.protobuf.Any] = Seq.empty)
+ : proto.GetStatusResponse.OperationStatus = {
+ val builder = proto.GetStatusResponse.OperationStatus
+ .newBuilder()
+ .setOperationId(operationId)
+ .setState(state)
+ extensions.foreach(builder.addExtensions)
+ builder.build()
+ }
+
+ private def processRequestExtensionsViaPlugins(
+ sessionHolder: SessionHolder,
+ requestExtensions: java.util.List[com.google.protobuf.Any])
+ : Seq[com.google.protobuf.Any] = {
+ SparkConnectPluginRegistry.getStatusRegistry.flatMap { plugin =>
+ try {
+ plugin.processRequestExtensions(sessionHolder,
requestExtensions).toScala match {
+ case Some(extensions) => extensions.asScala.toSeq
+ case None => Seq.empty
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(
+ log"Plugin ${MDC(LogKeys.CLASS_NAME, plugin.getClass.getName)}
failed to process " +
+ log"request extensions",
+ e)
+ Seq.empty
+ }
+ }
+ }
+
+ private def processOperationExtensionsViaPlugins(
+ sessionHolder: SessionHolder,
+ operationExtensions: java.util.List[com.google.protobuf.Any],
+ operationId: String): Seq[com.google.protobuf.Any] = {
+ SparkConnectPluginRegistry.getStatusRegistry.flatMap { plugin =>
+ try {
+ plugin
+ .processOperationExtensions(operationId, sessionHolder,
operationExtensions)
+ .toScala match {
+ case Some(extensions) => extensions.asScala.toSeq
+ case None => Seq.empty
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(
+ log"Plugin ${MDC(LogKeys.CLASS_NAME, plugin.getClass.getName)}
failed to process " +
+ log"operation extensions for operation
${MDC(LogKeys.OPERATION_ID, operationId)}",
+ e)
+ Seq.empty
+ }
+ }
+ }
+}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 00b93c19b2c7..c14c21bd6ccb 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -246,6 +246,19 @@ class SparkConnectService(debug: Boolean) extends
AsyncService with BindableServ
}
}
+ override def getStatus(
+ request: proto.GetStatusRequest,
+ responseObserver: StreamObserver[proto.GetStatusResponse]): Unit = {
+ try {
+ new SparkConnectGetStatusHandler(responseObserver).handle(request)
+ } catch
+ ErrorUtils.handleError(
+ "getStatus",
+ observer = responseObserver,
+ userId = request.getUserContext.getUserId,
+ sessionId = request.getSessionId)
+ }
+
private def methodWithCustomMarshallers(
methodDesc: MethodDescriptor[Message, Message]):
MethodDescriptor[Message, Message] = {
val recursionLimit =
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala
index d06c93cc1cad..d6a6390600da 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala
@@ -18,8 +18,9 @@ package org.apache.spark.sql.connect
import java.util.UUID
+import org.apache.spark.connect.proto
import org.apache.spark.sql.classic.SparkSession
-import org.apache.spark.sql.connect.service.{SessionHolder,
SparkConnectService}
+import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus,
SessionHolder, SessionStatus, SparkConnectService}
object SparkConnectTestUtils {
@@ -33,4 +34,29 @@ object SparkConnectTestUtils {
SparkConnectService.sessionManager.putSessionForTesting(ret)
ret
}
+
+ /** Creates a dummy execute holder for use in tests. */
+ def createDummyExecuteHolder(
+ sessionHolder: SessionHolder,
+ command: proto.Command): ExecuteHolder = {
+ sessionHolder.eventManager.status_(SessionStatus.Started)
+ val request = proto.ExecutePlanRequest
+ .newBuilder()
+ .setPlan(
+ proto.Plan
+ .newBuilder()
+ .setCommand(command)
+ .build())
+ .setSessionId(sessionHolder.sessionId)
+ .setUserContext(
+ proto.UserContext
+ .newBuilder()
+ .setUserId(sessionHolder.userId)
+ .build())
+ .build()
+ val executeHolder =
+ SparkConnectService.executionManager.createExecuteHolder(request)
+ executeHolder.eventsManager.status_(ExecuteStatus.Started)
+ executeHolder
+ }
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
index 32dbc9595eab..4617a8684224 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.connect.ConnectProtoUtils
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.{SparkConnectPlanner,
SparkConnectPlanTest}
+import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.test.SharedSparkSession
class DummyPlugin extends RelationPlugin {
@@ -45,6 +46,19 @@ class DummyExpressionPlugin extends ExpressionPlugin {
planner: SparkConnectPlanner): Optional[Expression] = Optional.empty()
}
+class DummyGetStatusPlugin extends GetStatusPlugin {
+ override def processRequestExtensions(
+ sessionHolder: SessionHolder,
+ requestExtensions: java.util.List[protobuf.Any]):
Optional[java.util.List[protobuf.Any]] =
+ Optional.empty()
+
+ override def processOperationExtensions(
+ operationId: String,
+ sessionHolder: SessionHolder,
+ operationExtensions: java.util.List[protobuf.Any]):
Optional[java.util.List[protobuf.Any]] =
+ Optional.empty()
+}
+
class DummyPluginNoTrivialCtor(id: Int) extends RelationPlugin {
override def transform(
relation: Array[Byte],
@@ -117,6 +131,9 @@ class SparkConnectPluginRegistrySuite extends
SharedSparkSession with SparkConne
if
(SparkEnv.get.conf.contains(Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES)) {
SparkEnv.get.conf.remove(Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES)
}
+ if
(SparkEnv.get.conf.contains(Connect.CONNECT_EXTENSIONS_GET_STATUS_CLASSES)) {
+ SparkEnv.get.conf.remove(Connect.CONNECT_EXTENSIONS_GET_STATUS_CLASSES)
+ }
SparkConnectPluginRegistry.reset()
}
@@ -239,6 +256,31 @@ class SparkConnectPluginRegistrySuite extends
SharedSparkSession with SparkConne
assert(SparkConnectPluginRegistry.loadCommandPlugins().isEmpty)
}
+ test("GetStatus registry is empty by default") {
+ assert(SparkConnectPluginRegistry.loadGetStatusPlugins().isEmpty)
+ }
+
+ test("GetStatus plugin loaded dynamically from config") {
+ withSparkConf(
+ Connect.CONNECT_EXTENSIONS_GET_STATUS_CLASSES.key ->
+ "org.apache.spark.sql.connect.plugin.DummyGetStatusPlugin") {
+ val plugins = SparkConnectPluginRegistry.loadGetStatusPlugins()
+ assert(plugins.size == 1)
+ assert(plugins.head.isInstanceOf[GetStatusPlugin])
+ }
+ }
+
+ test("Multiple GetStatus plugins loaded dynamically from config") {
+ withSparkConf(
+ Connect.CONNECT_EXTENSIONS_GET_STATUS_CLASSES.key ->
+ ("org.apache.spark.sql.connect.plugin.DummyGetStatusPlugin," +
+ "org.apache.spark.sql.connect.plugin.DummyGetStatusPlugin")) {
+ val plugins = SparkConnectPluginRegistry.loadGetStatusPlugins()
+ assert(plugins.size == 2)
+ plugins.foreach(p => assert(p.isInstanceOf[GetStatusPlugin]))
+ }
+ }
+
test("Building builders using factory methods") {
val x =
SparkConnectPluginRegistry.relation[DummyPlugin](classOf[DummyPlugin])
assert(x != null)
@@ -247,6 +289,10 @@ class SparkConnectPluginRegistrySuite extends
SharedSparkSession with SparkConne
SparkConnectPluginRegistry.expression[DummyExpressionPlugin](classOf[DummyExpressionPlugin])
assert(y != null)
assert(y().isInstanceOf[ExpressionPlugin])
+ val z =
+
SparkConnectPluginRegistry.getStatus[DummyGetStatusPlugin](classOf[DummyGetStatusPlugin])
+ assert(z != null)
+ assert(z().isInstanceOf[GetStatusPlugin])
}
test("Configured class not found is properly thrown") {
@@ -265,6 +311,14 @@ class SparkConnectPluginRegistrySuite extends
SharedSparkSession with SparkConne
SparkEnv.get.conf.get(Connect.CONNECT_EXTENSIONS_RELATION_CLASSES))
}
}
+
+ withSparkConf(
+ Connect.CONNECT_EXTENSIONS_GET_STATUS_CLASSES.key ->
"this.class.does.not.exist") {
+ assertThrows[ClassNotFoundException] {
+ SparkConnectPluginRegistry.createConfiguredPlugins(
+ SparkEnv.get.conf.get(Connect.CONNECT_EXTENSIONS_GET_STATUS_CLASSES))
+ }
+ }
}
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
index a17c76ae9528..a96d0ab977c5 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
@@ -341,6 +341,64 @@ class ExecuteEventsManagerSuite
}
}
+ test("terminationReason is None initially") {
+ val events = setupEvents(ExecuteStatus.Pending)
+ assert(events.terminationReason.isEmpty)
+ }
+
+ test("terminationReason is set to Succeeded after postFinished") {
+ val events = setupEvents(ExecuteStatus.Started)
+ assert(events.terminationReason.isEmpty)
+ events.postFinished()
+ assert(events.terminationReason.contains(TerminationReason.Succeeded))
+ }
+
+ test("terminationReason is set to Failed after postFailed") {
+ val events = setupEvents(ExecuteStatus.Started)
+ assert(events.terminationReason.isEmpty)
+ events.postFailed(DEFAULT_ERROR)
+ assert(events.terminationReason.contains(TerminationReason.Failed))
+ }
+
+ test("terminationReason is set to Canceled after postCanceled") {
+ val events = setupEvents(ExecuteStatus.Started)
+ assert(events.terminationReason.isEmpty)
+ events.postCanceled()
+ assert(events.terminationReason.contains(TerminationReason.Canceled))
+ }
+
+ test("terminationReason remains unchanged after postClosed") {
+ val events = setupEvents(ExecuteStatus.Started)
+ events.postFinished()
+ assert(events.terminationReason.contains(TerminationReason.Succeeded))
+ events.postClosed()
+ assert(events.terminationReason.contains(TerminationReason.Succeeded))
+ }
+
+ test("terminationReason: Canceled takes precedence over Succeeded") {
+ val events = setupEvents(ExecuteStatus.Started)
+ events.postFinished()
+ assert(events.terminationReason.contains(TerminationReason.Succeeded))
+ events.postCanceled()
+ assert(events.terminationReason.contains(TerminationReason.Canceled))
+ }
+
+ test("terminationReason: Canceled takes precedence over Failed") {
+ val events = setupEvents(ExecuteStatus.Started)
+ events.postFailed(DEFAULT_ERROR)
+ assert(events.terminationReason.contains(TerminationReason.Failed))
+ events.postCanceled()
+ assert(events.terminationReason.contains(TerminationReason.Canceled))
+ }
+
+ test("terminationReason: Failed takes precedence over Succeeded") {
+ val events = setupEvents(ExecuteStatus.Started)
+ events.postFinished()
+ assert(events.terminationReason.contains(TerminationReason.Succeeded))
+ events.postFailed(DEFAULT_ERROR)
+ assert(events.terminationReason.contains(TerminationReason.Failed))
+ }
+
def setupEvents(
executeStatus: ExecuteStatus,
sessionStatus: SessionStatus = SessionStatus.Started):
ExecuteEventsManager = {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerSuite.scala
new file mode 100644
index 000000000000..21d96fd8a87d
--- /dev/null
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerSuite.scala
@@ -0,0 +1,354 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import java.util
+import java.util.{Optional, UUID}
+
+import scala.concurrent.Promise
+import scala.concurrent.duration._
+import scala.jdk.CollectionConverters._
+
+import com.google.protobuf
+import com.google.protobuf.StringValue
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.GetStatusResponse
+import org.apache.spark.sql.connect.SparkConnectTestUtils
+import org.apache.spark.sql.connect.plugin.{GetStatusPlugin,
SparkConnectPluginRegistry}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A test-only base class for GetStatusPlugins that echo extensions back with
configurable
+ * prefixes. For each input extension containing a StringValue, it produces a
response extension
+ * with the value prefixed by "{requestPrefix}" or "{opPrefix}{operationId}:".
+ */
+abstract class EchoGetStatusPluginBase(requestPrefix: String, opPrefix: String)
+ extends GetStatusPlugin {
+
+ override def processRequestExtensions(
+ sessionHolder: SessionHolder,
+ requestExtensions: util.List[protobuf.Any]):
Optional[util.List[protobuf.Any]] =
+ echoWithPrefix(requestExtensions, requestPrefix)
+
+ override def processOperationExtensions(
+ operationId: String,
+ sessionHolder: SessionHolder,
+ operationExtensions: util.List[protobuf.Any]):
Optional[util.List[protobuf.Any]] =
+ echoWithPrefix(operationExtensions, s"$opPrefix$operationId:")
+
+ private def echoWithPrefix(
+ extensions: util.List[protobuf.Any],
+ prefix: String): Optional[util.List[protobuf.Any]] = {
+ if (extensions.isEmpty) return Optional.empty()
+ val result = new util.ArrayList[protobuf.Any]()
+ extensions.forEach { ext =>
+ if (ext.is(classOf[StringValue])) {
+ val value = ext.unpack(classOf[StringValue]).getValue
+ result.add(protobuf.Any.pack(StringValue.of(s"$prefix$value")))
+ }
+ }
+ Optional.of(result)
+ }
+}
+
+class EchoGetStatusPlugin extends EchoGetStatusPluginBase("request-echo:",
"op-echo:")
+
+class SecondEchoGetStatusPlugin extends
EchoGetStatusPluginBase("second-request:", "second-op:")
+
+/**
+ * A no-op plugin that always returns Optional.empty() for both request and
operation extensions.
+ */
+class NoOpGetStatusPlugin extends GetStatusPlugin {
+ override def processRequestExtensions(
+ sessionHolder: SessionHolder,
+ requestExtensions: util.List[protobuf.Any]):
Optional[util.List[protobuf.Any]] =
+ Optional.empty()
+
+ override def processOperationExtensions(
+ operationId: String,
+ sessionHolder: SessionHolder,
+ operationExtensions: util.List[protobuf.Any]):
Optional[util.List[protobuf.Any]] =
+ Optional.empty()
+}
+
+/**
+ * A plugin that always throws a RuntimeException.
+ */
+class FailingGetStatusPlugin extends GetStatusPlugin {
+ override def processRequestExtensions(
+ sessionHolder: SessionHolder,
+ requestExtensions: util.List[protobuf.Any]):
Optional[util.List[protobuf.Any]] =
+ throw new RuntimeException("request plugin failure")
+
+ override def processOperationExtensions(
+ operationId: String,
+ sessionHolder: SessionHolder,
+ operationExtensions: util.List[protobuf.Any]):
Optional[util.List[protobuf.Any]] =
+ throw new RuntimeException("operation plugin failure")
+}
+
+class GetStatusHandlerSuite extends SharedSparkSession {
+
+ // Default userId matching SparkConnectTestUtils.createDummySessionHolder
default
+ private val defaultUserId = "testUser"
+
+ protected override def afterEach(): Unit = {
+ super.afterEach()
+ SparkConnectService.sessionManager.invalidateAllSessions()
+ SparkConnectPluginRegistry.reset()
+ }
+
+ private def sendGetStatusRequest(
+ sessionId: String,
+ userId: String = defaultUserId,
+ serverSideSessionId: Option[String] = None,
+ requestExtensions: Seq[protobuf.Any] = Seq.empty,
+ customize: proto.GetStatusRequest.Builder => Unit = _ => ()):
GetStatusResponse = {
+ val userContext = proto.UserContext.newBuilder().setUserId(userId).build()
+ val requestBuilder = proto.GetStatusRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+
+ requestExtensions.foreach(requestBuilder.addExtensions)
+
serverSideSessionId.foreach(requestBuilder.setClientObservedServerSideSessionId)
+ customize(requestBuilder)
+
+ val request = requestBuilder.build()
+ val responseObserver = new GetStatusResponseObserver()
+ val handler = new SparkConnectGetStatusHandler(responseObserver)
+ handler.handle(request)
+
+ ThreadUtils.awaitResult(responseObserver.promise.future, 10.seconds)
+ }
+
+ private def sendGetOperationStatusRequest(
+ sessionId: String,
+ operationIds: Seq[String] = Seq.empty,
+ userId: String = defaultUserId,
+ serverSideSessionId: Option[String] = None,
+ requestExtensions: Seq[protobuf.Any] = Seq.empty,
+ operationExtensions: Seq[protobuf.Any] = Seq.empty): GetStatusResponse =
{
+ sendGetStatusRequest(
+ sessionId,
+ userId,
+ serverSideSessionId,
+ requestExtensions,
+ { builder =>
+ val operationStatusRequest =
proto.GetStatusRequest.OperationStatusRequest.newBuilder()
+ operationIds.foreach(operationStatusRequest.addOperationIds)
+ operationExtensions.foreach(operationStatusRequest.addExtensions)
+ builder.setOperationStatus(operationStatusRequest)
+ })
+ }
+
+ test("GetStatus returns session info for new session") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val response = sendGetStatusRequest(sessionHolder.sessionId, userId =
sessionHolder.userId)
+
+ assert(response.getSessionId == sessionHolder.sessionId)
+ assert(response.getServerSideSessionId.nonEmpty)
+ }
+
+ test("GetStatus without operation IDs returns all existing operations") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder1 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder2 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder3 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ val response =
+ sendGetOperationStatusRequest(sessionHolder.sessionId, userId =
sessionHolder.userId)
+
+ val statuses = response.getOperationStatusesList.asScala
+ assert(statuses.size == 3)
+ val operationIds = statuses.map(_.getOperationId).toSet
+ assert(operationIds.contains(executeHolder1.operationId))
+ assert(operationIds.contains(executeHolder2.operationId))
+ assert(operationIds.contains(executeHolder3.operationId))
+ }
+
+ test("GetStatus returns UNKNOWN status for non-existent operation ID") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val nonExistentOperationId = UUID.randomUUID().toString
+
+ val response = sendGetOperationStatusRequest(
+ sessionHolder.sessionId,
+ Seq(nonExistentOperationId),
+ userId = sessionHolder.userId)
+
+ val statuses = response.getOperationStatusesList.asScala
+ assert(statuses.size == 1)
+ assert(statuses.head.getOperationId == nonExistentOperationId)
+ assert(
+ statuses.head.getState ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_UNKNOWN)
+ }
+
+ test("GetStatus returns RUNNING status for active operation") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ // The execute holder is created with Started status by
createDummyExecuteHolder
+ val response = sendGetOperationStatusRequest(
+ sessionHolder.sessionId,
+ Seq(executeHolder.operationId),
+ sessionHolder.userId)
+
+ val statuses = response.getOperationStatusesList.asScala
+ assert(statuses.size == 1)
+ assert(statuses.head.getOperationId == executeHolder.operationId)
+ assert(
+ statuses.head.getState ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+ }
+
+ test("GetStatus propagates both request and operation extensions via
plugin") {
+ SparkConnectPluginRegistry.setGetStatusPluginsForTesting(Seq(new
EchoGetStatusPlugin()))
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder1 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder2 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ val reqExt = protobuf.Any.pack(StringValue.of("req-data"))
+ val opExt = protobuf.Any.pack(StringValue.of("op-data"))
+ val response = sendGetOperationStatusRequest(
+ sessionHolder.sessionId,
+ operationIds = Seq(executeHolder1.operationId,
executeHolder2.operationId),
+ userId = sessionHolder.userId,
+ requestExtensions = Seq(reqExt),
+ operationExtensions = Seq(opExt))
+
+ // Verify request-level extensions
+ val responseExtensions = response.getExtensionsList.asScala
+ assert(responseExtensions.size == 1)
+ assert(
+ responseExtensions.head.unpack(classOf[StringValue]).getValue ==
"request-echo:req-data")
+
+ // Verify operation-level extensions for both operations
+ val statuses = response.getOperationStatusesList.asScala
+ assert(statuses.size == 2)
+
+ val statusByOpId = statuses.map(s => s.getOperationId -> s).toMap
+ Seq(executeHolder1, executeHolder2).foreach { holder =>
+ val opStatus = statusByOpId(holder.operationId)
+ assert(
+ opStatus.getState ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+
+ val opExtensions = opStatus.getExtensionsList.asScala
+ assert(opExtensions.size == 1)
+ assert(
+ opExtensions.head.unpack(classOf[StringValue]).getValue ==
+ s"op-echo:${holder.operationId}:op-data")
+ }
+ }
+
+ test("GetStatus with no plugin returns no extensions") {
+ SparkConnectPluginRegistry.setGetStatusPluginsForTesting(Seq.empty)
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val operationId = UUID.randomUUID().toString
+
+ val reqExt = protobuf.Any.pack(StringValue.of("ignored"))
+ val opExt = protobuf.Any.pack(StringValue.of("also-ignored"))
+ val response = sendGetOperationStatusRequest(
+ sessionHolder.sessionId,
+ operationIds = Seq(operationId),
+ userId = sessionHolder.userId,
+ requestExtensions = Seq(reqExt),
+ operationExtensions = Seq(opExt))
+
+ assert(response.getExtensionsList.isEmpty)
+ val statuses = response.getOperationStatusesList.asScala
+ assert(statuses.size == 1)
+ assert(statuses.head.getExtensionsList.isEmpty)
+ }
+
+ test("GetStatus aggregates extensions from multiple plugins, skipping empty
ones") {
+ SparkConnectPluginRegistry.setGetStatusPluginsForTesting(
+ Seq(new EchoGetStatusPlugin(), new NoOpGetStatusPlugin(), new
SecondEchoGetStatusPlugin()))
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ val reqExt = protobuf.Any.pack(StringValue.of("hello"))
+ val opExt = protobuf.Any.pack(StringValue.of("world"))
+ val response = sendGetOperationStatusRequest(
+ sessionHolder.sessionId,
+ operationIds = Seq(executeHolder.operationId),
+ userId = sessionHolder.userId,
+ requestExtensions = Seq(reqExt),
+ operationExtensions = Seq(opExt))
+
+ val responseExtValues = response.getExtensionsList.asScala
+ .map(_.unpack(classOf[StringValue]).getValue)
+ assert(responseExtValues.size == 2)
+ assert(responseExtValues.contains("request-echo:hello"))
+ assert(responseExtValues.contains("second-request:hello"))
+
+ val statuses = response.getOperationStatusesList.asScala
+ assert(statuses.size == 1)
+ val opExtValues = statuses.head.getExtensionsList.asScala
+ .map(_.unpack(classOf[StringValue]).getValue)
+ assert(opExtValues.size == 2)
+ assert(opExtValues.contains(s"op-echo:${executeHolder.operationId}:world"))
+
assert(opExtValues.contains(s"second-op:${executeHolder.operationId}:world"))
+ }
+
+ test("GetStatus chain isolates plugin failures and collects from healthy
plugins") {
+ SparkConnectPluginRegistry.setGetStatusPluginsForTesting(
+ Seq(
+ new EchoGetStatusPlugin(),
+ new FailingGetStatusPlugin(),
+ new SecondEchoGetStatusPlugin()))
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ val reqExt = protobuf.Any.pack(StringValue.of("safe"))
+ val opExt = protobuf.Any.pack(StringValue.of("safe"))
+ val response = sendGetOperationStatusRequest(
+ sessionHolder.sessionId,
+ operationIds = Seq(executeHolder.operationId),
+ userId = sessionHolder.userId,
+ requestExtensions = Seq(reqExt),
+ operationExtensions = Seq(opExt))
+
+ val responseExtValues = response.getExtensionsList.asScala
+ .map(_.unpack(classOf[StringValue]).getValue)
+ assert(responseExtValues.size == 2)
+ assert(responseExtValues.contains("request-echo:safe"))
+ assert(responseExtValues.contains("second-request:safe"))
+
+ val opExtValues =
response.getOperationStatusesList.asScala.head.getExtensionsList.asScala
+ .map(_.unpack(classOf[StringValue]).getValue)
+ assert(opExtValues.size == 2)
+ assert(opExtValues.contains(s"op-echo:${executeHolder.operationId}:safe"))
+
assert(opExtValues.contains(s"second-op:${executeHolder.operationId}:safe"))
+ }
+}
+
+private class GetStatusResponseObserver extends
StreamObserver[proto.GetStatusResponse] {
+ val promise: Promise[GetStatusResponse] = Promise()
+ override def onNext(value: proto.GetStatusResponse): Unit =
promise.success(value)
+ override def onError(t: Throwable): Unit = promise.failure(t)
+ override def onCompleted(): Unit = {}
+}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManagerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManagerSuite.scala
new file mode 100644
index 000000000000..228bb6e83a98
--- /dev/null
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManagerSuite.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.SparkConnectTestUtils
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Test suite for SparkConnectExecutionManager.
+ */
+class SparkConnectExecutionManagerSuite extends SharedSparkSession {
+
+ protected override def afterEach(): Unit = {
+ super.afterEach()
+ SparkConnectService.sessionManager.invalidateAllSessions()
+ }
+
+ private def executionManager: SparkConnectExecutionManager = {
+ SparkConnectService.executionManager
+ }
+
+ test("tombstone is updated with Closed status after removeExecuteHolder with
abandoned") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeKey = executeHolder.key
+
+ executionManager.removeExecuteHolder(executeKey, abandoned = true)
+
+ val tombstoneInfo = executionManager.getAbandonedTombstone(executeKey)
+ assert(tombstoneInfo.isDefined, "Tombstone should exist for abandoned
operation")
+
+ val info = tombstoneInfo.get
+ assert(
+ info.status == ExecuteStatus.Closed,
+ s"Expected Closed status in tombstone, got ${info.status}")
+ assert(info.closedTimeNs.isDefined, "closedTimeNs should be set after
close()")
+ assert(info.closedTimeNs.get > 0, "closedTimeNs should be > 0")
+ }
+
+ test("normal execution removal does not create tombstone") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeKey = executeHolder.key
+
+ executionManager.removeExecuteHolder(executeKey)
+
+ val tombstoneInfo = executionManager.getAbandonedTombstone(executeKey)
+ assert(tombstoneInfo.isEmpty, "Tombstone should not exist for normal
(non-abandoned) removal")
+ }
+
+ test("inactiveOperations cache has correct state after abandoned removal") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val operationId = executeHolder.operationId
+
+ executionManager.removeExecuteHolder(executeHolder.key, abandoned = true)
+
+ val inactiveInfo = sessionHolder.getInactiveOperationInfo(operationId)
+ assert(inactiveInfo.isDefined, "Operation should be in inactive operations
cache")
+
+ val info = inactiveInfo.get
+ assert(
+ info.status == ExecuteStatus.Closed,
+ s"Expected Closed status in inactive cache, got ${info.status}")
+ assert(
+ info.terminationReason.isDefined,
+ "terminationReason should be set by postCanceled and captured by
closeOperation")
+ assert(
+ info.terminationReason.get == TerminationReason.Canceled,
+ s"Expected Canceled terminationReason for abandoned, got
${info.terminationReason}")
+ }
+
+ test("inactiveOperations cache has correct state after normal removal") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val operationId = executeHolder.operationId
+
+ assert(
+ sessionHolder.getOperationStatus(operationId).contains(true),
+ "Operation should be active before removal")
+ assert(
+ sessionHolder.getInactiveOperationInfo(operationId).isEmpty,
+ "Operation should not be in inactive cache before removal")
+
+ executionManager.removeExecuteHolder(executeHolder.key)
+
+ assert(
+ sessionHolder.getOperationStatus(operationId).contains(false),
+ "Operation should be inactive after removal")
+ val inactiveInfo = sessionHolder.getInactiveOperationInfo(operationId)
+ assert(inactiveInfo.isDefined, "Operation should be in inactive cache
after removal")
+
+ val info = inactiveInfo.get
+ assert(info.operationId == operationId, "Operation ID should match")
+ assert(
+ info.status == ExecuteStatus.Closed,
+ s"Expected Closed status in inactive cache, got ${info.status}")
+ }
+}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index 1b747705e9ad..17402ab5ddb4 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -431,6 +431,94 @@ class SparkConnectSessionHolderSuite extends
SharedSparkSession {
assert(ex.getMessage.contains("already exists"))
}
+ test("getInactiveOperationInfo returns TerminationInfo for closed
operations") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val operationId = executeHolder.operationId
+
+ sessionHolder.closeOperation(executeHolder)
+
+ val inactiveInfo = sessionHolder.getInactiveOperationInfo(operationId)
+ assert(inactiveInfo.isDefined)
+ assert(inactiveInfo.get.operationId == operationId)
+ }
+
+ test("getInactiveOperationInfo returns None for active operations") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val command = proto.Command.newBuilder().build()
+ val executeHolder =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val operationId = executeHolder.operationId
+
+ assert(sessionHolder.getInactiveOperationInfo(operationId) == None)
+ }
+
+ test("getInactiveOperationInfo returns None for unknown operations") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ assert(sessionHolder.getInactiveOperationInfo("unknown-op") == None)
+ }
+
+ test("listInactiveOperations returns all closed operations") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+
+ val command = proto.Command.newBuilder().build()
+ val executeHolder1 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder2 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder3 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ sessionHolder.closeOperation(executeHolder1)
+ sessionHolder.closeOperation(executeHolder2)
+ sessionHolder.closeOperation(executeHolder3)
+
+ val inactiveOps = sessionHolder.listInactiveOperations()
+ assert(inactiveOps.size == 3)
+ val inactiveOpIds = inactiveOps.map(_.operationId).toSet
+ assert(inactiveOpIds.contains(executeHolder1.operationId))
+ assert(inactiveOpIds.contains(executeHolder2.operationId))
+ assert(inactiveOpIds.contains(executeHolder3.operationId))
+ }
+
+ test("listInactiveOperations returns empty for new session") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ assert(sessionHolder.listInactiveOperations().isEmpty)
+ }
+
+ test("listActiveOperationIds returns all active operations") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+
+ val command = proto.Command.newBuilder().build()
+ val executeHolder1 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder2 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder3 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ val activeOps = sessionHolder.listActiveOperationIds()
+ assert(activeOps.size == 3)
+ val activeOpIds = activeOps.toSet
+ assert(activeOpIds.contains(executeHolder1.operationId))
+ assert(activeOpIds.contains(executeHolder2.operationId))
+ assert(activeOpIds.contains(executeHolder3.operationId))
+ }
+
+ test("listActiveOperationIds returns empty for new session") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ assert(sessionHolder.listActiveOperationIds().isEmpty)
+ }
+
+ test("listActiveOperationIds excludes closed operations") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+
+ val command = proto.Command.newBuilder().build()
+ val executeHolder1 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+ val executeHolder2 =
SparkConnectTestUtils.createDummyExecuteHolder(sessionHolder, command)
+
+ sessionHolder.closeOperation(executeHolder1)
+
+ val activeOps = sessionHolder.listActiveOperationIds()
+ assert(activeOps.size == 1)
+ assert(activeOps.contains(executeHolder2.operationId))
+ assert(!activeOps.contains(executeHolder1.operationId))
+ }
+
test("Pipeline execution cache") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
val graphId = "test_graph"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]