This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 313111d  feat: Support CollectLimit operator (#100)
313111d is described below

commit 313111d779645f63b5a075aeeef6b0b916c162ee
Author: advancedxy <xian...@apache.org>
AuthorDate: Thu Feb 29 02:19:08 2024 +0800

    feat: Support CollectLimit operator (#100)
---
 .../apache/comet/CometSparkSessionExtensions.scala |  37 +++++++
 .../org/apache/comet/serde/QueryPlanSerde.scala    |   1 +
 .../shims/ShimCometSparkSessionExtensions.scala    |  17 ++++
 .../apache/spark/sql/comet/CometCoalesceExec.scala |  20 +---
 .../spark/sql/comet/CometCollectLimitExec.scala    | 112 +++++++++++++++++++++
 .../apache/spark/sql/comet/CometExecUtils.scala    |  43 ++++++++
 .../sql/comet/CometTakeOrderedAndProjectExec.scala |   6 +-
 .../org/apache/comet/exec/CometExecSuite.scala     |  32 +++++-
 .../scala/org/apache/spark/sql/CometTestBase.scala |  19 ++++
 9 files changed, 261 insertions(+), 26 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 10c3328..dae9f3f 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -65,6 +65,9 @@ class CometSparkSessionExtensions
 
   case class CometExecColumnar(session: SparkSession) extends ColumnarRule {
     override def preColumnarTransitions: Rule[SparkPlan] = 
CometExecRule(session)
+
+    override def postColumnarTransitions: Rule[SparkPlan] =
+      EliminateRedundantColumnarToRow(session)
   }
 
   case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
@@ -284,6 +287,20 @@ class CometSparkSessionExtensions
               op
           }
 
+        case op: CollectLimitExec
+            if isCometNative(op.child) && isCometOperatorEnabled(conf, 
"collectLimit")
+              && isCometShuffleEnabled(conf)
+              && getOffset(op) == 0 =>
+          QueryPlanSerde.operator2Proto(op) match {
+            case Some(nativeOp) =>
+              val offset = getOffset(op)
+              val cometOp =
+                CometCollectLimitExec(op, op.limit, offset, op.child)
+              CometSinkPlaceHolder(nativeOp, op, cometOp)
+            case None =>
+              op
+          }
+
         case op: ExpandExec =>
           val newOp = transform1(op)
           newOp match {
@@ -457,6 +474,26 @@ class CometSparkSessionExtensions
       }
     }
   }
+
+  // CometExec already wraps a `ColumnarToRowExec` for row-based operators. 
Therefore,
+  // `ColumnarToRowExec` is redundant and can be eliminated.
+  //
+  // It was added during ApplyColumnarRulesAndInsertTransitions' 
insertTransitions phase when Spark
+  // requests row-based output such as `collect` call. It's correct to add a 
redundant
+  // `ColumnarToRowExec` for `CometExec`. However, for certain operators such 
as
+  // `CometCollectLimitExec` which overrides `executeCollect`, the redundant 
`ColumnarToRowExec`
+  // makes the override ineffective. The purpose of this rule is to eliminate 
the redundant
+  // `ColumnarToRowExec` for such operators.
+  case class EliminateRedundantColumnarToRow(session: SparkSession) extends 
Rule[SparkPlan] {
+    override def apply(plan: SparkPlan): SparkPlan = {
+      plan match {
+        case ColumnarToRowExec(child: CometCollectLimitExec) =>
+          child
+        case other =>
+          other
+      }
+    }
+  }
 }
 
 object CometSparkSessionExtensions extends Logging {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index fcc0ca9..46eb1b0 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1835,6 +1835,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
       case s if isCometScan(s) => true
       case _: CometSinkPlaceHolder => true
       case _: CoalesceExec => true
+      case _: CollectLimitExec => true
       case _: UnionExec => true
       case _: ShuffleExchangeExec => true
       case _: TakeOrderedAndProjectExec => true
diff --git 
a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
 
b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
index 8afed84..85c6413 100644
--- 
a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
+++ 
b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
@@ -20,9 +20,11 @@
 package org.apache.comet.shims
 
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.execution.{LimitExec, SparkPlan}
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
 
 trait ShimCometSparkSessionExtensions {
+  import org.apache.comet.shims.ShimCometSparkSessionExtensions._
 
   /**
    * TODO: delete after dropping Spark 3.2.0 support and directly call 
scan.pushedAggregate
@@ -32,4 +34,19 @@ trait ShimCometSparkSessionExtensions {
     .map { a => a.setAccessible(true); a }
     .flatMap(_.get(scan).asInstanceOf[Option[Aggregation]])
     .headOption
+
+  /**
+   * TODO: delete after dropping Spark 3.2 and 3.3 support
+   */
+  def getOffset(limit: LimitExec): Int = getOffsetOpt(limit).getOrElse(0)
+
+}
+
+object ShimCometSparkSessionExtensions {
+  private def getOffsetOpt(plan: SparkPlan): Option[Int] = 
plan.getClass.getDeclaredFields
+    .filter(_.getName == "offset")
+    .map { a => a.setAccessible(true); a.get(plan) }
+    .filter(_.isInstanceOf[Int])
+    .map(_.asInstanceOf[Int])
+    .headOption
 }
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala
index fc4f90f..cc635d7 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCoalesceExec.scala
@@ -19,7 +19,6 @@
 
 package org.apache.spark.sql.comet
 
-import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
SinglePartition, UnknownPartitioning}
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
@@ -42,7 +41,7 @@ case class CometCoalesceExec(
     if (numPartitions == 1 && rdd.getNumPartitions < 1) {
       // Make sure we don't output an RDD with 0 partitions, when claiming 
that we have a
       // `SinglePartition`.
-      new CometCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions)
+      CometExecUtils.emptyRDDWithPartitions(sparkContext, 1)
     } else {
       rdd.coalesce(numPartitions, shuffle = false)
     }
