This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 51e947d6a02 [SPARK-43192][CONNECT] Remove user agent charset validation
51e947d6a02 is described below

commit 51e947d6a02e4578f73902f8d487c1601c2f8dae
Author: Niranjan Jayakar <[email protected]>
AuthorDate: Fri Apr 21 15:23:42 2023 +0900

    [SPARK-43192][CONNECT] Remove user agent charset validation
    
    ### Why are the changes needed?
    
    The current validation on charset is restrictive. It does not allow
    blank space and digits. It's common for user agent strings to contain
    these characters.
    
    Secondly, it restricts the length to stay under 200 characters.
    The limit to 200 characters was mostly something that was done as a
    simple protection mechanism. We've looked into different specifications
    for what could be part of the user agent and longer user agents are
    common. Increase this restriction to 2KB.
    The server should enforce restrictions on its side, but we would still like
    to keep the restriction as fallback protection, but allow large values.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    Closes #40853 from nija-at/user-agent-charset.
    
    Authored-by: Niranjan Jayakar <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client.py               | 13 +++-------
 .../sql/tests/connect/test_connect_basic.py        | 29 ++++++++--------------
 2 files changed, 15 insertions(+), 27 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 60f3f1ac2ba..7585c8124bf 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -19,8 +19,6 @@ __all__ = [
     "SparkConnectClient",
 ]
 
-import string
-
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -270,16 +268,13 @@ class ChannelBuilder:
         user_agent : str
             The user_agent parameter specified in the connection string,
             or "_SPARK_CONNECT_PYTHON" when not specified.
+            The returned value will be percent encoded.
         """
         user_agent = self.params.get(ChannelBuilder.PARAM_USER_AGENT, 
"_SPARK_CONNECT_PYTHON")
-        allowed_chars = string.ascii_letters + string.punctuation
-        if len(user_agent) > 200:
-            raise SparkConnectException(
-                "'user_agent' parameter cannot exceed 200 characters in length"
-            )
-        if set(user_agent).difference(allowed_chars):
+        ua_len = len(urllib.parse.quote(user_agent))
+        if ua_len > 2048:
             raise SparkConnectException(
-                "Only alphanumeric and common punctuations are allowed for 
'user_agent'"
+                f"'user_agent' parameter should not exceed 2048 characters, 
found {len} characters."
             )
         return user_agent
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9d12eb2b26e..b316f0f3b4c 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3293,28 +3293,21 @@ class ChannelBuilderTests(unittest.TestCase):
         chan = ChannelBuilder("sc://host/;use_ssl=abcs")
         self.assertFalse(chan.secure, "Garbage in, false out")
 
-    def test_invalid_user_agent_charset(self):
-        # fmt: off
-        invalid_user_agents = [
-            "agent»",  # non standard symbol
-            "age nt",  # whitespace
-            "ägent",   # non-ascii alphabet
-        ]
-        # fmt: on
-        for user_agent in invalid_user_agents:
-            with self.subTest(user_agent=user_agent):
-                chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
-                with self.assertRaises(SparkConnectException) as err:
-                    chan.userAgent
-
-                self.assertRegex(err.exception.message, "alphanumeric and 
common punctuations")
+    def test_user_agent(self):
+        chan = ChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4")
+        self.assertEqual("Agent123 /3.4", chan.userAgent)
 
-    def test_invalid_user_agent_len(self):
-        user_agent = "x" * 201
+    def test_user_agent_len(self):
+        user_agent = "x" * 2049
         chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
         with self.assertRaises(SparkConnectException) as err:
             chan.userAgent
-        self.assertRegex(err.exception.message, "characters in length")
+        self.assertRegex(err.exception.message, "'user_agent' parameter should 
not exceed")
+
+        user_agent = "%C3%A4" * 341  # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
+        expected = "ä" * 341
+        chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
+        self.assertEqual(expected, chan.userAgent)
 
     def test_valid_channel_creation(self):
         chan = ChannelBuilder("sc://host").toChannel()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to