cloud-fan commented on a change in pull request #32875:
URL: https://github.com/apache/spark/pull/32875#discussion_r761769999



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
##########
@@ -70,61 +70,98 @@ case class EnsureRequirements(
     val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
       case (UnspecifiedDistribution, _) => false
       case (_: BroadcastDistribution, _) => false
+      case (AllTuples, _) => false
       case _ => true
     }.map(_._2)
 
-    val childrenNumPartitions =
-      childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
-
-    if (childrenNumPartitions.size > 1) {
-      // Get the number of partitions which is explicitly required by the 
distributions.
-      val requiredNumPartitions = {
-        val numPartitionsSet = childrenIndexes.flatMap {
-          index => requiredChildDistributions(index).requiredNumPartitions
-        }.toSet
-        assert(numPartitionsSet.size <= 1,
-          s"$requiredChildDistributions have incompatible requirements of the 
number of partitions")
-        numPartitionsSet.headOption
+    // If there are more than one children, we'll need to check partitioning & 
distribution of them
+    // and see if extra shuffles are necessary.
+    if (childrenIndexes.length > 1) {
+      childrenIndexes.map(requiredChildDistributions(_)).foreach { d =>
+        if (!d.isInstanceOf[ClusteredDistribution]) {
+          throw new IllegalStateException(s"Expected ClusteredDistribution but 
found " +
+              s"${d.getClass.getSimpleName}")
+        }
       }
+      val specs = childrenIndexes.map(i =>
+        i -> children(i).outputPartitioning.createShuffleSpec(
+          requiredChildDistributions(i).asInstanceOf[ClusteredDistribution])
+      ).toMap

Review comment:
       nit: we can combine the above code
   ```
   val specs = childrenIndexes.map { i =>
     val dist = requiredChildDistributions(i)
     assert(dist.isInstanceOf[ClusteredDistribution])
     i -> children(i).outputPartitioning.createShuffleSpec(dist)
   }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to