ulysses-you commented on a change in pull request #32816:
URL: https://github.com/apache/spark/pull/32816#discussion_r704911852



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
##########
@@ -38,26 +38,40 @@ import 
org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
  *                               but can be false in AQE when AQE optimization 
may change the plan
  *                               output partitioning and need to retain the 
user-specified
  *                               repartition shuffles in the plan.
+ * @param requiredDistribution The root required distribution we should 
ensure. This value is used
+ *                             in AQE in case we change final stage output 
partitioning.
  */
-case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends 
Rule[SparkPlan] {
-
-  private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
-    val requiredChildDistributions: Seq[Distribution] = 
operator.requiredChildDistribution
-    val requiredChildOrderings: Seq[Seq[SortOrder]] = 
operator.requiredChildOrdering
-    var children: Seq[SparkPlan] = operator.children
-    assert(requiredChildDistributions.length == children.length)
-    assert(requiredChildOrderings.length == children.length)
+case class EnsureRequirements(
+    optimizeOutRepartition: Boolean = true,
+    requiredDistribution: Option[Distribution] = None)
+  extends Rule[SparkPlan] {
 
+  private def ensureDistributionAndOrdering(
+      originChildren: Seq[SparkPlan],
+      requiredChildDistributions: Seq[Distribution],
+      requiredChildOrderings: Seq[Seq[SortOrder]],
+      isRootDistribution: Boolean): Seq[SparkPlan] = {
+    assert(requiredChildDistributions.length == originChildren.length)
+    assert(requiredChildOrderings.length == originChildren.length)
     // Ensure that the operator's children satisfy their output distribution 
requirements.
-    children = children.zip(requiredChildDistributions).map {
+    var children = originChildren.zip(requiredChildDistributions).map {
       case (child, distribution) if 
child.outputPartitioning.satisfies(distribution) =>
         child
       case (child, BroadcastDistribution(mode)) =>
         BroadcastExchangeExec(mode, child)
       case (child, distribution) =>
         val numPartitions = distribution.requiredNumPartitions
           .getOrElse(conf.numShufflePartitions)
-        ShuffleExchangeExec(distribution.createPartitioning(numPartitions), 
child)
+        val shuffleOrigin = if (isRootDistribution) {

Review comment:
       agree, pulled out the shuffle origin to make thie more cleaner.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to