This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c9fc228e2 [GLUTEN-9728][CH] Fix: Inconsistent result from 
JoinAggregateToAggregateUnion (#9729)
6c9fc228e2 is described below

commit 6c9fc228e277339f00be06551ee7b45c9a2a9905
Author: lgbo <[email protected]>
AuthorDate: Mon May 26 10:08:55 2025 +0800

    [GLUTEN-9728][CH] Fix: Inconsistent result from 
JoinAggregateToAggregateUnion (#9729)
---
 .../extension/CoalesceAggregationUnion.scala       |   4 +-
 .../extension/JoinAggregateToAggregateUnion.scala  | 128 +++++++++++++++------
 .../GlutenCoalesceAggregationUnionSuite.scala      |   6 +-
 .../execution/GlutenEliminateJoinSuite.scala       |  42 +++++++
 4 files changed, 138 insertions(+), 42 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
index d5faa48d07..90d647f437 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
@@ -533,7 +533,7 @@ case class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPla
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (
       spark.conf
-        .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, 
"true")
+        .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, 
"false")
         .toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
     ) {
       Try {
@@ -953,7 +953,7 @@ case class CoalesceProjectionUnion(spark: SparkSession) 
extends Rule[LogicalPlan
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (
       spark.conf
-        .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "true")
+        .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "false")
         .toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
     ) {
       Try {
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
index 61dd9ca869..b7c7f29c98 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
@@ -638,9 +638,10 @@ case class JoinAggregateToAggregateUnion(spark: 
SparkSession)
   }
 
   override def apply(plan: LogicalPlan): LogicalPlan = {
+    // It's experimental feature, disable it by default
     if (
       spark.conf
-        .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, 
"true")
+        .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, 
"false")
         .toBoolean && isResolvedPlan(plan)
     ) {
       val reorderedPlan = ReorderJoinSubqueries().apply(plan)
@@ -689,9 +690,23 @@ case class JoinAggregateToAggregateUnion(spark: 
SparkSession)
           } else {
             val unionedAggregates = 
unionAllJoinedAggregates(analyzedAggregates.toSeq)
             if (remainedPlan.isDefined) {
+              // The left join clause is not an aggregate query.
+
+              // Use the new right keys to build the join condition
+              val newRightKeys =
+                unionedAggregates.output.slice(0, 
analyzedAggregates.head.getPrimeJoinKeys.length)
+              val newJoinCondition =
+                buildJoinCondition(analyzedAggregates.head.getPrimeJoinKeys(), 
newRightKeys)
               val lastJoin = analyzedAggregates.head.join
-              lastJoin.copy(left = visitPlan(lastJoin.left), right = 
unionedAggregates)
+              lastJoin.copy(
+                left = visitPlan(lastJoin.left),
+                right = unionedAggregates,
+                condition = Some(newJoinCondition))
             } else {
+              /*
+               * The left join clause is also a aggregate query.
+               * If flag_0 is null, filter out the rows.
+               */
               buildPrimeJoinKeysFilterOnAggregateUnion(unionedAggregates, 
analyzedAggregates.toSeq)
             }
           }
@@ -702,25 +717,20 @@ case class JoinAggregateToAggregateUnion(spark: 
SparkSession)
     }
   }
 
-  def buildAggregateExpressionWithNewChildren(
-      ne: NamedExpression,
-      inputs: Seq[Attribute]): NamedExpression = {
-    val aggregateExpression =
-      ne.asInstanceOf[Alias].child.asInstanceOf[AggregateExpression]
-    val newAggregateFunction = aggregateExpression.aggregateFunction
-      .withNewChildren(inputs)
-      .asInstanceOf[AggregateFunction]
-    val newAggregateExpression = aggregateExpression.copy(aggregateFunction = 
newAggregateFunction)
-    RuleExpressionHelper.makeNamedExpression(newAggregateExpression, ne.name)
-  }
-
   def unionAllJoinedAggregates(analyzedAggregates: 
Seq[JoinedAggregateAnalyzer]): LogicalPlan = {
     val extendProjects = buildExtendProjects(analyzedAggregates)
     val union = buildUnionOnExtendedProjects(extendProjects)
+    // The output is {keys_0}, {keys_1}, ..., {aggs_0}, {agg_1}, ... , flag_0, 
flag_1, ...
     val aggregateUnion = buildAggregateOnUnion(union, analyzedAggregates)
-    logDebug(s"xxx aggregateUnion $aggregateUnion")
+    // Push a duplication of {kyes_0} to the head. This keys will be used as 
join keys later.
+    val duplicateRightKeysProject =
+      buildJoinRightKeysProject(aggregateUnion, analyzedAggregates)
+    /*
+     * If flag_0 is null, let {keys_0} be null too.
+     * If flag_i is null, let {aggs_i} be null too.
+     */
     val setNullsProject =
-      buildMakeNotMatchedRowsNullProject(aggregateUnion, analyzedAggregates, 
Set())
+      buildMakeNotMatchedRowsNullProject(duplicateRightKeysProject, 
analyzedAggregates, Set())
     buildRenameProject(setNullsProject, analyzedAggregates)
   }
 
