This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3f3b52969585 [SPARK-49427][CONNECT][SQL] Create a shared interface for
MergeIntoWriter
3f3b52969585 is described below
commit 3f3b52969585315f9218d58bd2dc438414e4ad38
Author: Herman van Hovell <[email protected]>
AuthorDate: Thu Sep 5 00:20:47 2024 -0400
[SPARK-49427][CONNECT][SQL] Create a shared interface for MergeIntoWriter
### What changes were proposed in this pull request?
This PR creates a shared interface for MergeIntoWriter.
### Why are the changes needed?
We are creating a shared Scala Spark SQL interface for Classic and Connect.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47963 from hvanhovell/SPARK-49427.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 27 +-
.../org/apache/spark/sql/MergeIntoWriter.scala | 415 ---------------------
.../spark/sql/internal/MergeIntoWriterImpl.scala | 135 +++++++
.../org/apache/spark/sql/MergeIntoWriter.scala | 151 +++-----
.../scala/org/apache/spark/sql/api/Dataset.scala | 26 +-
.../sql/catalyst/plans/logical/v2Commands.scala | 6 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 18 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 4 +-
.../spark/sql/internal/MergeIntoWriterImpl.scala | 125 +++++++
.../sql/connector/MergeIntoDataFrameSuite.scala | 12 +-
10 files changed, 356 insertions(+), 563 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 778cd153ec2e..552512fee8cd 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -37,7 +37,7 @@ import
org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevel
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.{struct, to_json}
-import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter,
DataFrameWriterImpl, DataFrameWriterV2Impl, ToScalaUDF, UDFAdaptors,
UnresolvedAttribute, UnresolvedRegex}
+import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter,
DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF,
UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
@@ -1023,28 +1023,7 @@ class Dataset[T] private[sql] (
new DataFrameWriterV2Impl[T](table, this)
}
- /**
- * Merges a set of updates, insertions, and deletions based on a source
table into a target
- * table.
- *
- * Scala Examples:
- * {{{
- * spark.table("source")
- * .mergeInto("target", $"source.id" === $"target.id")
- * .whenMatched($"salary" === 100)
- * .delete()
- * .whenNotMatched()
- * .insertAll()
- * .whenNotMatchedBySource($"salary" === 100)
- * .update(Map(
- * "salary" -> lit(200)
- * ))
- * .merge()
- * }}}
- *
- * @group basic
- * @since 4.0.0
- */
+ /** @inheritdoc */
def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = {
if (isStreaming) {
throw new AnalysisException(
@@ -1052,7 +1031,7 @@ class Dataset[T] private[sql] (
messageParameters = Map("methodName" -> toSQLId("mergeInto")))
}
- new MergeIntoWriter[T](table, this, condition)
+ new MergeIntoWriterImpl[T](table, this, condition)
}
/**
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
deleted file mode 100644
index 71813af1e354..000000000000
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
+++ /dev/null
@@ -1,415 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import scala.jdk.CollectionConverters._
-
-import org.apache.spark.SparkRuntimeException
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{Expression, MergeIntoTableCommand}
-import org.apache.spark.connect.proto.MergeAction
-import org.apache.spark.sql.functions.expr
-
-/**
- * `MergeIntoWriter` provides methods to define and execute merge actions
based on specified
- * conditions.
- *
- * @tparam T
- * the type of data in the Dataset.
- * @param table
- * the name of the target table for the merge operation.
- * @param ds
- * the source Dataset to merge into the target table.
- * @param on
- * the merge condition.
- * @param schemaEvolutionEnabled
- * whether to enable automatic schema evolution for this merge operation.
Default is `false`.
- *
- * @since 4.0.0
- */
-@Experimental
-class MergeIntoWriter[T] private[sql] (
- table: String,
- ds: Dataset[T],
- on: Column,
- schemaEvolutionEnabled: Boolean = false) {
- import ds.sparkSession.RichColumn
-
- private[sql] var matchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedBySourceActions: Seq[MergeAction] =
Seq.empty[MergeAction]
-
- /**
- * Initialize a `WhenMatched` action without any condition.
- *
- * This `WhenMatched` action will be executed when a source row matches a
target table row based
- * on the merge condition.
- *
- * This `WhenMatched` can be followed by one of the following merge actions:
- * - `updateAll`: Update all the matched target table rows with source
dataset rows.
- * - `update(Map)`: Update all the matched target table rows while
changing only a subset of
- * columns based on the provided assignment.
- * - `delete`: Delete all target rows that have a match in the source
table.
- *
- * @return
- * a new `WhenMatched` object.
- */
- def whenMatched(): WhenMatched[T] = {
- new WhenMatched[T](this, None)
- }
-
- /**
- * Initialize a `WhenMatched` action with a condition.
- *
- * This `WhenMatched` action will be executed when a source row matches a
target table row based
- * on the merge condition and the specified `condition` is satisfied.
- *
- * This `WhenMatched` can be followed by one of the following merge actions:
- * - `updateAll`: Update all the matched target table rows with source
dataset rows.
- * - `update(Map)`: Update all the matched target table rows while
changing only a subset of
- * columns based on the provided assignment.
- * - `delete`: Delete all target rows that have a match in the source
table.
- *
- * @param condition
- * a `Column` representing the condition to be evaluated for the action.
- * @return
- * a new `WhenMatched` object configured with the specified condition.
- */
- def whenMatched(condition: Column): WhenMatched[T] = {
- new WhenMatched[T](this, Some(condition))
- }
-
- /**
- * Initialize a `WhenNotMatched` action without any condition.
- *
- * This `WhenNotMatched` action will be executed when a source row does not
match any target row
- * based on the merge condition.
- *
- * This `WhenNotMatched` can be followed by one of the following merge
actions:
- * - `insertAll`: Insert all rows from the source that are not already in
the target table.
- * - `insert(Map)`: Insert all rows from the source that are not already
in the target table,
- * with the specified columns based on the provided assignment.
- *
- * @return
- * a new `WhenNotMatched` object.
- */
- def whenNotMatched(): WhenNotMatched[T] = {
- new WhenNotMatched[T](this, None)
- }
-
- /**
- * Initialize a `WhenNotMatched` action with a condition.
- *
- * This `WhenNotMatched` action will be executed when a source row does not
match any target row
- * based on the merge condition and the specified `condition` is satisfied.
- *
- * This `WhenNotMatched` can be followed by one of the following merge
actions:
- * - `insertAll`: Insert all rows from the source that are not already in
the target table.
- * - `insert(Map)`: Insert all rows from the source that are not already
in the target table,
- * with the specified columns based on the provided assignment.
- *
- * @param condition
- * a `Column` representing the condition to be evaluated for the action.
- * @return
- * a new `WhenNotMatched` object configured with the specified condition.
- */
- def whenNotMatched(condition: Column): WhenNotMatched[T] = {
- new WhenNotMatched[T](this, Some(condition))
- }
-
- /**
- * Initialize a `WhenNotMatchedBySource` action without any condition.
- *
- * This `WhenNotMatchedBySource` action will be executed when a target row
does not match any
- * rows in the source table based on the merge condition.
- *
- * This `WhenNotMatchedBySource` can be followed by one of the following
merge actions:
- * - `updateAll`: Update all the not matched target table rows with source
dataset rows.
- * - `update(Map)`: Update all the not matched target table rows while
changing only the
- * specified columns based on the provided assignment.
- * - `delete`: Delete all target rows that have no matches in the source
table.
- *
- * @return
- * a new `WhenNotMatchedBySource` object.
- */
- def whenNotMatchedBySource(): WhenNotMatchedBySource[T] = {
- new WhenNotMatchedBySource[T](this, None)
- }
-
- /**
- * Initialize a `WhenNotMatchedBySource` action with a condition.
- *
- * This `WhenNotMatchedBySource` action will be executed when a target row
does not match any
- * rows in the source table based on the merge condition and the specified
`condition` is
- * satisfied.
- *
- * This `WhenNotMatchedBySource` can be followed by one of the following
merge actions:
- * - `updateAll`: Update all the not matched target table rows with source
dataset rows.
- * - `update(Map)`: Update all the not matched target table rows while
changing only the
- * specified columns based on the provided assignment.
- * - `delete`: Delete all target rows that have no matches in the source
table.
- *
- * @param condition
- * a `Column` representing the condition to be evaluated for the action.
- * @return
- * a new `WhenNotMatchedBySource` object configured with the specified
condition.
- */
- def whenNotMatchedBySource(condition: Column): WhenNotMatchedBySource[T] = {
- new WhenNotMatchedBySource[T](this, Some(condition))
- }
-
- /**
- * Enable automatic schema evolution for this merge operation.
- * @return
- * A `MergeIntoWriter` instance with schema evolution enabled.
- */
- def withSchemaEvolution(): MergeIntoWriter[T] = {
- new MergeIntoWriter[T](this.table, this.ds, this.on,
schemaEvolutionEnabled = true)
- .withNewMatchedActions(this.matchedActions: _*)
- .withNewNotMatchedActions(this.notMatchedActions: _*)
- .withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*)
- }
-
- /**
- * Executes the merge operation.
- */
- def merge(): Unit = {
- if (matchedActions.isEmpty && notMatchedActions.isEmpty &&
notMatchedBySourceActions.isEmpty) {
- throw new SparkRuntimeException(
- errorClass = "NO_MERGE_ACTION_SPECIFIED",
- messageParameters = Map.empty)
- }
-
- val matchedActionExpressions =
-
matchedActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build())
- val notMatchedActionExpressions =
-
notMatchedActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build())
- val notMatchedBySourceActionExpressions =
-
notMatchedBySourceActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build())
- val mergeIntoCommand = MergeIntoTableCommand
- .newBuilder()
- .setTargetTableName(table)
- .setSourceTablePlan(ds.plan.getRoot)
- .setMergeCondition(on.expr)
- .addAllMatchActions(matchedActionExpressions.asJava)
- .addAllNotMatchedActions(notMatchedActionExpressions.asJava)
-
.addAllNotMatchedBySourceActions(notMatchedBySourceActionExpressions.asJava)
- .setWithSchemaEvolution(schemaEvolutionEnabled)
- .build()
-
- ds.sparkSession.execute(
- proto.Command
- .newBuilder()
- .setMergeIntoTableCommand(mergeIntoCommand)
- .build())
- }
-
- private[sql] def withNewMatchedActions(action: MergeAction*):
MergeIntoWriter[T] = {
- this.matchedActions = this.matchedActions :++ action
- this
- }
-
- private[sql] def withNewNotMatchedActions(action: MergeAction*):
MergeIntoWriter[T] = {
- this.notMatchedActions = this.notMatchedActions :++ action
- this
- }
-
- private[sql] def withNewNotMatchedBySourceActions(action: MergeAction*):
MergeIntoWriter[T] = {
- this.notMatchedBySourceActions = this.notMatchedBySourceActions :++ action
- this
- }
-
- private[sql] def buildMergeAction(
- actionType: MergeAction.ActionType,
- conditionOpt: Option[Column],
- assignmentsOpt: Option[Map[String, Column]] = None): MergeAction = {
- val assignmentsProtoOpt = assignmentsOpt.map {
- _.map { case (k, v) =>
- MergeAction.Assignment
- .newBuilder()
- .setKey(expr(k).expr)
- .setValue(v.expr)
- .build()
- }.toSeq.asJava
- }
-
- val builder = MergeAction.newBuilder().setActionType(actionType)
- conditionOpt.map(c => builder.setCondition(c.expr))
- assignmentsProtoOpt.map(builder.addAllAssignments)
- builder.build()
- }
-}
-
-/**
- * A class for defining actions to be taken when matching rows in a DataFrame
during a merge
- * operation.
- *
- * @param mergeIntoWriter
- * The MergeIntoWriter instance responsible for writing data to a target
DataFrame.
- * @param condition
- * An optional condition Expression that specifies when the actions should
be applied. If the
- * condition is None, the actions will be applied to all matched rows.
- *
- * @tparam T
- * The type of data in the MergeIntoWriter.
- */
-case class WhenMatched[T] private[sql] (
- mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Column]) {
-
- /**
- * Specifies an action to update all matched rows in the DataFrame.
- *
- * @return
- * The MergeIntoWriter instance with the update all action configured.
- */
- def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR,
condition))
- }
-
- /**
- * Specifies an action to update matched rows in the DataFrame with the
provided column
- * assignments.
- *
- * @param map
- * A Map of column names to Column expressions representing the updates to
be applied.
- * @return
- * The MergeIntoWriter instance with the update action configured.
- */
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE,
condition, Some(map)))
- }
-
- /**
- * Specifies an action to delete matched rows from the DataFrame.
- *
- * @return
- * The MergeIntoWriter instance with the delete action configured.
- */
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
-
mergeIntoWriter.buildMergeAction(MergeAction.ActionType.ACTION_TYPE_DELETE,
condition))
- }
-}
-
-/**
- * A class for defining actions to be taken when no matching rows are found in
a DataFrame during
- * a merge operation.
- *
- * @param mergeIntoWriter
- * The MergeIntoWriter instance responsible for writing data to a target
DataFrame.
- * @param condition
- * An optional condition Expression that specifies when the actions defined
in this
- * configuration should be applied. If the condition is None, the actions
will be applied when
- * there are no matching rows.
- * @tparam T
- * The type of data in the MergeIntoWriter.
- */
-case class WhenNotMatched[T] private[sql] (
- mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Column]) {
-
- /**
- * Specifies an action to insert all non-matched rows into the DataFrame.
- *
- * @return
- * The MergeIntoWriter instance with the insert all action configured.
- */
- def insertAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_INSERT_STAR,
condition))
- }
-
- /**
- * Specifies an action to insert non-matched rows into the DataFrame with
the provided column
- * assignments.
- *
- * @param map
- * A Map of column names to Column expressions representing the values to
be inserted.
- * @return
- * The MergeIntoWriter instance with the insert action configured.
- */
- def insert(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_INSERT,
condition, Some(map)))
- }
-}
-
-/**
- * A class for defining actions to be performed when there is no match by
source during a merge
- * operation in a MergeIntoWriter.
- *
- * @param mergeIntoWriter
- * the MergeIntoWriter instance to which the merge actions will be applied.
- * @param condition
- * an optional condition to be used with the merge actions.
- * @tparam T
- * the type parameter for the MergeIntoWriter.
- */
-case class WhenNotMatchedBySource[T] private[sql] (
- mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Column]) {
-
- /**
- * Specifies an action to update all non-matched rows in the target
DataFrame when not matched
- * by the source.
- *
- * @return
- * The MergeIntoWriter instance with the update all action configured.
- */
- def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR,
condition))
- }
-
- /**
- * Specifies an action to update non-matched rows in the target DataFrame
with the provided
- * column assignments when not matched by the source.
- *
- * @param map
- * A Map of column names to Column expressions representing the updates to
be applied.
- * @return
- * The MergeIntoWriter instance with the update action configured.
- */
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE,
condition, Some(map)))
- }
-
- /**
- * Specifies an action to delete non-matched rows from the target DataFrame
when not matched by
- * the source.
- *
- * @return
- * The MergeIntoWriter instance with the delete action configured.
- */
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_DELETE,
condition))
- }
-}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
new file mode 100644
index 000000000000..fba3c6343558
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{Expression, MergeAction,
MergeIntoTableCommand}
+import org.apache.spark.connect.proto.MergeAction.ActionType._
+import org.apache.spark.sql.{Column, Dataset, MergeIntoWriter}
+import org.apache.spark.sql.functions.expr
+
+/**
+ * `MergeIntoWriter` provides methods to define and execute merge actions
based on specified
+ * conditions.
+ *
+ * @tparam T
+ * the type of data in the Dataset.
+ * @param table
+ * the name of the target table for the merge operation.
+ * @param ds
+ * the source Dataset to merge into the target table.
+ * @param on
+ * the merge condition.
+ *
+ * @since 4.0.0
+ */
+@Experimental
+class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on:
Column)
+ extends MergeIntoWriter[T] {
+ import ds.sparkSession.RichColumn
+
+ private val builder = MergeIntoTableCommand
+ .newBuilder()
+ .setTargetTableName(table)
+ .setSourceTablePlan(ds.plan.getRoot)
+ .setMergeCondition(on.expr)
+
+ /**
+ * Executes the merge operation.
+ */
+ def merge(): Unit = {
+ if (builder.getMatchActionsCount == 0 &&
+ builder.getNotMatchedActionsCount == 0 &&
+ builder.getNotMatchedBySourceActionsCount == 0) {
+ throw new SparkRuntimeException(
+ errorClass = "NO_MERGE_ACTION_SPECIFIED",
+ messageParameters = Map.empty)
+ }
+ ds.sparkSession.execute(
+ proto.Command
+ .newBuilder()
+
.setMergeIntoTableCommand(builder.setWithSchemaEvolution(schemaEvolutionEnabled))
+ .build())
+ }
+
+ override protected[sql] def insertAll(condition: Option[Column]):
MergeIntoWriter[T] = {
+ builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT_STAR,
condition))
+ this
+ }
+
+ override protected[sql] def insert(
+ condition: Option[Column],
+ map: Map[String, Column]): MergeIntoWriter[T] = {
+ builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT,
condition, map))
+ this
+ }
+
+ override protected[sql] def updateAll(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(
+ buildMergeAction(ACTION_TYPE_UPDATE_STAR, condition),
+ notMatchedBySource)
+ }
+
+ override protected[sql] def update(
+ condition: Option[Column],
+ map: Map[String, Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(
+ buildMergeAction(ACTION_TYPE_UPDATE, condition, map),
+ notMatchedBySource)
+ }
+
+ override protected[sql] def delete(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(buildMergeAction(ACTION_TYPE_DELETE, condition),
notMatchedBySource)
+ }
+
+ private def appendUpdateDeleteAction(
+ action: Expression,
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ if (notMatchedBySource) {
+ builder.addNotMatchedBySourceActions(action)
+ } else {
+ builder.addMatchActions(action)
+ }
+ this
+ }
+
+ private def buildMergeAction(
+ actionType: MergeAction.ActionType,
+ condition: Option[Column],
+ assignments: Map[String, Column] = Map.empty): Expression = {
+ val builder = proto.MergeAction.newBuilder().setActionType(actionType)
+ condition.foreach(c => builder.setCondition(c.expr))
+ assignments.foreach { case (k, v) =>
+ builder
+ .addAssignmentsBuilder()
+ .setKey(expr(k).expr)
+ .setValue(v.expr)
+ }
+ Expression
+ .newBuilder()
+ .setMergeAction(builder)
+ .build()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
similarity index 68%
rename from sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
index 6212a7fdb259..dabd900917e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
@@ -14,48 +14,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql
-import org.apache.spark.SparkRuntimeException
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction,
InsertAction, InsertStarAction, MergeAction, MergeIntoTable, UpdateAction,
UpdateStarAction}
-import org.apache.spark.sql.functions.expr
/**
* `MergeIntoWriter` provides methods to define and execute merge actions based
* on specified conditions.
*
- * @tparam T the type of data in the Dataset.
- * @param table the name of the target table for the merge operation.
- * @param ds the source Dataset to merge into the target table.
- * @param on the merge condition.
- * @param schemaEvolutionEnabled whether to enable automatic schema evolution
for this merge
- * operation. Default is `false`.
+ * Please note that schema evolution is disabled by default.
*
+ * @tparam T the type of data in the Dataset.
* @since 4.0.0
*/
@Experimental
-class MergeIntoWriter[T] private[sql] (
- table: String,
- ds: Dataset[T],
- on: Column,
- private[sql] val schemaEvolutionEnabled: Boolean = false) {
-
- private val df: DataFrame = ds.toDF()
-
- private[sql] val sparkSession = ds.sparkSession
- import sparkSession.RichColumn
-
- private val tableName =
sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
-
- private val logicalPlan = df.queryExecution.logical
+abstract class MergeIntoWriter[T] {
+ private var schemaEvolution: Boolean = false
- private[sql] var matchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedBySourceActions: Seq[MergeAction] =
Seq.empty[MergeAction]
+ private[sql] def schemaEvolutionEnabled: Boolean = schemaEvolution
/**
* Initialize a `WhenMatched` action without any condition.
@@ -91,7 +67,7 @@ class MergeIntoWriter[T] private[sql] (
* @return a new `WhenMatched` object configured with the specified
condition.
*/
def whenMatched(condition: Column): WhenMatched[T] = {
- new WhenMatched[T](this, Some(condition.expr))
+ new WhenMatched[T](this, Some(condition))
}
/**
@@ -126,7 +102,7 @@ class MergeIntoWriter[T] private[sql] (
* @return a new `WhenNotMatched` object configured with the specified
condition.
*/
def whenNotMatched(condition: Column): WhenNotMatched[T] = {
- new WhenNotMatched[T](this, Some(condition.expr))
+ new WhenNotMatched[T](this, Some(condition))
}
/**
@@ -164,57 +140,43 @@ class MergeIntoWriter[T] private[sql] (
* @return a new `WhenNotMatchedBySource` object configured with the
specified condition.
*/
def whenNotMatchedBySource(condition: Column): WhenNotMatchedBySource[T] = {
- new WhenNotMatchedBySource[T](this, Some(condition.expr))
+ new WhenNotMatchedBySource[T](this, Some(condition))
}
/**
* Enable automatic schema evolution for this merge operation.
+ *
* @return A `MergeIntoWriter` instance with schema evolution enabled.
*/
def withSchemaEvolution(): MergeIntoWriter[T] = {
- new MergeIntoWriter[T](this.table, this.ds, this.on,
schemaEvolutionEnabled = true)
- .withNewMatchedActions(this.matchedActions: _*)
- .withNewNotMatchedActions(this.notMatchedActions: _*)
- .withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*)
+ schemaEvolution = true
+ this
}
/**
* Executes the merge operation.
*/
- def merge(): Unit = {
- if (matchedActions.isEmpty && notMatchedActions.isEmpty &&
notMatchedBySourceActions.isEmpty) {
- throw new SparkRuntimeException(
- errorClass = "NO_MERGE_ACTION_SPECIFIED",
- messageParameters = Map.empty)
- }
+ def merge(): Unit
- val merge = MergeIntoTable(
-
UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges(
- matchedActions, notMatchedActions, notMatchedBySourceActions)),
- logicalPlan,
- on.expr,
- matchedActions,
- notMatchedActions,
- notMatchedBySourceActions,
- schemaEvolutionEnabled)
- val qe = sparkSession.sessionState.executePlan(merge)
- qe.assertCommandExecuted()
- }
+ // Action callbacks.
+ protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T]
- private[sql] def withNewMatchedActions(actions: MergeAction*):
MergeIntoWriter[T] = {
- this.matchedActions ++= actions
- this
- }
+ protected[sql] def insert(
+ condition: Option[Column],
+ map: Map[String, Column]): MergeIntoWriter[T]
- private[sql] def withNewNotMatchedActions(actions: MergeAction*):
MergeIntoWriter[T] = {
- this.notMatchedActions ++= actions
- this
- }
+ protected[sql] def updateAll(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T]
- private[sql] def withNewNotMatchedBySourceActions(actions: MergeAction*):
MergeIntoWriter[T] = {
- this.notMatchedBySourceActions ++= actions
- this
- }
+ protected[sql] def update(
+ condition: Option[Column],
+ map: Map[String, Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T]
+
+ protected[sql] def delete(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T]
}
/**
@@ -227,22 +189,19 @@ class MergeIntoWriter[T] private[sql] (
* should be applied.
* If the condition is None, the actions will be
applied to all matched
* rows.
- *
* @tparam T The type of data in the MergeIntoWriter.
*/
case class WhenMatched[T] private[sql](
mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Expression]) {
- import mergeIntoWriter.sparkSession.RichColumn
+ condition: Option[Column]) {
/**
* Specifies an action to update all matched rows in the DataFrame.
*
* @return The MergeIntoWriter instance with the update all action
configured.
*/
- def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(UpdateStarAction(condition))
- }
+ def updateAll(): MergeIntoWriter[T] =
+ mergeIntoWriter.updateAll(condition, notMatchedBySource = false)
/**
* Specifies an action to update matched rows in the DataFrame with the
provided column
@@ -251,26 +210,23 @@ case class WhenMatched[T] private[sql](
* @param map A Map of column names to Column expressions representing the
updates to be applied.
* @return The MergeIntoWriter instance with the update action configured.
*/
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
- UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr,
x._2.expr)).toSeq))
- }
+ def update(map: Map[String, Column]): MergeIntoWriter[T] =
+ mergeIntoWriter.update(condition, map, notMatchedBySource = false)
/**
* Specifies an action to delete matched rows from the DataFrame.
*
* @return The MergeIntoWriter instance with the delete action configured.
*/
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(DeleteAction(condition))
- }
+ def delete(): MergeIntoWriter[T] =
+ mergeIntoWriter.delete(condition, notMatchedBySource = false)
}
/**
* A class for defining actions to be taken when no matching rows are found in
a DataFrame
* during a merge operation.
*
- * @param MergeIntoWriter The MergeIntoWriter instance responsible for
writing data to a
+ * @param mergeIntoWriter The MergeIntoWriter instance responsible for
writing data to a
* target DataFrame.
* @param condition An optional condition Expression that specifies
when the actions
* defined in this configuration should be applied.
@@ -280,17 +236,15 @@ case class WhenMatched[T] private[sql](
*/
case class WhenNotMatched[T] private[sql](
mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Expression]) {
- import mergeIntoWriter.sparkSession.RichColumn
+ condition: Option[Column]) {
/**
* Specifies an action to insert all non-matched rows into the DataFrame.
*
* @return The MergeIntoWriter instance with the insert all action
configured.
*/
- def insertAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(InsertStarAction(condition))
- }
+ def insertAll(): MergeIntoWriter[T] =
+ mergeIntoWriter.insertAll(condition)
/**
* Specifies an action to insert non-matched rows into the DataFrame with
the provided
@@ -299,10 +253,8 @@ case class WhenNotMatched[T] private[sql](
* @param map A Map of column names to Column expressions representing the
values to be inserted.
* @return The MergeIntoWriter instance with the insert action configured.
*/
- def insert(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(
- InsertAction(condition, map.map(x => Assignment(expr(x._1).expr,
x._2.expr)).toSeq))
- }
+ def insert(map: Map[String, Column]): MergeIntoWriter[T] =
+ mergeIntoWriter.insert(condition, map)
}
@@ -310,14 +262,13 @@ case class WhenNotMatched[T] private[sql](
* A class for defining actions to be performed when there is no match by
source
* during a merge operation in a MergeIntoWriter.
*
- * @param MergeIntoWriter the MergeIntoWriter instance to which the merge
actions will be applied.
+ * @param mergeIntoWriter the MergeIntoWriter instance to which the merge
actions will be applied.
* @param condition an optional condition to be used with the merge
actions.
* @tparam T the type parameter for the MergeIntoWriter.
*/
case class WhenNotMatchedBySource[T] private[sql](
mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Expression]) {
- import mergeIntoWriter.sparkSession.RichColumn
+ condition: Option[Column]) {
/**
* Specifies an action to update all non-matched rows in the target
DataFrame when
@@ -325,9 +276,8 @@ case class WhenNotMatchedBySource[T] private[sql](
*
* @return The MergeIntoWriter instance with the update all action
configured.
*/
- def updateAll(): MergeIntoWriter[T] = {
-
mergeIntoWriter.withNewNotMatchedBySourceActions(UpdateStarAction(condition))
- }
+ def updateAll(): MergeIntoWriter[T] =
+ mergeIntoWriter.updateAll(condition, notMatchedBySource = true)
/**
* Specifies an action to update non-matched rows in the target DataFrame
with the provided
@@ -336,10 +286,8 @@ case class WhenNotMatchedBySource[T] private[sql](
* @param map A Map of column names to Column expressions representing the
updates to be applied.
* @return The MergeIntoWriter instance with the update action configured.
*/
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr,
x._2.expr)).toSeq))
- }
+ def update(map: Map[String, Column]): MergeIntoWriter[T] =
+ mergeIntoWriter.update(condition, map, notMatchedBySource = true)
/**
* Specifies an action to delete non-matched rows from the target DataFrame
when not matched by
@@ -347,7 +295,6 @@ case class WhenNotMatchedBySource[T] private[sql](
*
* @return The MergeIntoWriter instance with the delete action configured.
*/
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(DeleteAction(condition))
- }
+ def delete(): MergeIntoWriter[T] =
+ mergeIntoWriter.delete(condition, notMatchedBySource = true)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 38abb63c9dcc..6af6cc537e47 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -23,7 +23,7 @@ import _root_.java.util
import org.apache.spark.annotation.{DeveloperApi, Stable}
import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction,
ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction,
ReduceFunction}
-import org.apache.spark.sql.{functions, AnalysisException, Column,
DataFrameWriter, DataFrameWriterV2, Encoder, Observation, Row, TypedColumn}
+import org.apache.spark.sql.{functions, AnalysisException, Column,
DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row,
TypedColumn}
import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors}
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
@@ -2837,6 +2837,30 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]]
extends Serializable {
protected def createTempView(viewName: String, replace: Boolean, global:
Boolean): Unit
+ /**
+ * Merges a set of updates, insertions, and deletions based on a source
table into
+ * a target table.
+ *
+ * Scala Examples:
+ * {{{
+ * spark.table("source")
+ * .mergeInto("target", $"source.id" === $"target.id")
+ * .whenMatched($"salary" === 100)
+ * .delete()
+ * .whenNotMatched()
+ * .insertAll()
+ * .whenNotMatchedBySource($"salary" === 100)
+ * .update(Map(
+ * "salary" -> lit(200)
+ * ))
+ * .merge()
+ * }}}
+ *
+ * @group basic
+ * @since 4.0.0
+ */
+ def mergeInto(table: String, condition: Column): MergeIntoWriter[T]
+
/**
* Create a write configuration builder for v2 sources.
*
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 05628d7b1c98..fdd43404e1d9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -796,9 +796,9 @@ case class MergeIntoTable(
object MergeIntoTable {
def getWritePrivileges(
- matchedActions: Seq[MergeAction],
- notMatchedActions: Seq[MergeAction],
- notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] =
{
+ matchedActions: Iterable[MergeAction],
+ notMatchedActions: Iterable[MergeAction],
+ notMatchedBySourceActions: Iterable[MergeAction]):
Seq[TableWritePrivilege] = {
val privileges =
scala.collection.mutable.HashSet.empty[TableWritePrivilege]
(matchedActions.iterator ++ notMatchedActions ++
notMatchedBySourceActions).foreach {
case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 58e61badaf37..83a41dad7fa3 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -43,7 +43,7 @@ import
org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
-import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation,
RelationalGroupedDataset, SparkSession}
+import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation,
RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery,
PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute,
UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -78,7 +78,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
import
org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator,
SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils}
+import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl,
TypedAggUtils}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf,
ProtobufDataToCatalyst}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode,
StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
@@ -3496,16 +3496,14 @@ class SparkConnectPlanner(
val notMatchedBySourceActions =
transformActions(cmd.getNotMatchedBySourceActionsList)
val sourceDs = Dataset.ofRows(session,
transformRelation(cmd.getSourceTablePlan))
- var mergeInto = sourceDs
+ val mergeInto = sourceDs
.mergeInto(cmd.getTargetTableName,
column(transformExpression(cmd.getMergeCondition)))
- .withNewMatchedActions(matchedActions: _*)
- .withNewNotMatchedActions(notMatchedActions: _*)
- .withNewNotMatchedBySourceActions(notMatchedBySourceActions: _*)
-
- mergeInto = if (cmd.getWithSchemaEvolution) {
+ .asInstanceOf[MergeIntoWriterImpl[Row]]
+ mergeInto.matchedActions ++= matchedActions
+ mergeInto.notMatchedActions ++= notMatchedActions
+ mergeInto.notMatchedBySourceActions ++= notMatchedBySourceActions
+ if (cmd.getWithSchemaEvolution) {
mergeInto.withSchemaEvolution()
- } else {
- mergeInto
}
mergeInto.merge()
executeHolder.eventsManager.postFinished()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f62331710d63..2dec4443da71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -60,7 +60,7 @@ import
org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation,
DataSourceV2ScanRelation, FileTable}
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
-import org.apache.spark.sql.internal.{DataFrameWriterImpl,
DataFrameWriterV2Impl, SQLConf, ToScalaUDF}
+import org.apache.spark.sql.internal.{DataFrameWriterImpl,
DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
import org.apache.spark.sql.streaming.DataStreamWriter
@@ -1635,7 +1635,7 @@ class Dataset[T] private[sql](
messageParameters = Map("methodName" -> toSQLId("mergeInto")))
}
- new MergeIntoWriter[T](table, this, condition)
+ new MergeIntoWriterImpl[T](table, this, condition)
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
new file mode 100644
index 000000000000..bb8146e3e0e3
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.{Column, DataFrame, Dataset, MergeIntoWriter}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.functions.expr
+
+/**
+ * `MergeIntoWriter` provides methods to define and execute merge actions based
+ * on specified conditions.
+ *
+ * @tparam T the type of data in the Dataset.
+ * @param table the name of the target table for the merge operation.
+ * @param ds the source Dataset to merge into the target table.
+ * @param on the merge condition.
+ *
+ * @since 4.0.0
+ */
+@Experimental
+class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on:
Column)
+ extends MergeIntoWriter[T] {
+
+ private val df: DataFrame = ds.toDF()
+
+ private[sql] val sparkSession = ds.sparkSession
+ import sparkSession.RichColumn
+
+ private val tableName =
sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
+
+ private val logicalPlan = df.queryExecution.logical
+
+ private[sql] val matchedActions = mutable.Buffer.empty[MergeAction]
+ private[sql] val notMatchedActions = mutable.Buffer.empty[MergeAction]
+ private[sql] val notMatchedBySourceActions =
mutable.Buffer.empty[MergeAction]
+
+ /** @inheritdoc */
+ def merge(): Unit = {
+ if (matchedActions.isEmpty && notMatchedActions.isEmpty &&
notMatchedBySourceActions.isEmpty) {
+ throw new SparkRuntimeException(
+ errorClass = "NO_MERGE_ACTION_SPECIFIED",
+ messageParameters = Map.empty)
+ }
+
+ val merge = MergeIntoTable(
+
UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges(
+ matchedActions, notMatchedActions, notMatchedBySourceActions)),
+ logicalPlan,
+ on.expr,
+ matchedActions.toSeq,
+ notMatchedActions.toSeq,
+ notMatchedBySourceActions.toSeq,
+ schemaEvolutionEnabled)
+ val qe = sparkSession.sessionState.executePlan(merge)
+ qe.assertCommandExecuted()
+ }
+
+ override protected[sql] def insertAll(condition: Option[Column]):
MergeIntoWriter[T] = {
+ this.notMatchedActions += InsertStarAction(condition.map(_.expr))
+ this
+ }
+
+ override protected[sql] def insert(
+ condition: Option[Column],
+ map: Map[String, Column]): MergeIntoWriter[T] = {
+ this.notMatchedActions += InsertAction(condition.map(_.expr),
mapToAssignments(map))
+ this
+ }
+
+ override protected[sql] def updateAll(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(UpdateStarAction(condition.map(_.expr)),
notMatchedBySource)
+ }
+
+ override protected[sql] def update(
+ condition: Option[Column],
+ map: Map[String, Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(
+ UpdateAction(condition.map(_.expr), mapToAssignments(map)),
+ notMatchedBySource)
+ }
+
+ override protected[sql] def delete(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(DeleteAction(condition.map(_.expr)),
notMatchedBySource)
+ }
+
+ private def appendUpdateDeleteAction(
+ action: MergeAction,
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ if (notMatchedBySource) {
+ notMatchedBySourceActions += action
+ } else {
+ matchedActions += action
+ }
+ this
+ }
+
+ private def mapToAssignments(map: Map[String, Column]): Seq[Assignment] = {
+ map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
index c080a66bce25..8aa8fb21f4ae 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.MergeIntoWriterImpl
class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
@@ -950,7 +951,7 @@ class MergeIntoDataFrameSuite extends
RowLevelOperationSuiteBase {
// an arbitrary merge
val writer1 = spark.table("source")
- .mergeInto("dummy", $"col" === $"col")
+ .mergeInto("dummy", $"colA" === $"colB")
.whenMatched(col("col") === 1)
.updateAll()
.whenMatched()
@@ -959,16 +960,15 @@ class MergeIntoDataFrameSuite extends
RowLevelOperationSuiteBase {
.insertAll()
.whenNotMatchedBySource(col("col") === 1)
.delete()
+ .asInstanceOf[MergeIntoWriterImpl[Row]]
val writer2 = writer1.withSchemaEvolution()
+ .asInstanceOf[MergeIntoWriterImpl[Row]]
+ assert(writer1 eq writer2)
assert(writer1.matchedActions.length === 2)
assert(writer1.notMatchedActions.length === 1)
assert(writer1.notMatchedBySourceActions.length === 1)
-
- assert(writer1.matchedActions === writer2.matchedActions)
- assert(writer1.notMatchedActions === writer2.notMatchedActions)
- assert(writer1.notMatchedBySourceActions ===
writer2.notMatchedBySourceActions)
- assert(writer2.schemaEvolutionEnabled)
+ assert(writer1.schemaEvolutionEnabled)
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]