This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5d6e9dd6b121 [SPARK-47986][CONNECT][FOLLOW-UP] Unable to create a new
session when the default session is closed by the server
5d6e9dd6b121 is described below
commit 5d6e9dd6b1212823dd3aa148935723151027f911
Author: Changgyoo Park <[email protected]>
AuthorDate: Thu Jun 20 08:49:43 2024 +0900
[SPARK-47986][CONNECT][FOLLOW-UP] Unable to create a new session when the
default session is closed by the server
### What changes were proposed in this pull request?
This is a Scala port of https://github.com/apache/spark/pull/46221 and
https://github.com/apache/spark/pull/46435.
A client is unaware of a server restart or the server having closed the
client until it receives an error. However, at this point, the client in unable
to create a new session to the same connect endpoint, since the stale session
is still recorded
as the active and default session.
With this change, when the server communicates that the session has changed
via a GRPC error, the session and the respective client are marked as stale,
thereby allowing a new default connection can be created via the session
builder.
In some cases, particularly when running older versions of the Spark
cluster (3.5), the error actually manifests as a mismatch in the observed
server-side session id between calls. With this fix, we also capture this case
and ensure that this case is
also handled.
### Why are the changes needed?
Being unable to use getOrCreate() after an error is unacceptable and should
be fixed.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
./build/sbt testOnly *SparkSessionE2ESuite
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47008 from changgyoopark-db/SPARK-47986.
Authored-by: Changgyoo Park <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 35 ++++++++++++++-----
.../apache/spark/sql/SparkSessionE2ESuite.scala | 39 ++++++++++++++++++++++
.../sql/connect/client/ResponseValidator.scala | 29 +++++++++++++++-
.../sql/connect/client/SparkConnectClient.scala | 11 ++++++
4 files changed, 105 insertions(+), 9 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 19c5a3f14c64..80336fb1eaea 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -829,10 +829,16 @@ object SparkSession extends Logging {
/**
* Set the (global) default [[SparkSession]], and (thread-local) active
[[SparkSession]] when
- * they are not set yet.
+ * they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
- defaultSession.compareAndSet(null, session)
+ val currentDefault = defaultSession.getAcquire
+ if (currentDefault == null || !currentDefault.client.isSessionValid) {
+ // Update `defaultSession` if it is null or the contained session is not
valid. There is a
+ // chance that the following `compareAndSet` fails if a new default
session has just been set,
+ // but that does not matter since that event has happened after this
method was invoked.
+ defaultSession.compareAndSet(currentDefault, session)
+ }
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
@@ -972,7 +978,7 @@ object SparkSession extends Logging {
def appName(name: String): Builder = this
private def tryCreateSessionFromClient(): Option[SparkSession] = {
- if (client != null) {
+ if (client != null && client.isSessionValid) {
Option(new SparkSession(client, planIdGenerator))
} else {
None
@@ -1024,7 +1030,16 @@ object SparkSession extends Logging {
*/
def getOrCreate(): SparkSession = {
val session = tryCreateSessionFromClient()
- .getOrElse(sessions.get(builder.configuration))
+ .getOrElse({
+ var existingSession = sessions.get(builder.configuration)
+ if (!existingSession.client.isSessionValid) {
+ // If the cached session has become invalid, e.g., due to a server
restart, the cache
+ // entry is invalidated.
+ sessions.invalidate(builder.configuration)
+ existingSession = sessions.get(builder.configuration)
+ }
+ existingSession
+ })
setDefaultAndActiveSession(session)
applyOptions(session)
session
@@ -1032,11 +1047,13 @@ object SparkSession extends Logging {
}
/**
- * Returns the default SparkSession.
+ * Returns the default SparkSession. If the previously set default
SparkSession becomes
+ * unusable, returns None.
*
* @since 3.5.0
*/
- def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())
+ def getDefaultSession: Option[SparkSession] =
+ Option(defaultSession.get()).filter(_.client.isSessionValid)
/**
* Sets the default SparkSession.
@@ -1057,11 +1074,13 @@ object SparkSession extends Logging {
}
/**
- * Returns the active SparkSession for the current thread.
+ * Returns the active SparkSession for the current thread. If the previously
set active
+ * SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
- def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get())
+ def getActiveSession: Option[SparkSession] =
+ Option(activeThreadSession.get()).filter(_.client.isSessionValid)
/**
* Changes the SparkSession that will be returned in this thread and its
children when
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
index 203b1295005a..b28aa905c7a2 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
@@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with
RemoteSparkSession {
.create()
}
}
+
+ test("SPARK-47986: get or create after session changed") {
+ val remote = s"sc://localhost:$serverPort"
+
+ SparkSession.clearDefaultSession()
+ SparkSession.clearActiveSession()
+
+ val session1 = SparkSession
+ .builder()
+ .remote(remote)
+ .getOrCreate()
+
+ assert(session1 eq SparkSession.getActiveSession.get)
+ assert(session1 eq SparkSession.getDefaultSession.get)
+ assert(session1.range(3).collect().length == 3)
+
+ session1.client.hijackServerSideSessionIdForTesting("-testing")
+
+ val e = intercept[SparkException] {
+ session1.range(3).analyze
+ }
+
+ assert(e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]"))
+ assert(!session1.client.isSessionValid)
+ assert(SparkSession.getActiveSession.isEmpty)
+ assert(SparkSession.getDefaultSession.isEmpty)
+
+ val session2 = SparkSession
+ .builder()
+ .remote(remote)
+ .getOrCreate()
+
+ assert(session1 ne session2)
+ assert(session2.client.isSessionValid)
+ assert(session2 eq SparkSession.getActiveSession.get)
+ assert(session2 eq SparkSession.getDefaultSession.get)
+ assert(session2.range(3).collect().length == 3)
+ }
+
}
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
index 29272c96132b..42c3387335be 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
@@ -16,7 +16,10 @@
*/
package org.apache.spark.sql.connect.client
+import java.util.concurrent.atomic.AtomicBoolean
+
import com.google.protobuf.GeneratedMessageV3
+import io.grpc.{Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
import org.apache.spark.internal.Logging
@@ -30,6 +33,12 @@ class ResponseValidator extends Logging {
// do not use server-side streaming.
private var serverSideSessionId: Option[String] = None
+ // Indicates whether the client and the client information on the server
correspond to each other
+ // This flag being false means that the server has restarted and lost the
client information, or
+ // there is a logic error in the code; both cases, the user should establish
a new connection to
+ // the server. Access to the value has to be synchronized since it can be
shared.
+ private val isSessionActive: AtomicBoolean = new AtomicBoolean(true)
+
// Returns the server side session ID, used to send it back to the server in
the follow-up
// requests so the server can validate it session id against the previous
requests.
def getServerSideSessionId: Option[String] = serverSideSessionId
@@ -42,8 +51,25 @@ class ResponseValidator extends Logging {
serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
}
+ /**
+ * Returns true if the session is valid on both the client and the server.
+ */
+ private[sql] def isSessionValid: Boolean = {
+ // An active session is considered valid.
+ isSessionActive.getAcquire
+ }
+
def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
- val response = fn
+ val response =
+ try {
+ fn
+ } catch {
+ case e: StatusRuntimeException
+ if e.getStatus.getCode == Status.Code.INTERNAL &&
+ e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]") =>
+ isSessionActive.setRelease(false)
+ throw e
+ }
val field =
response.getDescriptorForType.findFieldByName("server_side_session_id")
// If the field does not exist, we ignore it. New / Old message might not
contain it and this
// behavior allows us to be compatible.
@@ -54,6 +80,7 @@ class ResponseValidator extends Logging {
serverSideSessionId match {
case Some(id) =>
if (value != id) {
+ isSessionActive.setRelease(false)
throw new IllegalStateException(
s"Server side session ID changed from $id to $value")
}
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index b5eda024bfb3..7c3108fdb1b0 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -71,6 +71,17 @@ private[sql] class SparkConnectClient(
stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
}
+ /**
+ * Returns true if the session is valid on both the client and the server. A
session becomes
+ * invalid if the server side information about the client, e.g., session
ID, does not
+ * correspond to the actual client state.
+ */
+ private[sql] def isSessionValid: Boolean = {
+ // The last known state of the session is store in `responseValidator`,
because it is where the
+ // client gets responses from the server.
+ stubState.responseValidator.isSessionValid
+ }
+
private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]