Github user hvanhovell commented on a diff in the pull request:
https://github.com/apache/spark/pull/16677#discussion_r197284779
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---
@@ -93,25 +98,101 @@ trait BaseLimitExec extends UnaryExecNode with
CodegenSupport {
}
/**
- * Take the first `limit` elements of each child partition, but do not
collect or shuffle them.
+ * Take the `limit` elements of the child output.
*/
-case class LocalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
+case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
UnaryExecNode {
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
-}
-/**
- * Take the first `limit` elements of the child's single output partition.
- */
-case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ private val serializer: Serializer = new
UnsafeRowSerializer(child.output.size)
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val childRDD = child.execute()
+ val partitioner = LocalPartitioning(child.outputPartitioning,
+ childRDD.getNumPartitions)
+ val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency(
+ childRDD, child.output, partitioner, serializer)
+ val numberOfOutput: Seq[Long] = if
(shuffleDependency.rdd.getNumPartitions != 0) {
+ // submitMapStage does not accept RDD with 0 partition.
+ // So, we will not submit this dependency.
+ val submittedStageFuture =
sparkContext.submitMapStage(shuffleDependency)
+ submittedStageFuture.get().recordsByPartitionId.toSeq
+ } else {
+ Nil
+ }
- override def requiredChildDistribution: List[Distribution] = AllTuples
:: Nil
+ // Try to keep child plan's original data parallelism or not. It is
enabled by default.
+ // If child output has certain ordering, we can't evenly pick up rows
from each parititon.
+ val respectChildParallelism =
sqlContext.conf.enableParallelGlobalLimit &&
+ child.outputOrdering != Nil
- override def outputPartitioning: Partitioning = child.outputPartitioning
+ val shuffled = new ShuffledRowRDD(shuffleDependency)
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ val sumOfOutput = numberOfOutput.sum
+ if (sumOfOutput <= limit) {
+ shuffled
+ } else if (!respectChildParallelism) {
+ // This is mainly for tests.
+ // Some tests like hive compatibility tests assume that the rows are
returned by a specified
+ // order that the partitions are scaned sequentially until we reach
the required number of
+ // rows. However, logically a Limit operator should not care the row
scan order.
+ // Thus we take the rows of each partition until we reach the
required limit number.
+ var numTakenRow = 0
+ val takeAmounts = new mutable.HashMap[Int, Int]()
--- End diff --
Why a hash map? An array would also work right?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]