This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 576d1698c7b [SPARK-44740][CONNECT] Support specifying `session_id` in 
SPARK_REMOTE connection string
576d1698c7b is described below

commit 576d1698c7b52b2d9ce00fa2fc5912c92e8bbe67
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Thu Aug 10 08:41:13 2023 +0900

    [SPARK-44740][CONNECT] Support specifying `session_id` in SPARK_REMOTE 
connection string
    
    ### What changes were proposed in this pull request?
    To support cross-language session sharing in Spark connect, we need to be 
able to inject the session ID into the connection string because on the server 
side, the client-provided session ID is used already together with the user id.
    
    ```
    
SparkSession.builder.remote("sc://localhost/;session_id=abcdefg").getOrCreate()
    ```
    
    ### Why are the changes needed?
    ease of use
    
    ### Does this PR introduce _any_ user-facing change?
    Adds a way to configure the Spark Connect connection string with 
`session_id`
    
    ### How was this patch tested?
    Added UT for the parameter.
    
    Closes #42415 from grundprinzip/SPARK-44740.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 7af4e358f3f4902cc9601e56c2662b8921a925d6)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/connect/client/SparkConnectClient.scala    | 22 ++++++++++++++--
 .../connect/client/SparkConnectClientParser.scala  |  3 +++
 .../SparkConnectClientBuilderParseTestSuite.scala  |  4 +++
 .../connect/client/SparkConnectClientSuite.scala   |  6 +++++
 connector/connect/docs/client-connection-string.md | 11 ++++++++
 python/pyspark/sql/connect/client/core.py          | 30 +++++++++++++++++++---
 .../sql/tests/connect/client/test_client.py        |  7 +++++
 .../sql/tests/connect/test_connect_basic.py        | 18 ++++++++++++-
 8 files changed, 94 insertions(+), 7 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index a028df536cf..637499f090c 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -56,7 +56,7 @@ private[sql] class SparkConnectClient(
   // Generate a unique session ID for this client. This UUID must be unique to 
allow
   // concurrent Spark sessions of the same user. If the channel is closed, 
creating
   // a new client will create a new session ID.
-  private[sql] val sessionId: String = UUID.randomUUID.toString
+  private[sql] val sessionId: String = 
configuration.sessionId.getOrElse(UUID.randomUUID.toString)
 
   private[client] val artifactManager: ArtifactManager = {
     new ArtifactManager(configuration, sessionId, bstub, stub)
@@ -432,6 +432,7 @@ object SparkConnectClient {
       val PARAM_USE_SSL = "use_ssl"
       val PARAM_TOKEN = "token"
       val PARAM_USER_AGENT = "user_agent"
+      val PARAM_SESSION_ID = "session_id"
     }
 
     private def verifyURI(uri: URI): Unit = {
@@ -463,6 +464,21 @@ object SparkConnectClient {
       this
     }
 
+    def sessionId(value: String): Builder = {
+      try {
+        UUID.fromString(value).toString
+      } catch {
+        case e: IllegalArgumentException =>
+          throw new IllegalArgumentException(
+            "Parameter value 'session_id' must be a valid UUID format.",
+            e)
+      }
+      _configuration = _configuration.copy(sessionId = Some(value))
+      this
+    }
+
+    def sessionId: Option[String] = _configuration.sessionId
+
     def userAgent: String = _configuration.userAgent
 
     def option(key: String, value: String): Builder = {
@@ -490,6 +506,7 @@ object SparkConnectClient {
           case URIParams.PARAM_TOKEN => token(value)
           case URIParams.PARAM_USE_SSL =>
             if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl()
+          case URIParams.PARAM_SESSION_ID => sessionId(value)
           case _ => option(key, value)
         }
       }
@@ -576,7 +593,8 @@ object SparkConnectClient {
       userAgent: String = DEFAULT_USER_AGENT,
       retryPolicy: GrpcRetryHandler.RetryPolicy = 
GrpcRetryHandler.RetryPolicy(),
       useReattachableExecute: Boolean = true,
-      interceptors: List[ClientInterceptor] = List.empty) {
+      interceptors: List[ClientInterceptor] = List.empty,
+      sessionId: Option[String] = None) {
 
     def userContext: proto.UserContext = {
       val builder = proto.UserContext.newBuilder()
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
index dda769dc2ad..f873e1045bf 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
@@ -71,6 +71,9 @@ private[sql] object SparkConnectClientParser {
       case "--user_agent" :: tail =>
         val (value, remainder) = extract("--user_agent", tail)
         parse(remainder, builder.userAgent(value))
+      case "--session_id" :: tail =>
+        val (value, remainder) = extract("--session_id", tail)
+        parse(remainder, builder.sessionId(value))
       case "--option" :: tail =>
         if (args.isEmpty) {
           throw new IllegalArgumentException("--option requires a key-value 
pair")
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
index 2c6886d0386..1dc1fd567ec 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
@@ -16,6 +16,8 @@
  */
 package org.apache.spark.sql.connect.client
 
+import java.util.UUID
+
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
 
 /**
@@ -46,6 +48,7 @@ class SparkConnectClientBuilderParseTestSuite extends 
ConnectFunSuite {
   argumentTest("user_id", "U1238", _.userId.get)
   argumentTest("user_name", "alice", _.userName.get)
   argumentTest("user_agent", "MY APP", _.userAgent)
+  argumentTest("session_id", UUID.randomUUID().toString, _.sessionId.get)
 
   test("Argument - remote") {
     val builder =
@@ -55,6 +58,7 @@ class SparkConnectClientBuilderParseTestSuite extends 
ConnectFunSuite {
     assert(builder.token.contains("nahnah"))
     assert(builder.userId.contains("x127"))
     assert(builder.options === Map(("user_name", "Q"), ("param1", "x")))
+    assert(builder.sessionId.isEmpty)
   }
 
   test("Argument - use_ssl") {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 3436037809d..e483e0a7291 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -164,6 +164,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
       client => {
         assert(client.configuration.host == "localhost")
         assert(client.configuration.port == 1234)
+        assert(client.sessionId != null)
+        // Must be able to parse the UUID
+        assert(UUID.fromString(client.sessionId) != null)
       }),
     TestPackURI(
       "sc://localhost/;",
@@ -193,6 +196,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     TestPackURI("sc://host:123/;use_ssl=true", isCorrect = true),
     TestPackURI("sc://host:123/;token=mySecretToken", isCorrect = true),
     TestPackURI("sc://host:123/;token=", isCorrect = false),
+    TestPackURI("sc://host:123/;session_id=", isCorrect = false),
+    TestPackURI("sc://host:123/;session_id=abcdefgh", isCorrect = false),
+    TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", 
isCorrect = true),
     TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = 
true),
     TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = 
true),
     TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect 
= false),
diff --git a/connector/connect/docs/client-connection-string.md 
b/connector/connect/docs/client-connection-string.md
index 6e5b0c80db7..ebab7cbff4f 100644
--- a/connector/connect/docs/client-connection-string.md
+++ b/connector/connect/docs/client-connection-string.md
@@ -91,6 +91,17 @@ sc://hostname:port/;param1=value;param2=value
     <i>Default: </i><pre>_SPARK_CONNECT_PYTHON</pre> in the Python client</td>
     <td><pre>user_agent=my_data_query_app</pre></td>
   </tr>
+  <tr>
+    <td>session_id</td>
+    <td>String</td>
+    <td>In addition to the user ID, the cache of Spark Sessions in the Spark 
Connect
+    server uses a session ID as the cache key. This option in the connection 
string
+    allows to provide this session ID to allow sharing Spark Sessions for the 
same users
+    for example across multiple languages. The value must be provided in a 
valid UUID 
+    string format.<br/>
+    <i>Default: A UUID generated randomly.</td>
+    <td><pre>session_id=550e8400-e29b-41d4-a716-446655440000</pre></td>
+  </tr>
 </table>
 
 ## Examples
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index b62621bc3c2..a9da6723a8c 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -156,6 +156,7 @@ class ChannelBuilder:
     PARAM_TOKEN = "token"
     PARAM_USER_ID = "user_id"
     PARAM_USER_AGENT = "user_agent"
+    PARAM_SESSION_ID = "session_id"
     MAX_MESSAGE_LENGTH = 128 * 1024 * 1024
 
     @staticmethod
@@ -354,6 +355,22 @@ class ChannelBuilder:
         """
         return self.params[key]
 
+    @property
+    def session_id(self) -> Optional[str]:
+        """
+        Returns
+        -------
+        The session_id extracted from the parameters of the connection string 
or `None` if not
+        specified.
+        """
+        session_id = self.params.get(ChannelBuilder.PARAM_SESSION_ID, None)
+        if session_id is not None:
+            try:
+                uuid.UUID(session_id, version=4)
+            except ValueError as ve:
+                raise ValueError("Parameter value 'session_id' must be a valid 
UUID format.", ve)
+        return session_id
+
     def toChannel(self) -> grpc.Channel:
         """
         Applies the parameters of the connection string and creates a new
@@ -628,10 +645,15 @@ class SparkConnectClient(object):
         if retry_policy:
             self._retry_policy.update(retry_policy)
 
-        # Generate a unique session ID for this client. This UUID must be 
unique to allow
-        # concurrent Spark sessions of the same user. If the channel is 
closed, creating
-        # a new client will create a new session ID.
-        self._session_id = str(uuid.uuid4())
+        if self._builder.session_id is None:
+            # Generate a unique session ID for this client. This UUID must be 
unique to allow
+            # concurrent Spark sessions of the same user. If the channel is 
closed, creating
+            # a new client will create a new session ID.
+            self._session_id = str(uuid.uuid4())
+        else:
+            # Use the pre-defined session ID.
+            self._session_id = str(self._builder.session_id)
+
         if self._builder.userId is not None:
             self._user_id = self._builder.userId
         elif user_id is not None:
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 9276b88e153..9782add92f4 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -16,6 +16,7 @@
 #
 
 import unittest
+import uuid
 from typing import Optional
 
 from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
@@ -88,6 +89,12 @@ class SparkConnectClientTestCase(unittest.TestCase):
         client.close()
         self.assertTrue(client.is_closed)
 
+    def test_channel_builder_with_session(self):
+        dummy = str(uuid.uuid4())
+        chan = ChannelBuilder(f"sc://foo/;session_id={dummy}")
+        client = SparkConnectClient(chan)
+        self.assertEqual(client._session_id, chan.session_id)
+
 
 class MockService:
     # Simplest mock of the SparkConnectService.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0687fc9f313..63b65ecce1a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -23,6 +23,7 @@ import random
 import shutil
 import string
 import tempfile
+import uuid
 from collections import defaultdict
 
 from pyspark.errors import (
@@ -76,7 +77,7 @@ if should_test_connect:
     from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
     from pyspark.sql import functions as SF
     from pyspark.sql.connect import functions as CF
-    from pyspark.sql.connect.client.core import Retrying
+    from pyspark.sql.connect.client.core import Retrying, SparkConnectClient
 
 
 class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, 
PandasOnSparkTestUtils):
@@ -3522,6 +3523,21 @@ class ChannelBuilderTests(unittest.TestCase):
         md = chan.metadata()
         self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)
 
+    def test_metadata(self):
+        id = str(uuid.uuid4())
+        chan = ChannelBuilder(f"sc://host/;session_id={id}")
+        self.assertEqual(id, chan.session_id)
+
+        with self.assertRaises(ValueError) as ve:
+            chan = ChannelBuilder("sc://host/;session_id=abcd")
+            SparkConnectClient(chan)
+        self.assertIn(
+            "Parameter value 'session_id' must be a valid UUID format.", 
str(ve.exception)
+        )
+
+        chan = ChannelBuilder("sc://host/")
+        self.assertIsNone(chan.session_id)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_basic import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to