zhengruifeng commented on code in PR #46101:
URL: https://github.com/apache/spark/pull/46101#discussion_r1568725499
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala:
##########
@@ -118,18 +118,52 @@ case class CollectLimitExec(limit: Int = -1, child:
SparkPlan, offset: Int = 0)
* logical plan, which happens when the user is collecting results back to the
driver.
*/
case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
+ assert(limit >= 0)
+
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition
override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
+ private val serializer: Serializer = new
UnsafeRowSerializer(child.output.size)
+ private lazy val writeMetrics =
+ SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+ private lazy val readMetrics =
+ SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+ override lazy val metrics = readMetrics ++ writeMetrics
protected override def doExecute(): RDD[InternalRow] = {
- // This is a bit hacky way to avoid a shuffle and scanning all data when
it performs
- // at `Dataset.tail`.
- // Since this execution plan and `execute` are currently called only when
- // `Dataset.tail` is invoked, the jobs are always executed when they are
supposed to be.
-
- // If we use this execution plan separately like `Dataset.limit` without
an actual
- // job launch, we might just have to mimic the implementation of
`CollectLimitExec`.
- sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices =
1)
+ val childRDD = child.execute().map(_.copy())
+ if (childRDD.getNumPartitions == 0 || limit == 0) {
+ new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1,
Map.empty)
+ } else {
+ val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
+ childRDD
+ } else {
+ val locallyLimited = childRDD.mapPartitionsInternal(takeRight)
+ new ShuffledRowRDD(
+ ShuffleExchangeExec.prepareShuffleDependency(
+ locallyLimited,
+ child.output,
+ SinglePartition,
+ serializer,
+ writeMetrics),
+ readMetrics)
+ }
+ singlePartitionRDD.mapPartitionsInternal(takeRight)
+ }
+ }
+
+ private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = {
+ if (iter.isEmpty) {
+ Iterator.empty[InternalRow]
+ } else {
+ val queue = HybridRowQueue.apply(output.size)
+ while (iter.hasNext) {
+ queue.add(iter.next().asInstanceOf[UnsafeRow])
+ while (queue.size() > limit) {
+ queue.remove()
+ }
+ }
+ queue.destructiveIterator()
Review Comment:
Good point
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]