This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0b5e92b7815d [SPARK-54886] Add base session created in
SparkConnectService
0b5e92b7815d is described below
commit 0b5e92b7815db9e5ea3c5faf122fe2ceccfd5246
Author: Garland Zhang <[email protected]>
AuthorDate: Mon Jan 5 10:16:03 2026 -0400
[SPARK-54886] Add base session created in SparkConnectService
### What changes were proposed in this pull request?
This PR makes SparkConnectService rely on its own SparkSession that is
private and only intended for copying session configs to create new Sessions
### Why are the changes needed?
The default session can get cleaned up in which case the
SparkConnectService cannot recover as session creation fails on subsequent rpcs
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added basic testing
### Was this patch authored or co-authored using generative AI tooling?
Yes
Closes #52895 from garlandz-db/isolated_root_session.
Authored-by: Garland Zhang <[email protected]>
Signed-off-by: Herman van Hövell <[email protected]>
---
.../sql/connect/service/SparkConnectService.scala | 1 +
.../service/SparkConnectSessionManager.scala | 21 ++++++++--
.../connect/planner/SparkConnectServiceSuite.scala | 6 +++
.../service/ArtifactStatusesHandlerSuite.scala | 6 +++
.../service/SparkConnectCloneSessionSuite.scala | 1 +
.../service/SparkConnectSessionManagerSuite.scala | 49 ++++++++++++++++++++++
6 files changed, 80 insertions(+), 4 deletions(-)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 13ce2d64256b..4641bc0a1106 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -436,6 +436,7 @@ object SparkConnectService extends Logging {
return
}
+ sessionManager.initializeBaseSession(sc)
startGRPCService()
createListenerAndUI(sc)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index f28af0379a04..6c468ba46cc0 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
import com.google.common.cache.CacheBuilder
-import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO}
import org.apache.spark.sql.classic.SparkSession
@@ -39,6 +39,9 @@ import org.apache.spark.util.ThreadUtils
*/
class SparkConnectSessionManager extends Logging {
+ // Base SparkSession created from the SparkContext, used to create new
isolated sessions
+ @volatile private var baseSession: Option[SparkSession] = None
+
private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
new ConcurrentHashMap[SessionKey, SessionHolder]()
@@ -48,6 +51,16 @@ class SparkConnectSessionManager extends Logging {
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
.build[SessionKey, SessionHolderInfo]()
+ /**
+ * Initialize the base SparkSession from the provided SparkContext. This
should be called once
+ * during SparkConnectService startup.
+ */
+ def initializeBaseSession(sc: SparkContext): Unit = {
+ if (baseSession.isEmpty) {
+ baseSession =
Some(SparkSession.builder().sparkContext(sc).getOrCreate().newSession())
+ }
+ }
+
/** Executor for the periodic maintenance */
private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()
@@ -333,12 +346,12 @@ class SparkConnectSessionManager extends Logging {
}
private def newIsolatedSession(): SparkSession = {
- val active = SparkSession.active
- if (active.sparkContext.isStopped) {
+ val session = baseSession.get
+ if (session.sparkContext.isStopped) {
assert(SparkSession.getDefaultSession.nonEmpty)
SparkSession.getDefaultSession.get.newSession()
} else {
- active.newSession()
+ session.newSession()
}
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 2989471d36a0..6df50b6588c2 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -65,6 +65,12 @@ class SparkConnectServiceSuite
with Logging
with SparkConnectPlanTest {
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ SparkConnectService.sessionManager.invalidateAllSessions()
+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
+ }
+
private def sparkSessionHolder =
SparkConnectTestUtils.createDummySessionHolder(spark)
private def DEFAULT_UUID =
UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093")
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
index 7ce3ff46f553..caa71c644e6a 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
@@ -42,6 +42,12 @@ class ArtifactStatusesHandlerSuite extends
SharedSparkSession with ResourceHelpe
val sessionId = UUID.randomUUID().toString
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ SparkConnectService.sessionManager.invalidateAllSessions()
+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
+ }
+
def getStatuses(names: Seq[String], exist: Set[String]):
ArtifactStatusesResponse = {
val promise = Promise[ArtifactStatusesResponse]()
val handler = new SparkConnectArtifactStatusesHandler(new
DummyStreamObserver(promise)) {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
index 922c239526f3..42541b8c5f00 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
@@ -29,6 +29,7 @@ class SparkConnectCloneSessionSuite extends
SharedSparkSession with BeforeAndAft
override def beforeEach(): Unit = {
super.beforeEach()
SparkConnectService.sessionManager.invalidateAllSessions()
+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
}
test("clone session with invalid target session ID format") {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
index 94deb83f6ad4..04d16a910746 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkSQLException
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.pipelines.graph.{DataflowGraph,
PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.test.SharedSparkSession
@@ -32,6 +33,7 @@ class SparkConnectSessionManagerSuite extends
SharedSparkSession with BeforeAndA
override def beforeEach(): Unit = {
super.beforeEach()
SparkConnectService.sessionManager.invalidateAllSessions()
+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
}
test("sessionId needs to be an UUID") {
@@ -171,4 +173,51 @@ class SparkConnectSessionManagerSuite extends
SharedSparkSession with BeforeAndA
sessionHolder.getPipelineExecution(graphId).isEmpty,
"pipeline execution was not removed")
}
+
+ test("baseSession allows creating sessions after default session is
cleared") {
+ // Create a new session manager to test initialization
+ val sessionManager = new SparkConnectSessionManager()
+
+ // Initialize the base session with the test SparkContext
+ sessionManager.initializeBaseSession(spark.sparkContext)
+
+ // Clear the default and active sessions to simulate the scenario where
+ // SparkSession.active or SparkSession.getDefaultSession would fail
+ SparkSession.clearDefaultSession()
+ SparkSession.clearActiveSession()
+
+ // Create an isolated session - this should still work because we have
baseSession
+ val key = SessionKey("user", UUID.randomUUID().toString)
+ val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None)
+
+ // Verify the session was created successfully
+ assert(sessionHolder != null)
+ assert(sessionHolder.session != null)
+
+ // Clean up
+ sessionManager.closeSession(key)
+ }
+
+ test("initializeBaseSession is idempotent") {
+ // Create a new session manager to test initialization
+ val sessionManager = new SparkConnectSessionManager()
+
+ // Initialize the base session multiple times
+ sessionManager.initializeBaseSession(spark.sparkContext)
+ val key1 = SessionKey("user1", UUID.randomUUID().toString)
+ val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
+ val baseSessionUUID1 = sessionHolder1.session.sessionUUID
+
+ // Initialize again - should not change the base session
+ sessionManager.initializeBaseSession(spark.sparkContext)
+ val key2 = SessionKey("user2", UUID.randomUUID().toString)
+ val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)
+
+ // Both sessions should be isolated from each other
+ assert(sessionHolder1.session.sessionUUID !=
sessionHolder2.session.sessionUUID)
+
+ // Clean up
+ sessionManager.closeSession(key1)
+ sessionManager.closeSession(key2)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]