This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 78bd4b3  [SPARK-30751][SQL] Combine the skewed readers into one in AQE 
skew join optimizations
78bd4b3 is described below

commit 78bd4b34ca7e0834b8b3878cd74b3f59b46b4f90
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu Feb 13 20:09:24 2020 +0100

    [SPARK-30751][SQL] Combine the skewed readers into one in AQE skew join 
optimizations
    
    <!--
    Thanks for sending a pull request!  Here are some tips for you:
      1. If this is your first time, please read our contributor guidelines: 
https://spark.apache.org/contributing.html
      2. Ensure you have added or run the appropriate tests for your PR: 
https://spark.apache.org/developer-tools.html
      3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., 
'[WIP][SPARK-XXXX] Your PR title ...'.
      4. Be sure to keep the PR description updated to reflect all changes.
      5. Please write your PR title to summarize what this PR proposes.
      6. If possible, provide a concise example to reproduce the issue for a 
faster review.
    -->
    
    ### What changes were proposed in this pull request?
    <!--
    Please clarify what changes you are proposing. The purpose of this section 
is to outline the changes and how this PR fixes the issue.
    If possible, please consider writing useful notes for better and faster 
reviews in your PR. See the examples below.
      1. If you refactor some codes with changing classes, showing the class 
hierarchy will help reviewers.
      2. If you fix some SQL features, you can provide some references of other 
DBMSes.
      3. If there is design documentation, please add the link.
      4. If there is a discussion in the mailing list, please add the link.
    -->
    This is a followup of https://github.com/apache/spark/pull/26434
    
    This PR use one special shuffle reader for skew join, so that we only have 
one join after optimization. In order to do that, this PR
    1. add a very general `CustomShuffledRowRDD` which support all kind of 
partition arrangement.
    2. move the logic of coalescing shuffle partitions to a util function, and 
call it during skew join optimization, to totally decouple with the 
`ReduceNumShufflePartitions` rule. It's too complicated to interfere skew join 
with `ReduceNumShufflePartitions`, as you need to consider the size of split 
partitions which don't respect target size already.
    
    ### Why are the changes needed?
    <!--
    Please clarify why the changes are needed. For instance,
      1. If you propose a new API, clarify the use case for a new API.
      2. If you fix a bug, you can clarify why it is a bug.
    -->
    The current skew join optimization has a serious performance issue: the 
size of the query plan depends on the number and size of skewed partitions.
    
    ### Does this PR introduce any user-facing change?
    <!--
    If yes, please clarify the previous behavior and the change this PR 
proposes - provide the console output, description and/or an example to show 
the behavior difference if possible.
    If no, write 'No'.
    -->
    no
    
    ### How was this patch tested?
    <!--
    If tests were added, say they were added here. Please make sure to add some 
test cases that check the changes thoroughly including negative and positive 
cases if possible.
    If it was tested in a way different from regular unit tests, please clarify 
how you tested step by step, ideally copy and paste-able, so that other 
reviewers can test and check, and descendants can verify in the future.
    If tests were not added, please describe why they were not added and/or why 
it was difficult to add.
    -->
    existing tests
    
    test UI manually:
    
