Repository: spark
Updated Branches:
  refs/heads/master 2c73d2a94 -> eb45b52e8


[SPARK-21865][SQL] simplify the distribution semantic of Spark SQL

## What changes were proposed in this pull request?

**The current shuffle planning logic**

1. Each operator specifies the distribution requirements for its children, via 
the `Distribution` interface.
2. Each operator specifies its output partitioning, via the `Partitioning` 
interface.
3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a 
`Distribution`.
4. For each operator, check each child of it, add a shuffle node above the 
child if the child partitioning can not satisfy the required distribution.
5. For each operator, check if its children's output partitionings are 
compatible with each other, via the `Partitioning.compatibleWith`.
6. If the check in 5 failed, add a shuffle above each child.
7. try to eliminate the shuffles added in 6, via `Partitioning.guarantees`.

This design has a major problem with the definition of "compatible".

`Partitioning.compatibleWith` is not well defined, ideally a `Partitioning` 
can't know if it's compatible with other `Partitioning`, without more 
information from the operator. For example, `t1 join t2 on t1.a = t2.b`, 
`HashPartitioning(a, 10)` should be compatible with `HashPartitioning(b, 10)` 
under this case, but the partitioning itself doesn't know it.

As a result, currently `Partitioning.compatibleWith` always return false except 
for literals, which make it almost useless. This also means, if an operator has 
distribution requirements for multiple children, Spark always add shuffle nodes 
to all the children(although some of them can be eliminated). However, there is 
no guarantee that the children's output partitionings are compatible with each 
other after adding these shuffles, we just assume that the operator will only 
specify `ClusteredDistribution` for multiple children.

I think it's very hard to guarantee children co-partition for all kinds of 
operators, and we can not even give a clear definition about co-partition 
between distributions like `ClusteredDistribution(a,b)` and 
`ClusteredDistribution(c)`.

I think we should drop the "compatible" concept in the distribution model, and 
let the operator achieve the co-partition requirement by special distribution 
requirements.

**Proposed shuffle planning logic after this PR**
(The first 4 are same as before)
1. Each operator specifies the distribution requirements for its children, via 
the `Distribution` interface.
2. Each operator specifies its output partitioning, via the `Partitioning` 
interface.
3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a 
`Distribution`.
4. For each operator, check each child of it, add a shuffle node above the 
child if the child partitioning can not satisfy the required distribution.
5. For each operator, check if its children's output partitionings have the 
same number of partitions.
6. If the check in 5 failed, pick the max number of partitions from children's 
output partitionings, and add shuffle to child whose number of partitions 
doesn't equal to the max one.

The new distribution model is very simple, we only have one kind of 
relationship, which is `Partitioning.satisfy`. For multiple children, Spark 
only guarantees they have the same number of partitions, and it's the 
operator's responsibility to leverage this guarantee to achieve more 
complicated requirements. For example, non-broadcast joins can use the newly 
added `HashPartitionedDistribution` to achieve co-partition.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <[email protected]>

Closes #19080 from cloud-fan/exchange.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eb45b52e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eb45b52e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eb45b52e

Branch: refs/heads/master
Commit: eb45b52e826ea9cea48629760db35ef87f91fea0
Parents: 2c73d2a
Author: Wenchen Fan <[email protected]>
Authored: Mon Jan 8 19:41:41 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Mon Jan 8 19:41:41 2018 +0800

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  | 286 +++++++------------
 .../spark/sql/catalyst/PartitioningSuite.scala  |  55 ----
 .../apache/spark/sql/execution/SparkPlan.scala  |  16 +-
 .../execution/exchange/EnsureRequirements.scala | 120 +++-----
 .../execution/joins/ShuffledHashJoinExec.scala  |   2 +-
 .../sql/execution/joins/SortMergeJoinExec.scala |   2 +-
 .../apache/spark/sql/execution/objects.scala    |   2 +-
 .../spark/sql/execution/PlannerSuite.scala      |  81 ++----
 8 files changed, 194 insertions(+), 370 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index e57c842..0189bd7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
  *  - Intra-partition ordering of data: In this case the distribution 
describes guarantees made
  *    about how tuples are distributed within a single partition.
  */
-sealed trait Distribution
+sealed trait Distribution {
+  /**
+   * The required number of partitions for this distribution. If it's None, 
then any number of
+   * partitions is allowed for this distribution.
+   */
+  def requiredNumPartitions: Option[Int]
+
+  /**
+   * Creates a default partitioning for this distribution, which can satisfy 
this distribution while
+   * matching the given number of partitions.
+   */
+  def createPartitioning(numPartitions: Int): Partitioning
+}
 
 /**
  * Represents a distribution where no promises are made about co-location of 
data.
  */
