This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 0d63bc13c minor: Small refactor for consistent serde for hash
aggregate (#2764)
0d63bc13c is described below
commit 0d63bc13c20e47570b4f2dc1a994df6920e8bc4e
Author: Andy Grove <[email protected]>
AuthorDate: Wed Nov 12 15:15:39 2025 -0700
minor: Small refactor for consistent serde for hash aggregate (#2764)
---
.../org/apache/comet/rules/CometExecRule.scala | 88 ++++++++--------------
.../comet/serde/operator/CometAggregate.scala | 41 +++++++++-
.../org/apache/spark/sql/comet/operators.scala | 10 ++-
3 files changed, 81 insertions(+), 58 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 708a48d6a..c10a8b5af 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral,
EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan,
GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual,
NamedExpression, Remainder}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -32,7 +31,7 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec}
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
@@ -232,44 +231,37 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
op,
CometExpandExec(_, op, op.output, op.projections, op.child,
SerializedPlan(None)))
- // When Comet shuffle is disabled, we don't want to transform the
HashAggregate
- // to CometHashAggregate. Otherwise, we probably get partial Comet
aggregation
- // and final Spark aggregation.
- case op: BaseAggregateExec
- if op.isInstanceOf[HashAggregateExec] ||
- op.isInstanceOf[ObjectHashAggregateExec] &&
- isCometShuffleEnabled(conf) =>
- val modes = op.aggregateExpressions.map(_.mode).distinct
- // In distinct aggregates there can be a combination of modes
- val multiMode = modes.size > 1
- // For a final mode HashAggregate, we only need to transform the
HashAggregate
- // if there is Comet partial aggregation.
- val sparkFinalMode = modes.contains(Final) &&
findCometPartialAgg(op.child).isEmpty
-
- if (multiMode || sparkFinalMode) {
- op
- } else {
- newPlanWithProto(
- op,
- nativeOp => {
- // The aggExprs could be empty. For example, if the aggregate
functions only have
- // distinct aggregate functions or only have group by, the
aggExprs is empty and
- // modes is empty too. If aggExprs is not empty, we need to
verify all the
- // aggregates have the same mode.
- assert(modes.length == 1 || modes.isEmpty)
- CometHashAggregateExec(
- nativeOp,
- op,
- op.output,
- op.groupingExpressions,
- op.aggregateExpressions,
- op.resultExpressions,
- op.child.output,
- modes.headOption,
- op.child,
- SerializedPlan(None))
- })
- }
+ case op: HashAggregateExec =>
+ newPlanWithProto(
+ op,
+ nativeOp => {
+ CometHashAggregateExec(
+ nativeOp,
+ op,
+ op.output,
+ op.groupingExpressions,
+ op.aggregateExpressions,
+ op.resultExpressions,
+ op.child.output,
+ op.child,
+ SerializedPlan(None))
+ })
+
+ case op: ObjectHashAggregateExec =>
+ newPlanWithProto(
+ op,
+ nativeOp => {
+ CometHashAggregateExec(
+ nativeOp,
+ op,
+ op.output,
+ op.groupingExpressions,
+ op.aggregateExpressions,
+ op.resultExpressions,
+ op.child.output,
+ op.child,
+ SerializedPlan(None))
+ })
case op: ShuffledHashJoinExec
if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
@@ -738,22 +730,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
}
}
- /**
- * Find the first Comet partial aggregate in the plan. If it reaches a Spark
HashAggregate with
- * partial mode, it will return None.
- */
- private def findCometPartialAgg(plan: SparkPlan):
Option[CometHashAggregateExec] = {
- plan.collectFirst {
- case agg: CometHashAggregateExec if
agg.aggregateExpressions.forall(_.mode == Partial) =>
- Some(agg)
- case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode ==
Partial) => None
- case agg: ObjectHashAggregateExec if
agg.aggregateExpressions.forall(_.mode == Partial) =>
- None
- case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
- case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
- }.flatten
- }
-
/**
* Returns true if a given spark plan is Comet shuffle operator.
*/
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
index 93e5d52c8..b0c359f08 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
@@ -23,11 +23,14 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
+import org.apache.spark.sql.comet.CometHashAggregateExec
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.types.MapType
import org.apache.comet.{CometConf, ConfigEntry}
-import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled,
withInfo}
import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode =>
CometAggregateMode, Operator}
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
@@ -38,6 +41,18 @@ trait CometBaseAggregate {
aggregate: BaseAggregateExec,
builder: Operator.Builder,
childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+
+ val modes = aggregate.aggregateExpressions.map(_.mode).distinct
+ // In distinct aggregates there can be a combination of modes
+ val multiMode = modes.size > 1
+ // For a final mode HashAggregate, we only need to transform the
HashAggregate
+ // if there is Comet partial aggregation.
+ val sparkFinalMode = modes.contains(Final) &&
findCometPartialAgg(aggregate.child).isEmpty
+
+ if (multiMode || sparkFinalMode) {
+ return None
+ }
+
val groupingExpressions = aggregate.groupingExpressions
val aggregateExpressions = aggregate.aggregateExpressions
val aggregateAttributes = aggregate.aggregateAttributes
@@ -163,6 +178,22 @@ trait CometBaseAggregate {
}
+ /**
+ * Find the first Comet partial aggregate in the plan. If it reaches a Spark
HashAggregate with
+ * partial mode, it will return None.
+ */
+ private def findCometPartialAgg(plan: SparkPlan):
Option[CometHashAggregateExec] = {
+ plan.collectFirst {
+ case agg: CometHashAggregateExec if
agg.aggregateExpressions.forall(_.mode == Partial) =>
+ Some(agg)
+ case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode ==
Partial) => None
+ case agg: ObjectHashAggregateExec if
agg.aggregateExpressions.forall(_.mode == Partial) =>
+ None
+ case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
+ case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
+ }.flatten
+ }
+
}
object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with
CometBaseAggregate {
@@ -189,6 +220,14 @@ object CometObjectHashAggregate
aggregate: ObjectHashAggregateExec,
builder: Operator.Builder,
childOp: OperatorOuterClass.Operator*):
Option[OperatorOuterClass.Operator] = {
+
+ if (!isCometShuffleEnabled(aggregate.conf)) {
+ // When Comet shuffle is disabled, we don't want to transform the
HashAggregate
+ // to CometHashAggregate. Otherwise, we probably get partial Comet
aggregation
+ // and final Spark aggregation.
+ return None
+ }
+
doConvert(aggregate, builder, childOp: _*)
}
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index d7a743eb2..3d1fccc98 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -739,11 +739,19 @@ case class CometHashAggregateExec(
aggregateExpressions: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
input: Seq[Attribute],
- mode: Option[AggregateMode],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
+
+ // The aggExprs could be empty. For example, if the aggregate functions only
have
+ // distinct aggregate functions or only have group by, the aggExprs is empty
and
+ // modes is empty too. If aggExprs is not empty, we need to verify all the
+ // aggregates have the same mode.
+ val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct
+ assert(modes.length == 1 || modes.isEmpty)
+ val mode = modes.headOption
+
override def producedAttributes: AttributeSet = outputSet ++
AttributeSet(resultExpressions)
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]