This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 014c60fb5bf [SPARK-42477][CONNECT][PYTHON] accept user_agent in spark connect's connection string 014c60fb5bf is described below commit 014c60fb5bf712afafca4eef884665d4245d4aaf Author: Niranjan Jayakar <n...@databricks.com> AuthorDate: Mon Feb 20 23:02:11 2023 +0900 [SPARK-42477][CONNECT][PYTHON] accept user_agent in spark connect's connection string ### What changes were proposed in this pull request? Currently, the Spark Connect service's `client_type` attribute (which is really [user agent]) is set to `_SPARK_CONNECT_PYTHON` to signify PySpark. With this change, the connection for the Spark Connect remote accepts an optional `user_agent` parameter which is then passed down to the service. [user agent]: https://www.w3.org/WAI/UA/work/wiki/Definition_of_User_Agent ### Why are the changes needed? This enables partners using Spark Connect to set their application as the user agent, which then allows visibility and measurement of integrations and usages of spark connect. ### Does this PR introduce _any_ user-facing change? A new optional `user_agent` parameter is now recognized as part of the Spark Connect connection string. ### How was this patch tested? - unit tests attached - manually running the `pyspark` binary with the `user_agent` connection string set and verifying the payload sent to the server. Similar testing for the default. Closes #40054 from nija-at/user-agent. Authored-by: Niranjan Jayakar <n...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit b887d3de954ae5b2482087fe08affcc4ac60c669) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- connector/connect/docs/client-connection-string.md | 12 +++- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/client.py | 29 +++++++++- python/pyspark/sql/tests/connect/test_client.py | 67 ++++++++++++++++++++++ .../sql/tests/connect/test_connect_basic.py | 28 ++++++++- 5 files changed, 132 insertions(+), 5 deletions(-) diff --git a/connector/connect/docs/client-connection-string.md b/connector/connect/docs/client-connection-string.md index 8f1f0b8c631..6e5b0c80db7 100644 --- a/connector/connect/docs/client-connection-string.md +++ b/connector/connect/docs/client-connection-string.md @@ -58,7 +58,8 @@ sc://hostname:port/;param1=value;param2=value <td>token</td> <td>String</td> <td>When this param is set in the URL, it will enable standard - bearer token authentication using GRPC. By default this value is not set.</td> + bearer token authentication using GRPC. By default this value is not set. + Setting this value enables SSL.</td> <td><pre>token=ABCDEFGH</pre></td> </tr> <tr> @@ -81,6 +82,15 @@ sc://hostname:port/;param1=value;param2=value <pre>user_id=Martin</pre> </td> </tr> + <tr> + <td>user_agent</td> + <td>String</td> + <td>The user agent acting on behalf of the user, typically applications + that use Spark Connect to implement its functionality and execute Spark + requests on behalf of the user.<br/> + <i>Default: </i><pre>_SPARK_CONNECT_PYTHON</pre> in the Python client</td> + <td><pre>user_agent=my_data_query_app</pre></td> + </tr> </table> ## Examples diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 94ae1ffbce6..75a6b4401b8 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -516,6 +516,7 @@ pyspark_connect = Module( "pyspark.sql.connect.dataframe", "pyspark.sql.connect.functions", # unittests + "pyspark.sql.tests.connect.test_client", "pyspark.sql.tests.connect.test_connect_plan", "pyspark.sql.tests.connect.test_connect_basic", "pyspark.sql.tests.connect.test_connect_function", diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index aade0f6e050..78190b2c488 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -19,6 +19,8 @@ __all__ = [ "SparkConnectClient", ] +import string + from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__, __file__) @@ -120,6 +122,7 @@ class ChannelBuilder: PARAM_USE_SSL = "use_ssl" PARAM_TOKEN = "token" PARAM_USER_ID = "user_id" + PARAM_USER_AGENT = "user_agent" @staticmethod def default_port() -> int: @@ -215,6 +218,7 @@ class ChannelBuilder: ChannelBuilder.PARAM_TOKEN, ChannelBuilder.PARAM_USE_SSL, ChannelBuilder.PARAM_USER_ID, + ChannelBuilder.PARAM_USER_AGENT, ] ] @@ -244,6 +248,27 @@ class ChannelBuilder: """ return self.params.get(ChannelBuilder.PARAM_USER_ID, None) + @property + def userAgent(self) -> str: + """ + Returns + ------- + user_agent : str + The user_agent parameter specified in the connection string, + or "_SPARK_CONNECT_PYTHON" when not specified. + """ + user_agent = self.params.get(ChannelBuilder.PARAM_USER_AGENT, "_SPARK_CONNECT_PYTHON") + allowed_chars = string.ascii_letters + string.punctuation + if len(user_agent) > 200: + raise SparkConnectException( + "'user_agent' parameter cannot exceed 200 characters in length" + ) + if set(user_agent).difference(allowed_chars): + raise SparkConnectException( + "Only alphanumeric and common punctuations are allowed for 'user_agent'" + ) + return user_agent + def get(self, key: str) -> Any: """ Parameters @@ -559,7 +584,7 @@ class SparkConnectClient(object): def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest: req = pb2.ExecutePlanRequest() req.client_id = self._session_id - req.client_type = "_SPARK_CONNECT_PYTHON" + req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id return req @@ -567,7 +592,7 @@ class SparkConnectClient(object): def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req = pb2.AnalyzePlanRequest() req.client_id = self._session_id - req.client_type = "_SPARK_CONNECT_PYTHON" + req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id return req diff --git a/python/pyspark/sql/tests/connect/test_client.py b/python/pyspark/sql/tests/connect/test_client.py new file mode 100644 index 00000000000..41b2888eb74 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_client.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import Optional + +from pyspark.sql.connect.client import SparkConnectClient +import pyspark.sql.connect.proto as proto + + +class SparkConnectClientTestCase(unittest.TestCase): + def test_user_agent_passthrough(self): + client = SparkConnectClient("sc://foo/;user_agent=bar") + mock = MockService(client._session_id) + client._stub = mock + + command = proto.Command() + client.execute_command(command) + + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertEqual(mock.req.client_type, "bar") + + def test_user_agent_default(self): + client = SparkConnectClient("sc://foo/") + mock = MockService(client._session_id) + client._stub = mock + + command = proto.Command() + client.execute_command(command) + + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertEqual(mock.req.client_type, "_SPARK_CONNECT_PYTHON") + + +class MockService: + # Simplest mock of the SparkConnectService. + # If this needs more complex logic, it needs to be replaced with Python mocking. + + req: Optional[proto.ExecutePlanRequest] + + def __init__(self, session_id: str): + self._session_id = session_id + self.req = None + + def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): + self.req = req + resp = proto.ExecutePlanResponse() + resp.client_id = self._session_id + return [resp] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 8bfffee1ac1..adcd457a105 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2949,10 +2949,33 @@ class ChannelBuilderTests(unittest.TestCase): chan = ChannelBuilder("sc://host/;token=abcs") self.assertTrue(chan.secure, "specifying a token must set the channel to secure") - + self.assertEqual(chan.userAgent, "_SPARK_CONNECT_PYTHON") chan = ChannelBuilder("sc://host/;use_ssl=abcs") self.assertFalse(chan.secure, "Garbage in, false out") + def test_invalid_user_agent_charset(self): + # fmt: off + invalid_user_agents = [ + "agent»", # non standard symbol + "age nt", # whitespace + "ägent", # non-ascii alphabet + ] + # fmt: on + for user_agent in invalid_user_agents: + with self.subTest(user_agent=user_agent): + chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}") + with self.assertRaises(SparkConnectException) as err: + chan.userAgent + + self.assertRegex(err.exception.message, "alphanumeric and common punctuations") + + def test_invalid_user_agent_len(self): + user_agent = "x" * 201 + chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}") + with self.assertRaises(SparkConnectException) as err: + chan.userAgent + self.assertRegex(err.exception.message, "characters in length") + def test_valid_channel_creation(self): chan = ChannelBuilder("sc://host").toChannel() self.assertIsInstance(chan, grpc.Channel) @@ -2965,8 +2988,9 @@ class ChannelBuilderTests(unittest.TestCase): self.assertIsInstance(chan, grpc.Channel) def test_channel_properties(self): - chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021") + chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021") self.assertEqual("host:15002", chan.endpoint) + self.assertEqual("foo", chan.userAgent) self.assertEqual(True, chan.secure) self.assertEqual("120 21", chan.get("param1")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org