This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 2820327 feat: Support Count(Distinct) and similar aggregation
functions (#42)
2820327 is described below
commit 2820327db4a7324067497110c2a655997b39f4f0
Author: Huaxin Gao <[email protected]>
AuthorDate: Tue Feb 20 09:01:34 2024 -0800
feat: Support Count(Distinct) and similar aggregation functions (#42)
Co-authored-by: Huaxin Gao <[email protected]>
---
.../apache/comet/CometSparkSessionExtensions.scala | 8 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 123 ++++++++++++++-------
.../org/apache/spark/sql/comet/operators.scala | 2 +-
.../apache/comet/exec/CometAggregateSuite.scala | 67 ++++++++++-
4 files changed, 153 insertions(+), 47 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 69d1fb3..f4f56f0 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -292,14 +292,18 @@ class CometSparkSessionExtensions
newOp match {
case Some(nativeOp) =>
val modes = aggExprs.map(_.mode).distinct
- assert(modes.length == 1)
+ // The aggExprs could be empty. For example, if the aggregate
functions only have
+ // distinct aggregate functions or only have group by, the
aggExprs is empty and
+ // modes is empty too. If aggExprs is not empty, we need to
verify all the aggregates
+ // have the same mode.
+ assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
nativeOp,
op,
groupingExprs,
aggExprs,
child.output,
- modes.head,
+ if (modes.nonEmpty) Some(modes.head) else None,
child)
case None =>
op
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index f178a2f..15a26a0 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -28,7 +28,7 @@ import
org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import org.apache.spark.sql.comet.{CometHashAggregateExec,
CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan,
CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -1653,60 +1653,97 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
_,
groupingExpressions,
aggregateExpressions,
- _,
+ aggregateAttributes,
_,
resultExpressions,
child) if isCometOperatorEnabled(op.conf, "aggregate") =>
- val modes = aggregateExpressions.map(_.mode).distinct
-
- if (modes.size != 1) {
- // This shouldn't happen as all aggregation expressions should share
the same mode.
- // Fallback to Spark nevertheless here.
+ if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
return None
}
- val mode = modes.head match {
- case Partial => CometAggregateMode.Partial
- case Final => CometAggregateMode.Final
- case _ => return None
- }
-
- val output = mode match {
- case CometAggregateMode.Partial => child.output
- case CometAggregateMode.Final =>
- // Assuming `Final` always follows `Partial` aggregation, this
find the first
- // `Partial` aggregation and get the input attributes from it.
- child.collectFirst { case CometHashAggregateExec(_, _, _, _,
input, Partial, _) =>
- input
- } match {
- case Some(input) => input
- case _ => return None
- }
- case _ => return None
- }
-
- val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
val groupingExprs = groupingExpressions.map(exprToProto(_,
child.output))
- if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
- aggExprs.forall(_.isDefined)) {
+ // In some of the cases, the aggregateExpressions could be empty.
+ // For example, if the aggregate functions only have group by or if
the aggregate
+ // functions only have distinct aggregate functions:
+ //
+ // SELECT COUNT(distinct col2), col1 FROM test group by col1
+ // +- HashAggregate (keys =[col1# 6], functions =[count (distinct
col2#7)] )
+ // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS,
[plan_id = 36]
+ // +- HashAggregate (keys =[col1#6], functions =[partial_count
(distinct col2#7)] )
+ // +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
+ // +- Exchange hashpartitioning (col1#6, col2#7, 10),
ENSURE_REQUIREMENTS, ...
+ // +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
+ // +- FileScan parquet spark_catalog.default.test[col1#6,
col2#7] ......
+ // If the aggregateExpressions is empty, we only want to build
groupingExpressions,
+ // and skip processing of aggregateExpressions.
+ if (aggregateExpressions.isEmpty) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
- hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
- if (mode == CometAggregateMode.Final) {
- val attributes = groupingExpressions.map(_.toAttribute) ++
- aggregateExpressions.map(_.resultAttribute)
- val resultExprs = resultExpressions.map(exprToProto(_, attributes))
- if (resultExprs.exists(_.isEmpty)) {
- emitWarning(s"Unsupported result expressions found in:
${resultExpressions}")
- return None
- }
- hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
- }
- hashAggBuilder.setModeValue(mode.getNumber)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
- None
+ val modes = aggregateExpressions.map(_.mode).distinct
+
+ if (modes.size != 1) {
+ // This shouldn't happen as all aggregation expressions should
share the same mode.
+ // Fallback to Spark nevertheless here.
+ return None
+ }
+
+ val mode = modes.head match {
+ case Partial => CometAggregateMode.Partial
+ case Final => CometAggregateMode.Final
+ case _ => return None
+ }
+
+ val output = mode match {
+ case CometAggregateMode.Partial => child.output
+ case CometAggregateMode.Final =>
+ // Assuming `Final` always follows `Partial` aggregation, this
find the first
+ // `Partial` aggregation and get the input attributes from it.
+ // During finding partial aggregation, we must ensure all
traversed op are
+ // native operators. If not, we should fallback to Spark.
+ var seenNonNativeOp = false
+ var partialAggInput: Option[Seq[Attribute]] = None
+ child.transformDown {
+ case op if !op.isInstanceOf[CometPlan] =>
+ seenNonNativeOp = true
+ op
+ case op @ CometHashAggregateExec(_, _, _, _, input,
Some(Partial), _) =>
+ if (!seenNonNativeOp && partialAggInput.isEmpty) {
+ partialAggInput = Some(input)
+ }
+ op
+ }
+
+ if (partialAggInput.isDefined) {
+ partialAggInput.get
+ } else {
+ return None
+ }
+ case _ => return None
+ }
+
+ val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
+ if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
+ aggExprs.forall(_.isDefined)) {
+ val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
+ hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
+ hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
+ if (mode == CometAggregateMode.Final) {
+ val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateAttributes
+ val resultExprs = resultExpressions.map(exprToProto(_,
attributes))
+ if (resultExprs.exists(_.isEmpty)) {
+ emitWarning(s"Unsupported result expressions found in:
${resultExpressions}")
+ return None
+ }
+ hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
+ }
+ hashAggBuilder.setModeValue(mode.getNumber)
+ Some(result.setHashAgg(hashAggBuilder).build())
+ } else {
+ None
+ }
}
case op if isCometSink(op) =>
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index eac013e..7ac1084 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -421,7 +421,7 @@ case class CometHashAggregateExec(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
input: Seq[Attribute],
- mode: AggregateMode,
+ mode: Option[AggregateMode],
child: SparkPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 3465406..9098fe2 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -402,7 +402,7 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"tbl",
dictionaryEnabled) {
checkSparkAnswer(
- "SELECT _2, SUM(_1), MIN(_1), MAX(_1), COUNT(_1), AVG(_1) FROM
tbl GROUP BY _2")
+ "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1),
COUNT(_1), COUNT(DISTINCT _1), AVG(_1) FROM tbl GROUP BY _2")
}
}
}
@@ -423,6 +423,8 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
withParquetTable(path.toUri.toString, "tbl") {
checkSparkAnswer("SELECT _g1, _g2, SUM(_3) FROM tbl GROUP BY
_g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, COUNT(_3) FROM tbl GROUP BY
_g1, _g2")
+ checkSparkAnswer("SELECT _g1, _g2, SUM(DISTINCT _3) FROM tbl
GROUP BY _g1, _g2")
+ checkSparkAnswer("SELECT _g1, _g2, COUNT(DISTINCT _3) FROM tbl
GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, MIN(_3), MAX(_3) FROM tbl
GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, AVG(_3) FROM tbl GROUP BY
_g1, _g2")
}
@@ -453,8 +455,12 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
withParquetTable(path.toUri.toString, "tbl") {
checkSparkAnswer("SELECT _g3, _g4, SUM(_3), SUM(_4) FROM tbl
GROUP BY _g3, _g4")
+ checkSparkAnswer(
+ "SELECT _g3, _g4, SUM(DISTINCT _3), SUM(DISTINCT _4) FROM
tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, COUNT(_3), COUNT(_4) FROM tbl GROUP BY
_g3, _g4")
+ checkSparkAnswer(
+ "SELECT _g3, _g4, COUNT(DISTINCT _3), COUNT(DISTINCT _4)
FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, MIN(_3), MAX(_3), MIN(_4), MAX(_4) FROM
tbl GROUP BY _g3, _g4")
checkSparkAnswer("SELECT _g3, _g4, AVG(_3), AVG(_4) FROM tbl
GROUP BY _g3, _g4")
@@ -482,7 +488,11 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
(1 to 4).foreach { col =>
(1 to 14).foreach { gCol =>
checkSparkAnswer(s"SELECT _g$gCol, SUM(_$col) FROM tbl
GROUP BY _g$gCol")
+ checkSparkAnswer(
+ s"SELECT _g$gCol, SUM(DISTINCT _$col) FROM tbl GROUP BY
_g$gCol")
checkSparkAnswer(s"SELECT _g$gCol, COUNT(_$col) FROM tbl
GROUP BY _g$gCol")
+ checkSparkAnswer(
+ s"SELECT _g$gCol, COUNT(DISTINCT _$col) FROM tbl GROUP
BY _g$gCol")
checkSparkAnswer(
s"SELECT _g$gCol, MIN(_$col), MAX(_$col) FROM tbl GROUP
BY _g$gCol")
checkSparkAnswer(s"SELECT _g$gCol, AVG(_$col) FROM tbl
GROUP BY _g$gCol")
@@ -722,6 +732,61 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("distinct") {
+ withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+ Seq(true, false).foreach { bosonColumnShuffleEnabled =>
+ withSQLConf(
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key ->
bosonColumnShuffleEnabled.toString) {
+ Seq(true, false).foreach { dictionary =>
+ withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+ val table = "test"
+ withTable(table) {
+ sql(s"create table $table(col1 int, col2 int, col3 int) using
parquet")
+ sql(
+ s"insert into $table values(1, 1, 1), (1, 1, 1), (1, 3, 1),
(1, 4, 2), (5, 3, 2)")
+
+ var expectedNumOfBosonAggregates = 2
+
+ checkSparkAnswerAndNumOfAggregates(
+ s"SELECT DISTINCT(col2) FROM $table",
+ expectedNumOfBosonAggregates)
+
+ expectedNumOfBosonAggregates = 4
+
+ checkSparkAnswerAndNumOfAggregates(
+ s"SELECT COUNT(distinct col2) FROM $table",
+ expectedNumOfBosonAggregates)
+
+ checkSparkAnswerAndNumOfAggregates(
+ s"SELECT COUNT(distinct col2), col1 FROM $table group by
col1",
+ expectedNumOfBosonAggregates)
+
+ checkSparkAnswerAndNumOfAggregates(
+ s"SELECT SUM(distinct col2) FROM $table",
+ expectedNumOfBosonAggregates)
+
+ checkSparkAnswerAndNumOfAggregates(
+ s"SELECT SUM(distinct col2), col1 FROM $table group by col1",
+ expectedNumOfBosonAggregates)
+
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT COUNT(distinct col2), SUM(distinct col2), col1,
COUNT(distinct col2)," +
+ s" SUM(distinct col2) FROM $table group by col1",
+ expectedNumOfBosonAggregates)
+
+ expectedNumOfBosonAggregates = 1
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT COUNT(col2), MIN(col2), COUNT(DISTINCT col2),
SUM(col2)," +
+ s" SUM(DISTINCT col2), COUNT(DISTINCT col2), col1 FROM
$table group by col1",
+ expectedNumOfBosonAggregates)
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
protected def checkSparkAnswerAndNumOfAggregates(query: String,
numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)