This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4911a5bad4ac [SPARK-46505][CONNECT] Make bytes threshold configurable 
in `ProtoUtils.abbreviate`
4911a5bad4ac is described below

commit 4911a5bad4ac4665772bafbc45ea18cc03e64f3c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Dec 26 10:52:54 2023 +0800

    [SPARK-46505][CONNECT] Make bytes threshold configurable in 
`ProtoUtils.abbreviate`
    
    ### What changes were proposed in this pull request?
    Make bytes threshold configurable in `ProtoUtils.abbreviate`
    
    ### Why are the changes needed?
    the bytes threshold should be also configurable, like string type
    
    ### Does this PR introduce _any_ user-facing change?
    no, this function is only used internally
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44486 from zhengruifeng/connect_ab_config.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../spark/sql/connect/common/ProtoUtils.scala      | 23 ++++++++++++-----
 .../sql/connect/messages/AbbreviateSuite.scala     | 30 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 7 deletions(-)

diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
index 18739ed54a29..44de2350b9fd 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
@@ -24,18 +24,25 @@ import com.google.protobuf.Descriptors.FieldDescriptor
 
 private[connect] object ProtoUtils {
   private val format = java.text.NumberFormat.getInstance()
+  private val BYTES = "BYTES"
+  private val STRING = "STRING"
   private val MAX_BYTES_SIZE = 8
   private val MAX_STRING_SIZE = 1024
 
   def abbreviate(message: Message, maxStringSize: Int = MAX_STRING_SIZE): 
Message = {
+    abbreviate(message, Map(STRING -> maxStringSize))
+  }
+
+  def abbreviate(message: Message, thresholds: Map[String, Int]): Message = {
     val builder = message.toBuilder
 
     message.getAllFields.asScala.iterator.foreach {
       case (field: FieldDescriptor, string: String)
           if field.getJavaType == FieldDescriptor.JavaType.STRING && string != 
null =>
         val size = string.length
-        if (size > maxStringSize) {
-          builder.setField(field, createString(string.take(maxStringSize), 
size))
+        val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
+        if (size > threshold) {
+          builder.setField(field, createString(string.take(threshold), size))
         } else {
           builder.setField(field, string)
         }
@@ -43,11 +50,12 @@ private[connect] object ProtoUtils {
       case (field: FieldDescriptor, byteString: ByteString)
           if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && 
byteString != null =>
         val size = byteString.size
-        if (size > MAX_BYTES_SIZE) {
+        val threshold = thresholds.getOrElse(BYTES, MAX_BYTES_SIZE)
+        if (size > threshold) {
           builder.setField(
             field,
             byteString
-              .substring(0, MAX_BYTES_SIZE)
+              .substring(0, threshold)
               .concat(createTruncatedByteString(size)))
         } else {
           builder.setField(field, byteString)
@@ -56,11 +64,12 @@ private[connect] object ProtoUtils {
       case (field: FieldDescriptor, byteArray: Array[Byte])
           if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && 
byteArray != null =>
         val size = byteArray.length
-        if (size > MAX_BYTES_SIZE) {
+        val threshold = thresholds.getOrElse(BYTES, MAX_BYTES_SIZE)
+        if (size > threshold) {
           builder.setField(
             field,
             ByteString
-              .copyFrom(byteArray, 0, MAX_BYTES_SIZE)
+              .copyFrom(byteArray, 0, threshold)
               .concat(createTruncatedByteString(size)))
         } else {
           builder.setField(field, byteArray)
@@ -69,7 +78,7 @@ private[connect] object ProtoUtils {
       // TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx, 
msg>
       case (field: FieldDescriptor, msg: Message)
           if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != 
null =>
-        builder.setField(field, abbreviate(msg, maxStringSize))
+        builder.setField(field, abbreviate(msg, thresholds))
 
       case (field: FieldDescriptor, value: Any) => builder.setField(field, 
value)
     }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
index 9a712e9b7bf1..6dca2c1e8907 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
@@ -118,4 +118,34 @@ class AbbreviateSuite extends SparkFunSuite {
       }
     }
   }
+
+  test("truncate bytes with threshold: simple python udf") {
+    val bytes = Array.ofDim[Byte](1024)
+    val message = proto.PythonUDF
+      .newBuilder()
+      .setEvalType(1)
+      .setOutputType(ProtoDataTypes.BinaryType)
+      .setCommand(ByteString.copyFrom(bytes))
+      .setPythonVer("3.12")
+      .build()
+
+    Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+      val truncated = ProtoUtils.abbreviate(message, Map("BYTES" -> threshold))
+      assert(truncated.isInstanceOf[proto.PythonUDF])
+
+      val truncatedUDF = truncated.asInstanceOf[proto.PythonUDF]
+      assert(truncatedUDF.getEvalType === 1)
+      assert(truncatedUDF.getOutputType === ProtoDataTypes.BinaryType)
+      assert(truncatedUDF.getPythonVer === "3.12")
+
+      if (threshold < 1024) {
+        // with suffix: [truncated(size=...)]
+        assert(
+          threshold < truncatedUDF.getCommand.size() &&
+            truncatedUDF.getCommand.size() < threshold + 64)
+      } else {
+        assert(truncatedUDF.getCommand.size() === 1024)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to