Repository: spark
Updated Branches:
  refs/heads/master af441ddbd -> b6a873d6d


http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
new file mode 100644
index 0000000..de21d77
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import java.util.Random
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
+import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution._
+import org.apache.spark.util.MutablePair
+
+/**
+ * Performs a shuffle that will result in the desired `newPartitioning`.
+ */
+case class ShuffleExchange(
+    var newPartitioning: Partitioning,
+    child: SparkPlan,
+    @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode {
+
+  override def nodeName: String = {
+    val extraInfo = coordinator match {
+      case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated =>
+        s"(coordinator id: ${System.identityHashCode(coordinator)})"
+      case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated =>
+        s"(coordinator id: ${System.identityHashCode(coordinator)})"
+      case None => ""
+    }
+
+    val simpleNodeName = "Exchange"
+    s"$simpleNodeName$extraInfo"
+  }
+
+  override def outputPartitioning: Partitioning = newPartitioning
+
+  override def output: Seq[Attribute] = child.output
+
+  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
+
+  override protected def doPrepare(): Unit = {
+    // If an ExchangeCoordinator is needed, we register this Exchange operator
+    // to the coordinator when we do prepare. It is important to make sure
+    // we register this operator right before the execution instead of 
register it
+    // in the constructor because it is possible that we create new instances 
of
+    // Exchange operators when we transform the physical plan
+    // (then the ExchangeCoordinator will hold references of unneeded 
Exchanges).
+    // So, we should only call registerExchange just before we start to execute
+    // the plan.
+    coordinator match {
+      case Some(exchangeCoordinator) => 
exchangeCoordinator.registerExchange(this)
+      case None =>
+    }
+  }
+
+  /**
+   * Returns a [[ShuffleDependency]] that will partition rows of its child 
based on
+   * the partitioning scheme defined in `newPartitioning`. Those partitions of
+   * the returned ShuffleDependency will be the input of shuffle.
+   */
+  private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, 
InternalRow, InternalRow] = {
+    ShuffleExchange.prepareShuffleDependency(
+      child.execute(), child.output, newPartitioning, serializer)
+  }
+
+  /**
+   * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
+   * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] 
and an optional
+   * partition start indices array. If this optional array is defined, the 
returned
+   * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of 
this array.
+   */
+  private[sql] def preparePostShuffleRDD(
+      shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
+      specifiedPartitionStartIndices: Option[Array[Int]] = None): 
ShuffledRowRDD = {
+    // If an array of partition start indices is provided, we need to use this 
array
+    // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
+    // update the number of post-shuffle partitions.
+    specifiedPartitionStartIndices.foreach { indices =>
+      assert(newPartitioning.isInstanceOf[HashPartitioning])
+      newPartitioning = UnknownPartitioning(indices.length)
+    }
+    new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
+    coordinator match {
+      case Some(exchangeCoordinator) =>
+        val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
+        assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
+        shuffleRDD
+      case None =>
+        val shuffleDependency = prepareShuffleDependency()
+        preparePostShuffleRDD(shuffleDependency)
+    }
+  }
+}
+
+object ShuffleExchange {
+  def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange 
= {
+    ShuffleExchange(newPartitioning, child, coordinator = None: 
Option[ExchangeCoordinator])
+  }
+
+  /**
+   * Determines whether records must be defensively copied before being sent 
to the shuffle.
+   * Several of Spark's shuffle components will buffer deserialized Java 
objects in memory. The
+   * shuffle code assumes that objects are immutable and hence does not 
perform its own defensive
+   * copying. In Spark SQL, however, operators' iterators return the same 
mutable `Row` object. In
+   * order to properly shuffle the output of these operators, we need to 
perform our own copying
+   * prior to sending records to the shuffle. This copying is expensive, so we 
try to avoid it
+   * whenever possible. This method encapsulates the logic for choosing when 
to copy.
+   *
+   * In the long run, we might want to push this logic into core's shuffle 
APIs so that we don't
+   * have to rely on knowledge of core internals here in SQL.
+   *
+   * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this 
issue.
+   *
+   * @param partitioner the partitioner for the shuffle
+   * @param serializer the serializer that will be used to write rows
+   * @return true if rows should be copied before being shuffled, false 
otherwise
+   */
+  private def needToCopyObjectsBeforeShuffle(
+      partitioner: Partitioner,
+      serializer: Serializer): Boolean = {
+    // Note: even though we only use the partitioner's `numPartitions` field, 
we require it to be
+    // passed instead of directly passing the number of partitions in order to 
guard against
+    // corner-cases where a partitioner constructed with `numPartitions` 
partitions may output
+    // fewer partitions (like RangePartitioner, for example).
+    val conf = SparkEnv.get.conf
+    val shuffleManager = SparkEnv.get.shuffleManager
+    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
+    val bypassMergeThreshold = 
conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+    if (sortBasedShuffleOn) {
+      val bypassIsSupported = 
SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+      if (bypassIsSupported && partitioner.numPartitions <= 
bypassMergeThreshold) {
+        // If we're using the original SortShuffleManager and the number of 
output partitions is
+        // sufficiently small, then Spark will fall back to the hash-based 
shuffle write path, which
+        // doesn't buffer deserialized records.
+        // Note that we'll have to remove this case if we fix SPARK-6026 and 
remove this bypass.
+        false
+      } else if (serializer.supportsRelocationOfSerializedObjects) {
+        // SPARK-4550 and  SPARK-7081 extended sort-based shuffle to serialize 
individual records
+        // prior to sorting them. This optimization is only applied in cases 
where shuffle
+        // dependency does not specify an aggregator or ordering and the 
record serializer has
+        // certain properties. If this optimization is enabled, we can safely 
avoid the copy.
+        //
+        // Exchange never configures its ShuffledRDDs with aggregators or key 
orderings, so we only
+        // need to check whether the optimization is enabled and supported by 
our serializer.
+        false
+      } else {
+        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records 
in memory, so we must
+        // copy.
+        true
+      }
+    } else if (shuffleManager.isInstanceOf[HashShuffleManager]) {
+      // We're using hash-based shuffle, so we don't need to copy.
+      false
+    } else {
+      // Catch-all case to safely handle any future ShuffleManager 
implementations.
+      true
+    }
+  }
+
+  /**
+   * Returns a [[ShuffleDependency]] that will partition rows of its child 
based on
+   * the partitioning scheme defined in `newPartitioning`. Those partitions of
+   * the returned ShuffleDependency will be the input of shuffle.
+   */
+  private[sql] def prepareShuffleDependency(
+      rdd: RDD[InternalRow],
+      outputAttributes: Seq[Attribute],
+      newPartitioning: Partitioning,
+      serializer: Serializer): ShuffleDependency[Int, InternalRow, 
InternalRow] = {
+    val part: Partitioner = newPartitioning match {
+      case RoundRobinPartitioning(numPartitions) => new 
HashPartitioner(numPartitions)
+      case HashPartitioning(_, n) =>
+        new Partitioner {
+          override def numPartitions: Int = n
+          // For HashPartitioning, the partitioning key is already a valid 
partition ID, as we use
+          // `HashPartitioning.partitionIdExpression` to produce partitioning 
key.
+          override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+        }
+      case RangePartitioning(sortingExpressions, numPartitions) =>
+        // Internally, RangePartitioner runs a job on the RDD that samples 
keys to compute
+        // partition bounds. To get accurate samples, we need to copy the 
mutable keys.
+        val rddForSampling = rdd.mapPartitionsInternal { iter =>
+          val mutablePair = new MutablePair[InternalRow, Null]()
+          iter.map(row => mutablePair.update(row.copy(), null))
+        }
+        implicit val ordering = new 
LazilyGeneratedOrdering(sortingExpressions, outputAttributes)
+        new RangePartitioner(numPartitions, rddForSampling, ascending = true)
+      case SinglePartition =>
+        new Partitioner {
+          override def numPartitions: Int = 1
+          override def getPartition(key: Any): Int = 0
+        }
+      case _ => sys.error(s"Exchange not implemented for $newPartitioning")
+      // TODO: Handle BroadcastPartitioning.
+    }
+    def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match 
{
+      case RoundRobinPartitioning(numPartitions) =>
+        // Distributes elements evenly across output partitions, starting from 
a random partition.
+        var position = new 
Random(TaskContext.get().partitionId()).nextInt(numPartitions)
+        (row: InternalRow) => {
+          // The HashPartitioner will handle the `mod` by the number of 
partitions
+          position += 1
+          position
+        }
+      case h: HashPartitioning =>
+        val projection = UnsafeProjection.create(h.partitionIdExpression :: 
Nil, outputAttributes)
+        row => projection(row).getInt(0)
+      case RangePartitioning(_, _) | SinglePartition => identity
+      case _ => sys.error(s"Exchange not implemented for $newPartitioning")
+    }
+    val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
+      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+        rdd.mapPartitionsInternal { iter =>
+          val getPartitionKey = getPartitionKeyExtractor()
+          iter.map { row => (part.getPartition(getPartitionKey(row)), 
row.copy()) }
+        }
+      } else {
+        rdd.mapPartitionsInternal { iter =>
+          val getPartitionKey = getPartitionKeyExtractor()
+          val mutablePair = new MutablePair[Int, InternalRow]()
+          iter.map { row => 
mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
+        }
+      }
+    }
+
+    // Now, we manually create a ShuffleDependency. Because pairs in 
rddWithPartitionIds
+    // are in the form of (partitionId, row) and every partitionId is in the 
expected range
+    // [0, part.numPartitions - 1]. The partitioner of this is a 
PartitionIdPassthrough.
+    val dependency =
+      new ShuffleDependency[Int, InternalRow, InternalRow](
+        rddWithPartitionIds,
+        new PartitionIdPassthrough(part.numPartitions),
+        Some(serializer))
+
+    dependency
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index a64da22..ddc0882 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -17,9 +17,6 @@
 
 package org.apache.spark.sql.execution.joins
 
-import scala.concurrent._
-import scala.concurrent.duration._
-
 import org.apache.spark.TaskContext
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
@@ -27,10 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, 
RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, 
Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, 
SQLExecution}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, 
Distribution, Partitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.ThreadUtils
 import org.apache.spark.util.collection.CompactBuffer
 
 /**
@@ -52,60 +48,25 @@ case class BroadcastHashJoin(
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
output rows"))
 
-  val timeout: Duration = {
-    val timeoutValue = sqlContext.conf.broadcastTimeout
-    if (timeoutValue < 0) {
-      Duration.Inf
-    } else {
-      timeoutValue.seconds
-    }
-  }
-
   override def outputPartitioning: Partitioning = 
streamedPlan.outputPartitioning
 
-  override def requiredChildDistribution: Seq[Distribution] =
-    UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
-  // Use lazy so that we won't do broadcast when calling explain but still 
cache the broadcast value
-  // for the same query.
-  @transient
-  private lazy val broadcastFuture = {
-    // broadcastFuture is used in "doExecute". Therefore we can get the 
execution id correctly here.
-    val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    Future {
-      // This will run in another thread. Set the execution id so that we can 
connect these jobs
-      // with the correct execution.
-      SQLExecution.withExecutionId(sparkContext, executionId) {
-        // Note that we use .execute().collect() because we don't want to 
convert data to Scala
-        // types
-        val input: Array[InternalRow] = buildPlan.execute().map { row =>
-          row.copy()
-        }.collect()
-        // The following line doesn't run in a job so we cannot track the 
metric value. However, we
-        // have already tracked it in the above lines. So here we can use
-        // `SQLMetrics.nullLongMetric` to ignore it.
-        // TODO: move this check into HashedRelation
-        val hashed = if (canJoinKeyFitWithinLong) {
-          LongHashedRelation(
-            input.iterator, buildSideKeyGenerator, input.size)
-        } else {
-          HashedRelation(
-            input.iterator, buildSideKeyGenerator, input.size)
-        }
-        sparkContext.broadcast(hashed)
-      }
-    }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
-  }
-
-  protected override def doPrepare(): Unit = {
-    broadcastFuture
+  override def requiredChildDistribution: Seq[Distribution] = {
+    val mode = HashedRelationBroadcastMode(
+      canJoinKeyFitWithinLong,
+      rewriteKeyExpr(buildKeys),
+      buildPlan.output)
+    buildSide match {
+      case BuildLeft =>
+        BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
+      case BuildRight =>
+        UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+    }
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
-    val broadcastRelation = Await.result(broadcastFuture, timeout)
-
+    val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
     streamedPlan.execute().mapPartitions { streamedIter =>
       val joinedRow = new JoinedRow()
       val hashTable = broadcastRelation.value
@@ -160,7 +121,7 @@ case class BroadcastHashJoin(
     */
   private def prepareBroadcast(ctx: CodegenContext): 
