Repository: spark
Updated Branches:
  refs/heads/branch-1.5 1ce5061bb -> 323d68606


[SPARK-9703] [SQL] Refactor EnsureRequirements to avoid certain unnecessary 
shuffles

This pull request refactors the `EnsureRequirements` planning rule in order to 
avoid the addition of certain unnecessary shuffles.

As an example of how unnecessary shuffles can occur, consider SortMergeJoin, 
which requires clustered distribution and sorted ordering of its children's 
input rows. Say that both of SMJ's children produce unsorted output but are 
both SinglePartition. In this case, we will need to inject sort operators but 
should not need to inject Exchanges. Unfortunately, it looks like the 
EnsureRequirements unnecessarily repartitions using a hash partitioning.

This patch solves this problem by refactoring `EnsureRequirements` to properly 
implement the `compatibleWith` checks that were broken in earlier 
implementations. See the significant inline comments for a better description 
of how this works. The majority of this PR is new comments and test cases, with 
few actual changes to the code.

Author: Josh Rosen <[email protected]>

Closes #7988 from JoshRosen/exchange-fixes and squashes the following commits:

38006e7 [Josh Rosen] Rewrite EnsureRequirements _yet again_ to make things even 
simpler
0983f75 [Josh Rosen] More guarantees vs. compatibleWith cleanup; delete 
BroadcastPartitioning.
8784bd9 [Josh Rosen] Giant comment explaining compatibleWith vs. guarantees
1307c50 [Josh Rosen] Update conditions for requiring child compatibility.
18cddeb [Josh Rosen] Rename DummyPlan to DummySparkPlan.
2c7e126 [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
exchange-fixes
fee65c4 [Josh Rosen] Further refinement to comments / reasoning
642b0bb [Josh Rosen] Further expand comment / reasoning
06aba0c [Josh Rosen] Add more comments
8dbc845 [Josh Rosen] Add even more tests.
4f08278 [Josh Rosen] Fix the test by adding the compatibility check to 
EnsureRequirements
a1c12b9 [Josh Rosen] Add failing test to demonstrate allCompatible bug
0725a34 [Josh Rosen] Small assertion cleanup.
5172ac5 [Josh Rosen] Add test for 
requiresChildrenToProduceSameNumberOfPartitions.
2e0f33a [Josh Rosen] Write a more generic test for EnsureRequirements.
752b8de [Josh Rosen] style fix
c628daf [Josh Rosen] Revert accidental ExchangeSuite change.
c9fb231 [Josh Rosen] Rewrite exchange to fix better handle this case.
adcc742 [Josh Rosen] Move test to PlannerSuite.
0675956 [Josh Rosen] Preserving ordering and partitioning in row format 
converters also does not help.
cc5669c [Josh Rosen] Adding outputPartitioning to Repartition does not fix the 
test.
2dfc648 [Josh Rosen] Add failing test illustrating bad exchange planning.

(cherry picked from commit 23cf5af08d98da771c41571c00a2f5cafedfebdd)
Signed-off-by: Yin Huai <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/323d6860
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/323d6860
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/323d6860

Branch: refs/heads/branch-1.5
Commit: 323d686063937adb34729de16ea181bd56a97705
Parents: 1ce5061
Author: Josh Rosen <[email protected]>
Authored: Sun Aug 9 14:26:01 2015 -0700
Committer: Yin Huai <[email protected]>
Committed: Sun Aug 9 14:26:16 2015 -0700

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  | 128 ++++++++++++++--
 .../apache/spark/sql/execution/Exchange.scala   | 104 +++++++------
 .../spark/sql/execution/basicOperators.scala    |   5 +
 .../sql/execution/rowFormatConverters.scala     |   5 +
 .../spark/sql/execution/PlannerSuite.scala      | 151 +++++++++++++++++++
 5 files changed, 328 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/323d6860/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index ec659ce..5a89a90 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -75,6 +75,37 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) 
extends Distribution {
   def clustering: Set[Expression] = ordering.map(_.child).toSet
 }
 
