This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new d1b5e1c15b [GLUTEN-8227][VL] fix: Update sort elimination rules for 
Hash Aggregate (#9473)
d1b5e1c15b is described below

commit d1b5e1c15bf86c01d183ca782f3d6007ebff12a0
Author: Ankita Victor <[email protected]>
AuthorDate: Sat Mar 21 15:28:39 2026 +0530

    [GLUTEN-8227][VL] fix: Update sort elimination rules for Hash Aggregate 
(#9473)
    
    Closes #8227
---
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  | 17 +++++++++++
 .../execution/HashAggregateExecTransformer.scala   | 33 ++++++++++++++++++++++
 .../extension/FlushableHashAggregateRule.scala     | 28 ++++++++++++++++--
 .../execution/VeloxAggregateFunctionsSuite.scala   | 31 ++++++++++++++++++++
 .../gluten/backendsapi/SparkPlanExecApi.scala      | 23 +++++++++++++++
 .../HashAggregateExecBaseTransformer.scala         | 20 +++++++++++++
 .../extension/columnar/EliminateLocalSort.scala    |  4 +--
 .../columnar/offload/OffloadSingleNodeRules.scala  |  2 +-
 8 files changed, 152 insertions(+), 6 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 08b9964094..6cd5aef035 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -353,6 +353,23 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
       resultExpressions,
       child)
 
+  override def genSortAggregateExecTransformer(
+      requiredChildDistributionExpressions: Option[Seq[Expression]],
+      groupingExpressions: Seq[NamedExpression],
+      aggregateExpressions: Seq[AggregateExpression],
+      aggregateAttributes: Seq[Attribute],
+      initialInputBufferOffset: Int,
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): HashAggregateExecBaseTransformer =
+    SortHashAggregateExecTransformer(
+      requiredChildDistributionExpressions,
+      groupingExpressions,
+      aggregateExpressions,
+      aggregateAttributes,
+      initialInputBufferOffset,
+      resultExpressions,
+      child)
+
   /** Generate HashAggregateExecPullOutHelper */
   override def genHashAggregateExecPullOutHelper(
       aggregateExpressions: Seq[AggregateExpression],
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index 8097ea925d..a0c52c7909 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -623,6 +623,39 @@ case class RegularHashAggregateExecTransformer(
   }
 }
 
+// Hash aggregation that was offloaded from a SortAggregateExec. Preserves 
sort-aggregate semantics
+// so that upstream sort elimination rules can safely remove the preceding 
sort.
+case class SortHashAggregateExecTransformer(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends HashAggregateExecTransformer(
+    requiredChildDistributionExpressions,
+    groupingExpressions,
+    aggregateExpressions,
+    aggregateAttributes,
+    initialInputBufferOffset,
+    resultExpressions,
+    child)
+  with SortAggregateExecTransformer {
+
+  override protected def allowFlush: Boolean = false
+
+  override def simpleString(maxFields: Int): String =
+    s"SortToHash${super.simpleString(maxFields)}"
+
+  override def verboseString(maxFields: Int): String =
+    s"SortToHash${super.verboseString(maxFields)}"
+
+  override protected def withNewChildInternal(newChild: SparkPlan): 
HashAggregateExecTransformer = {
+    copy(child = newChild)
+  }
+}
+
 // Hash aggregation that emits pre-aggregated data which allows duplications 
on grouping keys
 // among its output rows.
 case class FlushableHashAggregateExecTransformer(
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
index 0aa48d8d37..6216dd8747 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
@@ -82,21 +82,43 @@ case class FlushableHashAggregateRule(session: 
SparkSession) extends Rule[SparkP
     aggExprs.exists(isUnsupportedAggregation)
   }
 
+  /**
+   * Walks the plan downward, applying func to each 
RegularHashAggregateExecTransformer or
+   * SortHashAggregateExecTransformer that is eligible for flushable 
conversion. An aggregate is
+   * eligible when all expressions are Partial/PartialMerge, input is not 
already partitioned by the
+   * grouping keys, and no aggregate function disallows flushing.
+   */
   private def replaceEligibleAggregates(plan: SparkPlan)(
-      func: RegularHashAggregateExecTransformer => SparkPlan): SparkPlan = {
+      func: HashAggregateExecTransformer => SparkPlan): SparkPlan = {
     def transformDown: SparkPlan => SparkPlan = {
       case agg: RegularHashAggregateExecTransformer
           if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode 
== PartialMerge) =>
-        // Not a intermediate agg. Skip.
+        // Not an intermediate agg. Skip.
         agg
       case agg: RegularHashAggregateExecTransformer
           if isAggInputAlreadyDistributedWithAggKeys(agg) =>
-        // Data already grouped by aggregate keys, Skip.
+        // Data already grouped by aggregate keys. Skip.
         agg
       case agg: RegularHashAggregateExecTransformer
           if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
+        // Aggregate uses a function that is unsafe to flush. Skip.
         agg
       case agg: RegularHashAggregateExecTransformer =>
+        // All guards passed; replace with the flushable variant.
+        func(agg)
+      case agg: SortHashAggregateExecTransformer
+          if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode 
== PartialMerge) =>
+        // Not an intermediate agg. Skip.
+        agg
+      case agg: SortHashAggregateExecTransformer if 
isAggInputAlreadyDistributedWithAggKeys(agg) =>
+        // Data already grouped by aggregate keys. Skip.
+        agg
+      case agg: SortHashAggregateExecTransformer
+          if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
+        // Aggregate uses a function that is unsafe to flush. Skip.
+        agg
+      case agg: SortHashAggregateExecTransformer =>
+        // All guards passed; replace with the flushable variant.
         func(agg)
       case p if !canPropagate(p) => p
       case other => other.withNewChildren(other.children.map(transformDown))
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index e1c89f9869..9f1c8d5cc1 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -1180,6 +1180,37 @@ abstract class VeloxAggregateFunctionsSuite extends 
VeloxWholeStageTransformerSu
       }
     }
   }
