This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 43b69b6b3a [GLUTEN-8354][CH] Fix cse issue in aggregate[Part2] (#8376)
43b69b6b3a is described below

commit 43b69b6b3a346c44ab13282af134b6b92fd31531
Author: Shuai li <[email protected]>
AuthorDate: Tue Dec 31 09:38:39 2024 +0800

    [GLUTEN-8354][CH] Fix cse issue in aggregate[Part2] (#8376)
    
    [CH] Fix cse issue in aggregate[Part2]
---
 .../CommonSubexpressionEliminateRule.scala         | 24 +++++++++++++++++++++-
 .../execution/GlutenFunctionValidateSuite.scala    | 18 ++++++++++++++++
 2 files changed, 41 insertions(+), 1 deletion(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
index ffe3c1af23..e10e2bba00 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
@@ -87,6 +87,21 @@ class CommonSubexpressionEliminateRule(spark: SparkSession) 
extends Rule[Logical
     }
   }
 
+  private def replaceAggCommonExprWithAttribute(
+      expr: Expression,
+      commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): 
Expression = {
+    val exprEquals = commonExprMap.get(ExpressionEquals(expr))
+    if (expr.isInstanceOf[AggregateExpression]) {
+      if (exprEquals.isDefined) {
+        exprEquals.get.attribute
+      } else {
+        expr
+      }
+    } else {
+      expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap))
+    }
+  }
+
   private def isValidCommonExpr(expr: Expression): Boolean = {
     if (
       (expr.isInstanceOf[Unevaluable] && 
!expr.isInstanceOf[AttributeReference])
@@ -162,7 +177,14 @@ class CommonSubexpressionEliminateRule(spark: 
SparkSession) extends Rule[Logical
     // Replace the common expressions with the first expression that produces 
it.
     try {
       var newExprs = inputCtx.exprs
-        .map(replaceCommonExprWithAttribute(_, commonExprMap))
+        .map(
+          expr => {
+            if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
+              replaceAggCommonExprWithAttribute(expr, commonExprMap)
+            } else {
+              replaceCommonExprWithAttribute(expr, commonExprMap)
+            }
+          })
       logTrace(s"newExprs after rewrite: $newExprs")
       RewriteContext(newExprs, preProject)
     } catch {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index e9df41ace6..5dcba3b476 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -659,6 +659,24 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
           "order by hash(id%10), hash(hash(id%10))") {
         df => checkOperatorCount[ProjectExecTransformer](3)(df)
       }
+
+      runQueryAndCompare(s"""
+                            |SELECT 'test' AS test
+                            |  , Sum(CASE
+                            |    WHEN name = '2' THEN 0
+                            |      ELSE id
+                            |    END) AS c1
+                            |  , Sum(CASE
+                            |    WHEN name = '2' THEN id
+                            |      ELSE 0
+                            |    END) AS c2
+                            | , CASE WHEN name = '2' THEN Sum(id) ELSE 0
+                            |   END AS c3
+                            |FROM (select id, cast(id as string) name from 
range(10))
+                            |GROUP BY name
+                            |""".stripMargin) {
+        df => checkOperatorCount[ProjectExecTransformer](3)(df)
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to