Github user hvanhovell commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16677#discussion_r197284779
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---
    @@ -93,25 +98,101 @@ trait BaseLimitExec extends UnaryExecNode with 
CodegenSupport {
     }
     
     /**
    - * Take the first `limit` elements of each child partition, but do not 
collect or shuffle them.
    + * Take the `limit` elements of the child output.
      */
    -case class LocalLimitExec(limit: Int, child: SparkPlan) extends 
BaseLimitExec {
    +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends 
UnaryExecNode {
     
    -  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
    +  override def output: Seq[Attribute] = child.output
     
       override def outputPartitioning: Partitioning = child.outputPartitioning
    -}
     
    -/**
    - * Take the first `limit` elements of the child's single output partition.
    - */
    -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends 
BaseLimitExec {
    +  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
    +
    +  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
    +
    +  protected override def doExecute(): RDD[InternalRow] = {
    +    val childRDD = child.execute()
    +    val partitioner = LocalPartitioning(child.outputPartitioning,
    +      childRDD.getNumPartitions)
    +    val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency(
    +      childRDD, child.output, partitioner, serializer)
    +    val numberOfOutput: Seq[Long] = if 
(shuffleDependency.rdd.getNumPartitions != 0) {
    +      // submitMapStage does not accept RDD with 0 partition.
    +      // So, we will not submit this dependency.
    +      val submittedStageFuture = 
sparkContext.submitMapStage(shuffleDependency)
    +      submittedStageFuture.get().recordsByPartitionId.toSeq
    +    } else {
    +      Nil
    +    }
     
    -  override def requiredChildDistribution: List[Distribution] = AllTuples 
:: Nil
    +    // Try to keep child plan's original data parallelism or not. It is 
enabled by default.
    +    // If child output has certain ordering, we can't evenly pick up rows 
from each parititon.
    +    val respectChildParallelism = 
sqlContext.conf.enableParallelGlobalLimit &&
    +      child.outputOrdering != Nil
     
    -  override def outputPartitioning: Partitioning = child.outputPartitioning
    +    val shuffled = new ShuffledRowRDD(shuffleDependency)
     
    -  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
    +    val sumOfOutput = numberOfOutput.sum
    +    if (sumOfOutput <= limit) {
    +      shuffled
    +    } else if (!respectChildParallelism) {
    +      // This is mainly for tests.
    +      // Some tests like hive compatibility tests assume that the rows are 
returned by a specified
    +      // order that the partitions are scaned sequentially until we reach 
the required number of
    +      // rows. However, logically a Limit operator should not care the row 
scan order.
    +      // Thus we take the rows of each partition until we reach the 
required limit number.
    +      var numTakenRow = 0
    +      val takeAmounts = new mutable.HashMap[Int, Int]()
    --- End diff --
    
    Why a hash map? An array would also work right?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to