This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 4ccd530f639 [SPARK-38085][SQL] DataSource V2: Handle DELETE commands
for group-based sources
4ccd530f639 is described below
commit 4ccd530f639e3652b7aad7c8bcfa379847dc2b68
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Wed Apr 13 13:47:00 2022 +0800
[SPARK-38085][SQL] DataSource V2: Handle DELETE commands for group-based
sources
This PR contains changes to rewrite DELETE operations for V2 data sources
that can replace groups of data (e.g. files, partitions).
These changes are needed to support row-level operations in Spark per SPIP
SPARK-35801.
No.
This PR comes with tests.
Closes #35395 from aokolnychyi/spark-38085.
Authored-by: Anton Okolnychyi <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 5a92eccd514b7bc0513feaecb041aee2f8cd5a24)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 1 +
.../catalyst/analysis/RewriteDeleteFromTable.scala | 89 +++
.../catalyst/analysis/RewriteRowLevelCommand.scala | 71 +++
.../ReplaceNullWithFalseInPredicate.scala | 3 +-
.../SimplifyConditionalsInPredicate.scala | 1 +
.../spark/sql/catalyst/planning/patterns.scala | 51 ++
.../sql/catalyst/plans/logical/v2Commands.scala | 92 ++-
.../write/RowLevelOperationInfoImpl.scala | 25 +
.../connector/write/RowLevelOperationTable.scala | 51 ++
.../spark/sql/errors/QueryCompilationErrors.scala | 4 +
.../datasources/v2/DataSourceV2Implicits.scala | 10 +
.../catalog/InMemoryRowLevelOperationTable.scala | 96 ++++
.../InMemoryRowLevelOperationTableCatalog.scala | 46 ++
.../sql/connector/catalog/InMemoryTable.scala | 22 +-
.../spark/sql/execution/SparkOptimizer.scala | 7 +-
.../datasources/v2/DataSourceV2Strategy.scala | 22 +-
.../GroupBasedRowLevelOperationScanPlanning.scala | 83 +++
.../v2/OptimizeMetadataOnlyDeleteFromTable.scala | 84 +++
.../execution/datasources/v2/PushDownUtils.scala | 2 +-
.../sql/execution/datasources/v2/V2Writes.scala | 24 +-
.../datasources/v2/WriteToDataSourceV2Exec.scala | 15 +
.../spark/sql/connector/DeleteFromTableSuite.scala | 629 +++++++++++++++++++++
.../execution/command/PlanResolutionSuite.scala | 4 +-
23 files changed, 1407 insertions(+), 25 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6b44483ab1d..9fdc466b425 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -318,6 +318,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
ResolveUnion ::
+ RewriteDeleteFromTable ::
typeCoercionRules ++
Seq(ResolveWithCTE) ++
extendedResolutionRules : _*),
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
new file mode 100644
index 00000000000..85af999902e
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, Expression,
Not}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
+import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter,
LogicalPlan, ReplaceData}
+import org.apache.spark.sql.connector.catalog.{SupportsDelete,
SupportsRowLevelOperations, TruncatableTable}
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE
+import org.apache.spark.sql.connector.write.RowLevelOperationTable
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * A rule that rewrites DELETE operations using plans that operate on
individual or groups of rows.
+ *
+ * If a table implements [[SupportsDelete]] and
[[SupportsRowLevelOperations]], this rule will
+ * still rewrite the DELETE operation but the optimizer will check whether
this particular DELETE
+ * statement can be handled by simply passing delete filters to the connector.
If so, the optimizer
+ * will discard the rewritten plan and will allow the data source to delete
using filters.
+ */
+object RewriteDeleteFromTable extends RewriteRowLevelCommand {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case d @ DeleteFromTable(aliasedTable, cond) if d.resolved =>
+ EliminateSubqueryAliases(aliasedTable) match {
+ case DataSourceV2Relation(_: TruncatableTable, _, _, _, _) if cond ==
TrueLiteral =>
+ // don't rewrite as the table supports truncation
+ d
+
+ case r @ DataSourceV2Relation(t: SupportsRowLevelOperations, _, _, _,
_) =>
+ val table = buildOperationTable(t, DELETE,
CaseInsensitiveStringMap.empty())
+ buildReplaceDataPlan(r, table, cond)
+
+ case DataSourceV2Relation(_: SupportsDelete, _, _, _, _) =>
+ // don't rewrite as the table supports deletes only with filters
+ d
+
+ case DataSourceV2Relation(t, _, _, _, _) =>
+ throw QueryCompilationErrors.tableDoesNotSupportDeletesError(t)
+
+ case _ =>
+ d
+ }
+ }
+
+ // build a rewrite plan for sources that support replacing groups of data
(e.g. files, partitions)
+ private def buildReplaceDataPlan(
+ relation: DataSourceV2Relation,
+ operationTable: RowLevelOperationTable,
+ cond: Expression): ReplaceData = {
+
+ // resolve all required metadata attrs that may be used for grouping data
on write
+ // for instance, JDBC data source may cluster data by shard/host before
writing
+ val metadataAttrs = resolveRequiredMetadataAttrs(relation,
operationTable.operation)
+
+ // construct a read relation and include all required metadata columns
+ val readRelation = buildRelationWithAttrs(relation, operationTable,
metadataAttrs)
+
+ // construct a plan that contains unmatched rows in matched groups that
must be carried over
+ // such rows do not match the condition but have to be copied over as the
source can replace
+ // only groups of rows (e.g. if a source supports replacing files,
unmatched rows in matched
+ // files must be carried over)
+ // it is safe to negate the condition here as the predicate pushdown for
group-based row-level
+ // operations is handled in a special way
+ val remainingRowsFilter = Not(EqualNullSafe(cond, TrueLiteral))
+ val remainingRowsPlan = Filter(remainingRowsFilter, readRelation)
+
+ // build a plan to replace read groups in the table
+ val writeRelation = relation.copy(table = operationTable)
+ ReplaceData(writeRelation, cond, remainingRowsPlan, relation)
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
new file mode 100644
index 00000000000..bf8c3e27f4d
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId,
V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
+import org.apache.spark.sql.connector.write.{RowLevelOperation,
RowLevelOperationInfoImpl, RowLevelOperationTable}
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
+
+ protected def buildOperationTable(
+ table: SupportsRowLevelOperations,
+ command: Command,
+ options: CaseInsensitiveStringMap): RowLevelOperationTable = {
+ val info = RowLevelOperationInfoImpl(command, options)
+ val operation = table.newRowLevelOperationBuilder(info).build()
+ RowLevelOperationTable(table, operation)
+ }
+
+ protected def buildRelationWithAttrs(
+ relation: DataSourceV2Relation,
+ table: RowLevelOperationTable,
+ metadataAttrs: Seq[AttributeReference]): DataSourceV2Relation = {
+
+ val attrs = dedupAttrs(relation.output ++ metadataAttrs)
+ relation.copy(table = table, output = attrs)
+ }
+
+ protected def dedupAttrs(attrs: Seq[AttributeReference]):
Seq[AttributeReference] = {
+ val exprIds = mutable.Set.empty[ExprId]
+ attrs.flatMap { attr =>
+ if (exprIds.contains(attr.exprId)) {
+ None
+ } else {
+ exprIds += attr.exprId
+ Some(attr)
+ }
+ }
+ }
+
+ protected def resolveRequiredMetadataAttrs(
+ relation: DataSourceV2Relation,
+ operation: RowLevelOperation): Seq[AttributeReference] = {
+
+ V2ExpressionUtils.resolveRefs[AttributeReference](
+ operation.requiredMetadataAttributes,
+ relation)
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
index 9ec498aa14e..d060a8be5da 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists,
ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, In, InSet,
LambdaFunction, Literal, MapFilter, Not, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLiteral}
-import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction,
DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan,
MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable}
+import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction,
DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan,
MergeAction, MergeIntoTable, ReplaceData, UpdateAction, UpdateStarAction,
UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{INSET, NULL_LITERAL,
TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.types.BooleanType
@@ -54,6 +54,7 @@ object ReplaceNullWithFalseInPredicate extends
Rule[LogicalPlan] {
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition =
Some(replaceNullWithFalse(cond)))
+ case rd @ ReplaceData(_, cond, _, _, _) => rd.copy(condition =
replaceNullWithFalse(cond))
case d @ DeleteFromTable(_, cond) => d.copy(condition =
replaceNullWithFalse(cond))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition =
Some(replaceNullWithFalse(cond)))
case m @ MergeIntoTable(_, _, mergeCond, matchedActions,
notMatchedActions) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala
index e1972b997c2..34773b24cac 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala
@@ -48,6 +48,7 @@ object SimplifyConditionalsInPredicate extends
Rule[LogicalPlan] {
_.containsAnyPattern(CASE_WHEN, IF), ruleId) {
case f @ Filter(cond, _) => f.copy(condition = simplifyConditional(cond))
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition =
Some(simplifyConditional(cond)))
+ case rd @ ReplaceData(_, cond, _, _, _) => rd.copy(condition =
simplifyConditional(cond))
case d @ DeleteFromTable(_, cond) => d.copy(condition =
simplifyConditional(cond))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition =
Some(simplifyConditional(cond)))
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 8c41ab2797b..382909d6d6f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -18,12 +18,15 @@
package org.apache.spark.sql.catalyst.planning
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation,
DataSourceV2ScanRelation}
import org.apache.spark.sql.internal.SQLConf
trait OperationHelper extends AliasHelper with PredicateHelper {
@@ -388,3 +391,51 @@ object ExtractSingleColumnNullAwareAntiJoin extends
JoinSelectionHelper with Pre
case _ => None
}
}
+
+/**
+ * An extractor for row-level commands such as DELETE, UPDATE, MERGE that were
rewritten using plans
+ * that operate on groups of rows.
+ *
+ * This class extracts the following entities:
+ * - the group-based rewrite plan;
+ * - the condition that defines matching groups;
+ * - the read relation that can be either [[DataSourceV2Relation]] or
[[DataSourceV2ScanRelation]]
+ * depending on whether the planning has already happened;
+ */
+object GroupBasedRowLevelOperation {
+ type ReturnType = (ReplaceData, Expression, LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
+ case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), cond,
query, _, _) =>
+ val readRelation = findReadRelation(table, query)
+ readRelation.map((rd, cond, _))
+
+ case _ =>
+ None
+ }
+
+ private def findReadRelation(
+ table: Table,
+ plan: LogicalPlan): Option[LogicalPlan] = {
+
+ val readRelations = plan.collect {
+ case r: DataSourceV2Relation if r.table eq table => r
+ case r: DataSourceV2ScanRelation if r.relation.table eq table => r
+ }
+
+ // in some cases, the optimizer replaces the v2 read relation with a local
relation
+ // for example, there is no reason to query the table if the condition is
always false
+ // that's why it is valid not to find the corresponding v2 read relation
+
+ readRelations match {
+ case relations if relations.isEmpty =>
+ None
+
+ case Seq(relation) =>
+ Some(relation)
+
+ case relations =>
+ throw new AnalysisException(s"Expected only one row-level read
relation: $relations")
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index b2ca34668a6..b1b8843aa33 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -17,16 +17,18 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, FieldName,
NamedRelation, PartitionSpec, ResolvedDBObjectName, UnresolvedException}
+import org.apache.spark.sql.{sources, AnalysisException}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisContext,
EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec,
ResolvedDBObjectName, UnresolvedException}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.FunctionResource
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, AttributeSet, Expression, Unevaluable}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, AttributeSet, Expression, MetadataAttribute, Unevaluable}
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.connector.write.Write
+import org.apache.spark.sql.connector.write.{RowLevelOperation,
RowLevelOperationTable, Write}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder,
StringType, StructType}
/**
@@ -176,6 +178,80 @@ object OverwritePartitionsDynamic {
}
}
+trait RowLevelWrite extends V2WriteCommand with SupportsSubquery {
+ def operation: RowLevelOperation
+ def condition: Expression
+ def originalTable: NamedRelation
+}
+
+/**
+ * Replace groups of data in an existing table during a row-level operation.
+ *
+ * This node is constructed in rules that rewrite DELETE, UPDATE, MERGE
operations for data sources
+ * that can replace groups of data (e.g. files, partitions).
+ *
+ * @param table a plan that references a row-level operation table
+ * @param condition a condition that defines matching groups
+ * @param query a query with records that should replace the records that were
read
+ * @param originalTable a plan for the original table for which the row-level
command was triggered
+ * @param write a logical write, if already constructed
+ */
+case class ReplaceData(
+ table: NamedRelation,
+ condition: Expression,
+ query: LogicalPlan,
+ originalTable: NamedRelation,
+ write: Option[Write] = None) extends RowLevelWrite {
+
+ override val isByName: Boolean = false
+ override val stringArgs: Iterator[Any] = Iterator(table, query, write)
+
+ override lazy val references: AttributeSet = query.outputSet
+
+ lazy val operation: RowLevelOperation = {
+ EliminateSubqueryAliases(table) match {
+ case DataSourceV2Relation(RowLevelOperationTable(_, operation), _, _, _,
_) =>
+ operation
+ case _ =>
+ throw new AnalysisException(s"Cannot retrieve row-level operation from
$table")
+ }
+ }
+
+ // the incoming query may include metadata columns
+ lazy val dataInput: Seq[Attribute] = {
+ query.output.filter {
+ case MetadataAttribute(_) => false
+ case _ => true
+ }
+ }
+
+ override def outputResolved: Boolean = {
+ assert(table.resolved && query.resolved,
+ "`outputResolved` can only be called when `table` and `query` are both
resolved.")
+
+ // take into account only incoming data columns and ignore metadata
columns in the query
+ // they will be discarded after the logical write is built in the optimizer
+ // metadata columns may be needed to request a correct distribution or
ordering
+ // but are not passed back to the data source during writes
+
+ table.skipSchemaResolution || (dataInput.size == table.output.size &&
+ dataInput.zip(table.output).forall { case (inAttr, outAttr) =>
+ val outType =
CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
+ // names and types must match, nullability must be compatible
+ inAttr.name == outAttr.name &&
+ DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType)
&&
+ (outAttr.nullable || !inAttr.nullable)
+ })
+ }
+
+ override def withNewQuery(newQuery: LogicalPlan): ReplaceData = copy(query =
newQuery)
+
+ override def withNewTable(newTable: NamedRelation): ReplaceData = copy(table
= newTable)
+
+ override protected def withNewChildInternal(newChild: LogicalPlan):
ReplaceData = {
+ copy(query = newChild)
+ }
+}
/** A trait used for logical plan nodes that create or replace V2 table
definitions. */
trait V2CreateTablePlan extends LogicalPlan {
@@ -457,6 +533,16 @@ case class DeleteFromTable(
copy(table = newChild)
}
+/**
+ * The logical plan of the DELETE FROM command that can be executed using data
source filters.
+ *
+ * As opposed to [[DeleteFromTable]], this node represents a DELETE operation
where the condition
+ * was converted into filters and the data source reported that it can handle
all of them.
+ */
+case class DeleteFromTableWithFilters(
+ table: LogicalPlan,
+ condition: Seq[sources.Filter]) extends LeafCommand
+
/**
* The logical plan of the UPDATE TABLE command.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala
new file mode 100644
index 00000000000..9d499cdef36
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.write
+
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+private[sql] case class RowLevelOperationInfoImpl(
+ command: Command,
+ options: CaseInsensitiveStringMap) extends RowLevelOperationInfo
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala
new file mode 100644
index 00000000000..d1f7ba000c6
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.write
+
+import java.util
+
+import org.apache.spark.sql.connector.catalog.{SupportsRead,
SupportsRowLevelOperations, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * An internal v2 table implementation that wraps the original table and a
logical row-level
+ * operation for DELETE, UPDATE, MERGE commands that require rewriting data.
+ *
+ * The purpose of this table is to make the existing scan and write planning
rules work
+ * with commands that require coordination between the scan and the write (so
that the write
+ * knows what to replace).
+ */
+private[sql] case class RowLevelOperationTable(
+ table: Table with SupportsRowLevelOperations,
+ operation: RowLevelOperation) extends Table with SupportsRead with
SupportsWrite {
+
+ override def name: String = table.name
+ override def schema: StructType = table.schema
+ override def capabilities: util.Set[TableCapability] = table.capabilities
+ override def toString: String = table.toString
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder
= {
+ operation.newScanBuilder(options)
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ operation.newWriteBuilder(info)
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 57ed7da7b20..0532a953ef4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -926,6 +926,10 @@ object QueryCompilationErrors {
tableDoesNotSupportError("atomic partition management", table)
}
+ def tableIsNotRowLevelOperationTableError(table: Table): Throwable = {
+ throw new AnalysisException(s"Table ${table.name} is not a row-level
operation table")
+ }
+
def cannotRenameTableWithAlterViewError(): Throwable = {
new AnalysisException(
"Cannot rename a table with ALTER VIEW. Please use ALTER TABLE instead.")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
index efd3ffebf5c..16d5a9cc70d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{PartitionSpec,
ResolvedPartitionS
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY
import org.apache.spark.sql.connector.catalog.{MetadataColumn,
SupportsAtomicPartitionManagement, SupportsDelete, SupportsPartitionManagement,
SupportsRead, SupportsWrite, Table, TableCapability, TruncatableTable}
+import org.apache.spark.sql.connector.write.RowLevelOperationTable
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -82,6 +83,15 @@ object DataSourceV2Implicits {
}
}
+ def asRowLevelOperationTable: RowLevelOperationTable = {
+ table match {
+ case rowLevelOperationTable: RowLevelOperationTable =>
+ rowLevelOperationTable
+ case _ =>
+ throw
QueryCompilationErrors.tableIsNotRowLevelOperationTableError(table)
+ }
+ }
+
def supports(capability: TableCapability): Boolean =
table.capabilities.contains(capability)
def supportsAny(capabilities: TableCapability*): Boolean =
capabilities.exists(supports)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
new file mode 100644
index 00000000000..cb061602ec1
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.connector.distributions.{Distribution,
Distributions}
+import org.apache.spark.sql.connector.expressions.{FieldReference,
LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
+import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo,
RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder,
RowLevelOperationInfo, Write, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class InMemoryRowLevelOperationTable(
+ name: String,
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String])
+ extends InMemoryTable(name, schema, partitioning, properties) with
SupportsRowLevelOperations {
+
+ override def newRowLevelOperationBuilder(
+ info: RowLevelOperationInfo): RowLevelOperationBuilder = {
+ () => PartitionBasedOperation(info.command)
+ }
+
+ case class PartitionBasedOperation(command: Command) extends
RowLevelOperation {
+ private final val PARTITION_COLUMN_REF =
FieldReference(PartitionKeyColumn.name)
+
+ var configuredScan: InMemoryBatchScan = _
+
+ override def requiredMetadataAttributes(): Array[NamedReference] = {
+ Array(PARTITION_COLUMN_REF)
+ }
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap):
ScanBuilder = {
+ new InMemoryScanBuilder(schema) {
+ override def build: Scan = {
+ val scan = super.build()
+ configuredScan = scan.asInstanceOf[InMemoryBatchScan]
+ scan
+ }
+ }
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = new
WriteBuilder {
+
+ override def build(): Write = new Write with
RequiresDistributionAndOrdering {
+ override def requiredDistribution(): Distribution = {
+ Distributions.clustered(Array(PARTITION_COLUMN_REF))
+ }
+
+ override def requiredOrdering(): Array[SortOrder] = {
+ Array[SortOrder](
+ LogicalExpressions.sort(
+ PARTITION_COLUMN_REF,
+ SortDirection.ASCENDING,
+ SortDirection.ASCENDING.defaultNullOrdering())
+ )
+ }
+
+ override def toBatch: BatchWrite =
PartitionBasedReplaceData(configuredScan)
+
+ override def description(): String = "InMemoryWrite"
+ }
+ }
+
+ override def description(): String = "InMemoryPartitionReplaceOperation"
+ }
+
+ private case class PartitionBasedReplaceData(scan: InMemoryBatchScan)
extends TestBatchWrite {
+
+ override def commit(messages: Array[WriterCommitMessage]): Unit =
dataMap.synchronized {
+ val newData = messages.map(_.asInstanceOf[BufferedRows])
+ val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows)
+ val readPartitions = readRows.map(r => getKey(r, schema))
+ dataMap --= readPartitions
+ withData(newData, schema)
+ }
+ }
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
new file mode 100644
index 00000000000..2d9a9f04785
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.StructType
+
+class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
+ import CatalogV2Implicits._
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident)
+ }
+
+ InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
+
+ val tableName = s"$name.${ident.quoted}"
+ val table = new InMemoryRowLevelOperationTable(tableName, schema,
partitions, properties)
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index a762b0f8783..beed9111a30 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -56,7 +56,7 @@ class InMemoryTable(
extends Table with SupportsRead with SupportsWrite with SupportsDelete
with SupportsMetadataColumns {
- private object PartitionKeyColumn extends MetadataColumn {
+ protected object PartitionKeyColumn extends MetadataColumn {
override def name: String = "_partition"
override def dataType: DataType = StringType
override def comment: String = "Partition key used to store the row"
@@ -104,7 +104,11 @@ class InMemoryTable(
private val UTC = ZoneId.of("UTC")
private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate
- private def getKey(row: InternalRow): Seq[Any] = {
+ protected def getKey(row: InternalRow): Seq[Any] = {
+ getKey(row, schema)
+ }
+
+ protected def getKey(row: InternalRow, rowSchema: StructType): Seq[Any] = {
@scala.annotation.tailrec
def extractor(
fieldNames: Array[String],
@@ -124,7 +128,7 @@ class InMemoryTable(
}
}
- val cleanedSchema =
CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
+ val cleanedSchema =
CharVarcharUtils.replaceCharVarcharWithStringInSchema(rowSchema)
partitioning.map {
case IdentityTransform(ref) =>
extractor(ref.fieldNames, cleanedSchema, row)._1
@@ -219,9 +223,15 @@ class InMemoryTable(
dataMap(key).clear()
}
- def withData(data: Array[BufferedRows]): InMemoryTable =
dataMap.synchronized {
+ def withData(data: Array[BufferedRows]): InMemoryTable = {
+ withData(data, schema)
+ }
+
+ def withData(
+ data: Array[BufferedRows],
+ writeSchema: StructType): InMemoryTable = dataMap.synchronized {
data.foreach(_.rows.foreach { row =>
- val key = getKey(row)
+ val key = getKey(row, writeSchema)
dataMap += dataMap.get(key)
.map(key -> _.withRow(row))
.getOrElse(key -> new BufferedRows(key).withRow(row))
@@ -372,7 +382,7 @@ class InMemoryTable(
}
}
- private abstract class TestBatchWrite extends BatchWrite {
+ protected abstract class TestBatchWrite extends BatchWrite {
override def createBatchWriterFactory(info: PhysicalWriteInfo):
DataWriterFactory = {
BufferedRowsWriterFactory
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index bfe4bd29241..8c134363af1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
-import org.apache.spark.sql.execution.datasources.v2.{V2ScanPartitioning,
V2ScanRelationPushDown, V2Writes}
+import
org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning,
OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioning,
V2ScanRelationPushDown, V2Writes}
import
org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters,
PartitionPruning}
import
org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate,
ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
@@ -38,11 +38,15 @@ class SparkOptimizer(
override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
// TODO: move SchemaPruning into catalyst
Seq(SchemaPruning) :+
+ GroupBasedRowLevelOperationScanPlanning :+
V2ScanRelationPushDown :+
V2ScanPartitioning :+
V2Writes :+
PruneFileSourcePartitions
+ override def preCBORules: Seq[Rule[LogicalPlan]] =
+ OptimizeMetadataOnlyDeleteFromTable :: Nil
+
override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++
super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once,
OptimizeMetadataOnlyQuery(catalog)) :+
Batch("PartitionPruning", Once,
@@ -78,6 +82,7 @@ class SparkOptimizer(
ExtractPythonUDFFromJoinCondition.ruleName :+
ExtractPythonUDFFromAggregate.ruleName :+
ExtractGroupingPythonUDFFromAggregate.ruleName :+
ExtractPythonUDFs.ruleName :+
+ GroupBasedRowLevelOperationScanPlanning.ruleName :+
V2ScanRelationPushDown.ruleName :+
V2ScanPartitioning.ruleName :+
V2Writes.ruleName
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 45540fb4a11..95418027187 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -25,10 +25,11 @@ import
org.apache.spark.sql.catalyst.analysis.{ResolvedDBObjectName, ResolvedNam
import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{And, Attribute,
DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper,
SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, V2ExpressionBuilder}
-import org.apache.spark.sql.connector.catalog.{Identifier,
StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement,
SupportsWrite, Table, TableCapability, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{Identifier,
StagingTableCatalog, SupportsDelete, SupportsNamespaces,
SupportsPartitionManagement, SupportsWrite, Table, TableCapability,
TableCatalog, TruncatableTable}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.{FieldReference}
import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not =>
V2Not, Or => V2Or, Predicate}
@@ -254,6 +255,9 @@ class DataSourceV2Strategy(session: SparkSession) extends
Strategy with Predicat
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _, _,
Some(write)) =>
OverwritePartitionsDynamicExec(planLater(query), refreshCache(r), write)
:: Nil
+ case DeleteFromTableWithFilters(r: DataSourceV2Relation, filters) =>
+ DeleteFromTableExec(r.table.asDeletable, filters.toArray,
refreshCache(r)) :: Nil
+
case DeleteFromTable(relation, condition) =>
relation match {
case DataSourceV2ScanRelation(r, _, output, _) =>
@@ -269,15 +273,25 @@ class DataSourceV2Strategy(session: SparkSession) extends
Strategy with Predicat
throw
QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(f))
}).toArray
- if (!table.asDeletable.canDeleteWhere(filters)) {
- throw
QueryCompilationErrors.cannotDeleteTableWhereFiltersError(table, filters)
+ table match {
+ case t: SupportsDelete if t.canDeleteWhere(filters) =>
+ DeleteFromTableExec(t, filters, refreshCache(r)) :: Nil
+ case t: SupportsDelete =>
+ throw
QueryCompilationErrors.cannotDeleteTableWhereFiltersError(t, filters)
+ case t: TruncatableTable if condition == TrueLiteral =>
+ TruncateTableExec(t, refreshCache(r)) :: Nil
+ case _ =>
+ throw
QueryCompilationErrors.tableDoesNotSupportDeletesError(table)
}
- DeleteFromTableExec(table.asDeletable, filters, refreshCache(r)) ::
Nil
case _ =>
throw QueryCompilationErrors.deleteOnlySupportedWithV2TablesError()
}
+ case ReplaceData(_: DataSourceV2Relation, _, query, r:
DataSourceV2Relation, Some(write)) =>
+ // use the original relation to refresh the cache
+ ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil
+
case WriteToContinuousDataSource(writer, query, customMetrics) =>
WriteToContinuousDataSourceExec(writer, planLater(query), customMetrics)
:: Nil
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
new file mode 100644
index 00000000000..48dee3f652c
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
AttributeSet, Expression, PredicateHelper, SubqueryExpression}
+import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.expressions.filter.{Predicate =>
V2Filter}
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.sources.Filter
+
+/**
+ * A rule that builds scans for group-based row-level operations.
+ *
+ * Note this rule must be run before [[V2ScanRelationPushDown]] as scans for
group-based
+ * row-level operations must be planned in a special way.
+ */
+object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with
PredicateHelper {
+
+ import DataSourceV2Implicits._
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ // push down the filter from the command condition instead of the filter
in the rewrite plan,
+ // which is negated for data sources that only support replacing groups of
data (e.g. files)
+ case GroupBasedRowLevelOperation(rd: ReplaceData, cond, relation:
DataSourceV2Relation) =>
+ val table = relation.table.asRowLevelOperationTable
+ val scanBuilder = table.newScanBuilder(relation.options)
+
+ val (pushedFilters, remainingFilters) = pushFilters(cond,
relation.output, scanBuilder)
+ val pushedFiltersStr = if (pushedFilters.isLeft) {
+ pushedFilters.left.get.mkString(", ")
+ } else {
+ pushedFilters.right.get.mkString(", ")
+ }
+
+ val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation,
relation.output, Nil)
+
+ logInfo(
+ s"""
+ |Pushing operators to ${relation.name}
+ |Pushed filters: $pushedFiltersStr
+ |Filters that were not pushed: ${remainingFilters.mkString(", ")}
+ |Output: ${output.mkString(", ")}
+ """.stripMargin)
+
+ // replace DataSourceV2Relation with DataSourceV2ScanRelation for the
row operation table
+ rd transform {
+ case r: DataSourceV2Relation if r eq relation =>
+ DataSourceV2ScanRelation(r, scan,
PushDownUtils.toOutputAttrs(scan.readSchema(), r))
+ }
+ }
+
+ private def pushFilters(
+ cond: Expression,
+ tableAttrs: Seq[AttributeReference],
+ scanBuilder: ScanBuilder): (Either[Seq[Filter], Seq[V2Filter]],
Seq[Expression]) = {
+
+ val tableAttrSet = AttributeSet(tableAttrs)
+ val filters =
splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet))
+ val normalizedFilters = DataSourceStrategy.normalizeExprs(filters,
tableAttrs)
+ val (_, normalizedFiltersWithoutSubquery) =
+ normalizedFilters.partition(SubqueryExpression.hasSubquery)
+
+ PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala
new file mode 100644
index 00000000000..bc45dbe9fef
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, PredicateHelper,
SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
+import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable,
DeleteFromTableWithFilters, LogicalPlan, ReplaceData, RowLevelWrite}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.{SupportsDelete,
TruncatableTable}
+import org.apache.spark.sql.connector.write.RowLevelOperation
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.sources
+
+/**
+ * A rule that replaces a rewritten DELETE operation with a delete using
filters if the data source
+ * can handle this DELETE command without executing the plan that operates on
individual or groups
+ * of rows.
+ *
+ * Note this rule must be run after expression optimization but before scan
planning.
+ */
+object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with
PredicateHelper {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case RewrittenRowLevelCommand(rowLevelPlan, DELETE, cond, relation:
DataSourceV2Relation) =>
+ relation.table match {
+ case table: SupportsDelete if !SubqueryExpression.hasSubquery(cond) =>
+ val predicates = splitConjunctivePredicates(cond)
+ val normalizedPredicates =
DataSourceStrategy.normalizeExprs(predicates, relation.output)
+ val filters = toDataSourceFilters(normalizedPredicates)
+ val allPredicatesTranslated = normalizedPredicates.size ==
filters.length
+ if (allPredicatesTranslated && table.canDeleteWhere(filters)) {
+ logDebug(s"Switching to delete with filters:
${filters.mkString("[", ", ", "]")}")
+ DeleteFromTableWithFilters(relation, filters)
+ } else {
+ rowLevelPlan
+ }
+
+ case _: TruncatableTable if cond == TrueLiteral =>
+ DeleteFromTable(relation, cond)
+
+ case _ =>
+ rowLevelPlan
+ }
+ }
+
+ private def toDataSourceFilters(predicates: Seq[Expression]):
Array[sources.Filter] = {
+ predicates.flatMap { p =>
+ val filter = DataSourceStrategy.translateFilter(p,
supportNestedPredicatePushdown = true)
+ if (filter.isEmpty) {
+ logDebug(s"Cannot translate expression to data source filter: $p")
+ }
+ filter
+ }.toArray
+ }
+
+ private object RewrittenRowLevelCommand {
+ type ReturnType = (RowLevelWrite, RowLevelOperation.Command, Expression,
LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
+ case rd @ ReplaceData(_, cond, _, originalTable, _) =>
+ val command = rd.operation.command
+ Some(rd, command, cond, originalTable)
+
+ case _ =>
+ None
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 862189ed3af..8ac91e02579 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -187,7 +187,7 @@ object PushDownUtils extends PredicateHelper {
}
}
- private def toOutputAttrs(
+ def toOutputAttrs(
schema: StructType,
relation: DataSourceV2Relation): Seq[AttributeReference] = {
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 38f741532d7..2fd1d52fd98 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
import java.util.UUID
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan,
OverwriteByExpression, OverwritePartitionsDynamic}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan,
OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
@@ -31,6 +31,7 @@ import
org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, WriteT
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
/**
* A rule that constructs logical writes.
@@ -41,7 +42,7 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) =>
- val writeBuilder = newWriteBuilder(r.table, query, options)
+ val writeBuilder = newWriteBuilder(r.table, options, query.schema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
conf)
a.copy(write = Some(write), query = newQuery)
@@ -57,7 +58,7 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
}.toArray
val table = r.table
- val writeBuilder = newWriteBuilder(table, query, options)
+ val writeBuilder = newWriteBuilder(table, options, query.schema)
val write = writeBuilder match {
case builder: SupportsTruncate if isTruncate(filters) =>
builder.truncate().build()
@@ -72,7 +73,7 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query,
options, _, None) =>
val table = r.table
- val writeBuilder = newWriteBuilder(table, query, options)
+ val writeBuilder = newWriteBuilder(table, options, query.schema)
val write = writeBuilder match {
case builder: SupportsDynamicOverwrite =>
builder.overwriteDynamicPartitions().build()
@@ -85,12 +86,21 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
case WriteToMicroBatchDataSource(
relation, table, query, queryId, writeOptions, outputMode,
Some(batchId)) =>
- val writeBuilder = newWriteBuilder(table, query, writeOptions, queryId)
+ val writeBuilder = newWriteBuilder(table, writeOptions, query.schema,
queryId)
val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
val customMetrics = write.supportedCustomMetrics.toSeq
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
conf)
WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)
+
+ case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) =>
+ val rowSchema = StructType.fromAttributes(rd.dataInput)
+ val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
+ val write = writeBuilder.build()
+ val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
conf)
+ // project away any metadata columns that could be used for distribution
and ordering
+ rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))
+
}
private def buildWriteForMicroBatch(
@@ -119,11 +129,11 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
private def newWriteBuilder(
table: Table,
- query: LogicalPlan,
writeOptions: Map[String, String],
+ rowSchema: StructType,
queryId: String = UUID.randomUUID().toString): WriteBuilder = {
- val info = LogicalWriteInfoImpl(queryId, query.schema,
writeOptions.asOptions)
+ val info = LogicalWriteInfoImpl(queryId, rowSchema, writeOptions.asOptions)
table.asWritable.newWriteBuilder(info)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 65c49283dd7..d23a9e51f65 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -284,6 +284,21 @@ case class OverwritePartitionsDynamicExec(
copy(query = newChild)
}
+/**
+ * Physical plan node to replace data in existing tables.
+ */
+case class ReplaceDataExec(
+ query: SparkPlan,
+ refreshCache: () => Unit,
+ write: Write) extends V2ExistingTableWriteExec {
+
+ override val stringArgs: Iterator[Any] = Iterator(query, write)
+
+ override protected def withNewChildInternal(newChild: SparkPlan):
ReplaceDataExec = {
+ copy(query = newChild)
+ }
+}
+
case class WriteToDataSourceV2Exec(
batchWrite: BatchWrite,
refreshCache: () => Unit,
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
new file mode 100644
index 00000000000..a2cfdde2671
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
@@ -0,0 +1,629 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import java.util.Collections
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders,
QueryTest, Row}
+import org.apache.spark.sql.connector.catalog.{Identifier,
InMemoryRowLevelOperationTableCatalog}
+import org.apache.spark.sql.connector.expressions.LogicalExpressions._
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec,
ReplaceDataExec}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.QueryExecutionListener
+
+abstract class DeleteFromTableSuiteBase
+ extends QueryTest with SharedSparkSession with BeforeAndAfter with
AdaptiveSparkPlanHelper {
+
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+ import testImplicits._
+
+ before {
+ spark.conf.set("spark.sql.catalog.cat",
classOf[InMemoryRowLevelOperationTableCatalog].getName)
+ }
+
+ after {
+ spark.sessionState.catalogManager.reset()
+ spark.sessionState.conf.unsetConf("spark.sql.catalog.cat")
+ }
+
+ private val namespace = Array("ns1")
+ private val ident = Identifier.of(namespace, "test_table")
+ private val tableNameAsString = "cat." + ident.toString
+
+ private def catalog: InMemoryRowLevelOperationTableCatalog = {
+ val catalog = spark.sessionState.catalogManager.catalog("cat")
+ catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog]
+ }
+
+ test("EXPLAIN only delete") {
+ createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""")
+
+ sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE id <= 10")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Nil)
+ }
+
+ test("delete from empty tables") {
+ createTable("id INT, dep STRING")
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+
+ test("delete with basic filters") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "software") :: Row(3, "hr") :: Nil)
+ }
+
+ test("delete with aliases") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString AS t WHERE t.id <= 1 OR t.dep = 'hr'")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "software")
:: Nil)
+ }
+
+ test("delete with IN predicates") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, null)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "software") :: Row(null, "hr") :: Nil)
+ }
+
+ test("delete with NOT IN predicates") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (null, 1)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "software") :: Row(null, "hr") :: Nil)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (1, 10)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(null, "hr") :: Nil)
+ }
+
+ test("delete with conditions on nested columns") {
+ createAndInitTable("id INT, complex STRUCT<c1:INT,c2:STRING>, dep STRING",
+ """{ "id": 1, "complex": { "c1": 3, "c2": "v1" }, "dep": "hr" }
+ |{ "id": 2, "complex": { "c1": 2, "c2": "v2" }, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE complex.c1 = id + 2")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, Row(2, "v2"), "software") :: Nil)
+
+ sql(s"DELETE FROM $tableNameAsString t WHERE t.complex.c1 = id")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+
+ test("delete with IN subqueries") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(0), Some(1), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id IN (SELECT * FROM deleted_id)
+ | AND
+ | dep IN (SELECT * FROM deleted_dep)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ append("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": -1, "dep": "hr" }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr")
:: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id IS NULL
+ | OR
+ | id IN (SELECT value + 2 FROM deleted_id)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Nil)
+
+ append("id INT, dep STRING",
+ """{ "id": null, "dep": "hr" }
+ |{ "id": 2, "dep": "hr" }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hr") :: Row(null, "hr") ::
Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id IN (SELECT value + 2 FROM deleted_id)
+ | AND
+ | dep = 'hr'
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Row(null, "hr") :: Nil)
+ }
+ }
+
+ test("delete with multi-column IN subqueries") {
+ withTempView("deleted_employee") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedEmployeeDF = Seq((None, "hr"), (Some(1), "hr")).toDF()
+ deletedEmployeeDF.createOrReplaceTempView("deleted_employee")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | (id, dep) IN (SELECT * FROM deleted_employee)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+ }
+ }
+
+ test("delete with NOT IN subqueries") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(null, "hr") ::
Nil)
+
+ append("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Row(null,
"hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id)
+ | OR
+ | dep IN ('software', 'hr')
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware")
:: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)
+ | AND
+ | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE dep =
deleted_dep.value)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware")
:: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | t.id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)
+ | OR
+ | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep =
deleted_dep.value)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+ }
+
+ test("delete with EXISTS subquery") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id
IS NULL
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value)
+ | AND
+ | EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Nil)
+ }
+ }
+
+ test("delete with NOT EXISTS subquery") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2)
+ | AND
+ | NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(1, "hr") ::
Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)
+ | OR
+ | t.id = 1
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+ }
+
+ test("delete with a scalar subquery") {
+ withTempView("deleted_id") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(1), Some(100), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | id <= (SELECT min(value) FROM deleted_id)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+ }
+ }
+
+ test("delete refreshes relation cache") {
+ withTempView("temp") {
+ withCache("temp") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 1, "dep": "hardware" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ // define a view on top of the table
+ val query = sql(s"SELECT * FROM $tableNameAsString WHERE id = 1")
+ query.createOrReplaceTempView("temp")
+
+ // cache the view
+ sql("CACHE TABLE temp")
+
+ // verify the view returns expected results
+ checkAnswer(
+ sql("SELECT * FROM temp"),
+ Row(1, "hr") :: Row(1, "hardware") :: Nil)
+
+ // delete some records from the table
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1")
+
+ // verify the delete was successful
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(3, "hr") :: Nil)
+
+ // verify the view reflects the changes in the table
+ checkAnswer(sql("SELECT * FROM temp"), Nil)
+ }
+ }
+ }
+
+ test("delete with nondeterministic conditions") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1 AND rand() > 0.5")
+ }
+ assert(e.message.contains("nondeterministic expressions are only allowed"))
+ }
+
+ test("delete without condition executed as delete with filters") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ executeDeleteWithFilters(s"DELETE FROM $tableNameAsString")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+
+ test("delete with supported predicates gets converted into delete with
filters") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ executeDeleteWithFilters(s"DELETE FROM $tableNameAsString WHERE dep = 100")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 200) :: Nil)
+ }
+
+ test("delete with unsupported predicates cannot be converted into delete
with filters") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ executeDeleteWithRewrite(s"DELETE FROM $tableNameAsString WHERE dep = 100
OR dep < 200")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 200) :: Nil)
+ }
+
+ test("delete with subquery cannot be converted into delete with filters") {
+ withTempView("deleted_id") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(1), Some(100), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val q = s"DELETE FROM $tableNameAsString WHERE dep = 100 AND id IN
(SELECT * FROM deleted_id)"
+ executeDeleteWithRewrite(q)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 200) :: Row(3, 100) :: Nil)
+ }
+ }
+
+ private def createTable(schemaString: String): Unit = {
+ val schema = StructType.fromDDL(schemaString)
+ val tableProps = Collections.emptyMap[String, String]
+ catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))),
tableProps)
+ }
+
+ private def createAndInitTable(schemaString: String, jsonData: String): Unit
= {
+ createTable(schemaString)
+ append(schemaString, jsonData)
+ }
+
+ private def append(schemaString: String, jsonData: String): Unit = {
+ val df = toDF(jsonData, schemaString)
+ df.coalesce(1).writeTo(tableNameAsString).append()
+ }
+
+ private def toDF(jsonData: String, schemaString: String = null): DataFrame =
{
+ val jsonRows = jsonData.split("\\n").filter(str => str.trim.nonEmpty)
+ val jsonDS = spark.createDataset(jsonRows)(Encoders.STRING)
+ if (schemaString == null) {
+ spark.read.json(jsonDS)
+ } else {
+ spark.read.schema(schemaString).json(jsonDS)
+ }
+ }
+
+ private def executeDeleteWithFilters(query: String): Unit = {
+ val executedPlan = executeAndKeepPlan {
+ sql(query)
+ }
+
+ executedPlan match {
+ case _: DeleteFromTableExec =>
+ // OK
+ case other =>
+ fail("unexpected executed plan: " + other)
+ }
+ }
+
+ private def executeDeleteWithRewrite(query: String): Unit = {
+ val executedPlan = executeAndKeepPlan {
+ sql(query)
+ }
+
+ executedPlan match {
+ case _: ReplaceDataExec =>
+ // OK
+ case other =>
+ fail("unexpected executed plan: " + other)
+ }
+ }
+
+ // executes an operation and keeps the executed plan
+ private def executeAndKeepPlan(func: => Unit): SparkPlan = {
+ var executedPlan: SparkPlan = null
+
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs:
Long): Unit = {
+ executedPlan = qe.executedPlan
+ }
+ override def onFailure(funcName: String, qe: QueryExecution, exception:
Exception): Unit = {
+ }
+ }
+ spark.listenerManager.register(listener)
+
+ func
+
+ sparkContext.listenerBus.waitUntilEmpty()
+
+ stripAQEPlan(executedPlan)
+ }
+}
+
+class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index 24b6be07619..6a20ee21294 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -34,7 +34,7 @@ import
org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn,
AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect,
DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction,
LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project,
SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias,
UnsetTableProperties, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.FakeV2Provider
-import org.apache.spark.sql.connector.catalog.{CatalogManager,
CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog,
V1Table}
+import org.apache.spark.sql.connector.catalog.{CatalogManager,
CatalogNotFoundException, Identifier, SupportsDelete, Table, TableCapability,
TableCatalog, V1Table}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.{CreateTable =>
CreateTableV1}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -49,7 +49,7 @@ class PlanResolutionSuite extends AnalysisTest {
private val v2Format = classOf[FakeV2Provider].getName
private val table: Table = {
- val t = mock(classOf[Table])
+ val t = mock(classOf[SupportsDelete])
when(t.schema()).thenReturn(new StructType().add("i", "int").add("s",
"string"))
when(t.partitioning()).thenReturn(Array.empty[Transform])
t
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]