This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2a49feeb5d7 [SPARK-43427][PROTOBUF] spark protobuf: allow upcasting
unsigned integer types
2a49feeb5d7 is described below
commit 2a49feeb5d727552758a75fdcfbc49e8f6eed72f
Author: Parth Upadhyay <[email protected]>
AuthorDate: Mon Dec 11 16:37:04 2023 -0800
[SPARK-43427][PROTOBUF] spark protobuf: allow upcasting unsigned integer
types
### What changes were proposed in this pull request?
JIRA: https://issues.apache.org/jira/browse/SPARK-43427
Protobuf supports unsigned integer types, including uint32 and uint64. When
deserializing protobuf values with fields of these types, `from_protobuf`
currently transforms them to the spark types of:
```
uint32 => IntegerType
uint64 => LongType
```
IntegerType and LongType are
[signed](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) integer
types, so this can lead to confusing results. Namely, if a uint32 value in a
stored proto is above 2^31 or a uint64 value is above 2^63, their
representation in binary will contain a 1 in the highest bit, which when
interpreted as a signed integer will be negative (I.e. overflow). No
information is lost, as `IntegerType` and `LongType` contain 32 and 64 bits
respectively, however [...]
In this PR, we add an option (`upcast.unsigned.ints`) to allow upcasting
unsigned integer types into a larger integer type that can represent them
natively, i.e.
```
uint32 => LongType
uint64 => Decimal(20, 0)
```
I added an option so that it doesn't break any existing clients.
**Example of current behavior**
Consider a protobuf message like:
```
syntax = "proto3";
message Test {
uint64 val = 1;
}
```
If we compile the above and then generate a message with a value for `val`
above 2^63:
```
import test_pb2
s = test_pb2.Test()
s.val = 9223372036854775809 # 2**63 + 1
serialized = s.SerializeToString()
print(serialized)
```
This generates the binary representation:
b'\x08\x81\x80\x80\x80\x80\x80\x80\x80\x80\x01'
Then, deserializing this using `from_protobuf`, we can see that it is
represented as a negative number. I did this in a notebook so its easier to
see, but could reproduce in a scala test as well:

