Repository: spark
Updated Branches:
  refs/heads/master c34c27fe9 -> d728d5c98


[SPARK-9858][SPARK-9859][SPARK-9861][SQL] Add an ExchangeCoordinator to 
estimate the number of post-shuffle partitions for aggregates and joins

https://issues.apache.org/jira/browse/SPARK-9858
https://issues.apache.org/jira/browse/SPARK-9859
https://issues.apache.org/jira/browse/SPARK-9861

Author: Yin Huai <[email protected]>

Closes #9276 from yhuai/numReducer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d728d5c9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d728d5c9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d728d5c9

Branch: refs/heads/master
Commit: d728d5c98658c44ed2949b55d36edeaa46f8c980
Parents: c34c27f
Author: Yin Huai <[email protected]>
Authored: Tue Nov 3 00:12:49 2015 -0800
Committer: Yin Huai <[email protected]>
Committed: Tue Nov 3 00:12:49 2015 -0800

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  |   8 +
 .../scala/org/apache/spark/sql/SQLConf.scala    |  27 ++
 .../apache/spark/sql/execution/Exchange.scala   | 217 ++++++++-
 .../sql/execution/ExchangeCoordinator.scala     | 260 ++++++++++
 .../spark/sql/execution/ShuffledRowRDD.scala    | 134 +++++-
 .../execution/ExchangeCoordinatorSuite.scala    | 479 +++++++++++++++++++
 .../spark/sql/execution/PlannerSuite.scala      |   8 +-
 .../execution/UnsafeRowSerializerSuite.scala    |   7 +-
 .../sql/execution/joins/InnerJoinSuite.scala    |  19 +-
 9 files changed, 1115 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 86b9417..9312c81 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -165,6 +165,11 @@ sealed trait Partitioning {
    * produced by `A` could have also been produced by `B`.
    */
   def guarantees(other: Partitioning): Boolean = this == other
+
+  def withNumPartitions(newNumPartitions: Int): Partitioning = {
+    throw new IllegalStateException(
+      s"It is not allowed to call withNumPartitions method of a 
${this.getClass.getSimpleName}")
+  }
 }
 
 object Partitioning {
@@ -249,6 +254,9 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
     case _ => false
   }
 
+  override def withNumPartitions(newNumPartitions: Int): HashPartitioning = {
+    HashPartitioning(expressions, newNumPartitions)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 6f28920..ed8b634 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -233,6 +233,25 @@ private[spark] object SQLConf {
     defaultValue = Some(200),
     doc = "The default number of partitions to use when shuffling data for 
joins or aggregations.")
 
+  val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE =
+    longConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize",
+      defaultValue = Some(64 * 1024 * 1024),
+      doc = "The target post-shuffle input size in bytes of a task.")
+
+  val ADAPTIVE_EXECUTION_ENABLED = booleanConf("spark.sql.adaptive.enabled",
+    defaultValue = Some(false),
+    doc = "When true, enable adaptive query execution.")
+
+  val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS =
+    intConf("spark.sql.adaptive.minNumPostShufflePartitions",
+      defaultValue = Some(-1),
+      doc = "The advisory minimal number of post-shuffle partitions provided 
to " +
+        "ExchangeCoordinator. This setting is used in our test to make sure we 
" +
+        "have enough parallelism to expose issues that will not be exposed 
with a " +
+        "single partition. When the value is a non-positive value, this 
setting will" +
+        "not be provided to ExchangeCoordinator.",
+      isPublic = false)
+
   val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled",
     defaultValue = Some(true),
     doc = "When true, use the optimized Tungsten physical execution backend 
which explicitly " +
@@ -487,6 +506,14 @@ private[sql] class SQLConf extends Serializable with 
CatalystConf {
 
   private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)
 
+  private[spark] def targetPostShuffleInputSize: Long =
+    getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE)
+
+  private[spark] def adaptiveExecutionEnabled: Boolean = 
getConf(ADAPTIVE_EXECUTION_ENABLED)
+
+  private[spark] def minNumPostShufflePartitions: Int =
+    getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
+
   private[spark] def parquetFilterPushDown: Boolean = 
getConf(PARQUET_FILTER_PUSHDOWN_ENABLED)
 
   private[spark] def orcFilterPushDown: Boolean = 
getConf(ORC_FILTER_PUSHDOWN_ENABLED)

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index e81108b..0f72ec6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -36,9 +36,23 @@ import org.apache.spark.util.MutablePair
 /**
  * Performs a shuffle that will result in the desired `newPartitioning`.
  */
-case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends 
UnaryNode {
+case class Exchange(
+    var newPartitioning: Partitioning,
+    child: SparkPlan,
+    @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode {
 
-  override def nodeName: String = if (tungstenMode) "TungstenExchange" else 
"Exchange"
+  override def nodeName: String = {
+    val extraInfo = coordinator match {
+      case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated =>
+        "Shuffle"
+      case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated =>
+        "May shuffle"
+      case None => "Shuffle without coordinator"
+    }
+
+    val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange"
+    s"$simpleNodeName($extraInfo)"
+  }
 
   /**
    * Returns true iff we can support the data type, and we are not doing range 
partitioning.
@@ -129,7 +143,27 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
     }
   }
 
-  protected override def doExecute(): RDD[InternalRow] = attachTree(this , 
"execute") {
+  override protected def doPrepare(): Unit = {
+    // If an ExchangeCoordinator is needed, we register this Exchange operator
+    // to the coordinator when we do prepare. It is important to make sure
+    // we register this operator right before the execution instead of 
register it
+    // in the constructor because it is possible that we create new instances 
of
+    // Exchange operators when we transform the physical plan
+    // (then the ExchangeCoordinator will hold references of unneeded 
Exchanges).
+    // So, we should only call registerExchange just before we start to execute
+    // the plan.
+    coordinator match {
+      case Some(exchangeCoordinator) => 
exchangeCoordinator.registerExchange(this)
+      case None =>
+    }
+  }
+
+  /**
+   * Returns a [[ShuffleDependency]] that will partition rows of its child 
based on
+   * the partitioning scheme defined in `newPartitioning`. Those partitions of
+   * the returned ShuffleDependency will be the input of shuffle.
+   */
+  private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, 
InternalRow, InternalRow] = {
     val rdd = child.execute()
     val part: Partitioner = newPartitioning match {
       case RoundRobinPartitioning(numPartitions) => new 
HashPartitioner(numPartitions)
@@ -181,7 +215,54 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
         }
       }
     }
-    new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions)
+
+    // Now, we manually create a ShuffleDependency. Because pairs in 
rddWithPartitionIds
+    // are in the form of (partitionId, row) and every partitionId is in the 
expected range
+    // [0, part.numPartitions - 1]. The partitioner of this is a 
PartitionIdPassthrough.
+    val dependency =
+      new ShuffleDependency[Int, InternalRow, InternalRow](
+        rddWithPartitionIds,
+        new PartitionIdPassthrough(part.numPartitions),
+        Some(serializer))
+
+    dependency
+  }
+
+  /**
+   * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
+   * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] 
and an optional
+   * partition start indices array. If this optional array is defined, the 
returned
+   * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of 
this array.
+   */
+  private[sql] def preparePostShuffleRDD(
+      shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
+      specifiedPartitionStartIndices: Option[Array[Int]] = None): 
ShuffledRowRDD = {
+    // If an array of partition start indices is provided, we need to use this 
array
+    // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
+    // update the number of post-shuffle partitions.
+    specifiedPartitionStartIndices.foreach { indices =>
+      assert(newPartitioning.isInstanceOf[HashPartitioning])
+      newPartitioning = newPartitioning.withNumPartitions(indices.length)
+    }
+    new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this , 
"execute") {
+    coordinator match {
+      case Some(exchangeCoordinator) =>
+        val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
+        assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
+        shuffleRDD
+      case None =>
+        val shuffleDependency = prepareShuffleDependency()
+        preparePostShuffleRDD(shuffleDependency)
+    }
+  }
+}
+
+object Exchange {
+  def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
+    Exchange(newPartitioning, child, None: Option[ExchangeCoordinator])
   }
 }
 
