This is an automated email from the ASF dual-hosted git repository. huaxingao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1103343e71f [SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite 1103343e71f is described below commit 1103343e71fbcb478fa41941c87d2c28b0c09281 Author: huaxingao <huaxin_...@apple.com> AuthorDate: Mon Aug 15 10:58:14 2022 -0700 [SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite ### What changes were proposed in this pull request? Migrate `SupportsOverwrite` to use V2 Filter ### Why are the changes needed? this is part of the V2Filter migration work ### Does this PR introduce _any_ user-facing change? Yes add `SupportsOverwriteV2` ### How was this patch tested? new tests Closes #37502 from huaxingao/v2overwrite. Authored-by: huaxingao <huaxin_...@apple.com> Signed-off-by: huaxingao <huaxin_...@apple.com> --- .../sql/connector/catalog/TableCapability.java | 2 +- .../connector/write/SupportsDynamicOverwrite.java | 2 +- .../sql/connector/write/SupportsOverwrite.java | 31 ++- ...ortsOverwrite.java => SupportsOverwriteV2.java} | 31 ++- .../sql/connector/catalog/InMemoryBaseTable.scala | 138 +++--------- .../sql/connector/catalog/InMemoryTable.scala | 99 ++++++++- .../catalog/InMemoryTableWithV2Filter.scala | 72 +++++-- .../sql/execution/datasources/v2/V2Writes.scala | 23 +- .../spark/sql/connector/DataSourceV2SQLSuite.scala | 233 +++------------------ .../spark/sql/connector/DeleteFromTests.scala | 132 ++++++++++++ .../spark/sql/connector/V1WriteFallbackSuite.scala | 4 +- 11 files changed, 412 insertions(+), 355 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java index 5bb42fb4b31..5732c0f3af4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java @@ -76,7 +76,7 @@ public enum TableCapability { * Signals that the table can replace existing data that matches a filter with appended data in * a write operation. * <p> - * See {@link org.apache.spark.sql.connector.write.SupportsOverwrite}. + * See {@link org.apache.spark.sql.connector.write.SupportsOverwriteV2}. */ OVERWRITE_BY_FILTER, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java index 422cd71d345..0288a679891 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java @@ -27,7 +27,7 @@ import org.apache.spark.annotation.Evolving; * write does not contain data will remain unchanged. * <p> * This is provided to implement SQL compatible with Hive table operations but is not recommended. - * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data. + * Instead, use the {@link SupportsOverwriteV2 overwrite by filter API} to explicitly replace data. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java index b4e60257942..51bec236088 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.connector.write; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.internal.connector.PredicateUtils; import org.apache.spark.sql.sources.AlwaysTrue$; import org.apache.spark.sql.sources.Filter; @@ -30,7 +32,24 @@ import org.apache.spark.sql.sources.Filter; * @since 3.0.0 */ @Evolving -public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { +public interface SupportsOverwrite extends SupportsOverwriteV2 { + + /** + * Checks whether it is possible to overwrite data from a data source table that matches filter + * expressions. + * <p> + * Rows should be overwritten from the data source iff all of the filter expressions match. + * That is, the expressions must be interpreted as a set of filters that are ANDed together. + * + * @param filters V2 filter expressions, used to match data to overwrite + * @return true if the delete operation can be performed + * + * @since 3.4.0 + */ + default boolean canOverwrite(Filter[] filters) { + return true; + } + /** * Configures a write to replace data matching the filters with data committed in the write. * <p> @@ -42,6 +61,16 @@ public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { */ WriteBuilder overwrite(Filter[] filters); + default boolean canOverwrite(Predicate[] predicates) { + Filter[] v1Filters = PredicateUtils.toV1(predicates); + if (v1Filters.length < predicates.length) return false; + return this.canOverwrite(v1Filters); + } + + default WriteBuilder overwrite(Predicate[] predicates) { + return this.overwrite(PredicateUtils.toV1(predicates)); + } + @Override default WriteBuilder truncate() { return overwrite(new Filter[] { AlwaysTrue$.MODULE$ }); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java similarity index 60% copy from sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java copy to sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java index b4e60257942..c1fcbfd38e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java @@ -18,8 +18,8 @@ package org.apache.spark.sql.connector.write; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.AlwaysTrue$; -import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.connector.expressions.filter.AlwaysTrue; +import org.apache.spark.sql.connector.expressions.filter.Predicate; /** * Write builder trait for tables that support overwrite by filter. @@ -27,23 +27,40 @@ import org.apache.spark.sql.sources.Filter; * Overwriting data by filter will delete any data that matches the filter and replace it with data * that is committed in the write. * - * @since 3.0.0 + * @since 3.4.0 */ @Evolving -public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { +public interface SupportsOverwriteV2 extends WriteBuilder, SupportsTruncate { + + /** + * Checks whether it is possible to overwrite data from a data source table that matches filter + * expressions. + * <p> + * Rows should be overwritten from the data source iff all of the filter expressions match. + * That is, the expressions must be interpreted as a set of filters that are ANDed together. + * + * @param predicates V2 filter expressions, used to match data to overwrite + * @return true if the delete operation can be performed + * + * @since 3.4.0 + */ + default boolean canOverwrite(Predicate[] predicates) { + return true; + } + /** * Configures a write to replace data matching the filters with data committed in the write. * <p> * Rows must be deleted from the data source if and only if all of the filters match. That is, * filters must be interpreted as ANDed together. * - * @param filters filters used to match data to overwrite + * @param predicates filters used to match data to overwrite * @return this write builder for method chaining */ - WriteBuilder overwrite(Filter[] filters); + WriteBuilder overwrite(Predicate[] predicates); @Override default WriteBuilder truncate() { - return overwrite(new Filter[] { AlwaysTrue$.MODULE$ }); + return overwrite(new Predicate[] { new AlwaysTrue() }); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 1f8b416cf55..f139399ed76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -45,7 +45,7 @@ import org.apache.spark.unsafe.types.UTF8String /** * A simple in-memory table. Rows are stored as a buffered group produced by each output task. */ -class InMemoryBaseTable( +abstract class InMemoryBaseTable( val name: String, val schema: StructType, override val partitioning: Array[Transform], @@ -337,59 +337,39 @@ class InMemoryBaseTable( } } - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) - InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options) + abstract class InMemoryWriterBuilder() extends SupportsTruncate with SupportsDynamicOverwrite + with SupportsStreamingUpdateAsAppend { - new WriteBuilder with SupportsTruncate with SupportsOverwrite - with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend { + protected var writer: BatchWrite = Append + protected var streamingWriter: StreamingWrite = StreamingAppend - private var writer: BatchWrite = Append - private var streamingWriter: StreamingWrite = StreamingAppend - - override def truncate(): WriteBuilder = { - assert(writer == Append) - writer = TruncateAndAppend - streamingWriter = StreamingTruncateAndAppend - this - } - - override def overwrite(filters: Array[Filter]): WriteBuilder = { - assert(writer == Append) - writer = new Overwrite(filters) - streamingWriter = new StreamingNotSupportedOperation( - s"overwrite (${filters.mkString("filters(", ", ", ")")})") - this - } - - override def overwriteDynamicPartitions(): WriteBuilder = { - assert(writer == Append) - writer = DynamicOverwrite - streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") - this - } + override def overwriteDynamicPartitions(): WriteBuilder = { + assert(writer == Append) + writer = DynamicOverwrite + streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") + this + } - override def build(): Write = new Write with RequiresDistributionAndOrdering { - override def requiredDistribution: Distribution = distribution + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = distribution - override def distributionStrictlyRequired: Boolean = isDistributionStrictlyRequired + override def distributionStrictlyRequired: Boolean = isDistributionStrictlyRequired - override def requiredOrdering: Array[SortOrder] = ordering + override def requiredOrdering: Array[SortOrder] = ordering - override def requiredNumPartitions(): Int = { - numPartitions.getOrElse(0) - } + override def requiredNumPartitions(): Int = { + numPartitions.getOrElse(0) + } - override def toBatch: BatchWrite = writer + override def toBatch: BatchWrite = writer - override def toStreaming: StreamingWrite = streamingWriter match { - case exc: StreamingNotSupportedOperation => exc.throwsException() - case s => s - } + override def toStreaming: StreamingWrite = streamingWriter match { + case exc: StreamingNotSupportedOperation => exc.throwsException() + case s => s + } - override def supportedCustomMetrics(): Array[CustomMetric] = { - Array(new InMemorySimpleCustomMetric) - } + override def supportedCustomMetrics(): Array[CustomMetric] = { + Array(new InMemorySimpleCustomMetric) } } } @@ -402,7 +382,7 @@ class InMemoryBaseTable( override def abort(messages: Array[WriterCommitMessage]): Unit = {} } - private object Append extends TestBatchWrite { + protected object Append extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { withData(messages.map(_.asInstanceOf[BufferedRows])) } @@ -416,24 +396,14 @@ class InMemoryBaseTable( } } - private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - val deleteKeys = InMemoryBaseTable.filtersToKeys( - dataMap.keys, partCols.map(_.toSeq.quoted), filters) - dataMap --= deleteKeys - withData(messages.map(_.asInstanceOf[BufferedRows])) - } - } - - private object TruncateAndAppend extends TestBatchWrite { + protected object TruncateAndAppend extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { dataMap.clear withData(messages.map(_.asInstanceOf[BufferedRows])) } } - private abstract class TestStreamingWrite extends StreamingWrite { + protected abstract class TestStreamingWrite extends StreamingWrite { def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = { BufferedRowsWriterFactory } @@ -441,7 +411,7 @@ class InMemoryBaseTable( def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} } - private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite { + protected class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite { override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = throwsException() @@ -463,7 +433,7 @@ class InMemoryBaseTable( } } - private object StreamingTruncateAndAppend extends TestStreamingWrite { + protected object StreamingTruncateAndAppend extends TestStreamingWrite { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { dataMap.synchronized { dataMap.clear @@ -476,46 +446,7 @@ class InMemoryBaseTable( object InMemoryBaseTable { val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite" - def filtersToKeys( - keys: Iterable[Seq[Any]], - partitionNames: Seq[String], - filters: Array[Filter]): Iterable[Seq[Any]] = { - keys.filter { partValues => - filters.flatMap(splitAnd).forall { - case EqualTo(attr, value) => - value == extractValue(attr, partitionNames, partValues) - case EqualNullSafe(attr, value) => - val attrVal = extractValue(attr, partitionNames, partValues) - if (attrVal == null && value === null) { - true - } else if (attrVal == null || value === null) { - false - } else { - value == attrVal - } - case IsNull(attr) => - null == extractValue(attr, partitionNames, partValues) - case IsNotNull(attr) => - null != extractValue(attr, partitionNames, partValues) - case AlwaysTrue() => true - case f => - throw new IllegalArgumentException(s"Unsupported filter type: $f") - } - } - } - - def supportsFilters(filters: Array[Filter]): Boolean = { - filters.flatMap(splitAnd).forall { - case _: EqualTo => true - case _: EqualNullSafe => true - case _: IsNull => true - case _: IsNotNull => true - case _: AlwaysTrue => true - case _ => false - } - } - - private def extractValue( + def extractValue( attr: String, partFieldNames: Seq[String], partValues: Seq[Any]): Any = { @@ -527,13 +458,6 @@ object InMemoryBaseTable { } } - private def splitAnd(filter: Filter): Seq[Filter] = { - filter match { - case And(left, right) => splitAnd(left) ++ splitAnd(right) - case _ => filter :: Nil - } - } - def maybeSimulateFailedTableWrite(tableOptions: CaseInsensitiveStringMap): Unit = { if (tableOptions.getBoolean(SIMULATE_FAILED_WRITE_OPTION, false)) { throw new IllegalStateException("Manual write to table failure.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index b82641a5d24..cd6821c8739 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.connector.catalog import java.util +import org.scalatest.Assertions.assert + import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A simple in-memory table. Rows are stored as a buffered group produced by each output task. @@ -40,12 +44,12 @@ class InMemoryTable( ordering, numPartitions, isDistributionStrictlyRequired) with SupportsDelete { override def canDeleteWhere(filters: Array[Filter]): Boolean = { - InMemoryBaseTable.supportsFilters(filters) + InMemoryTable.supportsFilters(filters) } override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - dataMap --= InMemoryBaseTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) + dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) } override def withData(data: Array[BufferedRows]): InMemoryTable = { @@ -64,4 +68,93 @@ class InMemoryTable( }) this } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) + InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options) + + new InMemoryWriterBuilderWithOverWrite() + } + + private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder + with SupportsOverwrite { + + override def truncate(): WriteBuilder = { + assert(writer == Append) + writer = TruncateAndAppend + streamingWriter = StreamingTruncateAndAppend + this + } + + override def overwrite(filters: Array[Filter]): WriteBuilder = { + assert(writer == Append) + writer = new Overwrite(filters) + streamingWriter = new StreamingNotSupportedOperation( + s"overwrite (${filters.mkString("filters(", ", ", ")")})") + this + } + + override def canOverwrite(filters: Array[Filter]): Boolean = { + InMemoryTable.supportsFilters(filters) + } + } + + private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val deleteKeys = InMemoryTable.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), filters) + dataMap --= deleteKeys + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } +} + +object InMemoryTable { + + def filtersToKeys( + keys: Iterable[Seq[Any]], + partitionNames: Seq[String], + filters: Array[Filter]): Iterable[Seq[Any]] = { + keys.filter { partValues => + filters.flatMap(splitAnd).forall { + case EqualTo(attr, value) => + value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + case EqualNullSafe(attr, value) => + val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + if (attrVal == null && value == null) { + true + } else if (attrVal == null || value == null) { + false + } else { + value == attrVal + } + case IsNull(attr) => + null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + case IsNotNull(attr) => + null != InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + case AlwaysTrue() => true + case f => + throw new IllegalArgumentException(s"Unsupported filter type: $f") + } + } + } + + def supportsFilters(filters: Array[Filter]): Boolean = { + filters.flatMap(splitAnd).forall { + case _: EqualTo => true + case _: EqualNullSafe => true + case _: IsNull => true + case _: IsNotNull => true + case _: AlwaysTrue => true + case _ => false + } + } + + private def splitAnd(filter: Filter): Seq[Filter] = { + filter match { + case And(left, right) => splitAnd(left) ++ splitAnd(right) + case _ => filter :: Nil + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala index 48000dd0d98..b4285f31dd7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.connector.catalog import java.util +import org.scalatest.Assertions.assert + import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform} import org.apache.spark.sql.connector.expressions.filter.{And, Predicate} import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwriteV2, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -32,8 +35,8 @@ class InMemoryTableWithV2Filter( properties: util.Map[String, String]) extends InMemoryBaseTable(name, schema, partitioning, properties) with SupportsDeleteV2 { - override def canDeleteWhere(filters: Array[Predicate]): Boolean = { - InMemoryTableWithV2Filter.supportsFilters(filters) + override def canDeleteWhere(predicates: Array[Predicate]): Boolean = { + InMemoryTableWithV2Filter.supportsPredicates(predicates) } override def deleteWhere(filters: Array[Predicate]): Unit = dataMap.synchronized { @@ -84,6 +87,46 @@ class InMemoryTableWithV2Filter( } } } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) + InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options) + + new InMemoryWriterBuilderWithOverWrite() + } + + private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder + with SupportsOverwriteV2 { + + override def truncate(): WriteBuilder = { + assert(writer == Append) + writer = TruncateAndAppend + streamingWriter = StreamingTruncateAndAppend + this + } + + override def overwrite(predicates: Array[Predicate]): WriteBuilder = { + assert(writer == Append) + writer = new Overwrite(predicates) + streamingWriter = new StreamingNotSupportedOperation( + s"overwrite (${predicates.mkString("filters(", ", ", ")")})") + this + } + + override def canOverwrite(predicates: Array[Predicate]): Boolean = { + InMemoryTableWithV2Filter.supportsPredicates(predicates) + } + } + + private class Overwrite(predicates: Array[Predicate]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val deleteKeys = InMemoryTableWithV2Filter.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), predicates) + dataMap --= deleteKeys + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } } object InMemoryTableWithV2Filter { @@ -96,9 +139,10 @@ object InMemoryTableWithV2Filter { filters.flatMap(splitAnd).forall { case p: Predicate if p.name().equals("=") => p.children()(1).asInstanceOf[LiteralValue[_]].value == - extractValue(p.children()(0).toString, partitionNames, partValues) + InMemoryBaseTable.extractValue(p.children()(0).toString, partitionNames, partValues) case p: Predicate if p.name().equals("<=>") => - val attrVal = extractValue(p.children()(0).toString, partitionNames, partValues) + val attrVal = InMemoryBaseTable + .extractValue(p.children()(0).toString, partitionNames, partValues) val value = p.children()(1).asInstanceOf[LiteralValue[_]].value if (attrVal == null && value == null) { true @@ -109,10 +153,10 @@ object InMemoryTableWithV2Filter { } case p: Predicate if p.name().equals("IS NULL") => val attr = p.children()(0).toString - null == extractValue(attr, partitionNames, partValues) + null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) case p: Predicate if p.name().equals("IS NOT NULL") => val attr = p.children()(0).toString - null != extractValue(attr, partitionNames, partValues) + null != InMemoryBaseTable.extractValue(attr, partitionNames, partValues) case p: Predicate if p.name().equals("ALWAYS_TRUE") => true case f => throw new IllegalArgumentException(s"Unsupported filter type: $f") @@ -120,8 +164,8 @@ object InMemoryTableWithV2Filter { } } - def supportsFilters(filters: Array[Predicate]): Boolean = { - filters.flatMap(splitAnd).forall { + def supportsPredicates(predicates: Array[Predicate]): Boolean = { + predicates.flatMap(splitAnd).forall { case p: Predicate if p.name().equals("=") => true case p: Predicate if p.name().equals("<=>") => true case p: Predicate if p.name().equals("IS NULL") => true @@ -131,18 +175,6 @@ object InMemoryTableWithV2Filter { } } - private def extractValue( - attr: String, - partFieldNames: Seq[String], - partValues: Seq[Any]): Any = { - partFieldNames.zipWithIndex.find(_._1 == attr) match { - case Some((_, partIndex)) => - partValues(partIndex) - case _ => - throw new IllegalArgumentException(s"Unknown filter attribute: $attr") - } - } - private def splitAnd(filter: Predicate): Seq[Predicate] = { filter match { case and: And => splitAnd(and.left()) ++ splitAnd(and.right()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 4422743c5ac..2d47d94ff1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -24,12 +24,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Ove import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} -import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, Write, WriteBuilder} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwriteV2, SupportsTruncate, Write, WriteBuilder} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, WriteToMicroBatchDataSource} import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend -import org.apache.spark.sql.sources.{AlwaysTrue, Filter} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -49,21 +48,21 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) => // fail if any filter cannot be converted. correctness depends on removing all matching data. - val filters = splitConjunctivePredicates(deleteExpr).flatMap { pred => - val filter = DataSourceStrategy.translateFilter(pred, supportNestedPredicatePushdown = true) - if (filter.isEmpty) { + val predicates = splitConjunctivePredicates(deleteExpr).flatMap { pred => + val predicate = DataSourceV2Strategy.translateFilterV2(pred) + if (predicate.isEmpty) { throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred) } - filter + predicate }.toArray val table = r.table val writeBuilder = newWriteBuilder(table, options, query.schema) val write = writeBuilder match { - case builder: SupportsTruncate if isTruncate(filters) => + case builder: SupportsTruncate if isTruncate(predicates) => builder.truncate().build() - case builder: SupportsOverwrite => - builder.overwrite(filters).build() + case builder: SupportsOverwriteV2 if builder.canOverwrite(predicates) => + builder.overwrite(predicates).build() case _ => throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table) } @@ -123,8 +122,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { } } - private def isTruncate(filters: Array[Filter]): Boolean = { - filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + private def isTruncate(predicates: Array[Predicate]): Boolean = { + predicates.length == 1 && predicates(0).name().equals("ALWAYS_TRUE") } private def newWriteBuilder( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 9ec5be46fc2..629a5ac83c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -40,18 +40,13 @@ import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{LongType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils -class DataSourceV2SQLSuite +abstract class DataSourceV2SQLSuite extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true) - with AlterTableTests with DatasourceV2SQLBase { + with DeleteFromTests with DatasourceV2SQLBase { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - - private val v2Source = classOf[FakeV2Provider].getName + protected val v2Source = classOf[FakeV2Provider].getName override protected val v2Format = v2Source - override protected val catalogAndNamespace = "testcat.ns1.ns2." - private val defaultUser: String = Utils.getCurrentUserName() protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = { val tmpView = "tmp_view" @@ -66,6 +61,20 @@ class DataSourceV2SQLSuite checkAnswer(spark.table(tableName), expected) } + protected def assertAnalysisError( + sqlStatement: String, + expectedError: String): Unit = { + val ex = intercept[AnalysisException] { + sql(sqlStatement) + } + assert(ex.getMessage.contains(expectedError)) + } +} + +class DataSourceV2SQLSuiteV1Filter extends DataSourceV2SQLSuite with AlterTableTests { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override protected val catalogAndNamespace = "testcat.ns1.ns2." override def getTableMetadata(tableName: String): Table = { val nameParts = spark.sessionState.sqlParser.parseMultipartIdentifier(tableName) val v2Catalog = catalog(nameParts.head).asTableCatalog @@ -622,8 +631,8 @@ class DataSourceV2SQLSuite assert(table.partitioning.isEmpty) assert(table.properties == withDefaultOwnership(Map("provider" -> v2Source)).asJava) assert(table.schema == new StructType() - .add("id", LongType) - .add("data", StringType)) + .add("id", LongType) + .add("data", StringType)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) @@ -639,8 +648,8 @@ class DataSourceV2SQLSuite assert(table.partitioning.isEmpty) assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava) assert(table.schema == new StructType() - .add("id", LongType) - .add("data", StringType)) + .add("id", LongType) + .add("data", StringType)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) @@ -659,8 +668,8 @@ class DataSourceV2SQLSuite assert(table2.partitioning.isEmpty) assert(table2.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava) assert(table2.schema == new StructType() - .add("id", LongType) - .add("data", StringType)) + .add("id", LongType) + .add("data", StringType)) val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source")) @@ -677,8 +686,8 @@ class DataSourceV2SQLSuite assert(table.partitioning.isEmpty) assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava) assert(table.schema == new StructType() - .add("id", LongType) - .add("data", StringType)) + .add("id", LongType) + .add("data", StringType)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) @@ -708,8 +717,8 @@ class DataSourceV2SQLSuite assert(table.partitioning.isEmpty) assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava) assert(table.schema == new StructType() - .add("id", LongType) - .add("data", StringType)) + .add("id", LongType) + .add("data", StringType)) val rdd = sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) @@ -1526,148 +1535,6 @@ class DataSourceV2SQLSuite assert(e.message.contains("REPLACE TABLE is only supported with v2 tables")) } - test("DeleteFrom: basic - delete all") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t") - checkAnswer(spark.table(t), Seq()) - } - } - - test("DeleteFrom with v2 filtering: basic - delete all") { - val t = "testv2filter.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t") - checkAnswer(spark.table(t), Seq()) - } - } - - test("DeleteFrom: basic - delete with where clause") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t WHERE id = 2") - checkAnswer(spark.table(t), Seq( - Row(3, "c", 3))) - } - } - - test("DeleteFrom with v2 filtering: basic - delete with where clause") { - val t = "testv2filter.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t WHERE id = 2") - checkAnswer(spark.table(t), Seq( - Row(3, "c", 3))) - } - } - - test("DeleteFrom: delete from aliased target table") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2") - checkAnswer(spark.table(t), Seq( - Row(3, "c", 3))) - } - } - - test("DeleteFrom with v2 filtering: delete from aliased target table") { - val t = "testv2filter.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2") - checkAnswer(spark.table(t), Seq( - Row(3, "c", 3))) - } - } - - test("DeleteFrom: normalize attribute names") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2") - checkAnswer(spark.table(t), Seq( - Row(3, "c", 3))) - } - } - - test("DeleteFrom with v2 filtering: normalize attribute names") { - val t = "testv2filter.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2") - checkAnswer(spark.table(t), Seq( - Row(3, "c", 3))) - } - } - - test("DeleteFrom: fail if has subquery") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - val exc = intercept[AnalysisException] { - sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)") - } - - assert(spark.table(t).count === 3) - assert(exc.getMessage.contains("Delete by condition with subquery is not supported")) - } - } - - test("DeleteFrom with v2 filtering: fail if has subquery") { - val t = "testv2filter.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - val exc = intercept[AnalysisException] { - sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)") - } - - assert(spark.table(t).count === 3) - assert(exc.getMessage.contains("Delete by condition with subquery is not supported")) - } - } - - test("DeleteFrom: delete with unsupported predicates") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - val exc = intercept[AnalysisException] { - sql(s"DELETE FROM $t WHERE id > 3 AND p > 3") - } - - assert(spark.table(t).count === 3) - assert(exc.getMessage.contains(s"Cannot delete from table $t")) - } - } - - test("DeleteFrom with v2 filtering: delete with unsupported predicates") { - val t = "testv2filter.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - val exc = intercept[AnalysisException] { - sql(s"DELETE FROM $t WHERE id > 3 AND p > 3") - } - - assert(spark.table(t).count === 3) - assert(exc.getMessage.contains(s"Cannot delete from table $t")) - } - } - test("DeleteFrom: - delete with invalid predicate") { val t = "testcat.ns1.ns2.tbl" withTable(t) { @@ -1682,37 +1549,6 @@ class DataSourceV2SQLSuite } } - test("DeleteFrom: DELETE is only supported with v2 tables") { - // unset this config to use the default v2 session catalog. - spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) - val v1Table = "tbl" - withTable(v1Table) { - sql(s"CREATE TABLE $v1Table" + - s" USING ${classOf[SimpleScanSource].getName} OPTIONS (from=0,to=1)") - val exc = intercept[AnalysisException] { - sql(s"DELETE FROM $v1Table WHERE i = 2") - } - - assert(exc.getMessage.contains("DELETE is only supported with v2 tables")) - } - } - - test("SPARK-33652: DeleteFrom should refresh caches referencing the table") { - val t = "testcat.ns1.ns2.tbl" - val view = "view" - withTable(t) { - withTempView(view) { - sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") - sql(s"CACHE TABLE view AS SELECT id FROM $t") - assert(spark.table(view).count() == 3) - - sql(s"DELETE FROM $t WHERE id = 2") - assert(spark.table(view).count() == 1) - } - } - } - test("UPDATE TABLE") { val t = "testcat.ns1.ns2.tbl" withTable(t) { @@ -2272,7 +2108,7 @@ class DataSourceV2SQLSuite val t1 = s"${catalogAndNamespace}table" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + - "PARTITIONED BY (bucket(4, id), id)") + "PARTITIONED BY (bucket(4, id), id)") val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0") val dfQuery = spark.table(t1).filter("index = 0") @@ -2796,15 +2632,6 @@ class DataSourceV2SQLSuite assert(e.message.contains(s"$sqlCommand is not supported for v2 tables")) } - private def assertAnalysisError( - sqlStatement: String, - expectedError: String): Unit = { - val ex = intercept[AnalysisException] { - sql(sqlStatement) - } - assert(ex.getMessage.contains(expectedError)) - } - private def assertAnalysisErrorClass( sqlStatement: String, expectedErrorClass: String, @@ -2815,8 +2642,12 @@ class DataSourceV2SQLSuite assert(ex.getErrorClass == expectedErrorClass) assert(ex.messageParameters.sameElements(expectedErrorMessageParameters)) } + } +class DataSourceV2SQLSuiteV2Filter extends DataSourceV2SQLSuite { + override protected val catalogAndNamespace = "testv2filter.ns1.ns2." +} /** Used as a V2 DataSource for V2SessionCatalog DDL */ class FakeV2Provider extends SimpleTableProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala new file mode 100644 index 00000000000..5ed64df6280 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.sources.SimpleScanSource + +/** + * A collection of "DELETE" tests that can be run through the SQL APIs. + */ +trait DeleteFromTests extends DatasourceV2SQLBase { + + protected val catalogAndNamespace: String + + test("DeleteFrom with v2 filtering: basic - delete all") { + val t = s"${catalogAndNamespace}tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") + sql(s"DELETE FROM $t") + checkAnswer(spark.table(t), Seq()) + } + } + + test("DeleteFrom with v2 filtering: basic - delete with where clause") { + val t = s"${catalogAndNamespace}tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") + sql(s"DELETE FROM $t WHERE id = 2") + checkAnswer(spark.table(t), Seq( + Row(3, "c", 3))) + } + } + + test("DeleteFrom with v2 filtering: delete from aliased target table") { + val t = s"${catalogAndNamespace}tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") + sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2") + checkAnswer(spark.table(t), Seq( + Row(3, "c", 3))) + } + } + + test("DeleteFrom with v2 filtering: normalize attribute names") { + val t = s"${catalogAndNamespace}tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") + sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2") + checkAnswer(spark.table(t), Seq( + Row(3, "c", 3))) + } + } + + test("DeleteFrom with v2 filtering: fail if has subquery") { + val t = s"${catalogAndNamespace}tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") + val exc = intercept[AnalysisException] { + sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)") + } + + assert(spark.table(t).count === 3) + assert(exc.getMessage.contains("Delete by condition with subquery is not supported")) + } + } + + test("DeleteFrom with v2 filtering: delete with unsupported predicates") { + val t = s"${catalogAndNamespace}tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo") + sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") + val exc = intercept[AnalysisException] { + sql(s"DELETE FROM $t WHERE id > 3 AND p > 3") + } + + assert(spark.table(t).count === 3) + assert(exc.getMessage.contains(s"Cannot delete from table $t")) + } + } + + test("DeleteFrom: DELETE is only supported with v2 tables") { + // unset this config to use the default v2 session catalog. + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + val v1Table = "tbl" + withTable(v1Table) { + sql(s"CREATE TABLE $v1Table" + + s" USING ${classOf[SimpleScanSource].getName} OPTIONS (from=0,to=1)") + val exc = intercept[AnalysisException] { + sql(s"DELETE FROM $v1Table WHERE i = 2") + } + + assert(exc.getMessage.contains("DELETE is only supported with v2 tables")) + } + } + + test("SPARK-33652: DeleteFrom should refresh caches referencing the table") { + val t = s"${catalogAndNamespace}tbl" + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)") + sql(s"CACHE TABLE view AS SELECT id FROM $t") + assert(spark.table(view).count() == 3) + + sql(s"DELETE FROM $t WHERE id = 2") + assert(spark.table(view).count() == 1) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index fe4f70e57ef..992c46cc6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveM import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryBaseTable, SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder} @@ -359,7 +359,7 @@ class InMemoryTableWithV1Fallback( } override def overwrite(filters: Array[Filter]): WriteBuilder = { - val keys = InMemoryBaseTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + val keys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) dataMap --= keys mode = "overwrite" this --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org