This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c4619b503e5 [SPARK-41738][CONNECT] Mix ClientId in SparkSession cache
c4619b503e5 is described below

commit c4619b503e58da38c6223020a73c2ca2e0a8c0fa
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Wed Dec 28 20:04:45 2022 +0900

    [SPARK-41738][CONNECT] Mix ClientId in SparkSession cache
    
    ### What changes were proposed in this pull request?
    This PR mixes the client ID into the cache for the SparkSessions on the 
server. This is necessary to allow to concurrent SparkSessions of the same user 
to run without interfering.
    
    The client ID was added to be used in that way, but until now the 
functionality was not implemented.
    
    On the client side, the Python client now validates that the result 
received is actually from the same session.
    
    ### Why are the changes needed?
    Stability
    
    ### Does this PR introduce _any_ user-facing change?
    NO
    
    ### How was this patch tested?
    Existing UT
    
    Closes #39256 from grundprinzip/SPARK-41738.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/connect/service/SparkConnectService.scala    | 14 ++++++++------
 .../sql/connect/service/SparkConnectStreamHandler.scala    |  4 +++-
 .../sql/connect/planner/SparkConnectServiceSuite.scala     |  3 ++-
 python/pyspark/sql/connect/client.py                       | 12 ++++++++++++
 4 files changed, 25 insertions(+), 8 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 3046c8eebfc..bfcea3d2252 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -94,7 +94,9 @@ class SparkConnectService(debug: Boolean)
             s"${request.getPlan.getOpTypeCase} not supported for analysis."))
       }
       val session =
-        
SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session
+        SparkConnectService
+          .getOrCreateIsolatedSession(request.getUserContext.getUserId, 
request.getClientId)
+          .session
 
       val explainMode = request.getExplain.getExplainMode match {
         case proto.Explain.ExplainMode.SIMPLE => SimpleMode
@@ -145,7 +147,7 @@ class SparkConnectService(debug: Boolean)
  * @param userId
  * @param session
  */
-case class SessionHolder(userId: String, session: SparkSession)
+case class SessionHolder(userId: String, sessionId: String, session: 
SparkSession)
 
 /**
  * Static instance of the SparkConnectService.
@@ -161,7 +163,7 @@ object SparkConnectService {
 
   // Type alias for the SessionCacheKey. Right now this is a String but allows 
us to switch to a
   // different or complex type easily.
-  private type SessionCacheKey = String;
+  private type SessionCacheKey = (String, String);
 
   private var server: Server = _
 
@@ -183,11 +185,11 @@ object SparkConnectService {
   /**
    * Based on the `key` find or create a new SparkSession.
    */
-  def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = {
+  def getOrCreateIsolatedSession(userId: String, sessionId: String): 
SessionHolder = {
     userSessionMapping.get(
-      key,
+      (userId, sessionId),
       () => {
-        SessionHolder(key, newIsolatedSession())
+        SessionHolder(userId, sessionId, newIsolatedSession())
       })
   }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index fcae3501cef..9631b93f6e9 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -42,7 +42,9 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
 
   def handle(v: ExecutePlanRequest): Unit = {
     val session =
-      
SparkConnectService.getOrCreateIsolatedSession(v.getUserContext.getUserId).session
+      SparkConnectService
+        .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getClientId)
+        .session
     v.getPlan.getOpTypeCase match {
       case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
       case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 8f4268b904b..6dcce0926dc 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -153,7 +153,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
     val instance = new SparkConnectService(false)
 
     // Add an always crashing UDF
-    val session = SparkConnectService.getOrCreateIsolatedSession("c1").session
+    val session = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session").session
     val instaKill: Long => Long = { _ =>
       throw new Exception("Kaboom")
     }
@@ -172,6 +172,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
       .newBuilder()
       .setPlan(plan)
       .setUserContext(context)
+      .setClientId("session")
       .build()
 
     // The observer is executed inside this thread. So
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 2e32a87676f..e258dbd92b4 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -290,6 +290,10 @@ class SparkConnectClient(object):
         # Parse the connection string.
         self._builder = ChannelBuilder(connectionString)
         self._user_id = None
+        # Generate a unique session ID for this client. This UUID must be 
unique to allow
+        # concurrent Spark sessions of the same user. If the channel is 
closed, creating
+        # a new client will create a new session ID.
+        self._session_id = str(uuid.uuid4())
         if self._builder.userId is not None:
             self._user_id = self._builder.userId
         elif userId is not None:
@@ -367,6 +371,7 @@ class SparkConnectClient(object):
 
     def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
         req = pb2.ExecutePlanRequest()
+        req.client_id = self._session_id
         req.client_type = "_SPARK_CONNECT_PYTHON"
         if self._user_id:
             req.user_context.user_id = self._user_id
@@ -374,6 +379,7 @@ class SparkConnectClient(object):
 
     def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
         req = pb2.AnalyzePlanRequest()
+        req.client_id = self._session_id
         req.client_type = "_SPARK_CONNECT_PYTHON"
         if self._user_id:
             req.user_context.user_id = self._user_id
@@ -401,6 +407,8 @@ class SparkConnectClient(object):
             req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED
 
         resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
+        if resp.client_id != self._session_id:
+            raise ValueError("Received incorrect session identifier for 
request.")
         return AnalyzeResult.fromProto(resp)
 
     def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) 
-> "pandas.DataFrame":
@@ -409,6 +417,8 @@ class SparkConnectClient(object):
 
     def _execute(self, req: pb2.ExecutePlanRequest) -> None:
         for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+            if b.client_id != self._session_id:
+                raise ValueError("Received incorrect session identifier for 
request.")
             continue
         return
 
@@ -419,6 +429,8 @@ class SparkConnectClient(object):
         result_dfs = []
 
         for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+            if b.client_id != self._session_id:
+                raise ValueError("Received incorrect session identifier for 
request.")
             if b.metrics is not None:
                 m = b.metrics
             if b.HasField("arrow_batch"):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to