Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16677#discussion_r218640368
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---
    @@ -93,25 +96,93 @@ trait BaseLimitExec extends UnaryExecNode with 
CodegenSupport {
     }
     
     /**
    - * Take the first `limit` elements of each child partition, but do not 
collect or shuffle them.
    + * Take the `limit` elements of the child output.
      */
    -case class LocalLimitExec(limit: Int, child: SparkPlan) extends 
BaseLimitExec {
    +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends 
UnaryExecNode {
     
    -  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
    +  override def output: Seq[Attribute] = child.output
     
       override def outputPartitioning: Partitioning = child.outputPartitioning
    -}
     
    -/**
    - * Take the first `limit` elements of the child's single output partition.
    - */
    -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends 
BaseLimitExec {
    +  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
     
    -  override def requiredChildDistribution: List[Distribution] = AllTuples 
:: Nil
    +  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
     
    -  override def outputPartitioning: Partitioning = child.outputPartitioning
    +  protected override def doExecute(): RDD[InternalRow] = {
    +    val childRDD = child.execute()
    +    val partitioner = LocalPartitioning(childRDD)
    +    val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency(
    +      childRDD, child.output, partitioner, serializer)
    +    val numberOfOutput: Seq[Long] = if 
(shuffleDependency.rdd.getNumPartitions != 0) {
    +      // submitMapStage does not accept RDD with 0 partition.
    +      // So, we will not submit this dependency.
    +      val submittedStageFuture = 
sparkContext.submitMapStage(shuffleDependency)
    +      submittedStageFuture.get().recordsByPartitionId.toSeq
    +    } else {
    +      Nil
    +    }
     
    -  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
    +    // During global limit, try to evenly distribute limited rows across 
data
    +    // partitions. If disabled, scanning data partitions sequentially 
until reaching limit number.
    +    // Besides, if child output has certain ordering, we can't evenly pick 
up rows from
    +    // each parititon.
    +    val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && 
child.outputOrdering == Nil
    +
    +    val shuffled = new ShuffledRowRDD(shuffleDependency)
    +
    +    val sumOfOutput = numberOfOutput.sum
    +    if (sumOfOutput <= limit) {
    +      shuffled
    +    } else if (!flatGlobalLimit) {
    +      var numRowTaken = 0
    +      val takeAmounts = numberOfOutput.map { num =>
    +        if (numRowTaken + num < limit) {
    +          numRowTaken += num.toInt
    +          num.toInt
    +        } else {
    +          val toTake = limit - numRowTaken
    +          numRowTaken += toTake
    +          toTake
    +        }
    +      }
    +      val broadMap = sparkContext.broadcast(takeAmounts)
    +      shuffled.mapPartitionsWithIndexInternal { case (index, iter) =>
    +        iter.take(broadMap.value(index).toInt)
    +      }
    +    } else {
    +      // We try to evenly require the asked limit number of rows across 
all child rdd's partitions.
    +      var rowsNeedToTake: Long = limit
    +      val takeAmountByPartition: Array[Long] = 
Array.fill[Long](numberOfOutput.length)(0L)
    +      val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*)
    +
    +      while (rowsNeedToTake > 0) {
    +        val nonEmptyParts = remainingRowsByPartition.count(_ > 0)
    +        // If the rows needed to take are less the number of non-empty 
partitions, take one row from
    +        // each non-empty partitions until we reach `limit` rows.
    +        // Otherwise, evenly divide the needed rows to each non-empty 
partitions.
    +        val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts)
    +        remainingRowsByPartition.zipWithIndex.foreach { case (num, index) 
=>
    +          // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of 
`rowsNeedToTake` during
    +          // the traversal, so we need to add this check.
    +          if (rowsNeedToTake > 0 && num > 0) {
    +            if (num >= takePerPart) {
    +              rowsNeedToTake -= takePerPart
    +              takeAmountByPartition(index) += takePerPart
    +              remainingRowsByPartition(index) -= takePerPart
    +            } else {
    +              rowsNeedToTake -= num
    +              takeAmountByPartition(index) += num
    +              remainingRowsByPartition(index) -= num
    +            }
    +          }
    +        }
    +      }
    +      val broadMap = sparkContext.broadcast(takeAmountByPartition)
    --- End diff --
    
    we also broadcast closures automatically, don't we? so just putting a 
variable in a closure would accomplish this.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to