LuciferYang commented on code in PR #46101:
URL: https://github.com/apache/spark/pull/46101#discussion_r1570111390


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala:
##########
@@ -118,18 +119,57 @@ case class CollectLimitExec(limit: Int = -1, child: 
SparkPlan, offset: Int = 0)
  * logical plan, which happens when the user is collecting results back to the 
driver.
  */
 case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
+  assert(limit >= 0)
+
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = SinglePartition
   override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
+  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
+  private lazy val writeMetrics =
+    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+  private lazy val readMetrics =
+    SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+  override lazy val metrics = readMetrics ++ writeMetrics
   protected override def doExecute(): RDD[InternalRow] = {
-    // This is a bit hacky way to avoid a shuffle and scanning all data when 
it performs
-    // at `Dataset.tail`.
-    // Since this execution plan and `execute` are currently called only when
-    // `Dataset.tail` is invoked, the jobs are always executed when they are 
supposed to be.
-
-    // If we use this execution plan separately like `Dataset.limit` without 
an actual
-    // job launch, we might just have to mimic the implementation of 
`CollectLimitExec`.
-    sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 
1)
+    val childRDD = child.execute()
+    if (childRDD.getNumPartitions == 0 || limit == 0) {
+      new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, 
Map.empty)
+    } else {
+      val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
+        childRDD
+      } else {
+        val locallyLimited = childRDD.mapPartitionsInternal(takeRight)
+        new ShuffledRowRDD(
+          ShuffleExchangeExec.prepareShuffleDependency(
+            locallyLimited,
+            child.output,
+            SinglePartition,
+            serializer,
+            writeMetrics),
+          readMetrics)
+      }
+      singlePartitionRDD.mapPartitionsInternal(takeRight)
+    }
+  }
+
+  private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = {
+    if (iter.isEmpty) {
+      Iterator.empty[InternalRow]
+    } else {
+      val context = TaskContext.get()
+      val queue = HybridRowQueue.apply(context.taskMemoryManager(), 
output.size)
+      context.addTaskCompletionListener[Unit](ctx => queue.close())
+      var count = 0
+      while (iter.hasNext) {
+        queue.add(iter.next().copy().asInstanceOf[UnsafeRow])
+        if (count < limit) {
+          count += 1
+        } else {
+          queue.remove()
+        }
+      }
+      Iterator.range(0, count).map(i => queue.remove())

Review Comment:
   ```suggestion
         Iterator.range(0, count).map(_ => queue.remove())
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala:
##########
@@ -118,18 +119,57 @@ case class CollectLimitExec(limit: Int = -1, child: 
SparkPlan, offset: Int = 0)
  * logical plan, which happens when the user is collecting results back to the 
driver.
  */
 case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
+  assert(limit >= 0)
+
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = SinglePartition
   override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
+  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
+  private lazy val writeMetrics =
+    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+  private lazy val readMetrics =
+    SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+  override lazy val metrics = readMetrics ++ writeMetrics
   protected override def doExecute(): RDD[InternalRow] = {
-    // This is a bit hacky way to avoid a shuffle and scanning all data when 
it performs
-    // at `Dataset.tail`.
-    // Since this execution plan and `execute` are currently called only when
-    // `Dataset.tail` is invoked, the jobs are always executed when they are 
supposed to be.
-
-    // If we use this execution plan separately like `Dataset.limit` without 
an actual
-    // job launch, we might just have to mimic the implementation of 
`CollectLimitExec`.
-    sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 
1)
+    val childRDD = child.execute()
+    if (childRDD.getNumPartitions == 0 || limit == 0) {
+      new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, 
Map.empty)
+    } else {
+      val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
+        childRDD
+      } else {
+        val locallyLimited = childRDD.mapPartitionsInternal(takeRight)
+        new ShuffledRowRDD(
+          ShuffleExchangeExec.prepareShuffleDependency(
+            locallyLimited,
+            child.output,
+            SinglePartition,
+            serializer,
+            writeMetrics),
+          readMetrics)
+      }
+      singlePartitionRDD.mapPartitionsInternal(takeRight)
+    }
+  }
+
+  private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = {
+    if (iter.isEmpty) {
+      Iterator.empty[InternalRow]
+    } else {
+      val context = TaskContext.get()
+      val queue = HybridRowQueue.apply(context.taskMemoryManager(), 
output.size)
+      context.addTaskCompletionListener[Unit](ctx => queue.close())

Review Comment:
   ```suggestion
         context.addTaskCompletionListener[Unit](_ => queue.close())
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala:
##########
@@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, 
SQLShuffleWriteMetricsReporter}
-import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.sql.execution.python.HybridRowQueue

Review Comment:
   On a side note, will we be moving xx to a more common package in the future?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to