This is an automated email from the ASF dual-hosted git repository.

biyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 875cadb28 [spark] Support merge into for append table (#3917)
875cadb28 is described below

commit 875cadb284d0d35987974c33fbd0c1548ee64669
Author: Zouxxyy <[email protected]>
AuthorDate: Fri Aug 9 16:08:31 2024 +0800

    [spark] Support merge into for append table (#3917)
---
 .../paimon/spark/sql/MergeIntoTableTest.scala      |  12 +-
 .../paimon/spark/sql/MergeIntoTableTest.scala      |  12 +-
 .../paimon/spark/sql/MergeIntoTableTest.scala      |  14 +-
 .../paimon/spark/sql/MergeIntoTableTest.scala      |  14 +-
 .../catalyst/analysis/PaimonMergeIntoBase.scala    |  18 +-
 .../spark/catalyst/analysis/RowLevelOp.scala       |   2 +-
 .../commands/DeleteFromPaimonTableCommand.scala    |  20 +-
 .../spark/commands/MergeIntoPaimonTable.scala      | 203 +++++++++++++--------
 .../paimon/spark/commands/PaimonCommand.scala      |  32 +++-
 .../spark/commands/UpdatePaimonTableCommand.scala  |  18 +-
 .../org/apache/paimon/spark/PaimonTableTest.scala  |  81 ++++----
 .../paimon/spark/sql/MergeIntoTableTestBase.scala  |  72 +++-----
 12 files changed, 298 insertions(+), 200 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
 
b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
index 1a7dffaf1..f1f0d8c06 100644
--- 
a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
+++ 
b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
@@ -18,12 +18,22 @@
 
 package org.apache.paimon.spark.sql
 
-import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
+import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, 
PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
 
 class MergeIntoPrimaryKeyBucketedTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with PaimonPrimaryKeyBucketedTableTest {}
 
 class MergeIntoPrimaryKeyNonBucketTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with PaimonPrimaryKeyNonBucketTableTest {}
+
+class MergeIntoAppendBucketedTableTest
+  extends MergeIntoTableTestBase
+  with PaimonAppendBucketedTableTest {}
+
+class MergeIntoAppendNonBucketedTableTest
+  extends MergeIntoTableTestBase
+  with PaimonAppendNonBucketTableTest {}
diff --git 
a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
 
b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
index 1a7dffaf1..f1f0d8c06 100644
--- 
a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
+++ 
b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
@@ -18,12 +18,22 @@
 
 package org.apache.paimon.spark.sql
 
-import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
+import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, 
PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
 
 class MergeIntoPrimaryKeyBucketedTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with PaimonPrimaryKeyBucketedTableTest {}
 
 class MergeIntoPrimaryKeyNonBucketTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with PaimonPrimaryKeyNonBucketTableTest {}
+
+class MergeIntoAppendBucketedTableTest
+  extends MergeIntoTableTestBase
+  with PaimonAppendBucketedTableTest {}
+
+class MergeIntoAppendNonBucketedTableTest
+  extends MergeIntoTableTestBase
+  with PaimonAppendNonBucketTableTest {}
diff --git 
a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
 
b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
index 13b79e744..e1cfe3a39 100644
--- 
a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
+++ 
b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
@@ -18,14 +18,26 @@
 
 package org.apache.paimon.spark.sql
 
-import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
+import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, 
PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
 
 class MergeIntoPrimaryKeyBucketedTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with MergeIntoNotMatchedBySourceTest
   with PaimonPrimaryKeyBucketedTableTest {}
 
 class MergeIntoPrimaryKeyNonBucketTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with MergeIntoNotMatchedBySourceTest
   with PaimonPrimaryKeyNonBucketTableTest {}
+
+class MergeIntoAppendBucketedTableTest
+  extends MergeIntoTableTestBase
+  with MergeIntoNotMatchedBySourceTest
+  with PaimonAppendBucketedTableTest {}
+
+class MergeIntoAppendNonBucketedTableTest
+  extends MergeIntoTableTestBase
+  with MergeIntoNotMatchedBySourceTest
+  with PaimonAppendNonBucketTableTest {}
diff --git 
a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
 
b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
index 13b79e744..e1cfe3a39 100644
--- 
a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
+++ 
b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
@@ -18,14 +18,26 @@
 
 package org.apache.paimon.spark.sql
 
-import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
+import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, 
PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, 
PaimonPrimaryKeyNonBucketTableTest}
 
 class MergeIntoPrimaryKeyBucketedTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with MergeIntoNotMatchedBySourceTest
   with PaimonPrimaryKeyBucketedTableTest {}
 
 class MergeIntoPrimaryKeyNonBucketTableTest
   extends MergeIntoTableTestBase
+  with MergeIntoPrimaryKeyTableTest
   with MergeIntoNotMatchedBySourceTest
   with PaimonPrimaryKeyNonBucketTableTest {}
+
+class MergeIntoAppendBucketedTableTest
+  extends MergeIntoTableTestBase
+  with MergeIntoNotMatchedBySourceTest
+  with PaimonAppendBucketedTableTest {}
+
+class MergeIntoAppendNonBucketedTableTest
+  extends MergeIntoTableTestBase
+  with MergeIntoNotMatchedBySourceTest
+  with PaimonAppendNonBucketTableTest {}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala
index c07b58399..ba6108395 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala
@@ -18,7 +18,6 @@
 
 package org.apache.paimon.spark.catalyst.analysis
 
-import org.apache.paimon.CoreOptions
 import org.apache.paimon.spark.SparkTable
 import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper
 import org.apache.paimon.spark.commands.MergeIntoPaimonTable
@@ -28,6 +27,8 @@ import 
org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeS
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 
+import scala.collection.JavaConverters._
+
 trait PaimonMergeIntoBase
   extends Rule[LogicalPlan]
   with RowLevelHelper
@@ -52,13 +53,14 @@ trait PaimonMergeIntoBase
         merge.notMatchedActions.flatMap(_.condition).foreach(checkCondition)
 
         val updateActions = merge.matchedActions.collect { case a: 
UpdateAction => a }
-        val primaryKeys = 
v2Table.properties().get(CoreOptions.PRIMARY_KEY.key).split(",")
-        checkUpdateActionValidity(
-          AttributeSet(targetOutput),
-          merge.mergeCondition,
-          updateActions,
-          primaryKeys)
-
+        val primaryKeys = v2Table.getTable.primaryKeys().asScala
+        if (primaryKeys.nonEmpty) {
+          checkUpdateActionValidity(
+            AttributeSet(targetOutput),
+            merge.mergeCondition,
+            updateActions,
+            primaryKeys)
+        }
         val alignedMatchedActions =
           merge.matchedActions.map(checkAndAlignActionAssignment(_, 
targetOutput))
         val alignedNotMatchedActions =
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala
index 41881b7b7..3e1e2b52d 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala
@@ -70,6 +70,6 @@ case object MergeInto extends RowLevelOp {
   override val supportedMergeEngine: Seq[MergeEngine] =
     Seq(MergeEngine.DEDUPLICATE, MergeEngine.PARTIAL_UPDATE)
 
-  override val supportAppendOnlyTable: Boolean = false
+  override val supportAppendOnlyTable: Boolean = true
 
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala
index 2aef8e576..cc440dd5c 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala
@@ -18,15 +18,12 @@
 
 package org.apache.paimon.spark.commands
 
-import org.apache.paimon.CoreOptions
 import org.apache.paimon.CoreOptions.MergeEngine
-import org.apache.paimon.spark.PaimonSplitScan
-import org.apache.paimon.spark.catalyst.Compatibility
 import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper
 import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand
 import org.apache.paimon.spark.schema.SparkSystemColumns.ROW_KIND_COL
 import org.apache.paimon.spark.util.SQLHelper
-import org.apache.paimon.table.{BucketMode, FileStoreTable}
+import org.apache.paimon.table.FileStoreTable
 import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage}
 import org.apache.paimon.types.RowKind
 import org.apache.paimon.utils.InternalRowPartitionComputer
@@ -144,20 +141,11 @@ case class DeleteFromPaimonTableCommand(
         findTouchedFiles(candidateDataSplits, condition, relation, 
sparkSession)
 
       // Step3: the smallest range of data files that need to be rewritten.
-      val touchedFiles = touchedFilePaths.map {
-        file =>
-          dataFilePathToMeta.getOrElse(file, throw new 
RuntimeException(s"Missing file: $file"))
-      }
+      val (touchedFiles, newRelation) =
+        createNewRelation(touchedFilePaths, dataFilePathToMeta, relation)
 
       // Step4: build a dataframe that contains the unchanged data, and write 
out them.
-      val touchedDataSplits =
-        SparkDataFileMeta.convertToDataSplits(touchedFiles, rawConvertible = 
true, pathFactory)
-      val toRewriteScanRelation = Filter(
-        Not(condition),
-        Compatibility.createDataSourceV2ScanRelation(
-          relation,
-          PaimonSplitScan(table, touchedDataSplits),
-          relation.output))
+      val toRewriteScanRelation = Filter(Not(condition), newRelation)
       val data = createDataset(sparkSession, toRewriteScanRelation)
 
       // only write new files, should have no compaction
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala
index a06bc437d..5fec8b997 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala
@@ -18,12 +18,13 @@
 
 package org.apache.paimon.spark.commands
 
-import org.apache.paimon.options.Options
-import org.apache.paimon.spark.{InsertInto, SparkTable}
+import org.apache.paimon.spark.SparkTable
+import org.apache.paimon.spark.catalyst.analysis.PaimonRelation
 import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand
 import org.apache.paimon.spark.schema.SparkSystemColumns
-import org.apache.paimon.spark.util.EncoderUtils
+import org.apache.paimon.spark.util.{EncoderUtils, SparkRowUtils}
 import org.apache.paimon.table.FileStoreTable
+import org.apache.paimon.table.sink.CommitMessage
 import org.apache.paimon.types.RowKind
 
 import org.apache.spark.sql.{Column, Dataset, Row, SparkSession}
@@ -33,10 +34,12 @@ import 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
BasePredicate, Expression, Literal, UnsafeProjection}
 import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
 import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
-import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, Filter, 
InsertAction, LogicalPlan, MergeAction, UpdateAction}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id, 
sum}
 import org.apache.spark.sql.types.{ByteType, StructField, StructType}
 
