This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d439e34d6bda [SPARK-40876][SQL] Widening type promotion for decimals 
with larger scale in Parquet readers
d439e34d6bda is described below

commit d439e34d6bda661e89cd29914db1a5d3abaafc67
Author: Johan Lasperas <johan.laspe...@databricks.com>
AuthorDate: Mon Jan 8 16:58:23 2024 +0800

    [SPARK-40876][SQL] Widening type promotion for decimals with larger scale 
in Parquet readers
    
    ### What changes were proposed in this pull request?
    This is a follow-up from https://github.com/apache/spark/pull/44368 
implementing an additional type promotion to decimals with larger precision and 
scale.
    As long as the precision increases by at least as much as the scale, the 
decimal values can be promoted without loss of precision: Decimal(6, 2) -> 
Decimal(8, 4):  1234.56 -> 1234.5600.
    
    The non-vectorized reader (parquet-mr) is already able to do this type 
promotion, this PR implements it for the vectorized reader.
    
    ### Why are the changes needed?
    This allows reading multiple parquet files that contain decimal with 
different precision/scales
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, the following now succeeds when using the vectorized Parquet reader:
    ```
      Seq(20).toDF($"a".cast(DecimalType(4, 2))).write.parquet(path)
      spark.read.schema("a decimal(6, 4)").parquet(path).collect()
    ```
    It failed before with the vectorized reader and succeeded with the 
non-vectorized reader.
    
    ### How was this patch tested?
    - Tests added to `ParquetWideningTypeSuite` to cover decimal promotion 
between decimals with different physical types: INT32, INT64, 
FIXED_LEN_BYTE_ARRAY.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #44513 from 
johanl-db/SPARK-40876-parquet-type-promotion-decimal-scale.
    
    Authored-by: Johan Lasperas <johan.laspe...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../parquet/ParquetVectorUpdaterFactory.java       | 224 ++++++++++++++++++++-
 .../parquet/VectorizedColumnReader.java            |  31 ++-
 .../datasources/parquet/ParquetQuerySuite.scala    |   4 +-
 .../parquet/ParquetTypeWideningSuite.scala         |  36 +++-
 4 files changed, 281 insertions(+), 14 deletions(-)

diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index 4961b52f4bb5..3863818b0255 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -34,7 +34,9 @@ import 
org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupporte
 import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
 import org.apache.spark.sql.types.*;
 
+import java.math.BigDecimal;
 import java.math.BigInteger;
+import java.math.RoundingMode;
 import java.time.ZoneId;
 import java.time.ZoneOffset;
 import java.util.Arrays;
@@ -108,6 +110,8 @@ public class ParquetVectorUpdaterFactory {
           }
         } else if (sparkType instanceof YearMonthIntervalType) {
           return new IntegerUpdater();
+        } else if (canReadAsDecimal(descriptor, sparkType)) {
+          return new IntegerToDecimalUpdater(descriptor, (DecimalType) 
sparkType);
         }
       }
       case INT64 -> {
@@ -153,6 +157,8 @@ public class ParquetVectorUpdaterFactory {
           return new LongAsMicrosUpdater();
         } else if (sparkType instanceof DayTimeIntervalType) {
           return new LongUpdater();
+        } else if (canReadAsDecimal(descriptor, sparkType)) {
+          return new LongToDecimalUpdater(descriptor, (DecimalType) sparkType);
         }
       }
       case FLOAT -> {
@@ -194,6 +200,8 @@ public class ParquetVectorUpdaterFactory {
         if (sparkType == DataTypes.StringType || sparkType == 
DataTypes.BinaryType ||
           canReadAsBinaryDecimal(descriptor, sparkType)) {
           return new BinaryUpdater();
+        } else if (canReadAsDecimal(descriptor, sparkType)) {
+          return new BinaryToDecimalUpdater(descriptor, (DecimalType) 
sparkType);
         }
       }
       case FIXED_LEN_BYTE_ARRAY -> {
@@ -206,6 +214,8 @@ public class ParquetVectorUpdaterFactory {
           return new FixedLenByteArrayUpdater(arrayLen);
         } else if (sparkType == DataTypes.BinaryType) {
           return new FixedLenByteArrayUpdater(arrayLen);
+        } else if (canReadAsDecimal(descriptor, sparkType)) {
+          return new FixedLenByteArrayToDecimalUpdater(descriptor, 
(DecimalType) sparkType);
         }
       }
       default -> {}
@@ -1358,6 +1368,188 @@ public class ParquetVectorUpdaterFactory {
     }
   }
 
