This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6d8235f3b2bb [SPARK-49357][CONNECT][PYTHON] Vertically truncate deeply
nested protobuf message
6d8235f3b2bb is described below
commit 6d8235f3b2bbaa88b10c35d6eecddffa4d1b04a4
Author: Changgyoo Park <[email protected]>
AuthorDate: Wed Aug 28 10:58:41 2024 +0800
[SPARK-49357][CONNECT][PYTHON] Vertically truncate deeply nested protobuf
message
### What changes were proposed in this pull request?
Add a new message truncation strategy to limit the nesting level since the
existing truncation strategies do not apply well to a deeply nested and large
protobuf message.
### Why are the changes needed?
There are instances where deeply nested protobuf messages cause performance
problems on the client side when the logger is turned on.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Add a new test scenario to test_truncate_message in test_connect_basic.py.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47891 from changgyoopark-db/SPARK-49357.
Authored-by: Changgyoo Park <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 21 ++++++++++++++++-----
.../pyspark/sql/tests/connect/test_connect_basic.py | 10 ++++++++++
2 files changed, 26 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 723a11b35c26..35dcf677fdb7 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -993,20 +993,25 @@ class SparkConnectClient(object):
----------
p : google.protobuf.message.Message
Generic Message type
+ truncate: bool
+ Indicates whether to truncate the message
Returns
-------
Single line string of the serialized proto message.
"""
try:
- p2 = self._truncate(p) if truncate else p
+ max_level = 8 if truncate else sys.maxsize
+ p2 = self._truncate(p, max_level) if truncate else p
return text_format.MessageToString(p2, as_one_line=True)
except RecursionError:
return "<Truncated message due to recursion error>"
except Exception:
return "<Truncated message due to truncation error>"
- def _truncate(self, p: google.protobuf.message.Message) ->
google.protobuf.message.Message:
+ def _truncate(
+ self, p: google.protobuf.message.Message, allowed_recursion_depth: int
+ ) -> google.protobuf.message.Message:
"""
Helper method to truncate the protobuf message.
Refer to 'org.apache.spark.sql.connect.common.Abbreviator' in the
server side.
@@ -1029,11 +1034,17 @@ class SparkConnectClient(object):
field_name = descriptor.name
if descriptor.type == descriptor.TYPE_MESSAGE:
- if descriptor.label == descriptor.LABEL_REPEATED:
+ if allowed_recursion_depth == 0:
+ p2.ClearField(field_name)
+ elif descriptor.label == descriptor.LABEL_REPEATED:
p2.ClearField(field_name)
- getattr(p2, field_name).extend([self._truncate(v) for
v in value])
+ getattr(p2, field_name).extend(
+ [self._truncate(v, allowed_recursion_depth - 1)
for v in value]
+ )
else:
- getattr(p2, field_name).CopyFrom(self._truncate(value))
+ getattr(p2, field_name).CopyFrom(
+ self._truncate(value, allowed_recursion_depth - 1)
+ )
elif descriptor.type == descriptor.TYPE_STRING:
if descriptor.label == descriptor.LABEL_REPEATED:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 07fda95e6548..f084601d2e7b 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1434,6 +1434,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
proto_string_truncated_2 =
self.connect._client._proto_to_string(plan2, True)
self.assertTrue(len(proto_string_truncated_2) < 8000,
len(proto_string_truncated_2))
+ cdf3 = cdf1.select("a" * 4096)
+ for _ in range(64):
+ cdf3 = cdf3.select("a" * 4096)
+ plan3 = cdf3._plan.to_proto(self.connect._client)
+
+ proto_string_3 = self.connect._client._proto_to_string(plan3, False)
+ self.assertTrue(len(proto_string_3) > 128000, len(proto_string_3))
+ proto_string_truncated_3 =
self.connect._client._proto_to_string(plan3, True)
+ self.assertTrue(len(proto_string_truncated_3) < 64000,
len(proto_string_truncated_3))
+
class SparkConnectGCTests(SparkConnectSQLTestCase):
@classmethod
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]