This is an automated email from the ASF dual-hosted git repository. yangjie01 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f2672fcf3cf [SPARK-43923][CONNECT][FOLLOWUP] Correct the message abbreviation f2672fcf3cf is described below commit f2672fcf3cf3019791a0afcf7eff28f86503fcbc Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Oct 26 22:11:11 2023 +0800 [SPARK-43923][CONNECT][FOLLOWUP] Correct the message abbreviation ### What changes were proposed in this pull request? 1, truncate raw bytes (udf/udtf/local relation) with `MAX_BYTES_SIZE`; 2, pass `maxStringSize` to abbreviate nested messages; 3, minor optimization to avoid temp array creation; ### Why are the changes needed? 1, there is only one place specifying the `maxStringSize`, with value `MAX_STATEMENT_TEXT_SIZE = 65535`. By its name, it is used to truncate the SQL statements which are always strings. No need to affect raw bytes; 2, according to the implementation of `Message.toString`: https://github.com/protocolbuffers/protobuf/blob/main/java/core/src/main/java/com/google/protobuf/TextFormat.java#L567-L574 the value of bytes fields can be either `ByteString` or `byte[]`, so the two branches should be consistent. 3, `maxStringSize` only affects the top-level string fields, it should also be used in nested messages. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43535 from zhengruifeng/connect_abbreviate_fix. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: yangjie01 <yangji...@baidu.com> --- .../spark/sql/connect/common/ProtoUtils.scala | 25 ++++++++++++---------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index c7bf3f93bd0..4d1be169ae1 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -43,9 +43,12 @@ private[connect] object ProtoUtils { case (field: FieldDescriptor, byteString: ByteString) if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteString != null => val size = byteString.size - if (size > maxStringSize) { - val prefix = Array.tabulate(maxStringSize)(byteString.byteAt) - builder.setField(field, createByteString(prefix, size)) + if (size > MAX_BYTES_SIZE) { + builder.setField( + field, + byteString + .substring(0, MAX_BYTES_SIZE) + .concat(createTruncatedByteString(size))) } else { builder.setField(field, byteString) } @@ -54,8 +57,11 @@ private[connect] object ProtoUtils { if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteArray != null => val size = byteArray.size if (size > MAX_BYTES_SIZE) { - val prefix = byteArray.take(MAX_BYTES_SIZE) - builder.setField(field, createByteString(prefix, size)) + builder.setField( + field, + ByteString + .copyFrom(byteArray, 0, MAX_BYTES_SIZE) + .concat(createTruncatedByteString(size))) } else { builder.setField(field, byteArray) } @@ -63,7 +69,7 @@ private[connect] object ProtoUtils { // TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx, msg> case (field: FieldDescriptor, msg: Message) if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null => - builder.setField(field, abbreviate(msg)) + builder.setField(field, abbreviate(msg, maxStringSize)) case (field: FieldDescriptor, value: Any) => builder.setField(field, value) } @@ -71,11 +77,8 @@ private[connect] object ProtoUtils { builder.build() } - private def createByteString(prefix: Array[Byte], size: Int): ByteString = { - ByteString.copyFrom( - List( - ByteString.copyFrom(prefix), - ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")).asJava) + private def createTruncatedByteString(size: Int): ByteString = { + ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]") } private def createString(prefix: String, size: Int): String = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org