Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22219#discussion_r213645236
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---
    @@ -329,49 +329,52 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] 
with Logging with Serializ
        *
        * This is modeled after `RDD.take` but never runs any job locally on 
the driver.
        */
    -  def executeTake(n: Int): Array[InternalRow] = {
    +  def executeTake(n: Int): Array[InternalRow] = 
executeTakeIterator(n)._2.toArray
    +
    +  private[spark] def executeTakeIterator(n: Int): (Long, 
Iterator[InternalRow]) = {
         if (n == 0) {
    -      return new Array[InternalRow](0)
    +      return (0, Iterator.empty)
         }
     
    -    val childRDD = getByteArrayRdd(n).map(_._2)
    -
    -    val buf = new ArrayBuffer[InternalRow]
    +    val childRDD = getByteArrayRdd(n)
    +    val encodedBuf = new ArrayBuffer[Array[Byte]]
         val totalParts = childRDD.partitions.length
    +    var scannedRowCount = 0L
         var partsScanned = 0
    -    while (buf.size < n && partsScanned < totalParts) {
    +    while (scannedRowCount < n && partsScanned < totalParts) {
           // The number of partitions to try in this iteration. It is ok for 
this number to be
           // greater than totalParts because we actually cap it at totalParts 
in runJob.
    -      var numPartsToTry = 1L
    -      if (partsScanned > 0) {
    +      val numPartsToTry = if (partsScanned > 0) {
             // If we didn't find any rows after the previous iteration, 
quadruple and retry.
             // Otherwise, interpolate the number of partitions we need to try, 
but overestimate
             // it by 50%. We also cap the estimation in the end.
             val limitScaleUpFactor = 
Math.max(sqlContext.conf.limitScaleUpFactor, 2)
    -        if (buf.isEmpty) {
    -          numPartsToTry = partsScanned * limitScaleUpFactor
    +        if (scannedRowCount == 0) {
    +          partsScanned * limitScaleUpFactor
             } else {
    -          val left = n - buf.size
    +          val left = n - scannedRowCount
               // As left > 0, numPartsToTry is always >= 1
    -          numPartsToTry = Math.ceil(1.5 * left * partsScanned / 
buf.size).toInt
    -          numPartsToTry = Math.min(numPartsToTry, partsScanned * 
limitScaleUpFactor)
    +          Math.min(Math.ceil(1.5 * left * partsScanned / 
scannedRowCount).toInt,
    +            partsScanned * limitScaleUpFactor)
             }
    +      } else {
    +        1L
    --- End diff --
    
    It is also fine to revert this.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to