This is an automated email from the ASF dual-hosted git repository. lgbo pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new 2614db9f3e fix (#9293) 2614db9f3e is described below commit 2614db9f3e190053e0e46fc3a823ee6ade01a37a Author: lgbo <lgbo.u...@gmail.com> AuthorDate: Fri Apr 11 09:12:55 2025 +0800 fix (#9293) --- .../gluten/backendsapi/clickhouse/CHBackend.scala | 2 + .../clickhouse/CHSparkPlanExecApi.scala | 3 +- .../execution/CHHashJoinExecTransformer.scala | 5 +- .../EliminateDeduplicateAggregateWithAnyJoin.scala | 43 ++++++++++--- .../extension/JoinAggregateToAggregateUnion.scala | 4 +- .../RewriteSortMergeJoinToHashJoinRule.scala | 3 +- .../execution/GlutenEliminateJoinSuite.scala | 72 ++++++++++++++++++++++ 7 files changed, 116 insertions(+), 16 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index dcdb5dcc5d..1addfe23ea 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -163,6 +163,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging { CHConfig.prefixOf("enable.coalesce.project.union") val GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION: String = CHConfig.prefixOf("join.aggregate.to.aggregate.union") + val GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN: String = + CHConfig.prefixOf("eliminate_deduplicate_aggregate_with_any_join") def affinityMode: String = { SparkEnv.get.conf diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index dcf19204d2..c77e7bc31b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -315,7 +315,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { condition, left, right, - isSkewJoin) + isSkewJoin, + false) } /** Generate BroadcastHashJoinExecTransformer. */ diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 21eab86da5..cbf3a3b6ea 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -90,7 +90,8 @@ case class CHShuffledHashJoinExecTransformer( condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean) + isSkewJoin: Boolean, + isAnyJoin: Boolean) extends ShuffledHashJoinExecTransformerBase( leftKeys, rightKeys, @@ -100,8 +101,6 @@ case class CHShuffledHashJoinExecTransformer( left, right, isSkewJoin) { - // `any join` is used to accelerate the case when the right table is the aggregate result. - var isAnyJoin = false override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): CHShuffledHashJoinExecTransformer = diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala index 06a4199d53..6a7f8a165c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala @@ -22,6 +22,7 @@ import org.apache.gluten.execution._ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan @@ -30,21 +31,48 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark: SparkSession) extends Rule[SparkPlan] with Logging { override def apply(plan: SparkPlan): SparkPlan = { - if (!CHBackendSettings.eliminateDeduplicateAggregateWithAnyJoin()) { + + if ( + !spark.conf + .get(CHBackendSettings.GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN, "true") + .toBoolean + ) { return plan } plan.transformUp { - case hashJoin: CHShuffledHashJoinExecTransformer => + case hashJoin: CHShuffledHashJoinExecTransformer + if (hashJoin.buildSide == BuildRight && hashJoin.joinType == LeftOuter) => hashJoin.right match { case aggregate: CHHashAggregateExecTransformer => + if ( + isDeduplicateAggregate(aggregate) && allGroupingKeysAreJoinKeys(hashJoin, aggregate) + ) { + hashJoin.copy(right = aggregate.child, isAnyJoin = true) + } else { + hashJoin + } + case project @ ProjectExecTransformer(_, aggregate: CHHashAggregateExecTransformer) => if ( hashJoin.joinType == LeftOuter && + isDeduplicateAggregate(aggregate) && + allGroupingKeysAreJoinKeys(hashJoin, aggregate) && project.projectList.forall( + _.isInstanceOf[AttributeReference]) + ) { + hashJoin.copy(right = project.copy(child = aggregate.child), isAnyJoin = true) + } else { + hashJoin + } + case _ => hashJoin + } + case hashJoin: CHShuffledHashJoinExecTransformer + if (hashJoin.buildSide == BuildLeft && hashJoin.joinType == LeftOuter) => + hashJoin.left match { + case aggregate: CHHashAggregateExecTransformer => + if ( isDeduplicateAggregate(aggregate) && allGroupingKeysAreJoinKeys(hashJoin, aggregate) ) { - val newHashJoin = hashJoin.copy(right = aggregate.child) - newHashJoin.isAnyJoin = true - newHashJoin + hashJoin.copy(left = aggregate.child, isAnyJoin = true) } else { hashJoin } @@ -55,10 +83,7 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark: SparkSession) allGroupingKeysAreJoinKeys(hashJoin, aggregate) && project.projectList.forall( _.isInstanceOf[AttributeReference]) ) { - val newHashJoin = - hashJoin.copy(right = project.copy(child = aggregate.child)) - newHashJoin.isAnyJoin = true - newHashJoin + hashJoin.copy(left = project.copy(child = aggregate.child), isAnyJoin = true) } else { hashJoin } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index 3fcc2d5369..da80d58f26 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -334,7 +334,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo joinKeys.length != aggregate.groupingExpressions.length || !joinKeys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k))) ) { - logError( + logDebug( s"xxx Join keys and grouping keys are not matched. joinKeys: $joinKeys" + s" outputGroupingKeys: $outputGroupingKeys") return false @@ -955,7 +955,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) analyzedAggregates.insert(0, rightAggregateAnalyzer.get) collectSameKeysJoinedAggregates(join.left, analyzedAggregates) } else { - logError( + logDebug( s"xxx Not have same keys. join keys:" + s"${analyzedAggregates.head.getPrimeJoinKeys()} vs. " + s"${rightAggregateAnalyzer.get.getPrimeJoinKeys()}") diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala index 8c5ada043f..441a181fca 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala @@ -80,7 +80,8 @@ case class RewriteSortMergeJoinToHashJoinRule(session: SparkSession) smj.condition, newLeft, newRight, - smj.isSkewJoin) + smj.isSkewJoin, + false) val validateResult = hashJoin.doValidate() if (!validateResult.ok()) { logError(s"Validation failed for ShuffledHashJoinExec: ${validateResult.reason()}") diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala index 080fc9b5c9..169892c4da 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.gluten.execution +import org.apache.gluten.backendsapi.clickhouse._ + import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.Row @@ -59,6 +61,8 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit .set("spark.sql.shuffle.partitions", "5") .set("spark.sql.autoBroadcastJoinThreshold", "-1") .set("spark.gluten.supported.scala.udfs", "compare_substrings:compare_substrings") + .set(CHConfig.runtimeSettings("max_memory_usage_ratio_for_streaming_aggregating"), "0.01") + .set(CHConfig.runtimeSettings("high_cardinality_threshold_for_streaming_aggregating"), "0.2") .set( SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConstantFolding.ruleName + "," + NullPropagation.ruleName) @@ -469,4 +473,72 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit assert(joins.length == 1) }) } + + // Ensure the isAnyJoin will never lost after apply other rules + test("lost any join setting") { + spark.sql("drop table if exists t_9267_1") + spark.sql("drop table if exists t_9267_2") + spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet") + spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet") + spark.sql("insert into t_9267_1 select id as a, id as b from range(20000000)") + spark.sql("insert into t_9267_2 select id as a, id as b from range(5000000)") + spark.sql("insert into t_9267_2 select id as a, id as b from range(5000000)") + + val sql = + """ + |select count(1) as n1, count(a1, b1, a2) as n2 from( + | select t1.a as a1, t1.b as b1, t2.a as a2 from ( + | select * from t_9267_1 where a >= 0 and b < 100000000 and b >= 0 + | ) t1 left join ( + | select a, b from t_9267_2 group by a, b + | ) t2 on t1.a = t2.a and t1.b = t2.b + |)""".stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + + spark.sql("drop table t_9267_1") + spark.sql("drop table t_9267_2") + } + + test("build left side") { + spark.sql("drop table if exists t_9267_1") + spark.sql("drop table if exists t_9267_2") + spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet") + spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet") + spark.sql("insert into t_9267_1 select id as a, id as b from range(2000000)") + spark.sql("insert into t_9267_2 select id as a, id as b from range(500000)") + spark.sql("insert into t_9267_2 select id as a, id as b from range(500000)") + + // left table is smaller, it will be used as the build side. + val sql = + """ + |select count(1) as n1, count(a1, b1, a2) as n2 from( + | select t1.a as a1, t1.b as b1, t2.a as a2 from ( + | select a, b from t_9267_2 group by a, b + | ) t1 left join ( + | select * from t_9267_1 where a >= 0 and b != 100000000 and b >= 0 + | ) t2 on t1.a = t2.a and t1.b = t2.b + |)""".stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + + spark.sql("drop table t_9267_1") + spark.sql("drop table t_9267_2") + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org