This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bc4794947 [GLUTEN-8974][CH] Replace specical `join + aggregate` case 
with `any join` (#9059)
7bc4794947 is described below

commit 7bc479494718e2078f9c56171834dcde1292938b
Author: lgbo <[email protected]>
AuthorDate: Tue Mar 25 09:04:24 2025 +0800

    [GLUTEN-8974][CH] Replace specical `join + aggregate` case with `any join` 
(#9059)
    
    * wip
    
    * wip
    
    * take advantage any join to accelerate join + aggregate
    
    * rewrite by rule
    
    * remove debug log
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |  7 ++
 .../gluten/backendsapi/clickhouse/CHRuleApi.scala  |  1 +
 .../execution/CHHashJoinExecTransformer.scala      |  5 ++
 .../EliminateDeduplicateAggregateWithAnyJoin.scala | 83 ++++++++++++++++++++++
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 34 +++++++++
 .../Parser/AdvancedParametersParseUtil.cpp         |  1 +
 .../Parser/AdvancedParametersParseUtil.h           |  2 +-
 .../Parser/RelParsers/JoinRelParser.cpp            | 11 +--
 cpp-ch/local-engine/Rewriter/RelRewriter.h         |  6 +-
 9 files changed, 143 insertions(+), 7 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index c0380452de..728c5c9a76 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -407,6 +407,13 @@ object CHBackendSettings extends BackendSettingsApi with 
Logging {
     )
   }
 
+  def eliminateDeduplicateAggregateWithAnyJoin(): Boolean = {
+    SparkEnv.get.conf.getBoolean(
+      CHConfig.runtimeConfig("eliminate_deduplicate_aggregate_with_any_join"),
+      defaultValue = true
+    )
+  }
+
   override def enableNativeWriteFiles(): Boolean = {
     GlutenConfig.get.enableNativeWriter.getOrElse(false)
   }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index 2c47d1e00e..1b56c003dc 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -122,6 +122,7 @@ object CHRuleApi {
     injector.injectPostTransform(c => RemoveDuplicatedColumns.apply(c.session))
     injector.injectPostTransform(c => 
AddPreProjectionForHashJoin.apply(c.session))
     injector.injectPostTransform(c => 
ReplaceSubStringComparison.apply(c.session))
+    injector.injectPostTransform(c => 
EliminateDeduplicateAggregateWithAnyJoin(c.session))
 
     // Gluten columnar: Fallback policies.
     injector.injectFallbackPolicy(c => p => 
ExpandFallbackPolicy(c.caller.isAqe(), p))
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 6bf2248ebe..21eab86da5 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -100,6 +100,8 @@ case class CHShuffledHashJoinExecTransformer(
     left,
     right,
     isSkewJoin) {
+  // `any join` is used to accelerate the case when the right table is the 
aggregate result.
+  var isAnyJoin = false
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan,
       newRight: SparkPlan): CHShuffledHashJoinExecTransformer =
@@ -139,6 +141,9 @@ case class CHShuffledHashJoinExecTransformer(
       .append("isExistenceJoin=")
       .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
       .append("\n")
+      .append("isAnyJoin=")
+      .append(if (isAnyJoin) 1 else 0)
+      .append("\n")
 
     CHAQEUtil.getShuffleQueryStageStats(streamedPlan) match {
       case Some(stats) =>
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
new file mode 100644
index 0000000000..06a4199d53
--- /dev/null
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
+import org.apache.gluten.execution._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+case class EliminateDeduplicateAggregateWithAnyJoin(spark: SparkSession)
+  extends Rule[SparkPlan]
+  with Logging {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    if (!CHBackendSettings.eliminateDeduplicateAggregateWithAnyJoin()) {
+      return plan
+    }
+
+    plan.transformUp {
+      case hashJoin: CHShuffledHashJoinExecTransformer =>
+        hashJoin.right match {
+          case aggregate: CHHashAggregateExecTransformer =>
+            if (
+              hashJoin.joinType == LeftOuter &&
+              isDeduplicateAggregate(aggregate) && 
allGroupingKeysAreJoinKeys(hashJoin, aggregate)
+            ) {
+              val newHashJoin = hashJoin.copy(right = aggregate.child)
+              newHashJoin.isAnyJoin = true
+              newHashJoin
+            } else {
+              hashJoin
+            }
+          case project @ ProjectExecTransformer(_, aggregate: 
CHHashAggregateExecTransformer) =>
+            if (
+              hashJoin.joinType == LeftOuter &&
+              isDeduplicateAggregate(aggregate) &&
+              allGroupingKeysAreJoinKeys(hashJoin, aggregate) && 
project.projectList.forall(
+                _.isInstanceOf[AttributeReference])
+            ) {
+              val newHashJoin =
+                hashJoin.copy(right = project.copy(child = aggregate.child))
+              newHashJoin.isAnyJoin = true
+              newHashJoin
+            } else {
+              hashJoin
+            }
+          case _ => hashJoin
+        }
+    }
+  }
+
+  def isDeduplicateAggregate(aggregate: CHHashAggregateExecTransformer): 
Boolean = {
+    aggregate.aggregateExpressions.isEmpty && 
aggregate.groupingExpressions.forall(
+      _.isInstanceOf[AttributeReference])
+  }
+
+  def allGroupingKeysAreJoinKeys(
+      join: CHShuffledHashJoinExecTransformer,
+      aggregate: CHHashAggregateExecTransformer): Boolean = {
+    val rightKeys = join.rightKeys
+    val groupingKeys = aggregate.groupingExpressions
+    groupingKeys.forall(key => rightKeys.exists(_.semanticEquals(key))) &&
+    groupingKeys.length == rightKeys.length
+  }
+}
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 3ce03565e8..81aa9e8fc7 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3396,5 +3396,39 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
   }
 
+  test("GLUTEN-8974 accelerate join + aggregate by any join") {
+    withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) {
+      // check EliminateDeduplicateAggregateWithAnyJoin is effective
+      def checkOnlyOneAggregate(df: DataFrame): Unit = {
+        val aggregates = collectWithSubqueries(df.queryExecution.executedPlan) 
{
+          case e: HashAggregateExecBaseTransformer => e
+        }
+        assert(aggregates.size == 1)
+      }
+      val sql1 =
+        """
+          |select t1.*, t2.* from nation as t1
+          |left join (select n_regionkey, n_nationkey from nation group by 
n_regionkey, n_nationkey) t2
+          |on t1.n_regionkey = t2.n_regionkey and t1.n_nationkey = 
t2.n_nationkey
+          |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql1, true, checkOnlyOneAggregate)
+
+      val sql2 =
+        """
+          |select t1.*, t2.* from nation as t1
+          |left join (select n_nationkey, n_regionkey from nation group by 
n_regionkey, n_nationkey) t2
+          |on t1.n_regionkey = t2.n_regionkey and t1.n_nationkey = 
t2.n_nationkey
+          |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql2, true, checkOnlyOneAggregate)
+
+      val sql3 =
+        """
+          |select t1.*, t2.* from nation as t1
+          |left join (select n_regionkey from nation group by n_regionkey) t2
+          |on t1.n_regionkey = t2.n_regionkey
+          |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql3, true, checkOnlyOneAggregate)
+    }
+  }
 }
 // scalastyle:on line.size.limit
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp 
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
index 49bbe02c55..d65f5b5a50 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
@@ -140,6 +140,7 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const 
String & advance)
     tryAssign(kvs, "buildHashTableId", info.storage_join_key);
     tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join);
     tryAssign(kvs, "isExistenceJoin", info.is_existence_join);
+    tryAssign(kvs, "isAnyJoin", info.is_any_join);
     tryAssign(kvs, "leftRowCount", info.left_table_rows);
     tryAssign(kvs, "leftSizeInBytes", info.left_table_bytes);
     tryAssign(kvs, "rightRowCount", info.right_table_rows);
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h 
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
index 795577328f..4b01c1ac12 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
@@ -22,13 +22,13 @@ namespace local_engine
 {
 std::unordered_map<String, std::unordered_map<String, String>> 
convertToKVs(const String & advance);
 
-
 struct JoinOptimizationInfo
 {
     bool is_broadcast = false;
     bool is_smj = false;
     bool is_null_aware_anti_join = false;
     bool is_existence_join = false;
+    bool is_any_join = false;
     Int64 left_table_rows = -1;
     Int64 left_table_bytes = -1;
     Int64 right_table_rows = -1;
diff --git a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp 
b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
index 9355777b44..d0f60a4d37 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
@@ -61,14 +61,17 @@ using namespace DB;
 
 namespace local_engine
 {
-std::shared_ptr<DB::TableJoin> 
createDefaultTableJoin(substrait::JoinRel_JoinType join_type, bool 
is_existence_join, ContextPtr & context)
+std::shared_ptr<DB::TableJoin> 
createDefaultTableJoin(substrait::JoinRel_JoinType join_type, const 
JoinOptimizationInfo & join_opt_info, ContextPtr & context)
 {
     auto table_join
         = std::make_shared<TableJoin>(context->getSettingsRef(), 
context->getGlobalTemporaryVolume(), context->getTempDataOnDisk());
 
-    std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = 
JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join);
+    std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = 
JoinUtil::getJoinKindAndStrictness(join_type, join_opt_info.is_existence_join);
     table_join->setKind(kind_and_strictness.first);
-    table_join->setStrictness(kind_and_strictness.second);
+    if (!join_opt_info.is_any_join)
+        table_join->setStrictness(kind_and_strictness.second);
+    else
+        table_join->setStrictness(DB::JoinStrictness::Any);
     return table_join;
 }
 
@@ -206,7 +209,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
     if (storage_join)
         renamePlanColumns(*left, *right, *storage_join);
 
-    auto table_join = createDefaultTableJoin(join.type(), 
join_opt_info.is_existence_join, context);
+    auto table_join = createDefaultTableJoin(join.type(), join_opt_info, 
context);
     DB::Block right_header_before_convert_step = right->getCurrentHeader();
     addConvertStep(*table_join, *left, *right);
 
diff --git a/cpp-ch/local-engine/Rewriter/RelRewriter.h 
b/cpp-ch/local-engine/Rewriter/RelRewriter.h
index 62f57a5789..e370e7eea1 100644
--- a/cpp-ch/local-engine/Rewriter/RelRewriter.h
+++ b/cpp-ch/local-engine/Rewriter/RelRewriter.h
@@ -29,7 +29,10 @@ namespace local_engine
 class RelRewriter
 {
 public:
-    RelRewriter(ParserContextPtr parser_context_) : 
parser_context(parser_context_) { }
+    RelRewriter(ParserContextPtr parser_context_)
+        : parser_context(parser_context_)
+    {
+    }
     virtual ~RelRewriter() = default;
     virtual void rewrite(substrait::Rel & rel) = 0;
 
@@ -38,5 +41,4 @@ protected:
 
     inline DB::ContextPtr getContext() const { return 
parser_context->queryContext(); }
 };
-
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to