This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 71b76dc3e66a [SPARK-43117][CONNECT] Make `ProtoUtils.abbreviate`
support repeated fields
71b76dc3e66a is described below
commit 71b76dc3e66a9fdd99f961876c503776e8085325
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Feb 8 15:23:34 2024 +0800
[SPARK-43117][CONNECT] Make `ProtoUtils.abbreviate` support repeated fields
### What changes were proposed in this pull request?
Make `ProtoUtils.abbreviate` support repeated fields
### Why are the changes needed?
existing implementation does not work for repeated fields (strings/messages)
we don't have `repeated bytes` in Spark Connect for now, so let it alone
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
added UTs
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #45056 from zhengruifeng/proto_abbr_repeat.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../spark/sql/connect/common/ProtoUtils.scala | 34 ++++++++---
.../sql/connect/messages/AbbreviateSuite.scala | 71 ++++++++++++++++++++++
2 files changed, 98 insertions(+), 7 deletions(-)
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
index 2f31b63acf87..66146698b701 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
@@ -42,7 +42,17 @@ private[connect] object ProtoUtils {
val size = string.length
val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
if (size > threshold) {
- builder.setField(field, createString(string.take(threshold), size))
+ builder.setField(field, truncateString(string, threshold))
+ }
+
+ case (field: FieldDescriptor, strings: java.lang.Iterable[_])
+ if field.getJavaType == FieldDescriptor.JavaType.STRING &&
field.isRepeated
+ && strings != null =>
+ val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
+ strings.iterator().asScala.zipWithIndex.foreach {
+ case (string: String, i) if string != null && string.length >
threshold =>
+ builder.setRepeatedField(field, i, truncateString(string,
threshold))
+ case _ =>
}
case (field: FieldDescriptor, byteString: ByteString)
@@ -69,23 +79,33 @@ private[connect] object ProtoUtils {
.concat(createTruncatedByteString(size)))
}
- // TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx,
msg>
+ // TODO(SPARK-46988): should support map<xxx, msg>
case (field: FieldDescriptor, msg: Message)
- if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg !=
null =>
+ if field.getJavaType == FieldDescriptor.JavaType.MESSAGE &&
!field.isRepeated
+ && msg != null =>
builder.setField(field, abbreviate(msg, thresholds))
+ case (field: FieldDescriptor, msgs: java.lang.Iterable[_])
+ if field.getJavaType == FieldDescriptor.JavaType.MESSAGE &&
field.isRepeated
+ && msgs != null =>
+ msgs.iterator().asScala.zipWithIndex.foreach {
+ case (msg: Message, i) if msg != null =>
+ builder.setRepeatedField(field, i, abbreviate(msg, thresholds))
+ case _ =>
+ }
+
case _ =>
}
builder.build()
}
- private def createTruncatedByteString(size: Int): ByteString = {
- ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")
+ private def truncateString(string: String, threshold: Int): String = {
+
s"${string.take(threshold)}[truncated(size=${format.format(string.length)})]"
}
- private def createString(prefix: String, size: Int): String = {
- s"$prefix[truncated(size=${format.format(size)})]"
+ private def createTruncatedByteString(size: Int): ByteString = {
+ ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")
}
// Because Spark Connect operation tags are also set as SparkContext Job
tags, they cannot contain
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
index 6dca2c1e8907..0b7104f6c67e 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
@@ -92,6 +92,77 @@ class AbbreviateSuite extends SparkFunSuite {
}
}
+ test("truncate repeated strings") {
+ val sql = proto.Relation
+ .newBuilder()
+ .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T"))
+ .build()
+ val names = Seq.range(0, 10).map(i => i.toString * 1024)
+ val drop =
proto.Drop.newBuilder().setInput(sql).addAllColumnNames(names.asJava).build()
+
+ Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+ val truncated = ProtoUtils.abbreviate(drop, threshold)
+ assert(drop.isInstanceOf[proto.Drop])
+
+ val truncatedNames =
truncated.asInstanceOf[proto.Drop].getColumnNamesList.asScala.toSeq
+ assert(truncatedNames.length === 10)
+
+ if (threshold < 1024) {
+ truncatedNames.foreach { truncatedName =>
+ assert(truncatedName.indexOf("[truncated") === threshold)
+ }
+ } else {
+ truncatedNames.foreach { truncatedName =>
+ assert(truncatedName.indexOf("[truncated") === -1)
+ assert(truncatedName.length === 1024)
+ }
+ }
+
+ }
+ }
+
+ test("truncate repeated messages") {
+ val sql = proto.Relation
+ .newBuilder()
+ .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T"))
+ .build()
+
+ val cols = Seq.range(0, 10).map { i =>
+ proto.Expression
+ .newBuilder()
+ .setUnresolvedAttribute(
+ proto.Expression.UnresolvedAttribute
+ .newBuilder()
+ .setUnparsedIdentifier(i.toString * 1024)
+ .build())
+ .build()
+ }
+ val drop =
proto.Drop.newBuilder().setInput(sql).addAllColumns(cols.asJava).build()
+
+ Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+ val truncated = ProtoUtils.abbreviate(drop, threshold)
+ assert(drop.isInstanceOf[proto.Drop])
+
+ val truncatedCols =
truncated.asInstanceOf[proto.Drop].getColumnsList.asScala.toSeq
+ assert(truncatedCols.length === 10)
+
+ if (threshold < 1024) {
+ truncatedCols.foreach { truncatedCol =>
+ assert(truncatedCol.isInstanceOf[proto.Expression])
+ val truncatedName =
truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier
+ assert(truncatedName.indexOf("[truncated") === threshold)
+ }
+ } else {
+ truncatedCols.foreach { truncatedCol =>
+ assert(truncatedCol.isInstanceOf[proto.Expression])
+ val truncatedName =
truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier
+ assert(truncatedName.indexOf("[truncated") === -1)
+ assert(truncatedName.length === 1024)
+ }
+ }
+ }
+ }
+
test("truncate bytes: simple python udf") {
Seq(1, 8, 16, 64, 256).foreach { numBytes =>
val bytes = Array.ofDim[Byte](numBytes)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]