@@ -799,34 +809,52 @@ case class JoinAggregateToAggregateUnion(spark: 
SparkSession)
       analyzedAggregates: Seq[JoinedAggregateAnalyzer],
       ignoreAggregates: Set[Int]): LogicalPlan = {
     val input = plan.output
-    val flagExpressions =
-      input.slice(plan.output.length - analyzedAggregates.length, 
plan.output.length)
-    val aggregateExprsStart =
-      analyzedAggregates.length * 
analyzedAggregates.head.getGroupingKeys.length
-    var fieldIndex = aggregateExprsStart
-    val aggregatesIfNullExpressions = analyzedAggregates.zipWithIndex.map {
+
+    val keysNumber = analyzedAggregates.head.getGroupingKeys().length
+    val keysStartIndex = keysNumber
+    val aggregateExpressionsStatIndex = keysStartIndex + 
analyzedAggregates.length * keysNumber
+    val flagExpressionsStartIndex = input.length - analyzedAggregates.length
+
+    val dupPrimeKeys = input.slice(0, keysStartIndex)
+    val keys = input.slice(keysStartIndex, aggregateExpressionsStatIndex)
+    val aggregateExpressions = input.slice(aggregateExpressionsStatIndex, 
flagExpressionsStartIndex)
+    val flagExpressions = input.slice(flagExpressionsStartIndex, input.length)
+
+    val newProjectList = ArrayBuffer[NamedExpression]()
+    newProjectList ++= dupPrimeKeys
+    val newFirstClauseKeys = keys.slice(0, keysNumber).map {
+      case key =>
+        val ifNull =
+          If(IsNull(flagExpressions(0)), 
RuleExpressionHelper.makeNullLiteral(key.dataType), key)
+        RuleExpressionHelper.makeNamedExpression(ifNull, key.name)
+    }
+    newProjectList ++= newFirstClauseKeys
+    newProjectList ++= keys.slice(keysNumber, keys.length)
+
+    var fieldIndex = 0
+    val newAggregateExpressions = analyzedAggregates.zipWithIndex.map {
       case (analyzedAggregate, i) =>
-        val flagExpr = flagExpressions(i)
-        val aggregateExpressions = analyzedAggregate.getAggregateExpressions()
+        val localAggregateExpressions = 
analyzedAggregate.getAggregateExpressions()
         val aggregateFunctionAnalyzers = 
analyzedAggregate.getAggregateFunctionAnalyzers()
-        aggregateExpressions.zipWithIndex.map {
-          case (e, i) =>
-            val valueExpr = input(fieldIndex)
+        localAggregateExpressions.zipWithIndex.map {
+          case (e, j) =>
+            val aggregateExpr = aggregateExpressions(fieldIndex)
             fieldIndex += 1
-            if (ignoreAggregates(i) || 
aggregateFunctionAnalyzers(i).ignoreNulls()) {
-              valueExpr.asInstanceOf[NamedExpression]
+
+            if (ignoreAggregates.contains(i) || 
aggregateFunctionAnalyzers(j).ignoreNulls()) {
+              aggregateExpr.asInstanceOf[NamedExpression]
             } else {
-              val clearExpr = If(
-                IsNull(flagExpr),
-                RuleExpressionHelper.makeNullLiteral(valueExpr.dataType),
-                valueExpr)
-              RuleExpressionHelper.makeNamedExpression(clearExpr, 
valueExpr.name)
+              val ifNullExpr = If(
+                IsNull(flagExpressions(i)),
+                RuleExpressionHelper.makeNullLiteral(aggregateExpr.dataType),
+                aggregateExpr)
+              RuleExpressionHelper.makeNamedExpression(ifNullExpr, 
aggregateExpr.name)
             }
         }
     }
-    val ifNullExpressions = aggregatesIfNullExpressions.flatten
-    val projectList = input.slice(0, aggregateExprsStart) ++ ifNullExpressions 
++ flagExpressions
-    Project(projectList, plan)
+    newProjectList ++= newAggregateExpressions.flatten
+    newProjectList ++= flagExpressions
+    Project(newProjectList.toSeq, plan)
   }
 
   /**
@@ -838,7 +866,8 @@ case class JoinAggregateToAggregateUnion(spark: 
SparkSession)
     val input = plan.output
     val projectList = ArrayBuffer[NamedExpression]()
     val keysNum = analyzedAggregates.head.getGroupingKeys.length
-    var fieldIndex = 0
+    projectList ++= input.slice(0, keysNum)
+    var fieldIndex = keysNum
     for (i <- 0 until analyzedAggregates.length) {
       val keys = analyzedAggregates(i).getGroupingKeys()
       for (j <- 0 until keys.length) {
@@ -867,6 +896,29 @@ case class JoinAggregateToAggregateUnion(spark: 
SparkSession)
     Project(projectList.toSeq, plan)
   }
 
+  def buildJoinCondition(leftKeys: Seq[Attribute], rightKeys: Seq[Attribute]): 
Expression = {
+    leftKeys
+      .zip(rightKeys)
+      .map {
+        case (leftKey, rightKey) =>
+          EqualTo(leftKey, rightKey)
+      }
+      .reduceLeft(And)
+  }
+
+  def buildJoinRightKeysProject(
+      plan: LogicalPlan,
+      analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = {
+    val input = plan.output
+    val keysNum = analyzedAggregates.head.getGroupingKeys.length
+    // Make a duplication of the prime aggregate keys, and put them in the 
front
+    val projectList = input.slice(0, keysNum).map {
+      case key =>
+        RuleExpressionHelper.makeNamedExpression(key, "_dup_prime_" + key.name)
+    }
+    Project(projectList ++ input, plan)
+  }
+
   /**
    * Build a extended project list, which contains three parts.
    *   - The grouping keys of all tables.
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
index 6e73d12230..fdd7482757 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
@@ -16,6 +16,8 @@
  */
 package org.apache.gluten.execution
 
+import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
+
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{DataFrame, Row}
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
@@ -47,8 +49,8 @@ class GlutenCoalesceAggregationUnionSuite extends 
GlutenClickHouseWholeStageTran
       .set("spark.io.compression.codec", "snappy")
       .set("spark.sql.shuffle.partitions", "5")
       .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
-      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"true")
-      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union", 
"true")
+      .set(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "true")
+      .set(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, "true")
   }
 
   def createTestTable(tableName: String, data: DataFrame): Unit = {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
index ea5de96a43..2f6f0023ef 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
@@ -17,6 +17,7 @@
 package org.apache.gluten.execution
 
 import org.apache.gluten.backendsapi.clickhouse._
+import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
 
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
@@ -58,6 +59,7 @@ class GlutenEliminateJoinSuite extends 
GlutenClickHouseWholeStageTransformerSuit
       .set("spark.gluten.supported.scala.udfs", 
"compare_substrings:compare_substrings")
       
.set(CHConfig.runtimeSettings("max_memory_usage_ratio_for_streaming_aggregating"),
 "0.01")
       
.set(CHConfig.runtimeSettings("high_cardinality_threshold_for_streaming_aggregating"),
 "0.2")
+      .set(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, "true")
       .set(
         SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
         ConstantFolding.ruleName + "," + NullPropagation.ruleName)
@@ -563,4 +565,44 @@ class GlutenEliminateJoinSuite extends 
GlutenClickHouseWholeStageTransformerSuit
     spark.sql("drop table t_9267_1")
     spark.sql("drop table t_9267_2")
   }
