This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 014c60fb5bf [SPARK-42477][CONNECT][PYTHON] accept user_agent in spark 
connect's connection string
014c60fb5bf is described below

commit 014c60fb5bf712afafca4eef884665d4245d4aaf
Author: Niranjan Jayakar <n...@databricks.com>
AuthorDate: Mon Feb 20 23:02:11 2023 +0900

    [SPARK-42477][CONNECT][PYTHON] accept user_agent in spark connect's 
connection string
    
    ### What changes were proposed in this pull request?
    
    Currently, the Spark Connect service's `client_type` attribute (which is 
really [user
    agent]) is set to `_SPARK_CONNECT_PYTHON` to signify PySpark.
    
    With this change, the connection for the Spark Connect remote accepts an 
optional
    `user_agent` parameter which is then passed down to the service.
    
    [user agent]: https://www.w3.org/WAI/UA/work/wiki/Definition_of_User_Agent
    
    ### Why are the changes needed?
    
    This enables partners using Spark Connect to set their application as the 
user agent,
    which then allows visibility and measurement of integrations and usages of 
spark
    connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    A new optional `user_agent` parameter is now recognized as part of the 
Spark Connect
    connection string.
    
    ### How was this patch tested?
    
    - unit tests attached
    - manually running the `pyspark` binary with the `user_agent` connection 
string set and
       verifying the payload sent to the server. Similar testing for the 
default.
    
    Closes #40054 from nija-at/user-agent.
    
    Authored-by: Niranjan Jayakar <n...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit b887d3de954ae5b2482087fe08affcc4ac60c669)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 connector/connect/docs/client-connection-string.md | 12 +++-
 dev/sparktestsupport/modules.py                    |  1 +
 python/pyspark/sql/connect/client.py               | 29 +++++++++-
 python/pyspark/sql/tests/connect/test_client.py    | 67 ++++++++++++++++++++++
 .../sql/tests/connect/test_connect_basic.py        | 28 ++++++++-
 5 files changed, 132 insertions(+), 5 deletions(-)

diff --git a/connector/connect/docs/client-connection-string.md 
b/connector/connect/docs/client-connection-string.md
index 8f1f0b8c631..6e5b0c80db7 100644
--- a/connector/connect/docs/client-connection-string.md
+++ b/connector/connect/docs/client-connection-string.md
@@ -58,7 +58,8 @@ sc://hostname:port/;param1=value;param2=value
     <td>token</td>
     <td>String</td>
     <td>When this param is set in the URL, it will enable standard
-    bearer token authentication using GRPC. By default this value is not 
set.</td>
+    bearer token authentication using GRPC. By default this value is not set.
+    Setting this value enables SSL.</td>
     <td><pre>token=ABCDEFGH</pre></td>
   </tr>
   <tr>
@@ -81,6 +82,15 @@ sc://hostname:port/;param1=value;param2=value
     <pre>user_id=Martin</pre>
     </td>
   </tr>
+  <tr>
+    <td>user_agent</td>
+    <td>String</td>
+    <td>The user agent acting on behalf of the user, typically applications
+    that use Spark Connect to implement its functionality and execute Spark
+    requests on behalf of the user.<br/>
+    <i>Default: </i><pre>_SPARK_CONNECT_PYTHON</pre> in the Python client</td>
+    <td><pre>user_agent=my_data_query_app</pre></td>
+  </tr>
 </table>
 
 ## Examples
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 94ae1ffbce6..75a6b4401b8 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -516,6 +516,7 @@ pyspark_connect = Module(
         "pyspark.sql.connect.dataframe",
         "pyspark.sql.connect.functions",
         # unittests
+        "pyspark.sql.tests.connect.test_client",
         "pyspark.sql.tests.connect.test_connect_plan",
         "pyspark.sql.tests.connect.test_connect_basic",
         "pyspark.sql.tests.connect.test_connect_function",
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index aade0f6e050..78190b2c488 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -19,6 +19,8 @@ __all__ = [
     "SparkConnectClient",
 ]
 
+import string
+
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__, __file__)
@@ -120,6 +122,7 @@ class ChannelBuilder:
     PARAM_USE_SSL = "use_ssl"
     PARAM_TOKEN = "token"
     PARAM_USER_ID = "user_id"
+    PARAM_USER_AGENT = "user_agent"
 
     @staticmethod
     def default_port() -> int:
@@ -215,6 +218,7 @@ class ChannelBuilder:
                 ChannelBuilder.PARAM_TOKEN,
                 ChannelBuilder.PARAM_USE_SSL,
                 ChannelBuilder.PARAM_USER_ID,
+                ChannelBuilder.PARAM_USER_AGENT,
             ]
         ]
 