@@ -193,13 +274,22 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
  * input partition ordering requirements are met.
  */
 private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends 
Rule[SparkPlan] {
-  // TODO: Determine the number of partitions.
-  private def defaultPartitions: Int = sqlContext.conf.numShufflePartitions
+  private def defaultNumPreShufflePartitions: Int = 
sqlContext.conf.numShufflePartitions
+
+  private def targetPostShuffleInputSize: Long = 
sqlContext.conf.targetPostShuffleInputSize
+
+  private def adaptiveExecutionEnabled: Boolean = 
sqlContext.conf.adaptiveExecutionEnabled
+
+  private def minNumPostShufflePartitions: Option[Int] = {
+    val minNumPostShufflePartitions = 
sqlContext.conf.minNumPostShufflePartitions
+    if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) 
else None
+  }
 
   /**
    * Given a required distribution, returns a partitioning that satisfies that 
distribution.
    */
-  private def createPartitioning(requiredDistribution: Distribution,
+  private def createPartitioning(
+      requiredDistribution: Distribution,
       numPartitions: Int): Partitioning = {
     requiredDistribution match {
       case AllTuples => SinglePartition
@@ -209,6 +299,98 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
     }
   }
 
+  /**
+   * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution 
is enabled
+   * and partitioning schemes of these [[Exchange]]s support 
[[ExchangeCoordinator]].
+   */
+  private def withExchangeCoordinator(
+      children: Seq[SparkPlan],
+      requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = {
+    val supportsCoordinator =
+      if (children.exists(_.isInstanceOf[Exchange])) {
+        // Right now, ExchangeCoordinator only support HashPartitionings.
+        children.forall {
+          case e @ Exchange(hash: HashPartitioning, _, _) => true
+          case child =>
+            child.outputPartitioning match {
+              case hash: HashPartitioning => true
+              case collection: PartitioningCollection =>
+                
collection.partitionings.exists(_.isInstanceOf[HashPartitioning])
+              case _ => false
+            }
+        }
+      } else {
+        // In this case, although we do not have Exchange operators, we may 
still need to
+        // shuffle data when we have more than one children because data 
generated by
+        // these children may not be partitioned in the same way.
+        // Please see the comment in withCoordinator for more details.
+        val supportsDistribution =
+          
requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
+        children.length > 1 && supportsDistribution
+      }
+
+    val withCoordinator =
+      if (adaptiveExecutionEnabled && supportsCoordinator) {
+        val coordinator =
+          new ExchangeCoordinator(
+            children.length,
+            targetPostShuffleInputSize,
+            minNumPostShufflePartitions)
+        children.zip(requiredChildDistributions).map {
+          case (e: Exchange, _) =>
+            // This child is an Exchange, we need to add the coordinator.
+            e.copy(coordinator = Some(coordinator))
+          case (child, distribution) =>
+            // If this child is not an Exchange, we need to add an Exchange 
for now.
+            // Ideally, we can try to avoid this Exchange. However, when we 
reach here,
+            // there are at least two children operators (because if there is 
a single child
+            // and we can avoid Exchange, supportsCoordinator will be false 
and we
+            // will not reach here.). Although we can make two children have 
the same number of
+            // post-shuffle partitions. Their numbers of pre-shuffle 
partitions may be different.
+            // For example, let's say we have the following plan
+            //         Join
+            //         /  \
+            //       Agg  Exchange
+            //       /      \
+            //    Exchange  t2
+            //      /
+            //     t1
+            // In this case, because a post-shuffle partition can include 
multiple pre-shuffle
+            // partitions, a HashPartitioning will not be strictly partitioned 
by the hashcodes
+            // after shuffle. So, even we can use the child Exchange operator 
of the Join to
+            // have a number of post-shuffle partitions that matches the 
number of partitions of
+            // Agg, we cannot say these two children are partitioned in the 
same way.
+            // Here is another case
+            //         Join
+            //         /  \
+            //       Agg1  Agg2
+            //       /      \
+            //   Exchange1  Exchange2
+            //       /       \
+            //      t1       t2
+            // In this case, two Aggs shuffle data with the same column of the 
join condition.
+            // After we use ExchangeCoordinator, these two Aggs may not be 
partitioned in the same
+            // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle 
partitions and 2
+            // post-shuffle partitions. It is possible that Agg1 fetches those 
pre-shuffle
+            // partitions by using a partitionStartIndices [0, 3]. However, 
Agg2 may fetch its
+            // pre-shuffle partitions by using another partitionStartIndices 
[0, 4].
+            // So, Agg1 and Agg2 are actually not co-partitioned.
+            //
+            // It will be great to introduce a new Partitioning to represent 
the post-shuffle
+            // partitions when one post-shuffle partition includes multiple 
pre-shuffle partitions.
+            val targetPartitioning =
+              createPartitioning(distribution, defaultNumPreShufflePartitions)
+            assert(targetPartitioning.isInstanceOf[HashPartitioning])
+            Exchange(targetPartitioning, child, Some(coordinator))
+        }
+      } else {
+        // If we do not need ExchangeCoordinator, the original children are 
returned.
+        children
+      }
+
+    withCoordinator
+  }
+
   private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
     val requiredChildDistributions: Seq[Distribution] = 
operator.requiredChildDistribution
     val requiredChildOrderings: Seq[Seq[SortOrder]] = 
operator.requiredChildOrdering
@@ -221,7 +403,7 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
       if (child.outputPartitioning.satisfies(distribution)) {
         child
       } else {
-        Exchange(createPartitioning(distribution, defaultPartitions), child)
+        Exchange(createPartitioning(distribution, 
defaultNumPreShufflePartitions), child)
       }
     }
 
@@ -234,7 +416,7 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
       // First check if the existing partitions of the children all match. 
This means they are
       // partitioned by the same partitioning into the same number of 
partitions. In that case,
       // don't try to make them match `defaultPartitions`, just use the 
existing partitioning.
-      // TODO: this should be a cost based descision. For example, a big 
relation should probably
+      // TODO: this should be a cost based decision. For example, a big 
relation should probably
       // maintain its existing number of partitions and smaller partitions 
should be shuffled.
       // defaultPartitions is arbitrary.
       val numPartitions = children.head.outputPartitioning.numPartitions
@@ -250,7 +432,8 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
       } else {
         children.zip(requiredChildDistributions).map {
           case (child, distribution) => {
-            val targetPartitioning = createPartitioning(distribution, 
defaultPartitions)
+            val targetPartitioning =
+              createPartitioning(distribution, defaultNumPreShufflePartitions)
             if (child.outputPartitioning.guarantees(targetPartitioning)) {
               child
             } else {
@@ -261,12 +444,24 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
       }
     }
 
+    // Now, we need to add ExchangeCoordinator if necessary.
+    // Actually, it is not a good idea to add ExchangeCoordinators while we 
are adding Exchanges.
+    // However, with the way that we plan the query, we do not have a place 
where we have a
+    // global picture of all shuffle dependencies of a post-shuffle stage. So, 
we add coordinator
+    // at here for now.
+    // Once we finish https://issues.apache.org/jira/browse/SPARK-10665,
+    // we can first add Exchanges and then add coordinator once we have a DAG 
of query fragments.
+    children = withExchangeCoordinator(children, requiredChildDistributions)
+
     // Now that we've performed any necessary shuffles, add sorts to guarantee 
output orderings:
     children = children.zip(requiredChildOrderings).map { case (child, 
requiredOrdering) =>
       if (requiredOrdering.nonEmpty) {
         // If child.outputOrdering is [a, b] and requiredOrdering is [a], we 
do not need to sort.
         if (requiredOrdering != 
child.outputOrdering.take(requiredOrdering.length)) {
-          sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, 
global = false, child)
+          sqlContext.planner.BasicOperators.getSortOperator(
+            requiredOrdering,
+            global = false,
+            child)
         } else {
           child
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
new file mode 100644
index 0000000..8dbd69e
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.{Map => JMap, HashMap => JHashMap}
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Logging, SimpleFutureAction, ShuffleDependency, 
MapOutputStatistics}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * A coordinator used to determines how we shuffle data between stages 
generated by Spark SQL.
+ * Right now, the work of this coordinator is to determine the number of 
post-shuffle partitions
+ * for a stage that needs to fetch shuffle data from one or multiple stages.
+ *
+ * A coordinator is constructed with three parameters, `numExchanges`,
+ * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`.
+ *  - `numExchanges` is used to indicated that how many [[Exchange]]s that 
will be registered to
+ *    this coordinator. So, when we start to do any actual work, we have a way 
to make sure that
+ *    we have got expected number of [[Exchange]]s.
+ *  - `targetPostShuffleInputSize` is the targeted size of a post-shuffle 
partition's
+ *    input data size. With this parameter, we can estimate the number of 
post-shuffle partitions.
+ *    This parameter is configured through
+ *    `spark.sql.adaptive.shuffle.targetPostShuffleInputSize`.
+ *  - `minNumPostShufflePartitions` is an optional parameter. If it is 
defined, this coordinator
+ *    will try to make sure that there are at least 
`minNumPostShufflePartitions` post-shuffle
+ *    partitions.
+ *
+ * The workflow of this coordinator is described as follows:
+ *  - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator,
+ *    if an [[ExchangeCoordinator]] is assigned to it, it registers itself to 
this coordinator.
+ *    This happens in the `doPrepare` method.
+ *  - Once we start to execute a physical plan, an [[Exchange]] registered to 
this coordinator will
+ *    call `postShuffleRDD` to get its corresponding post-shuffle 
[[ShuffledRowRDD]].
+ *    If this coordinator has made the decision on how to shuffle data, this 
[[Exchange]] will
+ *    immediately get its corresponding post-shuffle [[ShuffledRowRDD]].
+ *  - If this coordinator has not made the decision on how to shuffle data, it 
will ask those
+ *    registered [[Exchange]]s to submit their pre-shuffle stages. Then, based 
on the the size
+ *    statistics of pre-shuffle partitions, this coordinator will determine 
the number of
+ *    post-shuffle partitions and pack multiple pre-shuffle partitions with 
continuous indices
+ *    to a single post-shuffle partition whenever necessary.
+ *  - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s 
for all registered
+ *    [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this 
coordinator can
+ *    lookup the corresponding [[RDD]].
+ *
+ * The strategy used to determine the number of post-shuffle partitions is 
described as follows.
+ * To determine the number of post-shuffle partitions, we have a target input 
size for a
+ * post-shuffle partition. Once we have size statistics of pre-shuffle 
partitions from stages
+ * corresponding to the registered [[Exchange]]s, we will do a pass of those 
statistics and
+ * pack pre-shuffle partitions with continuous indices to a single 
post-shuffle partition until
+ * the size of a post-shuffle partition is equal or greater than the target 
size.
+ * For example, we have two stages with the following pre-shuffle partition 
size statistics:
+ * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB]
+ * stage 2: [10 MB,  10 MB, 70 MB,  5 MB, 5 MB]
+ * assuming the target input size is 128 MB, we will have three post-shuffle 
partitions,
+ * which are:
+ *  - post-shuffle partition 0: pre-shuffle partition 0 and 1
+ *  - post-shuffle partition 1: pre-shuffle partition 2
+ *  - post-shuffle partition 2: pre-shuffle partition 3 and 4
+ */
+private[sql] class ExchangeCoordinator(
+    numExchanges: Int,
+    advisoryTargetPostShuffleInputSize: Long,
+    minNumPostShufflePartitions: Option[Int] = None)
+  extends Logging {
+
+  // The registered Exchange operators.
+  private[this] val exchanges = ArrayBuffer[Exchange]()
+
+  // This map is used to lookup the post-shuffle ShuffledRowRDD for an 
Exchange operator.
+  private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] =
+    new JHashMap[Exchange, ShuffledRowRDD](numExchanges)
+
+  // A boolean that indicates if this coordinator has made decision on how to 
shuffle data.
+  // This variable will only be updated by doEstimationIfNecessary, which is 
protected by
+  // synchronized.
+  @volatile private[this] var estimated: Boolean = false
+
+  /**
+   * Registers an [[Exchange]] operator to this coordinator. This method is 
only allowed to be
+   * called in the `doPrepare` method of an [[Exchange]] operator.
+   */
+  def registerExchange(exchange: Exchange): Unit = synchronized {
+    exchanges += exchange
+  }
+
+  def isEstimated: Boolean = estimated
+
+  /**
+   * Estimates partition start indices for post-shuffle partitions based on
+   * mapOutputStatistics provided by all pre-shuffle stages.
+   */
+  private[sql] def estimatePartitionStartIndices(
+      mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = {
+    // If we have mapOutputStatistics.length <= numExchange, it is because we 
do not submit
+    // a stage when the number of partitions of this dependency is 0.
+    assert(mapOutputStatistics.length <= numExchanges)
+
+    // If minNumPostShufflePartitions is defined, it is possible that we need 
to use a
+    // value less than advisoryTargetPostShuffleInputSize as the target input 
size of
+    // a post shuffle task.
+    val targetPostShuffleInputSize = minNumPostShufflePartitions match {
+      case Some(numPartitions) =>
+        val totalPostShuffleInputSize = 
mapOutputStatistics.map(_.bytesByPartitionId.sum).sum
+        // The max at here is to make sure that when we have an empty table, we
+        // only have a single post-shuffle partition.
+        val maxPostShuffleInputSize =
+          math.max(math.ceil(totalPostShuffleInputSize / 
numPartitions.toDouble).toLong, 16)
+        math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize)
+
+      case None => advisoryTargetPostShuffleInputSize
+    }
+
+    logInfo(
+      s"advisoryTargetPostShuffleInputSize: 
$advisoryTargetPostShuffleInputSize, " +
+      s"targetPostShuffleInputSize $targetPostShuffleInputSize.")
+
+    // Make sure we do get the same number of pre-shuffle partitions for those 
stages.
+    val distinctNumPreShufflePartitions =
+      mapOutputStatistics.map(stats => 
stats.bytesByPartitionId.length).distinct
+    assert(
+      distinctNumPreShufflePartitions.length == 1,
+      "There should be only one distinct value of the number pre-shuffle 
partitions " +
+        "among registered Exchange operator.")
+    val numPreShufflePartitions = distinctNumPreShufflePartitions.head
+
+    val partitionStartIndices = ArrayBuffer[Int]()
+    // The first element of partitionStartIndices is always 0.
+    partitionStartIndices += 0
+
+    var postShuffleInputSize = 0L
+
+    var i = 0
+    while (i < numPreShufflePartitions) {
+      // We calculate the total size of ith pre-shuffle partitions from all 
pre-shuffle stages.
+      // Then, we add the total size to postShuffleInputSize.
+      var j = 0
+      while (j < mapOutputStatistics.length) {
+        postShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i)
+        j += 1
+      }
+
+      // If the current postShuffleInputSize is equal or greater than the
+      // targetPostShuffleInputSize, We need to add a new element in 
partitionStartIndices.
+      if (postShuffleInputSize >= targetPostShuffleInputSize) {
+        if (i < numPreShufflePartitions - 1) {
+          // Next start index.
+          partitionStartIndices += i + 1
+        } else {
+          // This is the last element. So, we do not need to append the next 
start index to
+          // partitionStartIndices.
+        }
+        // reset postShuffleInputSize.
+        postShuffleInputSize = 0L
+      }
+
+      i += 1
+    }
+
+    partitionStartIndices.toArray
+  }
+
+  private def doEstimationIfNecessary(): Unit = synchronized {
+    // It is unlikely that this method will be called from multiple threads
+    // (when multiple threads trigger the execution of THIS physical)
+    // because in common use cases, we will create new physical plan after
+    // users apply operations (e.g. projection) to an existing DataFrame.
+    // However, if it happens, we have synchronized to make sure only one
+    // thread will trigger the job submission.
+    if (!estimated) {
+      // Make sure we have the expected number of registered Exchange 
operators.
+      assert(exchanges.length == numExchanges)
+
+      val newPostShuffleRDDs = new JHashMap[Exchange, 
ShuffledRowRDD](numExchanges)
+
+      // Submit all map stages
+      val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, 
InternalRow, InternalRow]]()
+      val submittedStageFutures = 
ArrayBuffer[SimpleFutureAction[MapOutputStatistics]]()
+      var i = 0
+      while (i < numExchanges) {
+        val exchange = exchanges(i)
+        val shuffleDependency = exchange.prepareShuffleDependency()
+        shuffleDependencies += shuffleDependency
+        if (shuffleDependency.rdd.partitions.length != 0) {
+          // submitMapStage does not accept RDD with 0 partition.
+          // So, we will not submit this dependency.
+          submittedStageFutures +=
+            exchange.sqlContext.sparkContext.submitMapStage(shuffleDependency)
+        }
+        i += 1
+      }
+
+      // Wait for the finishes of those submitted map stages.
+      val mapOutputStatistics = new 
Array[MapOutputStatistics](submittedStageFutures.length)
+      i = 0
+      while (i < submittedStageFutures.length) {
+        // This call is a blocking call. If the stage has not finished, we 
will wait at here.
+        mapOutputStatistics(i) = submittedStageFutures(i).get()
+        i += 1
+      }
+
+      // Now, we estimate partitionStartIndices. partitionStartIndices.length 
will be the
+      // number of post-shuffle partitions.
+      val partitionStartIndices =
+        if (mapOutputStatistics.length == 0) {
+          None
+        } else {
+          Some(estimatePartitionStartIndices(mapOutputStatistics))
+        }
+
+      i = 0
+      while (i < numExchanges) {
+        val exchange = exchanges(i)
+        val rdd =
+          exchange.preparePostShuffleRDD(shuffleDependencies(i), 
partitionStartIndices)
+        newPostShuffleRDDs.put(exchange, rdd)
+
+        i += 1
+      }
+
+      // Finally, we set postShuffleRDDs and estimated.
+      assert(postShuffleRDDs.isEmpty)
+      assert(newPostShuffleRDDs.size() == numExchanges)
+      postShuffleRDDs.putAll(newPostShuffleRDDs)
+      estimated = true
+    }
+  }
+
+  def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = {
+    doEstimationIfNecessary()
+
+    if (!postShuffleRDDs.containsKey(exchange)) {
+      throw new IllegalStateException(
+        s"The given $exchange is not registered in this coordinator.")
+    }
+
+    postShuffleRDDs.get(exchange)
+  }
+
+  override def toString: String = {
+    s"coordinator[target post-shuffle partition size: 
$advisoryTargetPostShuffleInputSize]"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index fb338b9..4289128 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -17,14 +17,23 @@
 
 package org.apache.spark.sql.execution
 
+import java.util.Arrays
+
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.serializer.Serializer
 import org.apache.spark.sql.catalyst.InternalRow
 
-private class ShuffledRowRDDPartition(val idx: Int) extends Partition {
-  override val index: Int = idx
-  override def hashCode(): Int = idx
+/**
+ * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition
+ * (identified by `postShufflePartitionIndex`) contains a range of pre-shuffle 
partitions
+ * (`startPreShufflePartitionIndex` to `endPreShufflePartitionIndex - 1`, 
inclusive).
+ */
+private final class ShuffledRowRDDPartition(
+    val postShufflePartitionIndex: Int,
+    val startPreShufflePartitionIndex: Int,
+    val endPreShufflePartitionIndex: Int) extends Partition {
+  override val index: Int = postShufflePartitionIndex
+  override def hashCode(): Int = postShufflePartitionIndex
 }
 
 /**
@@ -36,32 +45,106 @@ private class PartitionIdPassthrough(override val 
numPartitions: Int) extends Pa
 }
 
 /**
+ * A Partitioner that might group together one or more partitions from the 
parent.
+ *
+ * @param parent a parent partitioner
+ * @param partitionStartIndices indices of partitions in parent that should 
create new partitions
+ *   in child (this should be an array of increasing partition IDs). For 
example, if we have a
+ *   parent with 5 partitions, and partitionStartIndices is [0, 2, 4], we get 
three output
+ *   partitions, corresponding to partition ranges [0, 1], [2, 3] and [4] of 
the parent partitioner.
+ */
+class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: 
Array[Int])
+  extends Partitioner {
+
+  @transient private lazy val parentPartitionMapping: Array[Int] = {
+    val n = parent.numPartitions
+    val result = new Array[Int](n)
+    for (i <- 0 until partitionStartIndices.length) {
+      val start = partitionStartIndices(i)
+      val end = if (i < partitionStartIndices.length - 1) 
partitionStartIndices(i + 1) else n
+      for (j <- start until end) {
+        result(j) = i
+      }
+    }
+    result
+  }
+
+  override def numPartitions: Int = partitionStartIndices.length
+
+  override def getPartition(key: Any): Int = {
+    parentPartitionMapping(parent.getPartition(key))
+  }
+
+  override def equals(other: Any): Boolean = other match {
+    case c: CoalescedPartitioner =>
+      c.parent == parent && Arrays.equals(c.partitionStartIndices, 
partitionStartIndices)
+    case _ =>
+      false
+  }
+
+  override def hashCode(): Int = 31 * parent.hashCode() + 
Arrays.hashCode(partitionStartIndices)
+}
+
+/**
  * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that 
is optimized for
  * shuffling rows instead of Java key-value pairs. Note that something like 
this should eventually
  * be implemented in Spark core, but that is blocked by some more general 
refactorings to shuffle
  * interfaces / internals.
  *
- * @param prev the RDD being shuffled. Elements of this RDD are (partitionId, 
Row) pairs.
- *             Partition ids should be in the range [0, numPartitions - 1].
- * @param serializer the serializer used during the shuffle.
- * @param numPartitions the number of post-shuffle partitions.
+ * This RDD takes a [[ShuffleDependency]] (`dependency`),
+ * and a optional array of partition start indices as input arguments
+ * (`specifiedPartitionStartIndices`).
+ *
+ * The `dependency` has the parent RDD of this RDD, which represents the 
dataset before shuffle
+ * (i.e. map output). Elements of this RDD are (partitionId, Row) pairs.
+ * Partition ids should be in the range [0, numPartitions - 1].
+ * `dependency.partitioner` is the original partitioner used to partition
+ * map output, and `dependency.partitioner.numPartitions` is the number of 
pre-shuffle partitions
+ * (i.e. the number of partitions of the map output).
+ *
+ * When `specifiedPartitionStartIndices` is defined, 
`specifiedPartitionStartIndices.length`
+ * will be the number of post-shuffle partitions. For this case, the `i`th 
post-shuffle
+ * partition includes `specifiedPartitionStartIndices[i]` to
+ * `specifiedPartitionStartIndices[i+1] - 1` (inclusive).
+ *
+ * When `specifiedPartitionStartIndices` is not defined, there will be
+ * `dependency.partitioner.numPartitions` post-shuffle partitions. For this 
case,
+ * a post-shuffle partition is created for every pre-shuffle partition.
  */
 class ShuffledRowRDD(
-    @transient var prev: RDD[Product2[Int, InternalRow]],
-    serializer: Serializer,
-    numPartitions: Int)
-  extends RDD[InternalRow](prev.context, Nil) {
+    var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
+    specifiedPartitionStartIndices: Option[Array[Int]] = None)
+  extends RDD[InternalRow](dependency.rdd.context, Nil) {
 
-  private val part: Partitioner = new PartitionIdPassthrough(numPartitions)
+  private[this] val numPreShufflePartitions = 
dependency.partitioner.numPartitions
 
-  override def getDependencies: Seq[Dependency[_]] = {
-    List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, 
Some(serializer)))
+  private[this] val partitionStartIndices: Array[Int] = 
specifiedPartitionStartIndices match {
+    case Some(indices) => indices
+    case None =>
+      // When specifiedPartitionStartIndices is not defined, every 
post-shuffle partition
+      // corresponds to a pre-shuffle partition.
+      (0 until numPreShufflePartitions).toArray
   }
 
-  override val partitioner = Some(part)
+  private[this] val part: Partitioner =
+    new CoalescedPartitioner(dependency.partitioner, partitionStartIndices)
+
+  override def getDependencies: Seq[Dependency[_]] = List(dependency)
+
+  override val partitioner: Option[Partitioner] = Some(part)
 
   override def getPartitions: Array[Partition] = {
-    Array.tabulate[Partition](part.numPartitions)(i => new 
ShuffledRowRDDPartition(i))
+    assert(partitionStartIndices.length == part.numPartitions)
+    Array.tabulate[Partition](partitionStartIndices.length) { i =>
+      val startIndex = partitionStartIndices(i)
+      val endIndex =
+        if (i < partitionStartIndices.length - 1) {
+          partitionStartIndices(i + 1)
+        } else {
+          numPreShufflePartitions
+        }
+      new ShuffledRowRDDPartition(i, startIndex, endIndex)
+    }
   }
 
   override def getPreferredLocations(partition: Partition): Seq[String] = {
@@ -71,15 +154,20 @@ class ShuffledRowRDD(
   }
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
-    val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, 
InternalRow, InternalRow]]
-    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, 
split.index + 1, context)
-      .read()
-      .asInstanceOf[Iterator[Product2[Int, InternalRow]]]
-      .map(_._2)
+    val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition]
+    // The range of pre-shuffle partitions that we are fetching at here is
+    // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1].
+    val reader =
+      SparkEnv.get.shuffleManager.getReader(
+        dependency.shuffleHandle,
+        shuffledRowPartition.startPreShufflePartitionIndex,
+        shuffledRowPartition.endPreShufflePartitionIndex,
+        context)
+    reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
   }
 
   override def clearDependencies() {
     super.clearDependencies()
-    prev = null
+    dependency = null
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
new file mode 100644
index 0000000..25f2f5c
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -0,0 +1,479 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql._
+import org.apache.spark.{SparkFunSuite, SparkContext, SparkConf, 
MapOutputStatistics}
+
+class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
+
+  private var originalActiveSQLContext: Option[SQLContext] = _
+  private var originalInstantiatedSQLContext: Option[SQLContext] = _
+
+  override protected def beforeAll(): Unit = {
+    originalActiveSQLContext = SQLContext.getActiveContextOption()
+    originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
+
+    SQLContext.clearActive()
+    originalInstantiatedSQLContext.foreach(ctx => 
SQLContext.clearInstantiatedContext(ctx))
+  }
+
+  override protected def afterAll(): Unit = {
+    // Set these states back.
+    originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
+    originalInstantiatedSQLContext.foreach(ctx => 
SQLContext.setInstantiatedContext(ctx))
+  }
+
+  private def checkEstimation(
+      coordinator: ExchangeCoordinator,
+      bytesByPartitionIdArray: Array[Array[Long]],
+      expectedPartitionStartIndices: Array[Int]): Unit = {
+    val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map {
+      case (bytesByPartitionId, index) =>
+        new MapOutputStatistics(index, bytesByPartitionId)
+    }
+    val estimatedPartitionStartIndices =
+      coordinator.estimatePartitionStartIndices(mapOutputStatistics)
+    assert(estimatedPartitionStartIndices === expectedPartitionStartIndices)
+  }
+
+  test("test estimatePartitionStartIndices - 1 Exchange") {
+    val coordinator = new ExchangeCoordinator(1, 100L)
+
+    {
+      // All bytes per partition are 0.
+      val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(coordinator, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
+    }
+
+    {
+      // Some bytes per partition are 0 and total size is less than the target 
size.
+      // 1 post-shuffle partition is needed.
+      val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(coordinator, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
+    }
+
+    {
+      // 2 post-shuffle partitions are needed.
+      val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0)
+      val expectedPartitionStartIndices = Array[Int](0, 3)
+      checkEstimation(coordinator, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
+    }
+
+    {
+      // There are a few large pre-shuffle partitions.
+      val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4)
+      checkEstimation(coordinator, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
+    }
+
+    {
+      // All pre-shuffle partitions are larger than the targeted size.
+      val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      checkEstimation(coordinator, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
+    }
+
+    {
+      // The last pre-shuffle partition is in a single post-shuffle partition.
+      val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 4)
+      checkEstimation(coordinator, Array(bytesByPartitionId), 
expectedPartitionStartIndices)
+    }
+  }
+
+  test("test estimatePartitionStartIndices - 2 Exchanges") {
+    val coordinator = new ExchangeCoordinator(2, 100L)
+
+    {
+      // If there are multiple values of the number of pre-shuffle partitions,
+      // we should see an assertion error.
+      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
+      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0)
+      val mapOutputStatistics =
+        Array(
+          new MapOutputStatistics(0, bytesByPartitionId1),
+          new MapOutputStatistics(1, bytesByPartitionId2))
+      
intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics))
+    }
+
+    {
+      // All bytes per partition are 0.
+      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
+      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // Some bytes per partition are 0.
+      // 1 post-shuffle partition is needed.
+      val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // 2 post-shuffle partition are needed.
+      val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 3)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // 2 post-shuffle partition are needed.
+      val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 2)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // 2 post-shuffle partition are needed.
+      val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 2, 4)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // There are a few large pre-shuffle partitions.
+      val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 2, 3)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // All pairs of pre-shuffle partitions are larger than the targeted size.
+      val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0)
+      val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110)
+      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+  }
+
+  test("test estimatePartitionStartIndices and enforce minimal number of 
reducers") {
+    val coordinator = new ExchangeCoordinator(2, 100L, Some(2))
+
+    {
+      // The minimal number of post-shuffle partitions is not enforced because
+      // the size of data is 0.
+      val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
+      val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
+      val expectedPartitionStartIndices = Array[Int](0)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // The minimal number of post-shuffle partitions is enforced.
+      val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20)
+      val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5)
+      val expectedPartitionStartIndices = Array[Int](0, 3)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+
+    {
+      // The number of post-shuffle partitions is determined by the 
coordinator.
+      val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20)
+      val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30)
+      val expectedPartitionStartIndices = Array[Int](0, 2, 4)
+      checkEstimation(
+        coordinator,
+        Array(bytesByPartitionId1, bytesByPartitionId2),
+        expectedPartitionStartIndices)
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Query tests
+  ///////////////////////////////////////////////////////////////////////////
+
+  val numInputPartitions: Int = 10
+
+  def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
+    QueryTest.checkAnswer(actual, expectedAnswer) match {
+      case Some(errorMessage) => fail(errorMessage)
+      case None =>
+    }
+  }
+
+  def withSQLContext(
+      f: SQLContext => Unit,
+      targetNumPostShufflePartitions: Int,
+      minNumPostShufflePartitions: Option[Int]): Unit = {
+    val sparkConf =
+      new SparkConf(false)
+        .setMaster("local[*]")
+        .setAppName("test")
+        .set("spark.ui.enabled", "false")
+        .set("spark.driver.allowMultipleContexts", "true")
+        .set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
+        .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+        .set(
+          SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key,
+          targetNumPostShufflePartitions.toString)
+    minNumPostShufflePartitions match {
+      case Some(numPartitions) =>
+        sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, 
numPartitions.toString)
+      case None =>
+        sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1")
+    }
+    val sparkContext = new SparkContext(sparkConf)
+    val sqlContext = new TestSQLContext(sparkContext)
+    try f(sqlContext) finally sparkContext.stop()
+  }
+
+  Seq(Some(3), None).foreach { minNumPostShufflePartitions =>
+    val testNameNote = minNumPostShufflePartitions match {
+      case Some(numPartitions) => "(minNumPostShufflePartitions: 3)"
+      case None => ""
+    }
+
+    test(s"determining the number of reducers: aggregate 
operator$testNameNote") {
+      val test = { sqlContext: SQLContext =>
+        val df =
+          sqlContext
+            .range(0, 1000, 1, numInputPartitions)
+            .selectExpr("id % 20 as key", "id as value")
+        val agg = df.groupBy("key").count
+
+        // Check the answer first.
+        checkAnswer(
+          agg,
+          sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect())
+
+        // Then, let's look at the number of post-shuffle partitions estimated
+        // by the ExchangeCoordinator.
+        val exchanges = agg.queryExecution.executedPlan.collect {
+          case e: Exchange => e
+        }
+        assert(exchanges.length === 1)
+        minNumPostShufflePartitions match {
+          case Some(numPartitions) =>
+            exchanges.foreach {
+              case e: Exchange =>
+                assert(e.coordinator.isDefined)
+                assert(e.outputPartitioning.numPartitions === 3)
+              case o =>
+            }
+
+          case None =>
+            exchanges.foreach {
+              case e: Exchange =>
+                assert(e.coordinator.isDefined)
+                assert(e.outputPartitioning.numPartitions === 2)
+              case o =>
+            }
+        }
+      }
+
+      withSQLContext(test, 1536, minNumPostShufflePartitions)
+    }
+
+    test(s"determining the number of reducers: join operator$testNameNote") {
+      val test = { sqlContext: SQLContext =>
+        val df1 =
+          sqlContext
+            .range(0, 1000, 1, numInputPartitions)
+            .selectExpr("id % 500 as key1", "id as value1")
+        val df2 =
+          sqlContext
+            .range(0, 1000, 1, numInputPartitions)
+            .selectExpr("id % 500 as key2", "id as value2")
+
+        val join = df1.join(df2, col("key1") === 
col("key2")).select(col("key1"), col("value2"))
+
+        // Check the answer first.
+        val expectedAnswer =
+          sqlContext
+            .range(0, 1000)
+            .selectExpr("id % 500 as key", "id as value")
+            .unionAll(sqlContext.range(0, 1000).selectExpr("id % 500 as key", 
"id as value"))
+        checkAnswer(
+          join,
+          expectedAnswer.collect())
+
+        // Then, let's look at the number of post-shuffle partitions estimated
+        // by the ExchangeCoordinator.
+        val exchanges = join.queryExecution.executedPlan.collect {
+          case e: Exchange => e
+        }
+        assert(exchanges.length === 2)
+        minNumPostShufflePartitions match {
+          case Some(numPartitions) =>
+            exchanges.foreach {
+              case e: Exchange =>
+                assert(e.coordinator.isDefined)
+                assert(e.outputPartitioning.numPartitions === 3)
+              case o =>
+            }
+
+          case None =>
+            exchanges.foreach {
+              case e: Exchange =>
+                assert(e.coordinator.isDefined)
+                assert(e.outputPartitioning.numPartitions === 2)
+              case o =>
+            }
+        }
+      }
+
+      withSQLContext(test, 16384, minNumPostShufflePartitions)
+    }
+
+    test(s"determining the number of reducers: complex query 1$testNameNote") {
+      val test = { sqlContext: SQLContext =>
+        val df1 =
+          sqlContext
+            .range(0, 1000, 1, numInputPartitions)
+            .selectExpr("id % 500 as key1", "id as value1")
+            .groupBy("key1")
+            .count
+            .toDF("key1", "cnt1")
+        val df2 =
+          sqlContext
+            .range(0, 1000, 1, numInputPartitions)
+            .selectExpr("id % 500 as key2", "id as value2")
+            .groupBy("key2")
+            .count
+            .toDF("key2", "cnt2")
+
+        val join = df1.join(df2, col("key1") === 
col("key2")).select(col("key1"), col("cnt2"))
+
+        // Check the answer first.
+        val expectedAnswer =
+          sqlContext
+            .range(0, 500)
+            .selectExpr("id", "2 as cnt")
+        checkAnswer(
+          join,
+          expectedAnswer.collect())
+
+        // Then, let's look at the number of post-shuffle partitions estimated
+        // by the ExchangeCoordinator.
+        val exchanges = join.queryExecution.executedPlan.collect {
+          case e: Exchange => e
+        }
+        assert(exchanges.length === 4)
+        minNumPostShufflePartitions match {
+          case Some(numPartitions) =>
+            exchanges.foreach {
+              case e: Exchange =>
+                assert(e.coordinator.isDefined)
+                assert(e.outputPartitioning.numPartitions === 3)
+              case o =>
+            }
+
+          case None =>
+            assert(exchanges.forall(_.coordinator.isDefined))
+            
assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(1, 
2))
+        }
+      }
+
+      withSQLContext(test, 6144, minNumPostShufflePartitions)
+    }
+
+    test(s"determining the number of reducers: complex query 2$testNameNote") {
+      val test = { sqlContext: SQLContext =>
+        val df1 =
+          sqlContext
+            .range(0, 1000, 1, numInputPartitions)
+            .selectExpr("id % 500 as key1", "id as value1")
+            .groupBy("key1")
+            .count
+            .toDF("key1", "cnt1")
+        val df2 =
+          sqlContext
+            .range(0, 1000, 1, numInputPartitions)
+            .selectExpr("id % 500 as key2", "id as value2")
+
+        val join =
+          df1
+            .join(df2, col("key1") === col("key2"))
+            .select(col("key1"), col("cnt1"), col("value2"))
+
+        // Check the answer first.
+        val expectedAnswer =
+          sqlContext
+            .range(0, 1000)
+            .selectExpr("id % 500 as key", "2 as cnt", "id as value")
+        checkAnswer(
+          join,
+          expectedAnswer.collect())
+
+        // Then, let's look at the number of post-shuffle partitions estimated
+        // by the ExchangeCoordinator.
+        val exchanges = join.queryExecution.executedPlan.collect {
+          case e: Exchange => e
+        }
+        assert(exchanges.length === 3)
+        minNumPostShufflePartitions match {
+          case Some(numPartitions) =>
+            exchanges.foreach {
+              case e: Exchange =>
+                assert(e.coordinator.isDefined)
+                assert(e.outputPartitioning.numPartitions === 3)
+              case o =>
+            }
+
+          case None =>
+            assert(exchanges.forall(_.coordinator.isDefined))
+            
assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(2, 
3))
+        }
+      }
+
+      withSQLContext(test, 6144, minNumPostShufflePartitions)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index ebdab1c..2076c57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -268,7 +268,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) {
+    if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
       fail(s"Exchange should have been added:\n$outputPlan")
     }
   }
@@ -306,7 +306,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) {
+    if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
       fail(s"Exchange should have been added:\n$outputPlan")
     }
   }
@@ -326,7 +326,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
+    if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
       fail(s"Exchange should not have been added:\n$outputPlan")
     }
   }
@@ -349,7 +349,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
+    if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
       fail(s"No Exchanges should have been added:\n$outputPlan")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index d32572b..09e2582 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -152,7 +152,12 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with 
LocalSparkContext {
     val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
     val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, 
unsafeRow)))
       .asInstanceOf[RDD[Product2[Int, InternalRow]]]
-    val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2)
+    val dependency =
+      new ShuffleDependency[Int, InternalRow, InternalRow](
+        rowsRDD,
+        new PartitionIdPassthrough(2),
+        Some(new UnsafeRowSerializer(2)))
+    val shuffled = new ShuffledRowRDD(dependency)
     shuffled.count()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d728d5c9/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index da58e96..066c16e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -49,7 +49,16 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       Row(null, "e")
     )), new StructType().add("n", IntegerType).add("l", StringType))
 
-  private lazy val myTestData = Seq(
+  private lazy val myTestData1 = Seq(
+    (1, 1),
+    (1, 2),
+    (2, 1),
+    (2, 2),
+    (3, 1),
+    (3, 2)
+  ).toDF("a", "b")
+
+  private lazy val myTestData2 = Seq(
     (1, 1),
     (1, 2),
     (2, 1),
@@ -184,8 +193,8 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
   )
 
   {
-    lazy val left = myTestData.where("a = 1")
-    lazy val right = myTestData.where("a = 1")
+    lazy val left = myTestData1.where("a = 1")
+    lazy val right = myTestData2.where("a = 1")
     testInnerJoin(
       "inner join, multiple matches",
       left,
@@ -201,8 +210,8 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
   }
 
   {
-    lazy val left = myTestData.where("a = 1")
-    lazy val right = myTestData.where("a = 2")
+    lazy val left = myTestData1.where("a = 1")
+    lazy val right = myTestData2.where("a = 2")
     testInnerJoin(
       "inner join, no matches",
       left,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to