This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d93128ce10 [GLUTEN-9178][CH] Fix cse in aggregate operator not working
(#9301)
d93128ce10 is described below
commit d93128ce10c2e74a46baf112cd388b613ce59c12
Author: Shuai li <[email protected]>
AuthorDate: Mon Apr 14 11:13:57 2025 +0800
[GLUTEN-9178][CH] Fix cse in aggregate operator not working (#9301)
* [GLUTEN-9178][CH] Fix cse in aggregate operator not working
* fix ci
---
.../CommonSubexpressionEliminateRule.scala | 16 ++++++-------
.../execution/GlutenFunctionValidateSuite.scala | 26 +++++++++++++++++++---
2 files changed, 31 insertions(+), 11 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
index e10e2bba00..7e674367f1 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
@@ -89,16 +89,16 @@ class CommonSubexpressionEliminateRule(spark: SparkSession)
extends Rule[Logical
private def replaceAggCommonExprWithAttribute(
expr: Expression,
- commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]):
Expression = {
+ commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute],
+ inAgg: Boolean = false): Expression = {
val exprEquals = commonExprMap.get(ExpressionEquals(expr))
- if (expr.isInstanceOf[AggregateExpression]) {
- if (exprEquals.isDefined) {
+ expr match {
+ case _ if exprEquals.isDefined && inAgg =>
exprEquals.get.attribute
- } else {
- expr
- }
- } else {
- expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap))
+ case _: AggregateExpression =>
+ expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap,
true))
+ case _ =>
+ expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap,
inAgg))
}
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 7bbcc3c363..02e5b34a37 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding,
NullPropagation}
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter,
LogicalPlan, Project}
import org.apache.spark.sql.execution._
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.sql.internal.SQLConf
@@ -825,8 +825,28 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
|FROM (select id, cast(id as string) name from
range(10))
|GROUP BY name
|""".stripMargin) {
- df => checkOperatorCount[ProjectExecTransformer](3)(df)
+ df => checkOperatorCount[ProjectExecTransformer](4)(df)
}
+
+ runQueryAndCompare(
+ s"""
+ |select id % 2, max(hash(id)), min(hash(id)) from range(10) group
by id % 2
+ |""".stripMargin)(
+ df => {
+ df.queryExecution.optimizedPlan.collect {
+ case Aggregate(_, aggregateExpressions, _) =>
+ val result =
+ aggregateExpressions
+ .map(a => a.asInstanceOf[Alias].child)
+ .filter(_.isInstanceOf[AggregateExpression])
+ .map(expr =>
expr.asInstanceOf[AggregateExpression].aggregateFunction)
+ .filter(aggFunc =>
aggFunc.children.head.isInstanceOf[AttributeReference])
+ .map(aggFunc =>
aggFunc.children.head.asInstanceOf[AttributeReference].name)
+ .distinct
+ assertResult(1)(result.size)
+ }
+ checkOperatorCount[ProjectExecTransformer](1)(df)
+ })
}
}
@@ -1339,7 +1359,7 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
test("Test rewrite aggregate if to aggregate with filter") {
val sql = "select sum(if(id % 2=0, id, null)), count(if(id % 2 = 0, 1,
null)), " +
- "avg(if(id % 2 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from
range(10)"
+ "avg(if(id % 4 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from
range(10)"
def checkAggregateWithFilter(df: DataFrame): Unit = {
val aggregates = collectWithSubqueries(df.queryExecution.executedPlan) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]