@@ -67,20 +66,3 @@ case class CometCoalesceExec(
 
   override def hashCode(): Int = Objects.hashCode(numPartitions: 
java.lang.Integer, child)
 }
-
-object CometCoalesceExec {
-
-  /** A simple RDD with no data, but with the given number of partitions. */
-  class EmptyRDDWithPartitions(@transient private val sc: SparkContext, 
numPartitions: Int)
-      extends RDD[ColumnarBatch](sc, Nil) {
-
-    override def getPartitions: Array[Partition] =
-      Array.tabulate(numPartitions)(i => EmptyPartition(i))
-
-    override def compute(split: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
-      Iterator.empty
-    }
-  }
-
-  case class EmptyPartition(index: Int) extends Partition
-}
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala
new file mode 100644
index 0000000..83126a7
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet
+
+import java.util.Objects
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, 
CometShuffleExchangeExec}
+import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, 
UnaryExecNode, UnsafeRowSerializer}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, 
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * Comet physical plan node for Spark `CollectLimitExec`.
+ *
+ * Similar to `CometTakeOrderedAndProjectExec`, it contains two native 
executions seperated by a
+ * Comet shuffle.
+ *
+ * TODO: support offset semantics
+ */
+case class CometCollectLimitExec(
+    override val originalPlan: SparkPlan,
+    limit: Int,
+    offset: Int,
+    child: SparkPlan)
+    extends CometExec
+    with UnaryExecNode {
+
+  private lazy val writeMetrics =
+    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+  private lazy val readMetrics =
+    SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+  override lazy val metrics: Map[String, SQLMetric] = Map(
+    "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+    "shuffleReadElapsedCompute" ->
+      SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed 
compute at native"),
+    "numPartitions" -> SQLMetrics.createMetric(
+      sparkContext,
+      "number of partitions")) ++ readMetrics ++ writeMetrics
+
+  private lazy val serializer: Serializer =
+    new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))
+
+  override def executeCollect(): Array[InternalRow] = {
+    ColumnarToRowExec(child).executeTake(limit)
+  }
+
+  protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+    val childRDD = child.executeColumnar()
+    if (childRDD.getNumPartitions == 0) {
+      CometExecUtils.emptyRDDWithPartitions(sparkContext, 1)
+    } else {
+      val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
+        childRDD
+      } else {
+        val localLimitedRDD = if (limit >= 0) {
+          CometExecUtils.getNativeLimitRDD(childRDD, output, limit)
+        } else {
+          childRDD
+        }
+        // Shuffle to Single Partition using Comet shuffle
+        val dep = CometShuffleExchangeExec.prepareShuffleDependency(
+          localLimitedRDD,
+          child.output,
+          outputPartitioning,
+          serializer,
+          metrics)
+        metrics("numPartitions").set(dep.partitioner.numPartitions)
+
+        new CometShuffledBatchRDD(dep, readMetrics)
+      }
+      CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit)
+    }
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    this.copy(child = newChild)
+
+  override def stringArgs: Iterator[Any] = Iterator(limit, offset, child)
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case other: CometCollectLimitExec =>
+        this.limit == other.limit && this.offset == other.offset &&
+        this.child == other.child
+      case _ =>
+        false
+    }
+  }
+
+  override def hashCode(): Int =
+    Objects.hashCode(limit: java.lang.Integer, offset: java.lang.Integer, 
child)
+}
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
index 9f8f215..5931920 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
@@ -20,9 +20,13 @@
 package org.apache.spark.sql.comet
 
 import scala.collection.JavaConverters.asJavaIterableConverter
+import scala.reflect.ClassTag
 
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, 
SortOrder}
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import org.apache.comet.serde.OperatorOuterClass
 import org.apache.comet.serde.OperatorOuterClass.Operator
