This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 59e291d36c4a [SPARK-45680][CONNECT] Release session
59e291d36c4a is described below
commit 59e291d36c4a9d956b993968a324359b3d75fe5f
Author: Juliusz Sompolski <[email protected]>
AuthorDate: Thu Nov 2 09:11:48 2023 +0900
[SPARK-45680][CONNECT] Release session
### What changes were proposed in this pull request?
Introduce a new `ReleaseSession` Spark Connect RPC, which cancels
everything running in the session and removes the session server side. Refactor
code around managing the cache of sessions into `SparkConnectSessionManager`.
### Why are the changes needed?
Better session management.
### Does this PR introduce _any_ user-facing change?
Not really. `SparkSession.stop()` API already existed on the client side.
It was closing the client's network connection, but the Session was still there
cached for 1 hour on the server side.
Caveats, which were not really supported user behaviour:
* After `session.stop()`, user could have created a new session with the
same session_id in Configuration. That session would be a new session on the
client side, but connect to the old cached session in the server. It could
therefore e.g. access that old session's state like views or artifacts.
* If a session timed out and was removed in the server, it used to be that
a new request would re-create the session. The client would then see this as
the old session, but the server would see a new one, and e.g. not have access
to old session state that was removed.
* User is no longer allowed to create a new session with the same
session_id as before.
### How was this patch tested?
Tests added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43546 from juliuszsompolski/release-session.
Lead-authored-by: Juliusz Sompolski <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/resources/error/error-classes.json | 5 +
.../scala/org/apache/spark/sql/SparkSession.scala | 8 +
.../apache/spark/sql/PlanGenerationTestSuite.scala | 4 +-
.../org/apache/spark/sql/SparkSessionSuite.scala | 38 +++--
.../src/main/protobuf/spark/connect/base.proto | 30 ++++
.../client/CustomSparkConnectBlockingStub.scala | 11 ++
.../sql/connect/client/SparkConnectClient.scala | 10 ++
.../apache/spark/sql/connect/config/Connect.scala | 18 +++
.../spark/sql/connect/service/SessionHolder.scala | 79 ++++++++-
.../service/SparkConnectExecutionManager.scala | 23 ++-
.../SparkConnectReleaseExecuteHandler.scala | 4 +-
.../SparkConnectReleaseSessionHandler.scala | 40 +++++
.../sql/connect/service/SparkConnectService.scala | 117 +++-----------
.../service/SparkConnectSessionManager.scala | 177 +++++++++++++++++++++
.../spark/sql/connect/utils/ErrorUtils.scala | 27 ++--
.../spark/sql/connect/SparkConnectServerTest.scala | 21 ++-
.../execution/ReattachableExecuteSuite.scala | 4 +
.../connect/planner/SparkConnectServiceSuite.scala | 4 +-
.../service/SparkConnectServiceE2ESuite.scala | 158 ++++++++++++++++++
...-error-conditions-invalid-handle-error-class.md | 4 +
python/pyspark/sql/connect/client/core.py | 23 ++-
python/pyspark/sql/connect/proto/base_pb2.py | 42 ++---
python/pyspark/sql/connect/proto/base_pb2.pyi | 78 +++++++++
python/pyspark/sql/connect/proto/base_pb2_grpc.py | 49 ++++++
python/pyspark/sql/connect/session.py | 12 +-
.../sql/tests/connect/test_connect_basic.py | 1 +
26 files changed, 819 insertions(+), 168 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 278011b8cc8f..af32bcf129c0 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -1737,6 +1737,11 @@
"Session already exists."
]
},
+ "SESSION_CLOSED" : {
+ "message" : [
+ "Session was closed."
+ ]
+ },
"SESSION_NOT_FOUND" : {
"message" : [
"Session not found."
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 969ac017ecb1..1cc1c8400fa8 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -665,6 +665,9 @@ class SparkSession private[sql] (
* @since 3.4.0
*/
override def close(): Unit = {
+ if (releaseSessionOnClose) {
+ client.releaseSession()
+ }
client.shutdown()
allocator.close()
SparkSession.onSessionClose(this)
@@ -735,6 +738,11 @@ class SparkSession private[sql] (
* We null out the instance for now.
*/
private def writeReplace(): Any = null
+
+ /**
+ * Set to false to prevent client.releaseSession on close() (testing only)
+ */
+ private[sql] var releaseSessionOnClose = true
}
// The minimal builder needed to create a spark session.
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index cf287088b59f..5cc63bc45a04 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -120,7 +120,9 @@ class PlanGenerationTestSuite
}
override protected def afterAll(): Unit = {
- session.close()
+ // Don't call client.releaseSession on close(), because the connection
details are dummy.
+ session.releaseSessionOnClose = false
+ session.stop()
if (cleanOrphanedGoldenFiles) {
cleanOrphanedGoldenFile()
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
index 4c858262c6ef..8abc41639fdd 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -33,18 +33,24 @@ class SparkSessionSuite extends ConnectFunSuite {
private val connectionString2: String = "sc://test.me:14099"
private val connectionString3: String = "sc://doit:16845"
+ private def closeSession(session: SparkSession): Unit = {
+ // Don't call client.releaseSession on close(), because the connection
details are dummy.
+ session.releaseSessionOnClose = false
+ session.close()
+ }
+
test("default") {
val session = SparkSession.builder().getOrCreate()
assert(session.client.configuration.host == "localhost")
assert(session.client.configuration.port == 15002)
- session.close()
+ closeSession(session)
}
test("remote") {
val session =
SparkSession.builder().remote(connectionString2).getOrCreate()
assert(session.client.configuration.host == "test.me")
assert(session.client.configuration.port == 14099)
- session.close()
+ closeSession(session)
}
test("getOrCreate") {
@@ -53,8 +59,8 @@ class SparkSessionSuite extends ConnectFunSuite {
try {
assert(session1 eq session2)
} finally {
- session1.close()
- session2.close()
+ closeSession(session1)
+ closeSession(session2)
}
}
@@ -65,8 +71,8 @@ class SparkSessionSuite extends ConnectFunSuite {
assert(session1 ne session2)
assert(session1.client.configuration == session2.client.configuration)
} finally {
- session1.close()
- session2.close()
+ closeSession(session1)
+ closeSession(session2)
}
}
@@ -77,8 +83,8 @@ class SparkSessionSuite extends ConnectFunSuite {
assert(session1 ne session2)
assert(session1.client.configuration == session2.client.configuration)
} finally {
- session1.close()
- session2.close()
+ closeSession(session1)
+ closeSession(session2)
}
}
@@ -98,7 +104,7 @@ class SparkSessionSuite extends ConnectFunSuite {
assertThrows[RuntimeException] {
session.range(10).count()
}
- session.close()
+ closeSession(session)
}
test("Default/Active session") {
@@ -136,12 +142,12 @@ class SparkSessionSuite extends ConnectFunSuite {
assert(SparkSession.getActiveSession.contains(session1))
// Close session1
- session1.close()
+ closeSession(session1)
assert(SparkSession.getDefaultSession.contains(session2))
assert(SparkSession.getActiveSession.isEmpty)
// Close session2
- session2.close()
+ closeSession(session2)
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.isEmpty)
}
@@ -187,7 +193,7 @@ class SparkSessionSuite extends ConnectFunSuite {
// Step 3 - close session 1, no more default session in both scripts
phaser.arriveAndAwaitAdvance()
- session1.close()
+ closeSession(session1)
// Step 4 - no default session, same active session.
phaser.arriveAndAwaitAdvance()
@@ -240,13 +246,13 @@ class SparkSessionSuite extends ConnectFunSuite {
// Step 7 - close active session in script2
phaser.arriveAndAwaitAdvance()
- internalSession.close()
+ closeSession(internalSession)
assert(SparkSession.getActiveSession.isEmpty)
}
assert(script1.get())
assert(script2.get())
assert(SparkSession.getActiveSession.contains(session2))
- session2.close()
+ closeSession(session2)
assert(SparkSession.getActiveSession.isEmpty)
} finally {
executor.shutdown()
@@ -254,13 +260,13 @@ class SparkSessionSuite extends ConnectFunSuite {
}
test("deprecated methods") {
- SparkSession
+ val session = SparkSession
.builder()
.master("yayay")
.appName("bob")
.enableHiveSupport()
.create()
- .close()
+ closeSession(session)
}
test("serialize as null") {
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 27f51551ba92..19a94a5a429f 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -784,6 +784,30 @@ message ReleaseExecuteResponse {
optional string operation_id = 2;
}
+message ReleaseSessionRequest {
+ // (Required)
+ //
+ // The session_id of the request to reattach to.
+ // This must be an id of existing session.
+ string session_id = 1;
+
+ // (Required) User context
+ //
+ // user_context.user_id and session+id both identify a unique remote spark
session on the
+ // server side.
+ UserContext user_context = 2;
+
+ // Provides optional information about the client sending the request. This
field
+ // can be used for language or version specific information and is only
intended for
+ // logging purposes and will not be interpreted by the server.
+ optional string client_type = 3;
+}
+
+message ReleaseSessionResponse {
+ // Session id of the session on which the release executed.
+ string session_id = 1;
+}
+
message FetchErrorDetailsRequest {
// (Required)
@@ -934,6 +958,12 @@ service SparkConnectService {
// RPC and ReleaseExecute may not be used.
rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {}
+ // Release a session.
+ // All the executions in the session will be released. Any further requests
for the session with
+ // that session_id for the given user_id will fail. If the session didn't
exist or was already
+ // released, this is a noop.
+ rpc ReleaseSession(ReleaseSessionRequest) returns (ReleaseSessionResponse) {}
+
// FetchErrorDetails retrieves the matched exception with details based on a
provided error id.
rpc FetchErrorDetails(FetchErrorDetailsRequest) returns
(FetchErrorDetailsResponse) {}
}
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index f2efa26f6b60..e963b4136160 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -96,6 +96,17 @@ private[connect] class CustomSparkConnectBlockingStub(
}
}
+ def releaseSession(request: ReleaseSessionRequest): ReleaseSessionResponse =
{
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
+ retryHandler.retry {
+ stub.releaseSession(request)
+ }
+ }
+ }
+
def artifactStatus(request: ArtifactStatusesRequest):
ArtifactStatusesResponse = {
grpcExceptionConverter.convert(
request.getSessionId,
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 42ace003da89..6d3d9420e226 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -243,6 +243,16 @@ private[sql] class SparkConnectClient(
bstub.interrupt(request)
}
+ private[sql] def releaseSession(): proto.ReleaseSessionResponse = {
+ val builder = proto.ReleaseSessionRequest.newBuilder()
+ val request = builder
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setClientType(userAgent)
+ .build()
+ bstub.releaseSession(request)
+ }
+
private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
override def childValue(parent: mutable.Set[String]): mutable.Set[String]
= {
// Note: make a clone such that changes in the parent tags aren't
reflected in
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 2b3f218362cd..1a5944676f5f 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -74,6 +74,24 @@ object Connect {
.intConf
.createWithDefault(1024)
+ val CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT =
+ buildStaticConf("spark.connect.session.manager.defaultSessionTimeout")
+ .internal()
+ .doc("Timeout after which sessions without any new incoming RPC will be
removed.")
+ .version("4.0.0")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefaultString("60m")
+
+ val CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE =
+
buildStaticConf("spark.connect.session.manager.closedSessionsTombstonesSize")
+ .internal()
+ .doc(
+ "Maximum size of the cache of sessions after which sessions that did
not receive any " +
+ "requests will be removed.")
+ .version("4.0.0")
+ .intConf
+ .createWithDefaultString("1000")
+
val CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT =
buildStaticConf("spark.connect.execute.manager.detachedTimeout")
.internal()
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index dcced21f3714..792012a682b2 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._
import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder
-import org.apache.spark.{JobArtifactSet, SparkException}
+import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
@@ -40,12 +40,19 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.SystemClock
import org.apache.spark.util.Utils
+// Unique key identifying session by combination of user, and session id
+case class SessionKey(userId: String, sessionId: String)
+
/**
* Object used to hold the Spark Connect session state.
*/
case class SessionHolder(userId: String, sessionId: String, session:
SparkSession)
extends Logging {
+ @volatile private var lastRpcAccessTime: Option[Long] = None
+
+ @volatile private var isClosing: Boolean = false
+
private val executions: ConcurrentMap[String, ExecuteHolder] =
new ConcurrentHashMap[String, ExecuteHolder]()
@@ -73,8 +80,21 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
private[connect] lazy val streamingForeachBatchRunnerCleanerCache =
new StreamingForeachBatchHelper.CleanerCache(this)
- /** Add ExecuteHolder to this session. Called only by
SparkConnectExecutionManager. */
+ def key: SessionKey = SessionKey(userId, sessionId)
+
+ /**
+ * Add ExecuteHolder to this session.
+ *
+ * Called only by SparkConnectExecutionManager under executionsLock.
+ */
private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = {
+ if (isClosing) {
+ // Do not accept new executions if the session is closing.
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.SESSION_CLOSED",
+ messageParameters = Map("handle" -> sessionId))
+ }
+
val oldExecute = executions.putIfAbsent(executeHolder.operationId,
executeHolder)
if (oldExecute != null) {
// the existence of this should alrady be checked by
SparkConnectExecutionManager
@@ -160,21 +180,55 @@ case class SessionHolder(userId: String, sessionId:
String, session: SparkSessio
*/
def classloader: ClassLoader = artifactManager.classloader
+ private[connect] def updateAccessTime(): Unit = {
+ lastRpcAccessTime = Some(System.currentTimeMillis())
+ }
+
+ /**
+ * Initialize the session.
+ *
+ * Called only by SparkConnectSessionManager.
+ */
private[connect] def initializeSession(): Unit = {
+ updateAccessTime()
eventManager.postStarted()
}
/**
* Expire this session and trigger state cleanup mechanisms.
+ *
+ * Called only by SparkConnectSessionManager.
*/
- private[connect] def expireSession(): Unit = {
- logDebug(s"Expiring session with userId: $userId and sessionId:
$sessionId")
+ private[connect] def close(): Unit = {
+ logInfo(s"Closing session with userId: $userId and sessionId: $sessionId")
+
+ // After isClosing=true, SessionHolder.addExecuteHolder() will not allow
new executions for
+ // this session. Because both SessionHolder.addExecuteHolder() and
+ // SparkConnectExecutionManager.removeAllExecutionsForSession() are
executed under
+ // executionsLock, this guarantees that removeAllExecutionsForSession
triggered below will
+ // remove all executions and no new executions will be added in the
meanwhile.
+ isClosing = true
+
+ // Note on the below notes about concurrency:
+ // While closing the session can potentially race with operations started
on the session, the
+ // intended use is that the client session will get closed when it's
really not used anymore,
+ // or that it expires due to inactivity, in which case there should be no
races.
+
+ // Clean up all artifacts.
+ // Note: there can be concurrent AddArtifact calls still adding something.
artifactManager.cleanUpResources()
- eventManager.postClosed()
- // Clean up running queries
+
+ // Clean up running streaming queries.
+ // Note: there can be concurrent streaming queries being started.
SparkConnectService.streamingSessionManager.cleanupRunningQueries(this)
streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any
streaming workers.
removeAllListeners() // removes all listener and stop python listener
processes if necessary.
+
+ // Clean up all executions
+ // It is guaranteed at this point that no new addExecuteHolder are getting
started.
+
SparkConnectService.executionManager.removeAllExecutionsForSession(this.key)
+
+ eventManager.postClosed()
}
/**
@@ -204,6 +258,10 @@ case class SessionHolder(userId: String, sessionId:
String, session: SparkSessio
}
}
+ /** Get SessionInfo with information about this SessionHolder. */
+ def getSessionHolderInfo: SessionHolderInfo =
+ SessionHolderInfo(userId, sessionId, eventManager.status,
lastRpcAccessTime)
+
/**
* Caches given DataFrame with the ID. The cache does not expire. The entry
needs to be
* explicitly removed by the owners of the DataFrame once it is not needed.
@@ -291,7 +349,14 @@ object SessionHolder {
userId = "testUser",
sessionId = UUID.randomUUID().toString,
session = session)
- SparkConnectService.putSessionForTesting(ret)
+ SparkConnectService.sessionManager.putSessionForTesting(ret)
ret
}
}
+
+/** Basic information about SessionHolder. */
+case class SessionHolderInfo(
+ userId: String,
+ sessionId: String,
+ status: SessionStatus,
+ lastRpcAccesTime: Option[Long])
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index 3c7254897822..c004358e1cf1 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -95,11 +95,16 @@ private[connect] class SparkConnectExecutionManager()
extends Logging {
* Remove an ExecuteHolder from this global manager and from its session.
Interrupt the
* execution if still running, free all resources.
*/
- private[connect] def removeExecuteHolder(key: ExecuteKey): Unit = {
+ private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean
= false): Unit = {
var executeHolder: Option[ExecuteHolder] = None
executionsLock.synchronized {
executeHolder = executions.remove(key)
- executeHolder.foreach(e =>
e.sessionHolder.removeExecuteHolder(e.operationId))
+ executeHolder.foreach { e =>
+ if (abandoned) {
+ abandonedTombstones.put(key, e.getExecuteInfo)
+ }
+ e.sessionHolder.removeExecuteHolder(e.operationId)
+ }
if (executions.isEmpty) {
lastExecutionTime = Some(System.currentTimeMillis())
}
@@ -115,6 +120,17 @@ private[connect] class SparkConnectExecutionManager()
extends Logging {
}
}
+ private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
+ val sessionExecutionHolders = executionsLock.synchronized {
+ executions.filter(_._2.sessionHolder.key == key)
+ }
+ sessionExecutionHolders.foreach { case (_, executeHolder) =>
+ val info = executeHolder.getExecuteInfo
+ logInfo(s"Execution $info removed in removeSessionExecutions.")
+ removeExecuteHolder(executeHolder.key, abandoned = true)
+ }
+ }
+
/** Get info about abandoned execution, if there is one. */
private[connect] def getAbandonedTombstone(key: ExecuteKey):
Option[ExecuteInfo] = {
Option(abandonedTombstones.getIfPresent(key))
@@ -204,8 +220,7 @@ private[connect] class SparkConnectExecutionManager()
extends Logging {
toRemove.foreach { executeHolder =>
val info = executeHolder.getExecuteInfo
logInfo(s"Found execution $info that was abandoned and expired and
will be removed.")
- removeExecuteHolder(executeHolder.key)
- abandonedTombstones.put(executeHolder.key, info)
+ removeExecuteHolder(executeHolder.key, abandoned = true)
}
}
logInfo("Finished periodic run of SparkConnectExecutionManager
maintenance.")
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
index a3a7815609e4..1ca886960d53 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
@@ -28,8 +28,8 @@ class SparkConnectReleaseExecuteHandler(
extends Logging {
def handle(v: proto.ReleaseExecuteRequest): Unit = {
- val sessionHolder = SparkConnectService
- .getIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+ val sessionHolder = SparkConnectService.sessionManager
+ .getIsolatedSession(SessionKey(v.getUserContext.getUserId,
v.getSessionId))
val responseBuilder = proto.ReleaseExecuteResponse
.newBuilder()
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
new file mode 100644
index 000000000000..a32852bac45e
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+
+class SparkConnectReleaseSessionHandler(
+ responseObserver: StreamObserver[proto.ReleaseSessionResponse])
+ extends Logging {
+
+ def handle(v: proto.ReleaseSessionRequest): Unit = {
+ val responseBuilder = proto.ReleaseSessionResponse.newBuilder()
+ responseBuilder.setSessionId(v.getSessionId)
+
+ // If the session doesn't exist, this will just be a noop.
+ val key = SessionKey(v.getUserContext.getUserId, v.getSessionId)
+ SparkConnectService.sessionManager.closeSession(key)
+
+ responseObserver.onNext(responseBuilder.build())
+ responseObserver.onCompleted()
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index e82c9cba5626..e4b60eeeff0d 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -18,13 +18,10 @@
package org.apache.spark.sql.connect.service
import java.net.InetSocketAddress
-import java.util.UUID
-import java.util.concurrent.{Callable, TimeUnit}
+import java.util.concurrent.TimeUnit
import scala.jdk.CollectionConverters._
-import com.google.common.base.Ticker
-import com.google.common.cache.{CacheBuilder, RemovalListener,
RemovalNotification}
import com.google.protobuf.MessageLite
import io.grpc.{BindableService, MethodDescriptor, Server,
ServerMethodDefinition, ServerServiceDefinition}
import io.grpc.MethodDescriptor.PrototypeMarshaller
@@ -34,13 +31,12 @@ import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
+import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest,
AddArtifactsResponse, SparkConnectServiceGrpc}
import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI.UI_ENABLED
-import org.apache.spark.sql.SparkSession
import
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS,
CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT,
CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore,
SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
@@ -201,6 +197,22 @@ class SparkConnectService(debug: Boolean) extends
AsyncService with BindableServ
sessionId = request.getSessionId)
}
+ /**
+ * Release session.
+ */
+ override def releaseSession(
+ request: proto.ReleaseSessionRequest,
+ responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = {
+ try {
+ new SparkConnectReleaseSessionHandler(responseObserver).handle(request)
+ } catch
+ ErrorUtils.handleError(
+ "releaseSession",
+ observer = responseObserver,
+ userId = request.getUserContext.getUserId,
+ sessionId = request.getSessionId)
+ }
+
override def fetchErrorDetails(
request: proto.FetchErrorDetailsRequest,
responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit
= {
@@ -268,14 +280,6 @@ class SparkConnectService(debug: Boolean) extends
AsyncService with BindableServ
*/
object SparkConnectService extends Logging {
- private val CACHE_SIZE = 100
-
- private val CACHE_TIMEOUT_SECONDS = 3600
-
- // Type alias for the SessionCacheKey. Right now this is a String but allows
us to switch to a
- // different or complex type easily.
- private type SessionCacheKey = (String, String)
-
private[connect] var server: Server = _
private[connect] var uiTab: Option[SparkConnectServerTab] = None
@@ -289,77 +293,18 @@ object SparkConnectService extends Logging {
server.getPort
}
- private val userSessionMapping =
- cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey,
SessionHolder]()
-
private[connect] lazy val executionManager = new
SparkConnectExecutionManager()
+ private[connect] lazy val sessionManager = new SparkConnectSessionManager()
+
private[connect] val streamingSessionManager =
new SparkConnectStreamingQueryCache()
- private class RemoveSessionListener extends RemovalListener[SessionCacheKey,
SessionHolder] {
- override def onRemoval(
- notification: RemovalNotification[SessionCacheKey, SessionHolder]):
Unit = {
- notification.getValue.expireSession()
- }
- }
-
- // Simple builder for creating the cache of Sessions.
- private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int):
CacheBuilder[Object, Object] = {
- var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker())
- if (cacheSize >= 0) {
- cacheBuilder = cacheBuilder.maximumSize(cacheSize)
- }
- if (timeoutSeconds >= 0) {
- cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS)
- }
- cacheBuilder.removalListener(new RemoveSessionListener)
- cacheBuilder
- }
-
/**
* Based on the userId and sessionId, find or create a new SparkSession.
*/
def getOrCreateIsolatedSession(userId: String, sessionId: String):
SessionHolder = {
- getSessionOrDefault(
- userId,
- sessionId,
- () => {
- val holder = SessionHolder(userId, sessionId, newIsolatedSession())
- holder.initializeSession()
- holder
- })
- }
-
- /**
- * Based on the userId and sessionId, find an existing SparkSession or throw
error.
- */
- def getIsolatedSession(userId: String, sessionId: String): SessionHolder = {
- getSessionOrDefault(
- userId,
- sessionId,
- () => {
- logDebug(s"Session not found: ($userId, $sessionId)")
- throw new SparkSQLException(
- errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND",
- messageParameters = Map("handle" -> sessionId))
- })
- }
-
- private def getSessionOrDefault(
- userId: String,
- sessionId: String,
- default: Callable[SessionHolder]): SessionHolder = {
- // Validate that sessionId is formatted like UUID before creating session.
- try {
- UUID.fromString(sessionId).toString
- } catch {
- case _: IllegalArgumentException =>
- throw new SparkSQLException(
- errorClass = "INVALID_HANDLE.FORMAT",
- messageParameters = Map("handle" -> sessionId))
- }
- userSessionMapping.get((userId, sessionId), default)
+ sessionManager.getOrCreateIsolatedSession(SessionKey(userId, sessionId))
}
/**
@@ -368,24 +313,6 @@ object SparkConnectService extends Logging {
*/
def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] =
executionManager.listActiveExecutions
- /**
- * Used for testing
- */
- private[connect] def invalidateAllSessions(): Unit = {
- userSessionMapping.invalidateAll()
- }
-
- /**
- * Used for testing.
- */
- private[connect] def putSessionForTesting(sessionHolder: SessionHolder):
Unit = {
- userSessionMapping.put((sessionHolder.userId, sessionHolder.sessionId),
sessionHolder)
- }
-
- private def newIsolatedSession(): SparkSession = {
- SparkSession.active.newSession()
- }
-
private def createListenerAndUI(sc: SparkContext): Unit = {
val kvStore = sc.statusStore.store.asInstanceOf[ElementTrackingStore]
listener = new SparkConnectServerListener(kvStore, sc.conf)
@@ -445,7 +372,7 @@ object SparkConnectService extends Logging {
}
streamingSessionManager.shutdown()
executionManager.shutdown()
- userSessionMapping.invalidateAll()
+ sessionManager.shutdown()
uiTab.foreach(_.detach())
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
new file mode 100644
index 000000000000..5c8e3c611586
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+import java.util.concurrent.{Callable, TimeUnit}
+
+import com.google.common.base.Ticker
+import com.google.common.cache.{CacheBuilder, RemovalListener,
RemovalNotification}
+
+import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import
org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE,
CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT}
+
+/**
+ * Global tracker of all SessionHolders holding Spark Connect sessions.
+ */
+class SparkConnectSessionManager extends Logging {
+
+ private val sessionsLock = new Object
+
+ private val sessionStore =
+ CacheBuilder
+ .newBuilder()
+ .ticker(Ticker.systemTicker())
+ .expireAfterAccess(
+ SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT),
+ TimeUnit.MILLISECONDS)
+ .removalListener(new RemoveSessionListener)
+ .build[SessionKey, SessionHolder]()
+
+ private val closedSessionsCache =
+ CacheBuilder
+ .newBuilder()
+
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
+ .build[SessionKey, SessionHolderInfo]()
+
+ /**
+ * Based on the userId and sessionId, find or create a new SparkSession.
+ */
+ private[connect] def getOrCreateIsolatedSession(key: SessionKey):
SessionHolder = {
+ // Lock to guard against concurrent removal and insertion into
closedSessionsCache.
+ sessionsLock.synchronized {
+ getSession(
+ key,
+ Some(() => {
+ validateSessionCreate(key)
+ val holder = SessionHolder(key.userId, key.sessionId,
newIsolatedSession())
+ holder.initializeSession()
+ holder
+ }))
+ }
+ }
+
+ /**
+ * Based on the userId and sessionId, find an existing SparkSession or throw
error.
+ */
+ private[connect] def getIsolatedSession(key: SessionKey): SessionHolder = {
+ getSession(
+ key,
+ Some(() => {
+ logDebug(s"Session not found: $key")
+ if (closedSessionsCache.getIfPresent(key) != null) {
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.SESSION_CLOSED",
+ messageParameters = Map("handle" -> key.sessionId))
+ } else {
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND",
+ messageParameters = Map("handle" -> key.sessionId))
+ }
+ }))
+ }
+
+ /**
+ * Based on the userId and sessionId, get an existing SparkSession if
present.
+ */
+ private[connect] def getIsolatedSessionIfPresent(key: SessionKey):
Option[SessionHolder] = {
+ Option(getSession(key, None))
+ }
+
+ private def getSession(
+ key: SessionKey,
+ default: Option[Callable[SessionHolder]]): SessionHolder = {
+ val session = default match {
+ case Some(callable) => sessionStore.get(key, callable)
+ case None => sessionStore.getIfPresent(key)
+ }
+ // record access time before returning
+ session match {
+ case null =>
+ null
+ case s: SessionHolder =>
+ s.updateAccessTime()
+ s
+ }
+ }
+
+ def closeSession(key: SessionKey): Unit = {
+ // Invalidate will trigger RemoveSessionListener
+ sessionStore.invalidate(key)
+ }
+
+ private class RemoveSessionListener extends RemovalListener[SessionKey,
SessionHolder] {
+ override def onRemoval(notification: RemovalNotification[SessionKey,
SessionHolder]): Unit = {
+ val sessionHolder = notification.getValue
+ sessionsLock.synchronized {
+ // First put into closedSessionsCache, so that it cannot get
accidentally recreated by
+ // getOrCreateIsolatedSession.
+ closedSessionsCache.put(sessionHolder.key,
sessionHolder.getSessionHolderInfo)
+ }
+ // Rest of the cleanup outside sessionLock - the session cannot be
accessed anymore by
+ // getOrCreateIsolatedSession.
+ sessionHolder.close()
+ }
+ }
+
+ def shutdown(): Unit = {
+ sessionsLock.synchronized {
+ sessionStore.invalidateAll()
+ closedSessionsCache.invalidateAll()
+ }
+ }
+
+ private def newIsolatedSession(): SparkSession = {
+ SparkSession.active.newSession()
+ }
+
+ private def validateSessionCreate(key: SessionKey): Unit = {
+ // Validate that sessionId is formatted like UUID before creating session.
+ try {
+ UUID.fromString(key.sessionId).toString
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.FORMAT",
+ messageParameters = Map("handle" -> key.sessionId))
+ }
+ // Validate that session with that key has not been already closed.
+ if (closedSessionsCache.getIfPresent(key) != null) {
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.SESSION_CLOSED",
+ messageParameters = Map("handle" -> key.sessionId))
+ }
+ }
+
+ /**
+ * Used for testing
+ */
+ private[connect] def invalidateAllSessions(): Unit = {
+ sessionStore.invalidateAll()
+ closedSessionsCache.invalidateAll()
+ }
+
+ /**
+ * Used for testing.
+ */
+ private[connect] def putSessionForTesting(sessionHolder: SessionHolder):
Unit = {
+ sessionStore.put(sessionHolder.key, sessionHolder)
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 741fa97f1787..837ee5a00227 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -41,7 +41,7 @@ import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.config.Connect
-import org.apache.spark.sql.connect.service.{ExecuteEventsManager,
SessionHolder, SparkConnectService}
+import org.apache.spark.sql.connect.service.{ExecuteEventsManager,
SessionHolder, SessionKey, SparkConnectService}
import org.apache.spark.sql.internal.SQLConf
private[connect] object ErrorUtils extends Logging {
@@ -153,7 +153,9 @@ private[connect] object ErrorUtils extends Logging {
.build()
}
- private def buildStatusFromThrowable(st: Throwable, sessionHolder:
SessionHolder): RPCStatus = {
+ private def buildStatusFromThrowable(
+ st: Throwable,
+ sessionHolderOpt: Option[SessionHolder]): RPCStatus = {
val errorInfo = ErrorInfo
.newBuilder()
.setReason(st.getClass.getName)
@@ -162,20 +164,20 @@ private[connect] object ErrorUtils extends Logging {
"classes",
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))
- if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) {
+ if
(sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)))
{
// Generate a new unique key for this exception.
val errorId = UUID.randomUUID().toString
errorInfo.putMetadata("errorId", errorId)
- sessionHolder.errorIdToError
+ sessionHolderOpt.get.errorIdToError
.put(errorId, st)
}
lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
val withStackTrace =
- if (sessionHolder.session.conf.get(
- SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) {
+ if (sessionHolderOpt.exists(
+ _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) &&
stackTrace.nonEmpty)) {
val maxSize =
SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE)
errorInfo.putMetadata("stackTrace",
StringUtils.abbreviate(stackTrace.get, maxSize))
} else {
@@ -215,19 +217,22 @@ private[connect] object ErrorUtils extends Logging {
sessionId: String,
events: Option[ExecuteEventsManager] = None,
isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = {
- val sessionHolder =
- SparkConnectService
- .getOrCreateIsolatedSession(userId, sessionId)
+
+ // SessionHolder may not be present, e.g. if the session was already
closed.
+ // When SessionHolder is not present error details will not be available
for FetchErrorDetails.
+ val sessionHolderOpt =
+ SparkConnectService.sessionManager.getIsolatedSessionIfPresent(
+ SessionKey(userId, sessionId))
val partial: PartialFunction[Throwable, (Throwable, Throwable)] = {
case se: SparkException if isPythonExecutionException(se) =>
(
se,
StatusProto.toStatusRuntimeException(
- buildStatusFromThrowable(se.getCause, sessionHolder)))
+ buildStatusFromThrowable(se.getCause, sessionHolderOpt)))
case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e)
=>
- (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e,
sessionHolder)))
+ (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e,
sessionHolderOpt)))
case e: Throwable =>
(
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 7b02377f4847..120126f20ec2 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -59,10 +59,6 @@ trait SparkConnectServerTest extends SharedSparkSession {
withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key,
serverPort.toString)) {
SparkConnectService.start(spark.sparkContext)
}
- // register udf directly on the server, we're not testing client UDFs
here...
- val serverSession =
- SparkConnectService.getOrCreateIsolatedSession(defaultUserId,
defaultSessionId).session
- serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms
}))
}
override def afterAll(): Unit = {
@@ -84,6 +80,7 @@ trait SparkConnectServerTest extends SharedSparkSession {
protected def clearAllExecutions(): Unit = {
SparkConnectService.executionManager.listExecuteHolders.foreach(_.close())
SparkConnectService.executionManager.periodicMaintenance(0)
+ SparkConnectService.sessionManager.invalidateAllSessions()
assertNoActiveExecutions()
}
@@ -215,12 +212,24 @@ trait SparkConnectServerTest extends SharedSparkSession {
}
}
+ protected def withClient(sessionId: String = defaultSessionId, userId:
String = defaultUserId)(
+ f: SparkConnectClient => Unit): Unit = {
+ withClient(f, sessionId, userId)
+ }
+
protected def withClient(f: SparkConnectClient => Unit): Unit = {
+ withClient(f, defaultSessionId, defaultUserId)
+ }
+
+ protected def withClient(
+ f: SparkConnectClient => Unit,
+ sessionId: String,
+ userId: String): Unit = {
val client = SparkConnectClient
.builder()
.port(serverPort)
- .sessionId(defaultSessionId)
- .userId(defaultUserId)
+ .sessionId(sessionId)
+ .userId(userId)
.enableReattachableExecute()
.build()
try f(client)
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
index 0e29a07b719a..784b978f447d 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
@@ -347,6 +347,10 @@ class ReattachableExecuteSuite extends
SparkConnectServerTest {
}
test("long sleeping query") {
+ // register udf directly on the server, we're not testing client UDFs
here...
+ val serverSession =
+ SparkConnectService.getOrCreateIsolatedSession(defaultUserId,
defaultSessionId).session
+ serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms
}))
// query will be sleeping and not returning results, while having multiple
reattach
withSparkEnvConfs(
(Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key,
"1s")) {
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index ce452623e6b8..b314e7d8d483 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -841,12 +841,12 @@ class SparkConnectServiceSuite
spark.sparkContext.addSparkListener(verifyEvents.listener)
Utils.tryWithSafeFinally({
f(verifyEvents)
- SparkConnectService.invalidateAllSessions()
+ SparkConnectService.sessionManager.invalidateAllSessions()
verifyEvents.onSessionClosed()
}) {
verifyEvents.waitUntilEmpty()
spark.sparkContext.removeSparkListener(verifyEvents.listener)
- SparkConnectService.invalidateAllSessions()
+ SparkConnectService.sessionManager.invalidateAllSessions()
SparkConnectPluginRegistry.reset()
}
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index 14ecc9a2e95e..cc0481dab0f4 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -16,13 +16,171 @@
*/
package org.apache.spark.sql.connect.service
+import java.util.UUID
+
import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar._
+import org.apache.spark.SparkException
import org.apache.spark.sql.connect.SparkConnectServerTest
class SparkConnectServiceE2ESuite extends SparkConnectServerTest {
+ // Making results of these queries large enough, so that all the results do
not fit in the
+ // buffers and are not pushed out immediately even when the client doesn't
consume them, so that
+ // even if the connection got closed, the client would see it as succeeded
because the results
+ // were all already in the buffer.
+ val BIG_ENOUGH_QUERY = "select * from range(1000000)"
+
+ test("ReleaseSession releases all queries and does not allow more requests
in the session") {
+ withClient { client =>
+ val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
+ val query2 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
+ val query3 = client.execute(buildPlan("select 1"))
+ // just creating the iterator is lazy, trigger query1 and query2 to be
sent.
+ query1.hasNext
+ query2.hasNext
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ SparkConnectService.executionManager.listExecuteHolders.length == 2
+ }
+
+ // Close session
+ client.releaseSession()
+
+ // Check that queries get cancelled
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ SparkConnectService.executionManager.listExecuteHolders.length == 0
+ // SparkConnectService.sessionManager.
+ }
+
+ // query1 and query2 could get either an:
+ // OPERATION_CANCELED if it happens fast - when closing the session
interrupted the queries,
+ // and that error got pushed to the client buffers before the client got
disconnected.
+ // OPERATION_ABANDONED if it happens slow - when closing the session
interrupted the client
+ // RPCs before it pushed out the error above. The client would then get
an
+ // INVALID_CURSOR.DISCONNECTED, which it will retry with a
ReattachExecute, and then get an
+ // INVALID_HANDLE.OPERATION_ABANDONED.
+ val query1Error = intercept[SparkException] {
+ while (query1.hasNext) query1.next()
+ }
+ assert(
+ query1Error.getMessage.contains("OPERATION_CANCELED") ||
+
query1Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+ val query2Error = intercept[SparkException] {
+ while (query2.hasNext) query2.next()
+ }
+ assert(
+ query2Error.getMessage.contains("OPERATION_CANCELED") ||
+
query2Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+
+ // query3 has not been submitted before, so it should now fail with
SESSION_CLOSED
+ val query3Error = intercept[SparkException] {
+ query3.hasNext
+ }
+ assert(query3Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+
+ // No other requests should be allowed in the session, failing with
SESSION_CLOSED
+ val requestError = intercept[SparkException] {
+ client.interruptAll()
+ }
+ assert(requestError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+ }
+ }
+
+ private def testReleaseSessionTwoSessions(
+ sessionIdA: String,
+ userIdA: String,
+ sessionIdB: String,
+ userIdB: String): Unit = {
+ withClient(sessionId = sessionIdA, userId = userIdA) { clientA =>
+ withClient(sessionId = sessionIdB, userId = userIdB) { clientB =>
+ val queryA = clientA.execute(buildPlan(BIG_ENOUGH_QUERY))
+ val queryB = clientB.execute(buildPlan(BIG_ENOUGH_QUERY))
+ // just creating the iterator is lazy, trigger query1 and query2 to be
sent.
+ queryA.hasNext
+ queryB.hasNext
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ SparkConnectService.executionManager.listExecuteHolders.length == 2
+ }
+ // Close session A
+ clientA.releaseSession()
+
+ // A's query gets kicked out.
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ SparkConnectService.executionManager.listExecuteHolders.length == 1
+ }
+ val queryAError = intercept[SparkException] {
+ while (queryA.hasNext) queryA.next()
+ }
+ assert(
+ queryAError.getMessage.contains("OPERATION_CANCELED") ||
+
queryAError.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+
+ // B's query can run.
+ while (queryB.hasNext) queryB.next()
+
+ // B can submit more queries.
+ val queryB2 = clientB.execute(buildPlan("SELECT 1"))
+ while (queryB2.hasNext) queryB2.next()
+ // A can't submit more queries.
+ val queryA2 = clientA.execute(buildPlan("SELECT 1"))
+ val queryA2Error = intercept[SparkException] {
+ clientA.interruptAll()
+ }
+
assert(queryA2Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+ }
+ }
+ }
+
+ test("ReleaseSession for different user_id with same session_id do not
affect each other") {
+ testReleaseSessionTwoSessions(defaultSessionId, "A", defaultSessionId, "B")
+ }
+
+ test("ReleaseSession for different session_id with same user_id do not
affect each other") {
+ val sessionIdA = UUID.randomUUID.toString()
+ val sessionIdB = UUID.randomUUID.toString()
+ testReleaseSessionTwoSessions(sessionIdA, "X", sessionIdB, "X")
+ }
+
+ test("ReleaseSession: can't create a new session with the same id and user
after release") {
+ val sessionId = UUID.randomUUID.toString()
+ val userId = "Y"
+ withClient(sessionId = sessionId, userId = userId) { client =>
+ // this will create the session, and then ReleaseSession at the end of
withClient.
+ val query = client.execute(buildPlan("SELECT 1"))
+ query.hasNext // trigger execution
+ client.releaseSession()
+ }
+ withClient(sessionId = sessionId, userId = userId) { client =>
+ // shall not be able to create a new session with the same id and user.
+ val query = client.execute(buildPlan("SELECT 1"))
+ val queryError = intercept[SparkException] {
+ while (query.hasNext) query.next()
+ }
+ assert(queryError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+ }
+ }
+
+ test("ReleaseSession: session with different session_id or user_id allowed
after release") {
+ val sessionId = UUID.randomUUID.toString()
+ val userId = "Y"
+ withClient(sessionId = sessionId, userId = userId) { client =>
+ val query = client.execute(buildPlan("SELECT 1"))
+ query.hasNext // trigger execution
+ client.releaseSession()
+ }
+ withClient(sessionId = UUID.randomUUID.toString, userId = userId) { client
=>
+ val query = client.execute(buildPlan("SELECT 1"))
+ query.hasNext // trigger execution
+ client.releaseSession()
+ }
+ withClient(sessionId = sessionId, userId = "YY") { client =>
+ val query = client.execute(buildPlan("SELECT 1"))
+ query.hasNext // trigger execution
+ client.releaseSession()
+ }
+ }
+
test("SPARK-45133 query should reach FINISHED state when results are not
consumed") {
withRawBlockingStub { stub =>
val iter =
diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md
b/docs/sql-error-conditions-invalid-handle-error-class.md
index c4cbb48035ff..14526cd53724 100644
--- a/docs/sql-error-conditions-invalid-handle-error-class.md
+++ b/docs/sql-error-conditions-invalid-handle-error-class.md
@@ -45,6 +45,10 @@ Operation not found.
Session already exists.
+## SESSION_CLOSED
+
+Session was closed.
+
## SESSION_NOT_FOUND
Session not found.
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 318f7d7ade4a..11a1112ad1fe 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -19,7 +19,7 @@ __all__ = [
"SparkConnectClient",
]
-from pyspark.loose_version import LooseVersion
+
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
@@ -61,6 +61,7 @@ import grpc
from google.protobuf import text_format
from google.rpc import error_details_pb2
+from pyspark.loose_version import LooseVersion
from pyspark.version import __version__
from pyspark.resource.information import ResourceInformation
from pyspark.sql.connect.client.artifact import ArtifactManager
@@ -1471,6 +1472,26 @@ class SparkConnectClient(object):
except Exception as error:
self._handle_error(error)
+ def release_session(self) -> None:
+ req = pb2.ReleaseSessionRequest()
+ req.session_id = self._session_id
+ req.client_type = self._builder.userAgent
+ if self._user_id:
+ req.user_context.user_id = self._user_id
+ try:
+ for attempt in self._retrying():
+ with attempt:
+ resp = self._stub.ReleaseSession(req,
metadata=self._builder.metadata())
+ if resp.session_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request:"
+ f"{resp.session_id} != {self._session_id}"
+ )
+ return
+ raise SparkConnectException("Invalid state during retry exception
handling.")
+ except Exception as error:
+ self._handle_error(error)
+
def add_tag(self, tag: str) -> None:
self._throw_if_invalid_tag(tag)
if not hasattr(self.thread_local, "tags"):
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 0ea02525f78f..0e374e7aa2cc 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -199,22 +199,26 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234
_RELEASEEXECUTERESPONSE._serialized_start = 11263
_RELEASEEXECUTERESPONSE._serialized_end = 11375
- _FETCHERRORDETAILSREQUEST._serialized_start = 11378
- _FETCHERRORDETAILSREQUEST._serialized_end = 11579
- _FETCHERRORDETAILSRESPONSE._serialized_start = 11582
- _FETCHERRORDETAILSRESPONSE._serialized_end = 13052
- _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727
- _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901
- _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904
- _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12271
- _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start =
12234
- _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12271
- _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12274
- _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12683
-
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
= 12585
-
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
= 12653
- _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12686
- _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13033
- _SPARKCONNECTSERVICE._serialized_start = 13055
- _SPARKCONNECTSERVICE._serialized_end = 13904
+ _RELEASESESSIONREQUEST._serialized_start = 11378
+ _RELEASESESSIONREQUEST._serialized_end = 11549
+ _RELEASESESSIONRESPONSE._serialized_start = 11551
+ _RELEASESESSIONRESPONSE._serialized_end = 11606
+ _FETCHERRORDETAILSREQUEST._serialized_start = 11609
+ _FETCHERRORDETAILSREQUEST._serialized_end = 11810
+ _FETCHERRORDETAILSRESPONSE._serialized_start = 11813
+ _FETCHERRORDETAILSRESPONSE._serialized_end = 13283
+ _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11958
+ _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12132
+ _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12135
+ _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12502
+ _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start =
12465
+ _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12502
+ _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12505
+ _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12914
+
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
= 12816
+
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
= 12884
+ _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12917
+ _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13264
+ _SPARKCONNECTSERVICE._serialized_start = 13286
+ _SPARKCONNECTSERVICE._serialized_end = 14232
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index c29feb4164cf..20abbcb348bd 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -2763,6 +2763,84 @@ class
ReleaseExecuteResponse(google.protobuf.message.Message):
global___ReleaseExecuteResponse = ReleaseExecuteResponse
+class ReleaseSessionRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ USER_CONTEXT_FIELD_NUMBER: builtins.int
+ CLIENT_TYPE_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ """(Required)
+
+ The session_id of the request to reattach to.
+ This must be an id of existing session.
+ """
+ @property
+ def user_context(self) -> global___UserContext:
+ """(Required) User context
+
+ user_context.user_id and session+id both identify a unique remote
spark session on the
+ server side.
+ """
+ client_type: builtins.str
+ """Provides optional information about the client sending the request.
This field
+ can be used for language or version specific information and is only
intended for
+ logging purposes and will not be interpreted by the server.
+ """
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ user_context: global___UserContext | None = ...,
+ client_type: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_type",
+ b"client_type",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_type",
+ b"client_type",
+ "session_id",
+ b"session_id",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_client_type",
b"_client_type"]
+ ) -> typing_extensions.Literal["client_type"] | None: ...
+
+global___ReleaseSessionRequest = ReleaseSessionRequest
+
+class ReleaseSessionResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ """Session id of the session on which the release executed."""
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["session_id",
b"session_id"]
+ ) -> None: ...
+
+global___ReleaseSessionResponse = ReleaseSessionResponse
+
class FetchErrorDetailsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index f6c5573ded6b..12675747e0f9 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -70,6 +70,11 @@ class SparkConnectServiceStub(object):
request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString,
response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString,
)
+ self.ReleaseSession = channel.unary_unary(
+ "/spark.connect.SparkConnectService/ReleaseSession",
+
request_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString,
+
response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString,
+ )
self.FetchErrorDetails = channel.unary_unary(
"/spark.connect.SparkConnectService/FetchErrorDetails",
request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString,
@@ -141,6 +146,16 @@ class SparkConnectServiceServicer(object):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
+ def ReleaseSession(self, request, context):
+ """Release a session.
+ All the executions in the session will be released. Any further
requests for the session with
+ that session_id for the given user_id will fail. If the session didn't
exist or was already
+ released, this is a noop.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
def FetchErrorDetails(self, request, context):
"""FetchErrorDetails retrieves the matched exception with details
based on a provided error id."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -190,6 +205,11 @@ def add_SparkConnectServiceServicer_to_server(servicer,
server):
request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString,
),
+ "ReleaseSession": grpc.unary_unary_rpc_method_handler(
+ servicer.ReleaseSession,
+
request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.FromString,
+
response_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.SerializeToString,
+ ),
"FetchErrorDetails": grpc.unary_unary_rpc_method_handler(
servicer.FetchErrorDetails,
request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString,
@@ -438,6 +458,35 @@ class SparkConnectService(object):
metadata,
)
+ @staticmethod
+ def ReleaseSession(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/ReleaseSession",
+
spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString,
+ spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
@staticmethod
def FetchErrorDetails(
request,
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 09bd60606c76..1aa857b4f617 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -254,6 +254,9 @@ class SparkSession:
self._client = SparkConnectClient(connection=connection,
user_id=userId)
self._session_id = self._client._session_id
+ # Set to false to prevent client.release_session on close() (testing
only)
+ self.release_session_on_close = True
+
@classmethod
def _set_default_and_active_session(cls, session: "SparkSession") -> None:
"""
@@ -645,15 +648,16 @@ class SparkSession:
clearTags.__doc__ = PySparkSession.clearTags.__doc__
def stop(self) -> None:
- # Stopping the session will only close the connection to the current
session (and
- # the life cycle of the session is maintained by the server),
- # whereas the regular PySpark session immediately terminates the Spark
Context
- # itself, meaning that stopping all Spark sessions.
+ # Whereas the regular PySpark session immediately terminates the Spark
Context
+ # itself, meaning that stopping all Spark sessions, this will only
stop this one session
+ # on the server.
# It is controversial to follow the existing the regular Spark
session's behavior
# specifically in Spark Connect the Spark Connect server is designed
for
# multi-tenancy - the remote client side cannot just stop the server
and stop
# other remote clients being used from other users.
with SparkSession._lock:
+ if not self.is_stopped and self.release_session_on_close:
+ self.client.release_session()
self.client.close()
if self is SparkSession._default_session:
SparkSession._default_session = None
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 34bd314c76f7..f024a03c2686 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3437,6 +3437,7 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
# Gets currently active session.
same =
PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate()
self.assertEquals(other, same)
+ same.release_session_on_close = False # avoid sending release to
dummy connection
same.stop()
# Make sure the environment is clean.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]