This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 9205f0d1 chore: Improve ObjectHashAggregate fallback error message
(#849)
9205f0d1 is described below
commit 9205f0d1913933f2cc8767c02a7728a4e318dd49
Author: Andy Grove <[email protected]>
AuthorDate: Tue Aug 20 14:37:42 2024 -0600
chore: Improve ObjectHashAggregate fallback error message (#849)
* add support for ObjectHashAggregate
* Revert a change
---
.../apache/comet/CometSparkSessionExtensions.scala | 19 ++++++++-----------
.../org/apache/comet/serde/QueryPlanSerde.scala | 22 +++++++++++-----------
2 files changed, 19 insertions(+), 22 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index e5b0b0cd..6e3663e5 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -33,7 +33,7 @@ import
org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
BroadcastQueryStageExec, ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -424,16 +424,13 @@ class CometSparkSessionExtensions
op
}
- case op @ HashAggregateExec(
- _,
- _,
- _,
- groupingExprs,
- aggExprs,
- _,
- _,
- resultExpressions,
- child) =>
+ case op: BaseAggregateExec
+ if op.isInstanceOf[HashAggregateExec] ||
+ op.isInstanceOf[ObjectHashAggregateExec] =>
+ val groupingExprs = op.groupingExpressions
+ val aggExprs = op.aggregateExpressions
+ val resultExpressions = op.resultExpressions
+ val child = op.child
val modes = aggExprs.map(_.mode).distinct
if (!modes.isEmpty && modes.size != 1) {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index caeda516..5ef924f6 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -34,7 +34,7 @@ import
org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec,
ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin,
ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
@@ -2663,16 +2663,16 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
None
}
- case HashAggregateExec(
- _,
- _,
- _,
- groupingExpressions,
- aggregateExpressions,
- aggregateAttributes,
- _,
- resultExpressions,
- child) if isCometOperatorEnabled(op.conf,
CometConf.OPERATOR_AGGREGATE) =>
+ case aggregate: BaseAggregateExec
+ if (aggregate.isInstanceOf[HashAggregateExec] ||
+ aggregate.isInstanceOf[ObjectHashAggregateExec]) &&
+ isCometOperatorEnabled(op.conf, CometConf.OPERATOR_AGGREGATE) =>
+ val groupingExpressions = aggregate.groupingExpressions
+ val aggregateExpressions = aggregate.aggregateExpressions
+ val aggregateAttributes = aggregate.aggregateAttributes
+ val resultExpressions = aggregate.resultExpressions
+ val child = aggregate.child
+
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
withInfo(op, "No group by or aggregation")
return None
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]