Repository: spark Updated Branches: refs/heads/master af441ddbd -> b6a873d6d
http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala new file mode 100644 index 0000000..de21d77 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import java.util.Random + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.util.MutablePair + +/** + * Performs a shuffle that will result in the desired `newPartitioning`. + */ +case class ShuffleExchange( + var newPartitioning: Partitioning, + child: SparkPlan, + @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { + + override def nodeName: String = { + val extraInfo = coordinator match { + case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case None => "" + } + + val simpleNodeName = "Exchange" + s"$simpleNodeName$extraInfo" + } + + override def outputPartitioning: Partitioning = newPartitioning + + override def output: Seq[Attribute] = child.output + + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + + override protected def doPrepare(): Unit = { + // If an ExchangeCoordinator is needed, we register this Exchange operator + // to the coordinator when we do prepare. It is important to make sure + // we register this operator right before the execution instead of register it + // in the constructor because it is possible that we create new instances of + // Exchange operators when we transform the physical plan + // (then the ExchangeCoordinator will hold references of unneeded Exchanges). + // So, we should only call registerExchange just before we start to execute + // the plan. + coordinator match { + case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) + case None => + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, newPartitioning, serializer) + } + + /** + * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. + * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional + * partition start indices array. If this optional array is defined, the returned + * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. + */ + private[sql] def preparePostShuffleRDD( + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { + // If an array of partition start indices is provided, we need to use this array + // to create the ShuffledRowRDD. Also, we need to update newPartitioning to + // update the number of post-shuffle partitions. + specifiedPartitionStartIndices.foreach { indices => + assert(newPartitioning.isInstanceOf[HashPartitioning]) + newPartitioning = UnknownPartitioning(indices.length) + } + new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } + } +} + +object ShuffleExchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { + ShuffleExchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) + } + + /** + * Determines whether records must be defensively copied before being sent to the shuffle. + * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The + * shuffle code assumes that objects are immutable and hence does not perform its own defensive + * copying. In Spark SQL, however, operators' iterators return the same mutable `Row` object. In + * order to properly shuffle the output of these operators, we need to perform our own copying + * prior to sending records to the shuffle. This copying is expensive, so we try to avoid it + * whenever possible. This method encapsulates the logic for choosing when to copy. + * + * In the long run, we might want to push this logic into core's shuffle APIs so that we don't + * have to rely on knowledge of core internals here in SQL. + * + * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. + * + * @param partitioner the partitioner for the shuffle + * @param serializer the serializer that will be used to write rows + * @return true if rows should be copied before being shuffled, false otherwise + */ + private def needToCopyObjectsBeforeShuffle( + partitioner: Partitioner, + serializer: Serializer): Boolean = { + // Note: even though we only use the partitioner's `numPartitions` field, we require it to be + // passed instead of directly passing the number of partitions in order to guard against + // corner-cases where a partitioner constructed with `numPartitions` partitions may output + // fewer partitions (like RangePartitioner, for example). + val conf = SparkEnv.get.conf + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] + val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + if (sortBasedShuffleOn) { + val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. + // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. + false + } else if (serializer.supportsRelocationOfSerializedObjects) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties. If this optimization is enabled, we can safely avoid the copy. + // + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only + // need to check whether the optimization is enabled and supported by our serializer. + false + } else { + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. + true + } + } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { + // We're using hash-based shuffle, so we don't need to copy. + false + } else { + // Catch-all case to safely handle any future ShuffleManager implementations. + true + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency( + rdd: RDD[InternalRow], + outputAttributes: Seq[Attribute], + newPartitioning: Partitioning, + serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { + val part: Partitioner = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(_, n) => + new Partitioner { + override def numPartitions: Int = n + // For HashPartitioning, the partitioning key is already a valid partition ID, as we use + // `HashPartitioning.partitionIdExpression` to produce partitioning key. + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } + case RangePartitioning(sortingExpressions, numPartitions) => + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + val rddForSampling = rdd.mapPartitionsInternal { iter => + val mutablePair = new MutablePair[InternalRow, Null]() + iter.map(row => mutablePair.update(row.copy(), null)) + } + implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) + new RangePartitioner(numPartitions, rddForSampling, ascending = true) + case SinglePartition => + new Partitioner { + override def numPartitions: Int = 1 + override def getPartition(key: Any): Int = 0 + } + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + // TODO: Handle BroadcastPartitioning. + } + def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => + // Distributes elements evenly across output partitions, starting from a random partition. + var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) + (row: InternalRow) => { + // The HashPartitioner will handle the `mod` by the number of partitions + position += 1 + position + } + case h: HashPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + row => projection(row).getInt(0) + case RangePartitioning(_, _) | SinglePartition => identity + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + rdd.mapPartitionsInternal { iter => + val getPartitionKey = getPartitionKeyExtractor() + iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + } + } else { + rdd.mapPartitionsInternal { iter => + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } + } + } + + // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds + // are in the form of (partitionId, row) and every partitionId is in the expected range + // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. + val dependency = + new ShuffleDependency[Int, InternalRow, InternalRow]( + rddWithPartitionIds, + new PartitionIdPassthrough(part.numPartitions), + Some(serializer)) + + dependency + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a64da22..ddc0882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution.joins -import scala.concurrent._ -import scala.concurrent.duration._ - import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -27,10 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.ThreadUtils import org.apache.spark.util.collection.CompactBuffer /** @@ -52,60 +48,25 @@ case class BroadcastHashJoin( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - val timeout: Duration = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - row.copy() - }.collect() - // The following line doesn't run in a job so we cannot track the metric value. However, we - // have already tracked it in the above lines. So here we can use - // `SQLMetrics.nullLongMetric` to ignore it. - // TODO: move this check into HashedRelation - val hashed = if (canJoinKeyFitWithinLong) { - LongHashedRelation( - input.iterator, buildSideKeyGenerator, input.size) - } else { - HashedRelation( - input.iterator, buildSideKeyGenerator, input.size) - } - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture + override def requiredChildDistribution: Seq[Distribution] = { + val mode = HashedRelationBroadcastMode( + canJoinKeyFitWithinLong, + rewriteKeyExpr(buildKeys), + buildPlan.output) + buildSide match { + case BuildLeft => + BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastRelation = Await.result(broadcastFuture, timeout) - + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value @@ -160,7 +121,7 @@ case class BroadcastHashJoin( */ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation - val broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName @@ -362,9 +323,3 @@ case class BroadcastHashJoin( } } } - -object BroadcastHashJoin { - - private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) -} http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 4f1cfd2..1f99fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -38,25 +39,25 @@ case class BroadcastLeftSemiJoinHash( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: Seq[Distribution] = { + val mode = if (condition.isEmpty) { + HashSetBroadcastMode(rightKeys, right.output) + } else { + HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) + } + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val input = right.execute().map { row => - row.copy() - }.collect() - if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) - val broadcastedRelation = sparkContext.broadcast(hashSet) - + val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]() left.execute().mapPartitionsInternal { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = - HashedRelation(input.toIterator, rightKeyGenerator, input.size) - val broadcastedRelation = sparkContext.broadcast(hashRelation) - + val broadcastedRelation = right.executeBroadcast[HashedRelation]() left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 4585cbd..e8bd7f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} @@ -33,7 +33,6 @@ case class BroadcastNestedLoopJoin( buildSide: BuildSide, joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -44,8 +43,15 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def requiredChildDistribution: Seq[Distribution] = buildSide match { + case BuildLeft => + BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) + UnsafeProjection.create(schema) } override def outputPartitioning: Partitioning = streamed.outputPartitioning @@ -73,15 +79,14 @@ case class BroadcastNestedLoopJoin( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - row.copy() - }.collect().toIndexedSeq) + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val relation = broadcastedRelation.value + val matchedRows = new CompactBuffer[InternalRow] - val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = new BitSet(relation.length) val joinedRow = new JoinedRow val leftNulls = new GenericMutableRow(left.output.size) @@ -92,8 +97,8 @@ case class BroadcastNestedLoopJoin( var i = 0 var streamRowMatched = false - while (i < broadcastedRelation.value.size) { - val broadcastedRow = broadcastedRelation.value(i) + while (i < relation.length) { + val broadcastedRow = relation(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 0220e0b..1cb6a00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric @@ -44,22 +45,7 @@ trait HashSemiJoin { protected def buildKeyHashSet( buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { - val hashSet = new java.util.HashSet[InternalRow]() - - // Create a Hash set of buildKeys - val rightKey = rightKeyGenerator - while (buildIter.hasNext) { - val currentRow = buildIter.next() - val rowKey = rightKey(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey.copy()) - } - } - } - - hashSet + HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter) } protected def hashSemiJoin( @@ -92,3 +78,36 @@ trait HashSemiJoin { } } } + +private[execution] object HashSemiJoin { + def buildKeyHashSet( + keys: Seq[Expression], + attributes: Seq[Attribute], + rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = { + val hashSet = new java.util.HashSet[InternalRow]() + + // Create a Hash set of buildKeys + val key = UnsafeProjection.create(keys, attributes) + while (rows.hasNext) { + val currentRow = rows.next() + val rowKey = key(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey.copy()) + } + } + } + hashSet + } +} + +/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */ +private[execution] case class HashSetBroadcastMode( + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { + + override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = { + HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0978570..606269b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,12 +25,11 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} import org.apache.spark.sql.execution.local.LocalNode -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.MemoryLocation import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} import org.apache.spark.util.collection.CompactBuffer @@ -675,3 +674,20 @@ private[joins] object LongHashedRelation { } } } + +/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ +private[execution] case class HashedRelationBroadcastMode( + canJoinKeyFitWithinLong: Boolean, + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { + + def transform(rows: Array[InternalRow]): HashedRelation = { + val generator = UnsafeProjection.create(keys, attributes) + if (canJoinKeyFitWithinLong) { + LongHashedRelation(rows.iterator, generator, rows.length) + } else { + HashedRelation(rows.iterator, generator, rows.length) + } + } +} + http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index ce758d6..df6dac8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -29,9 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * for hash join. */ case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -46,27 +44,28 @@ case class LeftSemiJoinBNL( /** The Broadcast relation */ override def right: SparkPlan = broadcast + override def requiredChildDistribution: Seq[Distribution] = { + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - row.copy() - }.collect().toIndexedSeq) + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow + val relation = broadcastedRelation.value streamedIter.filter(streamedRow => { var i = 0 var matched = false - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + while (i < relation.length && !matched) { + if (boundCondition(joinedRow(streamedRow, relation(i)))) { matched = true } i += 1 http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ef76847..cd543d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange /** @@ -38,7 +39,8 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { val shuffled = new ShuffledRowRDD( - Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer)) + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -110,7 +112,8 @@ case class TakeOrderedAndProject( } } val shuffled = new ShuffledRowRDD( - Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer)) + ShuffleExchange.prepareShuffleDependency( + localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) if (projectList.isDefined) { http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index e8d0678..83d7953 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -23,9 +23,9 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -357,7 +357,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected) + assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 99ba2e2..50a2464 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -26,8 +26,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} -import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.test.SQLTestData.TestData2 @@ -1119,7 +1119,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } atFirstAgg = true } - case e: Exchange => atFirstAgg = false + case e: ShuffleExchange => atFirstAgg = false case _ => } } http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 35ff1c4..b1c588a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql._ +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext @@ -297,13 +298,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -311,7 +312,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -348,13 +349,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -362,7 +363,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -404,13 +405,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -456,13 +457,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 87bff32..d4f22de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -28,7 +29,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => Exchange(SinglePartition, plan), + plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 250ce8f..4de5678 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -212,7 +213,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length assert(numExchanges === 5) } @@ -227,7 +228,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length assert(numExchanges === 5) } @@ -295,7 +296,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -333,7 +334,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -353,7 +354,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -376,7 +377,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -435,7 +436,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = Exchange(finalPartitioning, + val inputPlan = ShuffleExchange(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -444,7 +445,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.size == 2) { + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") } } @@ -455,7 +456,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = Exchange(finalPartitioning, + val inputPlan = ShuffleExchange(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -464,7 +465,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.size == 1) { + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") } } http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index e25b5e0..a256ee9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext} +import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ /** @@ -62,7 +63,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = df3.queryExecution.sparkPlan + val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index e22a810..6dfff37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -88,7 +89,15 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan) + val broadcastJoin = joins.BroadcastHashJoin( + leftKeys, + rightKeys, + Inner, + side, + boundCondition, + leftPlan, + rightPlan) + EnsureRequirements(sqlContext).apply(broadcastJoin) } def makeSortMergeJoin( http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index f4b01fb..cd6b6fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 9c86084..f3ad840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} http://git-wip-us.apache.org/repos/asf/spark/blob/b6a873d6/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9ba6456..a05a57c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,8 +22,9 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.{Exchange, PhysicalRDD} +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -252,8 +253,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin]) val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin] - assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft) - assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight) + assert(joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft) + assert(joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight) } } } @@ -312,7 +313,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) } } @@ -326,7 +327,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) } } @@ -339,7 +340,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val agged = hiveContext.table("bucketed_table").groupBy("i").count() // make sure we fall back to non-bucketing mode and can't avoid shuffle - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isDefined) checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i")) } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
