This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 4584885d60 [spark] Optimize MERGE INTO self-merge updates on 
dataEvolution table (#6827)
4584885d60 is described below

commit 4584885d60775d4e229e6471ce88ca47b0e4b1c9
Author: Weitai Li <[email protected]>
AuthorDate: Fri Dec 19 17:55:34 2025 +0800

    [spark] Optimize MERGE INTO self-merge updates on dataEvolution table 
(#6827)
---
 .../MergeIntoPaimonDataEvolutionTable.scala        | 196 ++++++++++++++++-----
 .../paimon/spark/sql/RowTrackingTestBase.scala     |  57 +++++-
 2 files changed, 207 insertions(+), 46 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
index e3c2d3ead0..e2eaed8fe5 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala
@@ -32,7 +32,7 @@ import org.apache.paimon.table.source.DataSplit
 import org.apache.spark.sql.{Dataset, Row, SparkSession}
 import org.apache.spark.sql.PaimonUtils._
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.resolver
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
EqualTo, Expression, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
EqualTo, Expression, ExprId, Literal}
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftOuter}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -114,6 +114,39 @@ case class MergeIntoPaimonDataEvolutionTable(
     map
   }
 
+  /**
+   * Self-Merge pattern:
+   * {{{
+   * MERGE INTO T AS t
+   * USING T AS s
+   * ON t._ROW_ID = s._ROW_ID
+   * WHEN MATCHED THEN UPDATE ... SET ...
+   * }}}
+   * For this pattern, the execution can be optimized to:
+   *
+   * `Scan -> MergeRows -> Write`
+   *
+   * without any extra shuffle, join, or sort.
+   */
+  private lazy val isSelfMergeOnRowId: Boolean = {
+    if (!targetRelation.name.equals(sourceRelation.name)) {
+      false
+    } else {
+      matchedCondition match {
+        case EqualTo(left: AttributeReference, right: AttributeReference)
+            if left.name == ROW_ID_NAME && right.name == ROW_ID_NAME =>
+          true
+        case _ => false
+      }
+    }
+  }
+
+  assert(
+    !(isSelfMergeOnRowId && (notMatchedActions.nonEmpty || 
notMatchedBySourceActions.nonEmpty)),
+    "Self-Merge on _ROW_ID only supports WHEN MATCHED THEN UPDATE. WHEN NOT 
MATCHED and WHEN " +
+      "NOT MATCHED BY SOURCE are not supported."
+  )
+
   lazy val targetRelation: DataSourceV2Relation = 
PaimonRelation.getPaimonRelation(targetTable)
   lazy val sourceRelation: DataSourceV2Relation = 
PaimonRelation.getPaimonRelation(sourceTable)
 
@@ -148,6 +181,18 @@ case class MergeIntoPaimonDataEvolutionTable(
   }
 
   private def targetRelatedSplits(sparkSession: SparkSession): Seq[DataSplit] 
