This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7bc4794947 [GLUTEN-8974][CH] Replace specical `join + aggregate` case
with `any join` (#9059)
7bc4794947 is described below
commit 7bc479494718e2078f9c56171834dcde1292938b
Author: lgbo <[email protected]>
AuthorDate: Tue Mar 25 09:04:24 2025 +0800
[GLUTEN-8974][CH] Replace specical `join + aggregate` case with `any join`
(#9059)
* wip
* wip
* take advantage any join to accelerate join + aggregate
* rewrite by rule
* remove debug log
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 7 ++
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 1 +
.../execution/CHHashJoinExecTransformer.scala | 5 ++
.../EliminateDeduplicateAggregateWithAnyJoin.scala | 83 ++++++++++++++++++++++
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 34 +++++++++
.../Parser/AdvancedParametersParseUtil.cpp | 1 +
.../Parser/AdvancedParametersParseUtil.h | 2 +-
.../Parser/RelParsers/JoinRelParser.cpp | 11 +--
cpp-ch/local-engine/Rewriter/RelRewriter.h | 6 +-
9 files changed, 143 insertions(+), 7 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index c0380452de..728c5c9a76 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -407,6 +407,13 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
)
}
+ def eliminateDeduplicateAggregateWithAnyJoin(): Boolean = {
+ SparkEnv.get.conf.getBoolean(
+ CHConfig.runtimeConfig("eliminate_deduplicate_aggregate_with_any_join"),
+ defaultValue = true
+ )
+ }
+
override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.get.enableNativeWriter.getOrElse(false)
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index 2c47d1e00e..1b56c003dc 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -122,6 +122,7 @@ object CHRuleApi {
injector.injectPostTransform(c => RemoveDuplicatedColumns.apply(c.session))
injector.injectPostTransform(c =>
AddPreProjectionForHashJoin.apply(c.session))
injector.injectPostTransform(c =>
ReplaceSubStringComparison.apply(c.session))
+ injector.injectPostTransform(c =>
EliminateDeduplicateAggregateWithAnyJoin(c.session))
// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(c => p =>
ExpandFallbackPolicy(c.caller.isAqe(), p))
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 6bf2248ebe..21eab86da5 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -100,6 +100,8 @@ case class CHShuffledHashJoinExecTransformer(
left,
right,
isSkewJoin) {
+ // `any join` is used to accelerate the case when the right table is the
aggregate result.
+ var isAnyJoin = false
override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): CHShuffledHashJoinExecTransformer =
@@ -139,6 +141,9 @@ case class CHShuffledHashJoinExecTransformer(
.append("isExistenceJoin=")
.append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
.append("\n")
+ .append("isAnyJoin=")
+ .append(if (isAnyJoin) 1 else 0)
+ .append("\n")
CHAQEUtil.getShuffleQueryStageStats(streamedPlan) match {
case Some(stats) =>
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
new file mode 100644
index 0000000000..06a4199d53
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
+import org.apache.gluten.execution._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+case class EliminateDeduplicateAggregateWithAnyJoin(spark: SparkSession)
+ extends Rule[SparkPlan]
+ with Logging {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!CHBackendSettings.eliminateDeduplicateAggregateWithAnyJoin()) {
+ return plan
+ }
+
+ plan.transformUp {
+ case hashJoin: CHShuffledHashJoinExecTransformer =>
+ hashJoin.right match {
+ case aggregate: CHHashAggregateExecTransformer =>
+ if (
+ hashJoin.joinType == LeftOuter &&
+ isDeduplicateAggregate(aggregate) &&
allGroupingKeysAreJoinKeys(hashJoin, aggregate)
+ ) {
+ val newHashJoin = hashJoin.copy(right = aggregate.child)
+ newHashJoin.isAnyJoin = true
+ newHashJoin
+ } else {
+ hashJoin
+ }
+ case project @ ProjectExecTransformer(_, aggregate:
CHHashAggregateExecTransformer) =>
+ if (
+ hashJoin.joinType == LeftOuter &&
+ isDeduplicateAggregate(aggregate) &&
+ allGroupingKeysAreJoinKeys(hashJoin, aggregate) &&
project.projectList.forall(
+ _.isInstanceOf[AttributeReference])
+ ) {
+ val newHashJoin =
+ hashJoin.copy(right = project.copy(child = aggregate.child))
+ newHashJoin.isAnyJoin = true
+ newHashJoin
+ } else {
+ hashJoin
+ }
+ case _ => hashJoin
+ }
+ }
+ }
+
+ def isDeduplicateAggregate(aggregate: CHHashAggregateExecTransformer):
Boolean = {
+ aggregate.aggregateExpressions.isEmpty &&
aggregate.groupingExpressions.forall(
+ _.isInstanceOf[AttributeReference])
+ }
+
+ def allGroupingKeysAreJoinKeys(
+ join: CHShuffledHashJoinExecTransformer,
+ aggregate: CHHashAggregateExecTransformer): Boolean = {
+ val rightKeys = join.rightKeys
+ val groupingKeys = aggregate.groupingExpressions
+ groupingKeys.forall(key => rightKeys.exists(_.semanticEquals(key))) &&
+ groupingKeys.length == rightKeys.length
+ }
+}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 3ce03565e8..81aa9e8fc7 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3396,5 +3396,39 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
}
+ test("GLUTEN-8974 accelerate join + aggregate by any join") {
+ withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) {
+ // check EliminateDeduplicateAggregateWithAnyJoin is effective
+ def checkOnlyOneAggregate(df: DataFrame): Unit = {
+ val aggregates = collectWithSubqueries(df.queryExecution.executedPlan)
{
+ case e: HashAggregateExecBaseTransformer => e
+ }
+ assert(aggregates.size == 1)
+ }
+ val sql1 =
+ """
+ |select t1.*, t2.* from nation as t1
+ |left join (select n_regionkey, n_nationkey from nation group by
n_regionkey, n_nationkey) t2
+ |on t1.n_regionkey = t2.n_regionkey and t1.n_nationkey =
t2.n_nationkey
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql1, true, checkOnlyOneAggregate)
+
+ val sql2 =
+ """
+ |select t1.*, t2.* from nation as t1
+ |left join (select n_nationkey, n_regionkey from nation group by
n_regionkey, n_nationkey) t2
+ |on t1.n_regionkey = t2.n_regionkey and t1.n_nationkey =
t2.n_nationkey
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql2, true, checkOnlyOneAggregate)
+
+ val sql3 =
+ """
+ |select t1.*, t2.* from nation as t1
+ |left join (select n_regionkey from nation group by n_regionkey) t2
+ |on t1.n_regionkey = t2.n_regionkey
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql3, true, checkOnlyOneAggregate)
+ }
+ }
}
// scalastyle:on line.size.limit
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
index 49bbe02c55..d65f5b5a50 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
@@ -140,6 +140,7 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const
String & advance)
tryAssign(kvs, "buildHashTableId", info.storage_join_key);
tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join);
tryAssign(kvs, "isExistenceJoin", info.is_existence_join);
+ tryAssign(kvs, "isAnyJoin", info.is_any_join);
tryAssign(kvs, "leftRowCount", info.left_table_rows);
tryAssign(kvs, "leftSizeInBytes", info.left_table_bytes);
tryAssign(kvs, "rightRowCount", info.right_table_rows);
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
index 795577328f..4b01c1ac12 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
@@ -22,13 +22,13 @@ namespace local_engine
{
std::unordered_map<String, std::unordered_map<String, String>>
convertToKVs(const String & advance);
-
struct JoinOptimizationInfo
{
bool is_broadcast = false;
bool is_smj = false;
bool is_null_aware_anti_join = false;
bool is_existence_join = false;
+ bool is_any_join = false;
Int64 left_table_rows = -1;
Int64 left_table_bytes = -1;
Int64 right_table_rows = -1;
diff --git a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
index 9355777b44..d0f60a4d37 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
@@ -61,14 +61,17 @@ using namespace DB;
namespace local_engine
{
-std::shared_ptr<DB::TableJoin>
createDefaultTableJoin(substrait::JoinRel_JoinType join_type, bool
is_existence_join, ContextPtr & context)
+std::shared_ptr<DB::TableJoin>
createDefaultTableJoin(substrait::JoinRel_JoinType join_type, const
JoinOptimizationInfo & join_opt_info, ContextPtr & context)
{
auto table_join
= std::make_shared<TableJoin>(context->getSettingsRef(),
context->getGlobalTemporaryVolume(), context->getTempDataOnDisk());
- std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness =
JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join);
+ std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness =
JoinUtil::getJoinKindAndStrictness(join_type, join_opt_info.is_existence_join);
table_join->setKind(kind_and_strictness.first);
- table_join->setStrictness(kind_and_strictness.second);
+ if (!join_opt_info.is_any_join)
+ table_join->setStrictness(kind_and_strictness.second);
+ else
+ table_join->setStrictness(DB::JoinStrictness::Any);
return table_join;
}
@@ -206,7 +209,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const
substrait::JoinRel & join, DB::Q
if (storage_join)
renamePlanColumns(*left, *right, *storage_join);
- auto table_join = createDefaultTableJoin(join.type(),
join_opt_info.is_existence_join, context);
+ auto table_join = createDefaultTableJoin(join.type(), join_opt_info,
context);
DB::Block right_header_before_convert_step = right->getCurrentHeader();
addConvertStep(*table_join, *left, *right);
diff --git a/cpp-ch/local-engine/Rewriter/RelRewriter.h
b/cpp-ch/local-engine/Rewriter/RelRewriter.h
index 62f57a5789..e370e7eea1 100644
--- a/cpp-ch/local-engine/Rewriter/RelRewriter.h
+++ b/cpp-ch/local-engine/Rewriter/RelRewriter.h
@@ -29,7 +29,10 @@ namespace local_engine
class RelRewriter
{
public:
- RelRewriter(ParserContextPtr parser_context_) :
parser_context(parser_context_) { }
+ RelRewriter(ParserContextPtr parser_context_)
+ : parser_context(parser_context_)
+ {
+ }
virtual ~RelRewriter() = default;
virtual void rewrite(substrait::Rel & rel) = 0;
@@ -38,5 +41,4 @@ protected:
inline DB::ContextPtr getContext() const { return
parser_context->queryContext(); }
};
-
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]