This is an automated email from the ASF dual-hosted git repository.

czweng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ae1414128 [core] [fix] Partial update should not return empty row 
with non-null column type exist (#5077)
3ae1414128 is described below

commit 3ae1414128afc7c1303ceaefa01064cc9b6f9343
Author: YeJunHao <[email protected]>
AuthorDate: Tue Feb 25 11:09:29 2025 +0800

    [core] [fix] Partial update should not return empty row with non-null 
column type exist (#5077)
    
    This closes #5077.
---
 .../compact/PartialUpdateMergeFunction.java        |  62 ++++++-
 .../compact/PartialUpdateMergeFunctionTest.java    |  25 +++
 .../paimon/table/PartialUpdateTableTest.java       | 132 ++++++++++++++
 .../parquet/writer/ParquetRowDataWriter.java       | 193 +++++++++++++++------
 4 files changed, 350 insertions(+), 62 deletions(-)

diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java
 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java
index 2497de0893..faddc8c4ec 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java
@@ -31,6 +31,7 @@ import org.apache.paimon.types.DataField;
 import org.apache.paimon.types.DataType;
 import org.apache.paimon.types.RowKind;
 import org.apache.paimon.types.RowType;
+import org.apache.paimon.utils.ArrayUtils;
 import org.apache.paimon.utils.FieldsComparator;
 import org.apache.paimon.utils.Preconditions;
 import org.apache.paimon.utils.Projection;
@@ -73,12 +74,20 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
     private final Map<Integer, FieldAggregator> fieldAggregators;
     private final boolean removeRecordOnDelete;
     private final Set<Integer> sequenceGroupPartialDelete;
+    private final boolean[] nullables;
 
     private InternalRow currentKey;
     private long latestSequenceNumber;
     private GenericRow row;
     private KeyValue reused;
     private boolean currentDeleteRow;
+    private boolean notNullColumnFilled;
+    /**
+     * If the first value is retract, and no insert record is received, the 
row kind should be
+     * RowKind.DELETE. (Partial update sequence group may not correctly set 
currentDeleteRow if no
+     * RowKind.INSERT value is received)
+     */
+    private boolean meetInsert;
 
     protected PartialUpdateMergeFunction(
             InternalRow.FieldGetter[] getters,
@@ -87,7 +96,8 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
             Map<Integer, FieldAggregator> fieldAggregators,
             boolean fieldSequenceEnabled,
             boolean removeRecordOnDelete,
-            Set<Integer> sequenceGroupPartialDelete) {
+            Set<Integer> sequenceGroupPartialDelete,
+            boolean[] nullables) {
         this.getters = getters;
         this.ignoreDelete = ignoreDelete;
         this.fieldSeqComparators = fieldSeqComparators;
@@ -95,11 +105,14 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
         this.fieldSequenceEnabled = fieldSequenceEnabled;
         this.removeRecordOnDelete = removeRecordOnDelete;
         this.sequenceGroupPartialDelete = sequenceGroupPartialDelete;
+        this.nullables = nullables;
     }
 
     @Override
     public void reset() {
         this.currentKey = null;
+        this.meetInsert = false;
+        this.notNullColumnFilled = false;
         this.row = new GenericRow(getters.length);
         fieldAggregators.values().forEach(FieldAggregator::reset);
     }
@@ -109,14 +122,21 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
         // refresh key object to avoid reference overwritten
         currentKey = kv.key();
         currentDeleteRow = false;
-
         if (kv.valueKind().isRetract()) {
+
+            if (!notNullColumnFilled) {
+                initRow(row, kv.value());
+                notNullColumnFilled = true;
+            }
+
             // In 0.7- versions, the delete records might be written into data 
file even when
             // ignore-delete configured, so ignoreDelete still needs to be 
checked
             if (ignoreDelete) {
                 return;
             }
 
+            latestSequenceNumber = kv.sequenceNumber();
+
             if (fieldSequenceEnabled) {
                 retractWithSequenceGroup(kv);
                 return;
@@ -126,6 +146,7 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
                 if (kv.valueKind() == RowKind.DELETE) {
                     currentDeleteRow = true;
                     row = new GenericRow(getters.length);
+                    initRow(row, kv.value());
                 }
                 return;
             }
@@ -148,6 +169,8 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
         } else {
             updateWithSequenceGroup(kv);
         }
+        meetInsert = true;
+        notNullColumnFilled = true;
     }
 
     private void updateNonNullFields(KeyValue kv) {
@@ -155,6 +178,10 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
             Object field = getters[i].getFieldOrNull(kv.value());
             if (field != null) {
                 row.setField(i, field);
+            } else {
+                if (!nullables[i]) {
+                    throw new IllegalArgumentException("Field " + i + " can 
not be null");
+                }
             }
         }
     }
@@ -232,6 +259,7 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
                                         && 
sequenceGroupPartialDelete.contains(field)) {
                                     currentDeleteRow = true;
                                     row = new GenericRow(getters.length);
+                                    initRow(row, kv.value());
                                     return;
                                 } else {
                                     row.setField(field, 
getters[field].getFieldOrNull(kv.value()));
@@ -263,13 +291,26 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
         }
     }
 
