This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d1b5e1c15b [GLUTEN-8227][VL] fix: Update sort elimination rules for
Hash Aggregate (#9473)
d1b5e1c15b is described below
commit d1b5e1c15bf86c01d183ca782f3d6007ebff12a0
Author: Ankita Victor <[email protected]>
AuthorDate: Sat Mar 21 15:28:39 2026 +0530
[GLUTEN-8227][VL] fix: Update sort elimination rules for Hash Aggregate
(#9473)
Closes #8227
---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 17 +++++++++++
.../execution/HashAggregateExecTransformer.scala | 33 ++++++++++++++++++++++
.../extension/FlushableHashAggregateRule.scala | 28 ++++++++++++++++--
.../execution/VeloxAggregateFunctionsSuite.scala | 31 ++++++++++++++++++++
.../gluten/backendsapi/SparkPlanExecApi.scala | 23 +++++++++++++++
.../HashAggregateExecBaseTransformer.scala | 20 +++++++++++++
.../extension/columnar/EliminateLocalSort.scala | 4 +--
.../columnar/offload/OffloadSingleNodeRules.scala | 2 +-
8 files changed, 152 insertions(+), 6 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 08b9964094..6cd5aef035 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -353,6 +353,23 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with
Logging {
resultExpressions,
child)
+ override def genSortAggregateExecTransformer(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): HashAggregateExecBaseTransformer =
+ SortHashAggregateExecTransformer(
+ requiredChildDistributionExpressions,
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ child)
+
/** Generate HashAggregateExecPullOutHelper */
override def genHashAggregateExecPullOutHelper(
aggregateExpressions: Seq[AggregateExpression],
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index 8097ea925d..a0c52c7909 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -623,6 +623,39 @@ case class RegularHashAggregateExecTransformer(
}
}
+// Hash aggregation that was offloaded from a SortAggregateExec. Preserves
sort-aggregate semantics
+// so that upstream sort elimination rules can safely remove the preceding
sort.
+case class SortHashAggregateExecTransformer(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends HashAggregateExecTransformer(
+ requiredChildDistributionExpressions,
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ child)
+ with SortAggregateExecTransformer {
+
+ override protected def allowFlush: Boolean = false
+
+ override def simpleString(maxFields: Int): String =
+ s"SortToHash${super.simpleString(maxFields)}"
+
+ override def verboseString(maxFields: Int): String =
+ s"SortToHash${super.verboseString(maxFields)}"
+
+ override protected def withNewChildInternal(newChild: SparkPlan):
HashAggregateExecTransformer = {
+ copy(child = newChild)
+ }
+}
+
// Hash aggregation that emits pre-aggregated data which allows duplications
on grouping keys
// among its output rows.
case class FlushableHashAggregateExecTransformer(
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
index 0aa48d8d37..6216dd8747 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
@@ -82,21 +82,43 @@ case class FlushableHashAggregateRule(session:
SparkSession) extends Rule[SparkP
aggExprs.exists(isUnsupportedAggregation)
}
+ /**
+ * Walks the plan downward, applying func to each
RegularHashAggregateExecTransformer or
+ * SortHashAggregateExecTransformer that is eligible for flushable
conversion. An aggregate is
+ * eligible when all expressions are Partial/PartialMerge, input is not
already partitioned by the
+ * grouping keys, and no aggregate function disallows flushing.
+ */
private def replaceEligibleAggregates(plan: SparkPlan)(
- func: RegularHashAggregateExecTransformer => SparkPlan): SparkPlan = {
+ func: HashAggregateExecTransformer => SparkPlan): SparkPlan = {
def transformDown: SparkPlan => SparkPlan = {
case agg: RegularHashAggregateExecTransformer
if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode
== PartialMerge) =>
- // Not a intermediate agg. Skip.
+ // Not an intermediate agg. Skip.
agg
case agg: RegularHashAggregateExecTransformer
if isAggInputAlreadyDistributedWithAggKeys(agg) =>
- // Data already grouped by aggregate keys, Skip.
+ // Data already grouped by aggregate keys. Skip.
agg
case agg: RegularHashAggregateExecTransformer
if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
+ // Aggregate uses a function that is unsafe to flush. Skip.
agg
case agg: RegularHashAggregateExecTransformer =>
+ // All guards passed; replace with the flushable variant.
+ func(agg)
+ case agg: SortHashAggregateExecTransformer
+ if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode
== PartialMerge) =>
+ // Not an intermediate agg. Skip.
+ agg
+ case agg: SortHashAggregateExecTransformer if
isAggInputAlreadyDistributedWithAggKeys(agg) =>
+ // Data already grouped by aggregate keys. Skip.
+ agg
+ case agg: SortHashAggregateExecTransformer
+ if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
+ // Aggregate uses a function that is unsafe to flush. Skip.
+ agg
+ case agg: SortHashAggregateExecTransformer =>
+ // All guards passed; replace with the flushable variant.
func(agg)
case p if !canPropagate(p) => p
case other => other.withNewChildren(other.children.map(transformDown))
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index e1c89f9869..9f1c8d5cc1 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -1180,6 +1180,37 @@ abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSu
}
}
}
+
+ test("test collect_list with ordering") {
+ withTempView("t1") {
+ Seq((2, "d"), (2, "e"), (2, "f"), (1, "b"), (1, "a"), (1, "c"), (3,
"i"), (3, "h"), (3, "g"))
+ .toDF("id", "value")
+ .createOrReplaceTempView("t1")
+ runQueryAndCompare(
+ """
+ | SELECT 1 - id, collect_list(value) AS values_list
+ | FROM (
+ | select * from
+ | (SELECT id, value
+ | FROM t1
+ | DISTRIBUTE BY rand())
+ | DISTRIBUTE BY id sort by id,value
+ | ) t
+ | GROUP BY 1
+ |""".stripMargin,
+ false
+ ) {
+ df =>
+ {
+ assert(
+ getExecutedPlan(df).count(
+ plan => {
+ plan.isInstanceOf[SortHashAggregateExecTransformer]
+ }) == 2)
+ }
+ }
+ }
+ }
}
class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite
{
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 002c8dad7b..27dbb6ce4c 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -81,6 +81,29 @@ trait SparkPlanExecApi {
resultExpressions: Seq[NamedExpression],
child: SparkPlan): HashAggregateExecBaseTransformer
+ /**
+ * Generate a HashAggregateExecTransformer for a SortAggregateExec that is
being offloaded to a
+ * native hash aggregate. The returned transformer preserves sort-aggregate
semantics (e.g.,
+ * requiredChildOrdering) so that upstream sort elimination rules can
distinguish it from a
+ * regular hash aggregate.
+ */
+ def genSortAggregateExecTransformer(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): HashAggregateExecBaseTransformer =
+ genHashAggregateExecTransformer(
+ requiredChildDistributionExpressions,
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ child)
+
/** Generate HashAggregateExecPullOutHelper */
def genHashAggregateExecPullOutHelper(
aggregateExpressions: Seq[AggregateExpression],
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
index a4bcc6081e..f4e174d9f5 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
@@ -196,8 +196,28 @@ object HashAggregateExecBaseTransformer {
agg.child
)
}
+
+ def fromSortAggregate(agg: BaseAggregateExec):
HashAggregateExecBaseTransformer = {
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genSortAggregateExecTransformer(
+ agg.requiredChildDistributionExpressions,
+ agg.groupingExpressions,
+ agg.aggregateExpressions,
+ agg.aggregateAttributes,
+ getInitialInputBufferOffset(agg),
+ agg.resultExpressions,
+ agg.child
+ )
+ }
}
+/**
+ * Marker trait for hash aggregate transformers that were offloaded from a
SortAggregateExec. This
+ * allows sort elimination rules to distinguish aggregates that were
originally sort-based (and thus
+ * can safely eliminate their upstream sort) from regular hash aggregates
(which must not).
+ */
+trait SortAggregateExecTransformer extends HashAggregateExecBaseTransformer {}
+
trait HashAggregateExecPullOutBaseHelper {
// The direct outputs of Aggregation.
def allAggregateResultAttributes(groupingExpressions: Seq[NamedExpression]):
List[Attribute] =
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
index 8a2c731e5e..17e7f29eec 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension.columnar
-import org.apache.gluten.execution.{HashAggregateExecBaseTransformer,
ProjectExecTransformer, ShuffledHashJoinExecTransformerBase,
SortExecTransformer, WindowGroupLimitExecTransformer}
+import org.apache.gluten.execution.{ProjectExecTransformer,
ShuffledHashJoinExecTransformerBase, SortAggregateExecTransformer,
SortExecTransformer, WindowGroupLimitExecTransformer}
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
@@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.{ProjectExec, SortExec,
SparkPlan, UnaryEx
*/
object EliminateLocalSort extends Rule[SparkPlan] {
private def canEliminateLocalSort(p: SparkPlan): Boolean = p match {
- case _: HashAggregateExecBaseTransformer => true
+ case _: SortAggregateExecTransformer => true
case _: ShuffledHashJoinExecTransformerBase => true
case _: WindowGroupLimitExecTransformer => true
case s: SortExec if s.global == false => true
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
index 684fbd36f1..3d844607b3 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
@@ -215,7 +215,7 @@ object OffloadOthers {
case plan: HashAggregateExec =>
HashAggregateExecBaseTransformer.from(plan)
case plan: SortAggregateExec =>
- HashAggregateExecBaseTransformer.from(plan)
+ HashAggregateExecBaseTransformer.fromSortAggregate(plan)
case plan: ObjectHashAggregateExec =>
HashAggregateExecBaseTransformer.from(plan)
case plan: UnionExec =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]