This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 45b3dd5  fix: Only trigger Comet Final aggregation on Comet partial 
aggregation (#264)
45b3dd5 is described below

commit 45b3dd58b33f7127b73a8ee980cddd8f22d9e881
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Apr 12 23:28:56 2024 -0700

    fix: Only trigger Comet Final aggregation on Comet partial aggregation 
(#264)
---
 .../apache/comet/CometSparkSessionExtensions.scala | 71 ++++++++++++++++------
 .../apache/comet/exec/CometAggregateSuite.scala    | 13 ++++
 2 files changed, 64 insertions(+), 20 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index a10ac57..275b9eb 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -26,13 +26,14 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, 
ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
BroadcastQueryStageExec, ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -319,26 +320,42 @@ class CometSparkSessionExtensions
           }
 
         case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, 
child) =>
-          val newOp = transform1(op)
-          newOp match {
-            case Some(nativeOp) =>
-              val modes = aggExprs.map(_.mode).distinct
-              // The aggExprs could be empty. For example, if the aggregate 
functions only have
-              // distinct aggregate functions or only have group by, the 
aggExprs is empty and
-              // modes is empty too. If aggExprs is not empty, we need to 
verify all the aggregates
-              // have the same mode.
-              assert(modes.length == 1 || modes.length == 0)
-              CometHashAggregateExec(
-                nativeOp,
-                op,
-                groupingExprs,
-                aggExprs,
-                child.output,
-                if (modes.nonEmpty) Some(modes.head) else None,
-                child,
-                SerializedPlan(None))
-            case None =>
+          val modes = aggExprs.map(_.mode).distinct
+
+          if (!modes.isEmpty && modes.size != 1) {
+            // This shouldn't happen as all aggregation expressions should 
share the same mode.
+            // Fallback to Spark nevertheless here.
+            op
+          } else {
+            val sparkFinalMode = {
+              !modes.isEmpty && modes.head == Final && 
findPartialAgg(child).isEmpty
+            }
+
+            if (sparkFinalMode) {
               op
+            } else {
+              val newOp = transform1(op)
+              newOp match {
+                case Some(nativeOp) =>
+                  val modes = aggExprs.map(_.mode).distinct
+                  // The aggExprs could be empty. For example, if the 
aggregate functions only have
+                  // distinct aggregate functions or only have group by, the 
aggExprs is empty and
+                  // modes is empty too. If aggExprs is not empty, we need to 
verify all the
+                  // aggregates have the same mode.
+                  assert(modes.length == 1 || modes.length == 0)
+                  CometHashAggregateExec(
+                    nativeOp,
+                    op,
+                    groupingExprs,
+                    aggExprs,
+                    child.output,
+                    if (modes.nonEmpty) Some(modes.head) else None,
+                    child,
+                    SerializedPlan(None))
+                case None =>
+                  op
+              }
+            }
           }
 
         case op: ShuffledHashJoinExec
@@ -596,6 +613,20 @@ class CometSparkSessionExtensions
         }
       }
     }
+
+    /**
+     * Find the first Comet partial aggregate in the plan. If it reaches a 
Spark HashAggregate
+     * with partial mode, it will return None.
+     */
+    def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
+      plan.collectFirst {
+        case agg: CometHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
+          Some(agg)
+        case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode 
== Partial) => None
+        case a: AQEShuffleReadExec => findPartialAgg(a.child)
+        case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
+      }.flatten
+    }
   }
 
   // This rule is responsible for eliminating redundant transitions between 
row-based and
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index b95ce9b..89681d3 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -40,6 +40,19 @@ import 
org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
 class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
   import testImplicits._
 
+  test("Only trigger Comet Final aggregation on Comet partial aggregation") {
+    withTempView("lowerCaseData") {
+      lowerCaseData.createOrReplaceTempView("lowerCaseData")
+      withSQLConf(
+        CometConf.COMET_ENABLED.key -> "true",
+        CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+        CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+        val df = sql("SELECT LAST(n) FROM lowerCaseData")
+        checkSparkAnswer(df)
+      }
+    }
+  }
+
   test(
     "Average expression in Comet Final should handle " +
       "all null inputs from partial Spark aggregation") {

Reply via email to