This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 45b3dd5 fix: Only trigger Comet Final aggregation on Comet partial
aggregation (#264)
45b3dd5 is described below
commit 45b3dd58b33f7127b73a8ee980cddd8f22d9e881
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Apr 12 23:28:56 2024 -0700
fix: Only trigger Comet Final aggregation on Comet partial aggregation
(#264)
---
.../apache/comet/CometSparkSessionExtensions.scala | 71 ++++++++++++++++------
.../apache/comet/exec/CometAggregateSuite.scala | 13 ++++
2 files changed, 64 insertions(+), 20 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index a10ac57..275b9eb 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -26,13 +26,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec,
ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -319,26 +320,42 @@ class CometSparkSessionExtensions
}
case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _,
child) =>
- val newOp = transform1(op)
- newOp match {
- case Some(nativeOp) =>
- val modes = aggExprs.map(_.mode).distinct
- // The aggExprs could be empty. For example, if the aggregate
functions only have
- // distinct aggregate functions or only have group by, the
aggExprs is empty and
- // modes is empty too. If aggExprs is not empty, we need to
verify all the aggregates
- // have the same mode.
- assert(modes.length == 1 || modes.length == 0)
- CometHashAggregateExec(
- nativeOp,
- op,
- groupingExprs,
- aggExprs,
- child.output,
- if (modes.nonEmpty) Some(modes.head) else None,
- child,
- SerializedPlan(None))
- case None =>
+ val modes = aggExprs.map(_.mode).distinct
+
+ if (!modes.isEmpty && modes.size != 1) {
+ // This shouldn't happen as all aggregation expressions should
share the same mode.
+ // Fallback to Spark nevertheless here.
+ op
+ } else {
+ val sparkFinalMode = {
+ !modes.isEmpty && modes.head == Final &&
findPartialAgg(child).isEmpty
+ }
+
+ if (sparkFinalMode) {
op
+ } else {
+ val newOp = transform1(op)
+ newOp match {
+ case Some(nativeOp) =>
+ val modes = aggExprs.map(_.mode).distinct
+ // The aggExprs could be empty. For example, if the
aggregate functions only have
+ // distinct aggregate functions or only have group by, the
aggExprs is empty and
+ // modes is empty too. If aggExprs is not empty, we need to
verify all the
+ // aggregates have the same mode.
+ assert(modes.length == 1 || modes.length == 0)
+ CometHashAggregateExec(
+ nativeOp,
+ op,
+ groupingExprs,
+ aggExprs,
+ child.output,
+ if (modes.nonEmpty) Some(modes.head) else None,
+ child,
+ SerializedPlan(None))
+ case None =>
+ op
+ }
+ }
}
case op: ShuffledHashJoinExec
@@ -596,6 +613,20 @@ class CometSparkSessionExtensions
}
}
}
+
+ /**
+ * Find the first Comet partial aggregate in the plan. If it reaches a
Spark HashAggregate
+ * with partial mode, it will return None.
+ */
+ def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
+ plan.collectFirst {
+ case agg: CometHashAggregateExec if
agg.aggregateExpressions.forall(_.mode == Partial) =>
+ Some(agg)
+ case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode
== Partial) => None
+ case a: AQEShuffleReadExec => findPartialAgg(a.child)
+ case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
+ }.flatten
+ }
}
// This rule is responsible for eliminating redundant transitions between
row-based and
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index b95ce9b..89681d3 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -40,6 +40,19 @@ import
org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._
+ test("Only trigger Comet Final aggregation on Comet partial aggregation") {
+ withTempView("lowerCaseData") {
+ lowerCaseData.createOrReplaceTempView("lowerCaseData")
+ withSQLConf(
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+ val df = sql("SELECT LAST(n) FROM lowerCaseData")
+ checkSparkAnswer(df)
+ }
+ }
+ }
+
test(
"Average expression in Comet Final should handle " +
"all null inputs from partial Spark aggregation") {