cloud-fan commented on a change in pull request #26289: 
[SPARK-28560][SQL][followup] support the build side to local shuffle reader as 
far as possible in BroadcastHashJoin
URL: https://github.com/apache/spark/pull/26289#discussion_r340052732
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
 ##########
 @@ -24,31 +24,33 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, 
BuildRight}
 import org.apache.spark.sql.internal.SQLConf
 
 case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
 
-  def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = {
-    join.buildSide == BuildRight &&  
ShuffleQueryStageExec.isShuffleQueryStageExec(join.left)
-  }
-
-  def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = {
-    join.buildSide == BuildLeft &&  
ShuffleQueryStageExec.isShuffleQueryStageExec(join.right)
-  }
-
   override def apply(plan: SparkPlan): SparkPlan = {
     if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) {
       return plan
     }
 
-    val optimizedPlan = plan.transformDown {
-      case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) 
=>
-        val localReader = 
LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
-        join.copy(right = localReader)
-      case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) =>
-        val localReader = 
LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
-        join.copy(left = localReader)
+    def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = 
plan match {
+      case _: LocalShuffleReaderExec => Nil
+      case stage: ShuffleQueryStageExec => Seq(stage)
+      case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => 
Seq(stage)
+      case _ => plan.children.flatMap(collectShuffleStages)
+    }
+    val shuffleStages = collectShuffleStages(plan)
+
+    val optimizedPlan = if (shuffleStages.isEmpty ||
+      !shuffleStages.forall(_.plan.canChangeNumPartitions)) {
 
 Review comment:
   This is different from `ReduceNumShufflePartitions`. 
`ReduceNumShufflePartitions` needs to change all the shuffles together, so as 
long as there is a user-added shuffle, we need to skip it.
   
   `OptimizeLocalShuffleReader` can add local reader to any shuffle, so it's 
simple
   ```
   private def canAddLocalReader(stage: QueryStage): Boolean = stage match {
     case s: ShuffleQueryStage => s.plan.canChangeNumPartitions
     case ReusedQueryStage(s: ShuffleQueryStage) => 
s.plan.canChangeNumPartitions
   }
   
   plan.transformUp {
     case stage: QueryStageExec if canAddLocalReader(stage) =>
       LocalShuffleReaderExec(stage)
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to