This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 59e291d36c4a [SPARK-45680][CONNECT] Release session
59e291d36c4a is described below

commit 59e291d36c4a9d956b993968a324359b3d75fe5f
Author: Juliusz Sompolski <[email protected]>
AuthorDate: Thu Nov 2 09:11:48 2023 +0900

    [SPARK-45680][CONNECT] Release session
    
    ### What changes were proposed in this pull request?
    
    Introduce a new `ReleaseSession` Spark Connect RPC, which cancels 
everything running in the session and removes the session server side. Refactor 
code around managing the cache of sessions into `SparkConnectSessionManager`.
    
    ### Why are the changes needed?
    
    Better session management.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Not really. `SparkSession.stop()` API already existed on the client side. 
It was closing the client's network connection, but the Session was still there 
cached for 1 hour on the server side.
    
    Caveats, which were not really supported user behaviour:
    * After `session.stop()`, user could have created a new session with the 
same session_id in Configuration. That session would be a new session on the 
client side, but connect to the old cached session in the server. It could 
therefore e.g. access that old session's state like views or artifacts.
    * If a session timed out and was removed in the server, it used to be that 
a new request would re-create the session. The client would then see this as 
the old session, but the server would see a new one, and e.g. not have access 
to old session state that was removed.
    * User is no longer allowed to create a new session with the same 
session_id as before.
    
    ### How was this patch tested?
    
    Tests added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43546 from juliuszsompolski/release-session.
    
    Lead-authored-by: Juliusz Sompolski <[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../src/main/resources/error/error-classes.json    |   5 +
 .../scala/org/apache/spark/sql/SparkSession.scala  |   8 +
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 +-
 .../org/apache/spark/sql/SparkSessionSuite.scala   |  38 +++--
 .../src/main/protobuf/spark/connect/base.proto     |  30 ++++
 .../client/CustomSparkConnectBlockingStub.scala    |  11 ++
 .../sql/connect/client/SparkConnectClient.scala    |  10 ++
 .../apache/spark/sql/connect/config/Connect.scala  |  18 +++
 .../spark/sql/connect/service/SessionHolder.scala  |  79 ++++++++-
 .../service/SparkConnectExecutionManager.scala     |  23 ++-
 .../SparkConnectReleaseExecuteHandler.scala        |   4 +-
 .../SparkConnectReleaseSessionHandler.scala        |  40 +++++
 .../sql/connect/service/SparkConnectService.scala  | 117 +++-----------
 .../service/SparkConnectSessionManager.scala       | 177 +++++++++++++++++++++
 .../spark/sql/connect/utils/ErrorUtils.scala       |  27 ++--
 .../spark/sql/connect/SparkConnectServerTest.scala |  21 ++-
 .../execution/ReattachableExecuteSuite.scala       |   4 +
 .../connect/planner/SparkConnectServiceSuite.scala |   4 +-
 .../service/SparkConnectServiceE2ESuite.scala      | 158 ++++++++++++++++++
 ...-error-conditions-invalid-handle-error-class.md |   4 +
 python/pyspark/sql/connect/client/core.py          |  23 ++-
 python/pyspark/sql/connect/proto/base_pb2.py       |  42 ++---
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  78 +++++++++
 python/pyspark/sql/connect/proto/base_pb2_grpc.py  |  49 ++++++
 python/pyspark/sql/connect/session.py              |  12 +-
 .../sql/tests/connect/test_connect_basic.py        |   1 +
 26 files changed, 819 insertions(+), 168 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 278011b8cc8f..af32bcf129c0 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -1737,6 +1737,11 @@
           "Session already exists."
         ]
       },
