Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/16677#discussion_r218639483
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---
@@ -93,25 +96,93 @@ trait BaseLimitExec extends UnaryExecNode with
CodegenSupport {
}
/**
- * Take the first `limit` elements of each child partition, but do not
collect or shuffle them.
+ * Take the `limit` elements of the child output.
*/
-case class LocalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
+case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
UnaryExecNode {
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
-}
-/**
- * Take the first `limit` elements of the child's single output partition.
- */
-case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
- override def requiredChildDistribution: List[Distribution] = AllTuples
:: Nil
+ private val serializer: Serializer = new
UnsafeRowSerializer(child.output.size)
- override def outputPartitioning: Partitioning = child.outputPartitioning
+ protected override def doExecute(): RDD[InternalRow] = {
+ val childRDD = child.execute()
+ val partitioner = LocalPartitioning(childRDD)
+ val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency(
+ childRDD, child.output, partitioner, serializer)
+ val numberOfOutput: Seq[Long] = if
(shuffleDependency.rdd.getNumPartitions != 0) {
+ // submitMapStage does not accept RDD with 0 partition.
+ // So, we will not submit this dependency.
+ val submittedStageFuture =
sparkContext.submitMapStage(shuffleDependency)
+ submittedStageFuture.get().recordsByPartitionId.toSeq
+ } else {
+ Nil
+ }
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ // During global limit, try to evenly distribute limited rows across
data
+ // partitions. If disabled, scanning data partitions sequentially
until reaching limit number.
+ // Besides, if child output has certain ordering, we can't evenly pick
up rows from
+ // each parititon.
+ val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit &&
child.outputOrdering == Nil
+
+ val shuffled = new ShuffledRowRDD(shuffleDependency)
+
+ val sumOfOutput = numberOfOutput.sum
+ if (sumOfOutput <= limit) {
+ shuffled
+ } else if (!flatGlobalLimit) {
+ var numRowTaken = 0
+ val takeAmounts = numberOfOutput.map { num =>
+ if (numRowTaken + num < limit) {
+ numRowTaken += num.toInt
+ num.toInt
+ } else {
+ val toTake = limit - numRowTaken
+ numRowTaken += toTake
+ toTake
+ }
+ }
+ val broadMap = sparkContext.broadcast(takeAmounts)
+ shuffled.mapPartitionsWithIndexInternal { case (index, iter) =>
+ iter.take(broadMap.value(index).toInt)
+ }
+ } else {
+ // We try to evenly require the asked limit number of rows across
all child rdd's partitions.
+ var rowsNeedToTake: Long = limit
+ val takeAmountByPartition: Array[Long] =
Array.fill[Long](numberOfOutput.length)(0L)
+ val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*)
+
+ while (rowsNeedToTake > 0) {
+ val nonEmptyParts = remainingRowsByPartition.count(_ > 0)
+ // If the rows needed to take are less the number of non-empty
partitions, take one row from
+ // each non-empty partitions until we reach `limit` rows.
+ // Otherwise, evenly divide the needed rows to each non-empty
partitions.
+ val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts)
+ remainingRowsByPartition.zipWithIndex.foreach { case (num, index)
=>
+ // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of
`rowsNeedToTake` during
+ // the traversal, so we need to add this check.
+ if (rowsNeedToTake > 0 && num > 0) {
+ if (num >= takePerPart) {
+ rowsNeedToTake -= takePerPart
+ takeAmountByPartition(index) += takePerPart
+ remainingRowsByPartition(index) -= takePerPart
+ } else {
+ rowsNeedToTake -= num
+ takeAmountByPartition(index) += num
+ remainingRowsByPartition(index) -= num
+ }
+ }
+ }
+ }
+ val broadMap = sparkContext.broadcast(takeAmountByPartition)
--- End diff --
Because we want the map to be sent to each node just only once?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]