This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9087d84b5127 [SPARK-50710][CONNECT] Add support for optional client 
reconnection to sessions after release
9087d84b5127 is described below

commit 9087d84b51278fc8e7429ab5ec27b73254cf60ad
Author: vicennial <[email protected]>
AuthorDate: Fri Jan 3 09:12:57 2025 +0900

    [SPARK-50710][CONNECT] Add support for optional client reconnection to 
sessions after release
    
    ### What changes were proposed in this pull request?
    
    Adds a new boolean `allow_reconnect` field to `ReleaseSessionRequest`.
    When set to `true` in the request, the server will not place the session in 
the `closedSessionsCache` of `SparkConnectSessionManager`.
    The session's clean-up process is unmodified.
    
    ### Why are the changes needed?
    
    Currently, the connect server will, by default, tombstone all sessions that 
have either been released explicitly (through a `ReleaseSession` request) or 
cleaned up due to inactivity/idleness in periodic checks.
    
    Tombstoning prevents clients from reconnecting with the same `userId` and 
`sessionId`. This mechanism ensures that clients do not accidentally end up 
with a 'fresh' server-side session, which may be disastrous/fatal as all 
previously held state is lost (e.g., Temporary views, temporary UDFs, modified 
configs, current catalog, etc.).
    
    Consider a client that runs simple non-state dependant queries (e.g `select 
count from ...`), perhaps sparsely during the lifetime of the application. Such 
a client may prefer to opt out of tombstoning for the following reasons:
    
    - Queries do not depend on any custom server-side state.
    - Modifying `userId`/`sessionId` on each reconnect may be inconvenient for 
tracking/observability purposes.
    - On resource-constrained servers, clients may want to minimize their 
memory footprint by explicitly releasing their state, especially when they 
believe their requests are sparsely spread out.
    
    Currently, the only way to allow clients to reconnect is to set 
`spark.connect.session.manager.closedSessionsTombstonesSize` to `0`. However, 
this is not ideal as it would allow all clients to reconnect, which as 
previously pointed out, may be dangerous.
    
    As an improvement, allowing specific clients to explicitly signal/request 
the reconnection possibility addresses the needs mentioned earlier.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    When the client releases a session with `allow_reconnect` set to `true`, a 
reconnection will lead to the server generation a fresh session and not result 
in an error like `[INVALID_HANDLE.SESSION_CLOSED] The handle 
271dab46-a9a0-4458-ad3a-71442eaa9a21 is invalid. Session was closed. SQLSTATE: 
HY000`
    
    Full example (gRPC based):
    
    #### Default/`allow_reconnect` set to `false`
    
    Create a session via a `Config` request:
    ```gRPC
    {
        "operation": {
            "get": {
                "keys": ["spark.sql.ansi.enabled"]
            }
        },
        "session_id": "271dab46-a9a0-4458-ad3a-71442eaa9a21",
        "user_context": {
            "user_id": "vicennial",
            "user_name": "Akhil"
        }
    }
    ```
    
    Release session via `ReleaseSession` request:
    ```
    {
        "session_id": "271dab46-a9a0-4458-ad3a-71442eaa9a21",
        "user_context": {
            "user_id": "vicennial",
            "user_name": "Akhil"
        },
        "allow_reconnect": false
    }
    ```
    
    Retry the earlier config request, the error 
`[INVALID_HANDLE.SESSION_CLOSED] The handle 
271dab46-a9a0-4458-ad3a-71442eaa9a21 is invalid. Session was closed. SQLSTATE: 
HY000` is hit.
    
    #### Default/`allow_reconnect` set to `true`
    Create a session via a `Config` request:
    ```gRPC
    {
        "operation": {
            "get": {
                "keys": ["spark.sql.ansi.enabled"]
            }
        },
        "session_id": "ff1410b9-0a75-4634-820a-b46c14f30896",
        "user_context": {
            "user_id": "vicennial",
            "user_name": "Akhil"
        }
    }
    ```
    
    Release session via `ReleaseSession` request:
    ```
    {
        "session_id": "ff1410b9-0a75-4634-820a-b46c14f30896",
        "user_context": {
            "user_id": "vicennial",
            "user_name": "Akhil"
        },
        "allow_reconnect": true
    }
    ```
    
    Retry the earlier config request, the request goes through and it can be 
noted the `server_side_session_id` in the response of the last config request 
is different from the first one as a new server side session was generated.
    
    ### How was this patch tested?
    
    New unit test + existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49342 from vicennial/allowReconnect.
    
    Authored-by: vicennial <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/proto/base_pb2.py       | 48 +++++++++++-----------
 python/pyspark/sql/connect/proto/base_pb2.pyi      | 18 ++++++++
 .../src/main/protobuf/spark/connect/base.proto     | 14 +++++++
 .../SparkConnectReleaseSessionHandler.scala        |  3 +-
 .../service/SparkConnectSessionManager.scala       | 30 +++++++++-----
 .../spark/sql/connect/SparkConnectServerTest.scala | 11 +++++
 .../service/SparkConnectServiceE2ESuite.scala      | 22 ++++++++++
 7 files changed, 110 insertions(+), 36 deletions(-)

diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 97694c33abeb..6e946a5bd4ae 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -43,7 +43,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 _globals = globals()
@@ -227,31 +227,31 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_RELEASEEXECUTERESPONSE"]._serialized_start = 13786
     _globals["_RELEASEEXECUTERESPONSE"]._serialized_end = 13951
     _globals["_RELEASESESSIONREQUEST"]._serialized_start = 13954
-    _globals["_RELEASESESSIONREQUEST"]._serialized_end = 14125
-    _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 14127
-    _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 14235
-    _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 14238
-    _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 14570
-    _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 14573
-    _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 16128
-    _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start 
= 14802
-    _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end = 
14976
-    _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start = 
14979
-    _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 15347
-    
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start
 = 15310
-    
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end 
= 15347
-    _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start = 
15350
-    _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end = 
15759
+    _globals["_RELEASESESSIONREQUEST"]._serialized_end = 14166
+    _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 14168
+    _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 14276
+    _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 14279
+    _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 14611
+    _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 14614
+    _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 16169
+    _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start 
= 14843
+    _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end = 
15017
+    _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start = 
15020
+    _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 15388
+    
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start
 = 15351
+    
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end 
= 15388
+    _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start = 
15391
+    _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end = 
15800
     _globals[
         "_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
-    ]._serialized_start = 15661
+    ]._serialized_start = 15702
     _globals[
         "_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
-    ]._serialized_end = 15729
-    _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 15762
-    _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 16109
-    _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 16130
-    _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 16220
-    _globals["_SPARKCONNECTSERVICE"]._serialized_start = 16223
-    _globals["_SPARKCONNECTSERVICE"]._serialized_end = 17169
+    ]._serialized_end = 15770
+    _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 15803
+    _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 16150
+    _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 16171
+    _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 16261
+    _globals["_SPARKCONNECTSERVICE"]._serialized_start = 16264
+    _globals["_SPARKCONNECTSERVICE"]._serialized_end = 17210
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 253f8a58166a..fc3a7e804f27 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -3216,6 +3216,7 @@ class 
ReleaseSessionRequest(google.protobuf.message.Message):
     SESSION_ID_FIELD_NUMBER: builtins.int
     USER_CONTEXT_FIELD_NUMBER: builtins.int
     CLIENT_TYPE_FIELD_NUMBER: builtins.int
+    ALLOW_RECONNECT_FIELD_NUMBER: builtins.int
     session_id: builtins.str
     """(Required)
 
@@ -3234,12 +3235,27 @@ class 
ReleaseSessionRequest(google.protobuf.message.Message):
     can be used for language or version specific information and is only 
intended for
     logging purposes and will not be interpreted by the server.
     """
