This is an automated email from the ASF dual-hosted git repository.

biyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new a128251d4 [spark] Fix resolve merge into with alias (#4189)
a128251d4 is described below

commit a128251d4d3380c7ec4a9c81071a2a407385a79e
Author: Zouxxyy <[email protected]>
AuthorDate: Mon Sep 23 10:01:56 2024 +0800

    [spark] Fix resolve merge into with alias (#4189)
---
 .../analysis/PaimonMergeIntoResolver.scala         |   2 -
 .../analysis/PaimonMergeIntoResolver.scala         |   2 -
 .../analysis/PaimonMergeIntoResolver.scala         |   2 -
 .../analysis/PaimonMergeIntoResolver.scala         |  40 ++----
 .../analysis/PaimonMergeIntoResolverBase.scala     | 154 +++++++++++++--------
 .../sql/MergeIntoNotMatchedBySourceTest.scala      |  41 ++++++
 .../paimon/spark/sql/MergeIntoTableTestBase.scala  |  38 +++++
 7 files changed, 187 insertions(+), 92 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
 
b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
index 031ea0a18..e0869a608 100644
--- 
a/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
+++ 
b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
@@ -25,8 +25,6 @@ object PaimonMergeIntoResolver extends 
PaimonMergeIntoResolverBase {
 
   def resolveNotMatchedBySourceActions(
       merge: MergeIntoTable,
-      target: LogicalPlan,
-      source: LogicalPlan,
       resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction] = {
     Seq.empty
   }
diff --git 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
index 031ea0a18..e0869a608 100644
--- 
a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
+++ 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
@@ -25,8 +25,6 @@ object PaimonMergeIntoResolver extends 
PaimonMergeIntoResolverBase {
 
   def resolveNotMatchedBySourceActions(
       merge: MergeIntoTable,
-      target: LogicalPlan,
-      source: LogicalPlan,
       resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction] = {
     Seq.empty
   }
diff --git 
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
 
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
index 031ea0a18..e0869a608 100644
--- 
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
+++ 
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
@@ -25,8 +25,6 @@ object PaimonMergeIntoResolver extends 
PaimonMergeIntoResolverBase {
 
   def resolveNotMatchedBySourceActions(
       merge: MergeIntoTable,
-      target: LogicalPlan,
-      source: LogicalPlan,
       resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction] = {
     Seq.empty
   }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
index bb66e3826..4525393bd 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolver.scala
@@ -27,36 +27,20 @@ object PaimonMergeIntoResolver extends 
PaimonMergeIntoResolverBase {
 
   def resolveNotMatchedBySourceActions(
       merge: MergeIntoTable,
-      target: LogicalPlan,
-      source: LogicalPlan,
       resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction] = {
-    val fakeSource = Project(source.output, source)
-
-    def resolveMergeAction(action: MergeAction): MergeAction = {
-      action match {
-        case DeleteAction(condition) =>
-          val resolvedCond = condition.map(resolve(_, target))
-          DeleteAction(resolvedCond)
-        case UpdateAction(condition, assignments) =>
-          val resolvedCond = condition.map(resolve(_, target))
-          val resolvedAssignments = assignments.map {
-            assignment =>
-              assignment.copy(
-                key = resolve(assignment.key, target),
-                value = resolve(assignment.value, target))
-          }
-          UpdateAction(resolvedCond, resolvedAssignments)
-        case UpdateStarAction(condition) =>
-          val resolvedCond = condition.map(resolve(_, target))
-          val resolvedAssignments = target.output.map {
-            attr =>
-              Assignment(attr, 
resolve(UnresolvedAttribute.quotedString(attr.name), fakeSource))
-          }
-          UpdateAction(resolvedCond, resolvedAssignments)
-      }
+    merge.notMatchedBySourceActions.map {
+      case DeleteAction(condition) =>
+        // The condition must be from the target table
+        val resolvedCond = condition.map(resolveCondition(resolve, _, merge, 
TARGET_ONLY))
+        DeleteAction(resolvedCond)
+      case UpdateAction(condition, assignments) =>
+        // The condition and value must be from the target table
+        val resolvedCond = condition.map(resolveCondition(resolve, _, merge, 
TARGET_ONLY))
+        val resolvedAssignments = resolveAssignments(resolve, assignments, 
merge, TARGET_ONLY)
+        UpdateAction(resolvedCond, resolvedAssignments)
+      case action =>
+        throw new RuntimeException(s"Can't recognize this action: $action")
     }
-
-    merge.notMatchedBySourceActions.map(resolveMergeAction)
   }
 
   def build(
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolverBase.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolverBase.scala
index cbd6b52c0..218fc9c0f 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolverBase.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoResolverBase.scala
@@ -23,7 +23,7 @@ import 
org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, 
InsertAction, InsertStarAction, LogicalPlan, MergeAction, MergeIntoTable, 
Project, UpdateAction, UpdateStarAction}
+import org.apache.spark.sql.catalyst.plans.logical._
 
 trait PaimonMergeIntoResolverBase extends ExpressionHelper {
 
@@ -33,73 +33,111 @@ trait PaimonMergeIntoResolverBase extends ExpressionHelper 
{
     assert(target.resolved, "Target should have been resolved here.")
     assert(source.resolved, "Source should have been resolved here.")
 
-    val condition = merge.mergeCondition
-    val matched = merge.matchedActions
-    val notMatched = merge.notMatchedActions
-    val fakeSource = Project(source.output, source)
-
     val resolve: (Expression, LogicalPlan) => Expression = 
resolveExpression(spark)
 
-    def resolveMergeAction(action: MergeAction): MergeAction = {
-      action match {
-        case DeleteAction(condition) =>
-          val resolvedCond = condition.map(resolve(_, merge))
-          DeleteAction(resolvedCond)
-        case UpdateAction(condition, assignments) =>
-          val resolvedCond = condition.map(resolve(_, merge))
-          val resolvedAssignments = assignments.map {
-            assignment =>
-              assignment.copy(
-                key = resolve(assignment.key, target),
-                value = resolve(assignment.value, merge))
-          }
-          UpdateAction(resolvedCond, resolvedAssignments)
-        case UpdateStarAction(condition) =>
-          val resolvedCond = condition.map(resolve(_, merge))
-          val resolvedAssignments = target.output.map {
-            attr => Assignment(attr, 
resolve(UnresolvedAttribute.quotedString(attr.name), source))
-          }
-          UpdateAction(resolvedCond, resolvedAssignments)
-        case InsertAction(condition, assignments) =>
-          val resolvedCond = condition.map(resolve(_, fakeSource))
-          val resolvedAssignments = assignments.map {
-            assignment =>
-              assignment.copy(
-                key = resolve(assignment.key, fakeSource),
-                value = resolve(assignment.value, fakeSource))
-          }
-          InsertAction(resolvedCond, resolvedAssignments)
-        case InsertStarAction(condition) =>
-          val resolvedCond = condition.map(resolve(_, fakeSource))
-          val resolvedAssignments = target.output.map {
-            attr =>
-              Assignment(attr, 
resolve(UnresolvedAttribute.quotedString(attr.name), fakeSource))
-          }
-          InsertAction(resolvedCond, resolvedAssignments)
-        case _ =>
-          throw new RuntimeException(s"Can't recognize this action: $action")
-      }
-    }
-
-    val resolvedCond = resolve(condition, merge)
-    val resolvedMatched: Seq[MergeAction] = matched.map(resolveMergeAction)
-    val resolvedNotMatched: Seq[MergeAction] = 
notMatched.map(resolveMergeAction)
-    val resolvedNotMatchedBySource: Seq[MergeAction] =
-      resolveNotMatchedBySourceActions(merge, target, source, resolve)
+    val resolvedCond = resolveCondition(resolve, merge.mergeCondition, merge, 
ALL)
+    val resolvedMatched = resolveMatchedByTargetActions(merge, resolve)
+    val resolvedNotMatched = resolveNotMatchedByTargetActions(merge, resolve)
+    val resolvedNotMatchedBySource = resolveNotMatchedBySourceActions(merge, 
resolve)
 
     build(merge, resolvedCond, resolvedMatched, resolvedNotMatched, 
resolvedNotMatchedBySource)
   }
 
-  def resolveNotMatchedBySourceActions(
-      merge: MergeIntoTable,
-      target: LogicalPlan,
-      source: LogicalPlan,
-      resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction]
-
   def build(
       merge: MergeIntoTable,
       resolvedCond: Expression,
       resolvedMatched: Seq[MergeAction],
       resolvedNotMatched: Seq[MergeAction],
       resolvedNotMatchedBySource: Seq[MergeAction]): MergeIntoTable
+
+  private def resolveMatchedByTargetActions(
+      merge: MergeIntoTable,
+      resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction] = {
+    merge.matchedActions.map {
+      case DeleteAction(condition) =>
+        // The condition can be from both target and source tables
+        val resolvedCond = condition.map(resolveCondition(resolve, _, merge, 
ALL))
+        DeleteAction(resolvedCond)
+      case UpdateAction(condition, assignments) =>
+        // The condition and value can be from both target and source tables
+        val resolvedCond = condition.map(resolveCondition(resolve, _, merge, 
ALL))
+        val resolvedAssignments = resolveAssignments(resolve, assignments, 
merge, ALL)
+        UpdateAction(resolvedCond, resolvedAssignments)
+      case UpdateStarAction(condition) =>
+        // The condition can be from both target and source tables, but the 
value must be from the source table
+        val resolvedCond = condition.map(resolveCondition(resolve, _, merge, 
ALL))
+        val assignments = merge.targetTable.output.map {
+          attr => Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
+        }
+        val resolvedAssignments =
+          resolveAssignments(resolve, assignments, merge, SOURCE_ONLY)
+        UpdateAction(resolvedCond, resolvedAssignments)
+      case action =>
+        throw new RuntimeException(s"Can't recognize this action: $action")
+    }
+  }
+
+  private def resolveNotMatchedByTargetActions(
+      merge: MergeIntoTable,
+      resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction] = {
+    merge.notMatchedActions.map {
+      case InsertAction(condition, assignments) =>
+        // The condition and value must be from the source table
+        val resolvedCond =
+          condition.map(resolveCondition(resolve, _, merge, SOURCE_ONLY))
+        val resolvedAssignments =
+          resolveAssignments(resolve, assignments, merge, SOURCE_ONLY)
+        InsertAction(resolvedCond, resolvedAssignments)
+      case InsertStarAction(condition) =>
+        // The condition and value must be from the source table
+        val resolvedCond =
+          condition.map(resolveCondition(resolve, _, merge, SOURCE_ONLY))
+        val assignments = merge.targetTable.output.map {
+          attr => Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
+        }
+        val resolvedAssignments =
+          resolveAssignments(resolve, assignments, merge, SOURCE_ONLY)
+        InsertAction(resolvedCond, resolvedAssignments)
+      case action =>
+        throw new RuntimeException(s"Can't recognize this action: $action")
+    }
+  }
+
+  def resolveNotMatchedBySourceActions(
+      merge: MergeIntoTable,
+      resolve: (Expression, LogicalPlan) => Expression): Seq[MergeAction]
+
+  sealed trait ResolvedWith
+  case object ALL extends ResolvedWith
+  case object SOURCE_ONLY extends ResolvedWith
+  case object TARGET_ONLY extends ResolvedWith
+
+  def resolveCondition(
+      resolve: (Expression, LogicalPlan) => Expression,
+      condition: Expression,
+      mergeInto: MergeIntoTable,
+      resolvedWith: ResolvedWith): Expression = {
+    resolvedWith match {
+      case ALL => resolve(condition, mergeInto)
+      case SOURCE_ONLY => resolve(condition, Project(Nil, 
mergeInto.sourceTable))
+      case TARGET_ONLY => resolve(condition, Project(Nil, 
mergeInto.targetTable))
+    }
+  }
+
+  def resolveAssignments(
+      resolve: (Expression, LogicalPlan) => Expression,
+      assignments: Seq[Assignment],
+      mergeInto: MergeIntoTable,
+      resolvedWith: ResolvedWith): Seq[Assignment] = {
+    assignments.map {
+      assign =>
+        val resolvedKey = resolve(assign.key, Project(Nil, 
mergeInto.targetTable))
+        val resolvedValue = resolvedWith match {
+          case ALL => resolve(assign.value, mergeInto)
+          case SOURCE_ONLY => resolve(assign.value, Project(Nil, 
mergeInto.sourceTable))
+          case TARGET_ONLY => resolve(assign.value, Project(Nil, 
mergeInto.targetTable))
+        }
+        Assignment(resolvedKey, resolvedValue)
+    }
+  }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoNotMatchedBySourceTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoNotMatchedBySourceTest.scala
index 6a2918165..2c46af0ae 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoNotMatchedBySourceTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoNotMatchedBySourceTest.scala
@@ -142,4 +142,45 @@ trait MergeIntoNotMatchedBySourceTest extends 
PaimonSparkTestBase with PaimonTab
       )
     }
   }
+
+  test(s"Paimon MergeInto: multiple clauses with not matched by source with 
alias") {
+    withTable("source", "target") {
+
+      Seq((1, 100, "c11"), (3, 300, "c33"), (5, 500, "c55"), (7, 700, "c77"), 
(9, 900, "c99"))
+        .toDF("a", "b", "c")
+        .createOrReplaceTempView("source")
+
+      createTable("target", "a INT, b INT, c STRING", Seq("a"))
+      spark.sql(
+        "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 
'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
+
+      spark.sql(s"""
+                   |MERGE INTO target t
+                   |USING source s
+                   |ON t.a = s.a
+                   |WHEN MATCHED AND t.a = 5 THEN
+                   |UPDATE SET t.b = s.b + t.b
+                   |WHEN MATCHED AND s.c > 'c2' THEN
+                   |UPDATE SET *
+                   |WHEN MATCHED THEN
+                   |DELETE
+                   |WHEN NOT MATCHED AND s.c > 'c9' THEN
+                   |INSERT (t.a, t.b, t.c) VALUES (s.a, s.b * 1.1, s.c)
+                   |WHEN NOT MATCHED THEN
+                   |INSERT *
+                   |WHEN NOT MATCHED BY SOURCE AND t.a = 2 THEN
+                   |UPDATE SET t.b = t.b * 10
+                   |WHEN NOT MATCHED BY SOURCE THEN
+                   |DELETE
+                   |""".stripMargin)
+
+      checkAnswer(
+        spark.sql("SELECT * FROM target ORDER BY a, b"),
+        Row(2, 200, "c2") :: Row(3, 300, "c33") :: Row(5, 550, "c5") :: Row(7, 
700, "c77") :: Row(
+          9,
+          990,
+          "c99") :: Nil
+      )
+    }
+  }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
index 1a4eae51d..8973ea93d 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
@@ -497,6 +497,7 @@ abstract class MergeIntoTableTestBase extends 
PaimonSparkTestBase with PaimonTab
         Row(1, 10, Row("x1", "y")) :: Row(2, 20, Row("x", "y")) :: Nil)
     }
   }
+
   test(s"Paimon MergeInto: update on source eq target condition") {
     withTable("source", "target") {
       Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b", 
"c").createOrReplaceTempView("source")
@@ -517,6 +518,43 @@ abstract class MergeIntoTableTestBase extends 
PaimonSparkTestBase with PaimonTab
         Row(1, 100, "c11") :: Row(2, 20, "c2") :: Nil)
     }
   }
+
+  test(s"Paimon MergeInto: merge into with alias") {
+    withTable("source", "target") {
+
+      Seq((1, 100, "c11"), (3, 300, "c33"), (5, 500, "c55"), (7, 700, "c77"), 
(9, 900, "c99"))
+        .toDF("a", "b", "c")
+        .createOrReplaceTempView("source")
+
+      createTable("target", "a INT, b INT, c STRING", Seq("a"))
+      spark.sql(
+        "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 
'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
+
+      spark.sql(s"""
+                   |MERGE INTO target t
+                   |USING source s
+                   |ON t.a = s.a
+                   |WHEN MATCHED AND t.a = 5 THEN
+                   |UPDATE SET t.b = s.b + t.b
+                   |WHEN MATCHED AND s.c > 'c2' THEN
+                   |UPDATE SET *
+                   |WHEN MATCHED THEN
+                   |DELETE
+                   |WHEN NOT MATCHED AND s.c > 'c9' THEN
+                   |INSERT (t.a, t.b, t.c) VALUES (s.a, s.b * 1.1, s.c)
+                   |WHEN NOT MATCHED THEN
+                   |INSERT *
+                   |""".stripMargin)
+
+      checkAnswer(
+        spark.sql("SELECT * FROM target ORDER BY a, b"),
+        Row(2, 20, "c2") :: Row(3, 300, "c33") :: Row(4, 40, "c4") :: Row(5, 
550, "c5") :: Row(
+          7,
+          700,
+          "c77") :: Row(9, 990, "c99") :: Nil
+      )
+    }
+  }
 }
 
 trait MergeIntoPrimaryKeyTableTest extends PaimonSparkTestBase with 
PaimonPrimaryKeyTable {

Reply via email to