This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 488c523 fix: Aggregation without aggregation expressions should use
correct result expressions (#175)
488c523 is described below
commit 488c52365fa331734785ec675f02b6a3c4605587
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Mar 8 22:36:02 2024 -0800
fix: Aggregation without aggregation expressions should use correct result
expressions (#175)
---
.../scala/org/apache/comet/serde/QueryPlanSerde.scala | 7 +++++++
.../org/apache/comet/exec/CometAggregateSuite.scala | 16 ++++++++++++++++
2 files changed, 23 insertions(+)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 56b6690..b27fa3a 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1783,6 +1783,13 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
if (aggregateExpressions.isEmpty) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
+ val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateAttributes
+ val resultExprs = resultExpressions.map(exprToProto(_, attributes))
+ if (resultExprs.exists(_.isEmpty)) {
+ emitWarning(s"Unsupported result expressions found in:
${resultExpressions}")
+ return None
+ }
+ hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
val modes = aggregateExpressions.map(_.mode).distinct
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index bc645cb..8a68a92 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -40,6 +40,22 @@ import
org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._
+ test("Aggregation without aggregate expressions should use correct result
expressions") {
+ withSQLConf(
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test")
+ makeParquetFile(path, 10000, 10, false)
+ withParquetTable(path.toUri.toString, "tbl") {
+ val df = sql("SELECT _g5 FROM tbl GROUP BY _g1, _g2, _g3, _g4, _g5")
+ checkSparkAnswer(df)
+ }
+ }
+ }
+ }
+
test("Final aggregation should not bind to the input of partial
aggregation") {
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",