This is an automated email from the ASF dual-hosted git repository.

etudenhoefner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new 40fbd8dc5c Spark 3.3: Move the Writer to a visitor (#9672)
40fbd8dc5c is described below

commit 40fbd8dc5cbb47ba2f2bba0a771bb7c0a0d50a10
Author: Fokko Driesprong <[email protected]>
AuthorDate: Wed Feb 7 09:43:41 2024 +0100

    Spark 3.3: Move the Writer to a visitor (#9672)
    
    Backport of https://github.com/apache/iceberg/pull/9440
---
 .../iceberg/spark/data/SparkParquetWriters.java    | 168 +++++++++++++++------
 1 file changed, 119 insertions(+), 49 deletions(-)

diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
index af6f65a089..8baea6c5ab 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
@@ -24,6 +24,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
+import java.util.Optional;
 import java.util.UUID;
 import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry;
 import org.apache.iceberg.parquet.ParquetValueWriter;
@@ -48,11 +49,9 @@ import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.util.ArrayData;
 import org.apache.spark.sql.catalyst.util.MapData;
 import org.apache.spark.sql.types.ArrayType;
-import org.apache.spark.sql.types.ByteType;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.sql.types.MapType;
-import org.apache.spark.sql.types.ShortType;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.unsafe.types.UTF8String;
@@ -136,46 +135,126 @@ public class SparkParquetWriters {
       return ParquetValueWriters.option(fieldType, maxD, writer);
     }
 
+    private static class LogicalTypeAnnotationParquetValueWriterVisitor
+        implements 
LogicalTypeAnnotation.LogicalTypeAnnotationVisitor<ParquetValueWriter<?>> {
+
+      private final ColumnDescriptor desc;
+      private final PrimitiveType primitive;
+
+      LogicalTypeAnnotationParquetValueWriterVisitor(
+          ColumnDescriptor desc, PrimitiveType primitive) {
+        this.desc = desc;
+        this.primitive = primitive;
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.StringLogicalTypeAnnotation stringLogicalType) 
{
+        return Optional.of(utf8Strings(desc));
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.EnumLogicalTypeAnnotation enumLogicalType) {
+        return Optional.of(utf8Strings(desc));
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.JsonLogicalTypeAnnotation jsonLogicalType) {
+        return Optional.of(utf8Strings(desc));
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.UUIDLogicalTypeAnnotation uuidLogicalType) {
+        return Optional.of(uuids(desc));
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) {
+        return 
LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(mapLogicalType);
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.ListLogicalTypeAnnotation listLogicalType) {
+        return 
LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(listLogicalType);
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> 
visit(DecimalLogicalTypeAnnotation decimal) {
+        switch (primitive.getPrimitiveTypeName()) {
+          case INT32:
+            return Optional.of(decimalAsInteger(desc, decimal.getPrecision(), 
decimal.getScale()));
+          case INT64:
+            return Optional.of(decimalAsLong(desc, decimal.getPrecision(), 
decimal.getScale()));
+          case BINARY:
+          case FIXED_LEN_BYTE_ARRAY:
+            return Optional.of(decimalAsFixed(desc, decimal.getPrecision(), 
decimal.getScale()));
+        }
+        return Optional.empty();
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) {
+        return Optional.of(ParquetValueWriters.ints(desc));
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) {
+        if (timeLogicalType.getUnit() == 
LogicalTypeAnnotation.TimeUnit.MICROS) {
+          return Optional.of(ParquetValueWriters.longs(desc));
+        }
+        return Optional.empty();
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.TimestampLogicalTypeAnnotation 
timestampLogicalType) {
+        if (timestampLogicalType.getUnit() == 
LogicalTypeAnnotation.TimeUnit.MICROS) {
+          return Optional.of(ParquetValueWriters.longs(desc));
+        }
+        return Optional.empty();
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.IntLogicalTypeAnnotation intLogicalType) {
+        int bitWidth = intLogicalType.getBitWidth();
+        if (bitWidth <= 8) {
+          return Optional.of(ParquetValueWriters.tinyints(desc));
+        } else if (bitWidth <= 16) {
+          return Optional.of(ParquetValueWriters.shorts(desc));
+        } else if (bitWidth <= 32) {
+          return Optional.of(ParquetValueWriters.ints(desc));
+        } else {
+          return Optional.of(ParquetValueWriters.longs(desc));
+        }
+      }
+
+      @Override
+      public Optional<ParquetValueWriter<?>> visit(
+          LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) {
+        return Optional.of(byteArrays(desc));
+      }
+    }
+
     @Override
     public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType 
primitive) {
       ColumnDescriptor desc = type.getColumnDescription(currentPath());
-
-      if (primitive.getOriginalType() != null) {
-        switch (primitive.getOriginalType()) {
-          case ENUM:
-          case JSON:
-          case UTF8:
-            return utf8Strings(desc);
-          case DATE:
-          case INT_8:
-          case INT_16:
-          case INT_32:
-            return ints(sType, desc);
-          case INT_64:
-          case TIME_MICROS:
-          case TIMESTAMP_MICROS:
-            return ParquetValueWriters.longs(desc);
-          case DECIMAL:
-            DecimalLogicalTypeAnnotation decimal =
-                (DecimalLogicalTypeAnnotation) 
primitive.getLogicalTypeAnnotation();
-            switch (primitive.getPrimitiveTypeName()) {
-              case INT32:
-                return decimalAsInteger(desc, decimal.getPrecision(), 
decimal.getScale());
-              case INT64:
-                return decimalAsLong(desc, decimal.getPrecision(), 
decimal.getScale());
-              case BINARY:
-              case FIXED_LEN_BYTE_ARRAY:
-                return decimalAsFixed(desc, decimal.getPrecision(), 
decimal.getScale());
-              default:
-                throw new UnsupportedOperationException(
-                    "Unsupported base type for decimal: " + 
primitive.getPrimitiveTypeName());
-            }
-          case BSON:
-            return byteArrays(desc);
-          default:
-            throw new UnsupportedOperationException(
-                "Unsupported logical type: " + primitive.getOriginalType());
-        }
+      LogicalTypeAnnotation logicalTypeAnnotation = 
primitive.getLogicalTypeAnnotation();
+
+      if (logicalTypeAnnotation != null) {
+        return logicalTypeAnnotation
+            .accept(new LogicalTypeAnnotationParquetValueWriterVisitor(desc, 
primitive))
+            .orElseThrow(
+                () ->
+                    new UnsupportedOperationException(
+                        "Unsupported logical type: " + 
primitive.getLogicalTypeAnnotation()));
       }
 
       switch (primitive.getPrimitiveTypeName()) {
@@ -188,7 +267,7 @@ public class SparkParquetWriters {
         case BOOLEAN:
           return ParquetValueWriters.booleans(desc);
         case INT32:
-          return ints(sType, desc);
+          return ParquetValueWriters.ints(desc);
         case INT64:
           return ParquetValueWriters.longs(desc);
         case FLOAT:
@@ -201,15 +280,6 @@ public class SparkParquetWriters {
     }
   }
 
-  private static PrimitiveWriter<?> ints(DataType type, ColumnDescriptor desc) 
{
-    if (type instanceof ByteType) {
-      return ParquetValueWriters.tinyints(desc);
-    } else if (type instanceof ShortType) {
-      return ParquetValueWriters.shorts(desc);
-    }
-    return ParquetValueWriters.ints(desc);
-  }
-
   private static PrimitiveWriter<UTF8String> utf8Strings(ColumnDescriptor 
desc) {
     return new UTF8StringWriter(desc);
   }

Reply via email to