This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 2614db9f3e fix (#9293)
2614db9f3e is described below

commit 2614db9f3e190053e0e46fc3a823ee6ade01a37a
Author: lgbo <lgbo.u...@gmail.com>
AuthorDate: Fri Apr 11 09:12:55 2025 +0800

    fix (#9293)
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |  2 +
 .../clickhouse/CHSparkPlanExecApi.scala            |  3 +-
 .../execution/CHHashJoinExecTransformer.scala      |  5 +-
 .../EliminateDeduplicateAggregateWithAnyJoin.scala | 43 ++++++++++---
 .../extension/JoinAggregateToAggregateUnion.scala  |  4 +-
 .../RewriteSortMergeJoinToHashJoinRule.scala       |  3 +-
 .../execution/GlutenEliminateJoinSuite.scala       | 72 ++++++++++++++++++++++
 7 files changed, 116 insertions(+), 16 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index dcdb5dcc5d..1addfe23ea 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -163,6 +163,8 @@ object CHBackendSettings extends BackendSettingsApi with 
Logging {
     CHConfig.prefixOf("enable.coalesce.project.union")
   val GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION: String =
     CHConfig.prefixOf("join.aggregate.to.aggregate.union")
+  val GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN: String =
+    CHConfig.prefixOf("eliminate_deduplicate_aggregate_with_any_join")
 
   def affinityMode: String = {
     SparkEnv.get.conf
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index dcf19204d2..c77e7bc31b 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -315,7 +315,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
       condition,
       left,
       right,
-      isSkewJoin)
+      isSkewJoin,
+      false)
   }
 
   /** Generate BroadcastHashJoinExecTransformer. */
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 21eab86da5..cbf3a3b6ea 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -90,7 +90,8 @@ case class CHShuffledHashJoinExecTransformer(
     condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan,
-    isSkewJoin: Boolean)
+    isSkewJoin: Boolean,
+    isAnyJoin: Boolean)
   extends ShuffledHashJoinExecTransformerBase(
     leftKeys,
     rightKeys,
@@ -100,8 +101,6 @@ case class CHShuffledHashJoinExecTransformer(
     left,
     right,
     isSkewJoin) {
-  // `any join` is used to accelerate the case when the right table is the 
aggregate result.
-  var isAnyJoin = false
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan,
       newRight: SparkPlan): CHShuffledHashJoinExecTransformer =
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
index 06a4199d53..6a7f8a165c 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.execution._
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.SparkPlan
@@ -30,21 +31,48 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark: 
SparkSession)
   extends Rule[SparkPlan]
   with Logging {
   override def apply(plan: SparkPlan): SparkPlan = {
-    if (!CHBackendSettings.eliminateDeduplicateAggregateWithAnyJoin()) {
+
+    if (
+      !spark.conf
+        
.get(CHBackendSettings.GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN, 
"true")
+        .toBoolean
+    ) {
       return plan
     }
 
     plan.transformUp {
-      case hashJoin: CHShuffledHashJoinExecTransformer =>
+      case hashJoin: CHShuffledHashJoinExecTransformer
+          if (hashJoin.buildSide == BuildRight && hashJoin.joinType == 
LeftOuter) =>
         hashJoin.right match {
           case aggregate: CHHashAggregateExecTransformer =>
+            if (
+              isDeduplicateAggregate(aggregate) && 
allGroupingKeysAreJoinKeys(hashJoin, aggregate)
+            ) {
+              hashJoin.copy(right = aggregate.child, isAnyJoin = true)
+            } else {
+              hashJoin
+            }
+          case project @ ProjectExecTransformer(_, aggregate: 
CHHashAggregateExecTransformer) =>
             if (
               hashJoin.joinType == LeftOuter &&
+              isDeduplicateAggregate(aggregate) &&
+              allGroupingKeysAreJoinKeys(hashJoin, aggregate) && 
project.projectList.forall(
+                _.isInstanceOf[AttributeReference])
+            ) {
+              hashJoin.copy(right = project.copy(child = aggregate.child), 
isAnyJoin = true)
+            } else {
+              hashJoin
+            }
+          case _ => hashJoin
+        }
+      case hashJoin: CHShuffledHashJoinExecTransformer
+          if (hashJoin.buildSide == BuildLeft && hashJoin.joinType == 
LeftOuter) =>
+        hashJoin.left match {
+          case aggregate: CHHashAggregateExecTransformer =>
+            if (
               isDeduplicateAggregate(aggregate) && 
allGroupingKeysAreJoinKeys(hashJoin, aggregate)
             ) {
-              val newHashJoin = hashJoin.copy(right = aggregate.child)
-              newHashJoin.isAnyJoin = true
-              newHashJoin
+              hashJoin.copy(left = aggregate.child, isAnyJoin = true)
             } else {
               hashJoin
             }
@@ -55,10 +83,7 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark: 
SparkSession)
               allGroupingKeysAreJoinKeys(hashJoin, aggregate) && 
project.projectList.forall(
                 _.isInstanceOf[AttributeReference])
             ) {
-              val newHashJoin =
-                hashJoin.copy(right = project.copy(child = aggregate.child))
-              newHashJoin.isAnyJoin = true
-              newHashJoin
+              hashJoin.copy(left = project.copy(child = aggregate.child), 
isAnyJoin = true)
             } else {
               hashJoin
             }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
index 3fcc2d5369..da80d58f26 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
@@ -334,7 +334,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: 
LogicalPlan) extends Lo
       joinKeys.length != aggregate.groupingExpressions.length ||
       !joinKeys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k)))
     ) {
-      logError(
+      logDebug(
         s"xxx Join keys and grouping keys are not matched. joinKeys: 
$joinKeys" +
           s" outputGroupingKeys: $outputGroupingKeys")
       return false
@@ -955,7 +955,7 @@ case class JoinAggregateToAggregateUnion(spark: 
SparkSession)
           analyzedAggregates.insert(0, rightAggregateAnalyzer.get)
           collectSameKeysJoinedAggregates(join.left, analyzedAggregates)
         } else {
-          logError(
+          logDebug(
             s"xxx Not have same keys. join keys:" +
               s"${analyzedAggregates.head.getPrimeJoinKeys()} vs. " +
               s"${rightAggregateAnalyzer.get.getPrimeJoinKeys()}")
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
index 8c5ada043f..441a181fca 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
@@ -80,7 +80,8 @@ case class RewriteSortMergeJoinToHashJoinRule(session: 
SparkSession)
       smj.condition,
       newLeft,
       newRight,
-      smj.isSkewJoin)
+      smj.isSkewJoin,
+      false)
     val validateResult = hashJoin.doValidate()
     if (!validateResult.ok()) {
       logError(s"Validation failed for ShuffledHashJoinExec: 
${validateResult.reason()}")
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
index 080fc9b5c9..169892c4da 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
@@ -16,6 +16,8 @@
  */
 package org.apache.gluten.execution
 
+import org.apache.gluten.backendsapi.clickhouse._
+
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
@@ -59,6 +61,8 @@ class GlutenEliminateJoinSuite extends 
GlutenClickHouseWholeStageTransformerSuit
       .set("spark.sql.shuffle.partitions", "5")
       .set("spark.sql.autoBroadcastJoinThreshold", "-1")
       .set("spark.gluten.supported.scala.udfs", 
"compare_substrings:compare_substrings")
+      
.set(CHConfig.runtimeSettings("max_memory_usage_ratio_for_streaming_aggregating"),
 "0.01")
+      
.set(CHConfig.runtimeSettings("high_cardinality_threshold_for_streaming_aggregating"),
 "0.2")
       .set(
         SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
         ConstantFolding.ruleName + "," + NullPropagation.ruleName)
@@ -469,4 +473,72 @@ class GlutenEliminateJoinSuite extends 
GlutenClickHouseWholeStageTransformerSuit
           assert(joins.length == 1)
       })
   }