+/**
+ * Describes how an operator's output is split across partitions. The 
`compatibleWith`,
+ * `guarantees`, and `satisfies` methods describe relationships between child 
partitionings,
+ * target partitionings, and [[Distribution]]s. These relations are described 
more precisely in
+ * their individual method docs, but at a high level:
+ *
+ *  - `satisfies` is a relationship between partitionings and distributions.
+ *  - `compatibleWith` is relationships between an operator's child output 
partitionings.
+ *  - `guarantees` is a relationship between a child's existing output 
partitioning and a target
+ *     output partitioning.
+ *
+ *  Diagrammatically:
+ *
+ *            +--------------+
+ *            | Distribution |
+ *            +--------------+
+ *                    ^
+ *                    |
+ *               satisfies
+ *                    |
+ *            +--------------+                  +--------------+
+ *            |    Child     |                  |    Target    |
+ *       +----| Partitioning |----guarantees--->| Partitioning |
+ *       |    +--------------+                  +--------------+
+ *       |            ^
+ *       |            |
+ *       |     compatibleWith
+ *       |            |
+ *       +------------+
+ *
+ */
 sealed trait Partitioning {
   /** Returns the number of partitions that the data is split across */
   val numPartitions: Int
@@ -90,9 +121,66 @@ sealed trait Partitioning {
   /**
    * Returns true iff we can say that the partitioning scheme of this 
[[Partitioning]]
    * guarantees the same partitioning scheme described by `other`.
+   *
+   * Compatibility of partitionings is only checked for operators that have 
multiple children
+   * and that require a specific child output [[Distribution]], such as joins.
+   *
+   * Intuitively, partitionings are compatible if they route the same 
partitioning key to the same
+   * partition. For instance, two hash partitionings are only compatible if 
they produce the same
+   * number of output partitionings and hash records according to the same 
hash function and
+   * same partitioning key schema.
+   *
+   * Put another way, two partitionings are compatible with each other if they 
satisfy all of the
+   * same distribution guarantees.
    */
-  // TODO: Add an example once we have the `nullSafe` concept.
-  def guarantees(other: Partitioning): Boolean
+  def compatibleWith(other: Partitioning): Boolean
+
+  /**
+   * Returns true iff we can say that the partitioning scheme of this 
[[Partitioning]] guarantees
+   * the same partitioning scheme described by `other`. If a 
`A.guarantees(B)`, then repartitioning
+   * the child's output according to `B` will be unnecessary. `guarantees` is 
used as a performance
+   * optimization to allow the exchange planner to avoid redundant 
repartitionings. By default,
+   * a partitioning only guarantees partitionings that are equal to itself 
(i.e. the same number
+   * of partitions, same strategy (range or hash), etc).
+   *
+   * In order to enable more aggressive optimization, this strict equality 
check can be relaxed.
+   * For example, say that the planner needs to repartition all of an 
operator's children so that
+   * they satisfy the [[AllTuples]] distribution. One way to do this is to 
repartition all children
+   * to have the [[SinglePartition]] partitioning. If one of the operator's 
children already happens
+   * to be hash-partitioned with a single partition then we do not need to 
re-shuffle this child;
+   * this repartitioning can be avoided if a single-partition 
[[HashPartitioning]] `guarantees`
+   * [[SinglePartition]].
+   *
+   * The SinglePartition example given above is not particularly interesting; 
guarantees' real
+   * value occurs for more advanced partitioning strategies. SPARK-7871 will 
introduce a notion
+   * of null-safe partitionings, under which partitionings can specify whether 
rows whose
+   * partitioning keys contain null values will be grouped into the same 
partition or whether they
+   * will have an unknown / random distribution. If a partitioning does not 
require nulls to be
+   * clustered then a partitioning which _does_ cluster nulls will guarantee 
the null clustered
+   * partitioning. The converse is not true, however: a partitioning which 
clusters nulls cannot
+   * be guaranteed by one which does not cluster them. Thus, in general 
`guarantees` is not a
+   * symmetric relation.
+   *
+   * Another way to think about `guarantees`: if `A.guarantees(B)`, then any 
partitioning of rows
+   * produced by `A` could have also been produced by `B`.
+   */
+  def guarantees(other: Partitioning): Boolean = this == other
+}
+
+object Partitioning {
+  def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
+    // Note: this assumes transitivity
+    partitionings.sliding(2).map {
+      case Seq(a) => true
+      case Seq(a, b) =>
+        if (a.numPartitions != b.numPartitions) {
+          assert(!a.compatibleWith(b) && !b.compatibleWith(a))
+          false
+        } else {
+          a.compatibleWith(b) && b.compatibleWith(a)
+        }
+    }.forall(_ == true)
+  }
 }
 
 case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -101,6 +189,8 @@ case class UnknownPartitioning(numPartitions: Int) extends 