**Precedent**
I believe that unsigned integer types in parquet are deserialized in a
similar manner, i.e. put into a larger type so that the unsigned representation
natively fits. https://issues.apache.org/jira/browse/SPARK-34817 and
https://github.com/apache/spark/pull/31921. So an option to get similar
behavior would be useful.
### Why are the changes needed?
Improve unsigned integer deserialization behavior.
### Does this PR introduce any user-facing change?
Yes, adds a new option.
### How was this patch tested?
Unit Testing
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43773 from
justaparth/parth/43427-add-option-to-expand-unsigned-integers.
Authored-by: Parth Upadhyay <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/sql/protobuf/ProtobufDeserializer.scala | 12 ++++++
.../spark/sql/protobuf/ProtobufSerializer.scala | 11 +++++-
.../spark/sql/protobuf/utils/ProtobufOptions.scala | 12 ++++++
.../sql/protobuf/utils/SchemaConverters.scala | 18 ++++++++-
.../sql/protobuf/ProtobufFunctionsSuite.scala | 46 ++++++++++++++++++++++
5 files changed, 96 insertions(+), 3 deletions(-)
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
index a46baf51379..45f3419edf9 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
@@ -193,6 +193,11 @@ private[sql] class ProtobufDeserializer(
case (INT, ShortType) =>
(updater, ordinal, value) => updater.setShort(ordinal,
value.asInstanceOf[Short])
+ case (INT, LongType) =>
+ (updater, ordinal, value) =>
+ updater.setLong(
+ ordinal,
+ Integer.toUnsignedLong(value.asInstanceOf[Int]))
case (
MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM |
BYTE_STRING,
ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
@@ -201,6 +206,13 @@ private[sql] class ProtobufDeserializer(
case (LONG, LongType) =>
(updater, ordinal, value) => updater.setLong(ordinal,
value.asInstanceOf[Long])
+ case (LONG, DecimalType.LongDecimal) =>
+ (updater, ordinal, value) =>
+ updater.setDecimal(
+ ordinal,
+ Decimal.fromString(
+
UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long]))))
+
case (FLOAT, FloatType) =>
(updater, ordinal, value) => updater.setFloat(ordinal,
value.asInstanceOf[Float])
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
index 4684934a565..432f948a902 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf
import scala.jdk.CollectionConverters._
-import com.google.protobuf.{Duration, DynamicMessage, Timestamp}
+import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
@@ -91,8 +91,17 @@ private[sql] class ProtobufSerializer(
(getter, ordinal) => {
getter.getInt(ordinal)
}
+ case (LongType, INT) if fieldDescriptor.getLiteType ==
WireFormat.FieldType.UINT32 =>
+ (getter, ordinal) => {
+ getter.getLong(ordinal).toInt
+ }
case (LongType, LONG) =>
(getter, ordinal) => getter.getLong(ordinal)
+ case (DecimalType(), LONG)
+ if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 =>
+ (getter, ordinal) => {
+ getter.getDecimal(ordinal, 20, 0).toUnscaledLong
+ }
case (FloatType, FLOAT) =>
(getter, ordinal) => getter.getFloat(ordinal)
case (DoubleType, DOUBLE) =>
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
index dfdef1f0ec3..f08dfabb606 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
@@ -168,6 +168,18 @@ private[sql] class ProtobufOptions(
// instead of string, so use caution if changing existing parsing logic.
val enumsAsInts: Boolean =
parameters.getOrElse("enums.as.ints", false.toString).toBoolean
+
+ // Protobuf supports unsigned integer types uint32 and uint64. By default
this library
+ // will serialize them as the signed IntegerType and LongType respectively.
For very
+ // large unsigned values this can cause overflow, causing these numbers
+ // to be represented as negative (above 2^31 for uint32
+ // and above 2^63 for uint64).
+ //
+ // Enabling this option will upcast unsigned integers into a larger type,
+ // i.e. LongType for uint32 and Decimal(20, 0) for uint64 so their
representation
+ // can contain large unsigned values without overflow.
+ val upcastUnsignedInts: Boolean =
+ parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean
}
private[sql] object ProtobufOptions {
diff --git
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index aa3ac998a74..083d1dac081 100644
---
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils
import scala.jdk.CollectionConverters._
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+import com.google.protobuf.WireFormat
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
@@ -67,9 +68,22 @@ object SchemaConverters extends Logging {
existingRecordNames: Map[String, Int],
protobufOptions: ProtobufOptions): Option[StructField] = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+
val dataType = fd.getJavaType match {
- case INT => Some(IntegerType)
- case LONG => Some(LongType)
+ // When the protobuf type is unsigned and upcastUnsignedIntegers has
been set,
+ // use a larger type (LongType and Decimal(20,0) for uint32 and uint64).
+ case INT =>
+ if (fd.getLiteType == WireFormat.FieldType.UINT32 &&
protobufOptions.upcastUnsignedInts) {
+ Some(LongType)
+ } else {
+ Some(IntegerType)
+ }
+ case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64
+ && protobufOptions.upcastUnsignedInts) {
+ Some(DecimalType.LongDecimal)
+ } else {
+ Some(LongType)
+ }
case FLOAT => Some(FloatType)
case DOUBLE => Some(DoubleType)
case BOOLEAN => Some(BooleanType)
diff --git
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 9c397597984..67f6568107e 100644
---
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -1600,6 +1600,52 @@ class ProtobufFunctionsSuite extends QueryTest with
SharedSparkSession with Prot
}
}
+ test("test unsigned integer types") {
+ // Test that we correctly handle unsigned integer parsing.
+ // We're using Integer/Long's `MIN_VALUE` as it has a 1 in the sign bit.
+ val sample = spark.range(1).select(
+ lit(
+ SimpleMessage
+ .newBuilder()
+ .setUint32Value(Integer.MIN_VALUE)
+ .setUint64Value(Long.MinValue)
+ .build()
+ .toByteArray
+ ).as("raw_proto"))
+
+ val expectedWithoutFlag = spark.range(1).select(
+ lit(Integer.MIN_VALUE).as("uint32_value"),
+ lit(Long.MinValue).as("uint64_value")
+ )
+
+ val expectedWithFlag = spark.range(1).select(
+
lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"),
+
lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value")
+ )
+
+ checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt)
=>
+ List(
+ Map.empty[String, String],
+ Map("upcast.unsigned.ints" -> "false")).foreach(opts => {
+ checkAnswer(
+ sample.select(
+ from_protobuf_wrapper($"raw_proto", name, descFilePathOpt,
opts).as("proto"))
+ .select("proto.uint32_value", "proto.uint64_value"),
+ expectedWithoutFlag)
+ })
+
+ checkAnswer(
+ sample.select(
+ from_protobuf_wrapper(
+ $"raw_proto",
+ name,
+ descFilePathOpt,
+ Map("upcast.unsigned.ints" -> "true")).as("proto"))
+ .select("proto.uint32_value", "proto.uint64_value"),
+ expectedWithFlag)
+ }
+ }
+
def testFromProtobufWithOptions(
df: DataFrame,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]