This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0356ac009472 [SPARK-40876][SQL] Widening type promotion from integers
to decimal in Parquet vectorized reader
0356ac009472 is described below
commit 0356ac00947282b1a0885ad7eaae1e25e43671fe
Author: Johan Lasperas <[email protected]>
AuthorDate: Tue Jan 23 12:37:18 2024 -0800
[SPARK-40876][SQL] Widening type promotion from integers to decimal in
Parquet vectorized reader
### What changes were proposed in this pull request?
This is a follow-up from https://github.com/apache/spark/pull/44368 and
https://github.com/apache/spark/pull/44513, implementing an additional type
promotion from integers to decimals in the parquet vectorized reader, bringing
it at parity with the non-vectorized reader in that regard.
### Why are the changes needed?
This allows reading parquet files that have different schemas and mix
decimals and integers - e.g reading files containing either `Decimal(15, 2)`
and `INT32` as `Decimal(15, 2)` - as long as the requested decimal type is
large enough to accommodate the integer values without precision loss.
### Does this PR introduce _any_ user-facing change?
Yes, the following now succeeds when using the vectorized Parquet reader:
```
Seq(20).toDF($"a".cast(IntegerType)).write.parquet(path)
spark.read.schema("a decimal(12, 0)").parquet(path).collect()
```
It failed before with the vectorized reader and succeeded with the
non-vectorized reader.
### How was this patch tested?
- Tests added to `ParquetWideningTypeSuite`
- Updated relevant `ParquetQuerySuite` test.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #44803 from johanl-db/SPARK-40876-widening-promotion-int-to-decimal.
Authored-by: Johan Lasperas <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../parquet/ParquetVectorUpdaterFactory.java | 39 ++++++-
.../parquet/VectorizedColumnReader.java | 7 +-
.../datasources/parquet/ParquetQuerySuite.scala | 8 +-
.../parquet/ParquetTypeWideningSuite.scala | 123 ++++++++++++++++++---
4 files changed, 150 insertions(+), 27 deletions(-)
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index 0d8713b58cec..f369688597b9 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -1407,7 +1407,11 @@ public class ParquetVectorUpdaterFactory {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
- this.parquetScale = ((DecimalLogicalTypeAnnotation)
typeAnnotation).getScale();
+ if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
+ this.parquetScale = ((DecimalLogicalTypeAnnotation)
typeAnnotation).getScale();
+ } else {
+ this.parquetScale = 0;
+ }
}
@Override
@@ -1436,14 +1440,18 @@ public class ParquetVectorUpdaterFactory {
}
}
-private static class LongToDecimalUpdater extends DecimalUpdater {
+ private static class LongToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;
- LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
+ LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
- this.parquetScale = ((DecimalLogicalTypeAnnotation)
typeAnnotation).getScale();
+ if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
+ this.parquetScale = ((DecimalLogicalTypeAnnotation)
typeAnnotation).getScale();
+ } else {
+ this.parquetScale = 0;
+ }
}
@Override
@@ -1641,6 +1649,12 @@ private static class FixedLenByteArrayToDecimalUpdater
extends DecimalUpdater {
return typeAnnotation instanceof DateLogicalTypeAnnotation;
}
+ private static boolean isSignedIntAnnotation(LogicalTypeAnnotation
typeAnnotation) {
+ if (!(typeAnnotation instanceof IntLogicalTypeAnnotation)) return false;
+ IntLogicalTypeAnnotation intAnnotation = (IntLogicalTypeAnnotation)
typeAnnotation;
+ return intAnnotation.isSigned();
+ }
+
private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor,
DataType dt) {
DecimalType requestedType = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
@@ -1652,6 +1666,20 @@ private static class FixedLenByteArrayToDecimalUpdater
extends DecimalUpdater {
int scaleIncrease = requestedType.scale() - parquetType.getScale();
int precisionIncrease = requestedType.precision() -
parquetType.getPrecision();
return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
+ } else if (typeAnnotation == null ||
isSignedIntAnnotation(typeAnnotation)) {
+ // Allow reading signed integers (which may be un-annotated) as decimal
as long as the
+ // requested decimal type is large enough to represent all possible
values.
+ PrimitiveType.PrimitiveTypeName typeName =
+ descriptor.getPrimitiveType().getPrimitiveTypeName();
+ int integerPrecision = requestedType.precision() - requestedType.scale();
+ switch (typeName) {
+ case INT32:
+ return integerPrecision >=
DecimalType$.MODULE$.IntDecimal().precision();
+ case INT64:
+ return integerPrecision >=
DecimalType$.MODULE$.LongDecimal().precision();
+ default:
+ return false;
+ }
}
return false;
}
@@ -1662,6 +1690,9 @@ private static class FixedLenByteArrayToDecimalUpdater
extends DecimalUpdater {
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
DecimalLogicalTypeAnnotation decimalType =
(DecimalLogicalTypeAnnotation) typeAnnotation;
return decimalType.getScale() == d.scale();
+ } else if (typeAnnotation == null ||
isSignedIntAnnotation(typeAnnotation)) {
+ // Consider signed integers (which may be un-annotated) as having scale
0.
+ return d.scale() == 0;
}
return false;
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index d580023bc877..731c78cf9450 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -153,10 +153,9 @@ public class VectorizedColumnReader {
// rebasing.
switch (typeName) {
case INT32: {
- boolean isDate = logicalTypeAnnotation instanceof
DateLogicalTypeAnnotation;
- boolean isDecimal = logicalTypeAnnotation instanceof
DecimalLogicalTypeAnnotation;
+ boolean isDecimal = sparkType instanceof DecimalType;
boolean needsUpcast = sparkType == LongType || sparkType == DoubleType
||
- (isDate && sparkType == TimestampNTZType) ||
+ sparkType == TimestampNTZType ||
(isDecimal && !DecimalType.is32BitDecimalType(sparkType));
boolean needsRebase = logicalTypeAnnotation instanceof
DateLogicalTypeAnnotation &&
!"CORRECTED".equals(datetimeRebaseMode);
@@ -164,7 +163,7 @@ public class VectorizedColumnReader {
break;
}
case INT64: {
- boolean isDecimal = logicalTypeAnnotation instanceof
DecimalLogicalTypeAnnotation;
+ boolean isDecimal = sparkType instanceof DecimalType;
boolean needsUpcast = (isDecimal &&
!DecimalType.is64BitDecimalType(sparkType)) ||
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
boolean needsRebase =
updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index b306a526818e..b8a6cb5d0712 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -1037,8 +1037,10 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
withAllParquetReaders {
// We can read the decimal parquet field with a larger precision, if
scale is the same.
- val schema = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
- checkAnswer(readParquet(schema, path), df)
+ val schema1 = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
+ checkAnswer(readParquet(schema1, path), df)
+ val schema2 = "a DECIMAL(18, 1), b DECIMAL(38, 2), c DECIMAL(38, 2)"
+ checkAnswer(readParquet(schema2, path), df)
}
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
@@ -1067,10 +1069,12 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
+ checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT
123456.0"))
checkAnswer(readParquet("c DECIMAL(11, 1)", path), Row(null))
checkAnswer(readParquet("c DECIMAL(13, 0)", path), df.select("c"))
+ checkAnswer(readParquet("c DECIMAL(22, 0)", path), df.select("c"))
val e = intercept[SparkException] {
readParquet("d DECIMAL(3, 2)", path).collect()
}.getCause
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
index 7b8357e20774..6302c2703619 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.io.File
import org.apache.hadoop.fs.Path
+import org.apache.parquet.column.{Encoding, ParquetProperties}
import org.apache.parquet.format.converter.ParquetMetadataConverter
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat}
@@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy,
SQLConf}
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.DecimalType.{ByteDecimal, IntDecimal,
LongDecimal, ShortDecimal}
class ParquetTypeWideningSuite
extends QueryTest
@@ -121,6 +123,19 @@ class ParquetTypeWideningSuite
if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) {
assertAllParquetFilesDictionaryEncoded(dir)
}
+
+ // Check which encoding was used when writing Parquet V2 files.
+ val isParquetV2 = spark.conf.getOption(ParquetOutputFormat.WRITER_VERSION)
+ .contains(ParquetProperties.WriterVersion.PARQUET_2_0.toString)
+ if (isParquetV2) {
+ if (dictionaryEnabled) {
+ assertParquetV2Encoding(dir, Encoding.PLAIN)
+ } else if (DecimalType.is64BitDecimalType(dataType)) {
+ assertParquetV2Encoding(dir, Encoding.DELTA_BINARY_PACKED)
+ } else if (DecimalType.isByteArrayDecimalType(dataType)) {
+ assertParquetV2Encoding(dir, Encoding.DELTA_BYTE_ARRAY)
+ }
+ }
df
}
@@ -145,6 +160,27 @@ class ParquetTypeWideningSuite
}
}
+ /**
+ * Asserts that all parquet files in the given directory have all their
columns encoded with the
+ * given encoding.
+ */
+ private def assertParquetV2Encoding(dir: File, expected_encoding: Encoding):
Unit = {
+ dir.listFiles(_.getName.endsWith(".parquet")).foreach { file =>
+ val parquetMetadata = ParquetFileReader.readFooter(
+ spark.sessionState.newHadoopConf(),
+ new Path(dir.toString, file.getName),
+ ParquetMetadataConverter.NO_FILTER)
+ parquetMetadata.getBlocks.forEach { block =>
+ block.getColumns.forEach { col =>
+ assert(
+ col.getEncodings.contains(expected_encoding),
+ s"Expected column '${col.getPath.toDotString}' to use encoding
$expected_encoding " +
+ s"but found ${col.getEncodings}.")
+ }
+ }
+ }
+ }
+
for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
(Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType),
@@ -157,24 +193,77 @@ class ParquetTypeWideningSuite
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType,
TimestampNTZType)
)
}
- test(s"parquet widening conversion $fromType -> $toType") {
- checkAllParquetReaders(values, fromType, toType, expectError = false)
- }
+ test(s"parquet widening conversion $fromType -> $toType") {
+ checkAllParquetReaders(values, fromType, toType, expectError = false)
+ }
+
+ for {
+ (values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
+ (Seq("1", Byte.MaxValue.toString), ByteType, IntDecimal),
+ (Seq("1", Byte.MaxValue.toString), ByteType, LongDecimal),
+ (Seq("1", Short.MaxValue.toString), ShortType, IntDecimal),
+ (Seq("1", Short.MaxValue.toString), ShortType, LongDecimal),
+ (Seq("1", Short.MaxValue.toString), ShortType,
DecimalType(DecimalType.MAX_PRECISION, 0)),
+ (Seq("1", Int.MaxValue.toString), IntegerType, IntDecimal),
+ (Seq("1", Int.MaxValue.toString), IntegerType, LongDecimal),
+ (Seq("1", Int.MaxValue.toString), IntegerType,
DecimalType(DecimalType.MAX_PRECISION, 0)),
+ (Seq("1", Long.MaxValue.toString), LongType, LongDecimal),
+ (Seq("1", Long.MaxValue.toString), LongType,
DecimalType(DecimalType.MAX_PRECISION, 0)),
+ (Seq("1", Byte.MaxValue.toString), ByteType,
DecimalType(IntDecimal.precision + 1, 1)),
+ (Seq("1", Short.MaxValue.toString), ShortType,
DecimalType(IntDecimal.precision + 1, 1)),
+ (Seq("1", Int.MaxValue.toString), IntegerType,
DecimalType(IntDecimal.precision + 1, 1)),
+ (Seq("1", Long.MaxValue.toString), LongType,
DecimalType(LongDecimal.precision + 1, 1))
+ )
+ }
+ test(s"parquet widening conversion $fromType -> $toType") {
+ checkAllParquetReaders(values, fromType, toType, expectError = false)
+ }
for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
(Seq("1", "2", Int.MinValue.toString), LongType, IntegerType),
(Seq("1.23", "10.34"), DoubleType, FloatType),
(Seq("1.23", "10.34"), FloatType, LongType),
+ (Seq("1", "10"), LongType, DoubleType),
(Seq("1", "10"), LongType, DateType),
(Seq("1", "10"), IntegerType, TimestampType),
(Seq("1", "10"), IntegerType, TimestampNTZType),
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType)
)
}
- test(s"unsupported parquet conversion $fromType -> $toType") {
- checkAllParquetReaders(values, fromType, toType, expectError = true)
- }
+ test(s"unsupported parquet conversion $fromType -> $toType") {
+ checkAllParquetReaders(values, fromType, toType, expectError = true)
+ }
+
+ for {
+ (values: Seq[String], fromType: DataType, toType: DecimalType) <- Seq(
+ // Parquet stores byte, short, int values as INT32, which then requires
using a decimal that
+ // can hold at least 4 byte integers.
+ (Seq("1", "2"), ByteType, DecimalType(1, 0)),
+ (Seq("1", "2"), ByteType, ByteDecimal),
+ (Seq("1", "2"), ShortType, ByteDecimal),
+ (Seq("1", "2"), ShortType, ShortDecimal),
+ (Seq("1", "2"), IntegerType, ShortDecimal),
+ (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision + 1, 1)),
+ (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision + 1, 1)),
+ (Seq("1", "2"), LongType, IntDecimal),
+ (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision - 1, 0)),
+ (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision - 1, 0)),
+ (Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision - 1, 0)),
+ (Seq("1", "2"), LongType, DecimalType(LongDecimal.precision - 1, 0)),
+ (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision, 1)),
+ (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision, 1)),
+ (Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision, 1)),
+ (Seq("1", "2"), LongType, DecimalType(LongDecimal.precision, 1))
+ )
+ }
+ test(s"unsupported parquet conversion $fromType -> $toType") {
+ checkAllParquetReaders(values, fromType, toType,
+ expectError =
+ // parquet-mr allows reading decimals into a smaller precision decimal
type without
+ // checking for overflows. See test below checking for the overflow case
in parquet-mr.
+
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
+ }
for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
@@ -201,17 +290,17 @@ class ParquetTypeWideningSuite
Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++
Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20)
}
- test(
- s"parquet decimal precision change Decimal($fromPrecision, 2) ->
Decimal($toPrecision, 2)") {
- checkAllParquetReaders(
- values = Seq("1.23", "10.34"),
- fromType = DecimalType(fromPrecision, 2),
- toType = DecimalType(toPrecision, 2),
- expectError = fromPrecision > toPrecision &&
- // parquet-mr allows reading decimals into a smaller precision
decimal type without
- // checking for overflows. See test below checking for the overflow
case in parquet-mr.
-
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
- }
+ test(
+ s"parquet decimal precision change Decimal($fromPrecision, 2) ->
Decimal($toPrecision, 2)") {
+ checkAllParquetReaders(
+ values = Seq("1.23", "10.34"),
+ fromType = DecimalType(fromPrecision, 2),
+ toType = DecimalType(toPrecision, 2),
+ expectError = fromPrecision > toPrecision &&
+ // parquet-mr allows reading decimals into a smaller precision decimal
type without
+ // checking for overflows. See test below checking for the overflow
case in parquet-mr.
+
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
+ }
for {
((fromPrecision, fromScale), (toPrecision, toScale)) <-
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]