This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c63ba6c4f43 [SPARK-42454][SQL] SPJ: encapsulate all SPJ related parameters in BatchScanExec c63ba6c4f43 is described below commit c63ba6c4f43f5872ee6804361ee443c29a739d9d Author: Szehon Ho <szehon.apa...@gmail.com> AuthorDate: Fri Jul 14 21:45:37 2023 -0700 [SPARK-42454][SQL] SPJ: encapsulate all SPJ related parameters in BatchScanExec ### What changes were proposed in this pull request? Pull out the SPJ-related attribute of BatchScanExec into a case class ### Why are the changes needed? We plan to have further evolution of SPJ parameters to support more SPJ features. So we want to stabilize the definition of BatchScanExec to not have to touch the many places in the code that it is pattern-matched/unapplied, etc.. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing unit test to verify no behavior change. Closes #41990 from szehon-ho/spj_refactor. Authored-by: Szehon Ho <szehon.apa...@gmail.com> Signed-off-by: Chao Sun <sunc...@apple.com> --- .../apache/spark/sql/avro/AvroRowReaderSuite.scala | 2 +- .../org/apache/spark/sql/avro/AvroSuite.scala | 6 +-- .../execution/datasources/v2/BatchScanExec.scala | 63 +++++++++++++++------- .../datasources/v2/DataSourceV2Strategy.scala | 3 +- .../execution/exchange/EnsureRequirements.scala | 8 +-- .../spark/sql/FileBasedDataSourceSuite.scala | 4 +- .../PruneFileSourcePartitionsSuite.scala | 2 +- .../datasources/PrunePartitionSuiteBase.scala | 2 +- .../datasources/orc/OrcV2SchemaPruningSuite.scala | 2 +- 9 files changed, 59 insertions(+), 33 deletions(-) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 046ff4ef088..cc0e178c617 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -59,7 +59,7 @@ class AvroRowReaderSuite val df = spark.read.format("avro").load(dir.getCanonicalPath) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } val filePath = fileScan.get.fileIndex.inputFiles(0) val fileSize = new File(new URI(filePath)).length diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c5e52292caf..35e9f43289c 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2778,7 +2778,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { }) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -2812,7 +2812,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) @@ -2893,7 +2893,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { .where("value = 'a'") val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } assert(fileScan.nonEmpty) if (filtersPushdown) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index d43331d57c4..4b538197392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -37,23 +37,19 @@ case class BatchScanExec( output: Seq[AttributeReference], @transient scan: Scan, runtimeFilters: Seq[Expression], - keyGroupedPartitioning: Option[Seq[Expression]] = None, ordering: Option[Seq[SortOrder]] = None, @transient table: Table, - commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - applyPartialClustering: Boolean = false, - replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase { + spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams() + ) extends DataSourceV2ScanExecBase { - @transient lazy val batch = if (scan == null) null else scan.toBatch + @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: BatchScanExec => this.batch != null && this.batch == other.batch && this.runtimeFilters == other.runtimeFilters && - this.commonPartitionValues == other.commonPartitionValues && - this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering + this.spjParams == other.spjParams case _ => false } @@ -119,11 +115,11 @@ case class BatchScanExec( override def outputPartitioning: Partitioning = { super.outputPartitioning match { - case k: KeyGroupedPartitioning if commonPartitionValues.isDefined => + case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined => // We allow duplicated partition values if // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true - val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) => - Seq.fill(numSplits)(partValue) + val newPartValues = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) case p => p @@ -148,15 +144,17 @@ case class BatchScanExec( s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + "is enabled") - val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get + val groupedPartitions = groupPartitions(finalPartitions.map(_.head), + groupSplits = true).get // This means the input partitions are not grouped by partition values. We'll need to // check `groupByPartitionValues` and decide whether to group and replicate splits // within a partition. - if (commonPartitionValues.isDefined && applyPartialClustering) { + if (spjParams.commonPartitionValues.isDefined && + spjParams.applyPartialClustering) { // A mapping from the common partition values to how many splits the partition // should contain. Note this no longer maintain the partition key ordering. - val commonPartValuesMap = commonPartitionValues + val commonPartValuesMap = spjParams.commonPartitionValues .get .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) .toMap @@ -168,7 +166,7 @@ case class BatchScanExec( assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") - val newSplits = if (replicatePartitions) { + val newSplits = if (spjParams.replicatePartitions) { // We need to also replicate partitions according to the other side of join Seq.fill(numSplits.get)(splits) } else { @@ -184,11 +182,12 @@ case class BatchScanExec( // Now fill missing partition keys with empty partitions val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), - Seq.fill(numSplits)(Seq.empty)) + finalPartitions = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), + Seq.fill(numSplits)(Seq.empty)) } } else { val partitionMapping = groupedPartitions.map { case (row, parts) => @@ -222,6 +221,9 @@ case class BatchScanExec( rdd } + override def keyGroupedPartitioning: Option[Seq[Expression]] = + spjParams.keyGroupedPartitioning + override def doCanonicalize(): BatchScanExec = { this.copy( output = output.map(QueryPlan.normalizeExpressions(_, output)), @@ -241,3 +243,24 @@ case class BatchScanExec( s"BatchScan ${table.name()}".trim } } + +case class StoragePartitionJoinParams( + keyGroupedPartitioning: Option[Seq[Expression]] = None, + commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + applyPartialClustering: Boolean = false, + replicatePartitions: Boolean = false) { + override def equals(other: Any): Boolean = other match { + case other: StoragePartitionJoinParams => + this.commonPartitionValues == other.commonPartitionValues && + this.replicatePartitions == other.replicatePartitions && + this.applyPartialClustering == other.applyPartialClustering + case _ => + false + } + + override def hashCode(): Int = Objects.hashCode( + commonPartitionValues: Option[Seq[(InternalRow, Int)]], + applyPartialClustering: java.lang.Boolean, + replicatePartitions: java.lang.Boolean) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 542ac2e6748..abd70f322c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -142,7 +142,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _ => false } val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters, - relation.keyGroupedPartitioning, relation.ordering, relation.relation.table) + relation.ordering, relation.relation.table, + StoragePartitionJoinParams(relation.keyGroupedPartitioning)) withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 457a9e0a868..42c880e7c62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -531,9 +531,11 @@ case class EnsureRequirements( replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( - commonPartitionValues = Some(values), - applyPartialClustering = applyPartialClustering, - replicatePartitions = replicatePartitions + spjParams = scan.spjParams.copy( + commonPartitionValues = Some(values), + applyPartialClustering = applyPartialClustering, + replicatePartitions = replicatePartitions + ) ) case node => node.mapChildren(child => populatePartitionValues( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index d69a68f5726..93275487f29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -1014,7 +1014,7 @@ class FileBasedDataSourceSuite extends QueryTest }) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: FileScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -1055,7 +1055,7 @@ class FileBasedDataSourceSuite extends QueryTest assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: FileScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala index 1b30205a418..3a70bfc7f4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala @@ -170,7 +170,7 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with Shared override def getScanExecPartitionSize(plan: SparkPlan): Long = { plan.collectFirst { case p: FileSourceScanExec => p.selectedPartitions.length - case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) => + case BatchScanExec(_, scan: FileScan, _, _, _, _) => scan.fileIndex.listFiles(scan.partitionFilters, scan.dataFilters).length }.get } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala index 9a61e6517f7..430e9f848e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala @@ -95,7 +95,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase { assert(getScanExecPartitionSize(plan) == expectedPartitionCount) val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse { - case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) => scan.partitionFilters + case BatchScanExec(_, scan: FileScan, _, _, _, _) => scan.partitionFilters } val pushedDownPartitionFilters = plan.collectFirst(collectFn) .map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala index 1fba772f5a8..8d503d64e30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala @@ -42,7 +42,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { val fileSourceScanSchemata = collect(df.queryExecution.executedPlan) { - case BatchScanExec(_, scan: OrcScan, _, _, _, _, _, _, _) => scan.readDataSchema + case BatchScanExec(_, scan: OrcScan, _, _, _, _) => scan.readDataSchema } assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org