This is an automated email from the ASF dual-hosted git repository.

csy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f15314b [AURON #1725] Support Native CollectLimit (#1726)
9f15314b is described below

commit 9f15314b72b29a239927f7963c12ad408ef87fa2
Author: Thomas <[email protected]>
AuthorDate: Wed Dec 17 20:52:57 2025 +0800

    [AURON #1725] Support Native CollectLimit (#1726)
    
    # Which issue does this PR close?
    
    Closes #1725 .
    
     # Rationale for this change
    
    
    # What changes are included in this PR?
    
    # Are there any user-facing changes?
    
    # How was this patch tested?
    
    ---------
    
    Co-authored-by: cxzl25 <[email protected]>
---
 .../org/apache/spark/sql/auron/ShimsImpl.scala     |  5 ++
 .../auron/plan/NativeCollectLimitExec.scala}       | 27 +++----
 .../AuronCheckConvertBroadcastExchangeSuite.scala  |  0
 .../AuronCheckConvertShuffleExchangeSuite.scala    |  0
 .../apache/auron}/AuronEmptyNativeRddSuite.scala   |  0
 .../apache/auron}/AuronFunctionSuite.scala         |  3 +-
 .../apache/auron}/AuronQuerySuite.scala            |  0
 .../apache/auron}/AuronSQLTestHelper.scala         |  0
 .../apache/auron}/BaseAuronSQLSuite.scala          |  0
 .../apache/auron}/EmptyNativeRddSuite.scala        |  0
 .../apache/auron}/NativeConvertersSuite.scala      |  3 +-
 .../apache/auron/exec/AuronExecSuite.scala}        | 27 +++----
 .../org/apache/spark/sql/AuronQueryTest.scala      |  3 +-
 .../spark/sql/auron/AuronConvertStrategy.scala     | 16 +---
 .../apache/spark/sql/auron/AuronConverters.scala   | 26 +++----
 .../scala/org/apache/spark/sql/auron/Shims.scala   |  2 +
 .../auron/plan/NativeCollectLimitBase.scala        | 88 ++++++++++++++++++++++
 17 files changed, 139 insertions(+), 61 deletions(-)

diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
index 6aa669de..3acbbed9 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -294,6 +294,11 @@ class ShimsImpl extends Shims with Logging {
   override def createNativeLocalLimitExec(limit: Long, child: SparkPlan): 
NativeLocalLimitBase =
     NativeLocalLimitExec(limit, child)
 
+  override def createNativeCollectLimitExec(
+      limit: Int,
+      child: SparkPlan): NativeCollectLimitBase =
+    NativeCollectLimitExec(limit, child)
+
   override def createNativeParquetInsertIntoHiveTableExec(
       cmd: InsertIntoHiveTable,
       child: SparkPlan): NativeParquetInsertIntoHiveTableBase =
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeCollectLimitExec.scala
similarity index 59%
copy from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
copy to 
spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeCollectLimitExec.scala
index 6a6716e0..ba514ab7 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeCollectLimitExec.scala
@@ -14,25 +14,20 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.auron
+package org.apache.spark.sql.execution.auron.plan
 
-import org.apache.spark.sql.AuronQueryTest
-import org.apache.spark.sql.auron.EmptyNativeRDD
+import org.apache.spark.sql.execution.SparkPlan
 
-class EmptyNativeRddSuite extends AuronQueryTest with BaseAuronSQLSuite {
+import org.apache.auron.sparkver
 
-  test("test empty native rdd") {
-    val sc = spark.sparkContext
-    val empty = new EmptyNativeRDD(sc)
-    assert(empty.count === 0)
-    assert(empty.collect().size === 0)
+case class NativeCollectLimitExec(limit: Int, override val child: SparkPlan)
+    extends NativeCollectLimitBase(limit, child) {
 
-    val thrown = intercept[UnsupportedOperationException] {
-      empty.reduce((row1, _) => {
-        row1
-      })
-    }
-    assert(thrown.getMessage.contains("empty"))
-  }
+  @sparkver("3.2 / 3.3 / 3.4 / 3.5")
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
 
+  @sparkver("3.0 / 3.1")
+  override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
+    copy(child = newChildren.head)
 }
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronCheckConvertBroadcastExchangeSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronCheckConvertBroadcastExchangeSuite.scala
similarity index 100%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronCheckConvertBroadcastExchangeSuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronCheckConvertBroadcastExchangeSuite.scala
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronCheckConvertShuffleExchangeSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronCheckConvertShuffleExchangeSuite.scala
similarity index 100%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronCheckConvertShuffleExchangeSuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronCheckConvertShuffleExchangeSuite.scala
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronEmptyNativeRddSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronEmptyNativeRddSuite.scala
similarity index 100%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronEmptyNativeRddSuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronEmptyNativeRddSuite.scala
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronFunctionSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala
similarity index 99%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronFunctionSuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala
index 78bddbba..18a798bc 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronFunctionSuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala
@@ -18,8 +18,7 @@ package org.apache.auron
 
 import java.text.SimpleDateFormat
 
-import org.apache.spark.sql.AuronQueryTest
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AuronQueryTest, Row}
 
 import org.apache.auron.util.AuronTestUtils
 
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
similarity index 100%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronSQLTestHelper.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronSQLTestHelper.scala
similarity index 100%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronSQLTestHelper.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronSQLTestHelper.scala
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/BaseAuronSQLSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/BaseAuronSQLSuite.scala
similarity index 100%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/BaseAuronSQLSuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/BaseAuronSQLSuite.scala
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/EmptyNativeRddSuite.scala
similarity index 100%
copy from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
copy to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/EmptyNativeRddSuite.scala
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/NativeConvertersSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/NativeConvertersSuite.scala
similarity index 96%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/NativeConvertersSuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/NativeConvertersSuite.scala
index 1b11e8f8..dc96731c 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/NativeConvertersSuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/NativeConvertersSuite.scala
@@ -18,8 +18,7 @@ package org.apache.auron
 
 import org.apache.spark.sql.AuronQueryTest
 import org.apache.spark.sql.auron.{AuronConf, NativeConverters}
-import org.apache.spark.sql.catalyst.expressions.Cast
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
 import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, 
StringType}
 
 import org.apache.auron.protobuf.ScalarFunction
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronExecSuite.scala
similarity index 56%
rename from 
spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
rename to 
spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronExecSuite.scala
index 6a6716e0..d7adf3a7 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/EmptyNativeRddSuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronExecSuite.scala
@@ -14,25 +14,26 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.auron
+package org.apache.auron.exec
 
 import org.apache.spark.sql.AuronQueryTest
-import org.apache.spark.sql.auron.EmptyNativeRDD
+import org.apache.spark.sql.execution.auron.plan.NativeCollectLimitExec
 
