This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9205f0d1 chore: Improve ObjectHashAggregate fallback error message 
(#849)
9205f0d1 is described below

commit 9205f0d1913933f2cc8767c02a7728a4e318dd49
Author: Andy Grove <[email protected]>
AuthorDate: Tue Aug 20 14:37:42 2024 -0600

    chore: Improve ObjectHashAggregate fallback error message (#849)
    
    * add support for ObjectHashAggregate
    
    * Revert a change
---
 .../apache/comet/CometSparkSessionExtensions.scala | 19 ++++++++-----------
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 22 +++++++++++-----------
 2 files changed, 19 insertions(+), 22 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index e5b0b0cd..6e3663e5 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -33,7 +33,7 @@ import 
org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet
 import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
BroadcastQueryStageExec, ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
HashAggregateExec, ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -424,16 +424,13 @@ class CometSparkSessionExtensions
               op
           }
 
-        case op @ HashAggregateExec(
-              _,
-              _,
-              _,
-              groupingExprs,
-              aggExprs,
-              _,
-              _,
-              resultExpressions,
-              child) =>
+        case op: BaseAggregateExec
+            if op.isInstanceOf[HashAggregateExec] ||
+              op.isInstanceOf[ObjectHashAggregateExec] =>
+          val groupingExprs = op.groupingExpressions
+          val aggExprs = op.aggregateExpressions
+          val resultExpressions = op.resultExpressions
+          val child = op.child
           val modes = aggExprs.map(_.mode).distinct
 
           if (!modes.isEmpty && modes.size != 1) {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index caeda516..5ef924f6 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -34,7 +34,7 @@ import 
org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, 
ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
HashAggregateExec, ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, 
ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.window.WindowExec
@@ -2663,16 +2663,16 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           None
         }
 
-      case HashAggregateExec(
-            _,
-            _,
-            _,
-            groupingExpressions,
-            aggregateExpressions,
-            aggregateAttributes,
-            _,
-            resultExpressions,
-            child) if isCometOperatorEnabled(op.conf, 
CometConf.OPERATOR_AGGREGATE) =>
+      case aggregate: BaseAggregateExec
+          if (aggregate.isInstanceOf[HashAggregateExec] ||
+            aggregate.isInstanceOf[ObjectHashAggregateExec]) &&
+            isCometOperatorEnabled(op.conf, CometConf.OPERATOR_AGGREGATE) =>
+        val groupingExpressions = aggregate.groupingExpressions
+        val aggregateExpressions = aggregate.aggregateExpressions
+        val aggregateAttributes = aggregate.aggregateAttributes
+        val resultExpressions = aggregate.resultExpressions
+        val child = aggregate.child
+
         if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
           withInfo(op, "No group by or aggregation")
           return None


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to