This is an automated email from the ASF dual-hosted git repository.

wombatukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new bafe1284d5e [HUDI-8629] Improve test coverage around different field 
names in MERGE INTO statements (#12641)
bafe1284d5e is described below

commit bafe1284d5e003978d9cacb5eaf363c74748131d
Author: Y Ethan Guo <[email protected]>
AuthorDate: Thu Jan 16 17:29:33 2025 -0800

    [HUDI-8629] Improve test coverage around different field names in MERGE 
INTO statements (#12641)
    
    Improve test coverage around different field names in MERGE INTO 
statements, fix tests on Spark 3.3 and add more docs
---
 .../sql/hudi/common/HoodieSparkSqlTestBase.scala   |  38 +++++-
 .../spark/sql/hudi/dml/TestInsertTable.scala       |  64 +++++++++
 .../spark/sql/hudi/dml/TestMergeIntoTable.scala    | 149 +++++++++++++++++----
 .../hudi/dml/TestPartialUpdateForMergeInto.scala   | 113 ++++++++++++++--
 4 files changed, 318 insertions(+), 46 deletions(-)

diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/common/HoodieSparkSqlTestBase.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/common/HoodieSparkSqlTestBase.scala
index 7735b658945..16187ba3fd9 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/common/HoodieSparkSqlTestBase.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/common/HoodieSparkSqlTestBase.scala
@@ -17,11 +17,10 @@
 
 package org.apache.spark.sql.hudi.common
 
-import org.apache.hadoop.fs.Path
-import org.apache.hudi.DefaultSparkRecordMerger
+import org.apache.hudi.{DefaultSparkRecordMerger, HoodieSparkUtils}
 import org.apache.hudi.HoodieFileIndex.DataSkippingFailureMode
 import org.apache.hudi.common.config.HoodieStorageConfig
-import org.apache.hudi.common.model.HoodieAvroRecordMerger
+import org.apache.hudi.common.model.{HoodieAvroRecordMerger, HoodieRecord}
 import org.apache.hudi.common.model.HoodieRecord.HoodieRecordType
 import org.apache.hudi.common.table.timeline.TimelineMetadataUtils
 import org.apache.hudi.config.HoodieWriteConfig
@@ -29,10 +28,13 @@ import org.apache.hudi.exception.ExceptionUtil.getRootCause
 import org.apache.hudi.hadoop.fs.HadoopFSUtils
 import org.apache.hudi.index.inmemory.HoodieInMemoryHashIndex
 import org.apache.hudi.testutils.HoodieClientTestUtils.{createMetaClient, 
getSparkConfForTest}
+
+import org.apache.hadoop.fs.Path
 import org.apache.spark.SparkConf
+import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import 
org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase.checkMessageContains
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.types.StructField
 import org.apache.spark.util.Utils
 import org.joda.time.DateTimeZone
 import org.scalactic.source
@@ -43,7 +45,6 @@ import java.io.File
 import java.util.TimeZone
 import java.util.concurrent.atomic.AtomicInteger
 import java.util.regex.Pattern
-import scala.util.matching.Regex
 
 class HoodieSparkSqlTestBase extends FunSuite with BeforeAndAfterAll {
   org.apache.log4j.Logger.getRootLogger.setLevel(org.apache.log4j.Level.WARN)
@@ -245,6 +246,33 @@ class HoodieSparkSqlTestBase extends FunSuite with 
BeforeAndAfterAll {
     assertResult(true)(hasException)
   }
 
+  protected def getExpectedUnresolvedColumnExceptionMessage(columnName: String,
+                                                            targetTableName: 
String): String = {
+    val targetTableFields = spark.sql(s"select * from 
$targetTableName").schema.fields
+      .map(e => (e.name, targetTableName, 
s"spark_catalog.default.$targetTableName.${e.name}"))
+    getExpectedUnresolvedColumnExceptionMessage(columnName, targetTableFields)
+  }
+
+  protected def getExpectedUnresolvedColumnExceptionMessage(columnName: String,
+                                                            fieldNameTuples: 
Seq[(String, String, String)]): String = {
+    val fieldNames = fieldNameTuples.sortBy(e => (e._1, e._2))
+      .map(e => e._3).mkString("[", ", ", "]")
+    if (HoodieSparkUtils.gteqSpark3_5) {
+      "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with 
name " +
+        s"$columnName cannot be resolved. Did you mean one of the following? 
$fieldNames."
+    } else {
+      s"cannot resolve $columnName in MERGE command given columns $fieldNames" 
+
+        (if (HoodieSparkUtils.gteqSpark3_4) "." else "")
+    }
+  }
+
+  protected def validateTableSchema(tableName: String,
+                                    expectedStructFields: List[StructField]): 
Unit = {
+    assertResult(expectedStructFields)(
+      spark.sql(s"select * from $tableName").schema.fields
+        .filter(e => 
!HoodieRecord.HOODIE_META_COLUMNS_WITH_OPERATION.contains(e.name)))
+  }
+
   def dropTypeLiteralPrefix(value: Any): Any = {
     value match {
       case s: String =>
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
index 68bbf9e1482..52a9e4f20e2 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestInsertTable.scala
@@ -30,6 +30,7 @@ import 
org.apache.hudi.exception.{HoodieDuplicateKeyException, HoodieException}
 import org.apache.hudi.execution.bulkinsert.BulkInsertSortMode
 import org.apache.hudi.index.HoodieIndex.IndexType
 import org.apache.hudi.testutils.HoodieClientTestUtils.createMetaClient
+
 import org.apache.spark.scheduler.{SparkListener, SparkListenerStageSubmitted}
 import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.hudi.HoodieSqlCommonUtils
@@ -3163,5 +3164,68 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
     }
     spark.sessionState.conf.unsetConf("hoodie.datasource.insert.dup.policy")
   }
+
+  test("Test throwing error on schema evolution in INSERT INTO") {
+    Seq("cow", "mor").foreach { tableType =>
+      withTempDir { tmp =>
+        val tableName = generateTableName
+        // Create a partitioned table
+        spark.sql(
+          s"""
+             |create table $tableName (
+             |  id int,
+             |  dt string,
+             |  name string,
+             |  price double,
+             |  ts long
+             |) using hudi
+             | tblproperties (primaryKey = 'id', type = '$tableType')
+             | partitioned by (dt)
+             | location '${tmp.getCanonicalPath}'
+             | """.stripMargin)
+
+        // INSERT INTO with same set of columns as that in the table schema 
works
+        spark.sql(
+          s"""
+             | insert into $tableName partition(dt = '2024-01-14')
+             | select 1 as id, 'a1' as name, 10 as price, 1000 as ts
+             | union
+             | select 2 as id, 'a2' as name, 20 as price, 1002 as ts
+             | """.stripMargin)
+        checkAnswer(s"select id, name, price, ts, dt from $tableName")(
+          Seq(1, "a1", 10.0, 1000, "2024-01-14"),
+          Seq(2, "a2", 20.0, 1002, "2024-01-14")
+        )
+
+        // INSERT INTO with an additional column that does not exist in the 
table schema
+        // throws an error, as INSERT INTO does not allow schema evolution
+        val sqlStatement =
+          s"""
+             | insert into $tableName partition(dt = '2024-01-14')
+             | select 3 as id, 'a3' as name, 30 as price, 1003 as ts, 'x' as 
new_col
+             | union
+             | select 2 as id, 'a2_updated' as name, 25 as price, 1005 as ts, 
'y' as new_col
+             | """.stripMargin
+        val expectedExceptionMessage = if (HoodieSparkUtils.gteqSpark3_5) {
+          "[INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS] " +
+            s"Cannot write to `spark_catalog`.`default`.`$tableName`, " +
+            "the reason is too many data columns:\n" +
+            "Table columns: `id`, `name`, `price`, `ts`.\n" +
+            "Data columns: `id`, `name`, `price`, `ts`, `new_col`."
+        } else {
+          val endingStr = if (HoodieSparkUtils.gteqSpark3_4) "." else ""
+          val tableId = if (HoodieSparkUtils.gteqSpark3_4) {
+            s"spark_catalog.default.$tableName"
+          } else {
+            s"default.$tableName"
+          }
+          s"Cannot write to '$tableId', too many data columns:\n" +
+            s"Table columns: 'id', 'name', 'price', 'ts'$endingStr\n" +
+            s"Data columns: 'id', 'name', 'price', 'ts', 'new_col'$endingStr"
+        }
+        checkExceptionContain(sqlStatement)(expectedExceptionMessage)
+      }
+    }
+  }
 }
 
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestMergeIntoTable.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestMergeIntoTable.scala
index fb9a7b7e224..cc60c686991 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestMergeIntoTable.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestMergeIntoTable.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.hudi.dml
 
-import org.apache.hudi.{DataSourceReadOptions, HoodieSparkUtils, 
ScalaAssertionSupport}
+import org.apache.hudi.{DataSourceReadOptions, ScalaAssertionSupport}
 import org.apache.hudi.DataSourceWriteOptions.SPARK_SQL_OPTIMIZED_WRITES
 import 
org.apache.hudi.config.HoodieWriteConfig.MERGE_SMALL_FILE_GROUP_CANDIDATES_LIMIT
 import org.apache.hudi.hadoop.fs.HadoopFSUtils
@@ -25,6 +25,7 @@ import org.apache.hudi.testutils.DataSourceTestUtils
 
 import org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, 
StringType, StructField}
 import org.slf4j.LoggerFactory
 
 class TestMergeIntoTable extends HoodieSparkSqlTestBase with 
ScalaAssertionSupport {
@@ -54,6 +55,11 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with 
ScalaAssertionSuppo
         // test with optimized sql merge enabled / disabled.
         spark.sql(s"set 
${SPARK_SQL_OPTIMIZED_WRITES.key()}=$sparkSqlOptimizedWrites")
 
+        val structFields = List(
+          StructField("id", IntegerType, nullable = true),
+          StructField("name", StringType, nullable = true),
+          StructField("price", DoubleType, nullable = true),
+          StructField("ts", LongType, nullable = true))
         // First merge with a extra input field 'flag' (insert a new record)
         spark.sql(
           s"""
@@ -66,22 +72,23 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
              | id = s0.id, name = s0.name, price = s0.price, ts = s0.ts
              | when not matched and flag = '1' then insert *
        """.stripMargin)
+        validateTableSchema(tableName, structFields)
         checkAnswer(s"select id, name, price, ts from $tableName")(
           Seq(1, "a1", 10.0, 1000)
         )
 
-        // Second merge (update the record)
+        // Second merge (update the record) with different field names in the 
source
         spark.sql(
           s"""
              | merge into $tableName
              | using (
-             |  select 1 as id, 'a1' as name, 10 as price, 1001 as ts
+             |  select 1 as _id, 'a1' as name, 10 as _price, 1001 as _ts
              | ) s0
-             | on s0.id = $tableName.id
+             | on s0._id = $tableName.id
              | when matched then update set
-             | id = s0.id, name = s0.name, price = s0.price + 
$tableName.price, ts = s0.ts
-             | when not matched then insert *
-       """.stripMargin)
+             | id = s0._id, name = s0.name, price = s0._price + 
$tableName.price, ts = s0._ts
+             | """.stripMargin)
+        validateTableSchema(tableName, structFields)
         checkAnswer(s"select id, name, price, ts from $tableName")(
           Seq(1, "a1", 20.0, 1001)
         )
@@ -102,6 +109,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
              | id = s0.id, name = s0.name, price = s0.price + 
$tableName.price, ts = s0.ts
              | when not matched and s0.id % 2 = 0 then insert *
        """.stripMargin)
+        validateTableSchema(tableName, structFields)
         checkAnswer(s"select id, name, price, ts from $tableName")(
           Seq(1, "a1", 30.0, 1002),
           Seq(2, "a2", 12.0, 1001)
@@ -446,6 +454,13 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
            | partitioned by(dt)
            | location '${tmp.getCanonicalPath}'
          """.stripMargin)
+      val structFields = List(
+        StructField("id", IntegerType, nullable = true),
+        StructField("name", StringType, nullable = true),
+        StructField("price", DoubleType, nullable = true),
+        StructField("ts", LongType, nullable = true),
+        StructField("dt", StringType, nullable = true))
+
       // Insert data
       spark.sql(
         s"""
@@ -457,6 +472,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
            | when not matched and s0.id % 2 = 1 then insert *
          """.stripMargin
       )
+      validateTableSchema(tableName, structFields)
       checkAnswer(s"select id,name,price,dt from $tableName")(
         Seq(1, "a1", 10, "2021-03-21")
       )
@@ -465,12 +481,14 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
         s"""
            | merge into $tableName as t0
            | using (
-           |  select 1 as id, 'a1' as name, 12 as price, 1001 as ts, 
'2021-03-21' as dt
+           |  select 1 as _id, 'a1' as name, 12 as _price, 1001 as _ts, 
'2021-03-21' as dt
            | ) as s0
-           | on t0.id = s0.id
-           | when matched and s0.id % 2 = 0 then update set *
+           | on t0.id = s0._id
+           | when matched and s0._id % 2 = 0 then update set
+           |  id = s0._id, name = s0.name, price = s0._price, ts = s0._ts, dt 
= s0.dt
          """.stripMargin
       )
+      validateTableSchema(tableName, structFields)
       checkAnswer(s"select id,name,price,dt from $tableName")(
         Seq(1, "a1", 10, "2021-03-21")
       )
@@ -479,12 +497,14 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
         s"""
            | merge into $tableName as t0
            | using (
-           |  select 1 as id, 'a1' as name, 12 as price, 1001 as ts, 
'2021-03-21' as dt
+           |  select 1 as _id, 'a1' as name, 12 as _price, 1001 as _ts, 
'2021-03-21' as dt
            | ) as s0
-           | on t0.id = s0.id
-           | when matched and s0.id % 2 = 1 then update set *
+           | on t0.id = s0._id
+           | when matched and s0._id % 2 = 1 then update set
+           |  id = s0._id, name = s0.name, price = s0._price, ts = s0._ts, dt 
= s0.dt
          """.stripMargin
       )
+      validateTableSchema(tableName, structFields)
       checkAnswer(s"select id,name,price,dt from $tableName")(
         Seq(1, "a1", 12, "2021-03-21")
       )
@@ -515,6 +535,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
            | then update set id = s_id, name = s_name, price = s_price, ts = 
s_ts, t0.dt = s0.dt
          """.stripMargin
       )
+      validateTableSchema(tableName, structFields)
       checkAnswer(s"select id,name,price,dt from $tableName order by id")(
         Seq(1, "a1", 12, "2021-03-21"),
         Seq(2, "a2", 15, "2021-03-21")
@@ -1102,10 +1123,17 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
            |  hoodie.compact.inline = 'true'
            | )
        """.stripMargin)
+      val structFields = List(
+        StructField("id", IntegerType, nullable = true),
+        StructField("name", StringType, nullable = true),
+        StructField("price", DoubleType, nullable = true),
+        StructField("ts", LongType, nullable = true))
+
       spark.sql(s"insert into $tableName values(1, 'a1', 10, 1000)")
       spark.sql(s"insert into $tableName values(2, 'a2', 10, 1000)")
       spark.sql(s"insert into $tableName values(3, 'a3', 10, 1000)")
       spark.sql(s"insert into $tableName values(4, 'a4', 10, 1000)")
+      validateTableSchema(tableName, structFields)
       checkAnswer(s"select id, name, price, ts from $tableName order by id")(
         Seq(1, "a1", 10.0, 1000),
         Seq(2, "a2", 10.0, 1000),
@@ -1117,12 +1145,14 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
         s"""
            |merge into $tableName h0
            |using (
-           | select 4 as id, 'a4' as name, 11 as price, 1000 as ts
+           | select 4 as _id, 'a4' as name, 11 as _price, 1000 as ts
            | ) s0
-           | on h0.id = s0.id
-           | when matched then update set *
+           | on h0.id = s0._id
+           | when matched then update set
+           |   id = s0._id, name = s0.name, price = s0._price, ts = s0.ts
            |""".stripMargin)
 
+      validateTableSchema(tableName, structFields)
       // 5 commits will trigger compaction.
       checkAnswer(s"select id, name, price, ts from $tableName order by id")(
         Seq(1, "a1", 10.0, 1000),
@@ -1512,23 +1542,86 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase 
with ScalaAssertionSuppo
              | """.stripMargin
 
         checkExceptionContain(sqlStatement1)(
-          getExpectedExceptionMessage("new_col", targetTableFields))
+          getExpectedUnresolvedColumnExceptionMessage("new_col", 
targetTableFields))
         checkExceptionContain(sqlStatement2)(
-          getExpectedExceptionMessage("s0.new_col", sourceTableFields ++ 
targetTableFields))
+          getExpectedUnresolvedColumnExceptionMessage("s0.new_col", 
sourceTableFields ++ targetTableFields))
       }
     }
   }
 
