HyukjinKwon commented on code in PR #47008:
URL: https://github.com/apache/spark/pull/47008#discussion_r1645304636
##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala:
##########
@@ -1024,7 +1037,14 @@ object SparkSession extends Logging {
*/
def getOrCreate(): SparkSession = {
val session = tryCreateSessionFromClient()
- .getOrElse(sessions.get(builder.configuration))
+ .getOrElse({
+ var existingSession = sessions.get(builder.configuration)
Review Comment:
Do we need a lock here for `sessions`?
##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala:
##########
@@ -1036,7 +1056,8 @@ object SparkSession extends Logging {
*
* @since 3.5.0
*/
- def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())
+ def getDefaultSession: Option[SparkSession] =
+ Option(defaultSession.get()).filterNot(s => s.client != null &&
s.client.hasSessionChanged)
Review Comment:
I wonder if we can use the same naming with the Python side so we can easily
land the same fix at both sides in the future.
##########
connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala:
##########
@@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with
RemoteSparkSession {
.create()
}
}
+
+ test("get or create after session changed") {
Review Comment:
```suggestion
test("SPARK-47986: get or create after session changed") {
```
##########
connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala:
##########
@@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with
RemoteSparkSession {
.create()
}
}
+
+ test("get or create after session changed") {
+ val remote = s"sc://localhost:$serverPort"
+ SparkSession.clearDefaultSession()
+ SparkSession.clearActiveSession()
+
+ val session1 = SparkSession
+ .builder()
+ .remote(remote)
+ .getOrCreate()
+
+ assert(session1 eq SparkSession.getActiveSession.get)
+ assert(session1 eq SparkSession.getDefaultSession.get)
+ assert(session1.range(3).collect().length == 3)
+
+ session1.client.hijackServerSideSessionIdForTesting("-testing")
+
+ try {
Review Comment:
Maybe:
```scala
val e = intercept[StatusRuntimeException] {
session1.range(3).analyze
}
assert(e.getMessage.contains("INVALID_HANDLE.SESSION_CHANGED"))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]