This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 27d71c8abd [VL] Refactor `getAggRelInternal` in 
`HashAggregateExecTransformer` (#11040)
27d71c8abd is described below

commit 27d71c8abdf231a1ff3de1358f2a32ecf89da4dd
Author: Zouxxyy <[email protected]>
AuthorDate: Tue Nov 11 08:53:20 2025 +0800

    [VL] Refactor `getAggRelInternal` in `HashAggregateExecTransformer` (#11040)
---
 .../execution/HashAggregateExecTransformer.scala   | 163 +++++++--------------
 1 file changed, 53 insertions(+), 110 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index e46d5340d0..ad9528246c 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -76,15 +76,10 @@ abstract class HashAggregateExecTransformer(
 
   // Return whether the outputs partial aggregation should be combined for 
Velox computing.
   // When the partial outputs are multiple-column, row construct is needed.
-  private def rowConstructNeeded(aggregateExpressions: 
Seq[AggregateExpression]): Boolean = {
-    aggregateExpressions.exists {
-      aggExpr =>
-        aggExpr.mode match {
-          case PartialMerge | Final =>
-            aggExpr.aggregateFunction.inputAggBufferAttributes.size > 1
-          case _ => false
-        }
-    }
+  private def rowConstructNeeded(): Boolean = aggregateExpressions.exists {
+    case AggregateExpression(aggFunc, PartialMerge | Final, _, _, _) =>
+      aggFunc.inputAggBufferAttributes.size > 1
+    case _ => false
   }
 
   /**
@@ -186,13 +181,12 @@ abstract class HashAggregateExecTransformer(
     s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\n"
   }
 
-  // Create aggregate function node and add to list.
-  private def addFunctionNode(
+  // Create aggregate function node.
+  private def makeFunctionNode(
       context: SubstraitContext,
       aggregateFunction: AggregateFunction,
       childrenNodeList: JList[ExpressionNode],
-      aggregateMode: AggregateMode,
-      aggregateNodeList: JList[AggregateFunctionNode]): Unit = {
+      aggregateMode: AggregateMode): AggregateFunctionNode = {
 
     val outputTypeNode = aggregateMode match {
       case Partial | PartialMerge if 
aggregateFunction.aggBufferAttributes.size > 1 =>
@@ -204,13 +198,12 @@ abstract class HashAggregateExecTransformer(
       case Final | Complete =>
         ConverterUtils.getTypeNode(aggregateFunction.dataType, 
aggregateFunction.nullable)
     }
-    val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
+    ExpressionBuilder.makeAggregateFunction(
       VeloxAggregateFunctionsBuilder.create(context, aggregateFunction, 
aggregateMode),
       childrenNodeList,
       modeToKeyWord(aggregateMode),
       outputTypeNode
     )
-    aggregateNodeList.add(aggFunctionNode)
   }
 
   /**
@@ -271,14 +264,13 @@ abstract class HashAggregateExecTransformer(
 
   // Add a projection node before aggregation for row constructing.
   // Mainly used for aggregation whose intermediate type is a compound type in 
Velox.
-  // Pre-projection is always not required for final stage.
-  private def getAggRelWithRowConstruct(
+  // Pre-projection is never required for final stages.
+  private def applyRowConstruct(
       context: SubstraitContext,
       originalInputAttributes: Seq[Attribute],
       operatorId: Long,
       inputRel: RelNode,
       validation: Boolean): RelNode = {
-    // Create a projection for row construct.
     val exprNodes = new JArrayList[ExpressionNode]()
     groupingExpressions.foreach(
       expr => {
@@ -373,69 +365,17 @@ abstract class HashAggregateExecTransformer(
       }
     }
 
-    // Create a project rel.
-    val projectRel = RelBuilder.makeProjectRel(
+    RelBuilder.makeProjectRel(
       originalInputAttributes.asJava,
       inputRel,
       exprNodes,
       context,
       operatorId,
       validation)
-
-    // Create aggregation rel.
-    val groupingList = new JArrayList[ExpressionNode]()
-    var colIdx = 0
-    groupingExpressions.foreach {
-      _ =>
-        groupingList.add(ExpressionBuilder.makeSelection(colIdx))
-        colIdx += 1
-    }
-
-    val aggFilterList = new JArrayList[ExpressionNode]()
-    val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
-    aggregateExpressions.foreach(
-      aggExpr => {
-        if (aggExpr.filter.isDefined) {
-          throw new GlutenNotSupportException("Filter in final aggregation is 
not supported.")
-        } else {
-          // The number of filters should be aligned with that of aggregate 
functions.
-          aggFilterList.add(null)
-        }
-
-        val aggFunc = aggExpr.aggregateFunction
-        val childrenNodes = new JArrayList[ExpressionNode]()
-        aggExpr.mode match {
-          case PartialMerge | Final =>
-            // Only occupies one column due to intermediate results are 
combined
-            // by previous projection.
-            childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
-            colIdx += 1
-          case Partial | Complete =>
-            aggFunc.children.foreach {
-              _ =>
-                childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
-                colIdx += 1
-            }
-          case _ =>
-            throw new GlutenNotSupportException(
-              s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
-        }
-        addFunctionNode(context, aggFunc, childrenNodes, aggExpr.mode, 
aggregateFunctionList)
-      })
-
-    val extensionNode = getAdvancedExtension()
-    RelBuilder.makeAggregateRel(
-      projectRel,
-      groupingList,
-      aggregateFunctionList,
-      aggFilterList,
-      extensionNode,
-      context,
-      operatorId)
   }
 
   /**
-   * Create and return the Rel for the this aggregation.
+   * Create and return the Rel for the aggregation.
    * @param context
    *   the Substrait context
    * @param operatorId
@@ -457,13 +397,21 @@ abstract class HashAggregateExecTransformer(
       validation: Boolean = false): RelNode = {
     val originalInputAttributes = child.output
 
-    var aggRel = if (rowConstructNeeded(aggregateExpressions)) {
+    val finalInput = if (rowConstructNeeded()) {
       aggParams.rowConstructionNeeded = true
-      getAggRelWithRowConstruct(context, originalInputAttributes, operatorId, 
input, validation)
+      applyRowConstruct(context, originalInputAttributes, operatorId, input, 
validation)
     } else {
-      getAggRelInternal(context, originalInputAttributes, operatorId, input, 
validation)
+      input
     }
 
+    var aggRel = getAggRelInternal(
+      context,
+      originalInputAttributes,
+      operatorId,
+      finalInput,
+      validation,
+      aggParams.rowConstructionNeeded)
+
     if (extractStructNeeded()) {
       aggParams.extractionNeeded = true
       aggRel = applyExtractStruct(context, aggRel, operatorId, validation)
@@ -515,27 +463,30 @@ abstract class HashAggregateExecTransformer(
       context: SubstraitContext,
       originalInputAttributes: Seq[Attribute],
       operatorId: Long,
-      input: RelNode = null,
-      validation: Boolean): RelNode = {
-    // Get the grouping nodes.
-    // Use 'child.output' as based Seq[Attribute], the originalInputAttributes
-    // may be different for each backend.
-    val groupingList = groupingExpressions
-      .map(
+      input: RelNode,
+      validation: Boolean,
+      rowConstructed: Boolean): RelNode = {
+    var colIdx = -1
+    val toExpressionNode: Expression => ExpressionNode = if (rowConstructed) {
+      // If the input is row constructed, use selection to get the column.
+      (_: Expression) =>
+        colIdx += 1
+        ExpressionBuilder.makeSelection(colIdx)
+    } else {
+      (expr: Expression) =>
         ExpressionConverter
-          .replaceWithExpressionTransformer(_, child.output)
-          .doTransform(context))
-      .asJava
+          .replaceWithExpressionTransformer(expr, originalInputAttributes)
+          .doTransform(context)
+    }
+
+    val groupingList = groupingExpressions.map(toExpressionNode).asJava
     // Get the aggregate function nodes.
     val aggFilterList = new JArrayList[ExpressionNode]()
     val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
     aggregateExpressions.foreach(
       aggExpr => {
         if (aggExpr.filter.isDefined) {
-          val exprNode = ExpressionConverter
-            .replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
-            .doTransform(context)
-          aggFilterList.add(exprNode)
+          aggFilterList.add(toExpressionNode(aggExpr.filter.get))
         } else {
           // The number of filters should be aligned with that of aggregate 
functions.
           aggFilterList.add(null)
@@ -543,33 +494,25 @@ abstract class HashAggregateExecTransformer(
         val aggregateFunc = aggExpr.aggregateFunction
         val childrenNodes = aggExpr.mode match {
           case Partial | Complete =>
-            aggregateFunc.children.toList.map(
-              expr => {
-                ExpressionConverter
-                  .replaceWithExpressionTransformer(expr, 
originalInputAttributes)
-                  .doTransform(context)
-              })
+            aggregateFunc.children.toList.map(toExpressionNode)
           case PartialMerge | Final =>
-            rewriteAggBufferAttributes(
-              aggregateFunc.inputAggBufferAttributes,
-              originalInputAttributes).map {
-              attr =>
-                ExpressionConverter
-                  .replaceWithExpressionTransformer(attr, 
originalInputAttributes)
-                  .doTransform(context)
+            if (rowConstructed) {
+              // Only occupies one column due to intermediate results are 
combined
+              // by previous row construct projection.
+              Seq(toExpressionNode.apply(null))
+            } else {
+              rewriteAggBufferAttributes(
+                aggregateFunc.inputAggBufferAttributes,
+                originalInputAttributes).map(toExpressionNode)
             }
           case other =>
             throw new GlutenNotSupportException(s"$other not supported.")
         }
-        addFunctionNode(
-          context,
-          aggregateFunc,
-          childrenNodes.asJava,
-          aggExpr.mode,
-          aggregateFunctionList)
+        aggregateFunctionList.add(
+          makeFunctionNode(context, aggregateFunc, childrenNodes.asJava, 
aggExpr.mode))
       })
 
-    val extensionNode = getAdvancedExtension(validation, 
originalInputAttributes)
+    val extensionNode = getAdvancedExtension(validation && !rowConstructed, 
originalInputAttributes)
     RelBuilder.makeAggregateRel(
       input,
       groupingList,
@@ -584,7 +527,7 @@ abstract class HashAggregateExecTransformer(
       validation: Boolean = false,
       originalInputAttributes: Seq[Attribute] = Seq.empty): 
AdvancedExtensionNode = {
     val enhancement = if (validation) {
-      // Use a extension node to send the input types through Substrait plan 
for validation.
+      // Use an extension node to send the input types through Substrait plan 
for validation.
       val inputTypeNodeList = originalInputAttributes
         .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
         .asJava


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to