This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d6e9dd6b121 [SPARK-47986][CONNECT][FOLLOW-UP] Unable to create a new 
session when the default session is closed by the server
5d6e9dd6b121 is described below

commit 5d6e9dd6b1212823dd3aa148935723151027f911
Author: Changgyoo Park <[email protected]>
AuthorDate: Thu Jun 20 08:49:43 2024 +0900

    [SPARK-47986][CONNECT][FOLLOW-UP] Unable to create a new session when the 
default session is closed by the server
    
    ### What changes were proposed in this pull request?
    
    This is a Scala port of https://github.com/apache/spark/pull/46221 and 
https://github.com/apache/spark/pull/46435.
    
    A client is unaware of a server restart or the server having closed the 
client until it receives an error. However, at this point, the client in unable 
to create a new session to the same connect endpoint, since the stale session 
is still recorded
    as the active and default session.
    
    With this change, when the server communicates that the session has changed 
via a GRPC error, the session and the respective client are marked as stale, 
thereby allowing a new default connection can be created via the session 
builder.
    
    In some cases, particularly when running older versions of the Spark 
cluster (3.5), the error actually manifests as a mismatch in the observed 
server-side session id between calls. With this fix, we also capture this case 
and ensure that this case is
    also handled.
    
    ### Why are the changes needed?
    
    Being unable to use getOrCreate() after an error is unacceptable and should 
