openinx commented on a change in pull request #1272:
URL: https://github.com/apache/iceberg/pull/1272#discussion_r476050599



##########
File path: 
flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java
##########
@@ -19,38 +19,436 @@
 
 package org.apache.iceberg.flink.data;
 
+import java.util.Iterator;
 import java.util.List;
-import org.apache.flink.types.Row;
-import org.apache.iceberg.data.parquet.BaseParquetWriter;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.DecimalData;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.data.TimestampData;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.MapType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.RowType.RowField;
+import org.apache.flink.table.types.logical.SmallIntType;
+import org.apache.flink.table.types.logical.TinyIntType;
+import org.apache.iceberg.parquet.ParquetValueReaders;
 import org.apache.iceberg.parquet.ParquetValueWriter;
 import org.apache.iceberg.parquet.ParquetValueWriters;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.types.TypeUtil;
+import org.apache.iceberg.util.DecimalUtil;
+import org.apache.parquet.column.ColumnDescriptor;
+import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.schema.GroupType;
+import 
org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
 import org.apache.parquet.schema.MessageType;
+import org.apache.parquet.schema.PrimitiveType;
+import org.apache.parquet.schema.Type;
 
-public class FlinkParquetWriters extends BaseParquetWriter<Row> {
+public class FlinkParquetWriters {
+  private FlinkParquetWriters() {
+  }
 
-  private static final FlinkParquetWriters INSTANCE = new 
FlinkParquetWriters();
+  @SuppressWarnings("unchecked")
+  public static <T> ParquetValueWriter<T> buildWriter(LogicalType schema, 
MessageType type) {
+    return (ParquetValueWriter<T>) ParquetWithFlinkSchemaVisitor.visit(schema, 
type, new WriteBuilder(type));
+  }
 
-  private FlinkParquetWriters() {
+  private static class WriteBuilder extends 
ParquetWithFlinkSchemaVisitor<ParquetValueWriter<?>> {
+    private final MessageType type;
+
+    WriteBuilder(MessageType type) {
+      this.type = type;
+    }
+
+    @Override
+    public ParquetValueWriter<?> message(RowType sStruct, MessageType message, 
List<ParquetValueWriter<?>> fields) {
+      return struct(sStruct, message.asGroupType(), fields);
+    }
+
+    @Override
+    public ParquetValueWriter<?> struct(RowType sStruct, GroupType struct,
+                                        List<ParquetValueWriter<?>> 
fieldWriters) {
+      List<Type> fields = struct.getFields();
+      List<RowField> flinkFields = sStruct.getFields();
+      List<ParquetValueWriter<?>> writers = 
Lists.newArrayListWithExpectedSize(fieldWriters.size());
+      List<LogicalType> flinkTypes = Lists.newArrayList();
+      for (int i = 0; i < fields.size(); i += 1) {
+        writers.add(newOption(struct.getType(i), fieldWriters.get(i)));
+        flinkTypes.add(flinkFields.get(i).getType());
+      }
+
+      return new RowDataWriter(writers, flinkTypes);
+    }
+
+    @Override
+    public ParquetValueWriter<?> list(ArrayType sArray, GroupType array, 
ParquetValueWriter<?> elementWriter) {
+      GroupType repeated = array.getFields().get(0).asGroupType();
+      String[] repeatedPath = currentPath();
+
+      int repeatedD = type.getMaxDefinitionLevel(repeatedPath);
+      int repeatedR = type.getMaxRepetitionLevel(repeatedPath);
+
+      return new ArrayDataWriter<>(repeatedD, repeatedR,
+          newOption(repeated.getType(0), elementWriter),
+          sArray.getElementType());
+    }
+
+    @Override
+    public ParquetValueWriter<?> map(MapType sMap, GroupType map,
+                                     ParquetValueWriter<?> keyWriter, 
ParquetValueWriter<?> valueWriter) {
+      GroupType repeatedKeyValue = map.getFields().get(0).asGroupType();
+      String[] repeatedPath = currentPath();
+
+      int repeatedD = type.getMaxDefinitionLevel(repeatedPath);
+      int repeatedR = type.getMaxRepetitionLevel(repeatedPath);
+
+      return new MapDataWriter<>(repeatedD, repeatedR,
+          newOption(repeatedKeyValue.getType(0), keyWriter),
+          newOption(repeatedKeyValue.getType(1), valueWriter),
+          sMap.getKeyType(), sMap.getValueType());
+    }
+
+
+    private ParquetValueWriter<?> newOption(org.apache.parquet.schema.Type 
fieldType, ParquetValueWriter<?> writer) {
+      int maxD = type.getMaxDefinitionLevel(path(fieldType.getName()));
+      return ParquetValueWriters.option(fieldType, maxD, writer);
+    }
+
+    @Override
+    public ParquetValueWriter<?> primitive(LogicalType sType, PrimitiveType 
primitive) {
+      ColumnDescriptor desc = type.getColumnDescription(currentPath());
+
+      if (primitive.getOriginalType() != null) {
+        switch (primitive.getOriginalType()) {
+          case ENUM:
+          case JSON:
+          case UTF8:
+            return strings(desc);
+          case DATE:
+          case INT_8:
+          case INT_16:
+          case INT_32:
+            return ints(sType, desc);
+          case INT_64:
+            return ParquetValueWriters.longs(desc);
+          case TIME_MICROS:
+            return timeMicros(desc);
+          case TIMESTAMP_MICROS:
+            return timestamps(desc);
+          case DECIMAL:
+            DecimalLogicalTypeAnnotation decimal = 
(DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation();
+            switch (primitive.getPrimitiveTypeName()) {
+              case INT32:
+                return decimalAsInteger(desc, decimal.getPrecision(), 
decimal.getScale());
+              case INT64:
+                return decimalAsLong(desc, decimal.getPrecision(), 
decimal.getScale());
+              case BINARY:
+              case FIXED_LEN_BYTE_ARRAY:
+                return decimalAsFixed(desc, decimal.getPrecision(), 
decimal.getScale());
+              default:
+                throw new UnsupportedOperationException(
+                    "Unsupported base type for decimal: " + 
primitive.getPrimitiveTypeName());
+            }
+          case BSON:
+            return byteArrays(desc);
+          default:
+            throw new UnsupportedOperationException(
+                "Unsupported logical type: " + primitive.getOriginalType());
+        }
+      }
+
+      switch (primitive.getPrimitiveTypeName()) {
+        case FIXED_LEN_BYTE_ARRAY:
+        case BINARY:
+          return byteArrays(desc);
+        case BOOLEAN:
+          return ParquetValueWriters.booleans(desc);
+        case INT32:
+          return ints(sType, desc);
+        case INT64:
+          return ParquetValueWriters.longs(desc);
+        case FLOAT:
+          return ParquetValueWriters.floats(desc);
+        case DOUBLE:
+          return ParquetValueWriters.doubles(desc);
+        default:
+          throw new UnsupportedOperationException("Unsupported type: " + 
primitive);
+      }
+    }
+  }
+
+  private static ParquetValueWriters.PrimitiveWriter<?> ints(LogicalType type, 
ColumnDescriptor desc) {
+    if (type instanceof TinyIntType) {
+      return ParquetValueWriters.tinyints(desc);
+    } else if (type instanceof SmallIntType) {
+      return ParquetValueWriters.shorts(desc);
+    }
+    return ParquetValueWriters.ints(desc);
+  }
+
+  private static ParquetValueWriters.PrimitiveWriter<StringData> 
strings(ColumnDescriptor desc) {
+    return new StringDataWriter(desc);
+  }
+
+  private static ParquetValueWriters.PrimitiveWriter<Integer> 
timeMicros(ColumnDescriptor desc) {
+    return new TimeMicrosWriter(desc);
+  }
+
+  private static ParquetValueWriters.PrimitiveWriter<DecimalData> 
decimalAsInteger(ColumnDescriptor desc,
+                                                                               
    int precision, int scale) {
+    return new IntegerDecimalWriter(desc, precision, scale);
+  }
+  private static ParquetValueWriters.PrimitiveWriter<DecimalData> 
decimalAsLong(ColumnDescriptor desc,
+                                                                               
 int precision, int scale) {
+    return new LongDecimalWriter(desc, precision, scale);
+  }
+
+  private static ParquetValueWriters.PrimitiveWriter<DecimalData> 
decimalAsFixed(ColumnDescriptor desc,
+                                                                               
  int precision, int scale) {
+    return new FixedDecimalWriter(desc, precision, scale);
+  }
+
+  private static ParquetValueWriters.PrimitiveWriter<TimestampData> 
timestamps(ColumnDescriptor desc) {
+    return new TimestampDataWriter(desc);
+  }
+
+  private static ParquetValueWriters.PrimitiveWriter<byte[]> 
byteArrays(ColumnDescriptor desc) {
+    return new ByteArrayWriter(desc);
+  }
+
+  private static class StringDataWriter extends 
ParquetValueWriters.PrimitiveWriter<StringData> {
+    private StringDataWriter(ColumnDescriptor desc) {
+      super(desc);
+    }
+
+    @Override
+    public void write(int repetitionLevel, StringData value) {
+      column.writeBinary(repetitionLevel, 
Binary.fromReusedByteArray(value.toBytes()));
+    }
+  }
+
+  private static class TimeMicrosWriter extends 
ParquetValueWriters.PrimitiveWriter<Integer> {
+    private TimeMicrosWriter(ColumnDescriptor desc) {
+      super(desc);
+    }
+
+    @Override
+    public void write(int repetitionLevel, Integer value) {
+      long micros = Long.valueOf(value) * 1000;
+      column.writeLong(repetitionLevel, micros);
+    }
+  }
+
+  private static class IntegerDecimalWriter extends 
ParquetValueWriters.PrimitiveWriter<DecimalData> {
+    private final int precision;
+    private final int scale;
+
+    private IntegerDecimalWriter(ColumnDescriptor desc, int precision, int 
scale) {
+      super(desc);
+      this.precision = precision;
+      this.scale = scale;
+    }
+
+    @Override
+    public void write(int repetitionLevel, DecimalData decimal) {
+      Preconditions.checkArgument(decimal.scale() == scale,
+          "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, 
scale, decimal);
+      Preconditions.checkArgument(decimal.precision() <= precision,
+          "Cannot write value as decimal(%s,%s), too large: %s", precision, 
scale, decimal);
+
+      column.writeInteger(repetitionLevel, (int) decimal.toUnscaledLong());
+    }
+  }
+
+  private static class LongDecimalWriter extends 
ParquetValueWriters.PrimitiveWriter<DecimalData> {
+    private final int precision;
+    private final int scale;
+
+    private LongDecimalWriter(ColumnDescriptor desc, int precision, int scale) 
{
+      super(desc);
+      this.precision = precision;
+      this.scale = scale;
+    }
+
+    @Override
+    public void write(int repetitionLevel, DecimalData decimal) {
+      Preconditions.checkArgument(decimal.scale() == scale,
+          "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, 
scale, decimal);
+      Preconditions.checkArgument(decimal.precision() <= precision,
+          "Cannot write value as decimal(%s,%s), too large: %s", precision, 
scale, decimal);
+
+      column.writeLong(repetitionLevel, decimal.toUnscaledLong());
+    }
+  }
+
+  private static class FixedDecimalWriter extends 
ParquetValueWriters.PrimitiveWriter<DecimalData> {
+    private final int precision;
+    private final int scale;
+    private final ThreadLocal<byte[]> bytes;
+
+    private FixedDecimalWriter(ColumnDescriptor desc, int precision, int 
scale) {
+      super(desc);
+      this.precision = precision;
+      this.scale = scale;
+      this.bytes = ThreadLocal.withInitial(() -> new 
byte[TypeUtil.decimalRequiredBytes(precision)]);
+    }
+
+    @Override
+    public void write(int repetitionLevel, DecimalData decimal) {
+      byte[] binary = DecimalUtil.toReusedFixLengthBytes(precision, scale, 
decimal.toBigDecimal(), bytes.get());
+      column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(binary));
+    }
+  }
+
+  private static class TimestampDataWriter extends 
ParquetValueWriters.PrimitiveWriter<TimestampData> {
+    private TimestampDataWriter(ColumnDescriptor desc) {
+      super(desc);
+    }
+
+    @Override
+    public void write(int repetitionLevel, TimestampData value) {
+      column.writeLong(repetitionLevel, value.getMillisecond() * 1000 + 
value.getNanoOfMillisecond() / 1000);
+    }
+  }
+
+  private static class ByteArrayWriter extends 
ParquetValueWriters.PrimitiveWriter<byte[]> {
+    private ByteArrayWriter(ColumnDescriptor desc) {
+      super(desc);
+    }
+
+    @Override
+    public void write(int repetitionLevel, byte[] bytes) {
+      column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(bytes));
+    }
   }
 
-  public static ParquetValueWriter<Row> buildWriter(MessageType type) {
-    return INSTANCE.createWriter(type);
+  private static class ArrayDataWriter<E> extends 
ParquetValueWriters.RepeatedWriter<ArrayData, E> {
+    private final LogicalType elementType;
+
+    private ArrayDataWriter(int definitionLevel, int repetitionLevel,
+                            ParquetValueWriter<E> writer, LogicalType 
elementType) {
+      super(definitionLevel, repetitionLevel, writer);
+      this.elementType = elementType;
+    }
+
+    @Override
+    protected Iterator<E> elements(ArrayData list) {
+      return new ElementIterator<>(list);
+    }
+
+    private class ElementIterator<E> implements Iterator<E> {
+      private final int size;
+      private final ArrayData list;
+      private int index;
+
+      private ElementIterator(ArrayData list) {
+        this.list = list;
+        size = list.size();
+        index = 0;
+      }
+
+      @Override
+      public boolean hasNext() {
+        return index != size;
+      }
+
+      @Override
+      @SuppressWarnings("unchecked")
+      public E next() {
+        if (index >= size) {
+          throw new NoSuchElementException();
+        }
+
+        E element;
+        if (list.isNullAt(index)) {
+          element = null;
+        } else {
+          element = (E) 
ArrayData.createElementGetter(elementType).getElementOrNull(list, index);

Review comment:
       > That means this getter should be created in the constructor and stored 
as an instance field. Then it can be called here.
   
   Yeah,  that sounds good to me, great point. 
   
   > does this need to call getElementOrNull or should it just call a get 
variant that assumes the value is non-null?
   The `getter` in ArrayData don't have a `get`  interface,  it have only the 
interface: 
   
   ```java
        /**
         * Accessor for getting the elements of an array during runtime.
         *
         * @see #createElementGetter(LogicalType)
         */
        interface ElementGetter extends Serializable {
                @Nullable Object getElementOrNull(ArrayData array, int pos);
        }
   ```
   
   Replacing the `if-else` to be `E element = (E) getter.getElementOrNull(list, 
index);`  sounds reasonable to me.
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to