+import scala.collection.mutable
+
 /** Command for Merge Into. */
 case class MergeIntoPaimonTable(
     v2Table: SparkTable,
@@ -55,7 +58,9 @@ case class MergeIntoPaimonTable(
 
   lazy val tableSchema: StructType = v2Table.schema
 
-  lazy val filteredTargetPlan: LogicalPlan = {
+  private lazy val writer = PaimonSparkWriter(table)
+
+  private lazy val filteredTargetPlan: LogicalPlan = {
     val filtersOnlyTarget = getExpressionOnlyRelated(mergeCondition, 
targetTable)
     filtersOnlyTarget
       .map(Filter.apply(_, targetTable))
@@ -63,25 +68,75 @@ case class MergeIntoPaimonTable(
   }
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
-
     // Avoid that more than one source rows match the same target row.
     checkMatchRationality(sparkSession)
+    val commitMessages = if (withPrimaryKeys) {
+      performMergeForPkTable(sparkSession)
+    } else {
+      performMergeForNonPkTable(sparkSession)
+    }
+    writer.commit(commitMessages)
+    Seq.empty[Row]
+  }
 
-    val changed = constructChangedRows(sparkSession)
+  private def performMergeForPkTable(sparkSession: SparkSession): 
Seq[CommitMessage] = {
+    writer.write(
+      constructChangedRows(sparkSession, createDataset(sparkSession, 
filteredTargetPlan)))
+  }
 
-    WriteIntoPaimonTable(
-      table,
-      InsertInto,
-      changed,
-      new Options()
-    ).run(sparkSession)
+  private def performMergeForNonPkTable(sparkSession: SparkSession): 
Seq[CommitMessage] = {
+    val targetDS = createDataset(sparkSession, filteredTargetPlan)
+    val sourceDS = createDataset(sparkSession, sourceTable)
 
-    Seq.empty[Row]
+    val targetFilePaths: Array[String] = findTouchedFiles(targetDS, 
sparkSession)
+
+    val touchedFilePathsSet = mutable.Set.empty[String]
+    def hasUpdate(actions: Seq[MergeAction]): Boolean = {
+      actions.exists {
+        case _: UpdateAction | _: DeleteAction => true
+        case _ => false
+      }
+    }
+    if (hasUpdate(matchedActions)) {
+      touchedFilePathsSet ++= findTouchedFiles(
+        targetDS.join(sourceDS, new Column(mergeCondition), "inner"),
+        sparkSession)
+    }
+    if (hasUpdate(notMatchedBySourceActions)) {
+      touchedFilePathsSet ++= findTouchedFiles(
+        targetDS.join(sourceDS, new Column(mergeCondition), "left_anti"),
+        sparkSession)
+    }
+
+    val touchedFilePaths: Array[String] = touchedFilePathsSet.toArray
+    val unTouchedFilePaths = 
targetFilePaths.filterNot(touchedFilePaths.contains)
+
+    val relation = PaimonRelation.getPaimonRelation(targetTable)
+    val dataFilePathToMeta = 
candidateFileMap(findCandidateDataSplits(TrueLiteral, relation.output))
+    val (touchedFiles, touchedFileRelation) =
+      createNewRelation(touchedFilePaths, dataFilePathToMeta, relation)
+    val (_, unTouchedFileRelation) =
+      createNewRelation(unTouchedFilePaths, dataFilePathToMeta, relation)
+
+    // Add FILE_TOUCHED_COL to mark the row as coming from the touched file, 
if the row has not been
+    // modified and was from touched file, it should be kept too.
+    val targetDSWithFileTouchedCol = createDataset(sparkSession, 
touchedFileRelation)
+      .withColumn(FILE_TOUCHED_COL, lit(true))
+      .union(
+        createDataset(sparkSession, 
unTouchedFileRelation).withColumn(FILE_TOUCHED_COL, lit(false)))
+
+    val addCommitMessage =
+      writer.write(constructChangedRows(sparkSession, 
targetDSWithFileTouchedCol))
+    val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles)
+
+    addCommitMessage ++ deletedCommitMessage
   }
 
   /** Get a Dataset where each of Row has an additional column called 
_row_kind_. */
-  private def constructChangedRows(sparkSession: SparkSession): Dataset[Row] = 
{
-    val targetDS = createDataset(sparkSession, filteredTargetPlan)
+  private def constructChangedRows(
+      sparkSession: SparkSession,
+      targetDataset: Dataset[Row]): Dataset[Row] = {
+    val targetDS = targetDataset
       .withColumn(TARGET_ROW_COL, lit(true))
 
     val sourceDS = createDataset(sparkSession, sourceTable)
@@ -100,29 +155,30 @@ case class MergeIntoPaimonTable(
     val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral))
     val notMatchedExprs = 
notMatchedActions.map(_.condition.getOrElse(TrueLiteral))
     val notMatchedBySourceExprs = 
notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral))
-    val matchedOutputs = matchedActions.map {
-      case UpdateAction(_, assignments) =>
-        assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
-      case DeleteAction(_) =>
-        targetOutput :+ Literal(RowKind.DELETE.toByteValue)
-      case _ =>
-        throw new RuntimeException("should not be here.")
-    }
-    val notMatchedBySourceOutputs = notMatchedBySourceActions.map {
-      case UpdateAction(_, assignments) =>
-        assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
-      case DeleteAction(_) =>
-        targetOutput :+ Literal(RowKind.DELETE.toByteValue)
-      case _ =>
-        throw new RuntimeException("should not be here.")
-    }
-    val notMatchedOutputs = notMatchedActions.map {
-      case InsertAction(_, assignments) =>
-        assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue)
-      case _ =>
-        throw new RuntimeException("should not be here.")
-    }
     val noopOutput = targetOutput :+ Alias(Literal(NOOP_ROW_KIND_VALUE), 
ROW_KIND_COL)()
+    val keepOutput = targetOutput :+ 
Alias(Literal(RowKind.INSERT.toByteValue), ROW_KIND_COL)()
+
+    def processMergeActions(actions: Seq[MergeAction], applyOnTargetTable: 
Boolean) = {
+      actions.map {
+        case UpdateAction(_, assignments) if applyOnTargetTable =>
+          assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
+        case DeleteAction(_) if applyOnTargetTable =>
+          if (withPrimaryKeys) {
+            targetOutput :+ Literal(RowKind.DELETE.toByteValue)
+          } else {
+            noopOutput
+          }
+        case InsertAction(_, assignments) if !applyOnTargetTable =>
+          assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue)
+        case _ =>
+          throw new RuntimeException("should not be here.")
+      }
+    }
+
+    val matchedOutputs = processMergeActions(matchedActions, 
applyOnTargetTable = true)
+    val notMatchedBySourceOutputs =
+      processMergeActions(notMatchedBySourceActions, applyOnTargetTable = true)
+    val notMatchedOutputs = processMergeActions(notMatchedActions, 
applyOnTargetTable = false)
     val outputSchema = StructType(tableSchema.fields :+ 
StructField(ROW_KIND_COL, ByteType))
 
     val joinedRowEncoder = EncoderUtils.encode(joinedPlan.schema)
