This is an automated email from the ASF dual-hosted git repository.

biyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 4aecb01b07 [spark] Support ACCEPT_ANY_SCHEMA for spark v2 write (#6281)
4aecb01b07 is described below

commit 4aecb01b079eb0d4ee4889d51ca5057b7b13f51a
Author: Kerwin Zhang <[email protected]>
AuthorDate: Tue Sep 23 22:23:24 2025 +0800

    [spark] Support ACCEPT_ANY_SCHEMA for spark v2 write (#6281)
---
 .../paimon/spark/SparkInternalRowWrapper.java      | 148 ++++++++--
 .../scala/org/apache/paimon/spark/SparkTable.scala |   3 +-
 .../paimon/spark/commands/SchemaHelper.scala       |  34 ++-
 .../apache/paimon/spark/write/PaimonV2Write.scala  |  35 ++-
 .../paimon/spark/write/PaimonV2WriteBuilder.scala  |   5 +-
 .../paimon/spark/sql/V2WriteMergeSchemaTest.scala  | 319 +++++++++++++++++++++
 6 files changed, 491 insertions(+), 53 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
index 0d00495c69..7de1695af0 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
@@ -43,6 +43,8 @@ import org.apache.spark.sql.types.TimestampType;
 
 import java.io.Serializable;
 import java.math.BigDecimal;
+import java.util.HashMap;
+import java.util.Map;
 
 /** Wrapper to fetch value from the spark internal row. */
 public class SparkInternalRowWrapper implements InternalRow, Serializable {
@@ -50,23 +52,32 @@ public class SparkInternalRowWrapper implements 
InternalRow, Serializable {
     private transient org.apache.spark.sql.catalyst.InternalRow internalRow;
     private final int length;
     private final int rowKindIdx;
-    private final StructType structType;
+    private final StructType tableSchema;
+    private int[] fieldIndexMap = null;
 
     public SparkInternalRowWrapper(
             org.apache.spark.sql.catalyst.InternalRow internalRow,
             int rowKindIdx,
-            StructType structType,
+            StructType tableSchema,
             int length) {
         this.internalRow = internalRow;
         this.rowKindIdx = rowKindIdx;
         this.length = length;
-        this.structType = structType;
+        this.tableSchema = tableSchema;
     }
 
-    public SparkInternalRowWrapper(int rowKindIdx, StructType structType, int 
length) {
+    public SparkInternalRowWrapper(int rowKindIdx, StructType tableSchema, int 
length) {
         this.rowKindIdx = rowKindIdx;
         this.length = length;
-        this.structType = structType;
+        this.tableSchema = tableSchema;
+    }
+
+    public SparkInternalRowWrapper(
+            int rowKindIdx, StructType tableSchema, StructType dataSchema, int 
length) {
+        this.rowKindIdx = rowKindIdx;
+        this.length = length;
+        this.tableSchema = tableSchema;
+        this.fieldIndexMap = buildFieldIndexMap(tableSchema, dataSchema);
     }
 
     public SparkInternalRowWrapper 
replace(org.apache.spark.sql.catalyst.InternalRow internalRow) {
@@ -74,6 +85,42 @@ public class SparkInternalRowWrapper implements InternalRow, 
Serializable {
         return this;
     }
 
+    private int[] buildFieldIndexMap(StructType schemaStruct, StructType 
dataSchema) {
+        int[] mapping = new int[schemaStruct.size()];
+
+        Map<String, Integer> rowFieldIndexMap = new HashMap<>();
+        for (int i = 0; i < dataSchema.size(); i++) {
+            rowFieldIndexMap.put(dataSchema.fields()[i].name(), i);
+        }
+
+        for (int i = 0; i < schemaStruct.size(); i++) {
+            String fieldName = schemaStruct.fields()[i].name();
+            Integer index = rowFieldIndexMap.get(fieldName);
+            mapping[i] = (index != null) ? index : -1;
+        }
+
+        return mapping;
+    }
+
+    private int getActualFieldPosition(int pos) {
+        if (fieldIndexMap == null) {
+            return pos;
+        } else {
+            if (pos < 0 || pos >= fieldIndexMap.length) {
+                return -1;
+            }
+            return fieldIndexMap[pos];
+        }
+    }
+
+    private int validateAndGetActualPosition(int pos) {
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1) {
+            throw new ArrayIndexOutOfBoundsException("Field index out of 
bounds: " + pos);
+        }
+        return actualPos;
+    }
+
     @Override
     public int getFieldCount() {
         return length;
@@ -82,10 +129,12 @@ public class SparkInternalRowWrapper implements 
InternalRow, Serializable {
     @Override
     public RowKind getRowKind() {
         if (rowKindIdx != -1) {
-            return RowKind.fromByteValue(internalRow.getByte(rowKindIdx));
-        } else {
-            return RowKind.INSERT;
+            int actualPos = getActualFieldPosition(rowKindIdx);
+            if (actualPos != -1) {
+                return RowKind.fromByteValue(internalRow.getByte(actualPos));
+            }
         }
+        return RowKind.INSERT;
     }
 
     @Override
@@ -95,69 +144,102 @@ public class SparkInternalRowWrapper implements 
InternalRow, Serializable {
 
     @Override
     public boolean isNullAt(int pos) {
-        return internalRow.isNullAt(pos);
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1) {
+            return true;
+        }
+        return internalRow.isNullAt(actualPos);
     }
 
     @Override
     public boolean getBoolean(int pos) {
-        return internalRow.getBoolean(pos);
+        int actualPos = validateAndGetActualPosition(pos);
+        return internalRow.getBoolean(actualPos);
     }
 
     @Override
     public byte getByte(int pos) {
-        return internalRow.getByte(pos);
+        int actualPos = validateAndGetActualPosition(pos);
+        return internalRow.getByte(actualPos);
     }
 
     @Override
     public short getShort(int pos) {
-        return internalRow.getShort(pos);
+        int actualPos = validateAndGetActualPosition(pos);
+        return internalRow.getShort(actualPos);
     }
 
     @Override
     public int getInt(int pos) {
-        return internalRow.getInt(pos);
+        int actualPos = validateAndGetActualPosition(pos);
+        return internalRow.getInt(actualPos);
     }
 
     @Override
     public long getLong(int pos) {
-        return internalRow.getLong(pos);
+        int actualPos = validateAndGetActualPosition(pos);
+        return internalRow.getLong(actualPos);
     }
 
     @Override
     public float getFloat(int pos) {
-        return internalRow.getFloat(pos);
+        int actualPos = validateAndGetActualPosition(pos);
+        return internalRow.getFloat(actualPos);
     }
 
     @Override
     public double getDouble(int pos) {
-        return internalRow.getDouble(pos);
+        int actualPos = validateAndGetActualPosition(pos);
+        return internalRow.getDouble(actualPos);
     }
 
     @Override
     public BinaryString getString(int pos) {
-        return 
BinaryString.fromBytes(internalRow.getUTF8String(pos).getBytes());
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
+        return 
BinaryString.fromBytes(internalRow.getUTF8String(actualPos).getBytes());
     }
 
     @Override
     public Decimal getDecimal(int pos, int precision, int scale) {
-        org.apache.spark.sql.types.Decimal decimal = 
internalRow.getDecimal(pos, precision, scale);
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
+        org.apache.spark.sql.types.Decimal decimal =
+                internalRow.getDecimal(actualPos, precision, scale);
         BigDecimal bigDecimal = decimal.toJavaBigDecimal();
         return Decimal.fromBigDecimal(bigDecimal, precision, scale);
     }
 
     @Override
     public Timestamp getTimestamp(int pos, int precision) {
-        return convertToTimestamp(structType.fields()[pos].dataType(), 
internalRow.getLong(pos));
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
+        return convertToTimestamp(
+                tableSchema.fields()[pos].dataType(), 
internalRow.getLong(actualPos));
     }
 
     @Override
     public byte[] getBinary(int pos) {
-        return internalRow.getBinary(pos);
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
+        return internalRow.getBinary(actualPos);
     }
 
     @Override
     public Variant getVariant(int pos) {
-        return SparkShimLoader.shim().toPaimonVariant(internalRow, pos);
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
+        return SparkShimLoader.shim().toPaimonVariant(internalRow, actualPos);
     }
 
     @Override
@@ -167,24 +249,36 @@ public class SparkInternalRowWrapper implements 
InternalRow, Serializable {
 
     @Override
     public InternalArray getArray(int pos) {
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
         return new SparkInternalArray(
-                internalRow.getArray(pos),
-                ((ArrayType) 
(structType.fields()[pos].dataType())).elementType());
+                internalRow.getArray(actualPos),
+                ((ArrayType) 
(tableSchema.fields()[pos].dataType())).elementType());
     }
 
     @Override
     public InternalMap getMap(int pos) {
-        MapType mapType = (MapType) structType.fields()[pos].dataType();
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
+        MapType mapType = (MapType) tableSchema.fields()[pos].dataType();
         return new SparkInternalMap(
-                internalRow.getMap(pos), mapType.keyType(), 
mapType.valueType());
+                internalRow.getMap(actualPos), mapType.keyType(), 
mapType.valueType());
     }
 
     @Override
     public InternalRow getRow(int pos, int numFields) {
+        int actualPos = getActualFieldPosition(pos);
+        if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+            return null;
+        }
         return new SparkInternalRowWrapper(
-                internalRow.getStruct(pos, numFields),
+                internalRow.getStruct(actualPos, numFields),
                 -1,
-                (StructType) structType.fields()[pos].dataType(),
+                (StructType) tableSchema.fields()[actualPos].dataType(),
                 numFields);
     }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
index 305a7191d8..e79e148ebb 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
@@ -108,6 +108,7 @@ case class SparkTable(table: Table)
     )
 
     if (useV2Write) {
+      capabilities.add(TableCapability.ACCEPT_ANY_SCHEMA)
       capabilities.add(TableCapability.BATCH_WRITE)
       capabilities.add(TableCapability.OVERWRITE_DYNAMIC)
     } else {
@@ -152,7 +153,7 @@ case class SparkTable(table: Table)
       case fileStoreTable: FileStoreTable =>
         val options = Options.fromMap(info.options)
         if (useV2Write) {
-          new PaimonV2WriteBuilder(fileStoreTable, info.schema())
+          new PaimonV2WriteBuilder(fileStoreTable, info.schema(), options)
         } else {
           new PaimonWriteBuilder(fileStoreTable, options)
         }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
index d66a941929..06f749b8ec 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
@@ -40,19 +40,8 @@ private[spark] trait SchemaHelper extends WithFileStoreTable 
{
   override def table: FileStoreTable = newTable.getOrElse(originTable)
 
   def mergeSchema(sparkSession: SparkSession, input: DataFrame, options: 
Options): DataFrame = {
-    val mergeSchemaEnabled =
-      options.get(SparkConnectorOptions.MERGE_SCHEMA) || 
OptionUtils.writeMergeSchemaEnabled()
-    if (!mergeSchemaEnabled) {
-      return input
-    }
-
     val dataSchema = SparkSystemColumns.filterSparkSystemColumns(input.schema)
-    val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST) 
|| OptionUtils
-      .writeMergeSchemaExplicitCastEnabled()
-    mergeAndCommitSchema(dataSchema, allowExplicitCast)
-
-    // For case that some columns is absent in data, we still allow to write 
once write.merge-schema is true.
-    val newTableSchema = 
SparkTypeUtils.fromPaimonRowType(table.schema().logicalRowType())
+    val newTableSchema = mergeSchema(input.schema, options)
     if (!PaimonUtils.sameType(newTableSchema, dataSchema)) {
       val resolve = sparkSession.sessionState.conf.resolver
       val cols = newTableSchema.map {
@@ -68,6 +57,27 @@ private[spark] trait SchemaHelper extends WithFileStoreTable 
{
     }
   }
 
+  def mergeSchema(dataSchema: StructType, options: Options): StructType = {
+    val mergeSchemaEnabled =
+      options.get(SparkConnectorOptions.MERGE_SCHEMA) || 
OptionUtils.writeMergeSchemaEnabled()
+    if (!mergeSchemaEnabled) {
+      return dataSchema
+    }
+
+    val filteredDataSchema = 
SparkSystemColumns.filterSparkSystemColumns(dataSchema)
+    val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST) 
|| OptionUtils
+      .writeMergeSchemaExplicitCastEnabled()
+    mergeAndCommitSchema(filteredDataSchema, allowExplicitCast)
+
+    val writeSchema = 
SparkTypeUtils.fromPaimonRowType(table.schema().logicalRowType())
+
+    if (!PaimonUtils.sameType(writeSchema, filteredDataSchema)) {
+      writeSchema
+    } else {
+      filteredDataSchema
+    }
+  }
+
   private def mergeAndCommitSchema(dataSchema: StructType, allowExplicitCast: 
Boolean): Unit = {
     val dataRowType = 
SparkTypeUtils.toPaimonType(dataSchema).asInstanceOf[RowType]
     if (table.store().mergeSchema(dataRowType, allowExplicitCast)) {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
index 8eaeffe2fc..9eaa1bf72f 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
@@ -19,7 +19,9 @@
 package org.apache.paimon.spark.write
 
 import org.apache.paimon.CoreOptions
+import org.apache.paimon.options.Options
 import org.apache.paimon.spark.{SparkInternalRowWrapper, SparkUtils}
+import org.apache.paimon.spark.commands.SchemaHelper
 import org.apache.paimon.table.FileStoreTable
 import org.apache.paimon.table.sink.{BatchTableWrite, BatchWriteBuilder, 
CommitMessage, CommitMessageSerializer}
 
@@ -36,21 +38,24 @@ import scala.collection.JavaConverters._
 import scala.util.{Failure, Success, Try}
 
 class PaimonV2Write(
-    storeTable: FileStoreTable,
+    override val originTable: FileStoreTable,
     overwriteDynamic: Boolean,
     overwritePartitions: Option[Map[String, String]],
-    writeSchema: StructType
+    dataSchema: StructType,
+    options: Options
 ) extends Write
   with RequiresDistributionAndOrdering
+  with SchemaHelper
   with Logging {
 
   assert(
     !(overwriteDynamic && overwritePartitions.exists(_.nonEmpty)),
     "Cannot overwrite dynamically and by filter both")
 
-  private val table =
-    storeTable.copy(
-      Map(CoreOptions.DYNAMIC_PARTITION_OVERWRITE.key -> 
overwriteDynamic.toString).asJava)
+  private val writeSchema = mergeSchema(dataSchema, options)
+
+  updateTableWithOptions(
+    Map(CoreOptions.DYNAMIC_PARTITION_OVERWRITE.key -> 
overwriteDynamic.toString))
 
   private val writeRequirement = PaimonWriteRequirement(table)
 
@@ -66,7 +71,8 @@ class PaimonV2Write(
     ordering
   }
 
-  override def toBatch: BatchWrite = PaimonBatchWrite(table, writeSchema, 
overwritePartitions)
+  override def toBatch: BatchWrite =
+    PaimonBatchWrite(table, writeSchema, dataSchema, overwritePartitions)
 
   override def toString: String = {
     val overwriteDynamicStr = if (overwriteDynamic) {
@@ -86,6 +92,7 @@ class PaimonV2Write(
 private case class PaimonBatchWrite(
     table: FileStoreTable,
     writeSchema: StructType,
+    dataSchema: StructType,
     overwritePartitions: Option[Map[String, String]])
   extends BatchWrite
   with WriteHelper {
@@ -97,7 +104,7 @@ private case class PaimonBatchWrite(
   }
 
   override def createBatchWriterFactory(info: PhysicalWriteInfo): 
DataWriterFactory =
-    WriterFactory(writeSchema, batchWriteBuilder)
+    WriterFactory(writeSchema, dataSchema, batchWriteBuilder)
 
   override def useCommitCoordinator(): Boolean = false
 
@@ -129,16 +136,22 @@ private case class PaimonBatchWrite(
   }
 }
 
-private case class WriterFactory(writeSchema: StructType, batchWriteBuilder: 
BatchWriteBuilder)
+private case class WriterFactory(
+    writeSchema: StructType,
+    dataSchema: StructType,
+    batchWriteBuilder: BatchWriteBuilder)
   extends DataWriterFactory {
 
   override def createWriter(partitionId: Int, taskId: Long): 
DataWriter[InternalRow] = {
     val batchTableWrite = batchWriteBuilder.newWrite()
-    new PaimonDataWriter(batchTableWrite, writeSchema)
+    new PaimonDataWriter(batchTableWrite, writeSchema, dataSchema)
   }
 }
 
-private class PaimonDataWriter(batchTableWrite: BatchTableWrite, writeSchema: 
StructType)
+private class PaimonDataWriter(
+    batchTableWrite: BatchTableWrite,
+    writeSchema: StructType,
+    dataSchema: StructType)
   extends DataWriter[InternalRow] {
 
   private val ioManager = SparkUtils.createIOManager()
@@ -146,7 +159,7 @@ private class PaimonDataWriter(batchTableWrite: 
BatchTableWrite, writeSchema: St
 
   private val rowConverter: InternalRow => SparkInternalRowWrapper = {
     val numFields = writeSchema.fields.length
-    val reusableWrapper = new SparkInternalRowWrapper(-1, writeSchema, 
numFields)
+    val reusableWrapper = new SparkInternalRowWrapper(-1, writeSchema, 
dataSchema, numFields)
     record => reusableWrapper.replace(record)
   }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
index 90f30a3955..d6b747a53f 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
@@ -18,13 +18,14 @@
 
 package org.apache.paimon.spark.write
 
+import org.apache.paimon.options.Options
 import org.apache.paimon.table.FileStoreTable
 
 import org.apache.spark.sql.connector.write.{SupportsDynamicOverwrite, 
SupportsOverwrite, WriteBuilder}
 import org.apache.spark.sql.sources.{And, Filter}
 import org.apache.spark.sql.types.StructType
 
-class PaimonV2WriteBuilder(table: FileStoreTable, writeSchema: StructType)
+class PaimonV2WriteBuilder(table: FileStoreTable, dataSchema: StructType, 
options: Options)
   extends BaseWriteBuilder(table)
   with SupportsOverwrite
   with SupportsDynamicOverwrite {
@@ -33,7 +34,7 @@ class PaimonV2WriteBuilder(table: FileStoreTable, 
writeSchema: StructType)
   private var overwritePartitions: Option[Map[String, String]] = None
 
   override def build =
-    new PaimonV2Write(table, overwriteDynamic, overwritePartitions, 
writeSchema)
+    new PaimonV2Write(table, overwriteDynamic, overwritePartitions, 
dataSchema, options)
 
   override def overwrite(filters: Array[Filter]): WriteBuilder = {
     if (overwriteDynamic) {
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala
new file mode 100644
index 0000000000..0b6e589d9f
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.Row
+
+class V2WriteMergeSchemaTest extends PaimonSparkTestBase {
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set("spark.sql.catalog.paimon.cache-enabled", "false")
+      .set("spark.paimon.write.use-v2-write", "true")
+      .set("spark.paimon.write.merge-schema", "true")
+      .set("spark.paimon.write.merge-schema.explicit-cast", "true")
+  }
+
+  import testImplicits._
+
+  test("Write merge schema: dataframe write") {
+    withTable("t") {
+      sql("CREATE TABLE t (a INT, b STRING)")
+      Seq((1, "1"), (2, "2"))
+        .toDF("a", "b")
+        .writeTo("t")
+        .option("write.merge-schema", "true")
+        .append()
+
+      // new columns
+      Seq((3, "3", 3))
+        .toDF("a", "b", "c")
+        .writeTo("t")
+        .option("write.merge-schema", "true")
+        .append()
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(Row(1, "1", null), Row(2, "2", null), Row(3, "3", 3))
+      )
+
+      // missing columns and new columns
+      Seq(("4", "4", 4))
+        .toDF("d", "b", "c")
+        .writeTo("t")
+        .option("write.merge-schema", "true")
+        .append()
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(null, "4", 4, "4"),
+          Row(1, "1", null, null),
+          Row(2, "2", null, null),
+          Row(3, "3", 3, null))
+      )
+    }
+  }
+
+  test("Write merge schema: sql write") {
+    withTable("t") {
+      sql("CREATE TABLE t (a INT, b STRING)")
+      sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+      // new columns
+      sql("INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, 3 AS c")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(Row(1, "1", null), Row(2, "2", null), Row(3, "3", 3))
+      )
+
+      // missing columns and new columns
+      sql("INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, 4 AS c")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(null, "4", 4, "4"),
+          Row(1, "1", null, null),
+          Row(2, "2", null, null),
+          Row(3, "3", 3, null))
+      )
+    }
+  }
+
+  test("Write merge schema: fail when merge schema is disabled but new columns 
are provided") {
+    withTable("t") {
+      withSparkSQLConf("spark.paimon.write.merge-schema" -> "false") {
+        sql("CREATE TABLE t (a INT, b STRING)")
+        sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+        val error = intercept[RuntimeException] {
+          spark.sql("INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, 3 AS c")
+        }.getMessage
+        assert(error.contains("the number of data columns don't match with the 
table schema's"))
+      }
+    }
+  }
+
+  test("Write merge schema: numeric types") {
+    withTable("t") {
+      sql("CREATE TABLE t (a INT, b STRING)")
+      sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+      // new columns with numeric types
+      sql(
+        "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+          "cast(10 as byte) AS byte_col, " +
+          "cast(1000 as short) AS short_col, " +
+          "100000 AS int_col, " +
+          "10000000000L AS long_col, " +
+          "cast(1.23 as float) AS float_col, " +
+          "4.56 AS double_col, " +
+          "cast(7.89 as decimal(10,2)) AS decimal_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(1, "1", null, null, null, null, null, null, null),
+          Row(2, "2", null, null, null, null, null, null, null),
+          Row(
+            3,
+            "3",
+            10.toByte,
+            1000.toShort,
+            100000,
+            10000000000L,
+            1.23f,
+            4.56d,
+            java.math.BigDecimal.valueOf(7.89))
+        )
+      )
+
+      // missing columns and new columns with numeric types
+      sql(
+        "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+          "cast(20 as byte) AS byte_col, " +
+          "cast(2000 as short) AS short_col, " +
+          "200000 AS int_col, " +
+          "20000000000L AS long_col, " +
+          "cast(2.34 as float) AS float_col, " +
+          "5.67 AS double_col, " +
+          "cast(8.96 as decimal(10,2)) AS decimal_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(
+            null,
+            "4",
+            20.toByte,
+            2000.toShort,
+            200000,
+            20000000000L,
+            2.34f,
+            5.67d,
+            java.math.BigDecimal.valueOf(8.96),
+            "4"),
+          Row(1, "1", null, null, null, null, null, null, null, null),
+          Row(2, "2", null, null, null, null, null, null, null, null),
+          Row(
+            3,
+            "3",
+            10.toByte,
+            1000.toShort,
+            100000,
+            10000000000L,
+            1.23f,
+            4.56d,
+            java.math.BigDecimal.valueOf(7.89),
+            null)
+        )
+      )
+    }
+  }
+
+  test("Write merge schema: date and time types") {
+    withTable("t") {
+      sql("CREATE TABLE t (a INT, b STRING)")
+      sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+      // new columns with date and time types
+      sql(
+        "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+          "cast('2023-01-01' as date) AS date_col, " +
+          "cast('2023-01-01 12:00:00' as timestamp) AS timestamp_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(1, "1", null, null),
+          Row(2, "2", null, null),
+          Row(
+            3,
+            "3",
+            java.sql.Date.valueOf("2023-01-01"),
+            java.sql.Timestamp.valueOf("2023-01-01 12:00:00"))
+        )
+      )
+
+      // missing columns and new columns with date and time types
+      sql(
+        "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+          "cast('2023-12-31' as date) AS date_col, " +
+          "cast('2023-12-31 23:59:59' as timestamp) AS timestamp_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(
+            null,
+            "4",
+            java.sql.Date.valueOf("2023-12-31"),
+            java.sql.Timestamp.valueOf("2023-12-31 23:59:59"),
+            "4"),
+          Row(1, "1", null, null, null),
+          Row(2, "2", null, null, null),
+          Row(
+            3,
+            "3",
+            java.sql.Date.valueOf("2023-01-01"),
+            java.sql.Timestamp.valueOf("2023-01-01 12:00:00"),
+            null)
+        )
+      )
+    }
+  }
+
+  test("Write merge schema: complex types") {
+    withTable("t") {
+      sql("CREATE TABLE t (a INT, b STRING)")
+      sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+      // new columns with complex types
+      sql(
+        "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+          "array(1, 2, 3) AS array_col, " +
+          "map('key1', 'value1', 'key2', 'value2') AS map_col, " +
+          "struct('x', 1) AS struct_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(1, "1", null, null, null),
+          Row(2, "2", null, null, null),
+          Row(3, "3", Array(1, 2, 3), Map("key1" -> "value1", "key2" -> 
"value2"), Row("x", 1))
+        )
+      )
+
+      // missing columns and new columns with complex types
+      sql(
+        "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+          "array(4, 5, 6) AS array_col, " +
+          "map('key3', 'value3') AS map_col, " +
+          "struct('y', 2) AS struct_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(null, "4", Array(4, 5, 6), Map("key3" -> "value3"), Row("y", 2), 
"4"),
+          Row(1, "1", null, null, null, null),
+          Row(2, "2", null, null, null, null),
+          Row(
+            3,
+            "3",
+            Array(1, 2, 3),
+            Map("key1" -> "value1", "key2" -> "value2"),
+            Row("x", 1),
+            null)
+        )
+      )
+    }
+  }
+
+  test("Write merge schema: binary and boolean types") {
+    withTable("t") {
+      sql("CREATE TABLE t (a INT, b STRING)")
+      sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+      // new columns with binary and boolean types
+      sql(
+        "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+          "cast('binary_data' as binary) AS binary_col, " +
+          "true AS boolean_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(1, "1", null, null),
+          Row(2, "2", null, null),
+          Row(3, "3", "binary_data".getBytes("UTF-8"), true)
+        )
+      )
+
+      // missing columns and new columns with binary and boolean types
+      sql(
+        "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+          "cast('more_data' as binary) AS binary_col, " +
+          "false AS boolean_col")
+      checkAnswer(
+        sql("SELECT * FROM t ORDER BY a"),
+        Seq(
+          Row(null, "4", "more_data".getBytes("UTF-8"), false, "4"),
+          Row(1, "1", null, null, null),
+          Row(2, "2", null, null, null),
+          Row(3, "3", "binary_data".getBytes("UTF-8"), true, null)
+        )
+      )
+    }
+  }
+
+}

Reply via email to