hvanhovell commented on a change in pull request #27493: [SPARK-30751][SQL] 
Combine the skewed readers into one in AQE skew join optimizations
URL: https://github.com/apache/spark/pull/27493#discussion_r378554332
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
 ##########
 @@ -156,60 +155,126 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
           |${getSizeInfo(rightMedSize, rightStats.bytesByPartitionId.max)}
         """.stripMargin)
 
-      val skewedPartitions = mutable.HashSet[Int]()
-      val subJoins = mutable.ArrayBuffer[SparkPlan]()
-      for (partitionId <- 0 until numPartitions) {
-        val isLeftSkew = isSkewed(leftStats, partitionId, leftMedSize)
-        val isRightSkew = isSkewed(rightStats, partitionId, rightMedSize)
-        val leftMapIdStartIndices = if (isLeftSkew && 
supportSplitOnLeftPartition(joinType)) {
-          getMapStartIndices(left, partitionId)
-        } else {
-          Array(0)
-        }
-        val rightMapIdStartIndices = if (isRightSkew && 
supportSplitOnRightPartition(joinType)) {
-          getMapStartIndices(right, partitionId)
-        } else {
-          Array(0)
-        }
-
-        if (leftMapIdStartIndices.length > 1 || rightMapIdStartIndices.length 
> 1) {
-          skewedPartitions += partitionId
-          for (i <- 0 until leftMapIdStartIndices.length;
-               j <- 0 until rightMapIdStartIndices.length) {
-            val leftEndMapId = if (i == leftMapIdStartIndices.length - 1) {
-              getNumMappers(left)
-            } else {
-              leftMapIdStartIndices(i + 1)
+      val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+      val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+      // This is used to delay the creation of non-skew partitions so that we 
can potentially
+      // coalesce them like `ReduceNumShufflePartitions` does.
+      val nonSkewPartitionIndices = mutable.ArrayBuffer.empty[Int]
+      val leftSkewDesc = new SkewDesc
+      val rightSkewDesc = new SkewDesc
+      for (partitionIndex <- 0 until numPartitions) {
+        val leftSize = leftStats.bytesByPartitionId(partitionIndex)
+        val isLeftSkew = isSkewed(leftSize, leftMedSize) && 
canSplitLeftSide(joinType)
+        val rightSize = rightStats.bytesByPartitionId(partitionIndex)
+        val isRightSkew = isSkewed(rightSize, rightMedSize) && 
canSplitRightSide(joinType)
+        if (isLeftSkew || isRightSkew) {
+          if (nonSkewPartitionIndices.nonEmpty) {
+            // As soon as we see a skew, we'll "flush" out unhandled non-skew 
partitions.
+            createNonSkewPartitions(leftStats, rightStats, 
nonSkewPartitionIndices).foreach { p =>
+              leftSidePartitions += p
+              rightSidePartitions += p
             }
-            val rightEndMapId = if (j == rightMapIdStartIndices.length - 1) {
-              getNumMappers(right)
-            } else {
-              rightMapIdStartIndices(j + 1)
+            nonSkewPartitionIndices.clear()
+          }
+
+          val leftParts = if (isLeftSkew) {
+            leftSkewDesc.addPartitionSize(leftSize)
+            createSkewPartitions(
+              partitionIndex,
+              getMapStartIndices(left, partitionIndex),
+              getNumMappers(left))
+          } else {
+            Seq(NormalPartitionSpec(partitionIndex))
+          }
+
+          val rightParts = if (isRightSkew) {
+            rightSkewDesc.addPartitionSize(rightSize)
+            createSkewPartitions(
+              partitionIndex,
+              getMapStartIndices(right, partitionIndex),
+              getNumMappers(right))
+          } else {
+            Seq(NormalPartitionSpec(partitionIndex))
+          }
+
+          for {
+            leftSidePartition <- leftParts
+            rightSidePartition <- rightParts
+          } {
+            leftSidePartitions += leftSidePartition
+            rightSidePartitions += rightSidePartition
+          }
+        } else {
+          // Add to `nonSkewPartitionIndices` first, and add real partitions 
later, in case we can
+          // coalesce the non-skew partitions.
+          nonSkewPartitionIndices += partitionIndex
+          // If this is the last partition, add real partition immediately.
+          if (partitionIndex == numPartitions - 1) {
+            createNonSkewPartitions(leftStats, rightStats, 
nonSkewPartitionIndices).foreach { p =>
+              leftSidePartitions += p
+              rightSidePartitions += p
             }
-            // TODO: we may can optimize the sort merge join to broad cast 
join after
-            //       obtaining the raw data size of per partition,
-            val leftSkewedReader = SkewedPartitionReaderExec(
-              left, partitionId, leftMapIdStartIndices(i), leftEndMapId)
-            val rightSkewedReader = SkewedPartitionReaderExec(right, 
partitionId,
-              rightMapIdStartIndices(j), rightEndMapId)
-            subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition,
-              s1.copy(child = leftSkewedReader), s2.copy(child = 
rightSkewedReader), true)
+            nonSkewPartitionIndices.clear()
           }
         }
       }
-      logDebug(s"number of skewed partitions is ${skewedPartitions.size}")
-      if (skewedPartitions.nonEmpty) {
-        val optimizedSmj = smj.copy(
-          left = s1.copy(child = PartialShuffleReaderExec(left, 
skewedPartitions.toSet)),
-          right = s2.copy(child = PartialShuffleReaderExec(right, 
skewedPartitions.toSet)),
-          isPartial = true)
-        subJoins += optimizedSmj
-        UnionExec(subJoins)
+
+      logDebug("number of skewed partitions: " +
+        s"left ${leftSkewDesc.numPartitions}, right 
${rightSkewDesc.numPartitions}")
+      if (leftSkewDesc.numPartitions > 0 || rightSkewDesc.numPartitions > 0) {
+        val newLeft = SkewJoinShuffleReaderExec(
+          left, leftSidePartitions.toArray, leftSkewDesc.toString)
+        val newRight = SkewJoinShuffleReaderExec(
+          right, rightSidePartitions.toArray, rightSkewDesc.toString)
+        smj.copy(
+          left = s1.copy(child = newLeft), right = s2.copy(child = newRight), 
isSkewJoin = true)
       } else {
         smj
       }
   }
 
+  private def createNonSkewPartitions(
+      leftStats: MapOutputStatistics,
+      rightStats: MapOutputStatistics,
+      nonSkewPartitionIndices: Seq[Int]): Seq[ShufflePartitionSpec] = {
+    assert(nonSkewPartitionIndices.nonEmpty)
+    if (nonSkewPartitionIndices.length == 1) {
+      Seq(NormalPartitionSpec(nonSkewPartitionIndices.head))
+    } else {
+      val startIndices = ShufflePartitionsCoalescer.coalescePartitions(
 
 Review comment:
   This is pretty neat.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to