-class EmptyNativeRddSuite extends AuronQueryTest with BaseAuronSQLSuite {
+import org.apache.auron.BaseAuronSQLSuite
 
-  test("test empty native rdd") {
-    val sc = spark.sparkContext
-    val empty = new EmptyNativeRDD(sc)
-    assert(empty.count === 0)
-    assert(empty.collect().size === 0)
+class AuronExecSuite extends AuronQueryTest with BaseAuronSQLSuite {
 
-    val thrown = intercept[UnsupportedOperationException] {
-      empty.reduce((row1, _) => {
-        row1
-      })
+  test("Collect Limit") {
+    withTable("t1") {
+      sql("create table t1(id INT) using parquet")
+      sql("insert into t1 
values(1),(2),(3),(3),(3),(4),(5),(6),(7),(8),(9),(10)")
+      Seq(1, 3, 8, 12, 20).foreach { limit =>
+        val df = checkSparkAnswerAndOperator(s"SELECT id FROM t1 limit $limit")
+        assert(collectFirst(df.queryExecution.executedPlan) { case e: 
NativeCollectLimitExec =>
+          e
+        }.isDefined)
+      }
     }
-    assert(thrown.getMessage.contains("empty"))
   }
 
 }
diff --git 
a/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/AuronQueryTest.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/AuronQueryTest.scala
index 7bd17b62..678faea8 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/AuronQueryTest.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/AuronQueryTest.scala
@@ -69,7 +69,8 @@ abstract class AuronQueryTest
         .foreach { op: SparkPlan =>
           fail(s"""
                |Found non-native operator: ${op.nodeName}
-               |plan: ${plan}""".stripMargin)
+               |plan:
+               |${plan}""".stripMargin)
         }
     }
 
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
index 0e4dca72..fb7471d4 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
@@ -19,19 +19,7 @@ package org.apache.spark.sql.auron
 import org.apache.commons.lang3.reflect.MethodUtils
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.execution.ExpandExec
-import org.apache.spark.sql.execution.FileSourceScanExec
-import org.apache.spark.sql.execution.FilterExec
-import org.apache.spark.sql.execution.GenerateExec
-import org.apache.spark.sql.execution.GlobalLimitExec
-import org.apache.spark.sql.execution.LocalLimitExec
-import org.apache.spark.sql.execution.LocalTableScanExec
-import org.apache.spark.sql.execution.ProjectExec
-import org.apache.spark.sql.execution.SortExec
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.TakeOrderedAndProjectExec
-import org.apache.spark.sql.execution.UnaryExecNode
-import org.apache.spark.sql.execution.UnionExec
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
@@ -162,6 +150,8 @@ object AuronConvertStrategy extends Logging {
         e.setTagValue(convertStrategyTag, AlwaysConvert)
       case e: TakeOrderedAndProjectExec if isNative(e.child) =>
         e.setTagValue(convertStrategyTag, AlwaysConvert)
+      case e: CollectLimitExec if isNative(e.child) =>
+        e.setTagValue(convertStrategyTag, AlwaysConvert)
       case e: HashAggregateExec if isNative(e.child) =>
         e.setTagValue(convertStrategyTag, AlwaysConvert)
       case e: SortAggregateExec if isNative(e.child) =>
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
index 4e39198d..413ad7be 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
@@ -49,20 +49,7 @@ import 
org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
 import org.apache.spark.sql.catalyst.plans.physical.RoundRobinPartitioning
-import org.apache.spark.sql.execution.ExpandExec
-import org.apache.spark.sql.execution.FileSourceScanExec
-import org.apache.spark.sql.execution.FilterExec
-import org.apache.spark.sql.execution.GenerateExec
-import org.apache.spark.sql.execution.GlobalLimitExec
-import org.apache.spark.sql.execution.LeafExecNode
-import org.apache.spark.sql.execution.LocalLimitExec
-import org.apache.spark.sql.execution.LocalTableScanExec
-import org.apache.spark.sql.execution.ProjectExec
-import org.apache.spark.sql.execution.SortExec
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.TakeOrderedAndProjectExec
-import org.apache.spark.sql.execution.UnaryExecNode
-import org.apache.spark.sql.execution.UnionExec
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
 import org.apache.spark.sql.execution.aggregate.SortAggregateExec
@@ -119,6 +106,8 @@ object AuronConverters extends Logging {
     getBooleanConf("spark.auron.enable.global.limit", defaultValue = true)
   def enableTakeOrderedAndProject: Boolean =
     getBooleanConf("spark.auron.enable.take.ordered.and.project", defaultValue 
= true)
+  def enableCollectLimit: Boolean =
+    getBooleanConf("spark.auron.enable.collectLimit", defaultValue = true)
   def enableAggr: Boolean =
     getBooleanConf("spark.auron.enable.aggr", defaultValue = true)
   def enableExpand: Boolean =
@@ -227,6 +216,8 @@ object AuronConverters extends Logging {
         tryConvert(e, convertGlobalLimitExec)
       case e: TakeOrderedAndProjectExec if enableTakeOrderedAndProject =>
         tryConvert(e, convertTakeOrderedAndProjectExec)
+      case e: CollectLimitExec if enableCollectLimit =>
+        tryConvert(e, convertCollectLimitExec)
 
       case e: HashAggregateExec if enableAggr => // hash aggregate
         val convertedAgg = tryConvert(e, convertHashAggregateExec)
@@ -325,6 +316,8 @@ object AuronConverters extends Logging {
           "Conversion disabled: spark.auron.enable.global.limit=false."
         case _: TakeOrderedAndProjectExec if !enableTakeOrderedAndProject =>
           "Conversion disabled: 
spark.auron.enable.take.ordered.and.project=false."
+        case _: CollectLimitExec if !enableCollectLimit =>
+          "Conversion disabled: spark.auron.enable.collectLimit=false."
         case _: HashAggregateExec if !enableAggr =>
           "Conversion disabled: spark.auron.enable.aggr=false."
         case _: ObjectHashAggregateExec if !enableAggr =>
@@ -796,6 +789,11 @@ object AuronConverters extends Logging {
     }
   }
 
+  def convertCollectLimitExec(exec: CollectLimitExec): SparkPlan = {
+    logDebugPlanConversion(exec)
+    Shims.get.createNativeCollectLimitExec(exec.limit, exec.child)
+  }
+
   def convertHashAggregateExec(exec: HashAggregateExec): SparkPlan = {
     // split non-trivial children exprs in partial-agg to a ProjectExec
     // for enabling filter-project optimization in native side
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
index fbac6a92..a192e198 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
@@ -123,6 +123,8 @@ abstract class Shims {
 
   def createNativeLocalLimitExec(limit: Long, child: SparkPlan): 
NativeLocalLimitBase
 
+  def createNativeCollectLimitExec(limit: Int, child: SparkPlan): 
NativeCollectLimitBase
+
   def createNativeParquetInsertIntoHiveTableExec(
       cmd: InsertIntoHiveTable,
       child: SparkPlan): NativeParquetInsertIntoHiveTableBase
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeCollectLimitBase.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeCollectLimitBase.scala
new file mode 100644
index 00000000..f8315ed1
--- /dev/null
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeCollectLimitBase.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.auron.plan
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.OneToOneDependency
+import org.apache.spark.sql.auron.{NativeHelper, NativeRDD, NativeSupports, 
Shims}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, 
UnknownPartitioning}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+
+import org.apache.auron.metric.SparkMetricNode
+import org.apache.auron.protobuf.{LimitExecNode, PhysicalPlanNode}
+
+abstract class NativeCollectLimitBase(limit: Int, override val child: 
SparkPlan)
+    extends UnaryExecNode
+    with NativeSupports {
+  override def output: Seq[Attribute] = child.output
+
+  override lazy val metrics: Map[String, SQLMetric] =
+    (mutable.LinkedHashMap[String, SQLMetric]() ++
+      Map(
+        "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+        "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of 
partitions"))).toMap
+
+  override def executeCollect(): Array[InternalRow] = {
+    val partial = Shims.get.createNativeLocalLimitExec(limit, child)
+    val buf = new ArrayBuffer[InternalRow]
+
+    // collect rows partition-by-partition up to 'limit', avoiding 
full-partition collect.
+    val it = partial.execute().toLocalIterator
+    while (buf.size < limit && it.hasNext) {
+      val row = it.next().copy()
+      buf += row
+    }
+    buf.toArray
+  }
+
+  override def doExecuteNative(): NativeRDD = {
+    val partial = Shims.get.createNativeLocalLimitExec(limit, child)
+    if (!partial.outputPartitioning.isInstanceOf[UnknownPartitioning]
+      && partial.outputPartitioning.numPartitions <= 1) {
+      return NativeHelper.executeNative(partial)
+    }
+
+    // merge all LocalLimit child partitions into a single partition
+    val shuffled = Shims.get.createNativeShuffleExchangeExec(SinglePartition, 
partial)
+    val singlePartitionRDD = NativeHelper.executeNative(shuffled)
+
+    new NativeRDD(
+      sparkContext,
+      SparkMetricNode(metrics, singlePartitionRDD.metrics :: Nil),
+      singlePartitionRDD.partitions,
+      singlePartitionRDD.partitioner,
+      new OneToOneDependency(singlePartitionRDD) :: Nil,
+      rddShuffleReadFull = false,
+      (partition, taskContext) => {
+        val inputPartition = singlePartitionRDD.partitions(partition.index)
+        val nativeLimitExec = LimitExecNode
+          .newBuilder()
+          .setInput(singlePartitionRDD.nativePlan(inputPartition, taskContext))
+          .setLimit(limit)
+          .build()
+        PhysicalPlanNode.newBuilder().setLimit(nativeLimitExec).build()
+      },
+      friendlyName = "NativeRDD.CollectLimit")
+  }
+
+  override val nodeName: String = "NativeCollectLimit"
+}

Reply via email to