This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a49feeb5d7 [SPARK-43427][PROTOBUF] spark protobuf: allow upcasting 
unsigned integer types
2a49feeb5d7 is described below

commit 2a49feeb5d727552758a75fdcfbc49e8f6eed72f
Author: Parth Upadhyay <parth.upadh...@gmail.com>
AuthorDate: Mon Dec 11 16:37:04 2023 -0800

    [SPARK-43427][PROTOBUF] spark protobuf: allow upcasting unsigned integer 
types
    
    ### What changes were proposed in this pull request?
    
    JIRA: https://issues.apache.org/jira/browse/SPARK-43427
    
    Protobuf supports unsigned integer types, including uint32 and uint64. When 
deserializing protobuf values with fields of these types, `from_protobuf` 
currently transforms them to the spark types of:
    
    ```
    uint32 => IntegerType
    uint64 => LongType
    ```
    
    IntegerType and LongType are 
[signed](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) integer 
types, so this can lead to confusing results. Namely, if a uint32 value in a 
stored proto is above 2^31 or a uint64 value is above 2^63, their 
representation in binary will contain a 1 in the highest bit, which when 
interpreted as a signed integer will be negative (I.e. overflow). No 
information is lost, as `IntegerType` and `LongType` contain 32 and 64 bits 
respectively, however [...]
    
    In this PR, we add an option (`upcast.unsigned.ints`) to allow upcasting 
unsigned integer types into a larger integer type that can represent them 
natively, i.e.
    
    ```
    uint32 => LongType
    uint64 => Decimal(20, 0)
    ```
    
    I added an option so that it doesn't break any existing clients.
    
    **Example of current behavior**
    
    Consider a protobuf message like:
    
    ```
    syntax = "proto3";
    
    message Test {
      uint64 val = 1;
    }
    ```
    
    If we compile the above and then generate a message with a value for `val` 
above 2^63:
    
    ```
    import test_pb2
    
    s = test_pb2.Test()
    s.val = 9223372036854775809 # 2**63 + 1
    serialized = s.SerializeToString()
    print(serialized)
    ```
    
    This generates the binary representation:
    
    b'\x08\x81\x80\x80\x80\x80\x80\x80\x80\x80\x01'
    
    Then, deserializing this using `from_protobuf`, we can see that it is 
represented as a negative number. I did this in a notebook so its easier to 
see, but could reproduce in a scala test as well:
    
    
![image](https://github.com/apache/spark/assets/1002986/7144e6a9-3f43-455e-94c3-9065ae88206e)
    
    **Precedent**
    I believe that unsigned integer types in parquet are deserialized in a 
similar manner, i.e. put into a larger type so that the unsigned representation 
natively fits. https://issues.apache.org/jira/browse/SPARK-34817 and 
https://github.com/apache/spark/pull/31921. So an option to get similar 
behavior would be useful.
    
    ### Why are the changes needed?
    Improve unsigned integer deserialization behavior.
    
    ### Does this PR introduce any user-facing change?
    Yes, adds a new option.
    
    ### How was this patch tested?
    Unit Testing
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43773 from 
justaparth/parth/43427-add-option-to-expand-unsigned-integers.
    
    Authored-by: Parth Upadhyay <parth.upadh...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/protobuf/ProtobufDeserializer.scala  | 12 ++++++
 .../spark/sql/protobuf/ProtobufSerializer.scala    | 11 +++++-
 .../spark/sql/protobuf/utils/ProtobufOptions.scala | 12 ++++++
 .../sql/protobuf/utils/SchemaConverters.scala      | 18 ++++++++-
 .../sql/protobuf/ProtobufFunctionsSuite.scala      | 46 ++++++++++++++++++++++
 5 files changed, 96 insertions(+), 3 deletions(-)

diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
index a46baf51379..45f3419edf9 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
@@ -193,6 +193,11 @@ private[sql] class ProtobufDeserializer(
       case (INT, ShortType) =>
         (updater, ordinal, value) => updater.setShort(ordinal, 
value.asInstanceOf[Short])
 
+      case (INT, LongType) =>
+        (updater, ordinal, value) =>
+          updater.setLong(
+            ordinal,
+            Integer.toUnsignedLong(value.asInstanceOf[Int]))
       case  (
         MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | 
BYTE_STRING,
         ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
@@ -201,6 +206,13 @@ private[sql] class ProtobufDeserializer(
       case (LONG, LongType) =>
         (updater, ordinal, value) => updater.setLong(ordinal, 
value.asInstanceOf[Long])
 
+      case (LONG, DecimalType.LongDecimal) =>
+        (updater, ordinal, value) =>
+          updater.setDecimal(
+            ordinal,
+            Decimal.fromString(
+              
UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long]))))
+
       case (FLOAT, FloatType) =>
         (updater, ordinal, value) => updater.setFloat(ordinal, 
value.asInstanceOf[Float])
 
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
index 4684934a565..432f948a902 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf
 
 import scala.jdk.CollectionConverters._
 
-import com.google.protobuf.{Duration, DynamicMessage, Timestamp}
+import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat}
 import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
 import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
 
