Github user hvanhovell commented on a diff in the pull request:
https://github.com/apache/spark/pull/16677#discussion_r197410930
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---
@@ -93,25 +98,95 @@ trait BaseLimitExec extends UnaryExecNode with
CodegenSupport {
}
/**
- * Take the first `limit` elements of each child partition, but do not
collect or shuffle them.
+ * Take the `limit` elements of the child output.
*/
-case class LocalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
+case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
UnaryExecNode {
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
-}
-/**
- * Take the first `limit` elements of the child's single output partition.
- */
-case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
- override def requiredChildDistribution: List[Distribution] = AllTuples
:: Nil
+ private val serializer: Serializer = new
UnsafeRowSerializer(child.output.size)
- override def outputPartitioning: Partitioning = child.outputPartitioning
+ protected override def doExecute(): RDD[InternalRow] = {
+ val childRDD = child.execute()
+ val partitioner = LocalPartitioning(child.outputPartitioning,
+ childRDD.getNumPartitions)
+ val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency(
+ childRDD, child.output, partitioner, serializer)
+ val numberOfOutput: Seq[Long] = if
(shuffleDependency.rdd.getNumPartitions != 0) {
+ // submitMapStage does not accept RDD with 0 partition.
+ // So, we will not submit this dependency.
+ val submittedStageFuture =
sparkContext.submitMapStage(shuffleDependency)
+ submittedStageFuture.get().recordsByPartitionId.toSeq
+ } else {
+ Nil
+ }
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ // During global limit, try to evenly distribute limited rows across
data
+ // partitions. If disabled, scanning data partitions sequentially
until reaching limit number.
+ // Besides, if child output has certain ordering, we can't evenly pick
up rows from
+ // each parititon.
+ val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit &&
child.outputOrdering == Nil
+
+ val shuffled = new ShuffledRowRDD(shuffleDependency)
+
+ val sumOfOutput = numberOfOutput.sum
+ if (sumOfOutput <= limit) {
+ shuffled
+ } else if (!flatGlobalLimit) {
+ var numRowTaken = 0
+ val takeAmounts =
mutable.ArrayBuffer.fill[Long](numberOfOutput.length)(0L)
--- End diff --
I might be dumb, but why do you need an `ArrayBuffer` here?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]