(Broadcast[HashedRelation], String) = {
     // create a name for HashedRelation
-    val broadcastRelation = Await.result(broadcastFuture, timeout)
+    val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
     val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
     val relationTerm = ctx.freshName("relation")
     val clsName = broadcastRelation.value.getClass.getName
@@ -362,9 +323,3 @@ case class BroadcastHashJoin(
     }
   }
 }
-
-object BroadcastHashJoin {
-
-  private[joins] val broadcastHashJoinExecutionContext = 
ExecutionContext.fromExecutorService(
-    ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128))
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 4f1cfd2..1f99fbe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.execution.joins
 
-import org.apache.spark.{InternalAccumulator, TaskContext}
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, 
Distribution, UnspecifiedDistribution}
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
@@ -38,25 +39,25 @@ case class BroadcastLeftSemiJoinHash(
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
output rows"))
 
+  override def requiredChildDistribution: Seq[Distribution] = {
+    val mode = if (condition.isEmpty) {
+      HashSetBroadcastMode(rightKeys, right.output)
+    } else {
+      HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, 
right.output)
+    }
+    UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
-    val input = right.execute().map { row =>
-      row.copy()
-    }.collect()
-
     if (condition.isEmpty) {
-      val hashSet = buildKeyHashSet(input.toIterator)
-      val broadcastedRelation = sparkContext.broadcast(hashSet)
-
+      val broadcastedRelation = 
right.executeBroadcast[java.util.Set[InternalRow]]()
       left.execute().mapPartitionsInternal { streamIter =>
         hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
       }
     } else {
-      val hashRelation =
-        HashedRelation(input.toIterator, rightKeyGenerator, input.size)
-      val broadcastedRelation = sparkContext.broadcast(hashRelation)
-
+      val broadcastedRelation = right.executeBroadcast[HashedRelation]()
       left.execute().mapPartitionsInternal { streamIter =>
         val hashedRelation = broadcastedRelation.value
         
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 4585cbd..e8bd7f6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.collection.{BitSet, CompactBuffer}
@@ -33,7 +33,6 @@ case class BroadcastNestedLoopJoin(
     buildSide: BuildSide,
     joinType: JoinType,
     condition: Option[Expression]) extends BinaryNode {
-  // TODO: Override requiredChildDistribution.
 
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
output rows"))
@@ -44,8 +43,15 @@ case class BroadcastNestedLoopJoin(
     case BuildLeft => (right, left)
   }
 
+  override def requiredChildDistribution: Seq[Distribution] = buildSide match {
+    case BuildLeft =>
+      BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution 
:: Nil
+    case BuildRight =>
+      UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) 
:: Nil
+  }
+
   private[this] def genResultProjection: InternalRow => InternalRow = {
-      UnsafeProjection.create(schema)
+    UnsafeProjection.create(schema)
   }
 
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
@@ -73,15 +79,14 @@ case class BroadcastNestedLoopJoin(
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
-    val broadcastedRelation =
-      sparkContext.broadcast(broadcast.execute().map { row =>
-        row.copy()
-      }.collect().toIndexedSeq)
+    val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
 
     /** All rows that either match both-way, or rows from streamed joined with 
nulls. */
     val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { 
streamedIter =>
+      val relation = broadcastedRelation.value
+
       val matchedRows = new CompactBuffer[InternalRow]
-      val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
+      val includedBroadcastTuples = new BitSet(relation.length)
       val joinedRow = new JoinedRow
 
       val leftNulls = new GenericMutableRow(left.output.size)
@@ -92,8 +97,8 @@ case class BroadcastNestedLoopJoin(
         var i = 0
         var streamRowMatched = false
 
-        while (i < broadcastedRelation.value.size) {
-          val broadcastedRow = broadcastedRelation.value(i)
+        while (i < relation.length) {
+          val broadcastedRow = relation(i)
           buildSide match {
             case BuildRight if boundCondition(joinedRow(streamedRow, 
broadcastedRow)) =>
               matchedRows += resultProj(joinedRow(streamedRow, 
broadcastedRow)).copy()

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 0220e0b..1cb6a00 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.LongSQLMetric
 
@@ -44,22 +45,7 @@ trait HashSemiJoin {
 
   protected def buildKeyHashSet(
       buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
-    val hashSet = new java.util.HashSet[InternalRow]()
-
-    // Create a Hash set of buildKeys
-    val rightKey = rightKeyGenerator
-    while (buildIter.hasNext) {
-      val currentRow = buildIter.next()
-      val rowKey = rightKey(currentRow)
-      if (!rowKey.anyNull) {
-        val keyExists = hashSet.contains(rowKey)
-        if (!keyExists) {
-          hashSet.add(rowKey.copy())
-        }
-      }
-    }
-
-    hashSet
+    HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
   }
 
   protected def hashSemiJoin(
@@ -92,3 +78,36 @@ trait HashSemiJoin {
     }
   }
 }
+
+private[execution] object HashSemiJoin {
+  def buildKeyHashSet(
+    keys: Seq[Expression],
+    attributes: Seq[Attribute],
+    rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
+    val hashSet = new java.util.HashSet[InternalRow]()
+
+    // Create a Hash set of buildKeys
+    val key = UnsafeProjection.create(keys, attributes)
+    while (rows.hasNext) {
+      val currentRow = rows.next()
+      val rowKey = key(currentRow)
+      if (!rowKey.anyNull) {
+        val keyExists = hashSet.contains(rowKey)
+        if (!keyExists) {
+          hashSet.add(rowKey.copy())
+        }
+      }
+    }
+    hashSet
+  }
+}
+
+/** HashSetBroadcastMode requires that the input rows are broadcasted as a 
set. */
+private[execution] case class HashSetBroadcastMode(
+    keys: Seq[Expression],
+    attributes: Seq[Attribute]) extends BroadcastMode {
+
+  override def transform(rows: Array[InternalRow]): 
java.util.HashSet[InternalRow] = {
+    HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0978570..606269b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -25,12 +25,11 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
 import org.apache.spark.sql.execution.local.LocalNode
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.unsafe.memory.MemoryLocation
 import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
 import org.apache.spark.util.collection.CompactBuffer
 
@@ -675,3 +674,20 @@ private[joins] object LongHashedRelation {
     }
   }
 }
+
+/** The HashedRelationBroadcastMode requires that rows are broadcasted as a 
HashedRelation. */
+private[execution] case class HashedRelationBroadcastMode(
+    canJoinKeyFitWithinLong: Boolean,
+    keys: Seq[Expression],
+    attributes: Seq[Attribute]) extends BroadcastMode {
+
+  def transform(rows: Array[InternalRow]): HashedRelation = {
+    val generator = UnsafeProjection.create(keys, attributes)
+    if (canJoinKeyFitWithinLong) {
+      LongHashedRelation(rows.iterator, generator, rows.length)
+    } else {
+      HashedRelation(rows.iterator, generator, rows.length)
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index ce758d6..df6dac8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
@@ -29,9 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
  * for hash join.
  */
 case class LeftSemiJoinBNL(
-    streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
-  extends BinaryNode {
-  // TODO: Override requiredChildDistribution.
+    streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) 
extends BinaryNode {
 
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
output rows"))
@@ -46,27 +44,28 @@ case class LeftSemiJoinBNL(
   /** The Broadcast relation */
   override def right: SparkPlan = broadcast
 
+  override def requiredChildDistribution: Seq[Distribution] = {
+    UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: 
Nil
+  }
+
   @transient private lazy val boundCondition =
     newPredicate(condition.getOrElse(Literal(true)), left.output ++ 
right.output)
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
-    val broadcastedRelation =
-      sparkContext.broadcast(broadcast.execute().map { row =>
-        row.copy()
-      }.collect().toIndexedSeq)
+    val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
 
     streamed.execute().mapPartitions { streamedIter =>
       val joinedRow = new JoinedRow
+      val relation = broadcastedRelation.value
 
       streamedIter.filter(streamedRow => {
         var i = 0
         var matched = false
 
-        while (i < broadcastedRelation.value.size && !matched) {
-          val broadcastedRow = broadcastedRelation.value(i)
-          if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
+        while (i < relation.length && !matched) {
+          if (boundCondition(joinedRow(streamedRow, relation(i)))) {
             matched = true
           }
           i += 1

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index ef76847..cd543d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
 
 
 /**
@@ -38,7 +39,8 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends 
UnaryNode {
   private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
   protected override def doExecute(): RDD[InternalRow] = {
     val shuffled = new ShuffledRowRDD(
-      Exchange.prepareShuffleDependency(child.execute(), child.output, 
SinglePartition, serializer))
+      ShuffleExchange.prepareShuffleDependency(
+        child.execute(), child.output, SinglePartition, serializer))
     shuffled.mapPartitionsInternal(_.take(limit))
   }
 }
@@ -110,7 +112,8 @@ case class TakeOrderedAndProject(
       }
     }
     val shuffled = new ShuffledRowRDD(
-      Exchange.prepareShuffleDependency(localTopK, child.output, 
SinglePartition, serializer))
+      ShuffleExchange.prepareShuffleDependency(
+        localTopK, child.output, SinglePartition, serializer))
     shuffled.mapPartitions { iter =>
       val topK = 
org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), 
limit)(ord)
       if (projectList.isDefined) {

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index e8d0678..83d7953 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -23,9 +23,9 @@ import scala.language.postfixOps
 import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark.Accumulators
-import org.apache.spark.sql.execution.Exchange
 import org.apache.spark.sql.execution.PhysicalRDD
 import org.apache.spark.sql.execution.columnar._
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
 import org.apache.spark.storage.{RDDBlockId, StorageLevel}
@@ -357,7 +357,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils 
with SharedSQLContext
    * Verifies that the plan for `df` contains `expected` number of Exchange 
operators.
    */
   private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = {
-    assert(df.queryExecution.executedPlan.collect { case e: Exchange => e 
}.size == expected)
+    assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => 
e }.size == expected)
   }
 
   test("A cached table preserves the partitioning and ordering of its cached 
SparkPlan") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 99ba2e2..50a2464 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -26,8 +26,8 @@ import org.scalatest.Matchers._
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
-import org.apache.spark.sql.execution.Exchange
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, 
SharedSQLContext}
 import org.apache.spark.sql.test.SQLTestData.TestData2
@@ -1119,7 +1119,7 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
         }
         atFirstAgg = true
       }
-      case e: Exchange => atFirstAgg = false
+      case e: ShuffleExchange => atFirstAgg = false
       case _ =>
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 35ff1c4..b1c588a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, 
SparkFunSuite}
 import org.apache.spark.sql._
+import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, 
ShuffleExchange}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext
 
@@ -297,13 +298,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with 
BeforeAndAfterAll {
         // Then, let's look at the number of post-shuffle partitions estimated
         // by the ExchangeCoordinator.
         val exchanges = agg.queryExecution.executedPlan.collect {
-          case e: Exchange => e
+          case e: ShuffleExchange => e
         }
         assert(exchanges.length === 1)
         minNumPostShufflePartitions match {
           case Some(numPartitions) =>
             exchanges.foreach {
-              case e: Exchange =>
+              case e: ShuffleExchange =>
                 assert(e.coordinator.isDefined)
                 assert(e.outputPartitioning.numPartitions === 3)
               case o =>
@@ -311,7 +312,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with 
BeforeAndAfterAll {
 
           case None =>
             exchanges.foreach {
-              case e: Exchange =>
+              case e: ShuffleExchange =>
                 assert(e.coordinator.isDefined)
                 assert(e.outputPartitioning.numPartitions === 2)
               case o =>
@@ -348,13 +349,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with 
BeforeAndAfterAll {
         // Then, let's look at the number of post-shuffle partitions estimated
         // by the ExchangeCoordinator.
         val exchanges = join.queryExecution.executedPlan.collect {
-          case e: Exchange => e
+          case e: ShuffleExchange => e
         }
         assert(exchanges.length === 2)
         minNumPostShufflePartitions match {
           case Some(numPartitions) =>
             exchanges.foreach {
-              case e: Exchange =>
+              case e: ShuffleExchange =>
                 assert(e.coordinator.isDefined)
                 assert(e.outputPartitioning.numPartitions === 3)
               case o =>
@@ -362,7 +363,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with 
BeforeAndAfterAll {
 
           case None =>
             exchanges.foreach {
-              case e: Exchange =>
+              case e: ShuffleExchange =>
                 assert(e.coordinator.isDefined)
                 assert(e.outputPartitioning.numPartitions === 2)
               case o =>
@@ -404,13 +405,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with 
BeforeAndAfterAll {
         // Then, let's look at the number of post-shuffle partitions estimated
         // by the ExchangeCoordinator.
         val exchanges = join.queryExecution.executedPlan.collect {
-          case e: Exchange => e
+          case e: ShuffleExchange => e
         }
         assert(exchanges.length === 4)
         minNumPostShufflePartitions match {
           case Some(numPartitions) =>
             exchanges.foreach {
-              case e: Exchange =>
+              case e: ShuffleExchange =>
                 assert(e.coordinator.isDefined)
                 assert(e.outputPartitioning.numPartitions === 3)
               case o =>
@@ -456,13 +457,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with 
BeforeAndAfterAll {
         // Then, let's look at the number of post-shuffle partitions estimated
         // by the ExchangeCoordinator.
         val exchanges = join.queryExecution.executedPlan.collect {
-          case e: Exchange => e
+          case e: ShuffleExchange => e
         }
         assert(exchanges.length === 3)
         minNumPostShufflePartitions match {
           case Some(numPartitions) =>
             exchanges.foreach {
-              case e: Exchange =>
+              case e: ShuffleExchange =>
                 assert(e.coordinator.isDefined)
                 assert(e.outputPartitioning.numPartitions === 3)
               case o =>

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 87bff32..d4f22de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -28,7 +29,7 @@ class ExchangeSuite extends SparkPlanTest with 
SharedSQLContext {
     val input = (1 to 1000).map(Tuple1.apply)
     checkAnswer(
       input.toDF(),
-      plan => Exchange(SinglePartition, plan),
+      plan => ShuffleExchange(SinglePartition, plan),
       input.map(Row.fromTuple)
     )
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 250ce8f..4de5678 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, 
Attribute, Literal,
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, 
InMemoryRelation}
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ShuffleExchange}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -212,7 +213,7 @@ class PlannerSuite extends SharedSQLContext {
               |  JOIN tiny ON (small.key = tiny.key)
             """.stripMargin
           ).queryExecution.executedPlan.collect {
-            case exchange: Exchange => exchange
+            case exchange: ShuffleExchange => exchange
           }.length
           assert(numExchanges === 5)
         }
@@ -227,7 +228,7 @@ class PlannerSuite extends SharedSQLContext {
               |  JOIN tiny ON (normal.key = tiny.key)
             """.stripMargin
           ).queryExecution.executedPlan.collect {
-            case exchange: Exchange => exchange
+            case exchange: ShuffleExchange => exchange
           }.length
           assert(numExchanges === 5)
         }
@@ -295,7 +296,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
+    if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
       fail(s"Exchange should have been added:\n$outputPlan")
     }
   }
@@ -333,7 +334,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
+    if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
       fail(s"Exchange should have been added:\n$outputPlan")
     }
   }
@@ -353,7 +354,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
+    if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
       fail(s"Exchange should not have been added:\n$outputPlan")
     }
   }
@@ -376,7 +377,7 @@ class PlannerSuite extends SharedSQLContext {
     )
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
+    if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
       fail(s"No Exchanges should have been added:\n$outputPlan")
     }
   }
@@ -435,7 +436,7 @@ class PlannerSuite extends SharedSQLContext {
     val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
     val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
     assert(!childPartitioning.satisfies(distribution))
-    val inputPlan = Exchange(finalPartitioning,
+    val inputPlan = ShuffleExchange(finalPartitioning,
       DummySparkPlan(
         children = DummySparkPlan(outputPartitioning = childPartitioning) :: 
Nil,
         requiredChildDistribution = Seq(distribution),
@@ -444,7 +445,7 @@ class PlannerSuite extends SharedSQLContext {
 
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: Exchange => true }.size == 2) {
+    if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) {
       fail(s"Topmost Exchange should have been eliminated:\n$outputPlan")
     }
   }
@@ -455,7 +456,7 @@ class PlannerSuite extends SharedSQLContext {
     val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8)
     val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
     assert(!childPartitioning.satisfies(distribution))
-    val inputPlan = Exchange(finalPartitioning,
+    val inputPlan = ShuffleExchange(finalPartitioning,
       DummySparkPlan(
         children = DummySparkPlan(outputPartitioning = childPartitioning) :: 
Nil,
         requiredChildDistribution = Seq(distribution),
@@ -464,7 +465,7 @@ class PlannerSuite extends SharedSQLContext {
 
     val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
     assertDistributionRequirementsAreSatisfied(outputPlan)
-    if (outputPlan.collect { case e: Exchange => true }.size == 1) {
+    if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) {
       fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index e25b5e0..a256ee9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,7 +22,8 @@ import scala.reflect.ClassTag
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext}
+import org.apache.spark.sql.{QueryTest, SQLContext}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.functions._
 
 /**
@@ -62,7 +63,7 @@ class BroadcastJoinSuite extends QueryTest with 
BeforeAndAfterAll {
       // Comparison at the end is for broadcast left semi join
       val joinExpression = df1("key") === df2("key") && df1("value") > 
df2("value")
       val df3 = df1.join(broadcast(df2), joinExpression, joinType)
-      val plan = df3.queryExecution.sparkPlan
+      val plan = 
EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan)
       assert(plan.collect { case p: T => p }.size === 1)
       plan.executeCollect()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index e22a810..6dfff37 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -23,6 +23,7 @@ import 
org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
 
@@ -88,7 +89,15 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan,
         side: BuildSide) = {
-      joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, 
boundCondition, leftPlan, rightPlan)
+      val broadcastJoin = joins.BroadcastHashJoin(
+        leftKeys,
+        rightKeys,
+        Inner,
+        side,
+        boundCondition,
+        leftPlan,
+        rightPlan)
+      EnsureRequirements(sqlContext).apply(broadcastJoin)
     }
 
     def makeSortMergeJoin(

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index f4b01fb..cd6b6fc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, 
Expression, LessThan}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, 
SparkPlanTest}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 9c86084..f3ad840 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, 
Expression, LessThan}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, 
SparkPlanTest}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 9ba6456..a05a57c 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -22,8 +22,9 @@ import java.io.File
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
+import org.apache.spark.sql.execution.PhysicalRDD
 import org.apache.spark.sql.execution.datasources.{BucketSpec, 
DataSourceStrategy}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.execution.joins.SortMergeJoin
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -252,8 +253,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils 
with TestHiveSinglet
         assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin])
         val joinOperator = 
joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin]
 
-        assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == 
shuffleLeft)
-        assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == 
shuffleRight)
+        
assert(joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == 
shuffleLeft)
+        
assert(joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == 
shuffleRight)
       }
     }
   }
@@ -312,7 +313,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils 
with TestHiveSinglet
         agged.sort("i", "j"),
         df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
 
-      
assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+      
assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
     }
   }
 
@@ -326,7 +327,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils 
with TestHiveSinglet
         agged.sort("i", "j"),
         df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
 
-      
assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+      
assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
     }
   }
 
@@ -339,7 +340,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils 
with TestHiveSinglet
 
       val agged = hiveContext.table("bucketed_table").groupBy("i").count()
       // make sure we fall back to non-bucketing mode and can't avoid shuffle
-      
assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined)
+      
assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isDefined)
       checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i"))
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to