This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2a49feeb5d7 [SPARK-43427][PROTOBUF] spark protobuf: allow upcasting unsigned integer types 2a49feeb5d7 is described below commit 2a49feeb5d727552758a75fdcfbc49e8f6eed72f Author: Parth Upadhyay <parth.upadh...@gmail.com> AuthorDate: Mon Dec 11 16:37:04 2023 -0800 [SPARK-43427][PROTOBUF] spark protobuf: allow upcasting unsigned integer types ### What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-43427 Protobuf supports unsigned integer types, including uint32 and uint64. When deserializing protobuf values with fields of these types, `from_protobuf` currently transforms them to the spark types of: ``` uint32 => IntegerType uint64 => LongType ``` IntegerType and LongType are [signed](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) integer types, so this can lead to confusing results. Namely, if a uint32 value in a stored proto is above 2^31 or a uint64 value is above 2^63, their representation in binary will contain a 1 in the highest bit, which when interpreted as a signed integer will be negative (I.e. overflow). No information is lost, as `IntegerType` and `LongType` contain 32 and 64 bits respectively, however [...] In this PR, we add an option (`upcast.unsigned.ints`) to allow upcasting unsigned integer types into a larger integer type that can represent them natively, i.e. ``` uint32 => LongType uint64 => Decimal(20, 0) ``` I added an option so that it doesn't break any existing clients. **Example of current behavior** Consider a protobuf message like: ``` syntax = "proto3"; message Test { uint64 val = 1; } ``` If we compile the above and then generate a message with a value for `val` above 2^63: ``` import test_pb2 s = test_pb2.Test() s.val = 9223372036854775809 # 2**63 + 1 serialized = s.SerializeToString() print(serialized) ``` This generates the binary representation: b'\x08\x81\x80\x80\x80\x80\x80\x80\x80\x80\x01' Then, deserializing this using `from_protobuf`, we can see that it is represented as a negative number. I did this in a notebook so its easier to see, but could reproduce in a scala test as well: ![image](https://github.com/apache/spark/assets/1002986/7144e6a9-3f43-455e-94c3-9065ae88206e) **Precedent** I believe that unsigned integer types in parquet are deserialized in a similar manner, i.e. put into a larger type so that the unsigned representation natively fits. https://issues.apache.org/jira/browse/SPARK-34817 and https://github.com/apache/spark/pull/31921. So an option to get similar behavior would be useful. ### Why are the changes needed? Improve unsigned integer deserialization behavior. ### Does this PR introduce any user-facing change? Yes, adds a new option. ### How was this patch tested? Unit Testing ### Was this patch authored or co-authored using generative AI tooling? No Closes #43773 from justaparth/parth/43427-add-option-to-expand-unsigned-integers. Authored-by: Parth Upadhyay <parth.upadh...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/protobuf/ProtobufDeserializer.scala | 12 ++++++ .../spark/sql/protobuf/ProtobufSerializer.scala | 11 +++++- .../spark/sql/protobuf/utils/ProtobufOptions.scala | 12 ++++++ .../sql/protobuf/utils/SchemaConverters.scala | 18 ++++++++- .../sql/protobuf/ProtobufFunctionsSuite.scala | 46 ++++++++++++++++++++++ 5 files changed, 96 insertions(+), 3 deletions(-) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala index a46baf51379..45f3419edf9 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -193,6 +193,11 @@ private[sql] class ProtobufDeserializer( case (INT, ShortType) => (updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short]) + case (INT, LongType) => + (updater, ordinal, value) => + updater.setLong( + ordinal, + Integer.toUnsignedLong(value.asInstanceOf[Int])) case ( MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING, ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated => @@ -201,6 +206,13 @@ private[sql] class ProtobufDeserializer( case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) + case (LONG, DecimalType.LongDecimal) => + (updater, ordinal, value) => + updater.setDecimal( + ordinal, + Decimal.fromString( + UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long])))) + case (FLOAT, FloatType) => (updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float]) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala index 4684934a565..432f948a902 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf import scala.jdk.CollectionConverters._ -import com.google.protobuf.{Duration, DynamicMessage, Timestamp} +import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ @@ -91,8 +91,17 @@ private[sql] class ProtobufSerializer( (getter, ordinal) => { getter.getInt(ordinal) } + case (LongType, INT) if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT32 => + (getter, ordinal) => { + getter.getLong(ordinal).toInt + } case (LongType, LONG) => (getter, ordinal) => getter.getLong(ordinal) + case (DecimalType(), LONG) + if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 => + (getter, ordinal) => { + getter.getDecimal(ordinal, 20, 0).toUnscaledLong + } case (FloatType, FLOAT) => (getter, ordinal) => getter.getFloat(ordinal) case (DoubleType, DOUBLE) => diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index dfdef1f0ec3..f08dfabb606 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -168,6 +168,18 @@ private[sql] class ProtobufOptions( // instead of string, so use caution if changing existing parsing logic. val enumsAsInts: Boolean = parameters.getOrElse("enums.as.ints", false.toString).toBoolean + + // Protobuf supports unsigned integer types uint32 and uint64. By default this library + // will serialize them as the signed IntegerType and LongType respectively. For very + // large unsigned values this can cause overflow, causing these numbers + // to be represented as negative (above 2^31 for uint32 + // and above 2^63 for uint64). + // + // Enabling this option will upcast unsigned integers into a larger type, + // i.e. LongType for uint32 and Decimal(20, 0) for uint64 so their representation + // can contain large unsigned values without overflow. + val upcastUnsignedInts: Boolean = + parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean } private[sql] object ProtobufOptions { diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index aa3ac998a74..083d1dac081 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils import scala.jdk.CollectionConverters._ import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} +import com.google.protobuf.WireFormat import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -67,9 +68,22 @@ object SchemaConverters extends Logging { existingRecordNames: Map[String, Int], protobufOptions: ProtobufOptions): Option[StructField] = { import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + val dataType = fd.getJavaType match { - case INT => Some(IntegerType) - case LONG => Some(LongType) + // When the protobuf type is unsigned and upcastUnsignedIntegers has been set, + // use a larger type (LongType and Decimal(20,0) for uint32 and uint64). + case INT => + if (fd.getLiteType == WireFormat.FieldType.UINT32 && protobufOptions.upcastUnsignedInts) { + Some(LongType) + } else { + Some(IntegerType) + } + case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64 + && protobufOptions.upcastUnsignedInts) { + Some(DecimalType.LongDecimal) + } else { + Some(LongType) + } case FLOAT => Some(FloatType) case DOUBLE => Some(DoubleType) case BOOLEAN => Some(BooleanType) diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 9c397597984..67f6568107e 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -1600,6 +1600,52 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } + test("test unsigned integer types") { + // Test that we correctly handle unsigned integer parsing. + // We're using Integer/Long's `MIN_VALUE` as it has a 1 in the sign bit. + val sample = spark.range(1).select( + lit( + SimpleMessage + .newBuilder() + .setUint32Value(Integer.MIN_VALUE) + .setUint64Value(Long.MinValue) + .build() + .toByteArray + ).as("raw_proto")) + + val expectedWithoutFlag = spark.range(1).select( + lit(Integer.MIN_VALUE).as("uint32_value"), + lit(Long.MinValue).as("uint64_value") + ) + + val expectedWithFlag = spark.range(1).select( + lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"), + lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value") + ) + + checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt) => + List( + Map.empty[String, String], + Map("upcast.unsigned.ints" -> "false")).foreach(opts => { + checkAnswer( + sample.select( + from_protobuf_wrapper($"raw_proto", name, descFilePathOpt, opts).as("proto")) + .select("proto.uint32_value", "proto.uint64_value"), + expectedWithoutFlag) + }) + + checkAnswer( + sample.select( + from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + Map("upcast.unsigned.ints" -> "true")).as("proto")) + .select("proto.uint32_value", "proto.uint64_value"), + expectedWithFlag) + } + } + def testFromProtobufWithOptions( df: DataFrame, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org