Partitioning {
     case _ => false
   }
 
+  override def compatibleWith(other: Partitioning): Boolean = false
+
   override def guarantees(other: Partitioning): Boolean = false
 }
 
@@ -109,21 +199,9 @@ case object SinglePartition extends Partitioning {
 
   override def satisfies(required: Distribution): Boolean = true
 
-  override def guarantees(other: Partitioning): Boolean = other match {
-    case SinglePartition => true
-    case _ => false
-  }
-}
-
-case object BroadcastPartitioning extends Partitioning {
-  val numPartitions = 1
+  override def compatibleWith(other: Partitioning): Boolean = 
other.numPartitions == 1
 
-  override def satisfies(required: Distribution): Boolean = true
-
-  override def guarantees(other: Partitioning): Boolean = other match {
-    case BroadcastPartitioning => true
-    case _ => false
-  }
+  override def guarantees(other: Partitioning): Boolean = other.numPartitions 
== 1
 }
 
 /**
@@ -147,6 +225,12 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
     case _ => false
   }
 
+  override def compatibleWith(other: Partitioning): Boolean = other match {
+    case o: HashPartitioning =>
+      this.clusteringSet == o.clusteringSet && this.numPartitions == 
o.numPartitions
+    case _ => false
+  }
+
   override def guarantees(other: Partitioning): Boolean = other match {
     case o: HashPartitioning =>
       this.clusteringSet == o.clusteringSet && this.numPartitions == 
o.numPartitions
@@ -185,6 +269,11 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
     case _ => false
   }
 
+  override def compatibleWith(other: Partitioning): Boolean = other match {
+    case o: RangePartitioning => this == o
+    case _ => false
+  }
+
   override def guarantees(other: Partitioning): Boolean = other match {
     case o: RangePartitioning => this == o
     case _ => false
@@ -229,6 +318,13 @@ case class PartitioningCollection(partitionings: 
Seq[Partitioning])
     partitionings.exists(_.satisfies(required))
 
   /**
+   * Returns true if any `partitioning` of this collection is compatible with
+   * the given [[Partitioning]].
+   */
+  override def compatibleWith(other: Partitioning): Boolean =
+    partitionings.exists(_.compatibleWith(other))
+
+  /**
    * Returns true if any `partitioning` of this collection guarantees
    * the given [[Partitioning]].
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/323d6860/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 49bb729..b89e634 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -190,66 +190,72 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
  * of input data meets the
  * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] 
requirements for
  * each operator by inserting [[Exchange]] Operators where required.  Also 
ensure that the
- * required input partition ordering requirements are met.
+ * input partition ordering requirements are met.
  */
 private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends 
