cloud-fan commented on a change in pull request #26289:
[SPARK-28560][SQL][followup] support the build side to local shuffle reader as
far as possible in BroadcastHashJoin
URL: https://github.com/apache/spark/pull/26289#discussion_r340052732
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
##########
@@ -24,31 +24,33 @@ import
org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements,
ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft,
BuildRight}
import org.apache.spark.sql.internal.SQLConf
case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
- def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = {
- join.buildSide == BuildRight &&
ShuffleQueryStageExec.isShuffleQueryStageExec(join.left)
- }
-
- def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = {
- join.buildSide == BuildLeft &&
ShuffleQueryStageExec.isShuffleQueryStageExec(join.right)
- }
-
override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) {
return plan
}
- val optimizedPlan = plan.transformDown {
- case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join)
=>
- val localReader =
LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
- join.copy(right = localReader)
- case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) =>
- val localReader =
LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
- join.copy(left = localReader)
+ def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] =
plan match {
+ case _: LocalShuffleReaderExec => Nil
+ case stage: ShuffleQueryStageExec => Seq(stage)
+ case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) =>
Seq(stage)
+ case _ => plan.children.flatMap(collectShuffleStages)
+ }
+ val shuffleStages = collectShuffleStages(plan)
+
+ val optimizedPlan = if (shuffleStages.isEmpty ||
+ !shuffleStages.forall(_.plan.canChangeNumPartitions)) {
Review comment:
This is different from `ReduceNumShufflePartitions`.
`ReduceNumShufflePartitions` needs to change all the shuffles together, so as
long as there is a user-added shuffle, we need to skip it.
`OptimizeLocalShuffleReader` can add local reader to any shuffle, so it's
simple
```
private def canAddLocalReader(stage: QueryStage): Boolean = stage match {
case s: ShuffleQueryStage => s.plan.canChangeNumPartitions
case ReusedQueryStage(s: ShuffleQueryStage) =>
s.plan.canChangeNumPartitions
}
plan.transformUp {
case stage: QueryStageExec if canAddLocalReader(stage) =>
LocalShuffleReaderExec(stage)
}
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]