-case object UnspecifiedDistribution extends Distribution
+case object UnspecifiedDistribution extends Distribution {
+  override def requiredNumPartitions: Option[Int] = None
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    throw new IllegalStateException("UnspecifiedDistribution does not have 
default partitioning.")
+  }
+}
 
 /**
  * Represents a distribution that only has a single partition and all tuples 
of the dataset
  * are co-located.
  */
-case object AllTuples extends Distribution
+case object AllTuples extends Distribution {
+  override def requiredNumPartitions: Option[Int] = Some(1)
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    assert(numPartitions == 1, "The default partitioning of AllTuples can only 
have 1 partition.")
+    SinglePartition
+  }
+}
 
 /**
  * Represents data where tuples that share the same values for the `clustering`
@@ -51,12 +76,41 @@ case object AllTuples extends Distribution
  */
 case class ClusteredDistribution(
     clustering: Seq[Expression],
-    numPartitions: Option[Int] = None) extends Distribution {
+    requiredNumPartitions: Option[Int] = None) extends Distribution {
   require(
     clustering != Nil,
     "The clustering expressions of a ClusteredDistribution should not be Nil. 
" +
       "An AllTuples should be used to represent a distribution that only has " 
+
       "a single partition.")
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == 
numPartitions,
+      s"This ClusteredDistribution requires ${requiredNumPartitions.get} 
partitions, but " +
+        s"the actual number of partitions is $numPartitions.")
+    HashPartitioning(clustering, numPartitions)
+  }
+}
+
+/**
+ * Represents data where tuples have been clustered according to the hash of 
the given
+ * `expressions`. The hash function is defined as 
`HashPartitioning.partitionIdExpression`, so only
+ * [[HashPartitioning]] can satisfy this distribution.
+ *
+ * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given 
a tuple and the
+ * number of partitions, this distribution strictly requires which partition 
the tuple should be in.
+ */
+case class HashClusteredDistribution(expressions: Seq[Expression]) extends 
Distribution {
+  require(
+    expressions != Nil,
+    "The expressions for hash of a HashPartitionedDistribution should not be 
Nil. " +
+      "An AllTuples should be used to represent a distribution that only has " 
+
+      "a single partition.")
+
+  override def requiredNumPartitions: Option[Int] = None
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    HashPartitioning(expressions, numPartitions)
+  }
 }
 
 /**
@@ -73,46 +127,31 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) 
extends Distribution {
       "An AllTuples should be used to represent a distribution that only has " 
+
       "a single partition.")
 
-  // TODO: This is not really valid...
-  def clustering: Set[Expression] = ordering.map(_.child).toSet
+  override def requiredNumPartitions: Option[Int] = None
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    RangePartitioning(ordering, numPartitions)
+  }
 }
 
 /**
  * Represents data where tuples are broadcasted to every node. It is quite 
common that the
  * entire set of tuples is transformed into different data structure.
  */