@@ -91,8 +91,17 @@ private[sql] class ProtobufSerializer(
         (getter, ordinal) => {
           getter.getInt(ordinal)
         }
+      case (LongType, INT) if fieldDescriptor.getLiteType == 
WireFormat.FieldType.UINT32 =>
+        (getter, ordinal) => {
+          getter.getLong(ordinal).toInt
+        }
       case (LongType, LONG) =>
         (getter, ordinal) => getter.getLong(ordinal)
+      case (DecimalType(), LONG)
+        if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 =>
+        (getter, ordinal) => {
+          getter.getDecimal(ordinal, 20, 0).toUnscaledLong
+        }
       case (FloatType, FLOAT) =>
         (getter, ordinal) => getter.getFloat(ordinal)
       case (DoubleType, DOUBLE) =>
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
index dfdef1f0ec3..f08dfabb606 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
@@ -168,6 +168,18 @@ private[sql] class ProtobufOptions(
   // instead of string, so use caution if changing existing parsing logic.
   val enumsAsInts: Boolean =
     parameters.getOrElse("enums.as.ints", false.toString).toBoolean
+
+  // Protobuf supports unsigned integer types uint32 and uint64. By default 
this library
+  // will serialize them as the signed IntegerType and LongType respectively. 
For very
+  // large unsigned values this can cause overflow, causing these numbers
+  // to be represented as negative (above 2^31 for uint32
+  // and above 2^63 for uint64).
+  //
+  // Enabling this option will upcast unsigned integers into a larger type,
+  // i.e. LongType for uint32 and Decimal(20, 0) for uint64 so their 
representation
+  // can contain large unsigned values without overflow.
+  val upcastUnsignedInts: Boolean =
+    parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean
 }
 
 private[sql] object ProtobufOptions {
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index aa3ac998a74..083d1dac081 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils
 import scala.jdk.CollectionConverters._
 
 import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+import com.google.protobuf.WireFormat
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.internal.Logging
@@ -67,9 +68,22 @@ object SchemaConverters extends Logging {
       existingRecordNames: Map[String, Int],
       protobufOptions: ProtobufOptions): Option[StructField] = {
     import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+
     val dataType = fd.getJavaType match {
-      case INT => Some(IntegerType)
-      case LONG => Some(LongType)
+      // When the protobuf type is unsigned and upcastUnsignedIntegers has 
been set,
+      // use a larger type (LongType and Decimal(20,0) for uint32 and uint64).
+      case INT =>
+        if (fd.getLiteType == WireFormat.FieldType.UINT32 && 
protobufOptions.upcastUnsignedInts) {
+          Some(LongType)
+        } else {
+          Some(IntegerType)
+        }
+      case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64
+          && protobufOptions.upcastUnsignedInts) {
+        Some(DecimalType.LongDecimal)
+      } else {
+        Some(LongType)
+      }
       case FLOAT => Some(FloatType)
       case DOUBLE => Some(DoubleType)
       case BOOLEAN => Some(BooleanType)
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 9c397597984..67f6568107e 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -1600,6 +1600,52 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     }
   }
 
+  test("test unsigned integer types") {
+    // Test that we correctly handle unsigned integer parsing.
+    // We're using Integer/Long's `MIN_VALUE` as it has a 1 in the sign bit.
+    val sample = spark.range(1).select(
+      lit(
+        SimpleMessage
+          .newBuilder()
+          .setUint32Value(Integer.MIN_VALUE)
+          .setUint64Value(Long.MinValue)
+          .build()
+          .toByteArray
+      ).as("raw_proto"))
+
+    val expectedWithoutFlag = spark.range(1).select(
+      lit(Integer.MIN_VALUE).as("uint32_value"),
+      lit(Long.MinValue).as("uint64_value")
+    )
+
+    val expectedWithFlag = spark.range(1).select(
+      
lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"),
+      
lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value")
+    )
+
+    checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt) 
=>
+      List(
+        Map.empty[String, String],
+        Map("upcast.unsigned.ints" -> "false")).foreach(opts => {
+        checkAnswer(
+          sample.select(
+              from_protobuf_wrapper($"raw_proto", name, descFilePathOpt, 
opts).as("proto"))
+            .select("proto.uint32_value", "proto.uint64_value"),
+          expectedWithoutFlag)
+      })
+
+      checkAnswer(
+        sample.select(
+            from_protobuf_wrapper(
+              $"raw_proto",
+              name,
+              descFilePathOpt,
+              Map("upcast.unsigned.ints" -> "true")).as("proto"))
+          .select("proto.uint32_value", "proto.uint64_value"),
+        expectedWithFlag)
+    }
+  }
+
 
   def testFromProtobufWithOptions(
     df: DataFrame,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to