This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4911a5bad4ac [SPARK-46505][CONNECT] Make bytes threshold configurable
in `ProtoUtils.abbreviate`
4911a5bad4ac is described below
commit 4911a5bad4ac4665772bafbc45ea18cc03e64f3c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Dec 26 10:52:54 2023 +0800
[SPARK-46505][CONNECT] Make bytes threshold configurable in
`ProtoUtils.abbreviate`
### What changes were proposed in this pull request?
Make bytes threshold configurable in `ProtoUtils.abbreviate`
### Why are the changes needed?
the bytes threshold should be also configurable, like string type
### Does this PR introduce _any_ user-facing change?
no, this function is only used internally
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44486 from zhengruifeng/connect_ab_config.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../spark/sql/connect/common/ProtoUtils.scala | 23 ++++++++++++-----
.../sql/connect/messages/AbbreviateSuite.scala | 30 ++++++++++++++++++++++
2 files changed, 46 insertions(+), 7 deletions(-)
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
index 18739ed54a29..44de2350b9fd 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
@@ -24,18 +24,25 @@ import com.google.protobuf.Descriptors.FieldDescriptor
private[connect] object ProtoUtils {
private val format = java.text.NumberFormat.getInstance()
+ private val BYTES = "BYTES"
+ private val STRING = "STRING"
private val MAX_BYTES_SIZE = 8
private val MAX_STRING_SIZE = 1024
def abbreviate(message: Message, maxStringSize: Int = MAX_STRING_SIZE):
Message = {
+ abbreviate(message, Map(STRING -> maxStringSize))
+ }
+
+ def abbreviate(message: Message, thresholds: Map[String, Int]): Message = {
val builder = message.toBuilder
message.getAllFields.asScala.iterator.foreach {
case (field: FieldDescriptor, string: String)
if field.getJavaType == FieldDescriptor.JavaType.STRING && string !=
null =>
val size = string.length
- if (size > maxStringSize) {
- builder.setField(field, createString(string.take(maxStringSize),
size))
+ val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
+ if (size > threshold) {
+ builder.setField(field, createString(string.take(threshold), size))
} else {
builder.setField(field, string)
}
@@ -43,11 +50,12 @@ private[connect] object ProtoUtils {
case (field: FieldDescriptor, byteString: ByteString)
if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING &&
byteString != null =>
val size = byteString.size
- if (size > MAX_BYTES_SIZE) {
+ val threshold = thresholds.getOrElse(BYTES, MAX_BYTES_SIZE)
+ if (size > threshold) {
builder.setField(
field,
byteString
- .substring(0, MAX_BYTES_SIZE)
+ .substring(0, threshold)
.concat(createTruncatedByteString(size)))
} else {
builder.setField(field, byteString)
@@ -56,11 +64,12 @@ private[connect] object ProtoUtils {
case (field: FieldDescriptor, byteArray: Array[Byte])
if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING &&
byteArray != null =>
val size = byteArray.length
- if (size > MAX_BYTES_SIZE) {
+ val threshold = thresholds.getOrElse(BYTES, MAX_BYTES_SIZE)
+ if (size > threshold) {
builder.setField(
field,
ByteString
- .copyFrom(byteArray, 0, MAX_BYTES_SIZE)
+ .copyFrom(byteArray, 0, threshold)
.concat(createTruncatedByteString(size)))
} else {
builder.setField(field, byteArray)
@@ -69,7 +78,7 @@ private[connect] object ProtoUtils {
// TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx,
msg>
case (field: FieldDescriptor, msg: Message)
if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg !=
null =>
- builder.setField(field, abbreviate(msg, maxStringSize))
+ builder.setField(field, abbreviate(msg, thresholds))
case (field: FieldDescriptor, value: Any) => builder.setField(field,
value)
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
index 9a712e9b7bf1..6dca2c1e8907 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/AbbreviateSuite.scala
@@ -118,4 +118,34 @@ class AbbreviateSuite extends SparkFunSuite {
}
}
}
+
+ test("truncate bytes with threshold: simple python udf") {
+ val bytes = Array.ofDim[Byte](1024)
+ val message = proto.PythonUDF
+ .newBuilder()
+ .setEvalType(1)
+ .setOutputType(ProtoDataTypes.BinaryType)
+ .setCommand(ByteString.copyFrom(bytes))
+ .setPythonVer("3.12")
+ .build()
+
+ Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
+ val truncated = ProtoUtils.abbreviate(message, Map("BYTES" -> threshold))
+ assert(truncated.isInstanceOf[proto.PythonUDF])
+
+ val truncatedUDF = truncated.asInstanceOf[proto.PythonUDF]
+ assert(truncatedUDF.getEvalType === 1)
+ assert(truncatedUDF.getOutputType === ProtoDataTypes.BinaryType)
+ assert(truncatedUDF.getPythonVer === "3.12")
+
+ if (threshold < 1024) {
+ // with suffix: [truncated(size=...)]
+ assert(
+ threshold < truncatedUDF.getCommand.size() &&
+ truncatedUDF.getCommand.size() < threshold + 64)
+ } else {
+ assert(truncatedUDF.getCommand.size() === 1024)
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]