-case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
+case class BroadcastDistribution(mode: BroadcastMode) extends Distribution {
+  override def requiredNumPartitions: Option[Int] = Some(1)
+
+  override def createPartitioning(numPartitions: Int): Partitioning = {
+    assert(numPartitions == 1,
+      "The default partitioning of BroadcastDistribution can only have 1 
partition.")
+    BroadcastPartitioning(mode)
+  }
+}
 
 /**
- * Describes how an operator's output is split across partitions. The 
`compatibleWith`,
- * `guarantees`, and `satisfies` methods describe relationships between child 
partitionings,
- * target partitionings, and [[Distribution]]s. These relations are described 
more precisely in
- * their individual method docs, but at a high level:
- *
- *  - `satisfies` is a relationship between partitionings and distributions.
- *  - `compatibleWith` is relationships between an operator's child output 
partitionings.
- *  - `guarantees` is a relationship between a child's existing output 
partitioning and a target
- *     output partitioning.
- *
- *  Diagrammatically:
- *
- *            +--------------+
- *            | Distribution |
- *            +--------------+
- *                    ^
- *                    |
- *               satisfies
- *                    |
- *            +--------------+                  +--------------+
- *            |    Child     |                  |    Target    |
- *       +----| Partitioning |----guarantees--->| Partitioning |
- *       |    +--------------+                  +--------------+
- *       |            ^
- *       |            |
- *       |     compatibleWith
- *       |            |
- *       +------------+
- *
+ * Describes how an operator's output is split across partitions. It has 2 
major properties:
+ *   1. number of partitions.
+ *   2. if it can satisfy a given distribution.
  */
 sealed trait Partitioning {
   /** Returns the number of partitions that the data is split across */
@@ -123,113 +162,35 @@ sealed trait Partitioning {
    * to satisfy the partitioning scheme mandated by the `required` 
[[Distribution]],
    * i.e. the current dataset does not need to be re-partitioned for the 
`required`
    * Distribution (it is possible that tuples within a partition need to be 
reorganized).
-   */
-  def satisfies(required: Distribution): Boolean
-
-  /**
-   * Returns true iff we can say that the partitioning scheme of this 
[[Partitioning]]
-   * guarantees the same partitioning scheme described by `other`.
-   *
-   * Compatibility of partitionings is only checked for operators that have 
multiple children
-   * and that require a specific child output [[Distribution]], such as joins.
-   *
-   * Intuitively, partitionings are compatible if they route the same 
partitioning key to the same
-   * partition. For instance, two hash partitionings are only compatible if 
they produce the same
-   * number of output partitionings and hash records according to the same 
hash function and
-   * same partitioning key schema.
-   *
-   * Put another way, two partitionings are compatible with each other if they 
satisfy all of the
-   * same distribution guarantees.
-   */
-  def compatibleWith(other: Partitioning): Boolean
-
-  /**
-   * Returns true iff we can say that the partitioning scheme of this 
[[Partitioning]] guarantees
-   * the same partitioning scheme described by `other`. If a 
`A.guarantees(B)`, then repartitioning
-   * the child's output according to `B` will be unnecessary. `guarantees` is 
used as a performance
-   * optimization to allow the exchange planner to avoid redundant 
repartitionings. By default,
-   * a partitioning only guarantees partitionings that are equal to itself 
(i.e. the same number
-   * of partitions, same strategy (range or hash), etc).
-   *
-   * In order to enable more aggressive optimization, this strict equality 
check can be relaxed.
-   * For example, say that the planner needs to repartition all of an 
operator's children so that
-   * they satisfy the [[AllTuples]] distribution. One way to do this is to 
repartition all children
-   * to have the [[SinglePartition]] partitioning. If one of the operator's 
children already happens
-   * to be hash-partitioned with a single partition then we do not need to 
re-shuffle this child;
-   * this repartitioning can be avoided if a single-partition 
[[HashPartitioning]] `guarantees`
-   * [[SinglePartition]].
-   *
-   * The SinglePartition example given above is not particularly interesting; 
guarantees' real
-   * value occurs for more advanced partitioning strategies. SPARK-7871 will 
introduce a notion
-   * of null-safe partitionings, under which partitionings can specify whether 
rows whose
-   * partitioning keys contain null values will be grouped into the same 
partition or whether they
-   * will have an unknown / random distribution. If a partitioning does not 
require nulls to be
-   * clustered then a partitioning which _does_ cluster nulls will guarantee 
the null clustered
-   * partitioning. The converse is not true, however: a partitioning which 
clusters nulls cannot
-   * be guaranteed by one which does not cluster them. Thus, in general 
`guarantees` is not a
-   * symmetric relation.
    *
-   * Another way to think about `guarantees`: if `A.guarantees(B)`, then any 
partitioning of rows
-   * produced by `A` could have also been produced by `B`.
+   * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], 
and [[AllTuples]] if
+   * the [[Partitioning]] only have one partition. Implementations can 
overwrite this method with
+   * special logic.
    */
-  def guarantees(other: Partitioning): Boolean = this == other
-}
-
-object Partitioning {
-  def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
-    // Note: this assumes transitivity
-    partitionings.sliding(2).map {
-      case Seq(a) => true
-      case Seq(a, b) =>
-        if (a.numPartitions != b.numPartitions) {
-          assert(!a.compatibleWith(b) && !b.compatibleWith(a))
-          false
-        } else {
-          a.compatibleWith(b) && b.compatibleWith(a)
-        }
-    }.forall(_ == true)
-  }
-}
-
-case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
-  override def satisfies(required: Distribution): Boolean = required match {
+  def satisfies(required: Distribution): Boolean = required match {
     case UnspecifiedDistribution => true
+    case AllTuples => numPartitions == 1
     case _ => false
   }
-
-  override def compatibleWith(other: Partitioning): Boolean = false
-
-  override def guarantees(other: Partitioning): Boolean = false
 }
 
+case class UnknownPartitioning(numPartitions: Int) extends Partitioning
+
 /**
  * Represents a partitioning where rows are distributed evenly across output 
partitions
  * by starting from a random target partition number and distributing rows in 
a round-robin
  * fashion. This partitioning is used when implementing the 
DataFrame.repartition() operator.
  */
-case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning {
-  override def satisfies(required: Distribution): Boolean = required match {
-    case UnspecifiedDistribution => true
-    case _ => false
-  }
-
-  override def compatibleWith(other: Partitioning): Boolean = false
-
-  override def guarantees(other: Partitioning): Boolean = false
-}
+case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning
 
 case object SinglePartition extends Partitioning {
   val numPartitions = 1
 
   override def satisfies(required: Distribution): Boolean = required match {
     case _: BroadcastDistribution => false
-    case ClusteredDistribution(_, desiredPartitions) => 
desiredPartitions.forall(_ == 1)
+    case ClusteredDistribution(_, Some(requiredNumPartitions)) => 
requiredNumPartitions == 1
     case _ => true
   }
-
-  override def compatibleWith(other: Partitioning): Boolean = 
other.numPartitions == 1
-
-  override def guarantees(other: Partitioning): Boolean = other.numPartitions 
== 1
 }
 
 /**
@@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
   override def nullable: Boolean = false
   override def dataType: DataType = IntegerType
 
-  override def satisfies(required: Distribution): Boolean = required match {
-    case UnspecifiedDistribution => true
-    case ClusteredDistribution(requiredClustering, desiredPartitions) =>
-      expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) 
&&
-        desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = 
None, returns true
-    case _ => false
-  }
-
-  override def compatibleWith(other: Partitioning): Boolean = other match {
-    case o: HashPartitioning => this.semanticEquals(o)
-    case _ => false
-  }
-
-  override def guarantees(other: Partitioning): Boolean = other match {
-    case o: HashPartitioning => this.semanticEquals(o)
-    case _ => false
+  override def satisfies(required: Distribution): Boolean = {
+    super.satisfies(required) || {
+      required match {
+        case h: HashClusteredDistribution =>
+          expressions.length == h.expressions.length && 
expressions.zip(h.expressions).forall {
+            case (l, r) => l.semanticEquals(r)
+          }
+        case ClusteredDistribution(requiredClustering, requiredNumPartitions) 
=>
+          expressions.forall(x => 
requiredClustering.exists(_.semanticEquals(x))) &&
+            (requiredNumPartitions.isEmpty || requiredNumPartitions.get == 
numPartitions)
+        case _ => false
+      }
+    }
   }
 
   /**
@@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
   override def nullable: Boolean = false
   override def dataType: DataType = IntegerType
 
-  override def satisfies(required: Distribution): Boolean = required match {
-    case UnspecifiedDistribution => true
-    case OrderedDistribution(requiredOrdering) =>
-      val minSize = Seq(requiredOrdering.size, ordering.size).min
-      requiredOrdering.take(minSize) == ordering.take(minSize)
-    case ClusteredDistribution(requiredClustering, desiredPartitions) =>
-      ordering.map(_.child).forall(x => 
requiredClustering.exists(_.semanticEquals(x))) &&
-        desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = 
None, returns true
-    case _ => false
-  }
-
-  override def compatibleWith(other: Partitioning): Boolean = other match {
-    case o: RangePartitioning => this.semanticEquals(o)
-    case _ => false
-  }
-
-  override def guarantees(other: Partitioning): Boolean = other match {
-    case o: RangePartitioning => this.semanticEquals(o)
-    case _ => false
+  override def satisfies(required: Distribution): Boolean = {
+    super.satisfies(required) || {
+      required match {
+        case OrderedDistribution(requiredOrdering) =>
+          val minSize = Seq(requiredOrdering.size, ordering.size).min
+          requiredOrdering.take(minSize) == ordering.take(minSize)
+        case ClusteredDistribution(requiredClustering, requiredNumPartitions) 
=>
+          ordering.map(_.child).forall(x => 
requiredClustering.exists(_.semanticEquals(x))) &&
+            (requiredNumPartitions.isEmpty || requiredNumPartitions.get == 
numPartitions)
+        case _ => false
+      }
+    }
   }
 }
 
@@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: 
Seq[Partitioning])
   override def satisfies(required: Distribution): Boolean =
     partitionings.exists(_.satisfies(required))
 
-  /**
-   * Returns true if any `partitioning` of this collection is compatible with
-   * the given [[Partitioning]].
-   */
-  override def compatibleWith(other: Partitioning): Boolean =
-    partitionings.exists(_.compatibleWith(other))
-
-  /**
-   * Returns true if any `partitioning` of this collection guarantees
-   * the given [[Partitioning]].
-   */
-  override def guarantees(other: Partitioning): Boolean =
-    partitionings.exists(_.guarantees(other))
-
   override def toString: String = {
     partitionings.map(_.toString).mkString("(", " or ", ")")
   }
@@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) 
extends Partitioning {
     case BroadcastDistribution(m) if m == mode => true
     case _ => false
   }
