This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new d93128ce10 [GLUTEN-9178][CH] Fix cse in aggregate operator not working 
(#9301)
d93128ce10 is described below

commit d93128ce10c2e74a46baf112cd388b613ce59c12
Author: Shuai li <[email protected]>
AuthorDate: Mon Apr 14 11:13:57 2025 +0800

    [GLUTEN-9178][CH] Fix cse in aggregate operator not working (#9301)
    
    * [GLUTEN-9178][CH] Fix cse in aggregate operator not working
    
    * fix ci
---
 .../CommonSubexpressionEliminateRule.scala         | 16 ++++++-------
 .../execution/GlutenFunctionValidateSuite.scala    | 26 +++++++++++++++++++---
 2 files changed, 31 insertions(+), 11 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
index e10e2bba00..7e674367f1 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
@@ -89,16 +89,16 @@ class CommonSubexpressionEliminateRule(spark: SparkSession) 
extends Rule[Logical
 
   private def replaceAggCommonExprWithAttribute(
       expr: Expression,
-      commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): 
Expression = {
+      commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute],
+      inAgg: Boolean = false): Expression = {
     val exprEquals = commonExprMap.get(ExpressionEquals(expr))
-    if (expr.isInstanceOf[AggregateExpression]) {
-      if (exprEquals.isDefined) {
+    expr match {
+      case _ if exprEquals.isDefined && inAgg =>
         exprEquals.get.attribute
-      } else {
-        expr
-      }
-    } else {
-      expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap))
+      case _: AggregateExpression =>
+        expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap, 
true))
+      case _ =>
+        expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap, 
inAgg))
     }
   }
 
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 7bbcc3c363..02e5b34a37 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, 
NullPropagation}
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LogicalPlan, Project}
 import org.apache.spark.sql.execution._
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
 import org.apache.spark.sql.internal.SQLConf
@@ -825,8 +825,28 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
                             |FROM (select id, cast(id as string) name from 
range(10))
                             |GROUP BY name
                             |""".stripMargin) {
-        df => checkOperatorCount[ProjectExecTransformer](3)(df)
+        df => checkOperatorCount[ProjectExecTransformer](4)(df)
       }
+
+      runQueryAndCompare(
+        s"""
+           |select id % 2, max(hash(id)), min(hash(id)) from range(10) group 
by id % 2
+           |""".stripMargin)(
+        df => {
+          df.queryExecution.optimizedPlan.collect {
+            case Aggregate(_, aggregateExpressions, _) =>
+              val result =
+                aggregateExpressions
+                  .map(a => a.asInstanceOf[Alias].child)
+                  .filter(_.isInstanceOf[AggregateExpression])
+                  .map(expr => 
expr.asInstanceOf[AggregateExpression].aggregateFunction)
+                  .filter(aggFunc => 
aggFunc.children.head.isInstanceOf[AttributeReference])
+                  .map(aggFunc => 
aggFunc.children.head.asInstanceOf[AttributeReference].name)
+                  .distinct
+              assertResult(1)(result.size)
+          }
+          checkOperatorCount[ProjectExecTransformer](1)(df)
+        })
     }
   }
 
@@ -1339,7 +1359,7 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
 
   test("Test rewrite aggregate if to aggregate with filter") {
     val sql = "select sum(if(id % 2=0, id, null)), count(if(id % 2 = 0, 1, 
null)), " +
-      "avg(if(id % 2 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from 
range(10)"
+      "avg(if(id % 4 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from 
range(10)"
 
     def checkAggregateWithFilter(df: DataFrame): Unit = {
       val aggregates = collectWithSubqueries(df.queryExecution.executedPlan) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to