zhengruifeng commented on code in PR #46101:
URL: https://github.com/apache/spark/pull/46101#discussion_r1568725499


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala:
##########
@@ -118,18 +118,52 @@ case class CollectLimitExec(limit: Int = -1, child: 
SparkPlan, offset: Int = 0)
  * logical plan, which happens when the user is collecting results back to the 
driver.
  */
 case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
+  assert(limit >= 0)
+
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = SinglePartition
   override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
+  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
+  private lazy val writeMetrics =
+    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+  private lazy val readMetrics =
+    SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+  override lazy val metrics = readMetrics ++ writeMetrics
   protected override def doExecute(): RDD[InternalRow] = {
-    // This is a bit hacky way to avoid a shuffle and scanning all data when 
it performs
-    // at `Dataset.tail`.
-    // Since this execution plan and `execute` are currently called only when
-    // `Dataset.tail` is invoked, the jobs are always executed when they are 
supposed to be.
-
-    // If we use this execution plan separately like `Dataset.limit` without 
an actual
-    // job launch, we might just have to mimic the implementation of 
`CollectLimitExec`.
-    sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 
1)
+    val childRDD = child.execute().map(_.copy())
+    if (childRDD.getNumPartitions == 0 || limit == 0) {
+      new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, 
Map.empty)
+    } else {
+      val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
+        childRDD
+      } else {
+        val locallyLimited = childRDD.mapPartitionsInternal(takeRight)
+        new ShuffledRowRDD(
+          ShuffleExchangeExec.prepareShuffleDependency(
+            locallyLimited,
+            child.output,
+            SinglePartition,
+            serializer,
+            writeMetrics),
+          readMetrics)
+      }
+      singlePartitionRDD.mapPartitionsInternal(takeRight)
+    }
+  }
+
+  private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = {
+    if (iter.isEmpty) {
+      Iterator.empty[InternalRow]
+    } else {
+      val queue = HybridRowQueue.apply(output.size)
+      while (iter.hasNext) {
+        queue.add(iter.next().asInstanceOf[UnsafeRow])
+        while (queue.size() > limit) {
+          queue.remove()
+        }
+      }
+      queue.destructiveIterator()

Review Comment:
   Good point



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to