+    private void initRow(GenericRow row, InternalRow value) {
+        for (int i = 0; i < getters.length; i++) {
+            Object field = getters[i].getFieldOrNull(value);
+            if (!nullables[i]) {
+                if (field != null) {
+                    row.setField(i, field);
+                } else {
+                    throw new IllegalArgumentException("Field " + i + " can 
not be null");
+                }
+            }
+        }
+    }
+
     @Override
     public KeyValue getResult() {
         if (reused == null) {
             reused = new KeyValue();
         }
 
-        RowKind rowKind = currentDeleteRow ? RowKind.DELETE : RowKind.INSERT;
+        RowKind rowKind = currentDeleteRow || !meetInsert ? RowKind.DELETE : 
RowKind.INSERT;
         return reused.replace(currentKey, latestSequenceNumber, rowKind, row);
     }
 
@@ -442,14 +483,19 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
                     }
                 }
 
+                List<DataType> projectedTypes = 
Projection.of(projection).project(tableTypes);
                 return new PartialUpdateMergeFunction(
-                        
createFieldGetters(Projection.of(projection).project(tableTypes)),
+                        createFieldGetters(projectedTypes),
                         ignoreDelete,
                         projectedSeqComparators,
                         projectedAggregators,
                         !fieldSeqComparators.isEmpty(),
                         removeRecordOnDelete,
-                        sequenceGroupPartialDelete);
+                        sequenceGroupPartialDelete,
+                        ArrayUtils.toPrimitiveBoolean(
+                                projectedTypes.stream()
+                                        .map(DataType::isNullable)
+                                        .toArray(Boolean[]::new)));
             } else {
                 Map<Integer, FieldsComparator> fieldSeqComparators = new 
HashMap<>();
                 this.fieldSeqComparators.forEach(
@@ -464,7 +510,11 @@ public class PartialUpdateMergeFunction implements 
MergeFunction<KeyValue> {
                         fieldAggregators,
                         !fieldSeqComparators.isEmpty(),
                         removeRecordOnDelete,
-                        sequenceGroupPartialDelete);
+                        sequenceGroupPartialDelete,
+                        ArrayUtils.toPrimitiveBoolean(
+                                rowType.getFieldTypes().stream()
+                                        .map(DataType::isNullable)
+                                        .toArray(Boolean[]::new)));
             }
         }
 
diff --git 
a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java
 
b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java
index 28625a9bf3..5e88d2758e 100644
--- 
a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java
+++ 
b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java
@@ -857,6 +857,31 @@ public class PartialUpdateMergeFunctionTest {
                                 "Must use sequence group for aggregation 
functions but not found for field f1."));
     }
 