Rule[SparkPlan] {
   // TODO: Determine the number of partitions.
-  def numPartitions: Int = sqlContext.conf.numShufflePartitions
+  private def numPartitions: Int = sqlContext.conf.numShufflePartitions
 
-  def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case operator: SparkPlan =>
-      // Adds Exchange or Sort operators as required
-      def addOperatorsIfNecessary(
-          partitioning: Partitioning,
-          rowOrdering: Seq[SortOrder],
-          child: SparkPlan): SparkPlan = {
-
-        def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
-          if (!child.outputPartitioning.guarantees(partitioning)) {
-            Exchange(partitioning, child)
-          } else {
-            child
-          }
-        }
+  /**
+   * Given a required distribution, returns a partitioning that satisfies that 
distribution.
+   */
+  private def canonicalPartitioning(requiredDistribution: Distribution): 
Partitioning = {
+    requiredDistribution match {
+      case AllTuples => SinglePartition
+      case ClusteredDistribution(clustering) => HashPartitioning(clustering, 
numPartitions)
+      case OrderedDistribution(ordering) => RangePartitioning(ordering, 
numPartitions)
+      case dist => sys.error(s"Do not know how to satisfy distribution $dist")
+    }
+  }
 
-        def addSortIfNecessary(child: SparkPlan): SparkPlan = {
-
-          if (rowOrdering.nonEmpty) {
-            // If child.outputOrdering is [a, b] and rowOrdering is [a], we do 
not need to sort.
-            val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
-            if (minSize == 0 || rowOrdering.take(minSize) != 
child.outputOrdering.take(minSize)) {
-              sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, 
global = false, child)
-            } else {
-              child
-            }
-          } else {
-            child
-          }
-        }
+  private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
+    val requiredChildDistributions: Seq[Distribution] = 
operator.requiredChildDistribution
+    val requiredChildOrderings: Seq[Seq[SortOrder]] = 
operator.requiredChildOrdering
+    var children: Seq[SparkPlan] = operator.children
 
-        addSortIfNecessary(addShuffleIfNecessary(child))
+    // Ensure that the operator's children satisfy their output distribution 
requirements:
+    children = children.zip(requiredChildDistributions).map { case (child, 
distribution) =>
+      if (child.outputPartitioning.satisfies(distribution)) {
+        child
+      } else {
+        Exchange(canonicalPartitioning(distribution), child)
       }
+    }
 
-      val requirements =
-        (operator.requiredChildDistribution, operator.requiredChildOrdering, 
operator.children)
-
-      val fixedChildren = requirements.zipped.map {
-        case (AllTuples, rowOrdering, child) =>
-          addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
-        case (ClusteredDistribution(clustering), rowOrdering, child) =>
-          addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), 
rowOrdering, child)
-        case (OrderedDistribution(ordering), rowOrdering, child) =>
-          addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), 
rowOrdering, child)
-
-        case (UnspecifiedDistribution, Seq(), child) =>
+    // If the operator has multiple children and specifies child output 
distributions (e.g. join),
+    // then the children's output partitionings must be compatible:
+    if (children.length > 1
+        && requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
+        && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+      children = children.zip(requiredChildDistributions).map { case (child, 
distribution) =>
+        val targetPartitioning = canonicalPartitioning(distribution)
+        if (child.outputPartitioning.guarantees(targetPartitioning)) {
           child
-        case (UnspecifiedDistribution, rowOrdering, child) =>
-          sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, 
global = false, child)
+        } else {
+          Exchange(targetPartitioning, child)
+        }
+      }
+    }
 
-        case (dist, ordering, _) =>
-          sys.error(s"Don't know how to ensure $dist with ordering $ordering")
+    // Now that we've performed any necessary shuffles, add sorts to guarantee 
output orderings:
+    children = children.zip(requiredChildOrderings).map { case (child, 
requiredOrdering) =>
+      if (requiredOrdering.nonEmpty) {
+        // If child.outputOrdering is [a, b] and requiredOrdering is [a], we 
do not need to sort.
+        val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min
+        if (minSize == 0 || requiredOrdering.take(minSize) != 
child.outputOrdering.take(minSize)) {
+          sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, 
global = false, child)
+        } else {
+          child
+        }
+      } else {
+        child
       }
+    }
 
-      operator.withNewChildren(fixedChildren)
+    operator.withNewChildren(children)
+  }
+
+  def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+    case operator: SparkPlan => ensureDistributionAndOrdering(operator)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/323d6860/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index c5d1ed0..24950f2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -256,6 +256,11 @@ case class Repartition(numPartitions: Int, shuffle: 
Boolean, child: SparkPlan)
   extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 
