JkSelf commented on a change in pull request #27893: [SPARK-31134][SQL] 
optimize skew join after shuffle partitions are coalesced
URL: https://github.com/apache/spark/pull/27893#discussion_r392799589
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
 ##########
 @@ -150,72 +146,65 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
    */
   def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
     case smj @ SortMergeJoinExec(_, _, joinType, _,
-        s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _),
-        s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _), _)
+        s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
+        s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
         if supportedJoinTypes.contains(joinType) =>
-      val leftStats = getStatistics(left)
-      val rightStats = getStatistics(right)
-      val numPartitions = leftStats.bytesByPartitionId.length
-
-      val leftMedSize = medianSize(leftStats)
-      val rightMedSize = medianSize(rightStats)
+      assert(left.partitions.length == right.partitions.length)
+      val numPartitions = left.partitions.length
+      val leftShuffleId = left.shuffleStage.shuffle.shuffleDependency.shuffleId
+      val rightShuffleId = 
right.shuffleStage.shuffle.shuffleDependency.shuffleId
+      // We use the median size of the original shuffle partitions to detect 
skewed partitions.
+      val leftMedSize = medianSize(left.mapStats)
+      val rightMedSize = medianSize(right.mapStats)
       logDebug(
         s"""
           |Try to optimize skewed join.
           |Left side partition size:
-          |${getSizeInfo(leftMedSize, leftStats.bytesByPartitionId.max)}
+          |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId.max)}
           |Right side partition size:
-          |${getSizeInfo(rightMedSize, rightStats.bytesByPartitionId.max)}
+          |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId.max)}
         """.stripMargin)
       val canSplitLeft = canSplitLeftSide(joinType)
       val canSplitRight = canSplitRightSide(joinType)
-      val leftTargetSize = targetSize(leftStats, leftMedSize)
-      val rightTargetSize = targetSize(rightStats, rightMedSize)
+      // We use the actual partition sizes (may be coalesced) to calculate 
target size, so that
+      // the final data distribution is even (coalesced partitions + split 
partitions).
+      val leftSizes = left.partitions.map(_._2)
+      val rightSizes = right.partitions.map(_._2)
+      val leftTargetSize = targetSize(leftSizes, leftMedSize)
+      val rightTargetSize = targetSize(rightSizes, rightMedSize)
 
       val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
       val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
-      // This is used to delay the creation of non-skew partitions so that we 
can potentially
-      // coalesce them like `CoalesceShufflePartitions` does.
-      val nonSkewPartitionIndices = mutable.ArrayBuffer.empty[Int]
       val leftSkewDesc = new SkewDesc
       val rightSkewDesc = new SkewDesc
       for (partitionIndex <- 0 until numPartitions) {
-        val leftSize = leftStats.bytesByPartitionId(partitionIndex)
+        val leftSize = leftSizes(partitionIndex)
         val isLeftSkew = isSkewed(leftSize, leftMedSize) && canSplitLeft
-        val rightSize = rightStats.bytesByPartitionId(partitionIndex)
+        val rightSize = rightSizes(partitionIndex)
         val isRightSkew = isSkewed(rightSize, rightMedSize) && canSplitRight
         if (isLeftSkew || isRightSkew) {
-          if (nonSkewPartitionIndices.nonEmpty) {
-            // As soon as we see a skew, we'll "flush" out unhandled non-skew 
partitions.
-            createNonSkewPartitions(leftStats, rightStats, 
nonSkewPartitionIndices).foreach { p =>
-              leftSidePartitions += p
-              rightSidePartitions += p
-            }
-            nonSkewPartitionIndices.clear()
-          }
-
           val leftParts = if (isLeftSkew) {
-            val mapStartIndices = getMapStartIndices(left, partitionIndex, 
leftTargetSize)
-            if (mapStartIndices.length > 1) {
+            val CoalescedPartitionSpec(start, end) = 
left.partitions(partitionIndex)._1
+            assert(start + 1 == end, "coalesced partition should never be 
skewed.")
+            val specsAfterSplit = splitAndCreateSpecs(leftShuffleId, start, 
leftTargetSize)
+            if (specsAfterSplit.isDefined) {
               leftSkewDesc.addPartitionSize(leftSize)
-              createSkewPartitions(partitionIndex, mapStartIndices, 
getNumMappers(left))
-            } else {
-              Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
             }
+            specsAfterSplit.getOrElse(Seq(left.partitions(partitionIndex)._1))
           } else {
-            Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
+            Seq(left.partitions(partitionIndex)._1)
           }
 
           val rightParts = if (isRightSkew) {
-            val mapStartIndices = getMapStartIndices(right, partitionIndex, 
rightTargetSize)
-            if (mapStartIndices.length > 1) {
+            val CoalescedPartitionSpec(start, end) = 
right.partitions(partitionIndex)._1
 
 Review comment:
   The code in the calculation of `leftParts ` and `rightParts ` is almost 
same. It is better to wrap the code in a method.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to