+
+  // Ensure the isAnyJoin will never lost after apply other rules
+  test("lost any join setting") {
+    spark.sql("drop table if exists t_9267_1")
+    spark.sql("drop table if exists t_9267_2")
+    spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet")
+    spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet")
+    spark.sql("insert into t_9267_1 select id as a, id as b from 
range(20000000)")
+    spark.sql("insert into t_9267_2 select id as a, id as b from 
range(5000000)")
+    spark.sql("insert into t_9267_2 select id as a, id as b from 
range(5000000)")
+
+    val sql =
+      """
+        |select count(1) as n1, count(a1, b1, a2) as n2 from(
+        |  select t1.a as a1, t1.b as b1, t2.a as a2 from (
+        |    select * from t_9267_1 where a >= 0 and b < 100000000 and b >= 0
+        |  ) t1 left join (
+        |    select a, b from t_9267_2 group by a, b
+        |  ) t2 on t1.a = t2.a and t1.b = t2.b
+        |)""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql,
+      true,
+      {
+        df =>
+          val joins = df.queryExecution.executedPlan.collect {
+            case join: ShuffledHashJoinExecTransformerBase => join
+          }
+          assert(joins.length == 1)
+      })
+
+    spark.sql("drop table t_9267_1")
+    spark.sql("drop table t_9267_2")
+  }
+
+  test("build left side") {
+    spark.sql("drop table if exists t_9267_1")
+    spark.sql("drop table if exists t_9267_2")
+    spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet")
+    spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet")
+    spark.sql("insert into t_9267_1 select id as a, id as b from 
range(2000000)")
+    spark.sql("insert into t_9267_2 select id as a, id as b from 
range(500000)")
+    spark.sql("insert into t_9267_2 select id as a, id as b from 
range(500000)")
+
+    // left table is smaller, it will be used as the build side.
+    val sql =
+      """
+        |select count(1) as n1, count(a1, b1, a2) as n2 from(
+        |  select t1.a as a1, t1.b as b1, t2.a as a2 from (
+        |    select a, b from t_9267_2 group by a, b
+        |  ) t1 left join (
+        |    select * from t_9267_1 where a >= 0 and b != 100000000 and b >= 0
+        |  ) t2 on t1.a = t2.a and t1.b = t2.b
+        |)""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql,
+      true,
+      {
+        df =>
+          val joins = df.queryExecution.executedPlan.collect {
+            case join: ShuffledHashJoinExecTransformerBase => join
+          }
+          assert(joins.length == 1)
+      })
+
+    spark.sql("drop table t_9267_1")
+    spark.sql("drop table t_9267_2")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to