+  private abstract static class DecimalUpdater implements ParquetVectorUpdater 
{
+
+    private final DecimalType sparkType;
+
+    DecimalUpdater(DecimalType sparkType) {
+      this.sparkType = sparkType;
+    }
+
+    @Override
+    public void readValues(
+        int total,
+        int offset,
+        WritableColumnVector values,
+        VectorizedValuesReader valuesReader) {
+      for (int i = 0; i < total; i++) {
+        readValue(offset + i, values, valuesReader);
+      }
+    }
+
+    protected void writeDecimal(int offset, WritableColumnVector values, 
BigDecimal decimal) {
+      BigDecimal scaledDecimal = decimal.setScale(sparkType.scale(), 
RoundingMode.UNNECESSARY);
+      if (DecimalType.is32BitDecimalType(sparkType)) {
+        values.putInt(offset, scaledDecimal.unscaledValue().intValue());
+      } else if (DecimalType.is64BitDecimalType(sparkType)) {
+        values.putLong(offset, scaledDecimal.unscaledValue().longValue());
+      } else {
+        values.putByteArray(offset, 
scaledDecimal.unscaledValue().toByteArray());
+      }
+    }
+  }
+
+  private static class IntegerToDecimalUpdater extends DecimalUpdater {
+    private final int parquetScale;
+
+    IntegerToDecimalUpdater(ColumnDescriptor descriptor, DecimalType 
sparkType) {
+      super(sparkType);
+      LogicalTypeAnnotation typeAnnotation =
+        descriptor.getPrimitiveType().getLogicalTypeAnnotation();
+      this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+    }
+
+    @Override
+    public void skipValues(int total, VectorizedValuesReader valuesReader) {
+        valuesReader.skipIntegers(total);
+    }
+
+    @Override
+    public void readValue(
+        int offset,
+        WritableColumnVector values,
+        VectorizedValuesReader valuesReader) {
+      BigDecimal decimal = BigDecimal.valueOf(valuesReader.readInteger(), 
parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+
+    @Override
+    public void decodeSingleDictionaryId(
+        int offset,
+        WritableColumnVector values,
+        WritableColumnVector dictionaryIds,
+        Dictionary dictionary) {
+      BigDecimal decimal =
+        
BigDecimal.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), 
parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+  }
+
+private static class LongToDecimalUpdater extends DecimalUpdater {
+    private final int parquetScale;
+
+   LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
+      super(sparkType);
+      LogicalTypeAnnotation typeAnnotation =
+        descriptor.getPrimitiveType().getLogicalTypeAnnotation();
+      this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+    }
+
+    @Override
+    public void skipValues(int total, VectorizedValuesReader valuesReader) {
+        valuesReader.skipLongs(total);
+    }
+
+    @Override
+    public void readValue(
+        int offset,
+        WritableColumnVector values,
+        VectorizedValuesReader valuesReader) {
+      BigDecimal decimal = BigDecimal.valueOf(valuesReader.readLong(), 
parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+
+    @Override
+    public void decodeSingleDictionaryId(
+        int offset,
+        WritableColumnVector values,
+        WritableColumnVector dictionaryIds,
+        Dictionary dictionary) {
+      BigDecimal decimal =
+        
BigDecimal.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset)), 
parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+  }
+
+private static class BinaryToDecimalUpdater extends DecimalUpdater {
+    private final int parquetScale;
+
+  BinaryToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
+      super(sparkType);
+      LogicalTypeAnnotation typeAnnotation =
+        descriptor.getPrimitiveType().getLogicalTypeAnnotation();
+      this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+    }
+
+    @Override
+    public void skipValues(int total, VectorizedValuesReader valuesReader) {
+        valuesReader.skipBinary(total);
+    }
+
+    @Override
+    public void readValue(
+        int offset,
+        WritableColumnVector values,
+        VectorizedValuesReader valuesReader) {
+      valuesReader.readBinary(1, values, offset);
+      BigInteger value = new BigInteger(values.getBinary(offset));
+      BigDecimal decimal = new BigDecimal(value, parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+
+    @Override
+    public void decodeSingleDictionaryId(
+        int offset,
+        WritableColumnVector values,
+        WritableColumnVector dictionaryIds,
+        Dictionary dictionary) {
+      BigInteger value =
+        new 
BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
+      BigDecimal decimal = new BigDecimal(value, parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+  }
+
+private static class FixedLenByteArrayToDecimalUpdater extends DecimalUpdater {
+    private final int parquetScale;
+    private final int arrayLen;
+
+   FixedLenByteArrayToDecimalUpdater(ColumnDescriptor descriptor, DecimalType 
sparkType) {
+      super(sparkType);
+      LogicalTypeAnnotation typeAnnotation =
+        descriptor.getPrimitiveType().getLogicalTypeAnnotation();
+      this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+      this.arrayLen = descriptor.getPrimitiveType().getTypeLength();
+    }
+
+    @Override
+    public void skipValues(int total, VectorizedValuesReader valuesReader) {
+        valuesReader.skipFixedLenByteArray(total, arrayLen);
+    }
+
+    @Override
+    public void readValue(
+        int offset,
+        WritableColumnVector values,
+        VectorizedValuesReader valuesReader) {
+      BigInteger value = new 
BigInteger(valuesReader.readBinary(arrayLen).getBytes());
+      BigDecimal decimal = new BigDecimal(value, this.parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+
+    @Override
+    public void decodeSingleDictionaryId(
+        int offset,
+        WritableColumnVector values,
+        WritableColumnVector dictionaryIds,
+        Dictionary dictionary) {
+      BigInteger value =
+        new 
BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
+      BigDecimal decimal = new BigDecimal(value, this.parquetScale);
+      writeDecimal(offset, values, decimal);
+    }
+  }
+
   private static int rebaseDays(int julianDays, final boolean failIfRebase) {
     if (failIfRebase) {
       if (julianDays < RebaseDateTime.lastSwitchJulianDay()) {
@@ -1418,16 +1610,21 @@ public class ParquetVectorUpdaterFactory {
 
   private static boolean canReadAsIntDecimal(ColumnDescriptor descriptor, 
DataType dt) {
     if (!DecimalType.is32BitDecimalType(dt)) return false;
-    return isDecimalTypeMatched(descriptor, dt);
+    return isDecimalTypeMatched(descriptor, dt) && 
isSameDecimalScale(descriptor, dt);
   }
 
   private static boolean canReadAsLongDecimal(ColumnDescriptor descriptor, 
DataType dt) {
     if (!DecimalType.is64BitDecimalType(dt)) return false;
-    return isDecimalTypeMatched(descriptor, dt);
+    return isDecimalTypeMatched(descriptor, dt) && 
isSameDecimalScale(descriptor, dt);
   }
 
   private static boolean canReadAsBinaryDecimal(ColumnDescriptor descriptor, 
DataType dt) {
     if (!DecimalType.isByteArrayDecimalType(dt)) return false;
+    return isDecimalTypeMatched(descriptor, dt) && 
isSameDecimalScale(descriptor, dt);
+  }
+
+  private static boolean canReadAsDecimal(ColumnDescriptor descriptor, 
DataType dt) {
+    if (!(dt instanceof DecimalType)) return false;
     return isDecimalTypeMatched(descriptor, dt);
   }
 
@@ -1444,14 +1641,29 @@ public class ParquetVectorUpdaterFactory {
   }
 
   private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, 
DataType dt) {
+    DecimalType requestedType = (DecimalType) dt;
+    LogicalTypeAnnotation typeAnnotation = 
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
+    if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
+      DecimalLogicalTypeAnnotation parquetType = 
(DecimalLogicalTypeAnnotation) typeAnnotation;
+      // If the required scale is larger than or equal to the physical decimal 
scale in the Parquet
+      // metadata, we can upscale the value as long as the precision also 
increases by as much so
+      // that there is no loss of precision.
+      int scaleIncrease = requestedType.scale() - parquetType.getScale();
+      int precisionIncrease = requestedType.precision() - 
parquetType.getPrecision();
+      return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
+    }
+    return false;
+  }
+
+  private static boolean isSameDecimalScale(ColumnDescriptor descriptor, 
DataType dt) {
     DecimalType d = (DecimalType) dt;
     LogicalTypeAnnotation typeAnnotation = 
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
-    if (typeAnnotation instanceof DecimalLogicalTypeAnnotation decimalType) {
-      // It's OK if the required decimal precision is larger than or equal to 
the physical decimal
-      // precision in the Parquet metadata, as long as the decimal scale is 
the same.
-      return decimalType.getPrecision() <= d.precision() && 
decimalType.getScale() == d.scale();
+    if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
+      DecimalLogicalTypeAnnotation decimalType = 
(DecimalLogicalTypeAnnotation) typeAnnotation;
+      return decimalType.getScale() == d.scale();
     }
     return false;
   }
+
 }
 
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index 7c9bca6710aa..6479644968ed 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -152,32 +152,51 @@ public class VectorizedColumnReader {
     switch (typeName) {
       case INT32: {
         boolean isDate = logicalTypeAnnotation instanceof 
DateLogicalTypeAnnotation;
-        boolean needsUpcast = sparkType == LongType || (isDate && sparkType == 
TimestampNTZType) ||
-          !DecimalType.is32BitDecimalType(sparkType);
+        boolean isDecimal = logicalTypeAnnotation instanceof 
DecimalLogicalTypeAnnotation;
+        boolean needsUpcast = sparkType == LongType || sparkType == DoubleType 
||
+          (isDate && sparkType == TimestampNTZType) ||
+          (isDecimal && !DecimalType.is32BitDecimalType(sparkType));
         boolean needsRebase = logicalTypeAnnotation instanceof 
DateLogicalTypeAnnotation &&
           !"CORRECTED".equals(datetimeRebaseMode);
-        isSupported = !needsUpcast && !needsRebase;
+        isSupported = !needsUpcast && !needsRebase && 
!needsDecimalScaleRebase(sparkType);
         break;
       }
       case INT64: {
-        boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) ||
+        boolean isDecimal = logicalTypeAnnotation instanceof 
DecimalLogicalTypeAnnotation;
+        boolean needsUpcast = (isDecimal && 
!DecimalType.is64BitDecimalType(sparkType)) ||
           updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
         boolean needsRebase = 
updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
           !"CORRECTED".equals(datetimeRebaseMode);
-        isSupported = !needsUpcast && !needsRebase;
+        isSupported = !needsUpcast && !needsRebase && 
!needsDecimalScaleRebase(sparkType);
         break;
       }
       case FLOAT:
         isSupported = sparkType == FloatType;
         break;
       case DOUBLE:
-      case BINARY:
         isSupported = true;
         break;
+      case BINARY:
+        isSupported = !needsDecimalScaleRebase(sparkType);
+        break;
     }
     return isSupported;
   }
 
+  /**
+   * Returns whether the Parquet type of this column and the given spark type 
are two decimal types
+   * with different scale.
+   */
+  private boolean needsDecimalScaleRebase(DataType sparkType) {
+      LogicalTypeAnnotation typeAnnotation =
+        descriptor.getPrimitiveType().getLogicalTypeAnnotation();
+      if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return 
false;
+      if (!(sparkType instanceof DecimalType)) return false;
+      DecimalLogicalTypeAnnotation parquetDecimal = 
(DecimalLogicalTypeAnnotation) typeAnnotation;
+      DecimalType sparkDecimal = (DecimalType) sparkType;
+      return parquetDecimal.getScale() != sparkDecimal.scale();
+  }
+
   /**
    * Reads `total` rows from this columnReader into column.
    */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 73a9222c7338..b306a526818e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -1049,7 +1049,9 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
       }
 
       withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
-        Seq("a DECIMAL(3, 2)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach 
{ schema =>
+       val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
+        checkAnswer(readParquet(schema1, path), df)
+        Seq("a DECIMAL(3, 0)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach 
{ schema =>
           val e = intercept[SparkException] {
             readParquet(schema, path).collect()
           }.getCause.getCause
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
index 0a8618944241..1f56b51de3dd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
@@ -209,10 +209,44 @@ class ParquetTypeWideningSuite
         toType = DecimalType(toPrecision, 2),
         expectError = fromPrecision > toPrecision &&
           // parquet-mr allows reading decimals into a smaller precision 
decimal type without
-          // checking for overflows. See test below.
+          // checking for overflows. See test below checking for the overflow 
case in parquet-mr.
           
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
     }
 
+  for {
+    ((fromPrecision, fromScale), (toPrecision, toScale)) <-
+      // Test changing decimal types for decimals backed by different physical 
parquet types:
+      // - INT32: precisions 5, 7
+      // - INT64: precisions 10, 12
+      // - FIXED_LEN_BYTE_ARRAY: precisions 20, 22
+      // Widening  precision and scale by the same amount.
+      Seq((5, 2) -> (7, 4), (5, 2) -> (10, 7), (5, 2) -> (20, 17), (10, 2) -> 
(12, 4),
+        (10, 2) -> (20, 12), (20, 2) -> (22, 4)) ++
+      // Narrowing precision and scale by the same amount.
+      Seq((7, 4) -> (5, 2), (10, 7) -> (5, 2), (20, 17) -> (5, 2), (12, 4) -> 
(10, 2),
+        (20, 17) -> (10, 2), (22, 4) -> (20, 2)) ++
+      // Increasing precision and decreasing scale.
+      Seq((5, 4) -> (7, 2), (10, 6) -> (12, 4), (20, 7) -> (22, 5)) ++
+      // Decreasing precision and increasing scale.
+      Seq((7, 2) -> (5, 4), (12, 4) -> (10, 6), (22, 5) -> (20, 7)) ++
+      // Increasing precision by a smaller amount than scale.
+      Seq((5, 2) -> (6, 4), (10, 4) -> (12, 7), (20, 5) -> (22, 8))
+  }
+  test(s"parquet decimal precision and scale change Decimal($fromPrecision, 
$fromScale) -> " +
+    s"Decimal($toPrecision, $toScale)"
+  ) {
+    checkAllParquetReaders(
+      values = Seq("1.23", "10.34"),
+      fromType = DecimalType(fromPrecision, fromScale),
+      toType = DecimalType(toPrecision, toScale),
+      expectError =
+        (toScale < fromScale || toPrecision - toScale < fromPrecision - 
fromScale) &&
+          // parquet-mr allows reading decimals into a smaller precision 
decimal type without
+          // checking for overflows. See test below checking for the overflow 
case in parquet-mr.
+          
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean
+    )
+  }
+
   test("parquet decimal type change Decimal(5, 2) -> Decimal(3, 2) overflows 
with parquet-mr") {
     withTempDir { dir =>
       withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to