+    allow_reconnect: builtins.bool
+    """Signals the server to allow the client to reconnect to the session 
after it is released.
+
+    By default, the server tombstones the session upon release, preventing 
reconnections and
+    fully cleaning the session state.
+
+    If this flag is set to true, the server may permit the client to reconnect 
to the session
+    post-release, even if the session state has been cleaned. This can result 
in missing state,
+    such as Temporary Views, Temporary UDFs, or the Current Catalog, in the 
reconnected session.
+
+    Use this option sparingly and only when the client fully understands the 
implications of
+    reconnecting to a released session. The client must ensure that any 
queries executed do not
+    rely on the session state prior to its release.
+    """
     def __init__(
         self,
         *,
         session_id: builtins.str = ...,
         user_context: global___UserContext | None = ...,
         client_type: builtins.str | None = ...,
+        allow_reconnect: builtins.bool = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -3257,6 +3273,8 @@ class 
ReleaseSessionRequest(google.protobuf.message.Message):
         field_name: typing_extensions.Literal[
             "_client_type",
             b"_client_type",
+            "allow_reconnect",
+            b"allow_reconnect",
             "client_type",
             b"client_type",
             "session_id",
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/base.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
index e27049d2114d..c308c7e21b66 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -924,6 +924,20 @@ message ReleaseSessionRequest {
   // can be used for language or version specific information and is only 
intended for
   // logging purposes and will not be interpreted by the server.
   optional string client_type = 3;
+
+  // Signals the server to allow the client to reconnect to the session after 
it is released.
+  //
+  // By default, the server tombstones the session upon release, preventing 
reconnections and
+  // fully cleaning the session state.
+  //
+  // If this flag is set to true, the server may permit the client to 
reconnect to the session
+  // post-release, even if the session state has been cleaned. This can result 
in missing state,
+  // such as Temporary Views, Temporary UDFs, or the Current Catalog, in the 
reconnected session.
+  //
+  // Use this option sparingly and only when the client fully understands the 
implications of
+  // reconnecting to a released session. The client must ensure that any 
queries executed do not
+  // rely on the session state prior to its release.
+  bool allow_reconnect = 4;
 }
 
 // Next ID: 3
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
index ec7a7f3bd242..c36f07fc67f8 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala
@@ -37,7 +37,8 @@ class SparkConnectReleaseSessionHandler(
     val maybeSession = 
SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
     maybeSession.foreach(f => 
responseBuilder.setServerSideSessionId(f.serverSessionId))
 
-    SparkConnectService.sessionManager.closeSession(key)
+    val allowReconnect = v.getAllowReconnect
+    SparkConnectService.sessionManager.closeSession(key, allowReconnect)
 
     responseObserver.onNext(responseBuilder.build())
     responseObserver.onCompleted()
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index b0b74a36e187..c59fd02a829a 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -134,7 +134,9 @@ class SparkConnectSessionManager extends Logging {
   }
 
   // Removes session from sessionStore and returns it.
-  private def removeSessionHolder(key: SessionKey): Option[SessionHolder] = {
+  private def removeSessionHolder(
+      key: SessionKey,
+      allowReconnect: Boolean = false): Option[SessionHolder] = {
     var sessionHolder: Option[SessionHolder] = None
 
     // The session holder should remain in the session store until it is added 
to the closed session
@@ -144,9 +146,11 @@ class SparkConnectSessionManager extends Logging {
     sessionHolder = Option(sessionStore.get(key))
 
     sessionHolder.foreach { s =>
-      // Put into closedSessionsCache to prevent the same session from being 
recreated by
-      // getOrCreateIsolatedSession.
-      closedSessionsCache.put(s.key, s.getSessionHolderInfo)
+      if (!allowReconnect) {
+        // Put into closedSessionsCache to prevent the same session from being 
recreated by
+        // getOrCreateIsolatedSession when reconnection isn't allowed.
+        closedSessionsCache.put(s.key, s.getSessionHolderInfo)
+      }
 
       // Then, remove the session holder from the session store.
       sessionStore.remove(key)
@@ -154,17 +158,21 @@ class SparkConnectSessionManager extends Logging {
     sessionHolder
   }
 
-  // Shut downs the session after removing.
-  private def shutdownSessionHolder(sessionHolder: SessionHolder): Unit = {
+  // Shuts down the session after removing.
+  private def shutdownSessionHolder(
+      sessionHolder: SessionHolder,
+      allowReconnect: Boolean = false): Unit = {
     sessionHolder.close()
-    // Update in closedSessionsCache: above it wasn't updated with closedTime 
etc. yet.
-    closedSessionsCache.put(sessionHolder.key, 
sessionHolder.getSessionHolderInfo)
+    if (!allowReconnect) {
+      // Update in closedSessionsCache: above it wasn't updated with 
closedTime etc. yet.
+      closedSessionsCache.put(sessionHolder.key, 
sessionHolder.getSessionHolderInfo)
+    }
   }
 
-  def closeSession(key: SessionKey): Unit = {
-    val sessionHolder = removeSessionHolder(key)
+  def closeSession(key: SessionKey, allowReconnect: Boolean = false): Unit = {
+    val sessionHolder = removeSessionHolder(key, allowReconnect)
     // Rest of the cleanup: the session cannot be accessed anymore by 
getOrCreateIsolatedSession.
-    sessionHolder.foreach(shutdownSessionHolder(_))
+    sessionHolder.foreach(shutdownSessionHolder(_, allowReconnect))
   }
 
   private[connect] def shutdown(): Unit = {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index b04c42a73078..3c857554dc75 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -128,6 +128,17 @@ trait SparkConnectServerTest extends SharedSparkSession {
     req.build()
   }
 
+  protected def buildReleaseSessionRequest(
+      sessionId: String = defaultSessionId,
+      allowReconnect: Boolean = false) = {
+    proto.ReleaseSessionRequest
+      .newBuilder()
+      .setUserContext(userContext)
+      .setSessionId(sessionId)
+      .setAllowReconnect(allowReconnect)
+      .build()
+  }
+
   protected def buildPlan(query: String) = {
     proto.Plan.newBuilder().setRoot(dsl.sql(query)).build()
   }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index f86298a8b5b9..f24560259a88 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -245,4 +245,26 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
       assert(queryError.getMessage.contains("INVALID_HANDLE.SESSION_CHANGED"))
     }
   }
+
+  test("Client is allowed to reconnect to released session if allow_reconnect 
is set") {
+    withRawBlockingStub { stub =>
+      val sessionId = UUID.randomUUID.toString()
+      val iter =
+        stub.executePlan(
+          buildExecutePlanRequest(
+            buildPlan("select * from range(1000000)"),
+            sessionId = sessionId))
+      iter.hasNext // guarantees the request was received by server.
+
+      stub.releaseSession(buildReleaseSessionRequest(sessionId, allowReconnect 
= true))
+
+      val iter2 =
+        stub.executePlan(
+          buildExecutePlanRequest(
+            buildPlan("select * from range(1000000)"),
+            sessionId = sessionId))
+      // guarantees the request was received by server. No exception should be 
thrown on reuse
+      iter2.hasNext
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to