This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 2820327  feat: Support Count(Distinct) and similar aggregation 
functions (#42)
2820327 is described below

commit 2820327db4a7324067497110c2a655997b39f4f0
Author: Huaxin Gao <[email protected]>
AuthorDate: Tue Feb 20 09:01:34 2024 -0800

    feat: Support Count(Distinct) and similar aggregation functions (#42)
    
    Co-authored-by: Huaxin Gao <[email protected]>
---
 .../apache/comet/CometSparkSessionExtensions.scala |   8 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 123 ++++++++++++++-------
 .../org/apache/spark/sql/comet/operators.scala     |   2 +-
 .../apache/comet/exec/CometAggregateSuite.scala    |  67 ++++++++++-
 4 files changed, 153 insertions(+), 47 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 69d1fb3..f4f56f0 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -292,14 +292,18 @@ class CometSparkSessionExtensions
           newOp match {
             case Some(nativeOp) =>
               val modes = aggExprs.map(_.mode).distinct
-              assert(modes.length == 1)
+              // The aggExprs could be empty. For example, if the aggregate 
functions only have
+              // distinct aggregate functions or only have group by, the 
aggExprs is empty and
+              // modes is empty too. If aggExprs is not empty, we need to 
verify all the aggregates
+              // have the same mode.
+              assert(modes.length == 1 || modes.length == 0)
               CometHashAggregateExec(
                 nativeOp,
                 op,
                 groupingExprs,
                 aggExprs,
                 child.output,
-                modes.head,
+                if (modes.nonEmpty) Some(modes.head) else None,
                 child)
             case None =>
               op
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index f178a2f..15a26a0 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -28,7 +28,7 @@ import 
org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import org.apache.spark.sql.comet.{CometHashAggregateExec, 
CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, 
CometSinkPlaceHolder, DecimalPrecision}
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -1653,60 +1653,97 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
             _,
             groupingExpressions,
             aggregateExpressions,
-            _,
+            aggregateAttributes,
             _,
             resultExpressions,
             child) if isCometOperatorEnabled(op.conf, "aggregate") =>
-        val modes = aggregateExpressions.map(_.mode).distinct
-
-        if (modes.size != 1) {
-          // This shouldn't happen as all aggregation expressions should share 
the same mode.
-          // Fallback to Spark nevertheless here.
+        if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
           return None
         }
 
-        val mode = modes.head match {
-          case Partial => CometAggregateMode.Partial
-          case Final => CometAggregateMode.Final
-          case _ => return None
-        }
-
-        val output = mode match {
-          case CometAggregateMode.Partial => child.output
-          case CometAggregateMode.Final =>
-            // Assuming `Final` always follows `Partial` aggregation, this 
find the first
-            // `Partial` aggregation and get the input attributes from it.
-            child.collectFirst { case CometHashAggregateExec(_, _, _, _, 
input, Partial, _) =>
-              input
-            } match {
-              case Some(input) => input
-              case _ => return None
-            }
-          case _ => return None
-        }
-
-        val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
         val groupingExprs = groupingExpressions.map(exprToProto(_, 
child.output))
 
-        if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
-          aggExprs.forall(_.isDefined)) {
+        // In some of the cases, the aggregateExpressions could be empty.
+        // For example, if the aggregate functions only have group by or if 
the aggregate
+        // functions only have distinct aggregate functions:
+        //
+        // SELECT COUNT(distinct col2), col1 FROM test group by col1
+        //  +- HashAggregate (keys =[col1# 6], functions =[count (distinct 
col2#7)] )
+        //    +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, 
[plan_id = 36]
+        //      +- HashAggregate (keys =[col1#6], functions =[partial_count 
(distinct col2#7)] )
+        //        +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
+        //          +- Exchange hashpartitioning (col1#6, col2#7, 10), 
ENSURE_REQUIREMENTS, ...
+        //            +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
+        //              +- FileScan parquet spark_catalog.default.test[col1#6, 
col2#7] ......
+        // If the aggregateExpressions is empty, we only want to build 
groupingExpressions,
+        // and skip processing of aggregateExpressions.
+        if (aggregateExpressions.isEmpty) {
           val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
           hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
-          hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
-          if (mode == CometAggregateMode.Final) {
-            val attributes = groupingExpressions.map(_.toAttribute) ++
-              aggregateExpressions.map(_.resultAttribute)
-            val resultExprs = resultExpressions.map(exprToProto(_, attributes))
-            if (resultExprs.exists(_.isEmpty)) {
-              emitWarning(s"Unsupported result expressions found in: 
${resultExpressions}")
-              return None
-            }
-            hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
-          }
-          hashAggBuilder.setModeValue(mode.getNumber)
           Some(result.setHashAgg(hashAggBuilder).build())
         } else {
-          None
+          val modes = aggregateExpressions.map(_.mode).distinct
+
+          if (modes.size != 1) {
+            // This shouldn't happen as all aggregation expressions should 
share the same mode.
+            // Fallback to Spark nevertheless here.
+            return None
+          }
+
+          val mode = modes.head match {
+            case Partial => CometAggregateMode.Partial
+            case Final => CometAggregateMode.Final
+            case _ => return None
+          }
+
+          val output = mode match {
+            case CometAggregateMode.Partial => child.output
+            case CometAggregateMode.Final =>
+              // Assuming `Final` always follows `Partial` aggregation, this 
find the first
+              // `Partial` aggregation and get the input attributes from it.
+              // During finding partial aggregation, we must ensure all 
traversed op are
+              // native operators. If not, we should fallback to Spark.
+              var seenNonNativeOp = false
+              var partialAggInput: Option[Seq[Attribute]] = None
+              child.transformDown {
+                case op if !op.isInstanceOf[CometPlan] =>
+                  seenNonNativeOp = true
+                  op
+                case op @ CometHashAggregateExec(_, _, _, _, input, 
Some(Partial), _) =>
+                  if (!seenNonNativeOp && partialAggInput.isEmpty) {
+                    partialAggInput = Some(input)
+                  }
+                  op
+              }
+
+              if (partialAggInput.isDefined) {
+                partialAggInput.get
+              } else {
+                return None
+              }
+            case _ => return None
+          }
+
+          val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
+          if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
+            aggExprs.forall(_.isDefined)) {
+            val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
+            hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
+            hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
+            if (mode == CometAggregateMode.Final) {
+              val attributes = groupingExpressions.map(_.toAttribute) ++ 
aggregateAttributes
+              val resultExprs = resultExpressions.map(exprToProto(_, 
attributes))
+              if (resultExprs.exists(_.isEmpty)) {
+                emitWarning(s"Unsupported result expressions found in: 
${resultExpressions}")
+                return None
+              }
+              hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
+            }
+            hashAggBuilder.setModeValue(mode.getNumber)
+            Some(result.setHashAgg(hashAggBuilder).build())
+          } else {
+            None
+          }
         }
 
       case op if isCometSink(op) =>
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index eac013e..7ac1084 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -421,7 +421,7 @@ case class CometHashAggregateExec(
     groupingExpressions: Seq[NamedExpression],
     aggregateExpressions: Seq[AggregateExpression],
     input: Seq[Attribute],
-    mode: AggregateMode,
+    mode: Option[AggregateMode],
     child: SparkPlan)
     extends CometUnaryExec {
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 3465406..9098fe2 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -402,7 +402,7 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               "tbl",
               dictionaryEnabled) {
               checkSparkAnswer(
-                "SELECT _2, SUM(_1), MIN(_1), MAX(_1), COUNT(_1), AVG(_1) FROM 
tbl GROUP BY _2")
+                "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), 
COUNT(_1), COUNT(DISTINCT _1), AVG(_1) FROM tbl GROUP BY _2")
             }
           }
         }
@@ -423,6 +423,8 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               withParquetTable(path.toUri.toString, "tbl") {
                 checkSparkAnswer("SELECT _g1, _g2, SUM(_3) FROM tbl GROUP BY 
_g1, _g2")
                 checkSparkAnswer("SELECT _g1, _g2, COUNT(_3) FROM tbl GROUP BY 
_g1, _g2")
+                checkSparkAnswer("SELECT _g1, _g2, SUM(DISTINCT _3) FROM tbl 
GROUP BY _g1, _g2")
+                checkSparkAnswer("SELECT _g1, _g2, COUNT(DISTINCT _3) FROM tbl 
GROUP BY _g1, _g2")
                 checkSparkAnswer("SELECT _g1, _g2, MIN(_3), MAX(_3) FROM tbl 
GROUP BY _g1, _g2")
                 checkSparkAnswer("SELECT _g1, _g2, AVG(_3) FROM tbl GROUP BY 
_g1, _g2")
               }
@@ -453,8 +455,12 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
               withParquetTable(path.toUri.toString, "tbl") {
                 checkSparkAnswer("SELECT _g3, _g4, SUM(_3), SUM(_4) FROM tbl 
GROUP BY _g3, _g4")
+                checkSparkAnswer(
+                  "SELECT _g3, _g4, SUM(DISTINCT _3), SUM(DISTINCT _4) FROM 
tbl GROUP BY _g3, _g4")
                 checkSparkAnswer(
                   "SELECT _g3, _g4, COUNT(_3), COUNT(_4) FROM tbl GROUP BY 
_g3, _g4")
+                checkSparkAnswer(
+                  "SELECT _g3, _g4, COUNT(DISTINCT _3), COUNT(DISTINCT _4) 
FROM tbl GROUP BY _g3, _g4")
                 checkSparkAnswer(
                   "SELECT _g3, _g4, MIN(_3), MAX(_3), MIN(_4), MAX(_4) FROM 
tbl GROUP BY _g3, _g4")
                 checkSparkAnswer("SELECT _g3, _g4, AVG(_3), AVG(_4) FROM tbl 
GROUP BY _g3, _g4")
@@ -482,7 +488,11 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 (1 to 4).foreach { col =>
                   (1 to 14).foreach { gCol =>
                     checkSparkAnswer(s"SELECT _g$gCol, SUM(_$col) FROM tbl 
GROUP BY _g$gCol")
+                    checkSparkAnswer(
+                      s"SELECT _g$gCol, SUM(DISTINCT _$col) FROM tbl GROUP BY 
_g$gCol")
                     checkSparkAnswer(s"SELECT _g$gCol, COUNT(_$col) FROM tbl 
GROUP BY _g$gCol")
+                    checkSparkAnswer(
+                      s"SELECT _g$gCol, COUNT(DISTINCT _$col) FROM tbl GROUP 
BY _g$gCol")
                     checkSparkAnswer(
                       s"SELECT _g$gCol, MIN(_$col), MAX(_$col) FROM tbl GROUP 
BY _g$gCol")
                     checkSparkAnswer(s"SELECT _g$gCol, AVG(_$col) FROM tbl 
GROUP BY _g$gCol")
@@ -722,6 +732,61 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("distinct") {
+    withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+      Seq(true, false).foreach { bosonColumnShuffleEnabled =>
+        withSQLConf(
+          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
bosonColumnShuffleEnabled.toString) {
+          Seq(true, false).foreach { dictionary =>
+            withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+              val table = "test"
+              withTable(table) {
+                sql(s"create table $table(col1 int, col2 int, col3 int) using 
parquet")
+                sql(
+                  s"insert into $table values(1, 1, 1), (1, 1, 1), (1, 3, 1), 
(1, 4, 2), (5, 3, 2)")
+
+                var expectedNumOfBosonAggregates = 2
+
+                checkSparkAnswerAndNumOfAggregates(
+                  s"SELECT DISTINCT(col2) FROM $table",
+                  expectedNumOfBosonAggregates)
+
+                expectedNumOfBosonAggregates = 4
+
+                checkSparkAnswerAndNumOfAggregates(
+                  s"SELECT COUNT(distinct col2) FROM $table",
+                  expectedNumOfBosonAggregates)
+
+                checkSparkAnswerAndNumOfAggregates(
+                  s"SELECT COUNT(distinct col2), col1 FROM $table group by 
col1",
+                  expectedNumOfBosonAggregates)
+
+                checkSparkAnswerAndNumOfAggregates(
+                  s"SELECT SUM(distinct col2) FROM $table",
+                  expectedNumOfBosonAggregates)
+
+                checkSparkAnswerAndNumOfAggregates(
+                  s"SELECT SUM(distinct col2), col1 FROM $table group by col1",
+                  expectedNumOfBosonAggregates)
+
+                checkSparkAnswerAndNumOfAggregates(
+                  "SELECT COUNT(distinct col2), SUM(distinct col2), col1, 
COUNT(distinct col2)," +
+                    s" SUM(distinct col2) FROM $table group by col1",
+                  expectedNumOfBosonAggregates)
+
+                expectedNumOfBosonAggregates = 1
+                checkSparkAnswerAndNumOfAggregates(
+                  "SELECT COUNT(col2), MIN(col2), COUNT(DISTINCT col2), 
SUM(col2)," +
+                    s" SUM(DISTINCT col2), COUNT(DISTINCT col2), col1 FROM 
$table group by col1",
+                  expectedNumOfBosonAggregates)
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)

Reply via email to