This is an automated email from the ASF dual-hosted git repository.
zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 43b69b6b3a [GLUTEN-8354][CH] Fix cse issue in aggregate[Part2] (#8376)
43b69b6b3a is described below
commit 43b69b6b3a346c44ab13282af134b6b92fd31531
Author: Shuai li <[email protected]>
AuthorDate: Tue Dec 31 09:38:39 2024 +0800
[GLUTEN-8354][CH] Fix cse issue in aggregate[Part2] (#8376)
[CH] Fix cse issue in aggregate[Part2]
---
.../CommonSubexpressionEliminateRule.scala | 24 +++++++++++++++++++++-
.../execution/GlutenFunctionValidateSuite.scala | 18 ++++++++++++++++
2 files changed, 41 insertions(+), 1 deletion(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
index ffe3c1af23..e10e2bba00 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
@@ -87,6 +87,21 @@ class CommonSubexpressionEliminateRule(spark: SparkSession)
extends Rule[Logical
}
}
+ private def replaceAggCommonExprWithAttribute(
+ expr: Expression,
+ commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]):
Expression = {
+ val exprEquals = commonExprMap.get(ExpressionEquals(expr))
+ if (expr.isInstanceOf[AggregateExpression]) {
+ if (exprEquals.isDefined) {
+ exprEquals.get.attribute
+ } else {
+ expr
+ }
+ } else {
+ expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap))
+ }
+ }
+
private def isValidCommonExpr(expr: Expression): Boolean = {
if (
(expr.isInstanceOf[Unevaluable] &&
!expr.isInstanceOf[AttributeReference])
@@ -162,7 +177,14 @@ class CommonSubexpressionEliminateRule(spark:
SparkSession) extends Rule[Logical
// Replace the common expressions with the first expression that produces
it.
try {
var newExprs = inputCtx.exprs
- .map(replaceCommonExprWithAttribute(_, commonExprMap))
+ .map(
+ expr => {
+ if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
+ replaceAggCommonExprWithAttribute(expr, commonExprMap)
+ } else {
+ replaceCommonExprWithAttribute(expr, commonExprMap)
+ }
+ })
logTrace(s"newExprs after rewrite: $newExprs")
RewriteContext(newExprs, preProject)
} catch {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index e9df41ace6..5dcba3b476 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -659,6 +659,24 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
"order by hash(id%10), hash(hash(id%10))") {
df => checkOperatorCount[ProjectExecTransformer](3)(df)
}
+
+ runQueryAndCompare(s"""
+ |SELECT 'test' AS test
+ | , Sum(CASE
+ | WHEN name = '2' THEN 0
+ | ELSE id
+ | END) AS c1
+ | , Sum(CASE
+ | WHEN name = '2' THEN id
+ | ELSE 0
+ | END) AS c2
+ | , CASE WHEN name = '2' THEN Sum(id) ELSE 0
+ | END AS c3
+ |FROM (select id, cast(id as string) name from
range(10))
+ |GROUP BY name
+ |""".stripMargin) {
+ df => checkOperatorCount[ProjectExecTransformer](3)(df)
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]