This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f310f4fcc955 [SPARK-54886] Lazily create Spark Connect reuse session
f310f4fcc955 is described below

commit f310f4fcc95580a6824bc7d22b76006f79b8804a
Author: Garland Zhang <[email protected]>
AuthorDate: Thu Jan 15 14:25:13 2026 -0400

    [SPARK-54886] Lazily create Spark Connect reuse session
    
    ### What changes were proposed in this pull request?
    This is a follow up to https://github.com/apache/spark/pull/52895
    
    Make the Spark Connect base session created to be lazy.
    
    ### Why are the changes needed?
    Prevoiusly the session is created while initializing the 
SparkConnectService and this happens in SparkContext initialization. This PR 
makes it lazy so that only when the SparkContext is fully initialized the 
session can be created
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #53812 from garlandz-db/SPARK-54886_lazy.
    
    Authored-by: Garland Zhang <[email protected]>
    Signed-off-by: Herman van Hövell <[email protected]>
---
 .../sql/connect/service/SparkConnectService.scala    |  5 ++++-
 .../connect/service/SparkConnectSessionManager.scala | 20 +++++++++++++++-----
 .../connect/planner/SparkConnectServiceSuite.scala   |  2 +-
 .../service/ArtifactStatusesHandlerSuite.scala       |  2 +-
 .../service/SparkConnectCloneSessionSuite.scala      |  2 +-
 .../service/SparkConnectSessionManagerSuite.scala    |  8 ++++----
 6 files changed, 26 insertions(+), 13 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 4641bc0a1106..00b93c19b2c7 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -38,6 +38,8 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys.HOST
 import org.apache.spark.internal.config.UI.UI_ENABLED
 import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.classic.ClassicConversions._
 import org.apache.spark.sql.connect.config.Connect.{getAuthenticateToken, 
CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, 
CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, 
CONNECT_GRPC_PORT_MAX_RETRIES}
 import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
 import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, 
SparkConnectServerListener, SparkConnectServerTab}
@@ -436,7 +438,8 @@ object SparkConnectService extends Logging {
       return
     }
 
-    sessionManager.initializeBaseSession(sc)
+    sessionManager.initializeBaseSession(() =>
+      SparkSession.builder().sparkContext(sc).getOrCreate().newSession())
     startGRPCService()
     createListenerAndUI(sc)
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index 6c468ba46cc0..d3ddf592e9e7 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
 
 import com.google.common.cache.CacheBuilder
 
-import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
+import org.apache.spark.{SparkEnv, SparkSQLException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO}
 import org.apache.spark.sql.classic.SparkSession
@@ -39,8 +39,11 @@ import org.apache.spark.util.ThreadUtils
  */
 class SparkConnectSessionManager extends Logging {
 
+  // Used to lazily initialize the base session
+  @volatile private var baseSessionCreator: Option[() => SparkSession] = None
+
   // Base SparkSession created from the SparkContext, used to create new 
isolated sessions
-  @volatile private var baseSession: Option[SparkSession] = None
+  @volatile private var _baseSession: Option[SparkSession] = None
 
   private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
     new ConcurrentHashMap[SessionKey, SessionHolder]()
@@ -51,13 +54,20 @@ class SparkConnectSessionManager extends Logging {
       
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
       .build[SessionKey, SessionHolderInfo]()
 
+  private def baseSession: Option[SparkSession] = {
+    if (_baseSession.isEmpty && baseSessionCreator.isDefined) {
+      _baseSession = Some(baseSessionCreator.get())
+    }
+    _baseSession
+  }
+
   /**
    * Initialize the base SparkSession from the provided SparkContext. This 
should be called once
    * during SparkConnectService startup.
    */
-  def initializeBaseSession(sc: SparkContext): Unit = {
-    if (baseSession.isEmpty) {
-      baseSession = 
Some(SparkSession.builder().sparkContext(sc).getOrCreate().newSession())
+  def initializeBaseSession(createSession: () => SparkSession): Unit = {
+    if (baseSessionCreator.isEmpty) {
+      baseSessionCreator = Some(createSession)
     }
   }
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 6df50b6588c2..0e5488e31222 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -68,7 +68,7 @@ class SparkConnectServiceSuite
   override def beforeEach(): Unit = {
     super.beforeEach()
     SparkConnectService.sessionManager.invalidateAllSessions()
-    
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
+    SparkConnectService.sessionManager.initializeBaseSession(() => 
spark.newSession())
   }
 
   private def sparkSessionHolder = 
SparkConnectTestUtils.createDummySessionHolder(spark)
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
index caa71c644e6a..275808942d37 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
@@ -45,7 +45,7 @@ class ArtifactStatusesHandlerSuite extends SharedSparkSession 
with ResourceHelpe
   override def beforeEach(): Unit = {
     super.beforeEach()
     SparkConnectService.sessionManager.invalidateAllSessions()
-    
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
+    SparkConnectService.sessionManager.initializeBaseSession(() => 
spark.newSession())
   }
 
   def getStatuses(names: Seq[String], exist: Set[String]): 
ArtifactStatusesResponse = {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
index 09292ec2a227..b132c6d64c16 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
@@ -27,7 +27,7 @@ class SparkConnectCloneSessionSuite extends 
SharedSparkSession {
   override def beforeEach(): Unit = {
     super.beforeEach()
     SparkConnectService.sessionManager.invalidateAllSessions()
-    
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
+    SparkConnectService.sessionManager.initializeBaseSession(() => 
spark.newSession())
   }
 
   test("clone session with invalid target session ID format") {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
index 1716fbb34b96..4b846631d7b7 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -32,7 +32,7 @@ class SparkConnectSessionManagerSuite extends 
SharedSparkSession {
   override def beforeEach(): Unit = {
     super.beforeEach()
     SparkConnectService.sessionManager.invalidateAllSessions()
-    
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
+    SparkConnectService.sessionManager.initializeBaseSession(() => 
spark.newSession())
   }
 
   test("sessionId needs to be an UUID") {
@@ -178,7 +178,7 @@ class SparkConnectSessionManagerSuite extends 
SharedSparkSession {
     val sessionManager = new SparkConnectSessionManager()
 
     // Initialize the base session with the test SparkContext
-    sessionManager.initializeBaseSession(spark.sparkContext)
+    sessionManager.initializeBaseSession(() => spark.newSession())
 
     // Clear the default and active sessions to simulate the scenario where
     // SparkSession.active or SparkSession.getDefaultSession would fail
@@ -202,13 +202,13 @@ class SparkConnectSessionManagerSuite extends 
SharedSparkSession {
     val sessionManager = new SparkConnectSessionManager()
 
     // Initialize the base session multiple times
-    sessionManager.initializeBaseSession(spark.sparkContext)
+    sessionManager.initializeBaseSession(() => spark.newSession())
     val key1 = SessionKey("user1", UUID.randomUUID().toString)
     val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
     val baseSessionUUID1 = sessionHolder1.session.sessionUUID
 
     // Initialize again - should not change the base session
-    sessionManager.initializeBaseSession(spark.sparkContext)
+    sessionManager.initializeBaseSession(() => spark.newSession())
     val key2 = SessionKey("user2", UUID.randomUUID().toString)
     val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to