This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 02be3928d [spark] Support MERGE INTO command (#2331)
02be3928d is described below
commit 02be3928d4d4a7dd91c531bb767eb2ed1b12870c
Author: Yann Byron <[email protected]>
AuthorDate: Fri Nov 17 15:24:45 2023 +0800
[spark] Support MERGE INTO command (#2331)
---
.../spark/commands/MergeIntoPaimonTable.scala | 235 ++++++++
.../extensions/PaimonSparkSessionExtensions.scala | 4 +-
.../main/scala/org/apache/spark/sql/Utils.scala | 5 +
.../sql/catalyst/analysis/AnalysisHelper.scala | 50 ++
.../sql/catalyst/analysis/PaimonAnalysis.scala | 8 +-
.../sql/catalyst/analysis/PaimonMergeInto.scala | 229 ++++++++
.../analysis/PaimonMergeIntoResolver.scala | 86 +++
.../analysis/expressions/ExpressionHelper.scala | 96 ++++
.../paimon/spark/sql/MergeIntoTableTest.scala | 622 +++++++++++++++++++++
9 files changed, 1332 insertions(+), 3 deletions(-)
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala
new file mode 100644
index 000000000..7239baea2
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.commands
+
+import org.apache.paimon.options.Options
+import org.apache.paimon.spark.{InsertInto, SparkTable}
+import org.apache.paimon.spark.schema.SparkSystemColumns
+import org.apache.paimon.spark.util.EncoderUtils
+import org.apache.paimon.table.FileStoreTable
+import org.apache.paimon.types.RowKind
+
+import org.apache.spark.sql.{Column, Dataset, Row, SparkSession}
+import org.apache.spark.sql.Utils._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.expressions.ExpressionHelper
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
BasePredicate, Expression, Literal, PredicateHelper, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
+import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
+import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, Filter,
InsertAction, LogicalPlan, MergeAction, UpdateAction}
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id,
sum}
+import org.apache.spark.sql.types.{ByteType, StructField, StructType}
+
+/** Command for Merge Into. */
+case class MergeIntoPaimonTable(
+ v2Table: SparkTable,
+ targetTable: LogicalPlan,
+ sourceTable: LogicalPlan,
+ mergeCondition: Expression,
+ matchedActions: Seq[MergeAction],
+ notMatchedActions: Seq[MergeAction])
+ extends LeafRunnableCommand
+ with WithFileStoreTable
+ with ExpressionHelper
+ with PredicateHelper {
+
+ import MergeIntoPaimonTable._
+
+ override val table: FileStoreTable =
v2Table.getTable.asInstanceOf[FileStoreTable]
+
+ lazy val tableSchema: StructType = v2Table.schema()
+
+ lazy val filteredTargetPlan: LogicalPlan = {
+ val filtersOnlyTarget = getExpressionOnlyRelated(mergeCondition,
targetTable)
+ filtersOnlyTarget
+ .map(Filter.apply(_, targetTable))
+ .getOrElse(targetTable)
+ }
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+
+ // Avoid that more than one source rows match the same target row.
+ checkMatchRationality(sparkSession)
+
+ val changed = constructChangedRows(sparkSession)
+
+ WriteIntoPaimonTable(
+ table,
+ InsertInto,
+ changed,
+ new Options()
+ ).run(sparkSession)
+
+ Seq.empty[Row]
+ }
+
+ /** Get a Dataset where each of Row has an additional column called
_row_kind_. */
+ private def constructChangedRows(sparkSession: SparkSession): Dataset[Row] =
{
+ val targetDS = createDataset(sparkSession, filteredTargetPlan)
+ .withColumn(TARGET_ROW_COL, lit(true))
+
+ val sourceDS = createDataset(sparkSession, sourceTable)
+ .withColumn(SOURCE_ROW_COL, lit(true))
+
+ val joinedDS = sourceDS.join(targetDS, new Column(mergeCondition),
"fullOuter")
+ val joinedPlan = joinedDS.queryExecution.analyzed
+
+ def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = {
+ resolveExpressions(sparkSession)(exprs, joinedPlan)
+ }
+
+ val targetOutput = filteredTargetPlan.output
+ val targetRowNotMatched =
resolveOnJoinedPlan(Seq(col(SOURCE_ROW_COL).isNull.expr)).head
+ val sourceRowNotMatched =
resolveOnJoinedPlan(Seq(col(TARGET_ROW_COL).isNull.expr)).head
+ val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral))
+ val notMatchedExprs =
notMatchedActions.map(_.condition.getOrElse(TrueLiteral))
+ val matchedOutputs = matchedActions.map {
+ case UpdateAction(_, assignments) =>
+ assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
+ case DeleteAction(_) =>
+ targetOutput :+ Literal(RowKind.DELETE.toByteValue)
+ case _ =>
+ throw new RuntimeException("should not be here.")
+ }
+ val notMatchedOutputs = notMatchedActions.map {
+ case InsertAction(_, assignments) =>
+ assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue)
+ case _ =>
+ throw new RuntimeException("should not be here.")
+ }
+ val noopOutput = targetOutput :+ Alias(Literal(NOOP_ROW_KIND_VALUE),
ROW_KIND_COL)()
+ val outputSchema = StructType(tableSchema.fields :+
StructField(ROW_KIND_COL, ByteType))
+
+ val joinedRowEncoder = EncoderUtils.encode(joinedPlan.schema)
+ val outputEncoder = EncoderUtils.encode(outputSchema).resolveAndBind()
+
+ val processor = MergeIntoProcessor(
+ joinedPlan.output,
+ targetRowNotMatched,
+ sourceRowNotMatched,
+ matchedExprs,
+ matchedOutputs,
+ notMatchedExprs,
+ notMatchedOutputs,
+ noopOutput,
+ joinedRowEncoder,
+ outputEncoder
+ )
+ joinedDS.mapPartitions(processor.processPartition)(outputEncoder)
+ }
+
+ private def checkMatchRationality(sparkSession: SparkSession): Unit = {
+ if (matchedActions.nonEmpty) {
+ val targetDS = createDataset(sparkSession, filteredTargetPlan)
+ .withColumn(ROW_ID_COL, monotonically_increasing_id())
+ val sourceDS = createDataset(sparkSession, sourceTable)
+ val count = sourceDS
+ .join(targetDS, new Column(mergeCondition), "inner")
+ .select(col(ROW_ID_COL), lit(1).as("one"))
+ .groupBy(ROW_ID_COL)
+ .agg(sum("one").as("count"))
+ .filter("count > 1")
+ .count()
+ if (count > 0) {
+ throw new RuntimeException(
+ "Can't execute this MergeInto when there are some target rows that
each of them match more then one source rows. It may lead to an unexpected
result.")
+ }
+ }
+ }
+}
+
+object MergeIntoPaimonTable {
+ val ROW_ID_COL = "_row_id_"
+ val SOURCE_ROW_COL = "_source_row_"
+ val TARGET_ROW_COL = "_target_row_"
+ // +I, +U, -U, -D
+ val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL
+ val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte
+
+ case class MergeIntoProcessor(
+ joinedAttributes: Seq[Attribute],
+ targetRowHasNoMatch: Expression,
+ sourceRowHasNoMatch: Expression,
+ matchedConditions: Seq[Expression],
+ matchedOutputs: Seq[Seq[Expression]],
+ notMatchedConditions: Seq[Expression],
+ notMatchedOutputs: Seq[Seq[Expression]],
+ noopCopyOutput: Seq[Expression],
+ joinedRowEncoder: ExpressionEncoder[Row],
+ outputRowEncoder: ExpressionEncoder[Row]
+ ) extends Serializable {
+
+ private def generateProjection(exprs: Seq[Expression]): UnsafeProjection =
{
+ UnsafeProjection.create(exprs, joinedAttributes)
+ }
+
+ private def generatePredicate(expr: Expression): BasePredicate = {
+ GeneratePredicate.generate(expr, joinedAttributes)
+ }
+
+ private def unusedRow(row: InternalRow): Boolean = {
+ row.getByte(outputRowEncoder.schema.fieldIndex(ROW_KIND_COL)) ==
NOOP_ROW_KIND_VALUE
+ }
+
+ def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
+ val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch)
+ val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch)
+ val matchedPreds = matchedConditions.map(generatePredicate)
+ val matchedProjs = matchedOutputs.map(generateProjection)
+ val notMatchedPreds = notMatchedConditions.map(generatePredicate)
+ val notMatchedProjs = notMatchedOutputs.map(generateProjection)
+ val noopCopyProj = generateProjection(noopCopyOutput)
+ val outputProj = UnsafeProjection.create(outputRowEncoder.schema)
+
+ def processRow(inputRow: InternalRow): InternalRow = {
+ if (targetRowHasNoMatchPred.eval(inputRow)) {
+ noopCopyProj.apply(inputRow)
+ } else if (sourceRowHasNoMatchPred.eval(inputRow)) {
+ val pair = notMatchedPreds.zip(notMatchedProjs).find {
+ case (predicate, _) => predicate.eval(inputRow)
+ }
+
+ pair match {
+ case Some((_, projections)) =>
+ projections.apply(inputRow)
+ case None => noopCopyProj.apply(inputRow)
+ }
+ } else {
+ val pair =
+ matchedPreds.zip(matchedProjs).find { case (predicate, _) =>
predicate.eval(inputRow) }
+
+ pair match {
+ case Some((_, projections)) =>
+ projections.apply(inputRow)
+ case None => noopCopyProj.apply(inputRow)
+ }
+ }
+ }
+
+ val toRow = joinedRowEncoder.createSerializer()
+ val fromRow = outputRowEncoder.createDeserializer()
+ rowIterator
+ .map(toRow)
+ .map(processRow)
+ .filterNot(unusedRow)
+ .map(notDeletedInternalRow =>
fromRow(outputProj(notDeletedInternalRow)))
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
index d5006ea64..93278ddb2 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
@@ -18,7 +18,7 @@
package org.apache.paimon.spark.extensions
import org.apache.spark.sql.SparkSessionExtensions
-import org.apache.spark.sql.catalyst.analysis.{CoerceArguments,
PaimonAnalysis, ResolveProcedures, RewriteRowLevelCommands}
+import org.apache.spark.sql.catalyst.analysis.{CoerceArguments,
PaimonAnalysis, PaimonMergeInto, ResolveProcedures, RewriteRowLevelCommands}
import
org.apache.spark.sql.catalyst.parser.extensions.PaimonSparkSqlExtensionsParser
import org.apache.spark.sql.catalyst.plans.logical.PaimonTableValuedFunctions
import
org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy
@@ -34,7 +34,9 @@ class PaimonSparkSessionExtensions extends
(SparkSessionExtensions => Unit) {
extensions.injectResolutionRule(sparkSession => new
PaimonAnalysis(sparkSession))
extensions.injectResolutionRule(spark => ResolveProcedures(spark))
extensions.injectResolutionRule(_ => CoerceArguments)
+
extensions.injectPostHocResolutionRule(_ => RewriteRowLevelCommands)
+ extensions.injectPostHocResolutionRule(spark => PaimonMergeInto(spark))
// table function extensions
PaimonTableValuedFunctions.supportedFnNames.foreach {
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
index e1bbdce7e..e371a18ed 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
@@ -17,6 +17,8 @@
*/
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
/**
* Some classes or methods defined in the spark project are marked as private
under
* [[org.apache.spark.sql]] package, Hence, use this class to adapt then so
that we can use them
@@ -40,4 +42,7 @@ object Utils {
data.sqlContext.internalCreateDataFrame(data.queryExecution.toRdd,
data.schema)
}
+ def createDataset(sparkSession: SparkSession, logicalPlan: LogicalPlan):
Dataset[Row] = {
+ Dataset.ofRows(sparkSession, logicalPlan)
+ }
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnalysisHelper.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnalysisHelper.scala
new file mode 100644
index 000000000..2ac7c3734
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnalysisHelper.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.paimon.spark.SparkTable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+
+import scala.util.control.NonFatal
+
+/** An analysis helper */
+trait AnalysisHelper extends Logging {
+
+ def isPaimonTable(plan: LogicalPlan): Boolean = {
+ try {
+ EliminateSubqueryAliases(plan) match {
+ case DataSourceV2Relation(_: SparkTable, _, _, _, _) => true
+ case _ => false
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Can't check if this plan is a paimon table", e)
+ false
+ }
+ }
+
+ def getPaimonTableRelation(plan: LogicalPlan): DataSourceV2Relation = {
+ EliminateSubqueryAliases(plan) match {
+ case d @ DataSourceV2Relation(_: SparkTable, _, _, _, _) => d
+ case _ => throw new RuntimeException(s"It's not a paimon table, $plan")
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonAnalysis.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonAnalysis.scala
index 04cf36fd1..6498fb5ae 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonAnalysis.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonAnalysis.scala
@@ -21,11 +21,11 @@ import org.apache.paimon.spark.SparkTable
import org.apache.paimon.table.FileStoreTable
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
OverwritePartitionsDynamic, PaimonDynamicPartitionOverwriteCommand,
PaimonTableValuedFunctions, PaimonTableValueFunction}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
MergeIntoTable, OverwritePartitionsDynamic,
PaimonDynamicPartitionOverwriteCommand, PaimonTableValuedFunctions,
PaimonTableValueFunction}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
-class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
+class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] with
AnalysisHelper {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsDown {
@@ -34,7 +34,11 @@ class PaimonAnalysis(session: SparkSession) extends
Rule[LogicalPlan] {
case o @ PaimonDynamicPartitionOverwrite(r, d) if o.resolved =>
PaimonDynamicPartitionOverwriteCommand(r, d, o.query, o.writeOptions,
o.isByName)
+
+ case merge: MergeIntoTable if isPaimonTable(merge.targetTable) &&
merge.childrenResolved =>
+ PaimonMergeIntoResolver(merge, session)
}
+
}
object PaimonDynamicPartitionOverwrite {
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeInto.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeInto.scala
new file mode 100644
index 000000000..7fa050398
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeInto.scala
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.paimon.CoreOptions
+import org.apache.paimon.spark.SparkTable
+import org.apache.paimon.spark.commands.MergeIntoPaimonTable
+import org.apache.paimon.table.{FileStoreTable, PrimaryKeyFileStoreTable}
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.expressions.ExpressionHelper
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo,
Expression, SubqueryExpression}
+import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction,
InsertAction, InsertStarAction, LogicalPlan, MergeAction, MergeIntoTable,
UpdateAction, UpdateStarAction}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+import scala.collection.mutable
+
+/** A post-hoc resolution rule for MergeInto. */
+case class PaimonMergeInto(spark: SparkSession)
+ extends Rule[LogicalPlan]
+ with AnalysisHelper
+ with ExpressionHelper {
+
+ private val resolver: Resolver = spark.sessionState.conf.resolver
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan match {
+ case merge: MergeIntoTable =>
+ val relation = getPaimonTableRelation(merge.targetTable)
+ val v2Table = relation.table.asInstanceOf[SparkTable]
+ val targetOutput = relation.output
+
+ v2Table.getTable match {
+ case _: PrimaryKeyFileStoreTable =>
+ case _: FileStoreTable =>
+ throw new RuntimeException("Only support to merge into table with
primary keys.")
+ case _ =>
+ throw new RuntimeException("Can't merge into a non-file store
table.")
+ }
+
+ checkCondition(merge.mergeCondition)
+ merge.matchedActions.flatMap(_.condition).foreach(checkCondition)
+ merge.notMatchedActions.flatMap(_.condition).foreach(checkCondition)
+
+ val updateActions = merge.matchedActions.collect { case a:
UpdateAction => a }
+ val primaryKeys =
v2Table.properties().get(CoreOptions.PRIMARY_KEY.key).split(",")
+ checkUpdateActionValidity(
+ merge.targetTable,
+ merge.sourceTable,
+ merge.mergeCondition,
+ updateActions,
+ primaryKeys)
+
+ val alignedMatchedActions =
+ merge.matchedActions.map(checkAndAlignActionAssigment(_,
targetOutput))
+ val alignedNotMatchedActions =
+ merge.notMatchedActions.map(checkAndAlignActionAssigment(_,
targetOutput))
+
+ MergeIntoPaimonTable(
+ v2Table,
+ merge.targetTable,
+ merge.sourceTable,
+ merge.mergeCondition,
+ alignedMatchedActions,
+ alignedNotMatchedActions)
+
+ case _ => plan
+ }
+ }
+
+ private def checkAndAlignActionAssigment(
+ action: MergeAction,
+ targetOutput: Seq[AttributeReference]): MergeAction = {
+ action match {
+ case d @ DeleteAction(_) => d
+ case u @ UpdateAction(_, assignments) =>
+ val attrNameAndUpdateExpr = checkAndConvertAssignments(assignments,
targetOutput)
+
+ val newAssignments = targetOutput.map {
+ field =>
+ val fieldAndExpr = attrNameAndUpdateExpr.find(a =>
resolver(field.name, a._1))
+ if (fieldAndExpr.isEmpty) {
+ Assignment(field, field)
+ } else {
+ Assignment(field, castIfNeeded(fieldAndExpr.get._2,
field.dataType))
+ }
+ }
+ u.copy(assignments = newAssignments)
+
+ case i @ InsertAction(_, assignments) =>
+ val attrNameAndUpdateExpr = checkAndConvertAssignments(assignments,
targetOutput)
+ if (assignments.length != targetOutput.length) {
+ throw new RuntimeException("Can't align the table's columns in
insert clause.")
+ }
+
+ val newAssignments = targetOutput.map {
+ field =>
+ val fieldAndExpr = attrNameAndUpdateExpr.find(a =>
resolver(field.name, a._1))
+ if (fieldAndExpr.isEmpty) {
+ throw new RuntimeException(s"Can't find the expression for
${field.name}.")
+ } else {
+ Assignment(field, castIfNeeded(fieldAndExpr.get._2,
field.dataType))
+ }
+ }
+ i.copy(assignments = newAssignments)
+
+ case _: UpdateStarAction =>
+ throw new RuntimeException(s"UpdateStarAction should not be here.")
+
+ case _: InsertStarAction =>
+ throw new RuntimeException(s"InsertStarAction should not be here.")
+
+ case _ =>
+ throw new RuntimeException(s"Can't recognize this action: $action")
+ }
+ }
+
+ private def checkCondition(condition: Expression): Unit = {
+ if (!condition.resolved) {
+ throw new RuntimeException(s"Condition $condition should have been
resolved.")
+ }
+ if (SubqueryExpression.hasSubquery(condition)) {
+ throw new RuntimeException(s"Condition $condition with subquery can't be
supported.")
+ }
+ }
+
+ private def checkAndConvertAssignments(
+ assignments: Seq[Assignment],
+ targetOutput: Seq[AttributeReference]): Seq[(String, Expression)] = {
+ val columnToAssign = mutable.HashMap.empty[String, Int]
+ val pairs = assignments.map {
+ assignment =>
+ assignment.key match {
+ case a: AttributeReference =>
+ if (!targetOutput.exists(attr => resolver(attr.name, a.name))) {
+ throw new RuntimeException(
+ s"Ths key of assignment doesn't belong to the target table,
$assignment")
+ }
+ columnToAssign.put(a.name, columnToAssign.getOrElse(a.name, 0) + 1)
+ case _ =>
+ throw new RuntimeException(
+ s"Only primitive type is supported in update/insert clause,
$assignment")
+ }
+ (assignment.key.asInstanceOf[AttributeReference].name,
assignment.value)
+ }
+
+ val duplicatedColumns = columnToAssign.filter(_._2 > 1).keys
+ if (duplicatedColumns.nonEmpty) {
+ val partOfMsg = duplicatedColumns.mkString(",")
+ throw new RuntimeException(
+ s"Can't update/insert the same column ($partOfMsg) multiple times.")
+ }
+
+ pairs
+ }
+
+ /** This check will avoid to update the primary key columns */
+ private def checkUpdateActionValidity(
+ target: LogicalPlan,
+ source: LogicalPlan,
+ mergeCondition: Expression,
+ actions: Seq[UpdateAction],
+ primaryKeys: Seq[String]): Unit = {
+ val targetOutput = target.outputSet
+ val sourceOutput = source.outputSet
+
+ // Check whether this attribute is same to primary key and is from target
table.
+ def isTargetPrimaryKey(attr: AttributeReference, primaryKey: String):
Boolean = {
+ resolver(primaryKey, attr.name) && targetOutput.contains(attr)
+ }
+
+ // Check whether there is an `EqualTo` expression related to primary key
between source and target.
+ def existsPrimaryKeyEqualToExpression(
+ expressions: Seq[Expression],
+ primaryKey: String): Boolean = {
+ expressions.exists {
+ case EqualTo(left: AttributeReference, right: AttributeReference) =>
+ if (isTargetPrimaryKey(left, primaryKey)) {
+ sourceOutput.contains(right)
+ } else if (isTargetPrimaryKey(right, primaryKey)) {
+ targetOutput.contains(left)
+ } else {
+ false
+ }
+ case _ => false
+ }
+ }
+
+ // Check whether there are enough `EqualTo` expressions related to all the
primary keys.
+ lazy val isMergeConditionValid = {
+ val mergeExpressions = splitConjunctivePredicates(mergeCondition)
+ primaryKeys.forall {
+ primaryKey => existsPrimaryKeyEqualToExpression(mergeExpressions,
primaryKey)
+ }
+ }
+
+ // Check whether there are on `EqualTo` expression related to any primary
key.
+ // Then, we do not update primary key columns.
+ def isUpdateActionValid(action: UpdateAction): Boolean = {
+ val found = primaryKeys.find {
+ primaryKey =>
+ existsPrimaryKeyEqualToExpression(
+ action.assignments.map(a => EqualTo(a.key, a.value)),
+ primaryKey)
+ }
+ found.isEmpty
+ }
+
+ val valid = isMergeConditionValid || actions.forall(isUpdateActionValid)
+ if (!valid) {
+ throw new RuntimeException("Can't update the primary key column in
update clause.")
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeIntoResolver.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeIntoResolver.scala
new file mode 100644
index 000000000..5331a9c40
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeIntoResolver.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.expressions.ExpressionHelper
+import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction,
InsertAction, InsertStarAction, LogicalPlan, MergeAction, MergeIntoTable,
UpdateAction, UpdateStarAction}
+
+/** Resolve all the expressions for MergeInto. */
+object PaimonMergeIntoResolver extends AnalysisHelper with ExpressionHelper {
+
+ def apply(merge: MergeIntoTable, spark: SparkSession): LogicalPlan = {
+ val target = merge.targetTable
+ val source = merge.sourceTable
+ assert(target.resolved, "Target should have been resolved here.")
+ assert(source.resolved, "Source should have been resolved here.")
+
+ val condition = merge.mergeCondition
+ val matched = merge.matchedActions
+ val notMatched = merge.notMatchedActions
+
+ val resolve = resolveExpression(spark) _
+
+ def resolveMergeAction(action: MergeAction): MergeAction = {
+ action match {
+ case DeleteAction(condition) =>
+ val resolvedCond = condition.map(resolve(_, merge))
+ DeleteAction(resolvedCond)
+ case UpdateAction(condition, assignments) =>
+ val resolvedCond = condition.map(resolve(_, merge))
+ val resolvedAssignments = assignments.map {
+ assignment =>
+ assignment.copy(
+ key = resolve(assignment.key, merge),
+ value = resolve(assignment.value, merge))
+ }
+ UpdateAction(resolvedCond, resolvedAssignments)
+ case UpdateStarAction(condition) =>
+ val resolvedCond = condition.map(resolve(_, merge))
+ val resolvedAssignments = target.output.map {
+ attr => Assignment(attr,
resolve(UnresolvedAttribute.quotedString(attr.name), source))
+ }
+ UpdateAction(resolvedCond, resolvedAssignments)
+ case InsertAction(condition, assignments) =>
+ val resolvedCond = condition.map(resolve(_, source))
+ val resolvedAssignments = assignments.map {
+ assignment =>
+ assignment.copy(
+ key = resolve(assignment.key, source),
+ value = resolve(assignment.value, source))
+ }
+ InsertAction(resolvedCond, resolvedAssignments)
+ case InsertStarAction(condition) =>
+ val resolvedCond = condition.map(resolve(_, source))
+ val resolvedAssignments = target.output.map {
+ attr => Assignment(attr,
resolve(UnresolvedAttribute.quotedString(attr.name), source))
+ }
+ InsertAction(resolvedCond, resolvedAssignments)
+ case _ =>
+ throw new RuntimeException(s"Can't recognize this action: $action")
+ }
+ }
+
+ val resolvedCond = resolve(condition, merge)
+ val resolvedMatched: Seq[MergeAction] = matched.map(resolveMergeAction)
+ val resolvedNotMatched: Seq[MergeAction] =
notMatched.map(resolveMergeAction)
+
+ merge.copy(target, source, resolvedCond, resolvedMatched,
resolvedNotMatched)
+ }
+
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/expressions/ExpressionHelper.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/expressions/ExpressionHelper.scala
new file mode 100644
index 000000000..a4e2b6755
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/expressions/ExpressionHelper.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis.expressions
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast,
Expression, Literal, PredicateHelper}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, NullType}
+
+/** An expression helper. */
+trait ExpressionHelper extends PredicateHelper {
+
+ import ExpressionHelper._
+
+ protected def resolveExpression(
+ spark: SparkSession)(expr: Expression, plan: LogicalPlan): Expression = {
+ if (expr.resolved) {
+ expr
+ } else {
+ val newPlan = FakeLogicalPlan(Seq(expr), plan.children)
+ spark.sessionState.analyzer.execute(newPlan) match {
+ case FakeLogicalPlan(resolvedExpr, _) =>
+ resolvedExpr.head
+ case _ =>
+ throw new RuntimeException(s"Could not resolve expression $expr in
plan: $plan")
+ }
+ }
+ }
+
+ protected def resolveExpressions(
+ spark: SparkSession)(exprs: Seq[Expression], plan: LogicalPlan):
Seq[Expression] = {
+ val newPlan = FakeLogicalPlan(exprs, plan.children)
+ spark.sessionState.analyzer.execute(newPlan) match {
+ case FakeLogicalPlan(resolvedExpr, _) =>
+ resolvedExpr
+ case _ =>
+ throw new RuntimeException(s"Could not resolve expressions $exprs in
plan: $plan")
+ }
+ }
+
+ /**
+ * Get the parts of the expression that are only relevant to the plan, and
then compose the new
+ * expression.
+ */
+ protected def getExpressionOnlyRelated(
+ expression: Expression,
+ plan: LogicalPlan): Option[Expression] = {
+ val expressions = splitConjunctivePredicates(expression).filter {
+ expression => canEvaluate(expression, plan) &&
canEvaluateWithinJoin(expression)
+ }
+ expressions.reduceOption(And)
+ }
+
+ protected def castIfNeeded(fromExpression: Expression, toDataType:
DataType): Expression = {
+ fromExpression match {
+ case Literal(null, NullType) => Literal(null, toDataType)
+ case _ =>
+ val fromDataType = fromExpression.dataType
+ if (!Cast.canCast(fromDataType, toDataType)) {
+ throw new RuntimeException(s"Can't cast from $fromDataType to
$toDataType.")
+ }
+ if (DataType.equalsIgnoreCaseAndNullability(fromDataType, toDataType))
{
+ fromExpression
+ } else {
+ Cast(fromExpression, toDataType,
Option(SQLConf.get.sessionLocalTimeZone))
+ }
+ }
+ }
+}
+
+object ExpressionHelper {
+
+ case class FakeLogicalPlan(exprs: Seq[Expression], children:
Seq[LogicalPlan])
+ extends LogicalPlan {
+ override def output: Seq[Attribute] = Nil
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[LogicalPlan]): FakeLogicalPlan = copy(children
= newChildren)
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
new file mode 100644
index 000000000..c07131ec3
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
@@ -0,0 +1,622 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.sql.Row
+
+class MergeIntoTableTest extends PaimonSparkTestBase {
+
+ import testImplicits._
+
+ test(s"Paimon MergeInto: only update") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET a = source.a, b = source.b, c = source.c
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 100, "c11") :: Row(2, 20, "c2") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: only delete") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN DELETE
+ |""".stripMargin)
+
+ checkAnswer(spark.sql("SELECT * FROM target ORDER BY a, b"), Row(2, 20,
"c2") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: only insert") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 10, "c1") :: Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: update + insert") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET a = source.a, b = source.b, c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 100, "c11") :: Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: delete + insert") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |DELETE
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: conditional update") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3,
30, 'c3')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED AND source.b > 200 THEN
+ |UPDATE SET b = source.b, c = source.c
+ |WHEN MATCHED THEN
+ |DELETE
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: conditional insert") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET b = source.b, c = source.c
+ |WHEN NOT MATCHED AND b < 300 THEN
+ |INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 100, "c11") :: Row(2, 20, "c2") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: conditional delete") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED AND target.c < 'c1' THEN
+ |DELETE
+ |WHEN NOT MATCHED THEN
+ |INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 10, "c1") :: Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: star") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET *
+ |WHEN NOT MATCHED
+ |THEN INSERT *
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 100, "c11") :: Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: multiple clauses") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33"), (5, 500, "c55"), (7, 700, "c77"),
(9, 900, "c99"))
+ .toDF("a", "b", "c")
+ .createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql(
+ "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30,
'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED AND target.a = 5 THEN
+ |UPDATE SET b = source.b + target.b
+ |WHEN MATCHED AND source.c > 'c2' THEN
+ |UPDATE SET *
+ |WHEN MATCHED THEN
+ |DELETE
+ |WHEN NOT MATCHED AND c > 'c9' THEN
+ |INSERT (a, b, c) VALUES (a, b * 1.1, c)
+ |WHEN NOT MATCHED THEN
+ |INSERT *
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(2, 20, "c2") :: Row(3, 300, "c33") :: Row(4, 40, "c4") :: Row(5,
550, "c5") :: Row(
+ 7,
+ 700,
+ "c77") :: Row(9, 990, "c99") :: Nil
+ )
+ }
+ }
+
+ test(s"Paimon MergeInto: source and target are empty") {
+ withTable("source", "target") {
+
+ Seq.empty[(Int, Int, String)].toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET a = source.a, b = source.b, c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT *
+ |""".stripMargin)
+
+ checkAnswer(spark.sql("SELECT * FROM target ORDER BY a, b"), Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: update value from both source and target table") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET b = target.b * 11, c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b, c) values (source.a, source.b * 2,
source.c)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 110, "c11") :: Row(2, 20, "c2") :: Row(3, 600, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: insert/update columns in wrong order") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET c = source.c, b = source.b
+ |WHEN NOT MATCHED
+ |THEN INSERT (b, c, a) values (b, c, a)
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 100, "c11") :: Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test(s"Paimon MergeInto: miss some columns in update") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT *
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 10, "c11") :: Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test("Paimon MergeInto: fail in case that miss some columns in insert") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ val error = intercept[RuntimeException] {
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET *
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b) VALUES (a, b)
+ |""".stripMargin)
+ }.getMessage
+ assert(error.contains("Can't align the table's columns in insert
clause."))
+ }
+ }
+
+ test("Paimon MergeInto: source is a query") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33"), (4, 400, "c44"))
+ .toDF("a", "b", "c")
+ .createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING (SELECT a, b, c FROM source WHERE a % 2 = 1) AS src
+ |ON target.a = src.a
+ |WHEN MATCHED THEN
+ |UPDATE SET b = src.b, c = src.c
+ |WHEN NOT MATCHED
+ |THEN INSERT *
+ |""".stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 100, "c11") :: Row(2, 20, "c2") :: Row(3, 300, "c33") :: Nil)
+ }
+ }
+
+ test("Paimon MergeInto: fail in case that more than one source rows match
the same target row") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (1, 1000, "c111"), (3, 300, "c33"))
+ .toDF("a", "b", "c")
+ .createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ val error = intercept[RuntimeException] {
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET b = source.b, c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+ }.getMessage
+ assert(error.contains("match more then one source rows"))
+ }
+ }
+
+ test("Paimon MergeInto: fail in case that update/insert same column multiple
times") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ val error1 = intercept[RuntimeException] {
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET a = source.a, b = source.b, b = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT *
+ |""".stripMargin)
+ }.getMessage
+ assert(error1.contains("Can't update/insert the same column (b) multiple
times."))
+
+ val error2 = intercept[RuntimeException] {
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET *
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, a, c) VALUES (a, b, c)
+ |""".stripMargin)
+ }.getMessage
+ assert(error2.contains("Can't update/insert the same column (a) multiple
times."))
+ }
+ }
+
+ test("Paimon MergeInto: fail in case that update nested column") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "x1", "y1"), (3, 300, "x3", "y3"))
+ .toDF("a", "b", "c1", "c2")
+ .createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRUCT<c1:STRING,
c2:STRING>)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, struct('x', 'y')), (2, 20,
struct('x', 'y'))")
+
+ val error = intercept[RuntimeException] {
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET c.c1 = source.c1
+ |""".stripMargin)
+ }.getMessage
+ assert(error.contains("Only primitive type is supported"))
+ }
+ }
+
+ test("Paimon MergeInto: fail in case that maybe update primary key column") {
+ withTable("source", "target") {
+
+ Seq((101, 10, "c111"), (103, 30, "c333"))
+ .toDF("a", "b", "c")
+ .createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |TBLPROPERTIES ('primary-key'='a', 'bucket'='2')
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ val error = intercept[RuntimeException] {
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.b = source.b
+ |WHEN MATCHED THEN
+ |UPDATE SET a = source.a, b = source.b, c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT *
+ |""".stripMargin)
+ }.getMessage
+ assert(error.contains("Can't update the primary key column"))
+
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.b = source.b
+ |WHEN MATCHED THEN
+ |UPDATE SET c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT *
+ |""".stripMargin)
+ checkAnswer(
+ spark.sql("SELECT * FROM target ORDER BY a, b"),
+ Row(1, 10, "c111") :: Row(2, 20, "c2") :: Row(103, 30, "c333") :: Nil)
+ }
+ }
+
+ test("Paimon MergeInto: not support in table without primary keys") {
+ withTable("source", "target") {
+
+ Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b",
"c").createOrReplaceTempView("source")
+
+ spark.sql(s"""
+ |CREATE TABLE target (a INT, b INT, c STRING)
+ |""".stripMargin)
+ spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')")
+
+ val error = intercept[RuntimeException] {
+ spark.sql(s"""
+ |MERGE INTO target
+ |USING source
+ |ON target.a = source.a
+ |WHEN MATCHED THEN
+ |UPDATE SET a = source.a, b = source.b, c = source.c
+ |WHEN NOT MATCHED
+ |THEN INSERT (a, b, c) values (a, b, c)
+ |""".stripMargin)
+ }.getMessage
+ assert(error.contains("Only support to merge into table with primary
keys."))
+ }
+ }
+}