This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 96045b1175 [GLUTEN-7778][CH] Make aggregation output schema same as CH 
native (#7811)
96045b1175 is described below

commit 96045b117531a2a0db93bf4f9ea7bac5d7fba117
Author: lgbo <[email protected]>
AuthorDate: Thu Nov 7 09:51:19 2024 +0800

    [GLUTEN-7778][CH] Make aggregation output schema same as CH native (#7811)
    
    * unity agg output
    
    * update
    
    * update
    
    * update
    
    * fixed
---
 .../clickhouse/CHSparkPlanExecApi.scala            |  9 +++--
 .../backendsapi/clickhouse/CHTransformerApi.scala  | 11 +++---
 .../execution/CHHashAggregateExecTransformer.scala | 39 ++++++++++++++++++++++
 3 files changed, 53 insertions(+), 6 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index ba165d936e..190fcb13ea 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -158,16 +158,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
       aggregateAttributes: Seq[Attribute],
       initialInputBufferOffset: Int,
       resultExpressions: Seq[NamedExpression],
-      child: SparkPlan): HashAggregateExecBaseTransformer =
+      child: SparkPlan): HashAggregateExecBaseTransformer = {
+    val replacedResultExpressions = 
CHHashAggregateExecTransformer.getCHAggregateResultExpressions(
+      groupingExpressions,
+      aggregateExpressions,
+      resultExpressions)
     CHHashAggregateExecTransformer(
       requiredChildDistributionExpressions,
       groupingExpressions.distinct,
       aggregateExpressions,
       aggregateAttributes,
       initialInputBufferOffset,
-      resultExpressions.distinct,
+      replacedResultExpressions.distinct,
       child
     )
+  }
 
   /** Generate HashAggregateExecPullOutHelper */
   override def genHashAggregateExecPullOutHelper(
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
index e5b7182585..bf909c52ac 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
@@ -177,10 +177,13 @@ class CHTransformerApi extends TransformerApi with 
Logging {
         // output name will be different from grouping expressions,
         // so using output attribute instead of grouping expression
         val groupingExpressions = 
hash.output.splitAt(hash.groupingExpressions.size)._1
-        val aggResultAttributes = 
CHHashAggregateExecTransformer.getAggregateResultAttributes(
-          groupingExpressions,
-          hash.aggregateExpressions
-        )
+        val aggResultAttributes = CHHashAggregateExecTransformer
+          .getCHAggregateResultExpressions(
+            groupingExpressions,
+            hash.aggregateExpressions,
+            hash.resultExpressions
+          )
+          .map(_.toAttribute)
         if (aggResultAttributes.size == hash.output.size) {
           aggResultAttributes
         } else {
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index f5e64330cd..fcf6320f8e 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -43,6 +43,45 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
 
 object CHHashAggregateExecTransformer {
+  // The result attributes of aggregate expressions from vanilla may be 
different from CH native.
+  // For example, the result attributes of `avg(x)` are `sum(x)` and 
`count(x)`. This could bring
+  // some unexpected issues. So we need to make the result attributes 
consistent with CH native.
+  def getCHAggregateResultExpressions(
+      groupingExpressions: Seq[NamedExpression],
+      aggregateExpressions: Seq[AggregateExpression],
+      resultExpressions: Seq[NamedExpression]): Seq[NamedExpression] = {
+    var adjustedResultExpressions = resultExpressions.slice(0, 
groupingExpressions.length)
+    var resultExpressionIndex = groupingExpressions.length
+    adjustedResultExpressions ++ aggregateExpressions.flatMap {
+      aggExpr =>
+        aggExpr.mode match {
+          case Partial | PartialMerge =>
+            // For partial aggregate, the size of the result expressions of an 
aggregate expression
+            // is the same as aggBufferAttributes' length
+            val aggBufferAttributesCount = 
aggExpr.aggregateFunction.aggBufferAttributes.length
+            aggExpr.aggregateFunction match {
+              case avg: Average =>
+                val res = Seq(aggExpr.resultAttribute)
+                resultExpressionIndex += aggBufferAttributesCount
+                res
+              case sum: Sum if (sum.dataType.isInstanceOf[DecimalType]) =>
+                val res = Seq(resultExpressions(resultExpressionIndex))
+                resultExpressionIndex += aggBufferAttributesCount
+                res
+              case _ =>
+                val res = resultExpressions
+                  .slice(resultExpressionIndex, resultExpressionIndex + 
aggBufferAttributesCount)
+                resultExpressionIndex += aggBufferAttributesCount
+                res
+            }
+          case _ =>
+            val res = Seq(resultExpressions(resultExpressionIndex))
+            resultExpressionIndex += 1
+            res
+        }
+    }
+  }
+
   def getAggregateResultAttributes(
       groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression]): Seq[Attribute] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to