-  private def getExpectedExceptionMessage(columnName: String,
-                                          fieldNameTuples: Seq[(String, 
String, String)]): String = {
-    val fieldNames = fieldNameTuples.sortBy(e => (e._1, e._2))
-      .map(e => e._3).mkString("[", ", ", "]")
-    if (HoodieSparkUtils.gteqSpark3_5) {
-      "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with 
name " +
-        s"$columnName cannot be resolved. Did you mean one of the following? 
$fieldNames."
-    } else {
-      s"cannot resolve $columnName in MERGE command given columns $fieldNames" 
+
-        (if (HoodieSparkUtils.gteqSpark3_4) "." else "")
+  test("Test no schema evolution in MERGE INTO") {
+    Seq("cow", "mor").foreach { tableType =>
+      withTempDir { tmp =>
+        val tableName = generateTableName
+        // Create a partitioned table
+        spark.sql(
+          s"""
+             |create table $tableName (
+             |  id int,
+             |  dt string,
+             |  name string,
+             |  price double,
+             |  ts long
+             |) using hudi
+             | tblproperties (primaryKey = 'id', type = '$tableType')
+             | partitioned by (dt)
+             | location '${tmp.getCanonicalPath}'
+             | """.stripMargin)
+        val structFields = List(
+          StructField("id", IntegerType, nullable = true),
+          StructField("name", StringType, nullable = true),
+          StructField("price", DoubleType, nullable = true),
+          StructField("ts", LongType, nullable = true),
+          StructField("dt", StringType, nullable = true))
+
+        spark.sql(
+          s"""
+             | insert into $tableName partition(dt = '2024-01-14')
+             | select 1 as id, 'a1' as name, 10 as price, 1000 as ts
+             | union
+             | select 2 as id, 'a2' as name, 20 as price, 1002 as ts
+             | """.stripMargin)
+        checkAnswer(s"select id, name, price, ts, dt from $tableName")(
+          Seq(1, "a1", 10.0, 1000, "2024-01-14"),
+          Seq(2, "a2", 20.0, 1002, "2024-01-14")
+        )
+
+        // MERGE INTO with an additional column that does not exist in the 
table schema
+        // throws an error if it tries to set the new column in the "UPDATE 
SET" clause
+        val sqlStatement1 =
+          s"""
+             | merge into $tableName
+             | using (
+             |  select 2 as _id, '2024-01-14' as dt, 'a2_new' as name, 25 as 
_price, 1005 as _ts, 'x' as new_col
+             |  union
+             |  select 3 as _id, '2024-01-14' as dt, 'a3' as name, 30 as 
_price, 1003 as _ts, 'y' as new_col
+             | ) s0
+             | on s0._id = $tableName.id
+             | when matched then update set
+             | id = s0._id, dt = s0.dt, name = s0.name, price = s0._price + 
$tableName.price,
+             | ts = s0._ts, new_col = s0.new_col
+             | """.stripMargin
+        checkExceptionContain(sqlStatement1)(
+          getExpectedUnresolvedColumnExceptionMessage("new_col", tableName))
+        validateTableSchema(tableName, structFields)
+
+        // MERGE INTO with an additional column that does not exist in the 
table schema
+        // works only if it tries to do inserts (the additional column is 
dropped before inserts)
+        spark.sql(
+          s"""
+             | merge into $tableName
+             | using (
+             |  select 3 as id, '2024-01-14' as dt, 'a3' as name, 30 as price, 
1003 as ts, 'y' as new_col
+             | ) s0
+             | on s0.id = $tableName.id
+             | when not matched then insert *
+             | """.stripMargin)
+        validateTableSchema(tableName, structFields)
+        checkAnswer(s"select id, name, price, ts, dt from $tableName")(
+          Seq(1, "a1", 10.0, 1000, "2024-01-14"),
+          Seq(2, "a2", 20.0, 1002, "2024-01-14"),
+          Seq(3, "a3", 30.0, 1003, "2024-01-14"))
+      }
     }
   }
 }
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestPartialUpdateForMergeInto.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestPartialUpdateForMergeInto.scala
index c3894593105..df81df34cf9 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestPartialUpdateForMergeInto.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/TestPartialUpdateForMergeInto.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.hudi.dml
 