@@ -30,6 +34,29 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, 
serializeDataType}
 
 object CometExecUtils {
 
+  /**
+   * Create an empty RDD with the given number of partitions.
+   */
+  def emptyRDDWithPartitions[T: ClassTag](
+      sparkContext: SparkContext,
+      numPartitions: Int): RDD[T] = {
+    new EmptyRDDWithPartitions(sparkContext, numPartitions)
+  }
+
+  /**
+   * Transform the given RDD into a new RDD that takes the first `limit` 
elements of each
+   * partition. The limit operation is performed on the native side.
+   */
+  def getNativeLimitRDD(
+      childPlan: RDD[ColumnarBatch],
+      outputAttribute: Seq[Attribute],
+      limit: Int): RDD[ColumnarBatch] = {
+    childPlan.mapPartitionsInternal { iter =>
+      val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, 
limit).get
+      CometExec.getCometIterator(Seq(iter), limitOp)
+    }
+  }
+
   /**
    * Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec.
    */
@@ -119,3 +146,19 @@ object CometExecUtils {
     }
   }
 }
+
+/** A simple RDD with no data, but with the given number of partitions. */
+private class EmptyRDDWithPartitions[T: ClassTag](
+    @transient private val sc: SparkContext,
+    numPartitions: Int)
+    extends RDD[T](sc, Nil) {
+
+  override def getPartitions: Array[Partition] =
+    Array.tabulate(numPartitions)(i => EmptyPartition(i))
+
+  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+    Iterator.empty
+  }
+}
+
+private case class EmptyPartition(index: Int) extends Partition
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
index 8898438..26ec401 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
@@ -77,11 +77,7 @@ case class CometTakeOrderedAndProjectExec(
         childRDD
       } else {
         val localTopK = if (orderingSatisfies) {
-          childRDD.mapPartitionsInternal { iter =>
-            val limitOp =
-              CometExecUtils.getLimitNativePlan(output, limit).get
-            CometExec.getCometIterator(Seq(iter), limitOp)
-          }
+          CometExecUtils.getNativeLimitRDD(childRDD, output, limit)
         } else {
           childRDD.mapPartitionsInternal { iter =>
             val topK =
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 05be34c..d7434d5 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, 
CatalogTable}
 import org.apache.spark.sql.catalyst.expressions.Hex
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, 
CometTakeOrderedAndProjectExec}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, 
CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
-import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
UnionExec}
+import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
SQLExecution, UnionExec}
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
 import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, 
CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.window.WindowExec
@@ -1087,6 +1087,34 @@ class CometExecSuite extends CometTestBase {
         }
       })
   }
+
+  test("collect limit") {
+    Seq("true", "false").foreach(aqe => {
+      withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe) {
+        withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
+          val df = sql("SELECT _1 as id, _2 as value FROM tbl limit 2")
+          assert(df.queryExecution.executedPlan.execute().getNumPartitions === 
1)
+          checkSparkAnswerAndOperator(df, Seq(classOf[CometCollectLimitExec]))
+          assert(df.collect().length === 2)
+
+          val qe = df.queryExecution
+          // make sure the root node is CometCollectLimitExec
+          assert(qe.executedPlan.isInstanceOf[CometCollectLimitExec])
+          // executes CometCollectExec directly to check doExecuteColumnar 
implementation
+          SQLExecution.withNewExecutionId(qe, Some("count")) {
+            qe.executedPlan.resetMetrics()
+            assert(qe.executedPlan.execute().count() === 2)
+          }
+
+          assert(df.isEmpty === false)
+
+          // follow up native operation is possible
+          val df3 = df.groupBy("id").sum("value")
+          checkSparkAnswerAndOperator(df3)
+        }
+      }
+    })
+  }
 }
 
 case class BucketedTableTestSpec(
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 2e523fa..0d7904c 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -150,7 +150,15 @@ abstract class CometTestBase
   protected def checkSparkAnswerAndOperator(
       df: => DataFrame,
       excludedClasses: Class[_]*): Unit = {
+    checkSparkAnswerAndOperator(df, Seq.empty, excludedClasses: _*)
+  }
+
+  protected def checkSparkAnswerAndOperator(
+      df: => DataFrame,
+      includeClasses: Seq[Class[_]],
+      excludedClasses: Class[_]*): Unit = {
     checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), 
excludedClasses: _*)
+    checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), 
includeClasses: _*)
     checkSparkAnswer(df)
   }
 
@@ -173,6 +181,17 @@ abstract class CometTestBase
     }
   }
 
+  protected def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): 
Unit = {
+    includePlans.foreach { case planClass =>
+      if (!plan.exists(op => planClass.isAssignableFrom(op.getClass))) {
+        assert(
+          false,
+          s"Expected plan to contain ${planClass.getSimpleName}, but not.\n" +
+            s"plan: $plan")
+      }
+    }
+  }
+
   /**
    * Check the answer of a Comet SQL query with Spark result using absolute 
tolerance.
    */

Reply via email to