![image](https://user-images.githubusercontent.com/3182036/74357390-cfb30480-4dfa-11ea-83f6-825d1b9379ca.png)
    
    explain output
    ```
    AdaptiveSparkPlan(isFinalPlan=true)
    +- OverwriteByExpression 
org.apache.spark.sql.execution.datasources.noop.NoopTable$403a2ed5, 
[AlwaysTrue()], org.apache.spark.sql.util.CaseInsensitiveStringMap1f
       +- *(5) SortMergeJoin(skew=true) [key1#2L], [key2#6L], Inner
          :- *(3) Sort [key1#2L ASC NULLS FIRST], false, 0
          :  +- SkewJoinShuffleReader 2 skewed partitions with size(max=5 KB, 
min=5 KB, avg=5 KB)
          :     +- ShuffleQueryStage 0
          :        +- Exchange hashpartitioning(key1#2L, 200), true, [id=#53]
          :           +- *(1) Project [(id#0L % 2) AS key1#2L]
          :              +- *(1) Filter isnotnull((id#0L % 2))
          :                 +- *(1) Range (0, 100000, step=1, splits=6)
          +- *(4) Sort [key2#6L ASC NULLS FIRST], false, 0
             +- SkewJoinShuffleReader 2 skewed partitions with size(max=5 KB, 
min=5 KB, avg=5 KB)
                +- ShuffleQueryStage 1
                   +- Exchange hashpartitioning(key2#6L, 200), true, [id=#64]
                      +- *(2) Project [((id#4L % 2) + 1) AS key2#6L]
                         +- *(2) Filter isnotnull(((id#4L % 2) + 1))
                            +- *(2) Range (0, 100000, step=1, splits=6)
    ```
    
    Closes #27493 from cloud-fan/aqe.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: herman <her...@databricks.com>
    (cherry picked from commit a4ceea6868002b88161517b14b94a2006be8af1b)
    Signed-off-by: herman <her...@databricks.com>
---
 .../spark/sql/execution/ShuffledRowRDD.scala       |  23 +-
 .../execution/adaptive/CustomShuffledRowRDD.scala  | 113 +++++++++
 .../adaptive/OptimizeLocalShuffleReader.scala      |   2 +-
 .../execution/adaptive/OptimizeSkewedJoin.scala    | 276 +++++++++++++--------
 .../adaptive/ReduceNumShufflePartitions.scala      | 157 ++----------
 .../adaptive/ShufflePartitionsCoalescer.scala      | 112 +++++++++
 .../execution/adaptive/SkewedShuffledRowRDD.scala  |  78 ------
 .../execution/exchange/ShuffleExchangeExec.scala   |  21 +-
 .../sql/execution/joins/SortMergeJoinExec.scala    |  13 +-
 .../ReduceNumShufflePartitionsSuite.scala          | 210 +---------------
 .../ShufflePartitionsCoalescerSuite.scala          | 220 ++++++++++++++++
 .../adaptive/AdaptiveQueryExecSuite.scala          | 219 ++++++----------
 12 files changed, 741 insertions(+), 703 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index efa4939..4c19f95 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -116,7 +116,7 @@ class CoalescedPartitioner(val parent: Partitioner, val 
partitionStartIndices: A
 class ShuffledRowRDD(
     var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
     metrics: Map[String, SQLMetric],
-    specifiedPartitionIndices: Option[Array[(Int, Int)]] = None)
+    specifiedPartitionStartIndices: Option[Array[Int]] = None)
   extends RDD[InternalRow](dependency.rdd.context, Nil) {
 
   if (SQLConf.get.fetchShuffleBlocksInBatchEnabled) {
@@ -126,8 +126,8 @@ class ShuffledRowRDD(
 
   private[this] val numPreShufflePartitions = 
dependency.partitioner.numPartitions
 
-  private[this] val partitionStartIndices: Array[Int] = 
specifiedPartitionIndices match {
-    case Some(indices) => indices.map(_._1)
+  private[this] val partitionStartIndices: Array[Int] = 
specifiedPartitionStartIndices match {
+    case Some(indices) => indices
     case None =>
       // When specifiedPartitionStartIndices is not defined, every 
post-shuffle partition
       // corresponds to a pre-shuffle partition.
@@ -142,15 +142,16 @@ class ShuffledRowRDD(
   override val partitioner: Option[Partitioner] = Some(part)
 
   override def getPartitions: Array[Partition] = {
-    specifiedPartitionIndices match {
-      case Some(indices) =>
-        Array.tabulate[Partition](indices.length) { i =>
-          new ShuffledRowRDDPartition(i, indices(i)._1, indices(i)._2)
-        }
-      case None =>
-        Array.tabulate[Partition](numPreShufflePartitions) { i =>
-          new ShuffledRowRDDPartition(i, i, i + 1)
+    assert(partitionStartIndices.length == part.numPartitions)
+    Array.tabulate[Partition](partitionStartIndices.length) { i =>
+      val startIndex = partitionStartIndices(i)
+      val endIndex =
+        if (i < partitionStartIndices.length - 1) {
+          partitionStartIndices(i + 1)
+        } else {
+          numPreShufflePartitions
         }
+      new ShuffledRowRDDPartition(i, startIndex, endIndex)
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala
new file mode 100644
index 0000000..5aba574
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, 
ShuffleDependency, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.metric.{SQLMetric, 
SQLShuffleReadMetricsReporter}
+
+sealed trait ShufflePartitionSpec
+
+// A partition that reads data of one reducer.
+case class SinglePartitionSpec(reducerIndex: Int) extends ShufflePartitionSpec
+
+// A partition that reads data of multiple reducers, from `startReducerIndex` 
(inclusive) to
+// `endReducerIndex` (exclusive).
+case class CoalescedPartitionSpec(
+    startReducerIndex: Int, endReducerIndex: Int) extends ShufflePartitionSpec
+
+// A partition that reads partial data of one reducer, from `startMapIndex` 
(inclusive) to
+// `endMapIndex` (exclusive).
+case class PartialPartitionSpec(
+    reducerIndex: Int, startMapIndex: Int, endMapIndex: Int) extends 
ShufflePartitionSpec
+
+private final case class CustomShufflePartition(
+    index: Int, spec: ShufflePartitionSpec) extends Partition
+
+// TODO: merge this with `ShuffledRowRDD`, and replace `LocalShuffledRowRDD` 
with this RDD.
+class CustomShuffledRowRDD(
+    var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
+    metrics: Map[String, SQLMetric],
+    partitionSpecs: Array[ShufflePartitionSpec])
+  extends RDD[InternalRow](dependency.rdd.context, Nil) {
+
+  override def getDependencies: Seq[Dependency[_]] = List(dependency)
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    dependency = null
+  }
+
+  override def getPartitions: Array[Partition] = {
+    Array.tabulate[Partition](partitionSpecs.length) { i =>
+      CustomShufflePartition(i, partitionSpecs(i))
+    }
+  }
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    val tracker = 
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+    partition.asInstanceOf[CustomShufflePartition].spec match {
+      case SinglePartitionSpec(reducerIndex) =>
+        tracker.getPreferredLocationsForShuffle(dependency, reducerIndex)
+
+      case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+        startReducerIndex.until(endReducerIndex).flatMap { reducerIndex =>
+          tracker.getPreferredLocationsForShuffle(dependency, reducerIndex)
+        }
+
+      case PartialPartitionSpec(_, startMapIndex, endMapIndex) =>
+        tracker.getMapLocation(dependency, startMapIndex, endMapIndex)
+    }
+  }
+
+  override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
+    val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+    // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL 
exchange operator,
+    // as well as the `tempMetrics` for basic shuffle metrics.
+    val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, 
metrics)
+    val reader = split.asInstanceOf[CustomShufflePartition].spec match {
+      case SinglePartitionSpec(reducerIndex) =>
+        SparkEnv.get.shuffleManager.getReader(
+          dependency.shuffleHandle,
+          reducerIndex,
+          reducerIndex + 1,
+          context,
+          sqlMetricsReporter)
+
+      case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+        SparkEnv.get.shuffleManager.getReader(
+          dependency.shuffleHandle,
+          startReducerIndex,
+          endReducerIndex,
+          context,
+          sqlMetricsReporter)
+
+      case PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex) =>
+        SparkEnv.get.shuffleManager.getReaderForRange(
+          dependency.shuffleHandle,
+          startMapIndex,
+          endMapIndex,
+          reducerIndex,
+          reducerIndex + 1,
+          context,
+          sqlMetricsReporter)
+    }
+    reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
index a8d8f35..e95441e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
@@ -71,7 +71,7 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends 
Rule[SparkPlan] {
     plan match {
       case c @ CoalescedShuffleReaderExec(s: ShuffleQueryStageExec, _) =>
         LocalShuffleReaderExec(
-          s, getPartitionStartIndices(s, Some(c.partitionIndices.length)))
+          s, getPartitionStartIndices(s, Some(c.partitionStartIndices.length)))
       case s: ShuffleQueryStageExec =>
         LocalShuffleReaderExec(s, getPartitionStartIndices(s, None))
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index 74b7fbd..a716497 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.commons.io.FileUtils
+
 import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -44,11 +46,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
    * partition size * spark.sql.adaptive.skewedPartitionFactor and also larger 
than
    * spark.sql.adaptive.skewedPartitionSizeThreshold.
    */
-  private def isSkewed(
-      stats: MapOutputStatistics,
-      partitionId: Int,
-      medianSize: Long): Boolean = {
-    val size = stats.bytesByPartitionId(partitionId)
+  private def isSkewed(size: Long, medianSize: Long): Boolean = {
     size > medianSize * 
conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR) &&
       size > 
conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD)
   }
@@ -108,12 +106,12 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
     stage.resultOption.get.asInstanceOf[MapOutputStatistics]
   }
 
-  private def supportSplitOnLeftPartition(joinType: JoinType) = {
+  private def canSplitLeftSide(joinType: JoinType) = {
     joinType == Inner || joinType == Cross || joinType == LeftSemi ||
       joinType == LeftAnti || joinType == LeftOuter
   }
 
-  private def supportSplitOnRightPartition(joinType: JoinType) = {
+  private def canSplitRightSide(joinType: JoinType) = {
     joinType == Inner || joinType == Cross || joinType == RightOuter
   }
 
@@ -130,17 +128,18 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
    * 1. Check whether the shuffle partition is skewed based on the median size
    *    and the skewed partition threshold in origin smj.
    * 2. Assuming partition0 is skewed in left side, and it has 5 mappers 
(Map0, Map1...Map4).
-   *    And we will split the 5 Mappers into 3 mapper ranges [(Map0, Map1), 
(Map2, Map3), (Map4)]
+   *    And we may split the 5 Mappers into 3 mapper ranges [(Map0, Map1), 
(Map2, Map3), (Map4)]
    *    based on the map size and the max split number.
-   * 3. Create the 3 smjs with separately reading the above mapper ranges and 
then join with
-   *    the Partition0 in right side.
-   * 4. Finally union the above 3 split smjs and the origin smj.
+   * 3. Wrap the join left child with a special shuffle reader that reads each 
mapper range with one
+   *    task, so total 3 tasks.
+   * 4. Wrap the join right child with a special shuffle reader that reads 
partition0 3 times by
+   *    3 tasks separately.
    */
   def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
+    case smj @ SortMergeJoinExec(_, _, joinType, _,
         s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _),
         s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _), _)
-      if (supportedJoinTypes.contains(joinType)) =>
+        if supportedJoinTypes.contains(joinType) =>
       val leftStats = getStatistics(left)
       val rightStats = getStatistics(right)
       val numPartitions = leftStats.bytesByPartitionId.length
@@ -155,61 +154,134 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
           |Right side partition size:
           |${getSizeInfo(rightMedSize, rightStats.bytesByPartitionId.max)}
         """.stripMargin)
+      val canSplitLeft = canSplitLeftSide(joinType)
+      val canSplitRight = canSplitRightSide(joinType)
+
+      val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+      val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
+      // This is used to delay the creation of non-skew partitions so that we 
can potentially
+      // coalesce them like `ReduceNumShufflePartitions` does.
+      val nonSkewPartitionIndices = mutable.ArrayBuffer.empty[Int]
+      val leftSkewDesc = new SkewDesc
+      val rightSkewDesc = new SkewDesc
+      for (partitionIndex <- 0 until numPartitions) {
+        val leftSize = leftStats.bytesByPartitionId(partitionIndex)
+        val isLeftSkew = isSkewed(leftSize, leftMedSize) && canSplitLeft
+        val rightSize = rightStats.bytesByPartitionId(partitionIndex)
+        val isRightSkew = isSkewed(rightSize, rightMedSize) && canSplitRight
+        if (isLeftSkew || isRightSkew) {
+          if (nonSkewPartitionIndices.nonEmpty) {
+            // As soon as we see a skew, we'll "flush" out unhandled non-skew 
partitions.
+            createNonSkewPartitions(leftStats, rightStats, 
nonSkewPartitionIndices).foreach { p =>
+              leftSidePartitions += p
+              rightSidePartitions += p
+            }
+            nonSkewPartitionIndices.clear()
+          }
 
-      val skewedPartitions = mutable.HashSet[Int]()
-      val subJoins = mutable.ArrayBuffer[SparkPlan]()
-      for (partitionId <- 0 until numPartitions) {
-        val isLeftSkew = isSkewed(leftStats, partitionId, leftMedSize)
-        val isRightSkew = isSkewed(rightStats, partitionId, rightMedSize)
-        val leftMapIdStartIndices = if (isLeftSkew && 
supportSplitOnLeftPartition(joinType)) {
-          getMapStartIndices(left, partitionId)
-        } else {
-          Array(0)
-        }
-        val rightMapIdStartIndices = if (isRightSkew && 
supportSplitOnRightPartition(joinType)) {
-          getMapStartIndices(right, partitionId)
-        } else {
-          Array(0)
-        }
+          val leftParts = if (isLeftSkew) {
+            leftSkewDesc.addPartitionSize(leftSize)
+            createSkewPartitions(
+              partitionIndex,
+              getMapStartIndices(left, partitionIndex),
+              getNumMappers(left))
+          } else {
+            Seq(SinglePartitionSpec(partitionIndex))
+          }
 
-        if (leftMapIdStartIndices.length > 1 || rightMapIdStartIndices.length 
> 1) {
-          skewedPartitions += partitionId
-          for (i <- 0 until leftMapIdStartIndices.length;
-               j <- 0 until rightMapIdStartIndices.length) {
-            val leftEndMapId = if (i == leftMapIdStartIndices.length - 1) {
-              getNumMappers(left)
-            } else {
-              leftMapIdStartIndices(i + 1)
-            }
-            val rightEndMapId = if (j == rightMapIdStartIndices.length - 1) {
-              getNumMappers(right)
-            } else {
-              rightMapIdStartIndices(j + 1)
+          val rightParts = if (isRightSkew) {
+            rightSkewDesc.addPartitionSize(rightSize)
+            createSkewPartitions(
+              partitionIndex,
+              getMapStartIndices(right, partitionIndex),
+              getNumMappers(right))
+          } else {
+            Seq(SinglePartitionSpec(partitionIndex))
+          }
+
+          for {
+            leftSidePartition <- leftParts
+            rightSidePartition <- rightParts
+          } {
+            leftSidePartitions += leftSidePartition
+            rightSidePartitions += rightSidePartition
+          }
+        } else {
+          // Add to `nonSkewPartitionIndices` first, and add real partitions 
later, in case we can
+          // coalesce the non-skew partitions.
+          nonSkewPartitionIndices += partitionIndex
+          // If this is the last partition, add real partition immediately.
+          if (partitionIndex == numPartitions - 1) {
+            createNonSkewPartitions(leftStats, rightStats, 
nonSkewPartitionIndices).foreach { p =>
+              leftSidePartitions += p
+              rightSidePartitions += p
             }
-            // TODO: we may can optimize the sort merge join to broad cast 
join after
-            //       obtaining the raw data size of per partition,
-            val leftSkewedReader = SkewedPartitionReaderExec(
-              left, partitionId, leftMapIdStartIndices(i), leftEndMapId)
-            val rightSkewedReader = SkewedPartitionReaderExec(right, 
partitionId,
-              rightMapIdStartIndices(j), rightEndMapId)
-            subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition,
-              s1.copy(child = leftSkewedReader), s2.copy(child = 
rightSkewedReader), true)
+            nonSkewPartitionIndices.clear()
           }
         }
       }
-      logDebug(s"number of skewed partitions is ${skewedPartitions.size}")
-      if (skewedPartitions.nonEmpty) {
-        val optimizedSmj = smj.copy(
-          left = s1.copy(child = PartialShuffleReaderExec(left, 
skewedPartitions.toSet)),
-          right = s2.copy(child = PartialShuffleReaderExec(right, 
skewedPartitions.toSet)),
-          isPartial = true)
-        subJoins += optimizedSmj
-        UnionExec(subJoins)
+
+      logDebug("number of skewed partitions: " +
+        s"left ${leftSkewDesc.numPartitions}, right 
${rightSkewDesc.numPartitions}")
+      if (leftSkewDesc.numPartitions > 0 || rightSkewDesc.numPartitions > 0) {
+        val newLeft = SkewJoinShuffleReaderExec(
+          left, leftSidePartitions.toArray, leftSkewDesc.toString)
+        val newRight = SkewJoinShuffleReaderExec(
+          right, rightSidePartitions.toArray, rightSkewDesc.toString)
+        smj.copy(
+          left = s1.copy(child = newLeft), right = s2.copy(child = newRight), 
isSkewJoin = true)
       } else {
         smj
       }
   }
 
+  private def createNonSkewPartitions(
+      leftStats: MapOutputStatistics,
+      rightStats: MapOutputStatistics,
+      nonSkewPartitionIndices: Seq[Int]): Seq[ShufflePartitionSpec] = {
+    assert(nonSkewPartitionIndices.nonEmpty)
+    if (nonSkewPartitionIndices.length == 1) {
+      Seq(SinglePartitionSpec(nonSkewPartitionIndices.head))
+    } else {
+      val startIndices = ShufflePartitionsCoalescer.coalescePartitions(
+        Array(leftStats, rightStats),
+        firstPartitionIndex = nonSkewPartitionIndices.head,
+        // `lastPartitionIndex` is exclusive.
+        lastPartitionIndex = nonSkewPartitionIndices.last + 1,
+        advisoryTargetSize = conf.targetPostShuffleInputSize)
+      startIndices.indices.map { i =>
+        val startIndex = startIndices(i)
+        val endIndex = if (i == startIndices.length - 1) {
+          // `endIndex` is exclusive.
+          nonSkewPartitionIndices.last + 1
+        } else {
+          startIndices(i + 1)
+        }
+        // Do not create `CoalescedPartitionSpec` if only need to read a singe 
partition.
+        if (startIndex + 1 == endIndex) {
+          SinglePartitionSpec(startIndex)
+        } else {
+          CoalescedPartitionSpec(startIndex, endIndex)
+        }
+      }
+    }
+  }
+
+  private def createSkewPartitions(
+      reducerIndex: Int,
+      mapStartIndices: Array[Int],
+      numMappers: Int): Seq[PartialPartitionSpec] = {
+    mapStartIndices.indices.map { i =>
+      val startMapIndex = mapStartIndices(i)
+      val endMapIndex = if (i == mapStartIndices.length - 1) {
+        numMappers
+      } else {
+        mapStartIndices(i + 1)
+      }
+      PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex)
+    }
+  }
+
   override def apply(plan: SparkPlan): SparkPlan = {
     if (!conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED)) {
       return plan
@@ -248,79 +320,69 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
   }
 }
 
-/**
- * A wrapper of shuffle query stage, which submits one reduce task to read a 
single
- * shuffle partition 'partitionIndex' produced by the mappers in range 
[startMapIndex, endMapIndex).
- * This is used to increase the parallelism when reading skewed partitions.
- *
- * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle 
exchange
- *              node during canonicalization.
- * @param partitionIndex The pre shuffle partition index.
- * @param startMapIndex The start map index.
- * @param endMapIndex The end map index.
- */
-case class SkewedPartitionReaderExec(
-    child: QueryStageExec,
-    partitionIndex: Int,
-    startMapIndex: Int,
-    endMapIndex: Int) extends LeafExecNode {
+private class SkewDesc {
+  private[this] var numSkewedPartitions: Int = 0
+  private[this] var totalSize: Long = 0
+  private[this] var maxSize: Long = 0
+  private[this] var minSize: Long = 0
 
-  override def output: Seq[Attribute] = child.output
+  def numPartitions: Int = numSkewedPartitions
 
-  override def outputPartitioning: Partitioning = {
-    UnknownPartitioning(1)
+  def addPartitionSize(size: Long): Unit = {
+    if (numSkewedPartitions == 0) {
+      maxSize = size
+      minSize = size
+    }
+    numSkewedPartitions += 1
+    totalSize += size
+    if (size > maxSize) maxSize = size
+    if (size < minSize) minSize = size
   }
-  private var cachedSkewedShuffleRDD: SkewedShuffledRowRDD = null
 
-  override def doExecute(): RDD[InternalRow] = {
-    if (cachedSkewedShuffleRDD == null) {
-      cachedSkewedShuffleRDD = child match {
-        case stage: ShuffleQueryStageExec =>
-          stage.shuffle.createSkewedShuffleRDD(partitionIndex, startMapIndex, 
endMapIndex)
-        case _ =>
-          throw new IllegalStateException("operating on canonicalization plan")
-      }
+  override def toString: String = {
+    if (numSkewedPartitions == 0) {
+      "no skewed partition"
+    } else {
+      val maxSizeStr = FileUtils.byteCountToDisplaySize(maxSize)
+      val minSizeStr = FileUtils.byteCountToDisplaySize(minSize)
+      val avgSizeStr = FileUtils.byteCountToDisplaySize(totalSize / 
numSkewedPartitions)
+      s"$numSkewedPartitions skewed partitions with " +
+        s"size(max=$maxSizeStr, min=$minSizeStr, avg=$avgSizeStr)"
     }
-    cachedSkewedShuffleRDD
   }
 }
 
 /**
- * A wrapper of shuffle query stage, which skips some partitions when reading 
the shuffle blocks.
+ * A wrapper of shuffle query stage, which follows the given partition 
arrangement.
  *
  * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle 
exchange node during
  *              canonicalization.
- * @param excludedPartitions The partitions to skip when reading.
+ * @param partitionSpecs The partition specs that defines the arrangement.
+ * @param skewDesc The description of the skewed partitions.
  */
-case class PartialShuffleReaderExec(
-    child: QueryStageExec,
-    excludedPartitions: Set[Int]) extends UnaryExecNode {
+case class SkewJoinShuffleReaderExec(
+    child: SparkPlan,
+    partitionSpecs: Array[ShufflePartitionSpec],
+    skewDesc: String) extends UnaryExecNode {
 
   override def output: Seq[Attribute] = child.output
 
   override def outputPartitioning: Partitioning = {
-    UnknownPartitioning(1)
+    UnknownPartitioning(partitionSpecs.length)
   }
 
-  private def shuffleExchange(): ShuffleExchangeExec = child match {
-    case stage: ShuffleQueryStageExec => stage.shuffle
-    case _ =>
-      throw new IllegalStateException("operating on canonicalization plan")
-  }
-
-  private def getPartitionIndexRanges(): Array[(Int, Int)] = {
-    val length = shuffleExchange().shuffleDependency.partitioner.numPartitions
-    (0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 
1)).toArray
-  }
+  override def stringArgs: Iterator[Any] = Iterator(skewDesc)
 
   private var cachedShuffleRDD: RDD[InternalRow] = null
 
-  override def doExecute(): RDD[InternalRow] = {
+  override protected def doExecute(): RDD[InternalRow] = {
     if (cachedShuffleRDD == null) {
-      cachedShuffleRDD = if (excludedPartitions.isEmpty) {
-        child.execute()
-      } else {
-        shuffleExchange().createShuffledRDD(Some(getPartitionIndexRanges()))
+      cachedShuffleRDD = child match {
+        case stage: ShuffleQueryStageExec =>
+          new CustomShuffledRowRDD(
+            stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, 
partitionSpecs)
+        case _ =>
+          throw new IllegalStateException("operating on canonicalization plan")
       }
     }
     cachedShuffleRDD
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
index 2c50b63..5bbcb14 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution.adaptive
 
-import scala.collection.mutable.{ArrayBuffer, HashSet}
-
 import org.apache.spark.MapOutputStatistics
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -29,24 +27,8 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, 
SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.internal.SQLConf
 
 /**
- * A rule to adjust the post shuffle partitions based on the map output 
statistics.
- *
- * The strategy used to determine the number of post-shuffle partitions is 
described as follows.
- * To determine the number of post-shuffle partitions, we have a target input 
size for a
- * post-shuffle partition. Once we have size statistics of all pre-shuffle 
partitions, we will do
- * a pass of those statistics and pack pre-shuffle partitions with continuous 
indices to a single
- * post-shuffle partition until adding another pre-shuffle partition would 
cause the size of a
- * post-shuffle partition to be greater than the target size.
- *
- * For example, we have two stages with the following pre-shuffle partition 
size statistics:
- * stage 1: [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB]
- * stage 2: [10 MiB,  10 MiB, 70 MiB,  5 MiB, 5 MiB]
- * assuming the target input size is 128 MiB, we will have four post-shuffle 
partitions,
- * which are:
- *  - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MiB)
- *  - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MiB)
- *  - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MiB)
- *  - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MiB)
+ * A rule to reduce the post shuffle partitions based on the map output 
statistics, which can
+ * avoid many small reduce tasks that hurt performance.
  */
 case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
 
@@ -54,28 +36,21 @@ case class ReduceNumShufflePartitions(conf: SQLConf) 
extends Rule[SparkPlan] {
     if (!conf.reducePostShufflePartitionsEnabled) {
       return plan
     }
-    // 'SkewedShufflePartitionReader' is added by us, so it's safe to ignore 
it when changing
-    // number of reducers.
-    val leafNodes = 
plan.collectLeaves().filter(!_.isInstanceOf[SkewedPartitionReaderExec])
-    if (!leafNodes.forall(_.isInstanceOf[QueryStageExec])) {
+    if (!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec])) {
       // If not all leaf nodes are query stages, it's not safe to reduce the 
number of
       // shuffle partitions, because we may break the assumption that all 
children of a spark plan
       // have same number of output partitions.
       return plan
     }
 
-    def collectShuffles(plan: SparkPlan): Seq[SparkPlan] = plan match {
+    def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = 
plan match {
       case _: LocalShuffleReaderExec => Nil
-      case p: PartialShuffleReaderExec => Seq(p)
+      case _: SkewJoinShuffleReaderExec => Nil
       case stage: ShuffleQueryStageExec => Seq(stage)
-      case _ => plan.children.flatMap(collectShuffles)
+      case _ => plan.children.flatMap(collectShuffleStages)
     }
 
-    val shuffles = collectShuffles(plan)
-    val shuffleStages = shuffles.map {
-      case PartialShuffleReaderExec(s: ShuffleQueryStageExec, _) => s
-      case s: ShuffleQueryStageExec => s
-    }
+    val shuffleStages = collectShuffleStages(plan)
     // ShuffleExchanges introduced by repartition do not support changing the 
number of partitions.
     // We change the number of partitions in the stage only if all the 
ShuffleExchanges support it.
     if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
@@ -94,110 +69,27 @@ case class ReduceNumShufflePartitions(conf: SQLConf) 
extends Rule[SparkPlan] {
       // partition) and a result of a SortMergeJoin (multiple partitions).
       val distinctNumPreShufflePartitions =
         validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
-      val distinctExcludedPartitions = shuffles.map {
-        case PartialShuffleReaderExec(_, excludedPartitions) => 
excludedPartitions
-        case _: ShuffleQueryStageExec => Set.empty[Int]
-      }.distinct
-      if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1
-        && distinctExcludedPartitions.length == 1) {
-        val excludedPartitions = distinctExcludedPartitions.head
-        val partitionIndices = estimatePartitionStartAndEndIndices(
-          validMetrics.toArray, excludedPartitions)
+      if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 
1) {
+        val partitionStartIndices = 
ShufflePartitionsCoalescer.coalescePartitions(
+          validMetrics.toArray,
+          firstPartitionIndex = 0,
+          lastPartitionIndex = distinctNumPreShufflePartitions.head,
+          advisoryTargetSize = conf.targetPostShuffleInputSize,
+          minNumPartitions = conf.minNumPostShufflePartitions)
         // This transformation adds new nodes, so we must use `transformUp` 
here.
-        // Even for shuffle exchange whose input RDD has 0 partition, we 
should still update its
-        // `partitionStartIndices`, so that all the leaf shuffles in a stage 
have the same
-        // number of output partitions.
-        val visitedStages = HashSet.empty[Int]
-        plan.transformDown {
-          // Replace `PartialShuffleReaderExec` with 
`CoalescedShuffleReaderExec`, which keeps the
-          // "excludedPartition" requirement and also merges some partitions.
-          case PartialShuffleReaderExec(stage: ShuffleQueryStageExec, _) =>
-            visitedStages.add(stage.id)
-            CoalescedShuffleReaderExec(stage, partitionIndices)
-
-          // We are doing `transformDown`, so the `ShuffleQueryStageExec` may 
already be optimized
-          // and wrapped by `CoalescedShuffleReaderExec`.
-          case stage: ShuffleQueryStageExec if 
!visitedStages.contains(stage.id) =>
-            visitedStages.add(stage.id)
-            CoalescedShuffleReaderExec(stage, partitionIndices)
+        val stageIds = shuffleStages.map(_.id).toSet
+        plan.transformUp {
+          // even for shuffle exchange whose input RDD has 0 partition, we 
should still update its
+          // `partitionStartIndices`, so that all the leaf shuffles in a stage 
have the same
+          // number of output partitions.
+          case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
+            CoalescedShuffleReaderExec(stage, partitionStartIndices)
         }
       } else {
         plan
       }
     }
   }
-
-  /**
-   * Estimates partition start and end indices for post-shuffle partitions 
based on
-   * mapOutputStatistics provided by all pre-shuffle stages and skip the 
omittedPartitions
-   * already handled in skewed partition optimization.
-   */
-  // visible for testing.
-  private[sql] def estimatePartitionStartAndEndIndices(
-      mapOutputStatistics: Array[MapOutputStatistics],
-      excludedPartitions: Set[Int] = Set.empty): Array[(Int, Int)] = {
-    val minNumPostShufflePartitions = conf.minNumPostShufflePartitions - 
excludedPartitions.size
-    val advisoryTargetPostShuffleInputSize = conf.targetPostShuffleInputSize
-    // If minNumPostShufflePartitions is defined, it is possible that we need 
to use a
-    // value less than advisoryTargetPostShuffleInputSize as the target input 
size of
-    // a post shuffle task.
-    val totalPostShuffleInputSize = 
mapOutputStatistics.map(_.bytesByPartitionId.sum).sum
-    // The max at here is to make sure that when we have an empty table, we
-    // only have a single post-shuffle partition.
-    // There is no particular reason that we pick 16. We just need a number to
-    // prevent maxPostShuffleInputSize from being set to 0.
-    val maxPostShuffleInputSize = math.max(
-      math.ceil(totalPostShuffleInputSize / 
minNumPostShufflePartitions.toDouble).toLong, 16)
-    val targetPostShuffleInputSize =
-      math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize)
-
-    logInfo(
-      s"advisoryTargetPostShuffleInputSize: 
$advisoryTargetPostShuffleInputSize, " +
-        s"targetPostShuffleInputSize $targetPostShuffleInputSize.")
-
-    // Make sure we do get the same number of pre-shuffle partitions for those 
stages.
-    val distinctNumPreShufflePartitions =
-      mapOutputStatistics.map(stats => 
stats.bytesByPartitionId.length).distinct
-    // The reason that we are expecting a single value of the number of 
pre-shuffle partitions
-    // is that when we add Exchanges, we set the number of pre-shuffle 
partitions
-    // (i.e. map output partitions) using a static setting, which is the value 
of
-    // spark.sql.shuffle.partitions. Even if two input RDDs are having 
different
-    // number of partitions, they will have the same number of pre-shuffle 
partitions
-    // (i.e. map output partitions).
-    assert(
-      distinctNumPreShufflePartitions.length == 1,
-      "There should be only one distinct value of the number pre-shuffle 
partitions " +
-        "among registered Exchange operator.")
-
-    val partitionStartIndices = ArrayBuffer[Int]()
-    val partitionEndIndices = ArrayBuffer[Int]()
-    val numPartitions = distinctNumPreShufflePartitions.head
-    val includedPartitions = (0 until 
numPartitions).filter(!excludedPartitions.contains(_))
-    val firstStartIndex = includedPartitions(0)
-    partitionStartIndices += firstStartIndex
-    var postShuffleInputSize = 
mapOutputStatistics.map(_.bytesByPartitionId(firstStartIndex)).sum
-    var i = firstStartIndex
-    includedPartitions.drop(1).foreach { nextPartitionIndex =>
-        val nextShuffleInputSize =
-          mapOutputStatistics.map(_.bytesByPartitionId(nextPartitionIndex)).sum
-        // If nextPartitionIndices is skewed and omitted, or including
-        // the nextShuffleInputSize would exceed the target partition size,
-        // then start a new partition.
-        if (nextPartitionIndex != i + 1 ||
-          (postShuffleInputSize + nextShuffleInputSize > 
targetPostShuffleInputSize)) {
-          partitionEndIndices += i + 1
-          partitionStartIndices += nextPartitionIndex
-          // reset postShuffleInputSize.
-          postShuffleInputSize = nextShuffleInputSize
-          i = nextPartitionIndex
-        } else {
-          postShuffleInputSize += nextShuffleInputSize
-          i += 1
-        }
-    }
-    partitionEndIndices += i + 1
-    partitionStartIndices.zip(partitionEndIndices).toArray
-  }
 }
 
 /**
@@ -206,15 +98,16 @@ case class ReduceNumShufflePartitions(conf: SQLConf) 
extends Rule[SparkPlan] {
  *
  * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle 
exchange node during
  *              canonicalization.
+ * @param partitionStartIndices The start partition indices for the coalesced 
partitions.
  */
 case class CoalescedShuffleReaderExec(
     child: SparkPlan,
-    partitionIndices: Array[(Int, Int)]) extends UnaryExecNode {
+    partitionStartIndices: Array[Int]) extends UnaryExecNode {
 
   override def output: Seq[Attribute] = child.output
 
   override def outputPartitioning: Partitioning = {
-    UnknownPartitioning(partitionIndices.length)
+    UnknownPartitioning(partitionStartIndices.length)
   }
 
   private var cachedShuffleRDD: ShuffledRowRDD = null
@@ -223,7 +116,7 @@ case class CoalescedShuffleReaderExec(
     if (cachedShuffleRDD == null) {
       cachedShuffleRDD = child match {
         case stage: ShuffleQueryStageExec =>
-          stage.shuffle.createShuffledRDD(Some(partitionIndices))
+          stage.shuffle.createShuffledRDD(Some(partitionStartIndices))
         case _ =>
           throw new IllegalStateException("operating on canonicalization plan")
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala
new file mode 100644
index 0000000..18f0585
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.MapOutputStatistics
+import org.apache.spark.internal.Logging
+
+object ShufflePartitionsCoalescer extends Logging {
+
+  /**
+   * Coalesce the same range of partitions (`firstPartitionIndex`` to 
`lastPartitionIndex`, the
+   * start is inclusive and the end is exclusive) from multiple shuffles. This 
method assumes that
+   * all the shuffles have the same number of partitions, and the partitions 
of same index will be
+   * read together by one task.
+   *
+   * The strategy used to determine the number of coalesced partitions is 
described as follows.
+   * To determine the number of coalesced partitions, we have a target size 
for a coalesced
+   * partition. Once we have size statistics of all shuffle partitions, we 
will do
+   * a pass of those statistics and pack shuffle partitions with continuous 
indices to a single
+   * coalesced partition until adding another shuffle partition would cause 
the size of a
+   * coalesced partition to be greater than the target size.
+   *
+   * For example, we have two shuffles with the following partition size 
statistics:
+   *  - shuffle 1 (5 partitions): [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB]
+   *  - shuffle 2 (5 partitions): [10 MiB,  10 MiB, 70 MiB,  5 MiB, 5 MiB]
+   * Assuming the target size is 128 MiB, we will have 4 coalesced partitions, 
which are:
+   *  - coalesced partition 0: shuffle partition 0 (size 110 MiB)
+   *  - coalesced partition 1: shuffle partition 1 (size 30 MiB)
+   *  - coalesced partition 2: shuffle partition 2 (size 170 MiB)
+   *  - coalesced partition 3: shuffle partition 3 and 4 (size 50 MiB)
+   *
+   *  @return An array of partition indices which represents the coalesced 
partitions. For example,
+   *          [0, 2, 3] means 3 coalesced partitions: [0, 2), [2, 3), [3, 
lastPartitionIndex]
+   */
+  def coalescePartitions(
+      mapOutputStatistics: Array[MapOutputStatistics],
+      firstPartitionIndex: Int,
+      lastPartitionIndex: Int,
+      advisoryTargetSize: Long,
+      minNumPartitions: Int = 1): Array[Int] = {
+    // If `minNumPartitions` is very large, it is possible that we need to use 
a value less than
+    // `advisoryTargetSize` as the target size of a coalesced task.
+    val totalPostShuffleInputSize = 
mapOutputStatistics.map(_.bytesByPartitionId.sum).sum
+    // The max at here is to make sure that when we have an empty table, we 
only have a single
+    // coalesced partition.
+    // There is no particular reason that we pick 16. We just need a number to 
prevent
+    // `maxTargetSize` from being set to 0.
+    val maxTargetSize = math.max(
+      math.ceil(totalPostShuffleInputSize / minNumPartitions.toDouble).toLong, 
16)
+    val targetSize = math.min(maxTargetSize, advisoryTargetSize)
+
+    logInfo(s"advisory target size: $advisoryTargetSize, actual target size 
$targetSize.")
+
+    // Make sure these shuffles have the same number of partitions.
+    val distinctNumShufflePartitions =
+      mapOutputStatistics.map(stats => 
stats.bytesByPartitionId.length).distinct
+    // The reason that we are expecting a single value of the number of 
shuffle partitions
+    // is that when we add Exchanges, we set the number of shuffle partitions
+    // (i.e. map output partitions) using a static setting, which is the value 
of
+    // `spark.sql.shuffle.partitions`. Even if two input RDDs are having 
different
+    // number of partitions, they will have the same number of shuffle 
partitions
+    // (i.e. map output partitions).
+    assert(
+      distinctNumShufflePartitions.length == 1,
+      "There should be only one distinct value of the number of shuffle 
partitions " +
+        "among registered Exchange operators.")
+
+    val splitPoints = ArrayBuffer[Int]()
+    splitPoints += firstPartitionIndex
+    var coalescedSize = 0L
+    var i = firstPartitionIndex
+    while (i < lastPartitionIndex) {
+      // We calculate the total size of i-th shuffle partitions from all 
shuffles.
+      var totalSizeOfCurrentPartition = 0L
+      var j = 0
+      while (j < mapOutputStatistics.length) {
+        totalSizeOfCurrentPartition += 
mapOutputStatistics(j).bytesByPartitionId(i)
+        j += 1
+      }
+
+      // If including the `totalSizeOfCurrentPartition` would exceed the 
target size, then start a
+      // new coalesced partition.
+      if (i > firstPartitionIndex && coalescedSize + 
totalSizeOfCurrentPartition > targetSize) {
+        splitPoints += i
+        // reset postShuffleInputSize.
+        coalescedSize = totalSizeOfCurrentPartition
+      } else {
+        coalescedSize += totalSizeOfCurrentPartition
+      }
+      i += 1
+    }
+
+    splitPoints.toArray
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala
deleted file mode 100644
index 52f793b..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.adaptive
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.metric.{SQLMetric, 
SQLShuffleReadMetricsReporter}
-
-/**
- * The [[Partition]] used by [[SkewedShuffledRowRDD]].
- */
-class SkewedShuffledRowRDDPartition(override val index: Int) extends Partition
-
-/**
- * This is a specialized version of 
[[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used
- * in Spark SQL adaptive execution to solve data skew issues. This RDD 
includes rearranged
- * partitions from mappers.
- *
- * This RDD takes a [[ShuffleDependency]] (`dependency`), a partitionIndex
- * and the range of startMapIndex to endMapIndex.
- */
-class SkewedShuffledRowRDD(
-     var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
-     partitionIndex: Int,
-     startMapIndex: Int,
-     endMapIndex: Int,
-     metrics: Map[String, SQLMetric])
-  extends RDD[InternalRow](dependency.rdd.context, Nil) {
-
-  override def getDependencies: Seq[Dependency[_]] = List(dependency)
-
-  override def getPartitions: Array[Partition] = {
-    Array(new SkewedShuffledRowRDDPartition(0))
-  }
-
-  override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val tracker = 
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
-    tracker.getMapLocation(dependency, startMapIndex, endMapIndex)
-  }
-
-  override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
-    val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
-    // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL 
exchange operator,
-    // as well as the `tempMetrics` for basic shuffle metrics.
-    val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, 
metrics)
-
-    val reader = SparkEnv.get.shuffleManager.getReaderForRange(
-      dependency.shuffleHandle,
-      startMapIndex,
-      endMapIndex,
-      partitionIndex,
-      partitionIndex + 1,
-      context,
-      sqlMetricsReporter)
-    reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
-  }
-
-  override def clearDependencies() {
-    super.clearDependencies()
-    dependency = null
-  }
-}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index ffcd6c7..4b08da0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -30,11 +30,11 @@ import 
org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
Divide, Literal, UnsafeProjection, UnsafeRow}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{LocalShuffledRowRDD, 
SkewedShuffledRowRDD}
+import org.apache.spark.sql.execution.adaptive.LocalShuffledRowRDD
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, 
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
@@ -49,11 +49,9 @@ case class ShuffleExchangeExec(
     child: SparkPlan,
     canChangeNumPartitions: Boolean = true) extends Exchange {
 
-  // NOTE: coordinator can be null after serialization/deserialization,
-  //       e.g. it can be null on the Executor side
   private lazy val writeMetrics =
     SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
-  private lazy val readMetrics =
+  private[sql] lazy val readMetrics =
     SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
   override lazy val metrics = Map(
     "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")
@@ -90,9 +88,8 @@ case class ShuffleExchangeExec(
       writeMetrics)
   }
 
-  def createShuffledRDD(
-      partitionRanges: Option[Array[(Int, Int)]]): ShuffledRowRDD = {
-    new ShuffledRowRDD(shuffleDependency, readMetrics, partitionRanges)
+  def createShuffledRDD(partitionStartIndices: Option[Array[Int]]): 
ShuffledRowRDD = {
+    new ShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndices)
   }
 
   def createLocalShuffleRDD(
@@ -100,14 +97,6 @@ case class ShuffleExchangeExec(
     new LocalShuffledRowRDD(shuffleDependency, readMetrics, 
partitionStartIndicesPerMapper)
   }
 
-  def createSkewedShuffleRDD(
-      partitionIndex: Int,
-      startMapIndex: Int,
-      endMapIndex: Int): SkewedShuffledRowRDD = {
-    new SkewedShuffledRowRDD(shuffleDependency,
-      partitionIndex, startMapIndex, endMapIndex, readMetrics)
-  }
-
   /**
    * Caches the created ShuffleRowRDD so we can reuse that.
    */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 6384aed..62eea61 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -28,7 +28,6 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{PartialShuffleReaderExec, 
SkewedPartitionReaderExec}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.util.collection.BitSet
 
@@ -42,11 +41,17 @@ case class SortMergeJoinExec(
     condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan,
-    isPartial: Boolean = false) extends BinaryExecNode with CodegenSupport {
+    isSkewJoin: Boolean = false) extends BinaryExecNode with CodegenSupport {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
 
+  override def nodeName: String = {
+    if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName
+  }
+
+  override def stringArgs: Iterator[Any] = 
super.stringArgs.toSeq.dropRight(1).iterator
+
   override def simpleStringWithNodeId(): String = {
     val opId = ExplainUtils.getOpId(this)
     s"$nodeName $joinType ($opId)".trim
@@ -98,7 +103,9 @@ case class SortMergeJoinExec(
   }
 
   override def requiredChildDistribution: Seq[Distribution] = {
-    if (isPartial) {
+    if (isSkewJoin) {
+      // We re-arrange the shuffle partitions to deal with skew join, and the 
new children
+      // partitioning doesn't satisfy `HashClusteredDistribution`.
       UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
     } else {
       HashClusteredDistribution(leftKeys) :: 
HashClusteredDistribution(rightKeys) :: Nil
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
index 04b4d4f..5565a0d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.execution
 
 import org.scalatest.BeforeAndAfterAll
 
-import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.internal.config.UI.UI_ENABLED
 import org.apache.spark.sql._
 import org.apache.spark.sql.execution.adaptive._
-import org.apache.spark.sql.execution.adaptive.{CoalescedShuffleReaderExec, 
ReduceNumShufflePartitions}
+import org.apache.spark.sql.execution.adaptive.CoalescedShuffleReaderExec
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -52,212 +52,6 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite 
with BeforeAndAfterA
     }
   }
 
-  private def checkEstimation(
-      rule: ReduceNumShufflePartitions,
-      bytesByPartitionIdArray: Array[Array[Long]],
-      expectedPartitionStartIndices: Array[Int]): Unit = {
-    val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map {
-      case (bytesByPartitionId, index) =>
-        new MapOutputStatistics(index, bytesByPartitionId)
-    }
-    val estimatedPartitionStartIndices =
-      rule.estimatePartitionStartAndEndIndices(mapOutputStatistics).map(_._1)
-    assert(estimatedPartitionStartIndices === expectedPartitionStartIndices)
-  }
-
-  private def createReduceNumShufflePartitionsRule(
-      advisoryTargetPostShuffleInputSize: Long,
-      minNumPostShufflePartitions: Int = 1): ReduceNumShufflePartitions = {
-    val conf = new SQLConf().copy(
-      SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE -> 
advisoryTargetPostShuffleInputSize,
-      SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS -> 
minNumPostShufflePartitions)
-    ReduceNumShufflePartitions(conf)
-  }
-
-  test("test estimatePartitionStartIndices - 1 Exchange") {
-    val rule = createReduceNumShufflePartitionsRule(100L)
-
-    {
-      // All bytes per partition are 0.
-      val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
-      checkEstimation(rule, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
-    }
-
-    {
-      // Some bytes per partition are 0 and total size is less than the target 
size.
-      // 1 post-shuffle partition is needed.
-      val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
-      checkEstimation(rule, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
-    }
-
-    {
-      // 2 post-shuffle partitions are needed.
-      val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0)
-      val expectedPartitionStartIndices = Array[Int](0, 3)
-      checkEstimation(rule, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
-    }
-
-    {
-      // There are a few large pre-shuffle partitions.
-      val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
-      checkEstimation(rule, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
-    }
-
-    {
-      // All pre-shuffle partitions are larger than the targeted size.
-      val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
-      checkEstimation(rule, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
-    }
-
-    {
-      // The last pre-shuffle partition is in a single post-shuffle partition.
-      val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 4)
-      checkEstimation(rule, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
-    }
-  }
-
-  test("test estimatePartitionStartIndices - 2 Exchanges") {
-    val rule = createReduceNumShufflePartitionsRule(100L)
-
-    {
-      // If there are multiple values of the number of pre-shuffle partitions,
-      // we should see an assertion error.
-      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
-      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0)
-      val mapOutputStatistics =
-        Array(
-          new MapOutputStatistics(0, bytesByPartitionId1),
-          new MapOutputStatistics(1, bytesByPartitionId2))
-      intercept[AssertionError](rule.estimatePartitionStartAndEndIndices(
-        mapOutputStatistics))
-    }
-
-    {
-      // All bytes per partition are 0.
-      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
-      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // Some bytes per partition are 0.
-      // 1 post-shuffle partition is needed.
-      val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
-      val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20)
-      val expectedPartitionStartIndices = Array[Int](0)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // 2 post-shuffle partition are needed.
-      val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
-      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 2, 4)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // 4 post-shuffle partition are needed.
-      val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0)
-      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // 2 post-shuffle partition are needed.
-      val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0)
-      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // There are a few large pre-shuffle partitions.
-      val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0)
-      val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // All pairs of pre-shuffle partitions are larger than the targeted size.
-      val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0)
-      val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-  }
-
-  test("test estimatePartitionStartIndices and enforce minimal number of 
reducers") {
-    val rule = createReduceNumShufflePartitionsRule(100L, 2)
-
-    {
-      // The minimal number of post-shuffle partitions is not enforced because
-      // the size of data is 0.
-      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
-      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // The minimal number of post-shuffle partitions is enforced.
-      val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20)
-      val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5)
-      val expectedPartitionStartIndices = Array[Int](0, 3)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-
-    {
-      // The number of post-shuffle partitions is determined by the 
coordinator.
-      val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20)
-      val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4)
-      checkEstimation(
-        rule,
-        Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices)
-    }
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Query tests
-  ///////////////////////////////////////////////////////////////////////////
-
   val numInputPartitions: Int = 10
 
   def withSparkSession(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala
new file mode 100644
index 0000000..fcfde83
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.{MapOutputStatistics, SparkFunSuite}
+import org.apache.spark.sql.execution.adaptive.ShufflePartitionsCoalescer
+
+class ShufflePartitionsCoalescerSuite extends SparkFunSuite {
+
+  private def checkEstimation(
+      bytesByPartitionIdArray: Array[Array[Long]],
+      expectedPartitionStartIndices: Array[Int],
+      targetSize: Long,
+      minNumPartitions: Int = 1): Unit = {
+    val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map {
+      case (bytesByPartitionId, index) =>
+        new MapOutputStatistics(index, bytesByPartitionId)
+    }
+    val estimatedPartitionStartIndices = 
ShufflePartitionsCoalescer.coalescePartitions(
+      mapOutputStatistics,
+      0,
+      bytesByPartitionIdArray.head.length,
+      targetSize,
+      minNumPartitions)
+    assert(estimatedPartitionStartIndices === expectedPartitionStartIndices)
+  }
+
+  test("1 shuffle") {
+    val targetSize = 100
+
+    {
+      // All bytes per partition are 0.
+      val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+    }
+
+    {
+      // Some bytes per partition are 0 and total size is less than the target 
size.
+      // 1 coalesced partition is expected.
+      val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+    }
+
+    {
+      // 2 coalesced partitions are expected.
+      val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0)
+      val expectedPartitionStartIndices = Array[Int](0, 3)
+      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+    }
+
+    {
+      // There are a few large shuffle partitions.
+      val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+    }
+
+    {
+      // All shuffle partitions are larger than the targeted size.
+      val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+    }
+
+    {
+      // The last shuffle partition is in a single coalesced partition.
+      val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 4)
+      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+    }
+  }
+
+  test("2 shuffles") {
+    val targetSize = 100
+
+    {
+      // If there are multiple values of the number of shuffle partitions,
+      // we should see an assertion error.
+      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
+      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0)
+      intercept[AssertionError] {
+        checkEstimation(Array(bytesByPartitionId1, bytesByPartitionId2), 
Array.empty, targetSize)
+      }
+    }
+
+    {
+      // All bytes per partition are 0.
+      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
+      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize)
+    }
+
+    {
+      // Some bytes per partition are 0.
+      // 1 coalesced partition is expected.
+      val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize)
+    }
+
+    {
+      // 2 coalesced partition are expected.
+      val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 2, 4)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize)
+    }
+
+    {
+      // 4 coalesced partition are expected.
+      val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize)
+    }
+
+    {
+      // 2 coalesced partition are needed.
+      val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize)
+    }
+
+    {
+      // There are a few large shuffle partitions.
+      val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize)
+    }
+
+    {
+      // All pairs of shuffle partitions are larger than the targeted size.
+      val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize)
+    }
+  }
+
+  test("enforce minimal number of coalesced partitions") {
+    val targetSize = 100
+    val minNumPartitions = 2
+
+    {
+      // The minimal number of coalesced partitions is not enforced because
+      // the size of data is 0.
+      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
+      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize, minNumPartitions)
+    }
+
+    {
+      // The minimal number of coalesced partitions is enforced.
+      val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20)
+      val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5)
+      val expectedPartitionStartIndices = Array[Int](0, 3)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize, minNumPartitions)
+    }
+
+    {
+      // The number of coalesced partitions is determined by the algorithm.
+      val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20)
+      val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4)
+      checkEstimation(
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices,
+        targetSize, minNumPartitions)
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index a207190..4edb35e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -23,7 +23,7 @@ import java.net.URI
 import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, 
SparkListenerJobStart}
 import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan}
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
Exchange, ReusedExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
Exchange, ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BuildRight, SortMergeJoinExec}
 import 
org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
 import org.apache.spark.sql.internal.SQLConf
@@ -594,160 +594,84 @@ class AdaptiveQueryExecSuite
           .range(0, 1000, 1, 10)
           .selectExpr("id % 1 as key2", "id as value2")
           .createOrReplaceTempView("skewData2")
-        val (innerPlan, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
+        val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
           "SELECT key1 FROM skewData1 join skewData2 ON key1 = key2 group by 
key1")
-        val innerSmj = findTopLevelSortMergeJoin(innerPlan)
-        assert(innerSmj.size == 1)
         // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" 
optimization
-        val innerSmjAfter = findTopLevelSortMergeJoin(innerAdaptivePlan)
-        assert(innerSmjAfter.size == 1)
+        val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
+        assert(innerSmj.size == 1 && !innerSmj.head.isSkewJoin)
       }
     }
   }
 
+  // TODO: we need a way to customize data distribution after shuffle, to 
improve test coverage
+  //       of this case.
   test("SPARK-29544: adaptive skew join with different join types") {
-    Seq("false", "true").foreach { reducePostShufflePartitionsEnabled =>
-      withSQLConf(
-        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
-        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-        SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> 
"100",
-        SQLConf.REDUCE_POST_SHUFFLE_PARTITIONS_ENABLED.key -> 
reducePostShufflePartitionsEnabled,
-        SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") {
-        withTempView("skewData1", "skewData2") {
-          spark
-            .range(0, 1000, 1, 10)
-            .selectExpr("id % 2 as key1", "id as value1")
-            .createOrReplaceTempView("skewData1")
-          spark
-            .range(0, 1000, 1, 10)
-            .selectExpr("id % 1 as key2", "id as value2")
-            .createOrReplaceTempView("skewData2")
-          // skewed inner join optimization
-          val (innerPlan, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
-            "SELECT * FROM skewData1 join skewData2 ON key1 = key2")
-          val innerSmj = findTopLevelSortMergeJoin(innerPlan)
-          assert(innerSmj.size == 1)
-          // left stats: [3496, 0, 0, 0, 4014]
-          // right stats:[6292, 0, 0, 0, 0]
-          // the partition 0 in both left and right side are all skewed.
-          // And divide into 5 splits both in left and right (the max splits 
number).
-          // So there are 5 x 5 smjs for partition 0.
-          // Partition 4 in left side is skewed and is divided into 5 splits.
-          // The right side of partition 4 is not skewed.
-          // So there are 5 smjs for partition 4.
-          // So total (25 + 5 + 1) smjs.
-          // Union
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- CoalescedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- CoalescedShuffleReader
-          //       +- ShuffleQueryStage
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //             .
-          //             .
-          //             .
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-
-          val innerSmjAfter = findTopLevelSortMergeJoin(innerAdaptivePlan)
-          assert(innerSmjAfter.size == 31)
-
-          // skewed left outer join optimization
-          val (leftPlan, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
-            "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
-          val leftSmj = findTopLevelSortMergeJoin(leftPlan)
-          assert(leftSmj.size == 1)
-          // left stats: [3496, 0, 0, 0, 4014]
-          // right stats:[6292, 0, 0, 0, 0]
-          // The partition 0 in both left and right are all skewed.
-          // The partition 4 in left side is skewed.
-          // But for left outer join, we don't split the right partition even 
skewed.
-          // So the partition 0 in left side is divided into 5 splits(the max 
split number).
-          // the partition 4 in left side is divided into 5 splits(the max 
split number).
-          // So total (5 + 5 + 1) smjs.
-          // Union
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- CoalescedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- CoalescedShuffleReader
-          //       +- ShuffleQueryStage
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //             .
-          //             .
-          //             .
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-
-          val leftSmjAfter = findTopLevelSortMergeJoin(leftAdaptivePlan)
-          assert(leftSmjAfter.size == 11)
-
-          // skewed right outer join optimization
-          val (rightPlan, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
-            "SELECT * FROM skewData1 right outer join skewData2 ON key1 = 
key2")
-          val rightSmj = findTopLevelSortMergeJoin(rightPlan)
-          assert(rightSmj.size == 1)
-          // left stats: [3496, 0, 0, 0, 4014]
-          // right stats:[6292, 0, 0, 0, 0]
-          // The partition 0 in both left and right side are all skewed.
-          // And the partition 4 in left side is skewed.
-          // But for right outer join, we don't split the left partition even 
skewed.
-          // And divide right side into 5 splits(the max split number)
-          // So total 6 smjs.
-          // Union
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- CoalescedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- CoalescedShuffleReader
-          //       +- ShuffleQueryStage
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //             .
-          //             .
-          //             .
-          // +- SortMergeJoin
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-          //   +- Sort
-          //     +- SkewedShuffleReader
-          //       +- ShuffleQueryStage
-
-          val rightSmjAfter = findTopLevelSortMergeJoin(rightAdaptivePlan)
-          assert(rightSmjAfter.size == 6)
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> "100",
+      SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") {
+      withTempView("skewData1", "skewData2") {
+        spark
+          .range(0, 1000, 1, 10)
+          .selectExpr("id % 2 as key1", "id as value1")
+          .createOrReplaceTempView("skewData1")
+        spark
+          .range(0, 1000, 1, 10)
+          .selectExpr("id % 1 as key2", "id as value2")
+          .createOrReplaceTempView("skewData2")
+
+        def checkSkewJoin(joins: Seq[SortMergeJoinExec], 
expectedNumPartitions: Int): Unit = {
+          assert(joins.size == 1 && joins.head.isSkewJoin)
+          assert(joins.head.left.collect {
+            case r: SkewJoinShuffleReaderExec => r
+          }.head.partitionSpecs.length == expectedNumPartitions)
+          assert(joins.head.right.collect {
+            case r: SkewJoinShuffleReaderExec => r
+          }.head.partitionSpecs.length == expectedNumPartitions)
         }
+
+        // skewed inner join optimization
+        val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
+          "SELECT * FROM skewData1 join skewData2 ON key1 = key2")
+        // left stats: [3496, 0, 0, 0, 4014]
+        // right stats:[6292, 0, 0, 0, 0]
+        // Partition 0: both left and right sides are skewed, and divide into 
5 splits, so
+        //              5 x 5 sub-partitions.
+        // Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
+        // Partition 4: only left side is skewed, and divide into 5 splits, so
+        //              5 sub-partitions.
+        // So total (25 + 1 + 5) partitions.
+        val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
+        checkSkewJoin(innerSmj, 25 + 1 + 5)
+
+        // skewed left outer join optimization
+        val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
+          "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
+        // left stats: [3496, 0, 0, 0, 4014]
+        // right stats:[6292, 0, 0, 0, 0]
+        // Partition 0: both left and right sides are skewed, but left join 
can't split right side,
+        //              so only left side is divided into 5 splits, and thus 5 
sub-partitions.
+        // Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
+        // Partition 4: only left side is skewed, and divide into 5 splits, so
+        //              5 sub-partitions.
+        // So total (5 + 1 + 5) partitions.
+        val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan)
+        checkSkewJoin(leftSmj, 5 + 1 + 5)
+
+        // skewed right outer join optimization
+        val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
+          "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
+        // left stats: [3496, 0, 0, 0, 4014]
+        // right stats:[6292, 0, 0, 0, 0]
+        // Partition 0: both left and right sides are skewed, but right join 
can't split left side,
+        //              so only right side is divided into 5 splits, and thus 
5 sub-partitions.
+        // Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
+        // Partition 4: only left side is skewed, but right join can't split 
left side, so just
+        //              1 partition.
+        // So total (5 + 1 + 1) partitions.
+        val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan)
+        checkSkewJoin(rightSmj, 5 + 1 + 1)
       }
     }
   }
@@ -805,3 +729,4 @@ class AdaptiveQueryExecSuite
         s" enabled but is not supported for")))
   }
 }
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to