Github user scwf commented on a diff in the pull request:
https://github.com/apache/spark/pull/16633#discussion_r96773557
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---
@@ -90,21 +94,74 @@ trait BaseLimitExec extends UnaryExecNode with
CodegenSupport {
}
/**
- * Take the first `limit` elements of each child partition, but do not
collect or shuffle them.
+ * Take the first `limit` elements of the child's partitions.
*/
-case class LocalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
-
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
-
- override def outputPartitioning: Partitioning = child.outputPartitioning
-}
+case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
UnaryExecNode {
+ override def output: Seq[Attribute] = child.output
-/**
- * Take the first `limit` elements of the child's single output partition.
- */
-case class GlobalLimitExec(limit: Int, child: SparkPlan) extends
BaseLimitExec {
+ protected override def doExecute(): RDD[InternalRow] = {
+ // This logic is mainly copyed from `SparkPlan.executeTake`.
+ // TODO: combine this with `SparkPlan.executeTake`, if possible.
+ val childRDD = child.execute()
+ val totalParts = childRDD.partitions.length
+ var partsScanned = 0
+ var totalNum = 0
+ var resultRDD: RDD[InternalRow] = null
+ while (totalNum < limit && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for
this number to be
+ // greater than totalParts because we actually cap it at totalParts
in runJob.
+ var numPartsToTry = 1L
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the previous iteration,
quadruple and retry.
+ // Otherwise, interpolate the number of partitions we need to try,
but overestimate
+ // it by 50%. We also cap the estimation in the end.
+ val limitScaleUpFactor =
Math.max(sqlContext.conf.limitScaleUpFactor, 2)
+ if (totalNum == 0) {
+ numPartsToTry = partsScanned * limitScaleUpFactor
+ } else {
+ // the left side of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = Math.max((1.5 * limit * partsScanned /
totalNum).toInt - partsScanned, 1)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned *
limitScaleUpFactor)
+ }
+ }
- override def requiredChildDistribution: List[Distribution] = AllTuples
:: Nil
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry,
totalParts).toInt)
+ val sc = sqlContext.sparkContext
+ val res = sc.runJob(childRDD,
+ (it: Iterator[InternalRow]) => Array[Int](it.size), p)
+
+ totalNum += res.map(_.head).sum
+ partsScanned += p.size
+
+ if (totalNum >= limit) {
+ // If we scan more rows than the limit number, we need to reduce
that from scanned.
+ // We calculate how many rows need to be reduced for each
partition,
+ // until all redunant rows are reduced.
+ var numToReduce = (totalNum - limit)
+ val reduceAmounts = new HashMap[Int, Int]()
+ val partitionsToReduce = p.zip(res.map(_.head)).foreach { case
(part, size) =>
+ val toReduce = if (size > numToReduce) numToReduce else size
+ reduceAmounts += ((part, toReduce))
+ numToReduce -= toReduce
+ }
+ resultRDD = childRDD.mapPartitionsWithIndexInternal { case (index,
iter) =>
+ if (index < partsScanned) {
--- End diff --
An example: select xxx from table where xxx > 99 limit 1000
if the table is a big table and real num of xxx which > 99 is less than
100, you still need compute the all the partitions and you will do the filter,
scan the big table twice.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]