= {
+    // Self-Merge shortcut:
+    // In Self-Merge mode, every row in the table may be updated, so we scan 
all splits.
+    if (isSelfMergeOnRowId) {
+      return table
+        .newSnapshotReader()
+        .read()
+        .splits()
+        .asScala
+        .map(_.asInstanceOf[DataSplit])
+        .toSeq
+    }
+
     val sourceDss = createDataset(sparkSession, sourceRelation)
 
     val firstRowIdsTouched = extractSourceRowIdMapping match {
@@ -226,52 +271,112 @@ case class MergeIntoPaimonDataEvolutionTable(
       allFields ++= action.references.flatMap(r => extractFields(r)).seq
     }
 
-    val allReadFieldsOnTarget = allFields.filter(
-      field =>
-        targetTable.output.exists(attr => attr.exprId.equals(field.exprId))) 
++ metadataColumns
-    val allReadFieldsOnSource =
-      allFields.filter(field => sourceTable.output.exists(attr => 
attr.exprId.equals(field.exprId)))
-
-    val targetReadPlan =
-      touchedFileTargetRelation.copy(targetRelation.table, 
allReadFieldsOnTarget.toSeq)
-    val targetTableProjExprs = targetReadPlan.output :+ Alias(TrueLiteral, 
ROW_FROM_TARGET)()
-    val targetTableProj = Project(targetTableProjExprs, targetReadPlan)
+    val toWrite = if (isSelfMergeOnRowId) {
+      // Self-Merge shortcut:
+      // - Scan the target table only (no source scan, no join), and read all 
columns required by
+      //   merge condition and update expressions.
+      // - Rewrite all source-side AttributeReferences to the corresponding 
target attributes.
+      // - The scan output already satisfies the required partitioning and 
ordering for partial
+      //   updates, so no extra shuffle or sort is needed.
+
+      val targetAttrsDedup: Seq[AttributeReference] =
+        (targetRelation.output ++ targetRelation.metadataOutput)
+          .groupBy(_.exprId)
+          .map { case (_, attrs) => attrs.head }
+          .toSeq
+
+      val neededNames: Set[String] = (allFields ++ 
metadataColumns).map(_.name).toSet
+      val allReadFieldsOnTarget: Seq[AttributeReference] =
+        targetAttrsDedup.filter(a => neededNames.exists(n => resolver(n, 
a.name)))
+      val readPlan = touchedFileTargetRelation.copy(output = 
allReadFieldsOnTarget)
+
+      // Build mapping: source exprId -> target attr (matched by column name).
+      val sourceToTarget = {
+        val targetAttrs = targetRelation.output ++ 
targetRelation.metadataOutput
+        val sourceAttrs = sourceRelation.output ++ 
sourceRelation.metadataOutput
+        sourceAttrs.flatMap {
+          s => targetAttrs.find(t => resolver(t.name, s.name)).map(t => 
s.exprId -> t)
+        }.toMap
+      }
 
-    val sourceReadPlan = sourceRelation.copy(sourceRelation.table, 
allReadFieldsOnSource.toSeq)
-    val sourceTableProjExprs = sourceReadPlan.output :+ Alias(TrueLiteral, 
ROW_FROM_SOURCE)()
-    val sourceTableProj = Project(sourceTableProjExprs, sourceReadPlan)
+      def rewriteSourceToTarget(
+          expr: Expression,
+          m: Map[ExprId, AttributeReference]): Expression = {
+        expr.transform {
+          case a: AttributeReference if m.contains(a.exprId) => m(a.exprId)
+        }
+      }
 
-    val joinPlan =
-      Join(targetTableProj, sourceTableProj, LeftOuter, 
Some(matchedCondition), JoinHint.NONE)
+      val rewrittenUpdateActions: Seq[UpdateAction] = realUpdateActions.map {
+        ua =>
+          val newCond = ua.condition.map(c => rewriteSourceToTarget(c, 
sourceToTarget))
+          val newAssignments = ua.assignments.map {
+            a => Assignment(a.key, rewriteSourceToTarget(a.value, 
sourceToTarget))
+          }
+          ua.copy(condition = newCond, assignments = newAssignments)
+      }
 
-    val rowFromSourceAttr = attribute(ROW_FROM_SOURCE, joinPlan)
-    val rowFromTargetAttr = attribute(ROW_FROM_TARGET, joinPlan)
+      val mergeRows = MergeRows(
+        isSourceRowPresent = TrueLiteral,
+        isTargetRowPresent = TrueLiteral,
+        matchedInstructions = rewrittenUpdateActions
+          .map(
+            action => {
+              Keep(action.condition.getOrElse(TrueLiteral), 
action.assignments.map(a => a.value))
+            }) ++ Seq(Keep(TrueLiteral, output)),
+        notMatchedInstructions = Nil,
+        notMatchedBySourceInstructions = Seq(Keep(TrueLiteral, output)),
+        checkCardinality = false,
+        output = output,
+        child = readPlan
+      )
 
-    val mergeRows = MergeRows(
-      isSourceRowPresent = rowFromSourceAttr,
-      isTargetRowPresent = rowFromTargetAttr,
-      matchedInstructions = realUpdateActions
-        .map(
-          action => {
-            Keep(action.condition.getOrElse(TrueLiteral), 
action.assignments.map(a => a.value))
-          }) ++ Seq(Keep(TrueLiteral, output)),
-      notMatchedInstructions = Nil,
-      notMatchedBySourceInstructions = Seq(Keep(TrueLiteral, output)).toSeq,
-      checkCardinality = false,
-      output = output,
-      child = joinPlan
-    )
+      val withFirstRowId = addFirstRowId(sparkSession, mergeRows)
+      assert(withFirstRowId.schema.fields.length == updateColumnsSorted.size + 
2)
+      withFirstRowId
+    } else {
+      val allReadFieldsOnTarget = allFields.filter(
+        field =>
+          targetTable.output.exists(attr => attr.exprId.equals(field.exprId))) 
++ metadataColumns
+      val allReadFieldsOnSource =
+        allFields.filter(
+          field => sourceTable.output.exists(attr => 
attr.exprId.equals(field.exprId)))
+
+      val targetReadPlan =
+        touchedFileTargetRelation.copy(output = allReadFieldsOnTarget.toSeq)
+      val targetTableProjExprs = targetReadPlan.output :+ Alias(TrueLiteral, 
ROW_FROM_TARGET)()
+      val targetTableProj = Project(targetTableProjExprs, targetReadPlan)
+
+      val sourceReadPlan = sourceRelation.copy(output = 
allReadFieldsOnSource.toSeq)
+      val sourceTableProjExprs = sourceReadPlan.output :+ Alias(TrueLiteral, 
ROW_FROM_SOURCE)()
+      val sourceTableProj = Project(sourceTableProjExprs, sourceReadPlan)
+
+      val joinPlan =
+        Join(targetTableProj, sourceTableProj, LeftOuter, 
Some(matchedCondition), JoinHint.NONE)
+      val rowFromSourceAttr = attribute(ROW_FROM_SOURCE, joinPlan)
+      val rowFromTargetAttr = attribute(ROW_FROM_TARGET, joinPlan)
+      val mergeRows = MergeRows(
+        isSourceRowPresent = rowFromSourceAttr,
+        isTargetRowPresent = rowFromTargetAttr,
+        matchedInstructions = realUpdateActions
+          .map(
+            action => {
+              Keep(action.condition.getOrElse(TrueLiteral), 
action.assignments.map(a => a.value))
+            }) ++ Seq(Keep(TrueLiteral, output)),
+        notMatchedInstructions = Nil,
+        notMatchedBySourceInstructions = Seq(Keep(TrueLiteral, output)).toSeq,
+        checkCardinality = false,
+        output = output,
+        child = joinPlan
+      )
+      val withFirstRowId = addFirstRowId(sparkSession, mergeRows)
+      assert(withFirstRowId.schema.fields.length == updateColumnsSorted.size + 
2)
+      withFirstRowId
+        .repartitionByRange(col(FIRST_ROW_ID_NAME))
+        .sortWithinPartitions(FIRST_ROW_ID_NAME, ROW_ID_NAME)
+    }
 
-    val firstRowIdsFinal = firstRowIds
-    val firstRowIdUdf = udf((rowId: Long) => 
floorBinarySearch(firstRowIdsFinal, rowId))
-    val firstRowIdColumn = firstRowIdUdf(col(ROW_ID_NAME))
-    val toWrite =
-      createDataset(sparkSession, mergeRows).withColumn(FIRST_ROW_ID_NAME, 
firstRowIdColumn)
-    assert(toWrite.schema.fields.length == updateColumnsSorted.size + 2)
-    val sortedDs = toWrite
-      .repartitionByRange(firstRowIdColumn)
-      .sortWithinPartitions(FIRST_ROW_ID_NAME, ROW_ID_NAME)
-    partialColumnWriter.writePartialFields(sortedDs, 
updateColumnsSorted.map(_.name))
+    partialColumnWriter.writePartialFields(toWrite, 
updateColumnsSorted.map(_.name))
   }
 
   private def insertActionInvoke(
@@ -394,6 +499,13 @@ case class MergeIntoPaimonDataEvolutionTable(
   private def attribute(name: String, plan: LogicalPlan) =
     plan.output.find(attr => resolver(name, attr.name)).get
 
+  private def addFirstRowId(sparkSession: SparkSession, plan: LogicalPlan): 
Dataset[Row] = {
+    assert(plan.output.exists(_.name.equals(ROW_ID_NAME)))
+    val firstRowIdsFinal = firstRowIds
+    val firstRowIdUdf = udf((rowId: Long) => 
floorBinarySearch(firstRowIdsFinal, rowId))
+    val firstRowIdColumn = firstRowIdUdf(col(ROW_ID_NAME))
+    createDataset(sparkSession, plan).withColumn(FIRST_ROW_ID_NAME, 
firstRowIdColumn)
+  }
 }
 
 object MergeIntoPaimonDataEvolutionTable {
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
index 611a936b6f..4d819cd7b6 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala
@@ -22,10 +22,8 @@ import org.apache.paimon.Snapshot.CommitKind
 import org.apache.paimon.spark.PaimonSparkTestBase
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
-import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
-import org.apache.spark.sql.execution.joins.BaseJoinExec
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, 
RepartitionByExpression, Sort}
+import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.util.QueryExecutionListener
 
 import scala.collection.mutable
@@ -436,6 +434,57 @@ abstract class RowTrackingTestBase extends 
PaimonSparkTestBase {
     }
   }
 
+  test("Data Evolution: merge into table with data-evolution for Self-Merge 
with _ROW_ID shortcut") {
+    withTable("target") {
+      sql(
+        "CREATE TABLE target (a INT, b INT, c STRING) TBLPROPERTIES 
('row-tracking.enabled' = 'true', 'data-evolution.enabled' = 'true')")
+      sql(
+        "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 
'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
+
+      val capturedPlans: mutable.ListBuffer[LogicalPlan] = 
mutable.ListBuffer.empty
+      val listener = new QueryExecutionListener {
+        override def onSuccess(funcName: String, qe: QueryExecution, 
durationNs: Long): Unit = {
+          capturedPlans += qe.analyzed
+        }
+        override def onFailure(funcName: String, qe: QueryExecution, 
exception: Exception): Unit = {
+          capturedPlans += qe.analyzed
+        }
+      }
+      spark.listenerManager.register(listener)
+      sql(s"""
+             |MERGE INTO target
+             |USING target AS source
+             |ON target._ROW_ID = source._ROW_ID
+             |WHEN MATCHED AND target.a = 5 THEN UPDATE SET b = source.b + 
target.b
+             |WHEN MATCHED AND source.c > 'c2' THEN UPDATE SET b = source.b * 
3,
+             |c = concat(target.c, source.c)
+             |""".stripMargin)
+      // Assert no shuffle/join/sort was used in
+      // 
'org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.updateActionInvoke'
+      assert(
+        capturedPlans.forall(
+          plan =>
+            plan.collectFirst {
+              case p: Join => p
+              case p: Sort => p
+              case p: RepartitionByExpression => p
+            }.isEmpty),
+        s"Found unexpected Join/Sort/Exchange in plan:\n${capturedPlans.head}"
+      )
+      spark.listenerManager.unregister(listener)
+
+      checkAnswer(
+        sql("SELECT *, _ROW_ID, _SEQUENCE_NUMBER FROM target ORDER BY a"),
+        Seq(
+          Row(1, 10, "c1", 0, 2),
+          Row(2, 20, "c2", 1, 2),
+          Row(3, 90, "c3c3", 2, 2),
+          Row(4, 120, "c4c4", 3, 2),
+          Row(5, 100, "c5", 4, 2))
+      )
+    }
+  }
+
   test("Data Evolution: update table throws exception") {
     withTable("t") {
       sql(

Reply via email to