+      "SESSION_CLOSED" : {
+        "message" : [
+          "Session was closed."
+        ]
+      },
       "SESSION_NOT_FOUND" : {
         "message" : [
           "Session not found."
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 969ac017ecb1..1cc1c8400fa8 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -665,6 +665,9 @@ class SparkSession private[sql] (
    * @since 3.4.0
    */
   override def close(): Unit = {
+    if (releaseSessionOnClose) {
+      client.releaseSession()
+    }
     client.shutdown()
     allocator.close()
     SparkSession.onSessionClose(this)
@@ -735,6 +738,11 @@ class SparkSession private[sql] (
    * We null out the instance for now.
    */
   private def writeReplace(): Any = null
+
+  /**
+   * Set to false to prevent client.releaseSession on close() (testing only)
+   */
+  private[sql] var releaseSessionOnClose = true
 }
 
 // The minimal builder needed to create a spark session.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index cf287088b59f..5cc63bc45a04 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -120,7 +120,9 @@ class PlanGenerationTestSuite
   }
 
   override protected def afterAll(): Unit = {
-    session.close()
+    // Don't call client.releaseSession on close(), because the connection 
details are dummy.
+    session.releaseSessionOnClose = false
+    session.stop()
     if (cleanOrphanedGoldenFiles) {
       cleanOrphanedGoldenFile()
     }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
index 4c858262c6ef..8abc41639fdd 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -33,18 +33,24 @@ class SparkSessionSuite extends ConnectFunSuite {
   private val connectionString2: String = "sc://test.me:14099"
   private val connectionString3: String = "sc://doit:16845"
 
+  private def closeSession(session: SparkSession): Unit = {
+    // Don't call client.releaseSession on close(), because the connection 
details are dummy.
+    session.releaseSessionOnClose = false
+    session.close()
+  }
+
   test("default") {
     val session = SparkSession.builder().getOrCreate()
     assert(session.client.configuration.host == "localhost")
     assert(session.client.configuration.port == 15002)
-    session.close()
+    closeSession(session)
   }
 
   test("remote") {
     val session = 
SparkSession.builder().remote(connectionString2).getOrCreate()
     assert(session.client.configuration.host == "test.me")
     assert(session.client.configuration.port == 14099)
-    session.close()
+    closeSession(session)
   }
 
   test("getOrCreate") {
@@ -53,8 +59,8 @@ class SparkSessionSuite extends ConnectFunSuite {
     try {
       assert(session1 eq session2)
     } finally {
-      session1.close()
-      session2.close()
+      closeSession(session1)
+      closeSession(session2)
     }
   }
 
@@ -65,8 +71,8 @@ class SparkSessionSuite extends ConnectFunSuite {
       assert(session1 ne session2)
       assert(session1.client.configuration == session2.client.configuration)
     } finally {
-      session1.close()
-      session2.close()
+      closeSession(session1)
+      closeSession(session2)
     }
   }
 
@@ -77,8 +83,8 @@ class SparkSessionSuite extends ConnectFunSuite {
       assert(session1 ne session2)
       assert(session1.client.configuration == session2.client.configuration)
     } finally {
-      session1.close()
-      session2.close()
+      closeSession(session1)
+      closeSession(session2)
     }
   }
 
@@ -98,7 +104,7 @@ class SparkSessionSuite extends ConnectFunSuite {
     assertThrows[RuntimeException] {
       session.range(10).count()
     }
-    session.close()
+    closeSession(session)
   }
 
   test("Default/Active session") {
@@ -136,12 +142,12 @@ class SparkSessionSuite extends ConnectFunSuite {
     assert(SparkSession.getActiveSession.contains(session1))
 
     // Close session1
-    session1.close()
+    closeSession(session1)
     assert(SparkSession.getDefaultSession.contains(session2))
     assert(SparkSession.getActiveSession.isEmpty)
 
     // Close session2
-    session2.close()
+    closeSession(session2)
     assert(SparkSession.getDefaultSession.isEmpty)
     assert(SparkSession.getActiveSession.isEmpty)
   }
@@ -187,7 +193,7 @@ class SparkSessionSuite extends ConnectFunSuite {
 
         // Step 3 - close session 1, no more default session in both scripts
         phaser.arriveAndAwaitAdvance()
-        session1.close()
+        closeSession(session1)
 
         // Step 4 - no default session, same active session.
         phaser.arriveAndAwaitAdvance()
@@ -240,13 +246,13 @@ class SparkSessionSuite extends ConnectFunSuite {
 
         // Step 7 - close active session in script2
         phaser.arriveAndAwaitAdvance()
-        internalSession.close()
+        closeSession(internalSession)
         assert(SparkSession.getActiveSession.isEmpty)
       }
       assert(script1.get())
       assert(script2.get())
       assert(SparkSession.getActiveSession.contains(session2))
-      session2.close()
+      closeSession(session2)
       assert(SparkSession.getActiveSession.isEmpty)
     } finally {
       executor.shutdown()
@@ -254,13 +260,13 @@ class SparkSessionSuite extends ConnectFunSuite {
   }
 
   test("deprecated methods") {
-    SparkSession
+    val session = SparkSession
       .builder()
       .master("yayay")
       .appName("bob")
       .enableHiveSupport()
       .create()
-      .close()
+    closeSession(session)
   }
 
   test("serialize as null") {
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 27f51551ba92..19a94a5a429f 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -784,6 +784,30 @@ message ReleaseExecuteResponse {
   optional string operation_id = 2;
 }
 
+message ReleaseSessionRequest {
+  // (Required)
+  //
+  // The session_id of the request to reattach to.
+  // This must be an id of existing session.
+  string session_id = 1;
+
+  // (Required) User context
+  //
+  // user_context.user_id and session+id both identify a unique remote spark 
session on the
+  // server side.
+  UserContext user_context = 2;
+
+  // Provides optional information about the client sending the request. This 
field
+  // can be used for language or version specific information and is only 
intended for
+  // logging purposes and will not be interpreted by the server.
+  optional string client_type = 3;
+}
+
+message ReleaseSessionResponse {
+  // Session id of the session on which the release executed.
+  string session_id = 1;
+}
+
 message FetchErrorDetailsRequest {
 
   // (Required)
@@ -934,6 +958,12 @@ service SparkConnectService {
   // RPC and ReleaseExecute may not be used.
   rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {}
 
+  // Release a session.
+  // All the executions in the session will be released. Any further requests 
for the session with
+  // that session_id for the given user_id will fail. If the session didn't 
exist or was already
+  // released, this is a noop.
+  rpc ReleaseSession(ReleaseSessionRequest) returns (ReleaseSessionResponse) {}
+
   // FetchErrorDetails retrieves the matched exception with details based on a 
provided error id.
   rpc FetchErrorDetails(FetchErrorDetailsRequest) returns 
(FetchErrorDetailsResponse) {}
 }
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index f2efa26f6b60..e963b4136160 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -96,6 +96,17 @@ private[connect] class CustomSparkConnectBlockingStub(
     }
   }
 
+  def releaseSession(request: ReleaseSessionRequest): ReleaseSessionResponse = 
{
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
+      retryHandler.retry {
+        stub.releaseSession(request)
+      }
+    }
+  }
+
   def artifactStatus(request: ArtifactStatusesRequest): 
ArtifactStatusesResponse = {
     grpcExceptionConverter.convert(
       request.getSessionId,
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 42ace003da89..6d3d9420e226 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -243,6 +243,16 @@ private[sql] class SparkConnectClient(
     bstub.interrupt(request)
   }
 
+  private[sql] def releaseSession(): proto.ReleaseSessionResponse = {
+    val builder = proto.ReleaseSessionRequest.newBuilder()
+    val request = builder
+      .setUserContext(userContext)
+      .setSessionId(sessionId)
+      .setClientType(userAgent)
+      .build()
+    bstub.releaseSession(request)
+  }
+
   private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
     override def childValue(parent: mutable.Set[String]): mutable.Set[String] 
= {
       // Note: make a clone such that changes in the parent tags aren't 
reflected in
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 2b3f218362cd..1a5944676f5f 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -74,6 +74,24 @@ object Connect {
       .intConf
       .createWithDefault(1024)
 
+  val CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT =
+    buildStaticConf("spark.connect.session.manager.defaultSessionTimeout")
+      .internal()
+      .doc("Timeout after which sessions without any new incoming RPC will be 
removed.")
+      .version("4.0.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("60m")
+
+  val CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE =
+    
buildStaticConf("spark.connect.session.manager.closedSessionsTombstonesSize")
+      .internal()
+      .doc(
+        "Maximum size of the cache of sessions after which sessions that did 
not receive any " +
+          "requests will be removed.")
+      .version("4.0.0")
+      .intConf
+      .createWithDefaultString("1000")
+
   val CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT =
     buildStaticConf("spark.connect.execute.manager.detachedTimeout")
       .internal()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index dcced21f3714..792012a682b2 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._
 import com.google.common.base.Ticker
 import com.google.common.cache.CacheBuilder
 
-import org.apache.spark.{JobArtifactSet, SparkException}
+import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.SparkSession
@@ -40,12 +40,19 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
 import org.apache.spark.util.SystemClock
 import org.apache.spark.util.Utils
 
+// Unique key identifying session by combination of user, and session id
+case class SessionKey(userId: String, sessionId: String)
+
 /**
  * Object used to hold the Spark Connect session state.
  */
 case class SessionHolder(userId: String, sessionId: String, session: 
SparkSession)
     extends Logging {
 
+  @volatile private var lastRpcAccessTime: Option[Long] = None
+
+  @volatile private var isClosing: Boolean = false
+
   private val executions: ConcurrentMap[String, ExecuteHolder] =
     new ConcurrentHashMap[String, ExecuteHolder]()
 
@@ -73,8 +80,21 @@ case class SessionHolder(userId: String, sessionId: String, 
session: SparkSessio
   private[connect] lazy val streamingForeachBatchRunnerCleanerCache =
     new StreamingForeachBatchHelper.CleanerCache(this)
 
-  /** Add ExecuteHolder to this session. Called only by 
SparkConnectExecutionManager. */
+  def key: SessionKey = SessionKey(userId, sessionId)
+
+  /**
+   * Add ExecuteHolder to this session.
+   *
+   * Called only by SparkConnectExecutionManager under executionsLock.
+   */
   private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = {
+    if (isClosing) {
+      // Do not accept new executions if the session is closing.
+      throw new SparkSQLException(
+        errorClass = "INVALID_HANDLE.SESSION_CLOSED",
+        messageParameters = Map("handle" -> sessionId))
+    }
+
     val oldExecute = executions.putIfAbsent(executeHolder.operationId, 
executeHolder)
     if (oldExecute != null) {
       // the existence of this should alrady be checked by 
SparkConnectExecutionManager
@@ -160,21 +180,55 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
    */
   def classloader: ClassLoader = artifactManager.classloader
 
+  private[connect] def updateAccessTime(): Unit = {
+    lastRpcAccessTime = Some(System.currentTimeMillis())
+  }
+
+  /**
+   * Initialize the session.
+   *
+   * Called only by SparkConnectSessionManager.
+   */
   private[connect] def initializeSession(): Unit = {
+    updateAccessTime()
     eventManager.postStarted()
   }
 
   /**
    * Expire this session and trigger state cleanup mechanisms.
+   *
+   * Called only by SparkConnectSessionManager.
    */
-  private[connect] def expireSession(): Unit = {
-    logDebug(s"Expiring session with userId: $userId and sessionId: 
$sessionId")
+  private[connect] def close(): Unit = {
+    logInfo(s"Closing session with userId: $userId and sessionId: $sessionId")
+
+    // After isClosing=true, SessionHolder.addExecuteHolder() will not allow 
new executions for
+    // this session. Because both SessionHolder.addExecuteHolder() and
+    // SparkConnectExecutionManager.removeAllExecutionsForSession() are 
executed under
+    // executionsLock, this guarantees that removeAllExecutionsForSession 
triggered below will
+    // remove all executions and no new executions will be added in the 
meanwhile.
+    isClosing = true
+
+    // Note on the below notes about concurrency:
+    // While closing the session can potentially race with operations started 
on the session, the
+    // intended use is that the client session will get closed when it's 
really not used anymore,
+    // or that it expires due to inactivity, in which case there should be no 
races.
+
+    // Clean up all artifacts.
+    // Note: there can be concurrent AddArtifact calls still adding something.
     artifactManager.cleanUpResources()
-    eventManager.postClosed()
-    // Clean up running queries
+
+    // Clean up running streaming queries.
+    // Note: there can be concurrent streaming queries being started.
     SparkConnectService.streamingSessionManager.cleanupRunningQueries(this)
     streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any 
streaming workers.
     removeAllListeners() // removes all listener and stop python listener 
processes if necessary.
+
+    // Clean up all executions
+    // It is guaranteed at this point that no new addExecuteHolder are getting 
started.
+    
SparkConnectService.executionManager.removeAllExecutionsForSession(this.key)
+
+    eventManager.postClosed()
   }
 
   /**
@@ -204,6 +258,10 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
     }
   }
 
+  /** Get SessionInfo with information about this SessionHolder. */
+  def getSessionHolderInfo: SessionHolderInfo =
+    SessionHolderInfo(userId, sessionId, eventManager.status, 
lastRpcAccessTime)
+
   /**
    * Caches given DataFrame with the ID. The cache does not expire. The entry 
needs to be
    * explicitly removed by the owners of the DataFrame once it is not needed.
@@ -291,7 +349,14 @@ object SessionHolder {
         userId = "testUser",
         sessionId = UUID.randomUUID().toString,
         session = session)
-    SparkConnectService.putSessionForTesting(ret)
+    SparkConnectService.sessionManager.putSessionForTesting(ret)
     ret
   }
 }
+
+/** Basic information about SessionHolder. */
+case class SessionHolderInfo(
+    userId: String,
+    sessionId: String,
+    status: SessionStatus,
+    lastRpcAccesTime: Option[Long])
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index 3c7254897822..c004358e1cf1 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -95,11 +95,16 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
    * Remove an ExecuteHolder from this global manager and from its session. 
Interrupt the
    * execution if still running, free all resources.
    */
-  private[connect] def removeExecuteHolder(key: ExecuteKey): Unit = {
+  private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean 
= false): Unit = {
     var executeHolder: Option[ExecuteHolder] = None
     executionsLock.synchronized {
       executeHolder = executions.remove(key)
-      executeHolder.foreach(e => 
e.sessionHolder.removeExecuteHolder(e.operationId))
+      executeHolder.foreach { e =>
+        if (abandoned) {
+          abandonedTombstones.put(key, e.getExecuteInfo)
+        }
+        e.sessionHolder.removeExecuteHolder(e.operationId)
+      }
       if (executions.isEmpty) {
         lastExecutionTime = Some(System.currentTimeMillis())
       }
@@ -115,6 +120,17 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
     }
   }
 
+  private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
+    val sessionExecutionHolders = executionsLock.synchronized {
+      executions.filter(_._2.sessionHolder.key == key)
+    }
+    sessionExecutionHolders.foreach { case (_, executeHolder) =>
+      val info = executeHolder.getExecuteInfo
+      logInfo(s"Execution $info removed in removeSessionExecutions.")
+      removeExecuteHolder(executeHolder.key, abandoned = true)
+    }
+  }
+
   /** Get info about abandoned execution, if there is one. */
   private[connect] def getAbandonedTombstone(key: ExecuteKey): 
Option[ExecuteInfo] = {
     Option(abandonedTombstones.getIfPresent(key))
@@ -204,8 +220,7 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
       toRemove.foreach { executeHolder =>
         val info = executeHolder.getExecuteInfo
         logInfo(s"Found execution $info that was abandoned and expired and 
will be removed.")
-        removeExecuteHolder(executeHolder.key)
-        abandonedTombstones.put(executeHolder.key, info)
+        removeExecuteHolder(executeHolder.key, abandoned = true)
       }
     }
     logInfo("Finished periodic run of SparkConnectExecutionManager 
maintenance.")
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
index a3a7815609e4..1ca886960d53 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala
@@ -28,8 +28,8 @@ class SparkConnectReleaseExecuteHandler(
     extends Logging {
 
   def handle(v: proto.ReleaseExecuteRequest): Unit = {
-    val sessionHolder = SparkConnectService
-      .getIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+    val sessionHolder = SparkConnectService.sessionManager
+      .getIsolatedSession(SessionKey(v.getUserContext.getUserId, 
v.getSessionId))
 
     val responseBuilder = proto.ReleaseExecuteResponse
       .newBuilder()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
new file mode 100644
index 000000000000..a32852bac45e
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+
+class SparkConnectReleaseSessionHandler(
+    responseObserver: StreamObserver[proto.ReleaseSessionResponse])
+    extends Logging {
+
+  def handle(v: proto.ReleaseSessionRequest): Unit = {
+    val responseBuilder = proto.ReleaseSessionResponse.newBuilder()
+    responseBuilder.setSessionId(v.getSessionId)
+
+    // If the session doesn't exist, this will just be a noop.
+    val key = SessionKey(v.getUserContext.getUserId, v.getSessionId)
+    SparkConnectService.sessionManager.closeSession(key)
+
+    responseObserver.onNext(responseBuilder.build())
+    responseObserver.onCompleted()
+  }
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index e82c9cba5626..e4b60eeeff0d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -18,13 +18,10 @@
 package org.apache.spark.sql.connect.service
 
 import java.net.InetSocketAddress
-import java.util.UUID
-import java.util.concurrent.{Callable, TimeUnit}
+import java.util.concurrent.TimeUnit
 
 import scala.jdk.CollectionConverters._
 
-import com.google.common.base.Ticker
-import com.google.common.cache.{CacheBuilder, RemovalListener, 
RemovalNotification}
 import com.google.protobuf.MessageLite
 import io.grpc.{BindableService, MethodDescriptor, Server, 
ServerMethodDefinition, ServerServiceDefinition}
 import io.grpc.MethodDescriptor.PrototypeMarshaller
@@ -34,13 +31,12 @@ import io.grpc.protobuf.services.ProtoReflectionService
 import io.grpc.stub.StreamObserver
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
+import org.apache.spark.{SparkContext, SparkEnv}
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{AddArtifactsRequest, 
AddArtifactsResponse, SparkConnectServiceGrpc}
 import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.UI.UI_ENABLED
-import org.apache.spark.sql.SparkSession
 import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, 
CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, 
CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}
 import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, 
SparkConnectServerListener, SparkConnectServerTab}
 import org.apache.spark.sql.connect.utils.ErrorUtils
@@ -201,6 +197,22 @@ class SparkConnectService(debug: Boolean) extends 
AsyncService with BindableServ
         sessionId = request.getSessionId)
   }
 
+  /**
+   * Release session.
+   */
+  override def releaseSession(
+      request: proto.ReleaseSessionRequest,
+      responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = {
+    try {
+      new SparkConnectReleaseSessionHandler(responseObserver).handle(request)
+    } catch
+      ErrorUtils.handleError(
+        "releaseSession",
+        observer = responseObserver,
+        userId = request.getUserContext.getUserId,
+        sessionId = request.getSessionId)
+  }
+
   override def fetchErrorDetails(
       request: proto.FetchErrorDetailsRequest,
       responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit 
= {
@@ -268,14 +280,6 @@ class SparkConnectService(debug: Boolean) extends 
AsyncService with BindableServ
  */
 object SparkConnectService extends Logging {
 
-  private val CACHE_SIZE = 100
-
-  private val CACHE_TIMEOUT_SECONDS = 3600
-
-  // Type alias for the SessionCacheKey. Right now this is a String but allows 
us to switch to a
-  // different or complex type easily.
-  private type SessionCacheKey = (String, String)
-
   private[connect] var server: Server = _
 
   private[connect] var uiTab: Option[SparkConnectServerTab] = None
@@ -289,77 +293,18 @@ object SparkConnectService extends Logging {
     server.getPort
   }
 
-  private val userSessionMapping =
-    cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, 
SessionHolder]()
-
   private[connect] lazy val executionManager = new 
SparkConnectExecutionManager()
 
+  private[connect] lazy val sessionManager = new SparkConnectSessionManager()
+
   private[connect] val streamingSessionManager =
     new SparkConnectStreamingQueryCache()
 
-  private class RemoveSessionListener extends RemovalListener[SessionCacheKey, 
SessionHolder] {
-    override def onRemoval(
-        notification: RemovalNotification[SessionCacheKey, SessionHolder]): 
Unit = {
-      notification.getValue.expireSession()
-    }
-  }
-
-  // Simple builder for creating the cache of Sessions.
-  private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): 
CacheBuilder[Object, Object] = {
-    var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker())
-    if (cacheSize >= 0) {
-      cacheBuilder = cacheBuilder.maximumSize(cacheSize)
-    }
-    if (timeoutSeconds >= 0) {
-      cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS)
-    }
-    cacheBuilder.removalListener(new RemoveSessionListener)
-    cacheBuilder
-  }
-
   /**
    * Based on the userId and sessionId, find or create a new SparkSession.
    */
   def getOrCreateIsolatedSession(userId: String, sessionId: String): 
SessionHolder = {
-    getSessionOrDefault(
-      userId,
-      sessionId,
-      () => {
-        val holder = SessionHolder(userId, sessionId, newIsolatedSession())
-        holder.initializeSession()
-        holder
-      })
-  }
-
-  /**
-   * Based on the userId and sessionId, find an existing SparkSession or throw 
error.
-   */
-  def getIsolatedSession(userId: String, sessionId: String): SessionHolder = {
-    getSessionOrDefault(
-      userId,
-      sessionId,
-      () => {
-        logDebug(s"Session not found: ($userId, $sessionId)")
-        throw new SparkSQLException(
-          errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND",
-          messageParameters = Map("handle" -> sessionId))
-      })
-  }
-
-  private def getSessionOrDefault(
-      userId: String,
-      sessionId: String,
-      default: Callable[SessionHolder]): SessionHolder = {
-    // Validate that sessionId is formatted like UUID before creating session.
-    try {
-      UUID.fromString(sessionId).toString
-    } catch {
-      case _: IllegalArgumentException =>
-        throw new SparkSQLException(
-          errorClass = "INVALID_HANDLE.FORMAT",
-          messageParameters = Map("handle" -> sessionId))
-    }
-    userSessionMapping.get((userId, sessionId), default)
+    sessionManager.getOrCreateIsolatedSession(SessionKey(userId, sessionId))
   }
 
   /**
@@ -368,24 +313,6 @@ object SparkConnectService extends Logging {
    */
   def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = 
executionManager.listActiveExecutions
 
-  /**
-   * Used for testing
-   */
-  private[connect] def invalidateAllSessions(): Unit = {
-    userSessionMapping.invalidateAll()
-  }
-
-  /**
-   * Used for testing.
-   */
-  private[connect] def putSessionForTesting(sessionHolder: SessionHolder): 
Unit = {
-    userSessionMapping.put((sessionHolder.userId, sessionHolder.sessionId), 
sessionHolder)
-  }
-
-  private def newIsolatedSession(): SparkSession = {
-    SparkSession.active.newSession()
-  }
-
   private def createListenerAndUI(sc: SparkContext): Unit = {
     val kvStore = sc.statusStore.store.asInstanceOf[ElementTrackingStore]
     listener = new SparkConnectServerListener(kvStore, sc.conf)
@@ -445,7 +372,7 @@ object SparkConnectService extends Logging {
     }
     streamingSessionManager.shutdown()
     executionManager.shutdown()
-    userSessionMapping.invalidateAll()
+    sessionManager.shutdown()
     uiTab.foreach(_.detach())
   }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
new file mode 100644
index 000000000000..5c8e3c611586
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+import java.util.concurrent.{Callable, TimeUnit}
+
+import com.google.common.base.Ticker
+import com.google.common.cache.{CacheBuilder, RemovalListener, 
RemovalNotification}
+
+import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE,
 CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT}
+
+/**
+ * Global tracker of all SessionHolders holding Spark Connect sessions.
+ */
+class SparkConnectSessionManager extends Logging {
+
+  private val sessionsLock = new Object
+
+  private val sessionStore =
+    CacheBuilder
+      .newBuilder()
+      .ticker(Ticker.systemTicker())
+      .expireAfterAccess(
+        SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT),
+        TimeUnit.MILLISECONDS)
+      .removalListener(new RemoveSessionListener)
+      .build[SessionKey, SessionHolder]()
+
+  private val closedSessionsCache =
+    CacheBuilder
+      .newBuilder()
+      
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
+      .build[SessionKey, SessionHolderInfo]()
+
+  /**
+   * Based on the userId and sessionId, find or create a new SparkSession.
+   */
+  private[connect] def getOrCreateIsolatedSession(key: SessionKey): 
SessionHolder = {
+    // Lock to guard against concurrent removal and insertion into 
closedSessionsCache.
+    sessionsLock.synchronized {
+      getSession(
+        key,
+        Some(() => {
+          validateSessionCreate(key)
+          val holder = SessionHolder(key.userId, key.sessionId, 
newIsolatedSession())
+          holder.initializeSession()
+          holder
+        }))
+    }
+  }
+
+  /**
+   * Based on the userId and sessionId, find an existing SparkSession or throw 
error.
+   */
+  private[connect] def getIsolatedSession(key: SessionKey): SessionHolder = {
+    getSession(
+      key,
+      Some(() => {
+        logDebug(s"Session not found: $key")
+        if (closedSessionsCache.getIfPresent(key) != null) {
+          throw new SparkSQLException(
+            errorClass = "INVALID_HANDLE.SESSION_CLOSED",
+            messageParameters = Map("handle" -> key.sessionId))
+        } else {
+          throw new SparkSQLException(
+            errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND",
+            messageParameters = Map("handle" -> key.sessionId))
+        }
+      }))
+  }
+
+  /**
+   * Based on the userId and sessionId, get an existing SparkSession if 
present.
+   */
+  private[connect] def getIsolatedSessionIfPresent(key: SessionKey): 
Option[SessionHolder] = {
+    Option(getSession(key, None))
+  }
+
+  private def getSession(
+      key: SessionKey,
+      default: Option[Callable[SessionHolder]]): SessionHolder = {
+    val session = default match {
+      case Some(callable) => sessionStore.get(key, callable)
+      case None => sessionStore.getIfPresent(key)
+    }
+    // record access time before returning
+    session match {
+      case null =>
+        null
+      case s: SessionHolder =>
+        s.updateAccessTime()
+        s
+    }
+  }
+
+  def closeSession(key: SessionKey): Unit = {
+    // Invalidate will trigger RemoveSessionListener
+    sessionStore.invalidate(key)
+  }
+
+  private class RemoveSessionListener extends RemovalListener[SessionKey, 
SessionHolder] {
+    override def onRemoval(notification: RemovalNotification[SessionKey, 
SessionHolder]): Unit = {
+      val sessionHolder = notification.getValue
+      sessionsLock.synchronized {
+        // First put into closedSessionsCache, so that it cannot get 
accidentally recreated by
+        // getOrCreateIsolatedSession.
+        closedSessionsCache.put(sessionHolder.key, 
sessionHolder.getSessionHolderInfo)
+      }
+      // Rest of the cleanup outside sessionLock - the session cannot be 
accessed anymore by
+      // getOrCreateIsolatedSession.
+      sessionHolder.close()
+    }
+  }
+
+  def shutdown(): Unit = {
+    sessionsLock.synchronized {
+      sessionStore.invalidateAll()
+      closedSessionsCache.invalidateAll()
+    }
+  }
+
+  private def newIsolatedSession(): SparkSession = {
+    SparkSession.active.newSession()
+  }
+
+  private def validateSessionCreate(key: SessionKey): Unit = {
+    // Validate that sessionId is formatted like UUID before creating session.
+    try {
+      UUID.fromString(key.sessionId).toString
+    } catch {
+      case _: IllegalArgumentException =>
+        throw new SparkSQLException(
+          errorClass = "INVALID_HANDLE.FORMAT",
+          messageParameters = Map("handle" -> key.sessionId))
+    }
+    // Validate that session with that key has not been already closed.
+    if (closedSessionsCache.getIfPresent(key) != null) {
+      throw new SparkSQLException(
+        errorClass = "INVALID_HANDLE.SESSION_CLOSED",
+        messageParameters = Map("handle" -> key.sessionId))
+    }
+  }
+
+  /**
+   * Used for testing
+   */
+  private[connect] def invalidateAllSessions(): Unit = {
+    sessionStore.invalidateAll()
+    closedSessionsCache.invalidateAll()
+  }
+
+  /**
+   * Used for testing.
+   */
+  private[connect] def putSessionForTesting(sessionHolder: SessionHolder): 
Unit = {
+    sessionStore.put(sessionHolder.key, sessionHolder)
+  }
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 741fa97f1787..837ee5a00227 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -41,7 +41,7 @@ import org.apache.spark.api.python.PythonException
 import org.apache.spark.connect.proto.FetchErrorDetailsResponse
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.config.Connect
-import org.apache.spark.sql.connect.service.{ExecuteEventsManager, 
SessionHolder, SparkConnectService}
+import org.apache.spark.sql.connect.service.{ExecuteEventsManager, 
SessionHolder, SessionKey, SparkConnectService}
 import org.apache.spark.sql.internal.SQLConf
 
 private[connect] object ErrorUtils extends Logging {
@@ -153,7 +153,9 @@ private[connect] object ErrorUtils extends Logging {
       .build()
   }
 
-  private def buildStatusFromThrowable(st: Throwable, sessionHolder: 
SessionHolder): RPCStatus = {
+  private def buildStatusFromThrowable(
+      st: Throwable,
+      sessionHolderOpt: Option[SessionHolder]): RPCStatus = {
     val errorInfo = ErrorInfo
       .newBuilder()
       .setReason(st.getClass.getName)
@@ -162,20 +164,20 @@ private[connect] object ErrorUtils extends Logging {
         "classes",
         
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))
 
-    if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) {
+    if 
(sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)))
 {
       // Generate a new unique key for this exception.
       val errorId = UUID.randomUUID().toString
 
       errorInfo.putMetadata("errorId", errorId)
 
-      sessionHolder.errorIdToError
+      sessionHolderOpt.get.errorIdToError
         .put(errorId, st)
     }
 
     lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
     val withStackTrace =
-      if (sessionHolder.session.conf.get(
-          SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) {
+      if (sessionHolderOpt.exists(
+          _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && 
stackTrace.nonEmpty)) {
         val maxSize = 
SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE)
         errorInfo.putMetadata("stackTrace", 
StringUtils.abbreviate(stackTrace.get, maxSize))
       } else {
@@ -215,19 +217,22 @@ private[connect] object ErrorUtils extends Logging {
       sessionId: String,
       events: Option[ExecuteEventsManager] = None,
       isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = {
-    val sessionHolder =
-      SparkConnectService
-        .getOrCreateIsolatedSession(userId, sessionId)
+
+    // SessionHolder may not be present, e.g. if the session was already 
closed.
+    // When SessionHolder is not present error details will not be available 
for FetchErrorDetails.
+    val sessionHolderOpt =
+      SparkConnectService.sessionManager.getIsolatedSessionIfPresent(
+        SessionKey(userId, sessionId))
 
     val partial: PartialFunction[Throwable, (Throwable, Throwable)] = {
       case se: SparkException if isPythonExecutionException(se) =>
         (
           se,
           StatusProto.toStatusRuntimeException(
-            buildStatusFromThrowable(se.getCause, sessionHolder)))
+            buildStatusFromThrowable(se.getCause, sessionHolderOpt)))
 
       case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) 
=>
-        (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, 
sessionHolder)))
+        (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, 
sessionHolderOpt)))
 
       case e: Throwable =>
         (
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 7b02377f4847..120126f20ec2 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -59,10 +59,6 @@ trait SparkConnectServerTest extends SharedSparkSession {
     withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, 
serverPort.toString)) {
       SparkConnectService.start(spark.sparkContext)
     }
-    // register udf directly on the server, we're not testing client UDFs 
here...
-    val serverSession =
-      SparkConnectService.getOrCreateIsolatedSession(defaultUserId, 
defaultSessionId).session
-    serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms 
}))
   }
 
   override def afterAll(): Unit = {
@@ -84,6 +80,7 @@ trait SparkConnectServerTest extends SharedSparkSession {
   protected def clearAllExecutions(): Unit = {
     SparkConnectService.executionManager.listExecuteHolders.foreach(_.close())
     SparkConnectService.executionManager.periodicMaintenance(0)
+    SparkConnectService.sessionManager.invalidateAllSessions()
     assertNoActiveExecutions()
   }
 
@@ -215,12 +212,24 @@ trait SparkConnectServerTest extends SharedSparkSession {
     }
   }
 
+  protected def withClient(sessionId: String = defaultSessionId, userId: 
String = defaultUserId)(
+      f: SparkConnectClient => Unit): Unit = {
+    withClient(f, sessionId, userId)
+  }
+
   protected def withClient(f: SparkConnectClient => Unit): Unit = {
+    withClient(f, defaultSessionId, defaultUserId)
+  }
+
+  protected def withClient(
+      f: SparkConnectClient => Unit,
+      sessionId: String,
+      userId: String): Unit = {
     val client = SparkConnectClient
       .builder()
       .port(serverPort)
-      .sessionId(defaultSessionId)
-      .userId(defaultUserId)
+      .sessionId(sessionId)
+      .userId(userId)
       .enableReattachableExecute()
       .build()
     try f(client)
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
index 0e29a07b719a..784b978f447d 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
@@ -347,6 +347,10 @@ class ReattachableExecuteSuite extends 
SparkConnectServerTest {
   }
 
   test("long sleeping query") {
+    // register udf directly on the server, we're not testing client UDFs 
here...
+    val serverSession =
+      SparkConnectService.getOrCreateIsolatedSession(defaultUserId, 
defaultSessionId).session
+    serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms 
}))
     // query will be sleeping and not returning results, while having multiple 
reattach
     withSparkEnvConfs(
       (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, 
"1s")) {
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index ce452623e6b8..b314e7d8d483 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -841,12 +841,12 @@ class SparkConnectServiceSuite
     spark.sparkContext.addSparkListener(verifyEvents.listener)
     Utils.tryWithSafeFinally({
       f(verifyEvents)
-      SparkConnectService.invalidateAllSessions()
+      SparkConnectService.sessionManager.invalidateAllSessions()
       verifyEvents.onSessionClosed()
     }) {
       verifyEvents.waitUntilEmpty()
       spark.sparkContext.removeSparkListener(verifyEvents.listener)
-      SparkConnectService.invalidateAllSessions()
+      SparkConnectService.sessionManager.invalidateAllSessions()
       SparkConnectPluginRegistry.reset()
     }
   }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index 14ecc9a2e95e..cc0481dab0f4 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -16,13 +16,171 @@
  */
 package org.apache.spark.sql.connect.service
 
+import java.util.UUID
+
 import org.scalatest.concurrent.Eventually
 import org.scalatest.time.SpanSugar._
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.connect.SparkConnectServerTest
 
 class SparkConnectServiceE2ESuite extends SparkConnectServerTest {
 
+  // Making results of these queries large enough, so that all the results do 
not fit in the
+  // buffers and are not pushed out immediately even when the client doesn't 
consume them, so that
+  // even if the connection got closed, the client would see it as succeeded 
because the results
+  // were all already in the buffer.
+  val BIG_ENOUGH_QUERY = "select * from range(1000000)"
+
+  test("ReleaseSession releases all queries and does not allow more requests 
in the session") {
+    withClient { client =>
+      val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
+      val query2 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
+      val query3 = client.execute(buildPlan("select 1"))
+      // just creating the iterator is lazy, trigger query1 and query2 to be 
sent.
+      query1.hasNext
+      query2.hasNext
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        SparkConnectService.executionManager.listExecuteHolders.length == 2
+      }
+
+      // Close session
+      client.releaseSession()
+
+      // Check that queries get cancelled
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        SparkConnectService.executionManager.listExecuteHolders.length == 0
+        // SparkConnectService.sessionManager.
+      }
+
+      // query1 and query2 could get either an:
+      // OPERATION_CANCELED if it happens fast - when closing the session 
interrupted the queries,
+      // and that error got pushed to the client buffers before the client got 
disconnected.
+      // OPERATION_ABANDONED if it happens slow - when closing the session 
interrupted the client
+      // RPCs before it pushed out the error above. The client would then get 
an
+      // INVALID_CURSOR.DISCONNECTED, which it will retry with a 
ReattachExecute, and then get an
+      // INVALID_HANDLE.OPERATION_ABANDONED.
+      val query1Error = intercept[SparkException] {
+        while (query1.hasNext) query1.next()
+      }
+      assert(
+        query1Error.getMessage.contains("OPERATION_CANCELED") ||
+          
query1Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+      val query2Error = intercept[SparkException] {
+        while (query2.hasNext) query2.next()
+      }
+      assert(
+        query2Error.getMessage.contains("OPERATION_CANCELED") ||
+          
query2Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+
+      // query3 has not been submitted before, so it should now fail with 
SESSION_CLOSED
+      val query3Error = intercept[SparkException] {
+        query3.hasNext
+      }
+      assert(query3Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+
+      // No other requests should be allowed in the session, failing with 
SESSION_CLOSED
+      val requestError = intercept[SparkException] {
+        client.interruptAll()
+      }
+      assert(requestError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+    }
+  }
+
+  private def testReleaseSessionTwoSessions(
+      sessionIdA: String,
+      userIdA: String,
+      sessionIdB: String,
+      userIdB: String): Unit = {
+    withClient(sessionId = sessionIdA, userId = userIdA) { clientA =>
+      withClient(sessionId = sessionIdB, userId = userIdB) { clientB =>
+        val queryA = clientA.execute(buildPlan(BIG_ENOUGH_QUERY))
+        val queryB = clientB.execute(buildPlan(BIG_ENOUGH_QUERY))
+        // just creating the iterator is lazy, trigger query1 and query2 to be 
sent.
+        queryA.hasNext
+        queryB.hasNext
+        Eventually.eventually(timeout(eventuallyTimeout)) {
+          SparkConnectService.executionManager.listExecuteHolders.length == 2
+        }
+        // Close session A
+        clientA.releaseSession()
+
+        // A's query gets kicked out.
+        Eventually.eventually(timeout(eventuallyTimeout)) {
+          SparkConnectService.executionManager.listExecuteHolders.length == 1
+        }
+        val queryAError = intercept[SparkException] {
+          while (queryA.hasNext) queryA.next()
+        }
+        assert(
+          queryAError.getMessage.contains("OPERATION_CANCELED") ||
+            
queryAError.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+
+        // B's query can run.
+        while (queryB.hasNext) queryB.next()
+
+        // B can submit more queries.
+        val queryB2 = clientB.execute(buildPlan("SELECT 1"))
+        while (queryB2.hasNext) queryB2.next()
+        // A can't submit more queries.
+        val queryA2 = clientA.execute(buildPlan("SELECT 1"))
+        val queryA2Error = intercept[SparkException] {
+          clientA.interruptAll()
+        }
+        
assert(queryA2Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+      }
+    }
+  }
+
+  test("ReleaseSession for different user_id with same session_id do not 
affect each other") {
+    testReleaseSessionTwoSessions(defaultSessionId, "A", defaultSessionId, "B")
+  }
+
+  test("ReleaseSession for different session_id with same user_id do not 
affect each other") {
+    val sessionIdA = UUID.randomUUID.toString()
+    val sessionIdB = UUID.randomUUID.toString()
+    testReleaseSessionTwoSessions(sessionIdA, "X", sessionIdB, "X")
+  }
+
+  test("ReleaseSession: can't create a new session with the same id and user 
after release") {
+    val sessionId = UUID.randomUUID.toString()
+    val userId = "Y"
+    withClient(sessionId = sessionId, userId = userId) { client =>
+      // this will create the session, and then ReleaseSession at the end of 
withClient.
+      val query = client.execute(buildPlan("SELECT 1"))
+      query.hasNext // trigger execution
+      client.releaseSession()
+    }
+    withClient(sessionId = sessionId, userId = userId) { client =>
+      // shall not be able to create a new session with the same id and user.
+      val query = client.execute(buildPlan("SELECT 1"))
+      val queryError = intercept[SparkException] {
+        while (query.hasNext) query.next()
+      }
+      assert(queryError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
+    }
+  }
+
+  test("ReleaseSession: session with different session_id or user_id allowed 
after release") {
+    val sessionId = UUID.randomUUID.toString()
+    val userId = "Y"
+    withClient(sessionId = sessionId, userId = userId) { client =>
+      val query = client.execute(buildPlan("SELECT 1"))
+      query.hasNext // trigger execution
+      client.releaseSession()
+    }
+    withClient(sessionId = UUID.randomUUID.toString, userId = userId) { client 
=>
+      val query = client.execute(buildPlan("SELECT 1"))
+      query.hasNext // trigger execution
+      client.releaseSession()
+    }
+    withClient(sessionId = sessionId, userId = "YY") { client =>
+      val query = client.execute(buildPlan("SELECT 1"))
+      query.hasNext // trigger execution
+      client.releaseSession()
+    }
+  }
+
   test("SPARK-45133 query should reach FINISHED state when results are not 
consumed") {
     withRawBlockingStub { stub =>
       val iter =
diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md 
b/docs/sql-error-conditions-invalid-handle-error-class.md
index c4cbb48035ff..14526cd53724 100644
--- a/docs/sql-error-conditions-invalid-handle-error-class.md
+++ b/docs/sql-error-conditions-invalid-handle-error-class.md
@@ -45,6 +45,10 @@ Operation not found.
 
 Session already exists.
 
+## SESSION_CLOSED
+
+Session was closed.
+
 ## SESSION_NOT_FOUND
 
 Session not found.
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 318f7d7ade4a..11a1112ad1fe 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -19,7 +19,7 @@ __all__ = [
     "SparkConnectClient",
 ]
 
-from pyspark.loose_version import LooseVersion
+
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -61,6 +61,7 @@ import grpc
 from google.protobuf import text_format
 from google.rpc import error_details_pb2
 
+from pyspark.loose_version import LooseVersion
 from pyspark.version import __version__
 from pyspark.resource.information import ResourceInformation
 from pyspark.sql.connect.client.artifact import ArtifactManager
@@ -1471,6 +1472,26 @@ class SparkConnectClient(object):
         except Exception as error:
             self._handle_error(error)
 
+    def release_session(self) -> None:
+        req = pb2.ReleaseSessionRequest()
+        req.session_id = self._session_id
+        req.client_type = self._builder.userAgent
+        if self._user_id:
+            req.user_context.user_id = self._user_id
+        try:
+            for attempt in self._retrying():
+                with attempt:
+                    resp = self._stub.ReleaseSession(req, 
metadata=self._builder.metadata())
+                    if resp.session_id != self._session_id:
+                        raise SparkConnectException(
+                            "Received incorrect session identifier for 
request:"
+                            f"{resp.session_id} != {self._session_id}"
+                        )
+                    return
+            raise SparkConnectException("Invalid state during retry exception 
handling.")
+        except Exception as error:
+            self._handle_error(error)
+
     def add_tag(self, tag: str) -> None:
         self._throw_if_invalid_tag(tag)
         if not hasattr(self.thread_local, "tags"):
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 0ea02525f78f..0e374e7aa2cc 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -199,22 +199,26 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234
     _RELEASEEXECUTERESPONSE._serialized_start = 11263
     _RELEASEEXECUTERESPONSE._serialized_end = 11375
-    _FETCHERRORDETAILSREQUEST._serialized_start = 11378
-    _FETCHERRORDETAILSREQUEST._serialized_end = 11579
-    _FETCHERRORDETAILSRESPONSE._serialized_start = 11582
-    _FETCHERRORDETAILSRESPONSE._serialized_end = 13052
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12271
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
12234
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12271
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12274
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12683
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 12585
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 12653
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12686
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13033
-    _SPARKCONNECTSERVICE._serialized_start = 13055
-    _SPARKCONNECTSERVICE._serialized_end = 13904
+    _RELEASESESSIONREQUEST._serialized_start = 11378
+    _RELEASESESSIONREQUEST._serialized_end = 11549
+    _RELEASESESSIONRESPONSE._serialized_start = 11551
+    _RELEASESESSIONRESPONSE._serialized_end = 11606
+    _FETCHERRORDETAILSREQUEST._serialized_start = 11609
+    _FETCHERRORDETAILSREQUEST._serialized_end = 11810
+    _FETCHERRORDETAILSRESPONSE._serialized_start = 11813
+    _FETCHERRORDETAILSRESPONSE._serialized_end = 13283
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11958
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12132
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12135
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12502
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
12465
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12502
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12505
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12914
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 12816
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 12884
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12917
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13264
+    _SPARKCONNECTSERVICE._serialized_start = 13286
+    _SPARKCONNECTSERVICE._serialized_end = 14232
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index c29feb4164cf..20abbcb348bd 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -2763,6 +2763,84 @@ class 
ReleaseExecuteResponse(google.protobuf.message.Message):
 
 global___ReleaseExecuteResponse = ReleaseExecuteResponse
 
+class ReleaseSessionRequest(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    SESSION_ID_FIELD_NUMBER: builtins.int
+    USER_CONTEXT_FIELD_NUMBER: builtins.int
+    CLIENT_TYPE_FIELD_NUMBER: builtins.int
+    session_id: builtins.str
+    """(Required)
+
+    The session_id of the request to reattach to.
+    This must be an id of existing session.
+    """
+    @property
+    def user_context(self) -> global___UserContext:
+        """(Required) User context
+
+        user_context.user_id and session+id both identify a unique remote 
spark session on the
+        server side.
+        """
+    client_type: builtins.str
+    """Provides optional information about the client sending the request. 
This field
+    can be used for language or version specific information and is only 
intended for
+    logging purposes and will not be interpreted by the server.
+    """
+    def __init__(
+        self,
+        *,
+        session_id: builtins.str = ...,
+        user_context: global___UserContext | None = ...,
+        client_type: builtins.str | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_client_type",
+            b"_client_type",
+            "client_type",
+            b"client_type",
+            "user_context",
+            b"user_context",
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_client_type",
+            b"_client_type",
+            "client_type",
+            b"client_type",
+            "session_id",
+            b"session_id",
+            "user_context",
+            b"user_context",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_client_type", 
b"_client_type"]
+    ) -> typing_extensions.Literal["client_type"] | None: ...
+
+global___ReleaseSessionRequest = ReleaseSessionRequest
+
+class ReleaseSessionResponse(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    SESSION_ID_FIELD_NUMBER: builtins.int
+    session_id: builtins.str
+    """Session id of the session on which the release executed."""
+    def __init__(
+        self,
+        *,
+        session_id: builtins.str = ...,
+    ) -> None: ...
+    def ClearField(
+        self, field_name: typing_extensions.Literal["session_id", 
b"session_id"]
+    ) -> None: ...
+
+global___ReleaseSessionResponse = ReleaseSessionResponse
+
 class FetchErrorDetailsRequest(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py 
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index f6c5573ded6b..12675747e0f9 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -70,6 +70,11 @@ class SparkConnectServiceStub(object):
             
request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString,
             
response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString,
         )
+        self.ReleaseSession = channel.unary_unary(
+            "/spark.connect.SparkConnectService/ReleaseSession",
+            
request_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString,
+            
response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString,
+        )
         self.FetchErrorDetails = channel.unary_unary(
             "/spark.connect.SparkConnectService/FetchErrorDetails",
             
request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString,
@@ -141,6 +146,16 @@ class SparkConnectServiceServicer(object):
         context.set_details("Method not implemented!")
         raise NotImplementedError("Method not implemented!")
 
+    def ReleaseSession(self, request, context):
+        """Release a session.
+        All the executions in the session will be released. Any further 
requests for the session with
+        that session_id for the given user_id will fail. If the session didn't 
exist or was already
+        released, this is a noop.
+        """
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details("Method not implemented!")
+        raise NotImplementedError("Method not implemented!")
+
     def FetchErrorDetails(self, request, context):
         """FetchErrorDetails retrieves the matched exception with details 
based on a provided error id."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -190,6 +205,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, 
server):
             
request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString,
             
response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString,
         ),
+        "ReleaseSession": grpc.unary_unary_rpc_method_handler(
+            servicer.ReleaseSession,
+            
request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.FromString,
+            
response_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.SerializeToString,
+        ),
         "FetchErrorDetails": grpc.unary_unary_rpc_method_handler(
             servicer.FetchErrorDetails,
             
request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString,
@@ -438,6 +458,35 @@ class SparkConnectService(object):
             metadata,
         )
 
+    @staticmethod
+    def ReleaseSession(
+        request,
+        target,
+        options=(),
+        channel_credentials=None,
+        call_credentials=None,
+        insecure=False,
+        compression=None,
+        wait_for_ready=None,
+        timeout=None,
+        metadata=None,
+    ):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            "/spark.connect.SparkConnectService/ReleaseSession",
+            
spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString,
+            spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+        )
+
     @staticmethod
     def FetchErrorDetails(
         request,
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 09bd60606c76..1aa857b4f617 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -254,6 +254,9 @@ class SparkSession:
         self._client = SparkConnectClient(connection=connection, 
user_id=userId)
         self._session_id = self._client._session_id
 
+        # Set to false to prevent client.release_session on close() (testing 
only)
+        self.release_session_on_close = True
+
     @classmethod
     def _set_default_and_active_session(cls, session: "SparkSession") -> None:
         """
@@ -645,15 +648,16 @@ class SparkSession:
     clearTags.__doc__ = PySparkSession.clearTags.__doc__
 
     def stop(self) -> None:
-        # Stopping the session will only close the connection to the current 
session (and
-        # the life cycle of the session is maintained by the server),
-        # whereas the regular PySpark session immediately terminates the Spark 
Context
-        # itself, meaning that stopping all Spark sessions.
+        # Whereas the regular PySpark session immediately terminates the Spark 
Context
+        # itself, meaning that stopping all Spark sessions, this will only 
stop this one session
+        # on the server.
         # It is controversial to follow the existing the regular Spark 
session's behavior
         # specifically in Spark Connect the Spark Connect server is designed 
for
         # multi-tenancy - the remote client side cannot just stop the server 
and stop
         # other remote clients being used from other users.
         with SparkSession._lock:
+            if not self.is_stopped and self.release_session_on_close:
+                self.client.release_session()
             self.client.close()
             if self is SparkSession._default_session:
                 SparkSession._default_session = None
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 34bd314c76f7..f024a03c2686 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3437,6 +3437,7 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
         # Gets currently active session.
         same = 
PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate()
         self.assertEquals(other, same)
+        same.release_session_on_close = False  # avoid sending release to 
dummy connection
         same.stop()
 
         # Make sure the environment is clean.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to