+
+  test("test collect_list with ordering") {
+    withTempView("t1") {
+      Seq((2, "d"), (2, "e"), (2, "f"), (1, "b"), (1, "a"), (1, "c"), (3, 
"i"), (3, "h"), (3, "g"))
+        .toDF("id", "value")
+        .createOrReplaceTempView("t1")
+      runQueryAndCompare(
+        """
+          | SELECT 1 - id, collect_list(value) AS values_list
+          |        FROM (
+          |        select * from
+          |        (SELECT id, value
+          |          FROM t1
+          |          DISTRIBUTE BY rand())
+          |          DISTRIBUTE BY id sort by id,value
+          |        ) t
+          |        GROUP BY 1
+          |""".stripMargin,
+        false
+      ) {
+        df =>
+          {
+            assert(
+              getExecutedPlan(df).count(
+                plan => {
+                  plan.isInstanceOf[SortHashAggregateExecTransformer]
+                }) == 2)
+          }
+      }
+    }
+  }
 }
 
 class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite 
{
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 002c8dad7b..27dbb6ce4c 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -81,6 +81,29 @@ trait SparkPlanExecApi {
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): HashAggregateExecBaseTransformer
 
+  /**
+   * Generate a HashAggregateExecTransformer for a SortAggregateExec that is 
being offloaded to a
+   * native hash aggregate. The returned transformer preserves sort-aggregate 
semantics (e.g.,
+   * requiredChildOrdering) so that upstream sort elimination rules can 
distinguish it from a
+   * regular hash aggregate.
+   */
+  def genSortAggregateExecTransformer(
+      requiredChildDistributionExpressions: Option[Seq[Expression]],
+      groupingExpressions: Seq[NamedExpression],
+      aggregateExpressions: Seq[AggregateExpression],
+      aggregateAttributes: Seq[Attribute],
+      initialInputBufferOffset: Int,
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): HashAggregateExecBaseTransformer =
+    genHashAggregateExecTransformer(
+      requiredChildDistributionExpressions,
+      groupingExpressions,
+      aggregateExpressions,
+      aggregateAttributes,
+      initialInputBufferOffset,
+      resultExpressions,
+      child)
+
   /** Generate HashAggregateExecPullOutHelper */
   def genHashAggregateExecPullOutHelper(
       aggregateExpressions: Seq[AggregateExpression],
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
index a4bcc6081e..f4e174d9f5 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
@@ -196,8 +196,28 @@ object HashAggregateExecBaseTransformer {
         agg.child
       )
   }
+
+  def fromSortAggregate(agg: BaseAggregateExec): 
HashAggregateExecBaseTransformer = {
+    BackendsApiManager.getSparkPlanExecApiInstance
+      .genSortAggregateExecTransformer(
+        agg.requiredChildDistributionExpressions,
+        agg.groupingExpressions,
+        agg.aggregateExpressions,
+        agg.aggregateAttributes,
+        getInitialInputBufferOffset(agg),
+        agg.resultExpressions,
+        agg.child
+      )
+  }
 }
 
+/**
+ * Marker trait for hash aggregate transformers that were offloaded from a 
SortAggregateExec. This
+ * allows sort elimination rules to distinguish aggregates that were 
originally sort-based (and thus
+ * can safely eliminate their upstream sort) from regular hash aggregates 
(which must not).
+ */
+trait SortAggregateExecTransformer extends HashAggregateExecBaseTransformer {}
+
 trait HashAggregateExecPullOutBaseHelper {
   // The direct outputs of Aggregation.
   def allAggregateResultAttributes(groupingExpressions: Seq[NamedExpression]): 
List[Attribute] =
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
index 8a2c731e5e..17e7f29eec 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.extension.columnar
 
-import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, 
ProjectExecTransformer, ShuffledHashJoinExecTransformerBase, 
SortExecTransformer, WindowGroupLimitExecTransformer}
+import org.apache.gluten.execution.{ProjectExecTransformer, 
ShuffledHashJoinExecTransformerBase, SortAggregateExecTransformer, 
SortExecTransformer, WindowGroupLimitExecTransformer}
 
 import org.apache.spark.sql.catalyst.expressions.SortOrder
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.{ProjectExec, SortExec, 
SparkPlan, UnaryEx
  */
 object EliminateLocalSort extends Rule[SparkPlan] {
   private def canEliminateLocalSort(p: SparkPlan): Boolean = p match {
-    case _: HashAggregateExecBaseTransformer => true
+    case _: SortAggregateExecTransformer => true
     case _: ShuffledHashJoinExecTransformerBase => true
     case _: WindowGroupLimitExecTransformer => true
     case s: SortExec if s.global == false => true
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
index 684fbd36f1..3d844607b3 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala
@@ -215,7 +215,7 @@ object OffloadOthers {
         case plan: HashAggregateExec =>
           HashAggregateExecBaseTransformer.from(plan)
         case plan: SortAggregateExec =>
-          HashAggregateExecBaseTransformer.from(plan)
+          HashAggregateExecBaseTransformer.fromSortAggregate(plan)
         case plan: ObjectHashAggregateExec =>
           HashAggregateExecBaseTransformer.from(plan)
         case plan: UnionExec =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to