-
-  override def compatibleWith(other: Partitioning): Boolean = other match {
-    case BroadcastPartitioning(m) if m == mode => true
-    case _ => false
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
deleted file mode 100644
index 5b802cc..0000000
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import org.apache.spark.SparkFunSuite
-import 
org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, 
Literal}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
HashPartitioning}
-
-class PartitioningSuite extends SparkFunSuite {
-  test("HashPartitioning compatibility should be sensitive to expression 
ordering (SPARK-9785)") {
-    val expressions = Seq(Literal(2), Literal(3))
-    // Consider two HashPartitionings that have the same _set_ of hash 
expressions but which are
-    // created with different orderings of those expressions:
-    val partitioningA = HashPartitioning(expressions, 100)
-    val partitioningB = HashPartitioning(expressions.reverse, 100)
-    // These partitionings are not considered equal:
-    assert(partitioningA != partitioningB)
-    // However, they both satisfy the same clustered distribution:
-    val distribution = ClusteredDistribution(expressions)
-    assert(partitioningA.satisfies(distribution))
-    assert(partitioningB.satisfies(distribution))
-    // These partitionings compute different hashcodes for the same input row:
-    def computeHashCode(partitioning: HashPartitioning): Int = {
-      val hashExprProj = new 
InterpretedMutableProjection(partitioning.expressions, Seq.empty)
-      hashExprProj.apply(InternalRow.empty).hashCode()
-    }
-    assert(computeHashCode(partitioningA) != computeHashCode(partitioningB))
-    // Thus, these partitionings are incompatible:
-    assert(!partitioningA.compatibleWith(partitioningB))
-    assert(!partitioningB.compatibleWith(partitioningA))
-    assert(!partitioningA.guarantees(partitioningB))
-    assert(!partitioningB.guarantees(partitioningA))
-
-    // Just to be sure that we haven't cheated by having these methods always 
return false,
-    // check that identical partitionings are still compatible with and 
guarantee each other:
-    assert(partitioningA === partitioningA)
-    assert(partitioningA.guarantees(partitioningA))
-    assert(partitioningA.compatibleWith(partitioningA))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 787c1cf..82300ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with 
Logging with Serializ
   /** Specifies how data is partitioned across different nodes in the cluster. 
*/
   def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG 
WIDTH!
 
-  /** Specifies any partition requirements on the input data for this 
operator. */
+  /**
+   * Specifies the data distribution requirements of all the children for this 
operator. By default
+   * it's [[UnspecifiedDistribution]] for each child, which means each child 
can have any
+   * distribution.
+   *
+   * If an operator overwrites this method, and specifies distribution 
requirements(excluding
+   * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than 
one child, Spark
+   * guarantees that the outputs of these children will have same number of 
partitions, so that the
+   * operator can safely zip partitions of these children's result RDDs. Some 
operators can leverage
+   * this guarantee to satisfy some interesting requirement, e.g., 
non-broadcast joins can specify
+   * HashClusteredDistribution(a,b) for its left child, and specify 
HashClusteredDistribution(c,d)
+   * for its right child, then it's guaranteed that left and right child are 
co-partitioned by
+   * a,b/c,d, which means tuples of same value are in the partitions of same 
index, e.g.,
+   * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and 
right child.
+   */
   def requiredChildDistribution: Seq[Distribution] =
     Seq.fill(children.size)(UnspecifiedDistribution)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index c8e236b..e3d2838 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -47,23 +47,6 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
   }
 
   /**
-   * Given a required distribution, returns a partitioning that satisfies that 
distribution.
-   * @param requiredDistribution The distribution that is required by the 
operator
-   * @param numPartitions Used when the distribution doesn't require a 
specific number of partitions
-   */
-  private def createPartitioning(
-      requiredDistribution: Distribution,
-      numPartitions: Int): Partitioning = {
-    requiredDistribution match {
-      case AllTuples => SinglePartition
-      case ClusteredDistribution(clustering, desiredPartitions) =>
-        HashPartitioning(clustering, 
desiredPartitions.getOrElse(numPartitions))
-      case OrderedDistribution(ordering) => RangePartitioning(ordering, 
numPartitions)
-      case dist => sys.error(s"Do not know how to satisfy distribution $dist")
-    }
-  }
-
-  /**
    * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive 
query execution is enabled
    * and partitioning schemes of these [[ShuffleExchangeExec]]s support 
[[ExchangeCoordinator]].
    */
@@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
         // shuffle data when we have more than one children because data 
generated by
         // these children may not be partitioned in the same way.
         // Please see the comment in withCoordinator for more details.
-        val supportsDistribution =
-          
requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
+        val supportsDistribution = requiredChildDistributions.forall { dist =>
+          dist.isInstanceOf[ClusteredDistribution] || 
dist.isInstanceOf[HashClusteredDistribution]
+        }
         children.length > 1 && supportsDistribution
       }
 
