This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e459674127e7 [SPARK-48683][SQL] Fix schema evolution with
`df.mergeInto` losing `when` clauses
e459674127e7 is described below
commit e459674127e7b21e2767cc62d10ea6f1f941936c
Author: Paddy Xu <[email protected]>
AuthorDate: Mon Jun 24 17:25:56 2024 +0900
[SPARK-48683][SQL] Fix schema evolution with `df.mergeInto` losing `when`
clauses
### What changes were proposed in this pull request?
This PR fixes an issue in the `DataFrame.mergeInto` API where defined
`when` clauses are lost after calling the `withSchemaEvoltuion()` method.
The issue is caused by not copying over existing clauses to the new writer.
### Why are the changes needed?
It fixes a bug.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47055 from xupefei/mergeinto-bugfix.
Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/sql/MergeIntoWriter.scala | 33 ++++++++++++----------
.../sql/connector/MergeIntoDataFrameSuite.scala | 28 ++++++++++++++++++
2 files changed, 46 insertions(+), 15 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
index 5020d1c88023..b7f9c96f82e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
@@ -42,7 +42,7 @@ class MergeIntoWriter[T] private[sql] (
table: String,
ds: Dataset[T],
on: Column,
- schemaEvolutionEnabled: Boolean = false) {
+ private[sql] val schemaEvolutionEnabled: Boolean = false) {
private val df: DataFrame = ds.toDF()
@@ -172,6 +172,9 @@ class MergeIntoWriter[T] private[sql] (
*/
def withSchemaEvolution(): MergeIntoWriter[T] = {
new MergeIntoWriter[T](this.table, this.ds, this.on,
schemaEvolutionEnabled = true)
+ .withNewMatchedActions(this.matchedActions: _*)
+ .withNewNotMatchedActions(this.notMatchedActions: _*)
+ .withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*)
}
/**
@@ -196,18 +199,18 @@ class MergeIntoWriter[T] private[sql] (
qe.assertCommandExecuted()
}
- private[sql] def withNewMatchedAction(action: MergeAction):
MergeIntoWriter[T] = {
- this.matchedActions = this.matchedActions :+ action
+ private[sql] def withNewMatchedActions(actions: MergeAction*):
MergeIntoWriter[T] = {
+ this.matchedActions ++= actions
this
}
- private[sql] def withNewNotMatchedAction(action: MergeAction):
MergeIntoWriter[T] = {
- this.notMatchedActions = this.notMatchedActions :+ action
+ private[sql] def withNewNotMatchedActions(actions: MergeAction*):
MergeIntoWriter[T] = {
+ this.notMatchedActions ++= actions
this
}
- private[sql] def withNewNotMatchedBySourceAction(action: MergeAction):
MergeIntoWriter[T] = {
- this.notMatchedBySourceActions = this.notMatchedBySourceActions :+ action
+ private[sql] def withNewNotMatchedBySourceActions(actions: MergeAction*):
MergeIntoWriter[T] = {
+ this.notMatchedBySourceActions ++= actions
this
}
}
@@ -234,7 +237,7 @@ case class WhenMatched[T] private[sql](
* @return The MergeIntoWriter instance with the update all action
configured.
*/
def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedAction(UpdateStarAction(condition))
+ mergeIntoWriter.withNewMatchedActions(UpdateStarAction(condition))
}
/**
@@ -245,7 +248,7 @@ case class WhenMatched[T] private[sql](
* @return The MergeIntoWriter instance with the update action configured.
*/
def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedAction(
+ mergeIntoWriter.withNewMatchedActions(
UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr,
x._2.expr)).toSeq))
}
@@ -255,7 +258,7 @@ case class WhenMatched[T] private[sql](
* @return The MergeIntoWriter instance with the delete action configured.
*/
def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedAction(DeleteAction(condition))
+ mergeIntoWriter.withNewMatchedActions(DeleteAction(condition))
}
}
@@ -281,7 +284,7 @@ case class WhenNotMatched[T] private[sql](
* @return The MergeIntoWriter instance with the insert all action
configured.
*/
def insertAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedAction(InsertStarAction(condition))
+ mergeIntoWriter.withNewNotMatchedActions(InsertStarAction(condition))
}
/**
@@ -292,7 +295,7 @@ case class WhenNotMatched[T] private[sql](
* @return The MergeIntoWriter instance with the insert action configured.
*/
def insert(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedAction(
+ mergeIntoWriter.withNewNotMatchedActions(
InsertAction(condition, map.map(x => Assignment(expr(x._1).expr,
x._2.expr)).toSeq))
}
}
@@ -317,7 +320,7 @@ case class WhenNotMatchedBySource[T] private[sql](
* @return The MergeIntoWriter instance with the update all action
configured.
*/
def updateAll(): MergeIntoWriter[T] = {
-
mergeIntoWriter.withNewNotMatchedBySourceAction(UpdateStarAction(condition))
+
mergeIntoWriter.withNewNotMatchedBySourceActions(UpdateStarAction(condition))
}
/**
@@ -328,7 +331,7 @@ case class WhenNotMatchedBySource[T] private[sql](
* @return The MergeIntoWriter instance with the update action configured.
*/
def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceAction(
+ mergeIntoWriter.withNewNotMatchedBySourceActions(
UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr,
x._2.expr)).toSeq))
}
@@ -339,6 +342,6 @@ case class WhenNotMatchedBySource[T] private[sql](
* @return The MergeIntoWriter instance with the delete action configured.
*/
def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceAction(DeleteAction(condition))
+ mergeIntoWriter.withNewNotMatchedBySourceActions(DeleteAction(condition))
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
index ed44111c81d2..c080a66bce25 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
@@ -943,4 +943,32 @@ class MergeIntoDataFrameSuite extends
RowLevelOperationSuiteBase {
Row(3, Row("y1 ", "y2"), "hr"))) // update (not matched by source)
}
}
+
+ test("withSchemaEvolution carries over existing when clauses") {
+ withTempView("source") {
+ Seq(1, 2, 4).toDF("pk").createOrReplaceTempView("source")
+
+ // an arbitrary merge
+ val writer1 = spark.table("source")
+ .mergeInto("dummy", $"col" === $"col")
+ .whenMatched(col("col") === 1)
+ .updateAll()
+ .whenMatched()
+ .delete()
+ .whenNotMatched(col("col") === 1)
+ .insertAll()
+ .whenNotMatchedBySource(col("col") === 1)
+ .delete()
+ val writer2 = writer1.withSchemaEvolution()
+
+ assert(writer1.matchedActions.length === 2)
+ assert(writer1.notMatchedActions.length === 1)
+ assert(writer1.notMatchedBySourceActions.length === 1)
+
+ assert(writer1.matchedActions === writer2.matchedActions)
+ assert(writer1.notMatchedActions === writer2.notMatchedActions)
+ assert(writer1.notMatchedBySourceActions ===
writer2.notMatchedBySourceActions)
+ assert(writer2.schemaEvolutionEnabled)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]