This is an automated email from the ASF dual-hosted git repository.
etudenhoefner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/main by this push:
new 4446e4f0ab Spark 3.4: Move the Writer to a visitor (#9673)
4446e4f0ab is described below
commit 4446e4f0abe5c040712ea103594ec31ad8ce902e
Author: Fokko Driesprong <[email protected]>
AuthorDate: Wed Feb 7 09:44:23 2024 +0100
Spark 3.4: Move the Writer to a visitor (#9673)
Backport of https://github.com/apache/iceberg/pull/9440
---
.../iceberg/spark/data/SparkParquetWriters.java | 168 +++++++++++++++------
1 file changed, 119 insertions(+), 49 deletions(-)
diff --git
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
index af6f65a089..8baea6c5ab 100644
---
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
+++
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
@@ -24,6 +24,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
+import java.util.Optional;
import java.util.UUID;
import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry;
import org.apache.iceberg.parquet.ParquetValueWriter;
@@ -48,11 +49,9 @@ import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.ArrayType;
-import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.MapType;
-import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
@@ -136,46 +135,126 @@ public class SparkParquetWriters {
return ParquetValueWriters.option(fieldType, maxD, writer);
}
+ private static class LogicalTypeAnnotationParquetValueWriterVisitor
+ implements
LogicalTypeAnnotation.LogicalTypeAnnotationVisitor<ParquetValueWriter<?>> {
+
+ private final ColumnDescriptor desc;
+ private final PrimitiveType primitive;
+
+ LogicalTypeAnnotationParquetValueWriterVisitor(
+ ColumnDescriptor desc, PrimitiveType primitive) {
+ this.desc = desc;
+ this.primitive = primitive;
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.StringLogicalTypeAnnotation stringLogicalType)
{
+ return Optional.of(utf8Strings(desc));
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.EnumLogicalTypeAnnotation enumLogicalType) {
+ return Optional.of(utf8Strings(desc));
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.JsonLogicalTypeAnnotation jsonLogicalType) {
+ return Optional.of(utf8Strings(desc));
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.UUIDLogicalTypeAnnotation uuidLogicalType) {
+ return Optional.of(uuids(desc));
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) {
+ return
LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(mapLogicalType);
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.ListLogicalTypeAnnotation listLogicalType) {
+ return
LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(listLogicalType);
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>>
visit(DecimalLogicalTypeAnnotation decimal) {
+ switch (primitive.getPrimitiveTypeName()) {
+ case INT32:
+ return Optional.of(decimalAsInteger(desc, decimal.getPrecision(),
decimal.getScale()));
+ case INT64:
+ return Optional.of(decimalAsLong(desc, decimal.getPrecision(),
decimal.getScale()));
+ case BINARY:
+ case FIXED_LEN_BYTE_ARRAY:
+ return Optional.of(decimalAsFixed(desc, decimal.getPrecision(),
decimal.getScale()));
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) {
+ return Optional.of(ParquetValueWriters.ints(desc));
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) {
+ if (timeLogicalType.getUnit() ==
LogicalTypeAnnotation.TimeUnit.MICROS) {
+ return Optional.of(ParquetValueWriters.longs(desc));
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
timestampLogicalType) {
+ if (timestampLogicalType.getUnit() ==
LogicalTypeAnnotation.TimeUnit.MICROS) {
+ return Optional.of(ParquetValueWriters.longs(desc));
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.IntLogicalTypeAnnotation intLogicalType) {
+ int bitWidth = intLogicalType.getBitWidth();
+ if (bitWidth <= 8) {
+ return Optional.of(ParquetValueWriters.tinyints(desc));
+ } else if (bitWidth <= 16) {
+ return Optional.of(ParquetValueWriters.shorts(desc));
+ } else if (bitWidth <= 32) {
+ return Optional.of(ParquetValueWriters.ints(desc));
+ } else {
+ return Optional.of(ParquetValueWriters.longs(desc));
+ }
+ }
+
+ @Override
+ public Optional<ParquetValueWriter<?>> visit(
+ LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) {
+ return Optional.of(byteArrays(desc));
+ }
+ }
+
@Override
public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType
primitive) {
ColumnDescriptor desc = type.getColumnDescription(currentPath());
-
- if (primitive.getOriginalType() != null) {
- switch (primitive.getOriginalType()) {
- case ENUM:
- case JSON:
- case UTF8:
- return utf8Strings(desc);
- case DATE:
- case INT_8:
- case INT_16:
- case INT_32:
- return ints(sType, desc);
- case INT_64:
- case TIME_MICROS:
- case TIMESTAMP_MICROS:
- return ParquetValueWriters.longs(desc);
- case DECIMAL:
- DecimalLogicalTypeAnnotation decimal =
- (DecimalLogicalTypeAnnotation)
primitive.getLogicalTypeAnnotation();
- switch (primitive.getPrimitiveTypeName()) {
- case INT32:
- return decimalAsInteger(desc, decimal.getPrecision(),
decimal.getScale());
- case INT64:
- return decimalAsLong(desc, decimal.getPrecision(),
decimal.getScale());
- case BINARY:
- case FIXED_LEN_BYTE_ARRAY:
- return decimalAsFixed(desc, decimal.getPrecision(),
decimal.getScale());
- default:
- throw new UnsupportedOperationException(
- "Unsupported base type for decimal: " +
primitive.getPrimitiveTypeName());
- }
- case BSON:
- return byteArrays(desc);
- default:
- throw new UnsupportedOperationException(
- "Unsupported logical type: " + primitive.getOriginalType());
- }
+ LogicalTypeAnnotation logicalTypeAnnotation =
primitive.getLogicalTypeAnnotation();
+
+ if (logicalTypeAnnotation != null) {
+ return logicalTypeAnnotation
+ .accept(new LogicalTypeAnnotationParquetValueWriterVisitor(desc,
primitive))
+ .orElseThrow(
+ () ->
+ new UnsupportedOperationException(
+ "Unsupported logical type: " +
primitive.getLogicalTypeAnnotation()));
}
switch (primitive.getPrimitiveTypeName()) {
@@ -188,7 +267,7 @@ public class SparkParquetWriters {
case BOOLEAN:
return ParquetValueWriters.booleans(desc);
case INT32:
- return ints(sType, desc);
+ return ParquetValueWriters.ints(desc);
case INT64:
return ParquetValueWriters.longs(desc);
case FLOAT:
@@ -201,15 +280,6 @@ public class SparkParquetWriters {
}
}
- private static PrimitiveWriter<?> ints(DataType type, ColumnDescriptor desc)
{
- if (type instanceof ByteType) {
- return ParquetValueWriters.tinyints(desc);
- } else if (type instanceof ShortType) {
- return ParquetValueWriters.shorts(desc);
- }
- return ParquetValueWriters.ints(desc);
- }
-
private static PrimitiveWriter<UTF8String> utf8Strings(ColumnDescriptor
desc) {
return new UTF8StringWriter(desc);
}