+    @Test
+    public void testDeleteReproduceCorrectSequenceNumber() {
+        Options options = new Options();
+        options.set("partial-update.remove-record-on-delete", "true");
+        RowType rowType =
+                RowType.of(
+                        DataTypes.INT(),
+                        DataTypes.INT(),
+                        DataTypes.INT(),
+                        DataTypes.INT(),
+                        DataTypes.INT());
+
+        MergeFunctionFactory<KeyValue> factory =
+                PartialUpdateMergeFunction.factory(options, rowType, 
ImmutableList.of("f0"));
+
+        MergeFunction<KeyValue> func = factory.create();
+
+        func.reset();
+
+        add(func, RowKind.INSERT, 1, 1, 1, 1, 1);
+        add(func, RowKind.DELETE, 1, 1, 1, 1, 1);
+
+        assertThat(func.getResult().sequenceNumber()).isEqualTo(1);
+    }
+
     private void add(MergeFunction<KeyValue> function, Integer... f) {
         add(function, RowKind.INSERT, f);
     }
diff --git 
a/paimon-core/src/test/java/org/apache/paimon/table/PartialUpdateTableTest.java 
b/paimon-core/src/test/java/org/apache/paimon/table/PartialUpdateTableTest.java
new file mode 100644
index 0000000000..21d824aa99
--- /dev/null
+++ 
b/paimon-core/src/test/java/org/apache/paimon/table/PartialUpdateTableTest.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.table;
+
+import org.apache.paimon.catalog.Catalog;
+import org.apache.paimon.catalog.CatalogContext;
+import org.apache.paimon.catalog.CatalogFactory;
+import org.apache.paimon.catalog.Identifier;
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.GenericRow;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.disk.IOManagerImpl;
+import org.apache.paimon.fs.Path;
+import org.apache.paimon.options.CatalogOptions;
+import org.apache.paimon.options.Options;
+import org.apache.paimon.reader.RecordReader;
+import org.apache.paimon.schema.Schema;
+import org.apache.paimon.table.sink.StreamTableCommit;
+import org.apache.paimon.table.sink.StreamTableWrite;
+import org.apache.paimon.table.sink.StreamWriteBuilder;
+import org.apache.paimon.table.source.ReadBuilder;
+import org.apache.paimon.table.source.TableScan;
+import org.apache.paimon.types.DataTypes;
+import org.apache.paimon.types.RowKind;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.paimon.types.DataTypesTest.assertThat;
+
+/** Test partial update table. */
+public class PartialUpdateTableTest {
+
+    @TempDir public static java.nio.file.Path tempDir;
+    private Catalog catalog;
+    private final Identifier identifier = Identifier.create("my_db", 
"my_table");
+
+    @BeforeEach
+    public void before() throws Exception {
+        Options options = new Options();
+        options.set(CatalogOptions.WAREHOUSE, new 
Path(path()).toUri().toString());
+        catalog = CatalogFactory.createCatalog(CatalogContext.create(options));
+        catalog.createDatabase("my_db", true);
+        catalog.createTable(identifier, schema(), true);
+    }
+
+    private String path() {
+        return tempDir.toString() + "/" + 
PartialUpdateTableTest.class.getSimpleName();
+    }
+
+    private static Schema schema() {
+        Schema.Builder schemaBuilder = Schema.newBuilder();
+        schemaBuilder.column("biz_no", DataTypes.INT());
+        schemaBuilder.column("customer_id", DataTypes.STRING());
+        schemaBuilder.column("payable_amount", DataTypes.INT());
+        schemaBuilder.column("g1", DataTypes.INT());
+        schemaBuilder.primaryKey("biz_no");
+        schemaBuilder.option("bucket", "1");
+        schemaBuilder.option("file.format", "parquet");
+        schemaBuilder.option("merge-engine", "partial-update");
+        schemaBuilder.option("fields.g1.sequence-group", "payable_amount");
+        schemaBuilder.option("fields.payable_amount.aggregation-function", 
"sum");
+        schemaBuilder.option("deletion-vectors.enabled", "true");
+        schemaBuilder.option("write-buffer-spillable", "true");
+        return schemaBuilder.build();
+    }
+
+    @Test
+    public void testWriteDeleteRecordWithNoInsertData() throws Exception {
+        Table table = catalog.getTable(identifier);
+        StreamWriteBuilder writeBuilder = table.newStreamWriteBuilder();
+        try (StreamTableCommit commit = writeBuilder.newCommit();
+                StreamTableWrite write = writeBuilder.newWrite()) {
+            write.withIOManager(new IOManagerImpl(tempDir.toString()));
+            for (int snapshotId = 0; snapshotId < 100; snapshotId++) {
+                int bizNo = snapshotId;
+                String customerId = String.valueOf(snapshotId);
+                int payableAmount = 1;
+                int g1 = 1;
+                write.write(
+                        GenericRow.ofKind(
+                                snapshotId == 0 || snapshotId == 10
+                                        ? RowKind.DELETE
+                                        : RowKind.INSERT,
+                                bizNo,
+                                BinaryString.fromString(customerId),
+                                payableAmount,
+                                g1));
+                commit.commit(snapshotId, write.prepareCommit(true, 
snapshotId));
+            }
+        }
+
+        ReadBuilder builder = table.newReadBuilder();
+        TableScan scan = builder.newScan();
+        TableScan.Plan plan = scan.plan();
+
+        AtomicInteger i = new AtomicInteger(0);
+        try (RecordReader<InternalRow> reader = 
builder.newRead().createReader(plan)) {
+            reader.forEachRemaining(
+                    row -> {
+                        if (i.get() == 0 || i.get() == 10) {
+                            i.incrementAndGet();
+                        }
+                        int index = i.get();
+                        assertThat(row.getInt(0)).isEqualTo(index);
+                        
assertThat(row.getString(1).toString()).isEqualTo(String.valueOf(index));
+                        assertThat(row.getInt(2)).isEqualTo(1);
+                        assertThat(row.getInt(3)).isEqualTo(1);
+                        i.incrementAndGet();
+                    });
+        }
+    }
+}
diff --git 
a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java
 
