This is an automated email from the ASF dual-hosted git repository. tgraves pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 7c91b15 [SPARK-32332][SQL][3.0] Support columnar exchanges 7c91b15 is described below commit 7c91b15c22fe875a44e08781caf3422bfca81b19 Author: Andy Grove <andygr...@nvidia.com> AuthorDate: Fri Jul 31 11:14:33 2020 -0500 [SPARK-32332][SQL][3.0] Support columnar exchanges ### What changes were proposed in this pull request? Backports SPARK-32332 to 3.0 branch. ### Why are the changes needed? Plugins cannot replace exchanges with columnar versions when AQE is enabled without this patch. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tests included. Closes #29310 from andygrove/backport-SPARK-32332. Authored-by: Andy Grove <andygr...@nvidia.com> Signed-off-by: Thomas Graves <tgra...@apache.org> --- .../execution/adaptive/AdaptiveSparkPlanExec.scala | 30 ++++-- .../adaptive/CustomShuffleReaderExec.scala | 37 ++++--- .../adaptive/OptimizeLocalShuffleReader.scala | 5 +- .../execution/adaptive/OptimizeSkewedJoin.scala | 17 +-- .../sql/execution/adaptive/QueryStageExec.scala | 24 +++-- .../sql/execution/adaptive/simpleCosting.scala | 6 +- .../execution/exchange/BroadcastExchangeExec.scala | 42 +++++++- .../execution/exchange/ShuffleExchangeExec.scala | 55 +++++++++- .../execution/streaming/IncrementalExecution.scala | 4 +- .../spark/sql/SparkSessionExtensionSuite.scala | 120 +++++++++++++++++---- 10 files changed, 272 insertions(+), 68 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 5714c33..8b59b12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -100,7 +100,12 @@ case class AdaptiveSparkPlanExec( // The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs' // added by `CoalesceShufflePartitions`. So they must be executed after it. OptimizeSkewedJoin(conf), - OptimizeLocalShuffleReader(conf), + OptimizeLocalShuffleReader(conf) + ) + + // A list of physical optimizer rules to be applied right after a new stage is created. The input + // plan to these rules has exchange as its root node. + @transient private val postStageCreationRules = Seq( ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules), CollapseCodegenStages(conf) ) @@ -227,7 +232,8 @@ case class AdaptiveSparkPlanExec( } // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules) + currentPhysicalPlan = applyPhysicalRules( + result.newPlan, queryStageOptimizerRules ++ postStageCreationRules) isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan @@ -375,10 +381,22 @@ case class AdaptiveSparkPlanExec( private def newQueryStage(e: Exchange): QueryStageExec = { val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules) val queryStage = e match { - case s: ShuffleExchangeExec => - ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan)) - case b: BroadcastExchangeExec => - BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan)) + case s: ShuffleExchangeLike => + val newShuffle = applyPhysicalRules( + s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules) + if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) { + throw new IllegalStateException( + "Custom columnar rules cannot transform shuffle node to something else.") + } + ShuffleQueryStageExec(currentStageId, newShuffle) + case b: BroadcastExchangeLike => + val newBroadcast = applyPhysicalRules( + b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules) + if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) { + throw new IllegalStateException( + "Custom columnar rules cannot transform broadcast node to something else.") + } + BroadcastQueryStageExec(currentStageId, newBroadcast) } currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala index ba3f725..8fd5720 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.vectorized.ColumnarBatch /** @@ -38,6 +39,8 @@ case class CustomShuffleReaderExec private( partitionSpecs: Seq[ShufflePartitionSpec], description: String) extends UnaryExecNode { + override def supportsColumnar: Boolean = child.supportsColumnar + override def output: Seq[Attribute] = child.output override lazy val outputPartitioning: Partitioning = { // If it is a local shuffle reader with one mapper per task, then the output partitioning is @@ -47,9 +50,9 @@ case class CustomShuffleReaderExec private( partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size == partitionSpecs.length) { child match { - case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) => + case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) => s.child.outputPartitioning - case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) => + case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) => s.child.outputPartitioning match { case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning] case other => other @@ -64,18 +67,24 @@ case class CustomShuffleReaderExec private( override def stringArgs: Iterator[Any] = Iterator(description) - private var cachedShuffleRDD: RDD[InternalRow] = null + private def shuffleStage = child match { + case stage: ShuffleQueryStageExec => Some(stage) + case _ => None + } - override protected def doExecute(): RDD[InternalRow] = { - if (cachedShuffleRDD == null) { - cachedShuffleRDD = child match { - case stage: ShuffleQueryStageExec => - new ShuffledRowRDD( - stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray) - case _ => - throw new IllegalStateException("operating on canonicalization plan") - } + private lazy val shuffleRDD: RDD[_] = { + shuffleStage.map { stage => + stage.shuffle.getShuffleRDD(partitionSpecs.toArray) + }.getOrElse { + throw new IllegalStateException("operating on canonicalized plan") } - cachedShuffleRDD + } + + override protected def doExecute(): RDD[InternalRow] = { + shuffleRDD.asInstanceOf[RDD[InternalRow]] + } + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + shuffleRDD.asInstanceOf[RDD[ColumnarBatch]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index fb6b40c..6684376 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { private def getPartitionSpecs( shuffleStage: ShuffleQueryStageExec, advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = { - val shuffleDep = shuffleStage.shuffle.shuffleDependency - val numReducers = shuffleDep.partitioner.numPartitions + val numMappers = shuffleStage.shuffle.numMappers + val numReducers = shuffleStage.shuffle.numPartitions val expectedParallelism = advisoryParallelism.getOrElse(numReducers) - val numMappers = shuffleDep.rdd.getNumPartitions val splitPoints = if (numMappers == 0) { Seq.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 91ae0b9..b3b3eb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.commons.io.FileUtils -import org.apache.spark.{MapOutputTrackerMaster, SparkEnv} +import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ @@ -197,7 +197,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val leftParts = if (isLeftSkew && !isLeftCoalesced) { val reducerId = leftPartSpec.startReducerIndex val skewSpecs = createSkewPartitionSpecs( - left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize) + left.mapStats.shuffleId, reducerId, leftTargetSize) if (skewSpecs.isDefined) { logDebug(s"Left side partition $partitionIndex is skewed, split it into " + s"${skewSpecs.get.length} parts.") @@ -212,7 +212,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val rightParts = if (isRightSkew && !isRightCoalesced) { val reducerId = rightPartSpec.startReducerIndex val skewSpecs = createSkewPartitionSpecs( - right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize) + right.mapStats.shuffleId, reducerId, rightTargetSize) if (skewSpecs.isDefined) { logDebug(s"Right side partition $partitionIndex is skewed, split it into " + s"${skewSpecs.get.length} parts.") @@ -287,15 +287,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { private object ShuffleStage { def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match { case s: ShuffleQueryStageExec if s.mapStats.isDefined => - val sizes = s.mapStats.get.bytesByPartitionId + val mapStats = s.mapStats.get + val sizes = mapStats.bytesByPartitionId val partitions = sizes.zipWithIndex.map { case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size } - Some(ShuffleStageInfo(s, partitions)) + Some(ShuffleStageInfo(s, mapStats, partitions)) case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _) if s.mapStats.isDefined && partitionSpecs.nonEmpty => - val sizes = s.mapStats.get.bytesByPartitionId + val mapStats = s.mapStats.get + val sizes = mapStats.bytesByPartitionId val partitions = partitionSpecs.map { case spec @ CoalescedPartitionSpec(start, end) => var sum = 0L @@ -308,7 +310,7 @@ private object ShuffleStage { case other => throw new IllegalArgumentException( s"Expect CoalescedPartitionSpec but got $other") } - Some(ShuffleStageInfo(s, partitions)) + Some(ShuffleStageInfo(s, mapStats, partitions)) case _ => None } @@ -316,6 +318,7 @@ private object ShuffleStage { private case class ShuffleStageInfo( shuffleStage: ShuffleQueryStageExec, + mapStats: MapOutputStatistics, partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)]) private class SkewDesc { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 9a9a8b1..74fe1ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ThreadUtils /** @@ -81,6 +82,11 @@ abstract class QueryStageExec extends LeafExecNode { def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec /** + * Returns the runtime statistics after stage materialization. + */ + def getRuntimeStatistics: Statistics + + /** * Compute the statistics of the query stage if executed, otherwise None. */ def computeStats(): Option[Statistics] = resultOption.map { _ => @@ -107,6 +113,8 @@ abstract class QueryStageExec extends LeafExecNode { protected override def doPrepare(): Unit = plan.prepare() protected override def doExecute(): RDD[InternalRow] = plan.execute() + override def supportsColumnar: Boolean = plan.supportsColumnar + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar() override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() override def doCanonicalize(): SparkPlan = plan.canonicalized @@ -135,15 +143,15 @@ abstract class QueryStageExec extends LeafExecNode { } /** - * A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]]. + * A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]]. */ case class ShuffleQueryStageExec( override val id: Int, override val plan: SparkPlan) extends QueryStageExec { @transient val shuffle = plan match { - case s: ShuffleExchangeExec => s - case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s + case s: ShuffleExchangeLike => s + case ReusedExchangeExec(_, s: ShuffleExchangeLike) => s case _ => throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString) } @@ -176,18 +184,20 @@ case class ShuffleQueryStageExec( val stats = resultOption.get.asInstanceOf[MapOutputStatistics] Option(stats) } + + override def getRuntimeStatistics: Statistics = shuffle.runtimeStatistics } /** - * A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]]. + * A broadcast query stage whose child is a [[BroadcastExchangeLike]] or [[ReusedExchangeExec]]. */ case class BroadcastQueryStageExec( override val id: Int, override val plan: SparkPlan) extends QueryStageExec { @transient val broadcast = plan match { - case b: BroadcastExchangeExec => b - case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b + case b: BroadcastExchangeLike => b + case ReusedExchangeExec(_, b: BroadcastExchangeLike) => b case _ => throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString) } @@ -224,6 +234,8 @@ case class BroadcastQueryStageExec( broadcast.relationFuture.cancel(true) } } + + override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics } object BroadcastQueryStageExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala index 67cd720..cdc57db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} /** * A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value. @@ -35,13 +35,13 @@ case class SimpleCost(value: Long) extends Cost { /** * A simple implementation of [[CostEvaluator]], which counts the number of - * [[ShuffleExchangeExec]] nodes in the plan. + * [[ShuffleExchangeLike]] nodes in the plan. */ object SimpleCostEvaluator extends CostEvaluator { override def evaluateCost(plan: SparkPlan): Cost = { val cost = plan.collect { - case s: ShuffleExchangeExec => s + case s: ShuffleExchangeLike => s }.size SimpleCost(cost) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index d35bbe9..bcdaf61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation @@ -38,15 +39,42 @@ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{SparkFatalException, ThreadUtils} /** + * Common trait for all broadcast exchange implementations to facilitate pattern matching. + */ +trait BroadcastExchangeLike extends Exchange { + + /** + * The broadcast job group ID + */ + def runId: UUID = UUID.randomUUID + + /** + * The asynchronous job that prepares the broadcast relation. + */ + def relationFuture: Future[broadcast.Broadcast[Any]] + + /** + * For registering callbacks on `relationFuture`. + * Note that calling this method may not start the execution of broadcast job. + */ + def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] + + /** + * Returns the runtime statistics after broadcast materialization. + */ + def runtimeStatistics: Statistics +} + +/** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of * a transformed SparkPlan. */ case class BroadcastExchangeExec( mode: BroadcastMode, - child: SparkPlan) extends Exchange { + child: SparkPlan) extends BroadcastExchangeLike { import BroadcastExchangeExec._ - private[sql] val runId: UUID = UUID.randomUUID + override val runId: UUID = UUID.randomUUID override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), @@ -60,6 +88,11 @@ case class BroadcastExchangeExec( BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + Statistics(dataSize) + } + @transient private lazy val promise = Promise[broadcast.Broadcast[Any]]() @@ -68,13 +101,14 @@ case class BroadcastExchangeExec( * Note that calling this field will not start the execution of broadcast job. */ @transient - lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future + override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = + promise.future @transient private val timeout: Long = SQLConf.get.broadcastTimeout @transient - private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b06742e..b7da78c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Divide, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -41,12 +42,48 @@ import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} /** + * Common trait for all shuffle exchange implementations to facilitate pattern matching. + */ +trait ShuffleExchangeLike extends Exchange { + + /** + * Returns the number of mappers of this shuffle. + */ + def numMappers: Int + + /** + * Returns the shuffle partition number. + */ + def numPartitions: Int + + /** + * Returns whether the shuffle partition number can be changed. + */ + def canChangeNumPartitions: Boolean + + /** + * The asynchronous job that materializes the shuffle. + */ + def mapOutputStatisticsFuture: Future[MapOutputStatistics] + + /** + * Returns the shuffle RDD with specified partition specs. + */ + def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] + + /** + * Returns the runtime statistics after shuffle materialization. + */ + def runtimeStatistics: Statistics +} + +/** * Performs a shuffle that will result in the desired partitioning. */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean = true) extends Exchange { + canChangeNumPartitions: Boolean = true) extends ShuffleExchangeLike { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) @@ -64,7 +101,7 @@ case class ShuffleExchangeExec( @transient lazy val inputRDD: RDD[InternalRow] = child.execute() // 'mapOutputStatisticsFuture' is only needed when enable AQE. - @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { if (inputRDD.getNumPartitions == 0) { Future.successful(null) } else { @@ -72,6 +109,20 @@ case class ShuffleExchangeExec( } } + override def numMappers: Int = shuffleDependency.rdd.getNumPartitions + + override def numPartitions: Int = shuffleDependency.partitioner.numPartitions + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = { + new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + /** * A [[ShuffleDependency]] that will partition rows of its child based on * the partitioning scheme defined in `newPartitioning`. Those partitions of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 7773ac7..bfa60cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.Utils @@ -118,7 +118,7 @@ class IncrementalExecution( case s: StatefulOperator => statefulOpFound = true - case e: ShuffleExchangeExec => + case e: ShuffleExchangeLike => // Don't search recursively any further as any child stateful operator as we // are only looking for stateful subplans that this plan has narrow dependencies on. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 44e784d..e5e8bc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -16,19 +16,24 @@ */ package org.apache.spark.sql -import java.util.Locale +import java.util.{Locale, UUID} -import org.apache.spark.{SparkFunSuite, TaskContext} +import scala.concurrent.Future + +import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE @@ -169,33 +174,61 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } - test("inject columnar") { + test("inject columnar AQE on") { + testInjectColumnar(true) + } + + test("inject columnar AQE off") { + testInjectColumnar(false) + } + + private def testInjectColumnar(enableAQE: Boolean): Unit = { + def collectPlanSteps(plan: SparkPlan): Seq[Int] = plan match { + case a: AdaptiveSparkPlanExec => + assert(a.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) + collectPlanSteps(a.executedPlan) + case _ => plan.collect { + case _: ReplacedRowToColumnarExec => 1 + case _: ColumnarProjectExec => 10 + case _: ColumnarToRowExec => 100 + case s: QueryStageExec => collectPlanSteps(s.plan).sum + case _: MyShuffleExchangeExec => 1000 + case _: MyBroadcastExchangeExec => 10000 + } + } + val extensions = create { extensions => extensions.injectColumnar(session => MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } withSession(extensions) { session => - // The ApplyColumnarRulesAndInsertTransitions rule is not applied when enable AQE - session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false) + session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE) assert(session.sessionState.columnarRules.contains( MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) import session.sqlContext.implicits._ - // repartitioning avoids having the add operation pushed up into the LocalTableScan - val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1) - val df = data.selectExpr("vals + 1") - // Verify that both pre and post processing of the plan worked. - val found = df.queryExecution.executedPlan.collect { - case rep: ReplacedRowToColumnarExec => 1 - case proj: ColumnarProjectExec => 10 - case c2r: ColumnarToRowExec => 100 - }.sum - assert(found == 111) + // perform a join to inject a broadcast exchange + val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2") + val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2") + val data = left.join(right, $"l1" === $"r1") + // repartitioning avoids having the add operation pushed up into the LocalTableScan + .repartition(1) + val df = data.selectExpr("l2 + r2") + // execute the plan so that the final adaptive plan is available when AQE is on + df.collect() + val found = collectPlanSteps(df.queryExecution.executedPlan).sum + // 1 MyBroadcastExchangeExec + // 1 MyShuffleExchangeExec + // 1 ColumnarToRowExec + // 2 ColumnarProjectExec + // 1 ReplacedRowToColumnarExec + // so 11121 is expected. + assert(found == 11121) // Verify that we get back the expected, wrong, result val result = df.collect() - assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used. - assert(result(1).getLong(0) == 202L) - assert(result(2).getLong(0) == 302L) + assert(result(0).getLong(0) == 101L) // Check that broken columnar Add was used. + assert(result(1).getLong(0) == 201L) + assert(result(2).getLong(0) == 301L) } } @@ -695,6 +728,16 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = try { plan match { + case e: ShuffleExchangeExec => + // note that this is not actually columnar but demonstrates that exchanges can + // be replaced. + val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan)) + MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec]) + case e: BroadcastExchangeExec => + // note that this is not actually columnar but demonstrates that exchanges can + // be replaced. + val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan)) + MyBroadcastExchangeExec(replaced.asInstanceOf[BroadcastExchangeExec]) case plan: ProjectExec => new ColumnarProjectExec(plan.projectList.map((exp) => replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]), @@ -713,6 +756,41 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan) } +/** + * Custom Exchange used in tests to demonstrate that shuffles can be replaced regardless of + * whether AQE is enabled. + */ +case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike { + override def numMappers: Int = delegate.numMappers + override def numPartitions: Int = delegate.numPartitions + override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions + override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = + delegate.mapOutputStatisticsFuture + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = + delegate.getShuffleRDD(partitionSpecs) + override def runtimeStatistics: Statistics = delegate.runtimeStatistics + override def child: SparkPlan = delegate.child + override protected def doExecute(): RDD[InternalRow] = delegate.execute() + override def outputPartitioning: Partitioning = delegate.outputPartitioning +} + +/** + * Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of + * whether AQE is enabled. + */ +case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike { + override def runId: UUID = delegate.runId + override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] = + delegate.relationFuture + override def completionFuture: Future[Broadcast[Any]] = delegate.completionFuture + override def runtimeStatistics: Statistics = delegate.runtimeStatistics + override def child: SparkPlan = delegate.child + override protected def doPrepare(): Unit = delegate.prepare() + override protected def doExecute(): RDD[InternalRow] = delegate.execute() + override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast() + override def outputPartitioning: Partitioning = delegate.outputPartitioning +} + class ReplacedRowToColumnarExec(override val child: SparkPlan) extends RowToColumnarExec(child) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org