@@ -139,10 +195,11 @@ case class MergeIntoPaimonTable(
       notMatchedExprs,
       notMatchedOutputs,
       noopOutput,
+      keepOutput,
       joinedRowEncoder,
       outputEncoder
     )
-    joinedDS.mapPartitions(processor.processPartition)(outputEncoder)
+    joinedDS.mapPartitions(processor.processPartition)(outputEncoder).toDF()
   }
 
   private def checkMatchRationality(sparkSession: SparkSession): Unit = {
@@ -159,21 +216,23 @@ case class MergeIntoPaimonTable(
         .count()
       if (count > 0) {
         throw new RuntimeException(
-          "Can't execute this MergeInto when there are some target rows that 
each of them match more then one source rows. It may lead to an unexpected 
result.")
+          "Can't execute this MergeInto when there are some target rows that 
each of " +
+            "them match more then one source rows. It may lead to an 
unexpected result.")
       }
     }
   }
 }
 
 object MergeIntoPaimonTable {
-  val ROW_ID_COL = "_row_id_"
-  val SOURCE_ROW_COL = "_source_row_"
-  val TARGET_ROW_COL = "_target_row_"
+  private val ROW_ID_COL = "_row_id_"
+  private val SOURCE_ROW_COL = "_source_row_"
+  private val TARGET_ROW_COL = "_target_row_"
+  private val FILE_TOUCHED_COL = "_file_touched_col_"
   // +I, +U, -U, -D
-  val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL
-  val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte
+  private val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL
+  private val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte
 
-  case class MergeIntoProcessor(
+  private case class MergeIntoProcessor(
       joinedAttributes: Seq[Attribute],
       targetRowHasNoMatch: Expression,
       sourceRowHasNoMatch: Expression,
@@ -184,10 +243,15 @@ object MergeIntoPaimonTable {
       notMatchedConditions: Seq[Expression],
       notMatchedOutputs: Seq[Seq[Expression]],
       noopCopyOutput: Seq[Expression],
+      keepOutput: Seq[Expression],
       joinedRowEncoder: ExpressionEncoder[Row],
       outputRowEncoder: ExpressionEncoder[Row]
   ) extends Serializable {
 
+    private val file_touched_col_index: Int =
+      SparkRowUtils.getFieldIndex(joinedRowEncoder.schema, FILE_TOUCHED_COL)
+    private val row_kind_col_index: Int = 
outputRowEncoder.schema.fieldIndex(ROW_KIND_COL)
+
     private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = 
{
       UnsafeProjection.create(exprs, joinedAttributes)
     }
@@ -196,8 +260,12 @@ object MergeIntoPaimonTable {
       GeneratePredicate.generate(expr, joinedAttributes)
     }
 
+    private def fromTouchedFile(row: InternalRow): Boolean = {
+      file_touched_col_index != -1 && row.getBoolean(file_touched_col_index)
+    }
+
     private def unusedRow(row: InternalRow): Boolean = {
-      row.getByte(outputRowEncoder.schema.fieldIndex(ROW_KIND_COL)) == 
NOOP_ROW_KIND_VALUE
+      row.getByte(row_kind_col_index) == NOOP_ROW_KIND_VALUE
     }
 
     def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
@@ -210,38 +278,29 @@ object MergeIntoPaimonTable {
       val notMatchedPreds = notMatchedConditions.map(generatePredicate)
       val notMatchedProjs = notMatchedOutputs.map(generateProjection)
       val noopCopyProj = generateProjection(noopCopyOutput)
+      val keepProj = generateProjection(keepOutput)
       val outputProj = UnsafeProjection.create(outputRowEncoder.schema)
 
       def processRow(inputRow: InternalRow): InternalRow = {
-        if (targetRowHasNoMatchPred.eval(inputRow)) {
-          val pair = notMatchedBySourcePreds.zip(notMatchedBySourceProjs).find 
{
-            case (predicate, _) => predicate.eval(inputRow)
+        def applyPreds(preds: Seq[BasePredicate], projs: 
Seq[UnsafeProjection]): InternalRow = {
+          preds.zip(projs).find { case (predicate, _) => 
predicate.eval(inputRow) } match {
+            case Some((_, projections)) => projections.apply(inputRow)
+            case None =>
+              // keep the row if it is from touched file and not be matched
+              if (fromTouchedFile(inputRow)) {
+                keepProj.apply(inputRow)
+              } else {
+                noopCopyProj.apply(inputRow)
+              }
           }
+        }
 
-          pair match {
-            case Some((_, projections)) =>
-              projections.apply(inputRow)
-            case None => noopCopyProj.apply(inputRow)
-          }
+        if (targetRowHasNoMatchPred.eval(inputRow)) {
+          applyPreds(notMatchedBySourcePreds, notMatchedBySourceProjs)
         } else if (sourceRowHasNoMatchPred.eval(inputRow)) {
-          val pair = notMatchedPreds.zip(notMatchedProjs).find {
-            case (predicate, _) => predicate.eval(inputRow)
-          }
-
-          pair match {
-            case Some((_, projections)) =>
-              projections.apply(inputRow)
-            case None => noopCopyProj.apply(inputRow)
-          }
+          applyPreds(notMatchedPreds, notMatchedProjs)
         } else {
-          val pair =
-            matchedPreds.zip(matchedProjs).find { case (predicate, _) => 
predicate.eval(inputRow) }
-
-          pair match {
-            case Some((_, projections)) =>
-              projections.apply(inputRow)
-            case None => noopCopyProj.apply(inputRow)
-          }
+          applyPreds(matchedPreds, matchedProjs)
         }
       }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
index 4a42e4f46..8e341a657 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
@@ -33,12 +33,12 @@ import org.apache.paimon.table.source.DataSplit
 import org.apache.paimon.types.RowType
 import org.apache.paimon.utils.SerializationUtils
 
-import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
 import org.apache.spark.sql.PaimonUtils.createDataset
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
 import org.apache.spark.sql.catalyst.plans.logical.{Filter => 
FilterLogicalNode, LogicalPlan, Project}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
DataSourceV2ScanRelation}
 import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, Filter}
 
 import java.net.URI
@@ -101,7 +101,7 @@ trait PaimonCommand extends WithFileStoreTable with 
ExpressionHelper {
       output: Seq[Attribute]): Seq[DataSplit] = {
     // low level snapshot reader, it can not be affected by 'scan.mode'
     val snapshotReader = table.newSnapshotReader()
-    if (condition == TrueLiteral) {
+    if (condition != TrueLiteral) {
       val filter =
         convertConditionToPaimonPredicate(condition, output, rowType, 
ignoreFailure = true)
       filter.foreach(snapshotReader.withFilter)
@@ -115,8 +115,6 @@ trait PaimonCommand extends WithFileStoreTable with 
ExpressionHelper {
       condition: Expression,
       relation: DataSourceV2Relation,
       sparkSession: SparkSession): Array[String] = {
-    import sparkSession.implicits._
-
     for (split <- candidateDataSplits) {
       if (!split.rawConvertible()) {
         throw new IllegalArgumentException(
@@ -126,7 +124,14 @@ trait PaimonCommand extends WithFileStoreTable with 
ExpressionHelper {
 
     val metadataCols = Seq(FILE_PATH)
     val filteredRelation = createNewScanPlan(candidateDataSplits, condition, 
relation, metadataCols)
-    createDataset(sparkSession, filteredRelation)
+    findTouchedFiles(createDataset(sparkSession, filteredRelation), 
sparkSession)
+  }
+
+  protected def findTouchedFiles(
+      dataset: Dataset[Row],
+      sparkSession: SparkSession): Array[String] = {
+    import sparkSession.implicits._
+    dataset
       .select(FILE_PATH_COLUMN)
       .distinct()
       .as[String]
@@ -134,6 +139,21 @@ trait PaimonCommand extends WithFileStoreTable with 
ExpressionHelper {
       .map(relativePath)
   }
 
+  protected def createNewRelation(
+      filePaths: Array[String],
+      filePathToMeta: Map[String, SparkDataFileMeta],
+      relation: DataSourceV2Relation): (Array[SparkDataFileMeta], 
DataSourceV2ScanRelation) = {
+    val files = filePaths.map(
+      file => filePathToMeta.getOrElse(file, throw new 
RuntimeException(s"Missing file: $file")))
+    val touchedDataSplits =
+      SparkDataFileMeta.convertToDataSplits(files, rawConvertible = true, 
fileStore.pathFactory())
+    val newRelation = Compatibility.createDataSourceV2ScanRelation(
+      relation,
+      PaimonSplitScan(table, touchedDataSplits),
+      relation.output)
+    (files, newRelation)
+  }
+
   /** Notice that, the key is a relative path, not just the file name. */
   protected def candidateFileMap(
       candidateDataSplits: Seq[DataSplit]): Map[String, SparkDataFileMeta] = {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
index 6c7d07bf5..dd88f388c 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.PaimonUtils.createDataset
 import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, If}
 import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
 import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Filter, 
Project, SupportsSubquery}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
DataSourceV2ScanRelation}
 import org.apache.spark.sql.functions.lit
 
 case class UpdatePaimonTableCommand(
@@ -116,15 +116,11 @@ case class UpdatePaimonTableCommand(
           findTouchedFiles(candidateDataSplits, condition, relation, 
sparkSession)
 
         // Step3: the smallest range of data files that need to be rewritten.
-        val touchedFiles = touchedFilePaths.map {
-          file =>
-            dataFilePathToMeta.getOrElse(file, throw new 
RuntimeException(s"Missing file: $file"))
-        }
+        val (touchedFiles, touchedFileRelation) =
+          createNewRelation(touchedFilePaths, dataFilePathToMeta, relation)
 
         // Step4: build a dataframe that contains the unchanged and updated 
data, and write out them.
-        val touchedDataSplits =
-          SparkDataFileMeta.convertToDataSplits(touchedFiles, rawConvertible = 
true, pathFactory)
-        val addCommitMessage = writeUpdatedAndUnchangedData(sparkSession, 
touchedDataSplits)
+        val addCommitMessage = writeUpdatedAndUnchangedData(sparkSession, 
touchedFileRelation)
 
         // Step5: convert the deleted files that need to be wrote to commit 
message.
         val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles)
@@ -157,7 +153,7 @@ case class UpdatePaimonTableCommand(
 
   private def writeUpdatedAndUnchangedData(
       sparkSession: SparkSession,
-      touchedDataSplits: Array[DataSplit]): Seq[CommitMessage] = {
+      toUpdateScanRelation: DataSourceV2ScanRelation): Seq[CommitMessage] = {
     val updateColumns = updateExpressions.zip(relation.output).map {
       case (update, origin) =>
         val updated = if (condition == TrueLiteral) {
@@ -168,10 +164,6 @@ case class UpdatePaimonTableCommand(
         new Column(updated).as(origin.name, origin.metadata)
     }
 
-    val toUpdateScanRelation = Compatibility.createDataSourceV2ScanRelation(
-      relation,
-      PaimonSplitScan(table, touchedDataSplits),
-      relation.output)
     val data = createDataset(sparkSession, 
toUpdateScanRelation).select(updateColumns: _*)
     writer.write(data)
   }
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala
index 53f41833f..0477bcbaf 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala
@@ -20,24 +20,39 @@ package org.apache.paimon.spark
 
 import org.apache.spark.sql.test.SharedSparkSession
 
-import scala.collection.mutable
-
 trait PaimonTableTest extends SharedSparkSession {
 
   val bucket: Int
 
-  def appendPrimaryKey(primaryKeys: Seq[String], props: mutable.Map[String, 
String]): Unit
-
+  def initProps(primaryOrBucketKeys: Seq[String], partitionKeys: Seq[String]): 
Map[String, String]
+
+  /**
+   * Create a table configured by the given parameters.
+   *
+   * @param tableName
+   *   table name
+   * @param columns
+   *   columns string, e.g. "a INT, b INT, c STRING"
+   * @param primaryOrBucketKeys
+   *   for [[PaimonPrimaryKeyTable]] they are `primary-key`, if you want to 
specify additional
+   *   `bucket-key`, you can specify that in extraProps. for 
[[PaimonAppendTable]] they are
+   *   `bucket-key`
+   * @param partitionKeys
+   *   partition keys seq
+   * @param extraProps
+   *   extra properties map
+   */
   def createTable(
       tableName: String,
       columns: String,
-      primaryKeys: Seq[String],
+      primaryOrBucketKeys: Seq[String],
       partitionKeys: Seq[String] = Seq.empty,
-      props: Map[String, String] = Map.empty): Unit = {
-    val newProps: mutable.Map[String, String] =
-      mutable.Map.empty[String, String] ++ Map("bucket" -> bucket.toString) ++ 
props
-    appendPrimaryKey(primaryKeys, newProps)
-    createTable0(tableName, columns, partitionKeys, newProps.toMap)
+      extraProps: Map[String, String] = Map.empty): Unit = {
+    createTable0(
+      tableName,
+      columns,
+      partitionKeys,
+      initProps(primaryOrBucketKeys, partitionKeys) ++ extraProps)
   }
 
   private def createTable0(
@@ -72,35 +87,35 @@ trait PaimonNonBucketedTable {
   val bucket: Int = -1
 }
 
-trait PaimonPrimaryKeyTable {
-  def appendPrimaryKey(primaryKeys: Seq[String], props: mutable.Map[String, 
String]): Unit = {
-    assert(primaryKeys.nonEmpty)
-    props += ("primary-key" -> primaryKeys.mkString(","))
+trait PaimonPrimaryKeyTable extends PaimonTableTest {
+  def initProps(
+      primaryOrBucketKeys: Seq[String],
+      partitionKeys: Seq[String]): Map[String, String] = {
+    assert(primaryOrBucketKeys.nonEmpty)
+    Map("primary-key" -> primaryOrBucketKeys.mkString(","), "bucket" -> 
bucket.toString)
   }
 }
 
-trait PaimonAppendTable {
-  def appendPrimaryKey(primaryKeys: Seq[String], props: mutable.Map[String, 
String]): Unit = {
-    // nothing to do
+trait PaimonAppendTable extends PaimonTableTest {
+  def initProps(
+      primaryOrBucketKeys: Seq[String],
+      partitionKeys: Seq[String]): Map[String, String] = {
+    if (bucket == -1) {
+      // Ignore bucket keys for unaware bucket table
+      Map("bucket" -> bucket.toString)
+    } else {
+      // Filter partition keys in bucket keys for fixed bucket table
+      val bucketKeys = primaryOrBucketKeys.filterNot(partitionKeys.contains(_))
+      assert(bucketKeys.nonEmpty)
+      Map("bucket-key" -> bucketKeys.mkString(","), "bucket" -> 
bucket.toString)
+    }
   }
 }
 
-trait PaimonPrimaryKeyBucketedTableTest
-  extends PaimonTableTest
-  with PaimonPrimaryKeyTable
-  with PaimonBucketedTable
+trait PaimonPrimaryKeyBucketedTableTest extends PaimonPrimaryKeyTable with 
PaimonBucketedTable
 
-trait PaimonPrimaryKeyNonBucketTableTest
-  extends PaimonTableTest
-  with PaimonPrimaryKeyTable
-  with PaimonNonBucketedTable
+trait PaimonPrimaryKeyNonBucketTableTest extends PaimonPrimaryKeyTable with 
PaimonNonBucketedTable
 
-trait PaimonAppendBucketedTableTest
-  extends PaimonTableTest
-  with PaimonAppendTable
-  with PaimonBucketedTable
+trait PaimonAppendBucketedTableTest extends PaimonAppendTable with 
PaimonBucketedTable
 
-trait PaimonAppendNonBucketTableTest
-  extends PaimonTableTest
-  with PaimonAppendTable
-  with PaimonNonBucketedTable
+trait PaimonAppendNonBucketTableTest extends PaimonAppendTable with 
PaimonNonBucketedTable
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
index 65670ebd8..1a4eae51d 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala
@@ -18,7 +18,7 @@
 
 package org.apache.paimon.spark.sql
 
-import org.apache.paimon.spark.{PaimonSparkTestBase, PaimonTableTest}
+import org.apache.paimon.spark.{PaimonPrimaryKeyTable, PaimonSparkTestBase, 
PaimonTableTest}
 
 import org.apache.spark.sql.Row
 
@@ -497,6 +497,30 @@ abstract class MergeIntoTableTestBase extends 
PaimonSparkTestBase with PaimonTab
         Row(1, 10, Row("x1", "y")) :: Row(2, 20, Row("x", "y")) :: Nil)
     }
   }
+  test(s"Paimon MergeInto: update on source eq target condition") {
+    withTable("source", "target") {
+      Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b", 
"c").createOrReplaceTempView("source")
+
+      createTable("target", "a INT, b INT, c STRING", Seq("a"))
+      sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+      sql(s"""
+             |MERGE INTO target
+             |USING source
+             |ON source.a = target.a
+             |WHEN MATCHED THEN
+             |UPDATE SET a = source.a, b = source.b, c = source.c
+             |""".stripMargin)
+
+      checkAnswer(
+        sql("SELECT * FROM target ORDER BY a, b"),
+        Row(1, 100, "c11") :: Row(2, 20, "c2") :: Nil)
+    }
+  }
+}
+
+trait MergeIntoPrimaryKeyTableTest extends PaimonSparkTestBase with 
PaimonPrimaryKeyTable {
+  import testImplicits._
 
   test("Paimon MergeInto: fail in case that maybe update primary key column") {
     withTable("source", "target") {
@@ -535,50 +559,4 @@ abstract class MergeIntoTableTestBase extends 
PaimonSparkTestBase with PaimonTab
         Row(1, 10, "c111") :: Row(2, 20, "c2") :: Row(103, 30, "c333") :: Nil)
     }
   }
-
-  test("Paimon MergeInto: not support in table without primary keys") {
-    withTable("source", "target") {
-
-      Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b", 
"c").createOrReplaceTempView("source")
-
-      spark.sql(s"""
-                   |CREATE TABLE target (a INT, b INT, c STRING)
-                   |""".stripMargin)
-      spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
-
-      val error = intercept[RuntimeException] {
-        spark.sql(s"""
-                     |MERGE INTO target
-                     |USING source
-                     |ON target.a = source.a
-                     |WHEN MATCHED THEN
-                     |UPDATE SET a = source.a, b = source.b, c = source.c
-                     |WHEN NOT MATCHED
-                     |THEN INSERT (a, b, c) values (a, b, c)
-                     |""".stripMargin)
-      }.getMessage
-      assert(error.contains("Only support to MergeInto table with primary 
keys."))
-    }
-  }
-
-  test(s"Paimon MergeInto: update on source eq target condition") {
-    withTable("source", "target") {
-      Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b", 
"c").createOrReplaceTempView("source")
-
-      createTable("target", "a INT, b INT, c STRING", Seq("a"))
-      sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
-
-      sql(s"""
-             |MERGE INTO target
-             |USING source
-             |ON source.a = target.a
-             |WHEN MATCHED THEN
-             |UPDATE SET a = source.a, b = source.b, c = source.c
-             |""".stripMargin)
-
-      checkAnswer(
-        sql("SELECT * FROM target ORDER BY a, b"),
-        Row(1, 100, "c11") :: Row(2, 20, "c2") :: Nil)
-    }
-  }
 }


Reply via email to