This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d8235f3b2bb [SPARK-49357][CONNECT][PYTHON] Vertically truncate deeply 
nested protobuf message
6d8235f3b2bb is described below

commit 6d8235f3b2bbaa88b10c35d6eecddffa4d1b04a4
Author: Changgyoo Park <[email protected]>
AuthorDate: Wed Aug 28 10:58:41 2024 +0800

    [SPARK-49357][CONNECT][PYTHON] Vertically truncate deeply nested protobuf 
message
    
    ### What changes were proposed in this pull request?
    
    Add a new message truncation strategy to limit the nesting level since the 
existing truncation strategies do not apply well to a deeply nested and large 
protobuf message.
    
    ### Why are the changes needed?
    
    There are instances where deeply nested protobuf messages cause performance 
problems on the client side when the logger is turned on.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Add a new test scenario to test_truncate_message in test_connect_basic.py.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47891 from changgyoopark-db/SPARK-49357.
    
    Authored-by: Changgyoo Park <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/client/core.py           | 21 ++++++++++++++++-----
 .../pyspark/sql/tests/connect/test_connect_basic.py | 10 ++++++++++
 2 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 723a11b35c26..35dcf677fdb7 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -993,20 +993,25 @@ class SparkConnectClient(object):
         ----------
         p : google.protobuf.message.Message
             Generic Message type
+        truncate: bool
+            Indicates whether to truncate the message
 
         Returns
         -------
         Single line string of the serialized proto message.
         """
         try:
-            p2 = self._truncate(p) if truncate else p
+            max_level = 8 if truncate else sys.maxsize
+            p2 = self._truncate(p, max_level) if truncate else p
             return text_format.MessageToString(p2, as_one_line=True)
         except RecursionError:
             return "<Truncated message due to recursion error>"
         except Exception:
             return "<Truncated message due to truncation error>"
 
-    def _truncate(self, p: google.protobuf.message.Message) -> 
google.protobuf.message.Message:
+    def _truncate(
+        self, p: google.protobuf.message.Message, allowed_recursion_depth: int
+    ) -> google.protobuf.message.Message:
         """
         Helper method to truncate the protobuf message.
         Refer to 'org.apache.spark.sql.connect.common.Abbreviator' in the 
server side.
@@ -1029,11 +1034,17 @@ class SparkConnectClient(object):
                 field_name = descriptor.name
 
                 if descriptor.type == descriptor.TYPE_MESSAGE:
-                    if descriptor.label == descriptor.LABEL_REPEATED:
+                    if allowed_recursion_depth == 0:
+                        p2.ClearField(field_name)
+                    elif descriptor.label == descriptor.LABEL_REPEATED:
                         p2.ClearField(field_name)
-                        getattr(p2, field_name).extend([self._truncate(v) for 
v in value])
+                        getattr(p2, field_name).extend(
+                            [self._truncate(v, allowed_recursion_depth - 1) 
for v in value]
+                        )
                     else:
-                        getattr(p2, field_name).CopyFrom(self._truncate(value))
+                        getattr(p2, field_name).CopyFrom(
+                            self._truncate(value, allowed_recursion_depth - 1)
+                        )
 
                 elif descriptor.type == descriptor.TYPE_STRING:
                     if descriptor.label == descriptor.LABEL_REPEATED:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 07fda95e6548..f084601d2e7b 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1434,6 +1434,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         proto_string_truncated_2 = 
self.connect._client._proto_to_string(plan2, True)
         self.assertTrue(len(proto_string_truncated_2) < 8000, 
len(proto_string_truncated_2))
 
+        cdf3 = cdf1.select("a" * 4096)
+        for _ in range(64):
+            cdf3 = cdf3.select("a" * 4096)
+        plan3 = cdf3._plan.to_proto(self.connect._client)
+
+        proto_string_3 = self.connect._client._proto_to_string(plan3, False)
+        self.assertTrue(len(proto_string_3) > 128000, len(proto_string_3))
+        proto_string_truncated_3 = 
self.connect._client._proto_to_string(plan3, True)
+        self.assertTrue(len(proto_string_truncated_3) < 64000, 
len(proto_string_truncated_3))
+
 
 class SparkConnectGCTests(SparkConnectSQLTestCase):
     @classmethod


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to