b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java
index 3f9f14af45..e563a36b56 100644
--- 
a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java
+++ 
b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java
@@ -47,6 +47,7 @@ import java.nio.ByteOrder;
 import java.util.Arrays;
 import java.util.List;
 
+import static java.lang.String.format;
 import static 
org.apache.paimon.format.parquet.ParquetSchemaConverter.computeMinBytesForDecimalPrecision;
 import static 
org.apache.paimon.format.parquet.reader.TimestampColumnReader.JULIAN_EPOCH_OFFSET_DAYS;
 import static 
org.apache.paimon.format.parquet.reader.TimestampColumnReader.MILLIS_IN_DAY;
@@ -61,7 +62,7 @@ public class ParquetRowDataWriter {
 
     public ParquetRowDataWriter(RecordConsumer recordConsumer, RowType 
rowType, GroupType schema) {
         this.recordConsumer = recordConsumer;
-        this.rowWriter = new RowWriter(rowType, schema);
+        this.rowWriter = new RowWriter(rowType, schema, false);
     }
 
     /**
@@ -80,35 +81,37 @@ public class ParquetRowDataWriter {
             switch (t.getTypeRoot()) {
                 case CHAR:
                 case VARCHAR:
-                    return new StringWriter();
+                    return new StringWriter(t.isNullable());
                 case BOOLEAN:
-                    return new BooleanWriter();
+                    return new BooleanWriter(t.isNullable());
                 case BINARY:
                 case VARBINARY:
-                    return new BinaryWriter();
+                    return new BinaryWriter(t.isNullable());
                 case DECIMAL:
                     DecimalType decimalType = (DecimalType) t;
-                    return createDecimalWriter(decimalType.getPrecision(), 
decimalType.getScale());
+                    return createDecimalWriter(
+                            decimalType.getPrecision(), 
decimalType.getScale(), t.isNullable());
                 case TINYINT:
-                    return new ByteWriter();
+                    return new ByteWriter(t.isNullable());
                 case SMALLINT:
-                    return new ShortWriter();
+                    return new ShortWriter(t.isNullable());
                 case DATE:
                 case TIME_WITHOUT_TIME_ZONE:
                 case INTEGER:
-                    return new IntWriter();
+                    return new IntWriter(t.isNullable());
                 case BIGINT:
-                    return new LongWriter();
+                    return new LongWriter(t.isNullable());
                 case FLOAT:
-                    return new FloatWriter();
+                    return new FloatWriter(t.isNullable());
                 case DOUBLE:
-                    return new DoubleWriter();
+                    return new DoubleWriter(t.isNullable());
                 case TIMESTAMP_WITHOUT_TIME_ZONE:
                     TimestampType timestampType = (TimestampType) t;
-                    return createTimestampWriter(timestampType.getPrecision());
+                    return createTimestampWriter(timestampType.getPrecision(), 
t.isNullable());
                 case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
                     LocalZonedTimestampType localZonedTimestampType = 
(LocalZonedTimestampType) t;
-                    return 
createTimestampWriter(localZonedTimestampType.getPrecision());
+                    return createTimestampWriter(
+                            localZonedTimestampType.getPrecision(), 
t.isNullable());
                 default:
                     throw new UnsupportedOperationException("Unsupported type: 
" + type);
             }
@@ -118,43 +121,63 @@ public class ParquetRowDataWriter {
 
             if (t instanceof ArrayType
                     && annotation instanceof 
LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
-                return new ArrayWriter(((ArrayType) t).getElementType(), 
groupType);
+                return new ArrayWriter(((ArrayType) t).getElementType(), 
groupType, t.isNullable());
             } else if (t instanceof MapType
                     && annotation instanceof 
LogicalTypeAnnotation.MapLogicalTypeAnnotation) {
                 return new MapWriter(
-                        ((MapType) t).getKeyType(), ((MapType) 
t).getValueType(), groupType);
+                        ((MapType) t).getKeyType(),
+                        ((MapType) t).getValueType(),
+                        groupType,
+                        t.isNullable());
             } else if (t instanceof MultisetType
                     && annotation instanceof 
LogicalTypeAnnotation.MapLogicalTypeAnnotation) {
                 return new MapWriter(
-                        ((MultisetType) t).getElementType(), new 
IntType(false), groupType);
+                        ((MultisetType) t).getElementType(),
+                        new IntType(false),
+                        groupType,
+                        t.isNullable());
             } else if (t instanceof RowType && type instanceof GroupType) {
-                return new RowWriter((RowType) t, groupType);
+                return new RowWriter((RowType) t, groupType, t.isNullable());
             } else if (t instanceof VariantType && type instanceof GroupType) {
-                return new VariantWriter();
+                return new VariantWriter(t.isNullable());
             } else {
                 throw new UnsupportedOperationException("Unsupported type: " + 
type);
             }
         }
     }
 
-    private FieldWriter createTimestampWriter(int precision) {
+    private FieldWriter createTimestampWriter(int precision, boolean 
isNullable) {
         if (precision <= 3) {
-            return new TimestampMillsWriter(precision);
+            return new TimestampMillsWriter(precision, isNullable);
         } else if (precision > 6) {
-            return new TimestampInt96Writer(precision);
+            return new TimestampInt96Writer(precision, isNullable);
         } else {
-            return new TimestampMicrosWriter(precision);
+            return new TimestampMicrosWriter(precision, isNullable);
         }
     }
 
-    private interface FieldWriter {
+    private abstract static class FieldWriter {
 
-        void write(InternalRow row, int ordinal);
+        private final boolean isNullable;
 
-        void write(InternalArray arrayData, int ordinal);
+        public FieldWriter(boolean isNullable) {
+            this.isNullable = isNullable;
+        }
+
+        abstract void write(InternalRow row, int ordinal);
+
+        abstract void write(InternalArray arrayData, int ordinal);
+
+        public boolean isNullable() {
+            return isNullable;
+        }
     }
 
-    private class BooleanWriter implements FieldWriter {
+    private class BooleanWriter extends FieldWriter {
+
+        public BooleanWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -171,7 +194,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class ByteWriter implements FieldWriter {
+    private class ByteWriter extends FieldWriter {
+
+        public ByteWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -188,7 +215,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class ShortWriter implements FieldWriter {
+    private class ShortWriter extends FieldWriter {
+
+        public ShortWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -205,7 +236,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class LongWriter implements FieldWriter {
+    private class LongWriter extends FieldWriter {
+
+        public LongWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -222,7 +257,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class FloatWriter implements FieldWriter {
+    private class FloatWriter extends FieldWriter {
+
+        public FloatWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -239,7 +278,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class DoubleWriter implements FieldWriter {
+    private class DoubleWriter extends FieldWriter {
+
+        public DoubleWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -256,7 +299,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class StringWriter implements FieldWriter {
+    private class StringWriter extends FieldWriter {
+
+        public StringWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -273,7 +320,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class BinaryWriter implements FieldWriter {
+    private class BinaryWriter extends FieldWriter {
+
+        public BinaryWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -290,7 +341,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class IntWriter implements FieldWriter {
+    private class IntWriter extends FieldWriter {
+
+        public IntWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -307,11 +362,12 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class TimestampMillsWriter implements FieldWriter {
+    private class TimestampMillsWriter extends FieldWriter {
 
         private final int precision;
 
-        private TimestampMillsWriter(int precision) {
+        private TimestampMillsWriter(int precision, boolean isNullable) {
+            super(isNullable);
             checkArgument(precision <= 3);
             this.precision = precision;
         }
@@ -331,11 +387,12 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class TimestampMicrosWriter implements FieldWriter {
+    private class TimestampMicrosWriter extends FieldWriter {
 
         private final int precision;
 
-        private TimestampMicrosWriter(int precision) {
+        private TimestampMicrosWriter(int precision, boolean isNullable) {
+            super(isNullable);
             checkArgument(precision > 3);
             checkArgument(precision <= 6);
             this.precision = precision;
@@ -356,11 +413,12 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class TimestampInt96Writer implements FieldWriter {
+    private class TimestampInt96Writer extends FieldWriter {
 
         private final int precision;
 
-        private TimestampInt96Writer(int precision) {
+        private TimestampInt96Writer(int precision, boolean isNullable) {
+            super(isNullable);
             checkArgument(precision > 6);
             this.precision = precision;
         }
@@ -381,7 +439,7 @@ public class ParquetRowDataWriter {
     }
 
     /** It writes a map field to parquet, both key and value are nullable. */
-    private class MapWriter implements FieldWriter {
+    private class MapWriter extends FieldWriter {
 
         private final String repeatedGroupName;
         private final String keyName;
@@ -389,7 +447,9 @@ public class ParquetRowDataWriter {
         private final FieldWriter keyWriter;
         private final FieldWriter valueWriter;
 
-        private MapWriter(DataType keyType, DataType valueType, GroupType 
groupType) {
+        private MapWriter(
+                DataType keyType, DataType valueType, GroupType groupType, 
boolean isNullable) {
+            super(isNullable);
             // Get the internal map structure (MAP_KEY_VALUE)
             GroupType repeatedType = groupType.getType(0).asGroupType();
             this.repeatedGroupName = repeatedType.getName();
@@ -453,14 +513,14 @@ public class ParquetRowDataWriter {
     }
 
     /** It writes an array type field to parquet. */
-    private class ArrayWriter implements FieldWriter {
+    private class ArrayWriter extends FieldWriter {
 
         private final String elementName;
         private final FieldWriter elementWriter;
         private final String repeatedGroupName;
 
-        private ArrayWriter(DataType t, GroupType groupType) {
-
+        private ArrayWriter(DataType t, GroupType groupType, boolean 
isNullable) {
+            super(isNullable);
             // Get the internal array structure
             GroupType repeatedType = groupType.getType(0).asGroupType();
             this.repeatedGroupName = repeatedType.getName();
@@ -504,11 +564,12 @@ public class ParquetRowDataWriter {
     }
 
     /** It writes a row type field to parquet. */
-    private class RowWriter implements FieldWriter {
+    private class RowWriter extends FieldWriter {
         private final FieldWriter[] fieldWriters;
         private final String[] fieldNames;
 
-        public RowWriter(RowType rowType, GroupType groupType) {
+        public RowWriter(RowType rowType, GroupType groupType, boolean 
isNullable) {
+            super(isNullable);
             this.fieldNames = rowType.getFieldNames().toArray(new String[0]);
             List<DataType> fieldTypes = rowType.getFieldTypes();
             this.fieldWriters = new FieldWriter[rowType.getFieldCount()];
@@ -526,6 +587,13 @@ public class ParquetRowDataWriter {
                     recordConsumer.startField(fieldName, i);
                     writer.write(row, i);
                     recordConsumer.endField(fieldName, i);
+                } else {
+                    if (!fieldWriters[i].isNullable()) {
+                        throw new IllegalArgumentException(
+                                format(
+                                        "Parquet does not support null values 
in non-nullable fields. Field name : %s expected not null but found null",
+                                        fieldNames[i]));
+                    }
                 }
             }
         }
@@ -547,7 +615,11 @@ public class ParquetRowDataWriter {
         }
     }
 
-    private class VariantWriter implements FieldWriter {
+    private class VariantWriter extends FieldWriter {
+
+        public VariantWriter(boolean isNullable) {
+            super(isNullable);
+        }
 
         @Override
         public void write(InternalRow row, int ordinal) {
@@ -587,14 +659,18 @@ public class ParquetRowDataWriter {
         return Binary.fromConstantByteBuffer(buf);
     }
 
-    private FieldWriter createDecimalWriter(int precision, int scale) {
+    private FieldWriter createDecimalWriter(int precision, int scale, boolean 
isNullable) {
         checkArgument(
                 precision <= DecimalType.MAX_PRECISION,
                 "Decimal precision %s exceeds max precision %s",
                 precision,
                 DecimalType.MAX_PRECISION);
 
-        class Int32Writer implements FieldWriter {
+        class Int32Writer extends FieldWriter {
+
+            public Int32Writer(boolean isNullable) {
+                super(isNullable);
+            }
 
             @Override
             public void write(InternalArray arrayData, int ordinal) {
@@ -614,7 +690,11 @@ public class ParquetRowDataWriter {
             }
         }
 
-        class Int64Writer implements FieldWriter {
+        class Int64Writer extends FieldWriter {
+
+            public Int64Writer(boolean isNullable) {
+                super(isNullable);
+            }
 
             @Override
             public void write(InternalArray arrayData, int ordinal) {
@@ -634,11 +714,12 @@ public class ParquetRowDataWriter {
             }
         }
 
-        class UnscaledBytesWriter implements FieldWriter {
+        class UnscaledBytesWriter extends FieldWriter {
             private final int numBytes;
             private final byte[] decimalBuffer;
 
-            private UnscaledBytesWriter() {
+            private UnscaledBytesWriter(boolean isNullable) {
+                super(isNullable);
                 this.numBytes = computeMinBytesForDecimalPrecision(precision);
                 this.decimalBuffer = new byte[numBytes];
             }
@@ -672,11 +753,11 @@ public class ParquetRowDataWriter {
         }
 
         if (ParquetSchemaConverter.is32BitDecimal(precision)) {
-            return new Int32Writer();
+            return new Int32Writer(isNullable);
         } else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
-            return new Int64Writer();
+            return new Int64Writer(isNullable);
         } else {
-            return new UnscaledBytesWriter();
+            return new UnscaledBytesWriter(isNullable);
         }
     }
 }

Reply via email to