-import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, 
HoodieSparkUtils}
+import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions}
 import org.apache.hudi.avro.HoodieAvroUtils
 import org.apache.hudi.common.config.{HoodieCommonConfig, 
HoodieMetadataConfig, HoodieReaderConfig, HoodieStorageConfig}
 import org.apache.hudi.common.engine.HoodieLocalEngineContext
@@ -32,8 +32,10 @@ import org.apache.hudi.common.util.CompactionUtils
 import org.apache.hudi.config.{HoodieClusteringConfig, HoodieCompactionConfig, 
HoodieIndexConfig, HoodieWriteConfig}
 import org.apache.hudi.exception.HoodieNotSupportedException
 import org.apache.hudi.metadata.HoodieTableMetadata
+
 import org.apache.avro.Schema
 import org.apache.spark.sql.hudi.common.HoodieSparkSqlTestBase
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, 
StringType, StructField}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
 
 import java.util.{Collections, List, Optional}
@@ -181,6 +183,80 @@ class TestPartialUpdateForMergeInto extends 
HoodieSparkSqlTestBase {
     }
   }
 
+  test("Test MERGE INTO with partial updates containing non-existent columns 
on COW table") {
+    testPartialUpdateWithNonExistentColumns("cow")
+  }
+
+  test("Test MERGE INTO with partial updates containing non-existent columns 
on MOR table") {
+    testPartialUpdateWithNonExistentColumns("mor")
+  }
+
+  def testPartialUpdateWithNonExistentColumns(tableType: String): Unit = {
+    withTempDir { tmp =>
+      val tableName = generateTableName
+      val basePath = tmp.getCanonicalPath + "/" + tableName
+      spark.sql(s"set 
${HoodieWriteConfig.MERGE_SMALL_FILE_GROUP_CANDIDATES_LIMIT.key} = 0")
+      spark.sql(s"set 
${DataSourceWriteOptions.ENABLE_MERGE_INTO_PARTIAL_UPDATES.key} = true")
+
+      // Create a table with five data fields
+      spark.sql(
+        s"""
+           |create table $tableName (
+           | id int,
+           | name string,
+           | price double,
+           | _ts long,
+           | description string
+           |) using hudi
+           |tblproperties(
+           | type ='$tableType',
+           | primaryKey = 'id',
+           | preCombineField = '_ts'
+           |)
+           |location '$basePath'
+        """.stripMargin)
+      val structFields = scala.collection.immutable.List(
+        StructField("id", IntegerType, nullable = true),
+        StructField("name", StringType, nullable = true),
+        StructField("price", DoubleType, nullable = true),
+        StructField("_ts", LongType, nullable = true),
+        StructField("description", StringType, nullable = true))
+
+      spark.sql(s"insert into $tableName values (1, 'a1', 10, 1000, 'a1: 
desc1')," +
+        "(2, 'a2', 20, 1200, 'a2: desc2'), (3, 'a3', 30, 1250, 'a3: desc3')")
+      validateTableSchema(tableName, structFields)
+
+      // Partial updates using MERGE INTO statement with changed fields: 
"price", "_ts"
+      // This is OK since the "UPDATE SET" clause does not contain the new 
column
+      spark.sql(
+        s"""
+           |merge into $tableName t0
+           |using ( select 1 as id, 'a1' as name, 12.0 as price, 1001 as ts, 
'x' as new_col
+           |union select 3 as id, 'a3' as name, 25.0 as price, 1260 as ts, 'y' 
as new_col) s0
+           |on t0.id = s0.id
+           |when matched then update set price = s0.price, _ts = s0.ts
+           |""".stripMargin)
+
+      validateTableSchema(tableName, structFields)
+      checkAnswer(s"select id, name, price, _ts, description from $tableName")(
+        Seq(1, "a1", 12.0, 1001, "a1: desc1"),
+        Seq(2, "a2", 20.0, 1200, "a2: desc2"),
+        Seq(3, "a3", 25.0, 1260, "a3: desc3")
+      )
+
+      // Partial updates using MERGE INTO statement with changed fields: 
"price", "_ts", "new_col"
+      // This throws an error since the "UPDATE SET" clause contains the new 
column
+      checkExceptionContain(
+        s"""
+           |merge into $tableName
+           |using ( select 1 as id, 'a1' as name, 12.0 as price, 1001 as ts, 
'x' as new_col
+           |union select 3 as id, 'a3' as name, 25.0 as price, 1260 as ts, 'y' 
as new_col) s0
+           |on $tableName.id = s0.id
+           |when matched then update set price = s0.price, _ts = s0.ts, 
new_col = s0.new_col
+           
|""".stripMargin)(getExpectedUnresolvedColumnExceptionMessage("new_col", 
tableName))
+    }
+  }
+
   def testPartialUpdate(tableType: String,
                         logDataBlockFormat: String): Unit = {
     withTempDir { tmp =>
@@ -208,19 +284,27 @@ class TestPartialUpdateForMergeInto extends 
HoodieSparkSqlTestBase {
            |)
            |location '$basePath'
         """.stripMargin)
+      val structFields = scala.collection.immutable.List(
+        StructField("id", IntegerType, nullable = true),
+        StructField("name", StringType, nullable = true),
+        StructField("price", DoubleType, nullable = true),
+        StructField("_ts", LongType, nullable = true),
+        StructField("description", StringType, nullable = true))
       spark.sql(s"insert into $tableName values (1, 'a1', 10, 1000, 'a1: 
desc1')," +
         "(2, 'a2', 20, 1200, 'a2: desc2'), (3, 'a3', 30, 1250, 'a3: desc3')")
+      validateTableSchema(tableName, structFields)
 
       // Partial updates using MERGE INTO statement with changed fields: 
"price" and "_ts"
       spark.sql(
         s"""
            |merge into $tableName t0
-           |using ( select 1 as id, 'a1' as name, 12.0 as price, 1001 as _ts
-           |union select 3 as id, 'a3' as name, 25.0 as price, 1260 as _ts) s0
+           |using ( select 1 as id, 'a1' as name, 12.0 as price, 1001 as ts
+           |union select 3 as id, 'a3' as name, 25.0 as price, 1260 as ts) s0
            |on t0.id = s0.id
-           |when matched then update set price = s0.price, _ts = s0._ts
+           |when matched then update set price = s0.price, _ts = s0.ts
            |""".stripMargin)
 
+      validateTableSchema(tableName, structFields)
       checkAnswer(s"select id, name, price, _ts, description from $tableName")(
         Seq(1, "a1", 12.0, 1001, "a1: desc1"),
         Seq(2, "a2", 20.0, 1200, "a2: desc2"),
@@ -235,12 +319,13 @@ class TestPartialUpdateForMergeInto extends 
HoodieSparkSqlTestBase {
       spark.sql(
         s"""
            |merge into $tableName t0
-           |using ( select 1 as id, 'a1' as name, 'a1: updated desc1' as 
description, 1023 as _ts
-           |union select 2 as id, 'a2' as name, 'a2: updated desc2' as 
description, 1270 as _ts) s0
+           |using ( select 1 as id, 'a1' as name, 'a1: updated desc1' as 
new_description, 1023 as _ts
+           |union select 2 as id, 'a2' as name, 'a2: updated desc2' as 
new_description, 1270 as _ts) s0
            |on t0.id = s0.id
-           |when matched then update set description = s0.description, _ts = 
s0._ts
+           |when matched then update set description = s0.new_description, _ts 
= s0._ts
            |""".stripMargin)
 
+      validateTableSchema(tableName, structFields)
       checkAnswer(s"select id, name, price, _ts, description from $tableName")(
         Seq(1, "a1", 12.0, 1023, "a1: updated desc1"),
         Seq(2, "a2", 20.0, 1270, "a2: updated desc2"),
@@ -255,12 +340,13 @@ class TestPartialUpdateForMergeInto extends 
HoodieSparkSqlTestBase {
         spark.sql(
           s"""
              |merge into $tableName t0
-             |using ( select 2 as id, '_a2' as name, 18.0 as price, 1275 as _ts
-             |union select 3 as id, '_a3' as name, 28.0 as price, 1280 as _ts) 
s0
+             |using ( select 2 as id, '_a2' as name, 18.0 as _price, 1275 as 
_ts
+             |union select 3 as id, '_a3' as name, 28.0 as _price, 1280 as 
_ts) s0
              |on t0.id = s0.id
-             |when matched then update set price = s0.price, _ts = s0._ts
+             |when matched then update set price = s0._price, _ts = s0._ts
              |""".stripMargin)
         validateCompactionExecuted(basePath)
+        validateTableSchema(tableName, structFields)
         checkAnswer(s"select id, name, price, _ts, description from 
$tableName")(
           Seq(1, "a1", 12.0, 1023, "a1: updated desc1"),
           Seq(2, "a2", 18.0, 1275, "a2: updated desc2"),
@@ -273,13 +359,14 @@ class TestPartialUpdateForMergeInto extends 
HoodieSparkSqlTestBase {
         spark.sql(
           s"""
              |merge into $tableName t0
-             |using ( select 2 as id, '_a2' as name, 48.0 as price, 1275 as _ts
-             |union select 3 as id, '_a3' as name, 58.0 as price, 1280 as _ts) 
s0
+             |using ( select 2 as id, '_a2' as name, 48.0 as _price, 1275 as 
_ts
+             |union select 3 as id, '_a3' as name, 58.0 as _price, 1280 as 
_ts) s0
              |on t0.id = s0.id
-             |when matched then update set price = s0.price, _ts = s0._ts
+             |when matched then update set price = s0._price, _ts = s0._ts
              |""".stripMargin)
 
         validateClusteringExecuted(basePath)
+        validateTableSchema(tableName, structFields)
         checkAnswer(s"select id, name, price, _ts, description from 
$tableName")(
           Seq(1, "a1", 12.0, 1023, "a1: updated desc1"),
           Seq(2, "a2", 48.0, 1275, "a2: updated desc2"),

Reply via email to