This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 860849d [SPARK-31365][SQL] Enable nested predicate pushdown per data sources 860849d is described below commit 860849d73c994a6f8970dab631e097cb234c7abe Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Wed May 6 04:50:06 2020 +0000 [SPARK-31365][SQL] Enable nested predicate pushdown per data sources ### What changes were proposed in this pull request? This patch proposes to replace `NESTED_PREDICATE_PUSHDOWN_ENABLED` with `NESTED_PREDICATE_PUSHDOWN_V1_SOURCE_LIST` which can configure which v1 data sources are enabled with nested predicate pushdown. ### Why are the changes needed? We added nested predicate pushdown feature that is configured by `NESTED_PREDICATE_PUSHDOWN_ENABLED`. However, this config is all or nothing config, and applies on all data sources. In order to not introduce API breaking change after enabling nested predicate pushdown, we'd like to set nested predicate pushdown per data sources. Please also refer to the comments https://github.com/apache/spark/pull/27728#discussion_r410829720. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added/Modified unit tests. Closes #28366 from viirya/SPARK-31365. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 4952f1a03cc48d9f1c3d2539ffa19bf051e398bf) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/sql/internal/SQLConf.scala | 21 ++--- .../spark/sql/execution/DataSourceScanExec.scala | 5 +- .../execution/datasources/DataSourceStrategy.scala | 94 ++++++++++++------- .../execution/datasources/DataSourceUtils.scala | 16 ++++ .../execution/datasources/FileSourceStrategy.scala | 7 +- .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../execution/datasources/v2/PushDownUtils.scala | 3 +- .../datasources/v2/V2ScanRelationPushDown.scala | 2 +- .../ParquetNestedPredicatePushDownBenchmark.scala | 8 +- .../datasources/DataSourceStrategySuite.scala | 20 +++- .../datasources/parquet/ParquetFilterSuite.scala | 103 +++++++++++++-------- 11 files changed, 186 insertions(+), 100 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 72946a9..51404a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2063,16 +2063,17 @@ object SQLConf { .booleanConf .createWithDefault(true) - val NESTED_PREDICATE_PUSHDOWN_ENABLED = - buildConf("spark.sql.optimizer.nestedPredicatePushdown.enabled") - .internal() - .doc("When true, Spark tries to push down predicates for nested columns and or names " + - "containing `dots` to data sources. Currently, Parquet implements both optimizations " + - "while ORC only supports predicates for names containing `dots`. The other data sources" + - "don't support this feature yet.") + val NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST = + buildConf("spark.sql.optimizer.nestedPredicatePushdown.supportedFileSources") + .internal() + .doc("A comma-separated list of data source short names or fully qualified data source " + + "implementation class names for which Spark tries to push down predicates for nested " + + "columns and/or names containing `dots` to data sources. Currently, Parquet implements " + + "both optimizations while ORC only supports predicates for names containing `dots`. The " + + "other data sources don't support this feature yet. So the default value is 'parquet,orc'.") .version("3.0.0") - .booleanConf - .createWithDefault(true) + .stringConf + .createWithDefault("parquet,orc") val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED = buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled") @@ -3077,8 +3078,6 @@ class SQLConf extends Serializable with Logging { def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) - def nestedPredicatePushdownEnabled: Boolean = getConf(NESTED_PREDICATE_PUSHDOWN_ENABLED) - def serializerNestedSchemaPruningEnabled: Boolean = getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 0be76ad..90a3f97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -326,7 +326,10 @@ case class FileSourceScanExec( } @transient - private lazy val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } override lazy val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a58038d..23454d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -448,60 +448,62 @@ object DataSourceStrategy { } } - private def translateLeafNodeFilter(predicate: Expression): Option[Filter] = predicate match { - case expressions.EqualTo(PushableColumn(name), Literal(v, t)) => + private def translateLeafNodeFilter( + predicate: Expression, + pushableColumn: PushableColumnBase): Option[Filter] = predicate match { + case expressions.EqualTo(pushableColumn(name), Literal(v, t)) => Some(sources.EqualTo(name, convertToScala(v, t))) - case expressions.EqualTo(Literal(v, t), PushableColumn(name)) => + case expressions.EqualTo(Literal(v, t), pushableColumn(name)) => Some(sources.EqualTo(name, convertToScala(v, t))) - case expressions.EqualNullSafe(PushableColumn(name), Literal(v, t)) => + case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) => Some(sources.EqualNullSafe(name, convertToScala(v, t))) - case expressions.EqualNullSafe(Literal(v, t), PushableColumn(name)) => + case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) => Some(sources.EqualNullSafe(name, convertToScala(v, t))) - case expressions.GreaterThan(PushableColumn(name), Literal(v, t)) => + case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) => Some(sources.GreaterThan(name, convertToScala(v, t))) - case expressions.GreaterThan(Literal(v, t), PushableColumn(name)) => + case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) => Some(sources.LessThan(name, convertToScala(v, t))) - case expressions.LessThan(PushableColumn(name), Literal(v, t)) => + case expressions.LessThan(pushableColumn(name), Literal(v, t)) => Some(sources.LessThan(name, convertToScala(v, t))) - case expressions.LessThan(Literal(v, t), PushableColumn(name)) => + case expressions.LessThan(Literal(v, t), pushableColumn(name)) => Some(sources.GreaterThan(name, convertToScala(v, t))) - case expressions.GreaterThanOrEqual(PushableColumn(name), Literal(v, t)) => + case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) => Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), PushableColumn(name)) => + case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) => Some(sources.LessThanOrEqual(name, convertToScala(v, t))) - case expressions.LessThanOrEqual(PushableColumn(name), Literal(v, t)) => + case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) => Some(sources.LessThanOrEqual(name, convertToScala(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) => + case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) => Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) - case expressions.InSet(e @ PushableColumn(name), set) => + case expressions.InSet(e @ pushableColumn(name), set) => val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) Some(sources.In(name, set.toArray.map(toScala))) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(e @ PushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => + case expressions.In(e @ pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(_.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) Some(sources.In(name, hSet.toArray.map(toScala))) - case expressions.IsNull(PushableColumn(name)) => + case expressions.IsNull(pushableColumn(name)) => Some(sources.IsNull(name)) - case expressions.IsNotNull(PushableColumn(name)) => + case expressions.IsNotNull(pushableColumn(name)) => Some(sources.IsNotNull(name)) - case expressions.StartsWith(PushableColumn(name), Literal(v: UTF8String, StringType)) => + case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => Some(sources.StringStartsWith(name, v.toString)) - case expressions.EndsWith(PushableColumn(name), Literal(v: UTF8String, StringType)) => + case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => Some(sources.StringEndsWith(name, v.toString)) - case expressions.Contains(PushableColumn(name), Literal(v: UTF8String, StringType)) => + case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) => Some(sources.StringContains(name, v.toString)) case expressions.Literal(true, BooleanType) => @@ -518,8 +520,9 @@ object DataSourceStrategy { * * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ - protected[sql] def translateFilter(predicate: Expression): Option[Filter] = { - translateFilterWithMapping(predicate, None) + protected[sql] def translateFilter( + predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[Filter] = { + translateFilterWithMapping(predicate, None, supportNestedPredicatePushdown) } /** @@ -529,11 +532,13 @@ object DataSourceStrategy { * @param translatedFilterToExpr An optional map from leaf node filter expressions to its * translated [[Filter]]. The map is used for rebuilding * [[Expression]] from [[Filter]]. + * @param nestedPredicatePushdownEnabled Whether nested predicate pushdown is enabled. * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ protected[sql] def translateFilterWithMapping( predicate: Expression, - translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]]) + translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]], + nestedPredicatePushdownEnabled: Boolean) : Option[Filter] = { predicate match { case expressions.And(left, right) => @@ -547,21 +552,26 @@ object DataSourceStrategy { // Pushing one leg of AND down is only safe to do at the top level. // You can see ParquetFilters' createFilter for more details. for { - leftFilter <- translateFilterWithMapping(left, translatedFilterToExpr) - rightFilter <- translateFilterWithMapping(right, translatedFilterToExpr) + leftFilter <- translateFilterWithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterWithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) } yield sources.And(leftFilter, rightFilter) case expressions.Or(left, right) => for { - leftFilter <- translateFilterWithMapping(left, translatedFilterToExpr) - rightFilter <- translateFilterWithMapping(right, translatedFilterToExpr) + leftFilter <- translateFilterWithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterWithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) } yield sources.Or(leftFilter, rightFilter) case expressions.Not(child) => - translateFilterWithMapping(child, translatedFilterToExpr).map(sources.Not) + translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) + .map(sources.Not) case other => - val filter = translateLeafNodeFilter(other) + val filter = translateLeafNodeFilter(other, PushableColumn(nestedPredicatePushdownEnabled)) if (filter.isDefined && translatedFilterToExpr.isDefined) { translatedFilterToExpr.get(filter.get) = predicate } @@ -608,8 +618,9 @@ object DataSourceStrategy { // A map from original Catalyst expressions to corresponding translated data source filters. // If a predicate is not in this map, it means it cannot be pushed down. + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) val translatedMap: Map[Expression, Filter] = predicates.flatMap { p => - translateFilter(p).map(f => p -> f) + translateFilter(p, supportNestedPredicatePushdown).map(f => p -> f) }.toMap val pushedFilters: Seq[Filter] = translatedMap.values.toSeq @@ -650,9 +661,10 @@ object DataSourceStrategy { /** * Find the column name of an expression that can be pushed down. */ -object PushableColumn { +abstract class PushableColumnBase { + val nestedPredicatePushdownEnabled: Boolean + def unapply(e: Expression): Option[String] = { - val nestedPredicatePushdownEnabled = SQLConf.get.nestedPredicatePushdownEnabled import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper def helper(e: Expression): Option[Seq[String]] = e match { case a: Attribute => @@ -668,3 +680,21 @@ object PushableColumn { helper(e).map(_.quoted) } } + +object PushableColumn { + def apply(nestedPredicatePushdownEnabled: Boolean): PushableColumnBase = { + if (nestedPredicatePushdownEnabled) { + PushableColumnAndNestedColumn + } else { + PushableColumnWithoutNestedColumn + } + } +} + +object PushableColumnAndNestedColumn extends PushableColumnBase { + override val nestedPredicatePushdownEnabled = true +} + +object PushableColumnWithoutNestedColumn extends PushableColumnBase { + override val nestedPredicatePushdownEnabled = false +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index b19de6d..45a9b1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s.NoTypeHints import org.json4s.jackson.Serialization @@ -24,6 +26,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -68,6 +71,19 @@ object DataSourceUtils { private[sql] def isDataFile(fileName: String) = !(fileName.startsWith("_") || fileName.startsWith(".")) + /** + * Returns if the given relation's V1 datasource provider supports nested predicate pushdown. + */ + private[sql] def supportNestedPredicatePushdown(relation: BaseRelation): Boolean = + relation match { + case hs: HadoopFsRelation => + val supportedDatasources = + Utils.stringToSeq(SQLConf.get.getConf(SQLConf.NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST) + .toLowerCase(Locale.ROOT)) + supportedDatasources.contains(hs.toString) + case _ => false + } + def needRebaseDateTime(lookupFileMeta: String => String): Option[Boolean] = { if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { return Some(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index f454951..477937d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -178,8 +178,11 @@ object FileSourceStrategy extends Strategy with Logging { // Partition keys are not available in the statistics of the files. val dataFilters = normalizedFiltersWithoutSubqueries.filter(_.references.intersect(partitionSet).isEmpty) - logInfo(s"Pushed Filters: " + - s"${dataFilters.flatMap(DataSourceStrategy.translateFilter).mkString(",")}") + val supportNestedPredicatePushdown = + DataSourceUtils.supportNestedPredicatePushdown(fsRelation) + val pushedFilters = dataFilters + .flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + logInfo(s"Pushed Filters: ${pushedFilters.mkString(",")}") // Predicates with both partition keys and attributes need to be evaluated after the scan. val afterScanFilters = filterSet -- partitionKeyFilters.filter(_.references.nonEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 8f4e2d2..cca80c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -180,8 +180,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. val filters = splitConjunctivePredicates(deleteExpr).map { - filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( - throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) + filter => DataSourceStrategy.translateFilter(deleteExpr, + supportNestedPredicatePushdown = true).getOrElse( + throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => @@ -205,7 +206,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat // correctness depends on removing all matching data. val filters = DataSourceStrategy.normalizeExprs(condition.toSeq, output) .flatMap(splitConjunctivePredicates(_).map { - f => DataSourceStrategy.translateFilter(f).getOrElse( + f => DataSourceStrategy.translateFilter(f, true).getOrElse( throw new AnalysisException(s"Exec update failed:" + s" cannot translate expression to source filter: $f")) }).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 33338b0..1a6f03f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -48,7 +48,8 @@ object PushDownUtils extends PredicateHelper { for (filterExpr <- filters) { val translated = - DataSourceStrategy.translateFilterWithMapping(filterExpr, Some(translatedFilterToExpr)) + DataSourceStrategy.translateFilterWithMapping(filterExpr, Some(translatedFilterToExpr), + nestedPredicatePushdownEnabled = true) if (translated.isEmpty) { untranslatableExprs += filterExpr } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 59089fa..b168e84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -59,7 +59,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { val wrappedScan = scan match { case v1: V1Scan => - val translated = filters.flatMap(DataSourceStrategy.translateFilter) + val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) V1ScanWrapper(v1, translated, pushedFilters) case _ => scan } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetNestedPredicatePushDownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetNestedPredicatePushDownBenchmark.scala index 11bc91a..d2bd962 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetNestedPredicatePushDownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetNestedPredicatePushDownBenchmark.scala @@ -48,12 +48,12 @@ object ParquetNestedPredicatePushDownBenchmark extends SqlBasedBenchmark { private def addCase( benchmark: Benchmark, inputPath: String, - enableNestedPD: Boolean, + enableNestedPD: String, name: String, withFilter: DataFrame => DataFrame): Unit = { val loadDF = spark.read.parquet(inputPath) benchmark.addCase(name) { _ => - withSQLConf((SQLConf.NESTED_PREDICATE_PUSHDOWN_ENABLED.key, enableNestedPD.toString)) { + withSQLConf((SQLConf.NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST.key, enableNestedPD)) { withFilter(loadDF).noop() } } @@ -67,13 +67,13 @@ object ParquetNestedPredicatePushDownBenchmark extends SqlBasedBenchmark { addCase( benchmark, outputPath, - enableNestedPD = false, + enableNestedPD = "", "Without nested predicate Pushdown", withFilter) addCase( benchmark, outputPath, - enableNestedPD = true, + enableNestedPD = "parquet", "With nested predicate Pushdown", withFilter) benchmark.run() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index a775a97..b94918e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -289,14 +289,26 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { test("SPARK-31027 test `PushableColumn.unapply` that finds the column name of " + "an expression that can be pushed down") { attrInts.foreach { case (attrInt, colName) => - assert(PushableColumn.unapply(attrInt) === Some(colName)) + assert(PushableColumnAndNestedColumn.unapply(attrInt) === Some(colName)) + + if (colName.contains(".")) { + assert(PushableColumnWithoutNestedColumn.unapply(attrInt) === None) + } else { + assert(PushableColumnWithoutNestedColumn.unapply(attrInt) === Some(colName)) + } } attrStrs.foreach { case (attrStr, colName) => - assert(PushableColumn.unapply(attrStr) === Some(colName)) + assert(PushableColumnAndNestedColumn.unapply(attrStr) === Some(colName)) + + if (colName.contains(".")) { + assert(PushableColumnWithoutNestedColumn.unapply(attrStr) === None) + } else { + assert(PushableColumnWithoutNestedColumn.unapply(attrStr) === Some(colName)) + } } // `Abs(col)` can not be pushed down, so it returns `None` - assert(PushableColumn.unapply(Abs('col.int)) === None) + assert(PushableColumnAndNestedColumn.unapply(Abs('col.int)) === None) } /** @@ -305,7 +317,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { */ def testTranslateFilter(catalystFilter: Expression, result: Option[sources.Filter]): Unit = { assertResult(result) { - DataSourceStrategy.translateFilter(catalystFilter) + DataSourceStrategy.translateFilter(catalystFilter, true) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 20bfb32..5cf2129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation, PushableColumnAndNestedColumn} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.functions._ @@ -1588,47 +1589,67 @@ class ParquetV1FilterSuite extends ParquetFilterSuite { expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf( - SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", - SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", - SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", - SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", - SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", - // Disable adding filters from constraints because it adds, for instance, - // is-not-null to pushed filters, which makes it hard to test if the pushed - // filter is expected or not (this had to be fixed with SPARK-13495). - SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> InferFiltersFromConstraints.ruleName, - SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) + Seq(("parquet", true), ("", false)).map { case (pushdownDsList, nestedPredicatePushdown) => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", + // Disable adding filters from constraints because it adds, for instance, + // is-not-null to pushed filters, which makes it hard to test if the pushed + // filter is expected or not (this had to be fixed with SPARK-13495). + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> InferFiltersFromConstraints.ruleName, + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST.key -> pushdownDsList) { + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + val nestedOrAttributes = predicate.collectFirst { + case g: GetStructField => g + case a: Attribute => a + } + assert(nestedOrAttributes.isDefined, "No GetStructField nor Attribute is detected.") + + val parsed = parseColumnPath( + PushableColumnAndNestedColumn.unapply(nestedOrAttributes.get).get) + + val containsNestedColumnOrDot = parsed.length > 1 || parsed(0).contains(".") + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, + LogicalRelation(relation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(relation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + // If predicates contains nested column or dot, we push down the predicates only if + // "parquet" is in `NESTED_PREDICATE_PUSHDOWN_V1_SOURCE_LIST`. + if (nestedPredicatePushdown || !containsNestedColumnOrDot) { + assert(selectedFilters.nonEmpty, "No filter is pushed down") + val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) + val parquetFilters = createParquetFilters(schema) + // In this test suite, all the simple predicates are convertible here. + assert(parquetFilters.convertibleFilters(selectedFilters) === selectedFilters) + val pushedParquetFilters = selectedFilters.map { pred => + val maybeFilter = parquetFilters.createFilter(pred) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") + maybeFilter.get + } + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(pushedParquetFilters.exists(_.getClass === filterClass), + s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.") - var maybeRelation: Option[HadoopFsRelation] = None - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, - LogicalRelation(relation: HadoopFsRelation, _, _, _)) => - maybeRelation = Some(relation) - filters - }.flatten.reduceLeftOption(_ && _) - assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") - - val (_, selectedFilters, _) = - DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) - assert(selectedFilters.nonEmpty, "No filter is pushed down") - val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) - val parquetFilters = createParquetFilters(schema) - // In this test suite, all the simple predicates are convertible here. - assert(parquetFilters.convertibleFilters(selectedFilters) === selectedFilters) - val pushedParquetFilters = selectedFilters.map { pred => - val maybeFilter = parquetFilters.createFilter(pred) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.get + checker(stripSparkFilter(query), expected) + } else { + assert(selectedFilters.isEmpty, "There is filter pushed down") + } } - // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) - assert(pushedParquetFilters.exists(_.getClass === filterClass), - s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.") - - checker(stripSparkFilter(query), expected) } } } @@ -1667,7 +1688,7 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, scan: ParquetScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter).toArray + val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true)).toArray val pushedFilters = scan.pushedFilters assert(pushedFilters.nonEmpty, "No filter is pushed down") val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org