This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 9e4e780a5820 [SPARK-54496][SQL] Fix Merge Into Schema Evolution for
Dataframe API
9e4e780a5820 is described below
commit 9e4e780a582017feb13dfcc62a68f09a6fe5bc1a
Author: Szehon Ho <[email protected]>
AuthorDate: Tue Nov 25 22:06:52 2025 -0800
[SPARK-54496][SQL] Fix Merge Into Schema Evolution for Dataframe API
### What changes were proposed in this pull request?
Some fixes to allow the Dataframe Merge API to support schema evolution.
The DataFrame API is here:
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala#L7
The fixes are described inline.
### Why are the changes needed?
The Dataframe Merge API is broken for schema evolution mode without these
fixes.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add unit tests. Will try to refactor later to combine test re-use.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53207 from szehon-ho/merge_schema_evolution_bug.
Authored-by: Szehon Ho <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit 9feb1b2c02202fcf04c2dc9a4f44fcd6c63cdeb8)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../analysis/ResolveMergeIntoSchemaEvolution.scala | 15 +-
.../sql/catalyst/plans/logical/v2Commands.scala | 30 +-
.../sql/connector/MergeIntoTableSuiteBase.scala | 999 ++++++++++++++++++++-
3 files changed, 1028 insertions(+), 16 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala
index ea0883f7928f..bbb8e7852b2c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util,
SupportsRowLevelOperations, TableCatalog, TableChange}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog,
TableChange}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -42,15 +45,19 @@ object ResolveMergeIntoSchemaEvolution extends
Rule[LogicalPlan] {
if (changes.isEmpty) {
m
} else {
- m transformUpWithNewOutput {
- case r @ DataSourceV2Relation(_: SupportsRowLevelOperations, _, _,
_, _, _) =>
+ val finalAttrMapping = ArrayBuffer.empty[(Attribute, Attribute)]
+ val newTarget = m.targetTable.transform {
+ case r: DataSourceV2Relation =>
val referencedSourceSchema =
MergeIntoTable.sourceSchemaForSchemaEvolution(m)
val newTarget = performSchemaEvolution(r, referencedSourceSchema,
changes)
val oldTargetOutput = m.targetTable.output
val newTargetOutput = newTarget.output
val attributeMapping = oldTargetOutput.zip(newTargetOutput)
- newTarget -> attributeMapping
+ finalAttrMapping ++= attributeMapping
+ newTarget
}
+ val res = m.copy(targetTable = newTarget)
+ res.rewriteAttrs(AttributeMap(finalAttrMapping.toSeq))
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 3f9e8da21d28..72274ee9bf17 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -916,19 +916,29 @@ case class MergeIntoTable(
false
} else {
val actions = matchedActions ++ notMatchedActions
- val assignments = actions.collect {
- case a: UpdateAction => a.assignments
- case a: InsertAction => a.assignments
- }.flatten
- val sourcePaths = DataTypeUtils.extractAllFieldPaths(sourceTable.schema)
- assignments.forall { assignment =>
- assignment.resolved ||
- (assignment.value.resolved && sourcePaths.exists {
- path => MergeIntoTable.isEqual(assignment, path)
- })
+ val hasStarActions = actions.exists {
+ case _: UpdateStarAction => true
+ case _: InsertStarAction => true
+ case _ => false
+ }
+ if (hasStarActions) {
+ // need to resolve star actions first
+ false
+ } else {
+ val assignments = actions.collect {
+ case a: UpdateAction => a.assignments
+ case a: InsertAction => a.assignments
+ }.flatten
+ val sourcePaths =
DataTypeUtils.extractAllFieldPaths(sourceTable.schema)
+ assignments.forall { assignment =>
+ assignment.resolved ||
+ (assignment.value.resolved && sourcePaths.exists {
+ path => MergeIntoTable.isEqual(assignment, path)
+ })
}
}
}
+ }
private lazy val sourceSchemaForEvolution: StructType =
MergeIntoTable.sourceSchemaForSchemaEvolution(this)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index b7a8ff374b84..680fa63e0929 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -21,13 +21,14 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo,
In, Not}
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column,
ColumnDefaultValue, InMemoryTable, TableInfo}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column,
ColumnDefaultValue, Identifier, InMemoryTable, TableInfo}
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression,
LiteralValue}
import org.apache.spark.sql.connector.write.MergeSummary
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
BroadcastNestedLoopJoinExec, CartesianProductExec}
+import org.apache.spark.sql.functions.{array, col, lit, map, struct, substring}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType,
LongType, MapType, StringType, StructField, StructType}
@@ -4411,7 +4412,8 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
}
}
- test("Merge schema evolution should evolve referencing new column assigned
to something else") {
+ test("Merge schema evolution should not evolve when referencing new column" +
+ "assigned to something else") {
Seq(true, false).foreach { withSchemaEvolution =>
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
@@ -5233,6 +5235,999 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
+ test("merge with schema evolution using dataframe API: add new column and
set all") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT,
dep STRING)")
+
+ val targetData = Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("salary", IntegerType),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("salary", IntegerType),
+ Column.create("dep", StringType),
+ Column.create("new_col", IntegerType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ sql(s"INSERT INTO $sourceTable VALUES (1, 101, 'support', 1)," +
+ s"(3, 301, 'support', 3), (4, 401, 'finance', 4)")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString,
+ $"source_table.pk" === col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .updateAll()
+ .whenNotMatched()
+ .insertAll()
+
+ if (withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "support", 1),
+ Row(2, 200, "software", null),
+ Row(3, 301, "support", 3),
+ Row(4, 401, "finance", 4)))
+ } else {
+ mergeBuilder.merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "support"),
+ Row(2, 200, "software"),
+ Row(3, 301, "support"),
+ Row(4, 401, "finance")))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge schema evolution new column with set explicit column using
dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT,
dep STRING)")
+
+ val targetData = Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(3, 300, "hr"),
+ Row(4, 400, "marketing"),
+ Row(5, 500, "executive")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("salary", IntegerType),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("salary", IntegerType),
+ Column.create("dep", StringType),
+ Column.create("active", BooleanType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ sql(s"INSERT INTO $sourceTable VALUES (4, 150, 'dummy', true)," +
+ s"(5, 250, 'dummy', true), (6, 350, 'dummy', false)")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .update(Map("dep" -> lit("software"), "active" ->
col("source_table.active")))
+ .whenNotMatched()
+ .insert(Map("pk" -> col("source_table.pk"), "salary" -> lit(0),
+ "dep" -> col("source_table.dep"), "active" ->
col("source_table.active")))
+
+ if (withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr", null),
+ Row(2, 200, "software", null),
+ Row(3, 300, "hr", null),
+ Row(4, 400, "software", true),
+ Row(5, 500, "software", true),
+ Row(6, 0, "dummy", false)))
+ } else {
+ val e = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION")
+ assert(e.getMessage.contains("A column, variable, or function
parameter with name " +
+ "`active` cannot be resolved"))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge schema evolution add column with nested struct and set explicit
columns " +
+ "using dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
+ |dep STRING)""".stripMargin)
+
+ val targetData = Seq(
+ Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType))
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType) // new column
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ val data = Seq(
+ Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)),
"engineering")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .update(Map(
+ "s.c1" -> lit(-1),
+ "s.c2.m" -> map(lit("k"), lit("v")),
+ "s.c2.a" -> array(lit(-1)),
+ "s.c2.c3" -> col("source_table.s.c2.c3")))
+ .whenNotMatched()
+ .insert(Map(
+ "pk" -> col("source_table.pk"),
+ "s" -> struct(
+ col("source_table.s.c1").as("c1"),
+ struct(
+ col("source_table.s.c2.a").as("a"),
+ map(lit("g"), lit("h")).as("m"),
+ lit(true).as("c3")
+ ).as("c2")
+ ),
+ "dep" -> col("source_table.dep")))
+
+ if (withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"),
+ Row(2, Row(20, Row(Seq(4, 5), Map("g" -> "h"), true)),
"engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get == "FIELD_NOT_FOUND")
+ assert(exception.getMessage.contains("No such struct field `c3` in
`a`, `m`. "))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge schema evolution add column with nested struct and set all
columns " +
+ "using dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
+ |dep STRING)""".stripMargin)
+
+ val targetData = Seq(
+ Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType))
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType) // new column
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ val data = Seq(
+ Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)),
"engineering")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .updateAll()
+ .whenNotMatched()
+ .insertAll()
+
+ if (withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(1, Row(10, Row(Seq(3, 4), Map("c" -> "d"), false)),
"sales"),
+ Row(2, Row(20, Row(Seq(4, 5), Map("e" -> "f"), true)),
"engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `s`.`c2`"))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge schema evolution replace column with nested struct and " +
+ "set explicit columns using dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
+ |dep STRING)""".stripMargin)
+
+ val targetData = Seq(
+ Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType))
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ // removed column 'a'
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType) // new column
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ val data = Seq(
+ Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .update(Map(
+ "s.c1" -> lit(-1),
+ "s.c2.m" -> map(lit("k"), lit("v")),
+ "s.c2.a" -> array(lit(-1)),
+ "s.c2.c3" -> col("source_table.s.c2.c3")))
+ .whenNotMatched()
+ .insert(Map(
+ "pk" -> col("source_table.pk"),
+ "s" -> struct(
+ col("source_table.s.c1").as("c1"),
+ struct(
+ array(lit(-2)).as("a"),
+ map(lit("g"), lit("h")).as("m"),
+ lit(true).as("c3")
+ ).as("c2")
+ ),
+ "dep" -> col("source_table.dep")))
+
+ if (withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"),
+ Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)),
"engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get == "FIELD_NOT_FOUND")
+ assert(exception.getMessage.contains("No such struct field `c3` in
`a`, `m`. "))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge schema evolution replace column with nested struct and set all
columns " +
+ "using dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
+ |dep STRING)
+ |PARTITIONED BY (dep)
+ |""".stripMargin)
+
+ val tableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType))
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val targetData = Seq(
+ Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
tableSchema)
+ .coalesce(1).writeTo(tableNameAsString).append()
+
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ // missing column 'a'
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType) // new column
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ val sourceData = Seq(
+ Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .updateAll()
+ .whenNotMatched()
+ .insertAll()
+
+ if (withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)),
"engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `s`.`c2`"))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge schema evolution replace column with nested struct and " +
+ "update top level struct using dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
+ |dep STRING)
+ |PARTITIONED BY (dep)
+ |""".stripMargin)
+
+ val tableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType))
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val targetData = Seq(
+ Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
tableSchema)
+ .coalesce(1).writeTo(tableNameAsString).append()
+
+ // Create source table
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ // missing column 'a'
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType) // new column
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ val sourceData = Seq(
+ Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .update(Map("s" -> col("source_table.s")))
+ .whenNotMatched()
+ .insertAll()
+
+ if (withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)),
"engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `s`.`c2`"))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge schema evolution should not evolve referencing new column " +
+ "via transform using dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT,
dep STRING)")
+
+ val targetData = Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("salary", IntegerType),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("salary", IntegerType),
+ Column.create("dep", StringType),
+ Column.create("extra", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ sql(s"INSERT INTO $sourceTable VALUES (2, 150, 'dummy', 'blah')," +
+ s"(3, 250, 'dummy', 'blah')")
+
+ val e = intercept[org.apache.spark.sql.AnalysisException] {
+ val builder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString,
+ $"source_table.pk" === col(tableNameAsString + ".pk"))
+
+ val builderWithEvolution = if (withSchemaEvolution) {
+ builder.withSchemaEvolution()
+ } else {
+ builder
+ }
+
+ builderWithEvolution
+ .whenMatched()
+ .update(Map("extra" -> substring(col("source_table.extra"), 1, 2)))
+ .merge()
+ }
+ assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION")
+ assert(e.getMessage.contains(
+ "A column, variable, or function parameter with name " +
+ "`extra` cannot be resolved"))
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge into with source missing fields in top-level struct using
dataframe API") {
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ // Target table has struct with 3 fields at top level
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRING, c3: BOOLEAN>,
+ |dep STRING)""".stripMargin)
+
+ val targetData = Seq(
+ Row(0, Row(1, "a", true), "sales")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType),
+ StructField("c3", BooleanType)
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Create source table with struct having only 2 fields (c1, c2) -
missing c3
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)))), // missing c3 field
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ val data = Seq(
+ Row(1, Row(10, "b"), "hr"),
+ Row(2, Row(20, "c"), "engineering")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)))),
+ StructField("dep", StringType)))
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+
+ spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .updateAll()
+ .whenNotMatched()
+ .insertAll()
+ .merge()
+
+ // Missing field c3 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, "a", true), "sales"),
+ Row(1, Row(10, "b", null), "hr"),
+ Row(2, Row(20, "c", null), "engineering")))
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+
+ test("merge with null struct with missing nested field using dataframe API")
{
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ // Target table has nested struct with fields c1 and c2
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: STRING>>,
+ |dep STRING)""".stripMargin)
+
+ val targetData = Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, Row(2, Row(20, "y")), "hr")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Create source table with missing nested field 'b'
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ // missing field 'b'
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ // Source table has null for the nested struct
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString,
+ $"source_table.pk" === col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .updateAll()
+ .whenNotMatched()
+ .insertAll()
+
+ if (coerceNestedTypes) {
+ mergeBuilder.merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ // Without coercion, the merge should fail due to missing field
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
+ assert(exception.getMessage.contains(
+ "Cannot write incompatible data for the table ``: " +
+ "Cannot find data for the output column `s`.`c2`.`b`."))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+ }
+
+ test("merge null struct with schema evolution - " +
+ "source with missing and extra nested fields using dataframe API") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ // Target table has nested struct with fields c1 and c2
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: STRING>>,
+ |dep STRING)""".stripMargin)
+
+ val targetData = Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, Row(2, Row(20, "y")), "hr")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+
spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Create source table with missing field 'b' and extra field
'c' in nested struct
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ // missing field 'b'
+ StructField("c", StringType) // extra field 'c'
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
+
+ // Source data has null for the nested struct
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("c", StringType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
+
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString, $"source_table.pk" ===
col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .updateAll()
+ .whenNotMatched()
+ .insertAll()
+
+ if (coerceNestedTypes) {
+ if (withSchemaEvolution) {
+ // extra nested field is added
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x", null)), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ // extra nested field is not added
+ val exception =
intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write incompatible data for the table ``: " +
+ "Cannot write extra fields `c` to the struct `s`.`c2`"))
+ }
+ } else {
+ // Without source struct coercion, the merge should fail
+ val exception =
intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
+ assert(exception.getMessage.contains(
+ "Cannot write incompatible data for the table ``: " +
+ "Cannot find data for the output column `s`.`c2`.`b`."))
+ }
+
+ sql(s"DROP TABLE $tableNameAsString")
+ }
+ }
+ }
+ }
+ }
+
test("Merge schema evolution should error on non-existent column in UPDATE
and INSERT") {
withTable(tableNameAsString) {
withTempView("source") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]