Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/16633#discussion_r96782248 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala --- @@ -90,21 +94,74 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the first `limit` elements of the child's partitions. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning -} +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = child.output -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + protected override def doExecute(): RDD[InternalRow] = { + // This logic is mainly copyed from `SparkPlan.executeTake`. + // TODO: combine this with `SparkPlan.executeTake`, if possible. + val childRDD = child.execute() + val totalParts = childRDD.partitions.length + var partsScanned = 0 + var totalNum = 0 + var resultRDD: RDD[InternalRow] = null + while (totalNum < limit && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1L + if (partsScanned > 0) { + // If we didn't find any rows after the previous iteration, quadruple and retry. + // Otherwise, interpolate the number of partitions we need to try, but overestimate + // it by 50%. We also cap the estimation in the end. + val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2) + if (totalNum == 0) { + numPartsToTry = partsScanned * limitScaleUpFactor + } else { + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * limit * partsScanned / totalNum).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) + } + } - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) + val sc = sqlContext.sparkContext + val res = sc.runJob(childRDD, + (it: Iterator[InternalRow]) => Array[Int](it.size), p) + + totalNum += res.map(_.head).sum + partsScanned += p.size + + if (totalNum >= limit) { + // If we scan more rows than the limit number, we need to reduce that from scanned. + // We calculate how many rows need to be reduced for each partition, + // until all redunant rows are reduced. + var numToReduce = (totalNum - limit) + val reduceAmounts = new HashMap[Int, Int]() + val partitionsToReduce = p.zip(res.map(_.head)).foreach { case (part, size) => + val toReduce = if (size > numToReduce) numToReduce else size + reduceAmounts += ((part, toReduce)) + numToReduce -= toReduce + } + resultRDD = childRDD.mapPartitionsWithIndexInternal { case (index, iter) => + if (index < partsScanned) { --- End diff -- When no rows to drop, it just returns the iterator without consuming it. So so scan will happen.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org