This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new c1f7aa2 [SPARK-33482][SPARK-34756][SQL] Fix FileScan equality check
c1f7aa2 is described below
commit c1f7aa286a64f650f1dc9fc85bde33b683f9dd2e
Author: Peter Toth <[email protected]>
AuthorDate: Tue Mar 23 17:01:16 2021 +0800
[SPARK-33482][SPARK-34756][SQL] Fix FileScan equality check
### What changes were proposed in this pull request?
This bug was introduced by SPARK-30428 at Apache Spark 3.0.0.
This PR fixes `FileScan.equals()`.
### Why are the changes needed?
- Without this fix `FileScan.equals` doesn't take `fileIndex` and
`readSchema` into account.
- Partition filters and data filters added to `FileScan` (in #27112 and
#27157) caused that canonicalized form of some `BatchScanExec` nodes don't
match and this prevents some reuse possibilities.
### Does this PR introduce _any_ user-facing change?
Yes, before this fix incorrect reuse of `FileScan` and so `BatchScanExec`
could have happed causing correctness issues.
### How was this patch tested?
Added new UTs.
Closes #31848 from peter-toth/SPARK-34756-fix-filescan-equality-check.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 93a5d34f84c362110ef7d8853e59ce597faddad9)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/avro/AvroScanSuite.scala | 30 ++
.../sql/execution/datasources/v2/FileScan.scala | 22 +-
.../scala/org/apache/spark/sql/FileScanSuite.scala | 374 +++++++++++++++++++++
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 24 ++
4 files changed, 446 insertions(+), 4 deletions(-)
diff --git
a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
new file mode 100644
index 0000000..98a7190
--- /dev/null
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.spark.sql.FileScanSuiteBase
+import org.apache.spark.sql.v2.avro.AvroScan
+
+class AvroScanSuite extends FileScanSuiteBase {
+ val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
+ ("AvroScan",
+ (s, fi, ds, rds, rps, f, o, pf, df) => AvroScan(s, fi, ds, rds, rps, o,
f, pf, df),
+ Seq.empty))
+
+ run(scanBuilders)
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
index 363dd15..ac63725 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
@@ -24,8 +24,9 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD
import org.apache.spark.sql.{AnalysisException, SparkSession}
-import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet}
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression,
ExpressionSet}
import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan,
Statistics, SupportsReportStatistics}
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
@@ -84,11 +85,24 @@ trait FileScan extends Scan
protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ",
"]")
+ private lazy val (normalizedPartitionFilters, normalizedDataFilters) = {
+ val output = readSchema().toAttributes
+ val partitionFilterAttributes = AttributeSet(partitionFilters).map(a =>
a.name -> a).toMap
+ val dataFiltersAttributes = AttributeSet(dataFilters).map(a => a.name ->
a).toMap
+ val normalizedPartitionFilters = ExpressionSet(partitionFilters.map(
+ QueryPlan.normalizeExpressions(_,
+ output.map(a => partitionFilterAttributes.getOrElse(a.name, a)))))
+ val normalizedDataFilters = ExpressionSet(dataFilters.map(
+ QueryPlan.normalizeExpressions(_,
+ output.map(a => dataFiltersAttributes.getOrElse(a.name, a)))))
+ (normalizedPartitionFilters, normalizedDataFilters)
+ }
+
override def equals(obj: Any): Boolean = obj match {
case f: FileScan =>
- fileIndex == f.fileIndex && readSchema == f.readSchema
- ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) &&
- ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters)
+ fileIndex == f.fileIndex && readSchema == f.readSchema &&
+ normalizedPartitionFilters == f.normalizedPartitionFilters &&
+ normalizedDataFilters == f.normalizedDataFilters
case _ => false
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
new file mode 100644
index 0000000..4e7fe84
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.mutable
+
+import com.google.common.collect.ImmutableMap
+import org.apache.hadoop.fs.{FileStatus, Path}
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, IsNull,
LessThan}
+import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex,
PartitionSpec}
+import org.apache.spark.sql.execution.datasources.v2.FileScan
+import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
+import org.apache.spark.sql.execution.datasources.v2.json.JsonScan
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.execution.datasources.v2.text.TextScan
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+trait FileScanSuiteBase extends SharedSparkSession {
+ private def newPartitioningAwareFileIndex() = {
+ new PartitioningAwareFileIndex(spark, Map.empty, None) {
+ override def partitionSpec(): PartitionSpec = {
+ PartitionSpec.emptySpec
+ }
+
+ override protected def leafFiles: mutable.LinkedHashMap[Path,
FileStatus] = {
+ mutable.LinkedHashMap.empty
+ }
+
+ override protected def leafDirToChildrenFiles: Map[Path,
Array[FileStatus]] = {
+ Map.empty
+ }
+
+ override def rootPaths: Seq[Path] = {
+ Seq.empty
+ }
+
+ override def refresh(): Unit = {}
+ }
+ }
+
+ type ScanBuilder = (
+ SparkSession,
+ PartitioningAwareFileIndex,
+ StructType,
+ StructType,
+ StructType,
+ Array[Filter],
+ CaseInsensitiveStringMap,
+ Seq[Expression],
+ Seq[Expression]) => FileScan
+
+ def run(scanBuilders: Seq[(String, ScanBuilder, Seq[String])]): Unit = {
+ val dataSchema = StructType.fromDDL("data INT, partition INT, other INT")
+ val dataSchemaNotEqual = StructType.fromDDL("data INT, partition INT,
other INT, new INT")
+ val readDataSchema = StructType.fromDDL("data INT")
+ val readDataSchemaNotEqual = StructType.fromDDL("data INT, other INT")
+ val readPartitionSchema = StructType.fromDDL("partition INT")
+ val readPartitionSchemaNotEqual = StructType.fromDDL("partition INT, other
INT")
+ val pushedFilters =
+ Array[Filter](sources.And(sources.IsNull("data"),
sources.LessThan("data", 0)))
+ val pushedFiltersNotEqual =
+ Array[Filter](sources.And(sources.IsNull("data"),
sources.LessThan("data", 1)))
+ val optionsMap = ImmutableMap.of("key", "value")
+ val options = new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap))
+ val optionsNotEqual =
+ new CaseInsensitiveStringMap(ImmutableMap.copyOf(ImmutableMap.of("key2",
"value2")))
+ val partitionFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
+ val partitionFiltersNotEqual = Seq(And(IsNull('data.int),
LessThan('data.int, 1)))
+ val dataFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
+ val dataFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int,
1)))
+
+ scanBuilders.foreach { case (name, scanBuilder, exclusions) =>
+ test(s"SPARK-33482: Test $name equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanEquals = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema.copy(),
+ readDataSchema.copy(),
+ readPartitionSchema.copy(),
+ pushedFilters.clone(),
+ new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap)),
+ Seq(partitionFilters: _*),
+ Seq(dataFilters: _*))
+
+ assert(scan === scanEquals)
+ }
+
+ test(s"SPARK-33482: Test $name fileIndex not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val partitioningAwareFileIndexNotEqual =
newPartitioningAwareFileIndex()
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndexNotEqual,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ assert(scan !== scanNotEqual)
+ }
+
+ if (!exclusions.contains("dataSchema")) {
+ test(s"SPARK-33482: Test $name dataSchema not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchemaNotEqual,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ assert(scan !== scanNotEqual)
+ }
+ }
+
+ test(s"SPARK-33482: Test $name readDataSchema not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchemaNotEqual,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ assert(scan !== scanNotEqual)
+ }
+
+ test(s"SPARK-33482: Test $name readPartitionSchema not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchemaNotEqual,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ assert(scan !== scanNotEqual)
+ }
+
+ if (!exclusions.contains("pushedFilters")) {
+ test(s"SPARK-33482: Test $name pushedFilters not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFiltersNotEqual,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ assert(scan !== scanNotEqual)
+ }
+ }
+
+ test(s"SPARK-33482: Test $name options not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ optionsNotEqual,
+ partitionFilters,
+ dataFilters)
+
+ assert(scan !== scanNotEqual)
+ }
+
+ test(s"SPARK-33482: Test $name partitionFilters not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFiltersNotEqual,
+ dataFilters)
+ assert(scan !== scanNotEqual)
+ }
+
+ test(s"SPARK-33482: Test $name dataFilters not equals") {
+ val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+ val scan = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFilters)
+
+ val scanNotEqual = scanBuilder(
+ spark,
+ partitioningAwareFileIndex,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ pushedFilters,
+ options,
+ partitionFilters,
+ dataFiltersNotEqual)
+ assert(scan !== scanNotEqual)
+ }
+ }
+ }
+}
+
+class FileScanSuite extends FileScanSuiteBase {
+ val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
+ ("ParquetScan",
+ (s, fi, ds, rds, rps, f, o, pf, df) =>
+ ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o,
pf, df),
+ Seq.empty),
+ ("OrcScan",
+ (s, fi, ds, rds, rps, f, o, pf, df) =>
+ OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf,
df),
+ Seq.empty),
+ ("CSVScan",
+ (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o,
f, pf, df),
+ Seq.empty),
+ ("JsonScan",
+ (s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o,
f, pf, df),
+ Seq.empty),
+ ("TextScan",
+ (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf,
df),
+ Seq("dataSchema", "pushedFilters")))
+
+ run(scanBuilders)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c29eac2..aa673dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -37,6 +37,7 @@ import
org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupporte
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -3945,6 +3946,29 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
}
}
}
+
+ test("SPARK-33482: Fix FileScan canonicalization") {
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ withTempPath { path =>
+ spark.range(5).toDF().write.mode("overwrite").parquet(path.toString)
+ withTempView("t") {
+ spark.read.parquet(path.toString).createOrReplaceTempView("t")
+ val df = sql(
+ """
+ |SELECT *
+ |FROM t AS t1
+ |JOIN t AS t2 ON t2.id = t1.id
+ |JOIN t AS t3 ON t3.id = t2.id
+ |""".stripMargin)
+ df.collect()
+ val reusedExchanges = collect(df.queryExecution.executedPlan) {
+ case r: ReusedExchangeExec => r
+ }
+ assert(reusedExchanges.size == 1)
+ }
+ }
+ }
+ }
}
case class Foo(bar: Option[String])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]