@@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
             //
             // It will be great to introduce a new Partitioning to represent 
the post-shuffle
             // partitions when one post-shuffle partition includes multiple 
pre-shuffle partitions.
-            val targetPartitioning =
-              createPartitioning(distribution, defaultNumPreShufflePartitions)
+            val targetPartitioning = 
distribution.createPartitioning(defaultNumPreShufflePartitions)
             assert(targetPartitioning.isInstanceOf[HashPartitioning])
             ShuffleExchangeExec(targetPartitioning, child, Some(coordinator))
         }
@@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
     assert(requiredChildDistributions.length == children.length)
     assert(requiredChildOrderings.length == children.length)
 
-    // Ensure that the operator's children satisfy their output distribution 
requirements:
+    // Ensure that the operator's children satisfy their output distribution 
requirements.
     children = children.zip(requiredChildDistributions).map {
       case (child, distribution) if 
child.outputPartitioning.satisfies(distribution) =>
         child
       case (child, BroadcastDistribution(mode)) =>
         BroadcastExchangeExec(mode, child)
       case (child, distribution) =>
-        ShuffleExchangeExec(createPartitioning(distribution, 
defaultNumPreShufflePartitions), child)
+        val numPartitions = distribution.requiredNumPartitions
+          .getOrElse(defaultNumPreShufflePartitions)
+        ShuffleExchangeExec(distribution.createPartitioning(numPartitions), 
child)
     }
 
