maropu commented on a change in pull request #32328:
URL: https://github.com/apache/spark/pull/32328#discussion_r620047543



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
##########
@@ -157,98 +157,121 @@ object OptimizeSkewedJoin extends 
CustomShuffleReaderRule {
    * 4. Wrap the join right child with a special shuffle reader that reads 
partition0 3 times by
    *    3 tasks separately.
    */
-  def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case smj @ SortMergeJoinExec(_, _, joinType, _,
-        s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
-        s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
-        if supportedJoinTypes.contains(joinType) =>
-      assert(left.partitionsWithSizes.length == 
right.partitionsWithSizes.length)
-      val numPartitions = left.partitionsWithSizes.length
-      // We use the median size of the original shuffle partitions to detect 
skewed partitions.
-      val leftMedSize = medianSize(left.mapStats)
-      val rightMedSize = medianSize(right.mapStats)
-      logDebug(
-        s"""
-          |Optimizing skewed join.
-          |Left side partitions size info:
-          |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
-          |Right side partitions size info:
-          |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
-        """.stripMargin)
-      val canSplitLeft = canSplitLeftSide(joinType)
-      val canSplitRight = canSplitRightSide(joinType)
-      // We use the actual partition sizes (may be coalesced) to calculate 
target size, so that
-      // the final data distribution is even (coalesced partitions + split 
partitions).
-      val leftActualSizes = left.partitionsWithSizes.map(_._2)
-      val rightActualSizes = right.partitionsWithSizes.map(_._2)
-      val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
-      val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
-
-      val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      var numSkewedLeft = 0
-      var numSkewedRight = 0
-      for (partitionIndex <- 0 until numPartitions) {
-        val leftActualSize = leftActualSizes(partitionIndex)
-        val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
-        val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
-        val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < 
leftPartSpec.endReducerIndex
-
-        val rightActualSize = rightActualSizes(partitionIndex)
-        val isRightSkew = isSkewed(rightActualSize, rightMedSize) && 
canSplitRight
-        val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
-        val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < 
rightPartSpec.endReducerIndex
-
-        // A skewed partition should never be coalesced, but skip it here just 
to be safe.
-        val leftParts = if (isLeftSkew && !isLeftCoalesced) {
-          val reducerId = leftPartSpec.startReducerIndex
-          val skewSpecs = createSkewPartitionSpecs(
-            left.mapStats.shuffleId, reducerId, leftTargetSize)
-          if (skewSpecs.isDefined) {
-            logDebug(s"Left side partition $partitionIndex " +
-              s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is 
skewed, " +
-              s"split it into ${skewSpecs.get.length} parts.")
-            numSkewedLeft += 1
-          }
-          skewSpecs.getOrElse(Seq(leftPartSpec))
-        } else {
-          Seq(leftPartSpec)
+  private def getOptimizedChildren(
+      left: ShuffleStageInfo,
+      right: ShuffleStageInfo,
+      joinType: JoinType): Option[(SparkPlan, SparkPlan)] = {
+    assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
+    val numPartitions = left.partitionsWithSizes.length
+    // We use the median size of the original shuffle partitions to detect 
skewed partitions.
+    val leftMedSize = medianSize(left.mapStats)
+    val rightMedSize = medianSize(right.mapStats)
+    logDebug(
+      s"""
+         |Optimizing skewed join.
+         |Left side partitions size info:
+         |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
+         |Right side partitions size info:
+         |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
+      """.stripMargin)
+    val canSplitLeft = canSplitLeftSide(joinType)
+    val canSplitRight = canSplitRightSide(joinType)
+    // We use the actual partition sizes (may be coalesced) to calculate 
target size, so that
+    // the final data distribution is even (coalesced partitions + split 
partitions).
+    val leftActualSizes = left.partitionsWithSizes.map(_._2)
+    val rightActualSizes = right.partitionsWithSizes.map(_._2)
+    val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
+    val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
+
+    val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+    val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+    var numSkewedLeft = 0
+    var numSkewedRight = 0
+    for (partitionIndex <- 0 until numPartitions) {
+      val leftActualSize = leftActualSizes(partitionIndex)
+      val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
+      val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
+      val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < 
leftPartSpec.endReducerIndex
+
+      val rightActualSize = rightActualSizes(partitionIndex)
+      val isRightSkew = isSkewed(rightActualSize, rightMedSize) && 
canSplitRight
+      val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
+      val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < 
rightPartSpec.endReducerIndex
+
+      // A skewed partition should never be coalesced, but skip it here just 
to be safe.
+      val leftParts = if (isLeftSkew && !isLeftCoalesced) {
+        val reducerId = leftPartSpec.startReducerIndex
+        val skewSpecs = createSkewPartitionSpecs(
+          left.mapStats.shuffleId, reducerId, leftTargetSize)
+        if (skewSpecs.isDefined) {
+          logDebug(s"Left side partition $partitionIndex " +
+            s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, 
" +
+            s"split it into ${skewSpecs.get.length} parts.")
+          numSkewedLeft += 1
         }
+        skewSpecs.getOrElse(Seq(leftPartSpec))
+      } else {
+        Seq(leftPartSpec)
+      }
 
-        // A skewed partition should never be coalesced, but skip it here just 
to be safe.
-        val rightParts = if (isRightSkew && !isRightCoalesced) {
-          val reducerId = rightPartSpec.startReducerIndex
-          val skewSpecs = createSkewPartitionSpecs(
-            right.mapStats.shuffleId, reducerId, rightTargetSize)
-          if (skewSpecs.isDefined) {
-            logDebug(s"Right side partition $partitionIndex " +
-              s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is 
skewed, " +
-              s"split it into ${skewSpecs.get.length} parts.")
-            numSkewedRight += 1
-          }
-          skewSpecs.getOrElse(Seq(rightPartSpec))
-        } else {
-          Seq(rightPartSpec)
+      // A skewed partition should never be coalesced, but skip it here just 
to be safe.
+      val rightParts = if (isRightSkew && !isRightCoalesced) {
+        val reducerId = rightPartSpec.startReducerIndex
+        val skewSpecs = createSkewPartitionSpecs(
+          right.mapStats.shuffleId, reducerId, rightTargetSize)
+        if (skewSpecs.isDefined) {
+          logDebug(s"Right side partition $partitionIndex " +
+            s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is 
skewed, " +
+            s"split it into ${skewSpecs.get.length} parts.")
+          numSkewedRight += 1
         }
+        skewSpecs.getOrElse(Seq(rightPartSpec))
+      } else {
+        Seq(rightPartSpec)
+      }
 
-        for {
-          leftSidePartition <- leftParts
-          rightSidePartition <- rightParts
-        } {
-          leftSidePartitions += leftSidePartition
-          rightSidePartitions += rightSidePartition
-        }
+      for {
+        leftSidePartition <- leftParts
+        rightSidePartition <- rightParts
+      } {
+        leftSidePartitions += leftSidePartition
+        rightSidePartitions += rightSidePartition
       }
+    }
+    logDebug(s"number of skewed partitions: left $numSkewedLeft, right 
$numSkewedRight")
+    if (numSkewedLeft > 0 || numSkewedRight > 0) {
+      Some((CustomShuffleReaderExec(left.shuffleStage, 
leftSidePartitions.toSeq),
+        CustomShuffleReaderExec(right.shuffleStage, 
rightSidePartitions.toSeq)))
+    } else {
+      None
+    }
+  }
 

Review comment:
       Ah, okay. In the last commit, I found the unnecessary change removing 
the blank line.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to