@@ -244,6 +248,27 @@ class ChannelBuilder:
         """
         return self.params.get(ChannelBuilder.PARAM_USER_ID, None)
 
+    @property
+    def userAgent(self) -> str:
+        """
+        Returns
+        -------
+        user_agent : str
+            The user_agent parameter specified in the connection string,
+            or "_SPARK_CONNECT_PYTHON" when not specified.
+        """
+        user_agent = self.params.get(ChannelBuilder.PARAM_USER_AGENT, 
"_SPARK_CONNECT_PYTHON")
+        allowed_chars = string.ascii_letters + string.punctuation
+        if len(user_agent) > 200:
+            raise SparkConnectException(
+                "'user_agent' parameter cannot exceed 200 characters in length"
+            )
+        if set(user_agent).difference(allowed_chars):
+            raise SparkConnectException(
+                "Only alphanumeric and common punctuations are allowed for 
'user_agent'"
+            )
+        return user_agent
+
     def get(self, key: str) -> Any:
         """
         Parameters
@@ -559,7 +584,7 @@ class SparkConnectClient(object):
     def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
         req = pb2.ExecutePlanRequest()
         req.client_id = self._session_id
-        req.client_type = "_SPARK_CONNECT_PYTHON"
+        req.client_type = self._builder.userAgent
         if self._user_id:
             req.user_context.user_id = self._user_id
         return req
@@ -567,7 +592,7 @@ class SparkConnectClient(object):
     def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
         req = pb2.AnalyzePlanRequest()
         req.client_id = self._session_id
-        req.client_type = "_SPARK_CONNECT_PYTHON"
+        req.client_type = self._builder.userAgent
         if self._user_id:
             req.user_context.user_id = self._user_id
         return req
diff --git a/python/pyspark/sql/tests/connect/test_client.py 
b/python/pyspark/sql/tests/connect/test_client.py
new file mode 100644
index 00000000000..41b2888eb74
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_client.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import Optional
+
+from pyspark.sql.connect.client import SparkConnectClient
+import pyspark.sql.connect.proto as proto
+
+
+class SparkConnectClientTestCase(unittest.TestCase):
+    def test_user_agent_passthrough(self):
+        client = SparkConnectClient("sc://foo/;user_agent=bar")
+        mock = MockService(client._session_id)
+        client._stub = mock
+
+        command = proto.Command()
+        client.execute_command(command)
+
+        self.assertIsNotNone(mock.req, "ExecutePlan API was not called when 
expected")
+        self.assertEqual(mock.req.client_type, "bar")
+
+    def test_user_agent_default(self):
+        client = SparkConnectClient("sc://foo/")
+        mock = MockService(client._session_id)
+        client._stub = mock
+
+        command = proto.Command()
+        client.execute_command(command)
+
+        self.assertIsNotNone(mock.req, "ExecutePlan API was not called when 
expected")
+        self.assertEqual(mock.req.client_type, "_SPARK_CONNECT_PYTHON")
+
+
+class MockService:
+    # Simplest mock of the SparkConnectService.
+    # If this needs more complex logic, it needs to be replaced with Python 
mocking.
+
+    req: Optional[proto.ExecutePlanRequest]
+
+    def __init__(self, session_id: str):
+        self._session_id = session_id
+        self.req = None
+
+    def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
+        self.req = req
+        resp = proto.ExecutePlanResponse()
+        resp.client_id = self._session_id
+        return [resp]
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 8bfffee1ac1..adcd457a105 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2949,10 +2949,33 @@ class ChannelBuilderTests(unittest.TestCase):
 
         chan = ChannelBuilder("sc://host/;token=abcs")
         self.assertTrue(chan.secure, "specifying a token must set the channel 
to secure")
-
+        self.assertEqual(chan.userAgent, "_SPARK_CONNECT_PYTHON")
         chan = ChannelBuilder("sc://host/;use_ssl=abcs")
         self.assertFalse(chan.secure, "Garbage in, false out")
 
+    def test_invalid_user_agent_charset(self):
+        # fmt: off
+        invalid_user_agents = [
+            "agent»",  # non standard symbol
+            "age nt",  # whitespace
+            "ägent",   # non-ascii alphabet
+        ]
+        # fmt: on
+        for user_agent in invalid_user_agents:
+            with self.subTest(user_agent=user_agent):
+                chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
+                with self.assertRaises(SparkConnectException) as err:
+                    chan.userAgent
+
+                self.assertRegex(err.exception.message, "alphanumeric and 
common punctuations")
+
+    def test_invalid_user_agent_len(self):
+        user_agent = "x" * 201
+        chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
+        with self.assertRaises(SparkConnectException) as err:
+            chan.userAgent
+        self.assertRegex(err.exception.message, "characters in length")
+
     def test_valid_channel_creation(self):
         chan = ChannelBuilder("sc://host").toChannel()
         self.assertIsInstance(chan, grpc.Channel)
@@ -2965,8 +2988,9 @@ class ChannelBuilderTests(unittest.TestCase):
         self.assertIsInstance(chan, grpc.Channel)
 
     def test_channel_properties(self):
-        chan = 
ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021")
+        chan = 
ChannelBuilder("sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021")
         self.assertEqual("host:15002", chan.endpoint)
+        self.assertEqual("foo", chan.userAgent)
         self.assertEqual(True, chan.secure)
         self.assertEqual("120 21", chan.get("param1"))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to