-    // If the operator has multiple children and specifies child output 
distributions (e.g. join),
-    // then the children's output partitionings must be compatible:
-    def requireCompatiblePartitioning(distribution: Distribution): Boolean = 
distribution match {
-      case UnspecifiedDistribution => false
-      case BroadcastDistribution(_) => false
+    // Get the indexes of children which have specified distribution 
requirements and need to have
+    // same number of partitions.
+    val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
+      case (UnspecifiedDistribution, _) => false
+      case (_: BroadcastDistribution, _) => false
       case _ => true
-    }
-    if (children.length > 1
-        && requiredChildDistributions.exists(requireCompatiblePartitioning)
-        && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+    }.map(_._2)
 
-      // First check if the existing partitions of the children all match. 
This means they are
-      // partitioned by the same partitioning into the same number of 
partitions. In that case,
-      // don't try to make them match `defaultPartitions`, just use the 
existing partitioning.
-      val maxChildrenNumPartitions = 
children.map(_.outputPartitioning.numPartitions).max
-      val useExistingPartitioning = 
children.zip(requiredChildDistributions).forall {
-        case (child, distribution) =>
-          child.outputPartitioning.guarantees(
-            createPartitioning(distribution, maxChildrenNumPartitions))
+    val childrenNumPartitions =
+      childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
+
+    if (childrenNumPartitions.size > 1) {
+      // Get the number of partitions which is explicitly required by the 
distributions.
+      val requiredNumPartitions = {
+        val numPartitionsSet = childrenIndexes.flatMap {
+          index => requiredChildDistributions(index).requiredNumPartitions
+        }.toSet
+        assert(numPartitionsSet.size <= 1,
+          s"$operator have incompatible requirements of the number of 
partitions for its children")
+        numPartitionsSet.headOption
       }
 
-      children = if (useExistingPartitioning) {
-        // We do not need to shuffle any child's output.
-        children
-      } else {
-        // We need to shuffle at least one child's output.
-        // Now, we will determine the number of partitions that will be used 
by created
-        // partitioning schemes.
-        val numPartitions = {
-          // Let's see if we need to shuffle all child's outputs when we use
-          // maxChildrenNumPartitions.
-          val shufflesAllChildren = 
children.zip(requiredChildDistributions).forall {
-            case (child, distribution) =>
-              !child.outputPartitioning.guarantees(
-                createPartitioning(distribution, maxChildrenNumPartitions))
-          }
-          // If we need to shuffle all children, we use 
defaultNumPreShufflePartitions as the
-          // number of partitions. Otherwise, we use maxChildrenNumPartitions.
-          if (shufflesAllChildren) defaultNumPreShufflePartitions else 
maxChildrenNumPartitions
-        }
+      val targetNumPartitions = 
requiredNumPartitions.getOrElse(childrenNumPartitions.max)
 
-        children.zip(requiredChildDistributions).map {
-          case (child, distribution) =>
-            val targetPartitioning = createPartitioning(distribution, 
numPartitions)
-            if (child.outputPartitioning.guarantees(targetPartitioning)) {
-              child
-            } else {
-              child match {
-                // If child is an exchange, we replace it with
-                // a new one having targetPartitioning.
-                case ShuffleExchangeExec(_, c, _) => 
ShuffleExchangeExec(targetPartitioning, c)
-                case _ => ShuffleExchangeExec(targetPartitioning, child)
-              }
+      children = children.zip(requiredChildDistributions).zipWithIndex.map {
+        case ((child, distribution), index) if childrenIndexes.contains(index) 
=>
+          if (child.outputPartitioning.numPartitions == targetNumPartitions) {
+            child
+          } else {
+            val defaultPartitioning = 
distribution.createPartitioning(targetNumPartitions)
+            child match {
+              // If child is an exchange, we replace it with a new one having 
defaultPartitioning.
+              case ShuffleExchangeExec(_, c, _) => 
ShuffleExchangeExec(defaultPartitioning, c)
+              case _ => ShuffleExchangeExec(defaultPartitioning, child)
+            }
           }
-        }
+
+        case ((child, _), _) => child
       }
     }
 
@@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case operator @ ShuffleExchangeExec(partitioning, child, _) =>
-      child.children match {
-        case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil =>
-          if (childPartitioning.guarantees(partitioning)) child else operator
+    // TODO: remove this after we create a physical operator for 
`RepartitionByExpression`.
+    case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
+      child.outputPartitioning match {
+        case lower: HashPartitioning if upper.semanticEquals(lower) => child
         case _ => operator
       }
     case operator: SparkPlan =>

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 66e8031..897a4da 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -46,7 +46,7 @@ case class ShuffledHashJoinExec(
     "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash 
probe"))
 
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+    HashClusteredDistribution(leftKeys) :: 
HashClusteredDistribution(rightKeys) :: Nil
 
   private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation 
= {
     val buildDataSize = longMetric("buildDataSize")

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 9440541..2de2f30 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -78,7 +78,7 @@ case class SortMergeJoinExec(
   }
 
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+    HashClusteredDistribution(leftKeys) :: 
HashClusteredDistribution(rightKeys) :: Nil
 
   override def outputOrdering: Seq[SortOrder] = joinType match {
     // For inner join, orders of both sides keys should be kept.

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index d1bd8a7..03d1bbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -456,7 +456,7 @@ case class CoGroupExec(
     right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
 
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: 
Nil
+    HashClusteredDistribution(leftGroup) :: 
HashClusteredDistribution(rightGroup) :: Nil
 
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
     leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, 
Ascending)) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/eb45b52e/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index b50642d..f8b26f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -260,11 +260,16 @@ class PlannerSuite extends SharedSQLContext {
   // do they satisfy the distribution requirements? As a result, we need at 
least four test cases.
 
   private def assertDistributionRequirementsAreSatisfied(outputPlan: 
SparkPlan): Unit = {
-    if (outputPlan.children.length > 1
-        && outputPlan.requiredChildDistribution.toSet != 
Set(UnspecifiedDistribution)) {
-      val childPartitionings = outputPlan.children.map(_.outputPartitioning)
-      if (!Partitioning.allCompatible(childPartitionings)) {
-        fail(s"Partitionings are not compatible: $childPartitionings")
+    if (outputPlan.children.length > 1) {
+      val childPartitionings = 
outputPlan.children.zip(outputPlan.requiredChildDistribution)
+        .filter {
+          case (_, UnspecifiedDistribution) => false
+          case (_, _: BroadcastDistribution) => false
+          case _ => true
+        }.map(_._1.outputPartitioning)
+
+      if (childPartitionings.map(_.numPartitions).toSet.size > 1) {
+        fail(s"Partitionings doesn't have same number of partitions: 
$childPartitionings")
       }
     }
     outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
@@ -274,40 +279,7 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
-  test("EnsureRequirements with incompatible child partitionings which satisfy 
distribution") {
-    // Consider an operator that requires inputs that are clustered by two 
expressions (e.g.
-    // sort merge join where there are multiple columns in the equi-join 
condition)
-    val clusteringA = Literal(1) :: Nil
-    val clusteringB = Literal(2) :: Nil
-    val distribution = ClusteredDistribution(clusteringA ++ clusteringB)
-    // Say that the left and right inputs are each partitioned by _one_ of the 
two join columns:
-    val leftPartitioning = HashPartitioning(clusteringA, 1)
-    val rightPartitioning = HashPartitioning(clusteringB, 1)
-    // Individually, each input's partitioning satisfies the clustering 
distribution:
-    assert(leftPartitioning.satisfies(distribution))
-    assert(rightPartitioning.satisfies(distribution))
-    // However, these partitionings are not compatible with each other, so we 
still need to
-    // repartition both inputs prior to performing the join:
-    assert(!leftPartitioning.compatibleWith(rightPartitioning))
-    assert(!rightPartitioning.compatibleWith(leftPartitioning))
-    val inputPlan = DummySparkPlan(
-      children = Seq(
-        DummySparkPlan(outputPartitioning = leftPartitioning),
-        DummySparkPlan(outputPartitioning = rightPartitioning)
-      ),
-      requiredChildDistribution = Seq(distribution, distribution),
-      requiredChildOrdering = Seq(Seq.empty, Seq.empty)
-    )
-    val outputPlan = 
EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
-    assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) {
-      fail(s"Exchange should have been added:\n$outputPlan")
-    }
-  }
-
   test("EnsureRequirements with child partitionings with different numbers of 
output partitions") {
-    // This is similar to the previous test, except it checks that 
partitionings are not compatible
-    // unless they produce the same number of partitions.
     val clustering = Literal(1) :: Nil
     val distribution = ClusteredDistribution(clustering)
     val inputPlan = DummySparkPlan(
@@ -386,18 +358,15 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
-  test("EnsureRequirements eliminates Exchange if child has Exchange with same 
partitioning") {
+  test("EnsureRequirements eliminates Exchange if child has same 
partitioning") {
     val distribution = ClusteredDistribution(Literal(1) :: Nil)
-    val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
-    val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
-    assert(!childPartitioning.satisfies(distribution))
-    val inputPlan = ShuffleExchangeExec(finalPartitioning,
-      DummySparkPlan(
-        children = DummySparkPlan(outputPartitioning = childPartitioning) :: 
Nil,
-        requiredChildDistribution = Seq(distribution),
-        requiredChildOrdering = Seq(Seq.empty)),
-      None)
+    val partitioning = HashPartitioning(Literal(1) :: Nil, 5)
+    assert(partitioning.satisfies(distribution))
 
+    val inputPlan = ShuffleExchangeExec(
+      partitioning,
+      DummySparkPlan(outputPartitioning = partitioning),
+      None)
     val outputPlan = 
EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
     if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) {
@@ -407,17 +376,13 @@ class PlannerSuite extends SharedSQLContext {
 
   test("EnsureRequirements does not eliminate Exchange with different 
partitioning") {
     val distribution = ClusteredDistribution(Literal(1) :: Nil)
-    // Number of partitions differ
-    val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8)
-    val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
-    assert(!childPartitioning.satisfies(distribution))
-    val inputPlan = ShuffleExchangeExec(finalPartitioning,
-      DummySparkPlan(
-        children = DummySparkPlan(outputPartitioning = childPartitioning) :: 
Nil,
-        requiredChildDistribution = Seq(distribution),
-        requiredChildOrdering = Seq(Seq.empty)),
-      None)
+    val partitioning = HashPartitioning(Literal(2) :: Nil, 5)
+    assert(!partitioning.satisfies(distribution))
 
+    val inputPlan = ShuffleExchangeExec(
+      partitioning,
+      DummySparkPlan(outputPartitioning = partitioning),
+      None)
     val outputPlan = 
EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
     if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to