be fixed.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    ./build/sbt testOnly *SparkSessionE2ESuite
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47008 from changgyoopark-db/SPARK-47986.
    
    Authored-by: Changgyoo Park <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  | 35 ++++++++++++++-----
 .../apache/spark/sql/SparkSessionE2ESuite.scala    | 39 ++++++++++++++++++++++
 .../sql/connect/client/ResponseValidator.scala     | 29 +++++++++++++++-
 .../sql/connect/client/SparkConnectClient.scala    | 11 ++++++
 4 files changed, 105 insertions(+), 9 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 19c5a3f14c64..80336fb1eaea 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -829,10 +829,16 @@ object SparkSession extends Logging {
 
   /**
    * Set the (global) default [[SparkSession]], and (thread-local) active 
[[SparkSession]] when
-   * they are not set yet.
+   * they are not set yet or the associated [[SparkConnectClient]] is unusable.
    */
   private def setDefaultAndActiveSession(session: SparkSession): Unit = {
-    defaultSession.compareAndSet(null, session)
+    val currentDefault = defaultSession.getAcquire
+    if (currentDefault == null || !currentDefault.client.isSessionValid) {
+      // Update `defaultSession` if it is null or the contained session is not 
valid. There is a
+      // chance that the following `compareAndSet` fails if a new default 
session has just been set,
+      // but that does not matter since that event has happened after this 
method was invoked.
+      defaultSession.compareAndSet(currentDefault, session)
+    }
     if (getActiveSession.isEmpty) {
       setActiveSession(session)
     }
@@ -972,7 +978,7 @@ object SparkSession extends Logging {
     def appName(name: String): Builder = this
 
     private def tryCreateSessionFromClient(): Option[SparkSession] = {
-      if (client != null) {
+      if (client != null && client.isSessionValid) {
         Option(new SparkSession(client, planIdGenerator))
       } else {
         None
@@ -1024,7 +1030,16 @@ object SparkSession extends Logging {
      */
     def getOrCreate(): SparkSession = {
       val session = tryCreateSessionFromClient()
-        .getOrElse(sessions.get(builder.configuration))
+        .getOrElse({
+          var existingSession = sessions.get(builder.configuration)
+          if (!existingSession.client.isSessionValid) {
+            // If the cached session has become invalid, e.g., due to a server 
restart, the cache
+            // entry is invalidated.
+            sessions.invalidate(builder.configuration)
+            existingSession = sessions.get(builder.configuration)
+          }
+          existingSession
+        })
       setDefaultAndActiveSession(session)
       applyOptions(session)
       session
@@ -1032,11 +1047,13 @@ object SparkSession extends Logging {
   }
 
   /**
-   * Returns the default SparkSession.
+   * Returns the default SparkSession. If the previously set default 
SparkSession becomes
+   * unusable, returns None.
    *
    * @since 3.5.0
    */
-  def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())
+  def getDefaultSession: Option[SparkSession] =
+    Option(defaultSession.get()).filter(_.client.isSessionValid)
 
   /**
    * Sets the default SparkSession.
@@ -1057,11 +1074,13 @@ object SparkSession extends Logging {
   }
 
   /**
-   * Returns the active SparkSession for the current thread.
+   * Returns the active SparkSession for the current thread. If the previously 
set active
+   * SparkSession becomes unusable, returns None.
    *
    * @since 3.5.0
    */
-  def getActiveSession: Option[SparkSession] = 
Option(activeThreadSession.get())
+  def getActiveSession: Option[SparkSession] =
+    Option(activeThreadSession.get()).filter(_.client.isSessionValid)
 
   /**
    * Changes the SparkSession that will be returned in this thread and its 
children when
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
index 203b1295005a..b28aa905c7a2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
@@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with 
RemoteSparkSession {
         .create()
     }
   }
+
+  test("SPARK-47986: get or create after session changed") {
+    val remote = s"sc://localhost:$serverPort"
+
+    SparkSession.clearDefaultSession()
+    SparkSession.clearActiveSession()
+
+    val session1 = SparkSession
+      .builder()
+      .remote(remote)
+      .getOrCreate()
+
+    assert(session1 eq SparkSession.getActiveSession.get)
+    assert(session1 eq SparkSession.getDefaultSession.get)
+    assert(session1.range(3).collect().length == 3)
+
+    session1.client.hijackServerSideSessionIdForTesting("-testing")
+
+    val e = intercept[SparkException] {
+      session1.range(3).analyze
+    }
+
+    assert(e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]"))
+    assert(!session1.client.isSessionValid)
+    assert(SparkSession.getActiveSession.isEmpty)
+    assert(SparkSession.getDefaultSession.isEmpty)
+
+    val session2 = SparkSession
+      .builder()
+      .remote(remote)
+      .getOrCreate()
+
+    assert(session1 ne session2)
+    assert(session2.client.isSessionValid)
+    assert(session2 eq SparkSession.getActiveSession.get)
+    assert(session2 eq SparkSession.getDefaultSession.get)
+    assert(session2.range(3).collect().length == 3)
+  }
+
 }
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
index 29272c96132b..42c3387335be 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
@@ -16,7 +16,10 @@
  */
 package org.apache.spark.sql.connect.client
 
+import java.util.concurrent.atomic.AtomicBoolean
+
 import com.google.protobuf.GeneratedMessageV3
+import io.grpc.{Status, StatusRuntimeException}
 import io.grpc.stub.StreamObserver
 
 import org.apache.spark.internal.Logging
@@ -30,6 +33,12 @@ class ResponseValidator extends Logging {
   // do not use server-side streaming.
   private var serverSideSessionId: Option[String] = None
 
+  // Indicates whether the client and the client information on the server 
correspond to each other
+  // This flag being false means that the server has restarted and lost the 
client information, or
+  // there is a logic error in the code; both cases, the user should establish 
a new connection to
+  // the server. Access to the value has to be synchronized since it can be 
shared.
+  private val isSessionActive: AtomicBoolean = new AtomicBoolean(true)
+
   // Returns the server side session ID, used to send it back to the server in 
the follow-up
   // requests so the server can validate it session id against the previous 
requests.
   def getServerSideSessionId: Option[String] = serverSideSessionId
@@ -42,8 +51,25 @@ class ResponseValidator extends Logging {
     serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
   }
 
+  /**
+   * Returns true if the session is valid on both the client and the server.
+   */
+  private[sql] def isSessionValid: Boolean = {
+    // An active session is considered valid.
+    isSessionActive.getAcquire
+  }
+
   def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
-    val response = fn
+    val response =
+      try {
+        fn
+      } catch {
+        case e: StatusRuntimeException
+            if e.getStatus.getCode == Status.Code.INTERNAL &&
+              e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]") =>
+          isSessionActive.setRelease(false)
+          throw e
+      }
     val field = 
response.getDescriptorForType.findFieldByName("server_side_session_id")
     // If the field does not exist, we ignore it. New / Old message might not 
contain it and this
     // behavior allows us to be compatible.
@@ -54,6 +80,7 @@ class ResponseValidator extends Logging {
         serverSideSessionId match {
           case Some(id) =>
             if (value != id) {
+              isSessionActive.setRelease(false)
               throw new IllegalStateException(
                 s"Server side session ID changed from $id to $value")
             }
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index b5eda024bfb3..7c3108fdb1b0 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -71,6 +71,17 @@ private[sql] class SparkConnectClient(
     stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
   }
 
+  /**
+   * Returns true if the session is valid on both the client and the server. A 
session becomes
+   * invalid if the server side information about the client, e.g., session 
ID, does not
+   * correspond to the actual client state.
+   */
+  private[sql] def isSessionValid: Boolean = {
+    // The last known state of the session is store in `responseValidator`, 
because it is where the
+    // client gets responses from the server.
+    stubState.responseValidator.isSessionValid
+  }
+
   private[sql] val artifactManager: ArtifactManager = {
     new ArtifactManager(configuration, sessionId, bstub, stub)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to