This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 6c9fc228e2 [GLUTEN-9728][CH] Fix: Inconsistent result from
JoinAggregateToAggregateUnion (#9729)
6c9fc228e2 is described below
commit 6c9fc228e277339f00be06551ee7b45c9a2a9905
Author: lgbo <[email protected]>
AuthorDate: Mon May 26 10:08:55 2025 +0800
[GLUTEN-9728][CH] Fix: Inconsistent result from
JoinAggregateToAggregateUnion (#9729)
---
.../extension/CoalesceAggregationUnion.scala | 4 +-
.../extension/JoinAggregateToAggregateUnion.scala | 128 +++++++++++++++------
.../GlutenCoalesceAggregationUnionSuite.scala | 6 +-
.../execution/GlutenEliminateJoinSuite.scala | 42 +++++++
4 files changed, 138 insertions(+), 42 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
index d5faa48d07..90d647f437 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
@@ -533,7 +533,7 @@ case class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPla
override def apply(plan: LogicalPlan): LogicalPlan = {
if (
spark.conf
- .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION,
"true")
+ .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION,
"false")
.toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
) {
Try {
@@ -953,7 +953,7 @@ case class CoalesceProjectionUnion(spark: SparkSession)
extends Rule[LogicalPlan
override def apply(plan: LogicalPlan): LogicalPlan = {
if (
spark.conf
- .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "true")
+ .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "false")
.toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
) {
Try {
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
index 61dd9ca869..b7c7f29c98 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
@@ -638,9 +638,10 @@ case class JoinAggregateToAggregateUnion(spark:
SparkSession)
}
override def apply(plan: LogicalPlan): LogicalPlan = {
+ // It's experimental feature, disable it by default
if (
spark.conf
- .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION,
"true")
+ .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION,
"false")
.toBoolean && isResolvedPlan(plan)
) {
val reorderedPlan = ReorderJoinSubqueries().apply(plan)
@@ -689,9 +690,23 @@ case class JoinAggregateToAggregateUnion(spark:
SparkSession)
} else {
val unionedAggregates =
unionAllJoinedAggregates(analyzedAggregates.toSeq)
if (remainedPlan.isDefined) {
+ // The left join clause is not an aggregate query.
+
+ // Use the new right keys to build the join condition
+ val newRightKeys =
+ unionedAggregates.output.slice(0,
analyzedAggregates.head.getPrimeJoinKeys.length)
+ val newJoinCondition =
+ buildJoinCondition(analyzedAggregates.head.getPrimeJoinKeys(),
newRightKeys)
val lastJoin = analyzedAggregates.head.join
- lastJoin.copy(left = visitPlan(lastJoin.left), right =
unionedAggregates)
+ lastJoin.copy(
+ left = visitPlan(lastJoin.left),
+ right = unionedAggregates,
+ condition = Some(newJoinCondition))
} else {
+ /*
+ * The left join clause is also a aggregate query.
+ * If flag_0 is null, filter out the rows.
+ */
buildPrimeJoinKeysFilterOnAggregateUnion(unionedAggregates,
analyzedAggregates.toSeq)
}
}
@@ -702,25 +717,20 @@ case class JoinAggregateToAggregateUnion(spark:
SparkSession)
}
}
- def buildAggregateExpressionWithNewChildren(
- ne: NamedExpression,
- inputs: Seq[Attribute]): NamedExpression = {
- val aggregateExpression =
- ne.asInstanceOf[Alias].child.asInstanceOf[AggregateExpression]
- val newAggregateFunction = aggregateExpression.aggregateFunction
- .withNewChildren(inputs)
- .asInstanceOf[AggregateFunction]
- val newAggregateExpression = aggregateExpression.copy(aggregateFunction =
newAggregateFunction)
- RuleExpressionHelper.makeNamedExpression(newAggregateExpression, ne.name)
- }
-
def unionAllJoinedAggregates(analyzedAggregates:
Seq[JoinedAggregateAnalyzer]): LogicalPlan = {
val extendProjects = buildExtendProjects(analyzedAggregates)
val union = buildUnionOnExtendedProjects(extendProjects)
+ // The output is {keys_0}, {keys_1}, ..., {aggs_0}, {agg_1}, ... , flag_0,
flag_1, ...
val aggregateUnion = buildAggregateOnUnion(union, analyzedAggregates)
- logDebug(s"xxx aggregateUnion $aggregateUnion")
+ // Push a duplication of {kyes_0} to the head. This keys will be used as
join keys later.
+ val duplicateRightKeysProject =
+ buildJoinRightKeysProject(aggregateUnion, analyzedAggregates)
+ /*
+ * If flag_0 is null, let {keys_0} be null too.
+ * If flag_i is null, let {aggs_i} be null too.
+ */
val setNullsProject =
- buildMakeNotMatchedRowsNullProject(aggregateUnion, analyzedAggregates,
Set())
+ buildMakeNotMatchedRowsNullProject(duplicateRightKeysProject,
analyzedAggregates, Set())
buildRenameProject(setNullsProject, analyzedAggregates)
}
@@ -799,34 +809,52 @@ case class JoinAggregateToAggregateUnion(spark:
SparkSession)
analyzedAggregates: Seq[JoinedAggregateAnalyzer],
ignoreAggregates: Set[Int]): LogicalPlan = {
val input = plan.output
- val flagExpressions =
- input.slice(plan.output.length - analyzedAggregates.length,
plan.output.length)
- val aggregateExprsStart =
- analyzedAggregates.length *
analyzedAggregates.head.getGroupingKeys.length
- var fieldIndex = aggregateExprsStart
- val aggregatesIfNullExpressions = analyzedAggregates.zipWithIndex.map {
+
+ val keysNumber = analyzedAggregates.head.getGroupingKeys().length
+ val keysStartIndex = keysNumber
+ val aggregateExpressionsStatIndex = keysStartIndex +
analyzedAggregates.length * keysNumber
+ val flagExpressionsStartIndex = input.length - analyzedAggregates.length
+
+ val dupPrimeKeys = input.slice(0, keysStartIndex)
+ val keys = input.slice(keysStartIndex, aggregateExpressionsStatIndex)
+ val aggregateExpressions = input.slice(aggregateExpressionsStatIndex,
flagExpressionsStartIndex)
+ val flagExpressions = input.slice(flagExpressionsStartIndex, input.length)
+
+ val newProjectList = ArrayBuffer[NamedExpression]()
+ newProjectList ++= dupPrimeKeys
+ val newFirstClauseKeys = keys.slice(0, keysNumber).map {
+ case key =>
+ val ifNull =
+ If(IsNull(flagExpressions(0)),
RuleExpressionHelper.makeNullLiteral(key.dataType), key)
+ RuleExpressionHelper.makeNamedExpression(ifNull, key.name)
+ }
+ newProjectList ++= newFirstClauseKeys
+ newProjectList ++= keys.slice(keysNumber, keys.length)
+
+ var fieldIndex = 0
+ val newAggregateExpressions = analyzedAggregates.zipWithIndex.map {
case (analyzedAggregate, i) =>
- val flagExpr = flagExpressions(i)
- val aggregateExpressions = analyzedAggregate.getAggregateExpressions()
+ val localAggregateExpressions =
analyzedAggregate.getAggregateExpressions()
val aggregateFunctionAnalyzers =
analyzedAggregate.getAggregateFunctionAnalyzers()
- aggregateExpressions.zipWithIndex.map {
- case (e, i) =>
- val valueExpr = input(fieldIndex)
+ localAggregateExpressions.zipWithIndex.map {
+ case (e, j) =>
+ val aggregateExpr = aggregateExpressions(fieldIndex)
fieldIndex += 1
- if (ignoreAggregates(i) ||
aggregateFunctionAnalyzers(i).ignoreNulls()) {
- valueExpr.asInstanceOf[NamedExpression]
+
+ if (ignoreAggregates.contains(i) ||
aggregateFunctionAnalyzers(j).ignoreNulls()) {
+ aggregateExpr.asInstanceOf[NamedExpression]
} else {
- val clearExpr = If(
- IsNull(flagExpr),
- RuleExpressionHelper.makeNullLiteral(valueExpr.dataType),
- valueExpr)
- RuleExpressionHelper.makeNamedExpression(clearExpr,
valueExpr.name)
+ val ifNullExpr = If(
+ IsNull(flagExpressions(i)),
+ RuleExpressionHelper.makeNullLiteral(aggregateExpr.dataType),
+ aggregateExpr)
+ RuleExpressionHelper.makeNamedExpression(ifNullExpr,
aggregateExpr.name)
}
}
}
- val ifNullExpressions = aggregatesIfNullExpressions.flatten
- val projectList = input.slice(0, aggregateExprsStart) ++ ifNullExpressions
++ flagExpressions
- Project(projectList, plan)
+ newProjectList ++= newAggregateExpressions.flatten
+ newProjectList ++= flagExpressions
+ Project(newProjectList.toSeq, plan)
}
/**
@@ -838,7 +866,8 @@ case class JoinAggregateToAggregateUnion(spark:
SparkSession)
val input = plan.output
val projectList = ArrayBuffer[NamedExpression]()
val keysNum = analyzedAggregates.head.getGroupingKeys.length
- var fieldIndex = 0
+ projectList ++= input.slice(0, keysNum)
+ var fieldIndex = keysNum
for (i <- 0 until analyzedAggregates.length) {
val keys = analyzedAggregates(i).getGroupingKeys()
for (j <- 0 until keys.length) {
@@ -867,6 +896,29 @@ case class JoinAggregateToAggregateUnion(spark:
SparkSession)
Project(projectList.toSeq, plan)
}
+ def buildJoinCondition(leftKeys: Seq[Attribute], rightKeys: Seq[Attribute]):
Expression = {
+ leftKeys
+ .zip(rightKeys)
+ .map {
+ case (leftKey, rightKey) =>
+ EqualTo(leftKey, rightKey)
+ }
+ .reduceLeft(And)
+ }
+
+ def buildJoinRightKeysProject(
+ plan: LogicalPlan,
+ analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = {
+ val input = plan.output
+ val keysNum = analyzedAggregates.head.getGroupingKeys.length
+ // Make a duplication of the prime aggregate keys, and put them in the
front
+ val projectList = input.slice(0, keysNum).map {
+ case key =>
+ RuleExpressionHelper.makeNamedExpression(key, "_dup_prime_" + key.name)
+ }
+ Project(projectList ++ input, plan)
+ }
+
/**
* Build a extended project list, which contains three parts.
* - The grouping keys of all tables.
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
index 6e73d12230..fdd7482757 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
@@ -16,6 +16,8 @@
*/
package org.apache.gluten.execution
+import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
+
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row}
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
@@ -47,8 +49,8 @@ class GlutenCoalesceAggregationUnionSuite extends
GlutenClickHouseWholeStageTran
.set("spark.io.compression.codec", "snappy")
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
-
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"true")
-
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union",
"true")
+ .set(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "true")
+ .set(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, "true")
}
def createTestTable(tableName: String, data: DataFrame): Unit = {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
index ea5de96a43..2f6f0023ef 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.execution
import org.apache.gluten.backendsapi.clickhouse._
+import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
@@ -58,6 +59,7 @@ class GlutenEliminateJoinSuite extends
GlutenClickHouseWholeStageTransformerSuit
.set("spark.gluten.supported.scala.udfs",
"compare_substrings:compare_substrings")
.set(CHConfig.runtimeSettings("max_memory_usage_ratio_for_streaming_aggregating"),
"0.01")
.set(CHConfig.runtimeSettings("high_cardinality_threshold_for_streaming_aggregating"),
"0.2")
+ .set(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, "true")
.set(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
ConstantFolding.ruleName + "," + NullPropagation.ruleName)
@@ -563,4 +565,44 @@ class GlutenEliminateJoinSuite extends
GlutenClickHouseWholeStageTransformerSuit
spark.sql("drop table t_9267_1")
spark.sql("drop table t_9267_2")
}
+
+ test("right keys are in used") {
+ spark.sql("drop table if exists t_join_1")
+ spark.sql("drop table if exists t_join_2")
+ spark.sql("drop table if exists t_join_3")
+
+ spark.sql("create table t_join_1 (a bigint, b bigint) using parquet")
+ spark.sql("create table t_join_2 (a bigint, b bigint) using parquet")
+ spark.sql("create table t_join_3 (a bigint, b bigint) using parquet")
+
+ spark.sql("insert into t_join_1 select id % 10 as a, id as b from
range(10)")
+ spark.sql("insert into t_join_2 select id % 7 as a, id as b from
range(20)")
+ spark.sql("insert into t_join_3 select id % 10 as a, id as b from
range(20)")
+
+ val sql =
+ """
+ |select a1, b, a2, a3, s2, s3 from (
+ | select a as a1, b from t_join_1
+ |) t1 left join (
+ | select a as a2, sum(b) as s2 from t_join_2 group by a
+ |) t2 on a1 = a2 left join (
+ | select a as a3, sum(b) as s3 from t_join_3 group by a
+ |) t3 on a1 = a3
+ |order by a1, b, a2, a3, s2, s3
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(
+ sql,
+ true,
+ {
+ df =>
+ val joins = df.queryExecution.executedPlan.collect {
+ case join: ShuffledHashJoinExecTransformerBase => join
+ }
+ assert(joins.length == 1)
+ })
+
+ spark.sql("drop table t_join_1")
+ spark.sql("drop table t_join_2")
+ spark.sql("drop table t_join_3")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]