+
+  test("right keys are in used") {
+    spark.sql("drop table if exists t_join_1")
+    spark.sql("drop table if exists t_join_2")
+    spark.sql("drop table if exists t_join_3")
+
+    spark.sql("create table t_join_1 (a bigint, b bigint) using parquet")
+    spark.sql("create table t_join_2 (a bigint, b bigint) using parquet")
+    spark.sql("create table t_join_3 (a bigint, b bigint) using parquet")
+
+    spark.sql("insert into t_join_1 select id % 10 as a, id as b from 
range(10)")
+    spark.sql("insert into t_join_2 select id % 7 as a, id as b from 
range(20)")
+    spark.sql("insert into t_join_3 select id % 10 as a, id as b from 
range(20)")
+
+    val sql =
+      """
+        |select a1, b, a2, a3, s2, s3 from (
+        |  select a as a1, b from t_join_1
+        |) t1 left join (
+        |  select a as a2, sum(b) as s2 from t_join_2 group by a
+        |) t2 on a1 = a2 left join (
+        |  select a as a3, sum(b) as s3 from t_join_3 group by a
+        |) t3 on a1 = a3
+        |order by a1, b, a2, a3, s2, s3
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql,
+      true,
+      {
+        df =>
+          val joins = df.queryExecution.executedPlan.collect {
+            case join: ShuffledHashJoinExecTransformerBase => join
+          }
+          assert(joins.length == 1)
+      })
+
+    spark.sql("drop table t_join_1")
+    spark.sql("drop table t_join_2")
+    spark.sql("drop table t_join_3")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to