+  override def outputPartitioning: Partitioning = {
+    if (numPartitions == 1) SinglePartition
+    else UnknownPartitioning(numPartitions)
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     child.execute().map(_.copy()).coalesce(numPartitions, shuffle)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/323d6860/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
index 29f3beb..855555d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 
 /**
@@ -33,6 +34,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends 
UnaryNode {
   require(UnsafeProjection.canSupport(child.schema), s"Cannot convert 
${child.schema} to Unsafe")
 
   override def output: Seq[Attribute] = child.output
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
   override def outputsUnsafeRows: Boolean = true
   override def canProcessUnsafeRows: Boolean = false
   override def canProcessSafeRows: Boolean = true
@@ -51,6 +54,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends 
UnaryNode {
 @DeveloperApi
 case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
   override def outputsUnsafeRows: Boolean = false
   override def canProcessUnsafeRows: Boolean = true
   override def canProcessSafeRows: Boolean = false

http://git-wip-us.apache.org/repos/asf/spark/blob/323d6860/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 18b0e54..5582caa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,9 +18,13 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, 
ShuffledHashJoin}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
@@ -202,4 +206,151 @@ class PlannerSuite extends SparkFunSuite with 
SQLTestUtils {
       }
     }
   }
+
+  // --- Unit tests of EnsureRequirements 
---------------------------------------------------------
+
+  // When it comes to testing whether EnsureRequirements properly ensures 
distribution requirements,
+  // there two dimensions that need to be considered: are the child 
partitionings compatible and
+  // do they satisfy the distribution requirements? As a result, we need at 
least four test cases.
+
+  private def assertDistributionRequirementsAreSatisfied(outputPlan: 
SparkPlan): Unit = {
+    if (outputPlan.children.length > 1
+        && outputPlan.requiredChildDistribution.toSet != 
Set(UnspecifiedDistribution)) {
+      val childPartitionings = outputPlan.children.map(_.outputPartitioning)
+      if (!Partitioning.allCompatible(childPartitionings)) {
+        fail(s"Partitionings are not compatible: $childPartitionings")
+      }
+    }
+    outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
+      case (child, requiredDist) =>
+        assert(child.outputPartitioning.satisfies(requiredDist),
+          s"$child output partitioning does not satisfy 
$requiredDist:\n$outputPlan")
+    }
+  }
+
+  test("EnsureRequirements with incompatible child partitionings which satisfy 
distribution") {
+    // Consider an operator that requires inputs that are clustered by two 
expressions (e.g.
+    // sort merge join where there are multiple columns in the equi-join 
condition)
+    val clusteringA = Literal(1) :: Nil
+    val clusteringB = Literal(2) :: Nil
+    val distribution = ClusteredDistribution(clusteringA ++ clusteringB)
+    // Say that the left and right inputs are each partitioned by _one_ of the 
two join columns:
+    val leftPartitioning = HashPartitioning(clusteringA, 1)
+    val rightPartitioning = HashPartitioning(clusteringB, 1)
+    // Individually, each input's partitioning satisfies the clustering 
distribution:
+    assert(leftPartitioning.satisfies(distribution))
+    assert(rightPartitioning.satisfies(distribution))
+    // However, these partitionings are not compatible with each other, so we 
still need to
+    // repartition both inputs prior to performing the join:
+    assert(!leftPartitioning.compatibleWith(rightPartitioning))
+    assert(!rightPartitioning.compatibleWith(leftPartitioning))
+    val inputPlan = DummySparkPlan(
+      children = Seq(
+        DummySparkPlan(outputPartitioning = leftPartitioning),
+        DummySparkPlan(outputPartitioning = rightPartitioning)
+      ),
+      requiredChildDistribution = Seq(distribution, distribution),
+      requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+    )
+    val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+    assertDistributionRequirementsAreSatisfied(outputPlan)
+    if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) {
+      fail(s"Exchange should have been added:\n$outputPlan")
+    }
+  }
+
+  test("EnsureRequirements with child partitionings with different numbers of 
output partitions") {
+    // This is similar to the previous test, except it checks that 
partitionings are not compatible
+    // unless they produce the same number of partitions.
+    val clustering = Literal(1) :: Nil
+    val distribution = ClusteredDistribution(clustering)
+    val inputPlan = DummySparkPlan(
+      children = Seq(
+        DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 1)),
+        DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 2))
+      ),
+      requiredChildDistribution = Seq(distribution, distribution),
+      requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+    )
+    val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+    assertDistributionRequirementsAreSatisfied(outputPlan)
+  }
+
+  test("EnsureRequirements with compatible child partitionings that do not 
satisfy distribution") {
+    val distribution = ClusteredDistribution(Literal(1) :: Nil)
+    // The left and right inputs have compatible partitionings but they do not 
satisfy the
+    // distribution because they are clustered on different columns. Thus, we 
need to shuffle.
+    val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1)
+    assert(!childPartitioning.satisfies(distribution))
+    val inputPlan = DummySparkPlan(
+      children = Seq(
+        DummySparkPlan(outputPartitioning = childPartitioning),
+        DummySparkPlan(outputPartitioning = childPartitioning)
+      ),
+      requiredChildDistribution = Seq(distribution, distribution),
+      requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+    )
+    val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+    assertDistributionRequirementsAreSatisfied(outputPlan)
+    if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) {
+      fail(s"Exchange should have been added:\n$outputPlan")
+    }
+  }
+
+  test("EnsureRequirements with compatible child partitionings that satisfy 
distribution") {
+    // In this case, all requirements are satisfied and no exchange should be 
added.
+    val distribution = ClusteredDistribution(Literal(1) :: Nil)
+    val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
+    assert(childPartitioning.satisfies(distribution))
+    val inputPlan = DummySparkPlan(
+      children = Seq(
+        DummySparkPlan(outputPartitioning = childPartitioning),
+        DummySparkPlan(outputPartitioning = childPartitioning)
+      ),
+      requiredChildDistribution = Seq(distribution, distribution),
+      requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+    )
+    val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+    assertDistributionRequirementsAreSatisfied(outputPlan)
+    if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
+      fail(s"Exchange should not have been added:\n$outputPlan")
+    }
+  }
+
+  // This is a regression test for SPARK-9703
+  test("EnsureRequirements should not repartition if only ordering requirement 
is unsatisfied") {
+    // Consider an operator that imposes both output distribution and  
ordering requirements on its
+    // children, such as sort sort merge join. If the distribution 
requirements are satisfied but
+    // the output ordering requirements are unsatisfied, then the planner 
should only add sorts and
+    // should not need to add additional shuffles / exchanges.
+    val outputOrdering = Seq(SortOrder(Literal(1), Ascending))
+    val distribution = ClusteredDistribution(Literal(1) :: Nil)
+    val inputPlan = DummySparkPlan(
+      children = Seq(
+        DummySparkPlan(outputPartitioning = SinglePartition),
+        DummySparkPlan(outputPartitioning = SinglePartition)
+      ),
+      requiredChildDistribution = Seq(distribution, distribution),
+      requiredChildOrdering = Seq(outputOrdering, outputOrdering)
+    )
+    val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+    assertDistributionRequirementsAreSatisfied(outputPlan)
+    if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
+      fail(s"No Exchanges should have been added:\n$outputPlan")
+    }
+  }
+
+  // 
---------------------------------------------------------------------------------------------
+}
+
+// Used for unit-testing EnsureRequirements
+private case class DummySparkPlan(
+    override val children: Seq[SparkPlan] = Nil,
+    override val outputOrdering: Seq[SortOrder] = Nil,
+    override val outputPartitioning: Partitioning = UnknownPartitioning(0),
+    override val requiredChildDistribution: Seq[Distribution] = Nil,
+    override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil
+  ) extends SparkPlan {
+  override protected def doExecute(): RDD[InternalRow] = throw new 
NotImplementedError
+  override def output: Seq[Attribute] = Seq.empty
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to