This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f4e5585f [GLUTEN-6768][CH] Try to use multi join on clauses instead 
of inequal join condition (#6787)
5f4e5585f is described below

commit 5f4e5585f5fc7f6134b0163125ef72a8a8a7efe1
Author: lgbo <[email protected]>
AuthorDate: Fri Aug 16 11:48:14 2024 +0800

    [GLUTEN-6768][CH] Try to use multi join on clauses instead of inequal join 
condition (#6787)
    
    What changes were proposed in this pull request?
    (Please fill in changes proposed in this fix)
    
    Fixes: #6768
    
    Transform a join with inequal condition into multi join on clauses as 
possible, it could be more efficient. For example convert
    
    on t1.key = t2.key and (t1.a1 = t2.a1 or t1.a2 = t1.a2 or t1.a3 = t2.a3)
    to
    
    on (t1.key = t2.key and t1.a1 = t2.a1) or (t1.key = t2.key and t1.a2 = 
t1.a2) or (t1.key = t2.key and t1.a3 = t2.a3)
    We need to limit the right table size to avoid OOM, because we can only use 
hash join algorithm on multi join on clauses.
    
    How was this patch tested?
    (Please explain how this patch was tested. E.g. unit tests, integration 
tests, manual tests)
    
    unit tests
    
    (If this patch involves UI changes, please attach a screenshot; otherwise, 
remove this)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   3 +-
 .../execution/CHHashJoinExecTransformer.scala      |  58 ++++
 .../scala/org/apache/gluten/utils/CHAQEUtil.scala  |  36 +--
 .../GlutenClickHouseColumnarShuffleAQESuite.scala  |  45 +++
 cpp-ch/local-engine/Common/GlutenConfig.h          |  21 ++
 .../Parser/AdvancedParametersParseUtil.cpp         |  23 ++
 .../Parser/AdvancedParametersParseUtil.h           |   5 +
 cpp-ch/local-engine/Parser/JoinRelParser.cpp       | 302 +++++++++++++++++----
 cpp-ch/local-engine/Parser/JoinRelParser.h         |  19 ++
 9 files changed, 440 insertions(+), 72 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 03e5aaa53..5a49d6ea3 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -309,7 +309,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
       condition: Option[Expression],
       left: SparkPlan,
       right: SparkPlan,
-      isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase =
+      isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = {
     CHShuffledHashJoinExecTransformer(
       leftKeys,
       rightKeys,
@@ -319,6 +319,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
       left,
       right,
       isSkewJoin)
+  }
 
   /** Generate BroadcastHashJoinExecTransformer. */
   def genBroadcastHashJoinExecTransformer(
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 7080e55dc..adb824804 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.gluten.execution
 
+import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.extension.ValidationResult
 import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, 
ShuffleHashJoinStrategy}
 
@@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.{SparkPlan, 
SQLExecution}
 import org.apache.spark.sql.execution.joins.BuildSideRelation
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
+import com.google.protobuf.{Any, StringValue}
 import io.substrait.proto.JoinRel
 
 object JoinTypeTransform {
@@ -104,6 +106,62 @@ case class CHShuffledHashJoinExecTransformer(
   private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
   override protected lazy val substraitJoinType: JoinRel.JoinType =
     JoinTypeTransform.toSubstraitType(joinType, buildSide)
+
+  override def genJoinParameters(): Any = {
+    val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = 
(0, 0, "")
+
+    // Start with "JoinParameters:"
+    val joinParametersStr = new StringBuffer("JoinParameters:")
+    // isBHJ: 0 for SHJ, 1 for BHJ
+    // isNullAwareAntiJoin: 0 for false, 1 for true
+    // buildHashTableId: the unique id for the hash table of build plan
+    joinParametersStr
+      .append("isBHJ=")
+      .append(isBHJ)
+      .append("\n")
+      .append("isNullAwareAntiJoin=")
+      .append(isNullAwareAntiJoin)
+      .append("\n")
+      .append("buildHashTableId=")
+      .append(buildHashTableId)
+      .append("\n")
+      .append("isExistenceJoin=")
+      .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
+      .append("\n")
+
+    CHAQEUtil.getShuffleQueryStageStats(streamedPlan) match {
+      case Some(stats) =>
+        joinParametersStr
+          .append("leftRowCount=")
+          .append(stats.rowCount.getOrElse(-1))
+          .append("\n")
+          .append("leftSizeInBytes=")
+          .append(stats.sizeInBytes)
+          .append("\n")
+      case _ =>
+    }
+    CHAQEUtil.getShuffleQueryStageStats(buildPlan) match {
+      case Some(stats) =>
+        joinParametersStr
+          .append("rightRowCount=")
+          .append(stats.rowCount.getOrElse(-1))
+          .append("\n")
+          .append("rightSizeInBytes=")
+          .append(stats.sizeInBytes)
+          .append("\n")
+      case _ =>
+    }
+    joinParametersStr
+      .append("numPartitions=")
+      .append(outputPartitioning.numPartitions)
+      .append("\n")
+
+    val message = StringValue
+      .newBuilder()
+      .setValue(joinParametersStr.toString)
+      .build()
+    BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
+  }
 }
 
 case class CHBroadcastBuildSideRDD(
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala
similarity index 53%
copy from cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
copy to 
backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala
index 5a15a3ea8..9a35517f5 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
+++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala
@@ -14,24 +14,26 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#pragma once
-#include <unordered_map>
+package org.apache.gluten.execution
 
-namespace local_engine
-{
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive._
 
-std::unordered_map<String, std::unordered_map<String, String>> 
convertToKVs(const String & advance);
+object CHAQEUtil {
 
-
-struct JoinOptimizationInfo
-{
-    bool is_broadcast = false;
-    bool is_smj = false;
-    bool is_null_aware_anti_join = false;
-    bool is_existence_join = false;
-    String storage_join_key;
-
-    static JoinOptimizationInfo parse(const String & advance);
-};
+  // All TransformSupports have lost the logicalLink. So we need iterate the 
plan to find the
+  // first ShuffleQueryStageExec and get the runtime stats.
+  def getShuffleQueryStageStats(plan: SparkPlan): Option[Statistics] = {
+    plan match {
+      case stage: ShuffleQueryStageExec =>
+        Some(stage.getRuntimeStatistics)
+      case _ =>
+        if (plan.children.length == 1) {
+          getShuffleQueryStageStats(plan.children.head)
+        } else {
+          None
+        }
+    }
+  }
 }
-
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala
index fc22add2d..10e5c7534 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala
@@ -32,6 +32,7 @@ class GlutenClickHouseColumnarShuffleAQESuite
   override protected val tablesPath: String = basePath + "/tpch-data-ch"
   override protected val tpchQueries: String = rootPath + 
"queries/tpch-queries-ch"
   override protected val queriesResults: String = rootPath + 
"mergetree-queries-output"
+  private val backendConfigPrefix = "spark.gluten.sql.columnar.backend.ch."
 
   /** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */
   override protected def sparkConf: SparkConf = {
@@ -261,4 +262,48 @@ class GlutenClickHouseColumnarShuffleAQESuite
       spark.sql("drop table t2")
     }
   }
+
+  test("GLUTEN-6768 change mixed join condition into multi join on clauses") {
+    withSQLConf(
+      (backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", 
"true"),
+      (backendConfigPrefix + 
"runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000")
+    ) {
+
+      spark.sql("create table t1(a int, b int, c int, d int) using parquet")
+      spark.sql("create table t2(a int, b int, c int, d int) using parquet")
+
+      spark.sql("""
+                  |insert into t1
+                  |select id % 2 as a, id as b, id + 1 as c, id + 2 as d from 
range(1000)
+                  |""".stripMargin)
+      spark.sql("""
+                  |insert into t2
+                  |select id % 2 as a, id as b, id + 1 as c, id + 2 as d from 
range(1000)
+                  |""".stripMargin)
+
+      var sql = """
+                  |select * from t1 join t2 on
+                  |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or t1.d = t2.d)
+                  |order by t1.a, t1.b, t1.c, t1.d
+                  |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+      sql = """
+              |select * from t1 join t2 on
+              |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.c = t2.c and 
t1.d = t2.d))
+              |order by t1.a, t1.b, t1.c, t1.d
+              |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+      sql = """
+              |select * from t1 join t2 on
+              |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.d = t2.d and 
t1.c >= t2.c))
+              |order by t1.a, t1.b, t1.c, t1.d
+              |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+      spark.sql("drop table t1")
+      spark.sql("drop table t2")
+    }
+  }
 }
diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h 
b/cpp-ch/local-engine/Common/GlutenConfig.h
index abb7295ad..84744dab2 100644
--- a/cpp-ch/local-engine/Common/GlutenConfig.h
+++ b/cpp-ch/local-engine/Common/GlutenConfig.h
@@ -92,6 +92,27 @@ struct StreamingAggregateConfig
     }
 };
 
+struct JoinConfig
+{
+    /// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or 
t1.id2 = t2.id2)`, try to join with multi
+    /// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or 
t1.id2 = t2.id2)`
+    inline static const String PREFER_MULTI_JOIN_ON_CLAUSES = 
"prefer_multi_join_on_clauses";
+    /// Only hash join supports multi join on clauses, the right table cannot 
be too large. If the row number of right
+    /// table is larger then this limit, this transform will not work.
+    inline static const String MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT = 
"multi_join_on_clauses_build_side_row_limit";
+
+    bool prefer_multi_join_on_clauses = true;
+    size_t multi_join_on_clauses_build_side_rows_limit = 10000000;
+
+    static JoinConfig loadFromContext(DB::ContextPtr context)
+    {
+        JoinConfig config;
+        config.prefer_multi_join_on_clauses = 
context->getConfigRef().getBool(PREFER_MULTI_JOIN_ON_CLAUSES, true);
+        config.multi_join_on_clauses_build_side_rows_limit = 
context->getConfigRef().getUInt64(MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT, 
10000000);
+        return config;
+    }
+};
+
 struct ExecutorConfig
 {
     inline static const String DUMP_PIPELINE = "dump_pipeline";
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp 
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
index a7a07c0bf..42d4f4d4d 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
@@ -57,6 +57,24 @@ void tryAssign<bool>(const std::unordered_map<String, 
String> & kvs, const Strin
     }
 }
 
+template<>
+void tryAssign<Int64>(const std::unordered_map<String, String> & kvs, const 
String & key, Int64 & v)
+{
+    auto it = kvs.find(key);
+    if (it != kvs.end())
+    {
+        try
+        {
+            v = std::stol(it->second);
+        }
+        catch (...)
+        {
+            LOG_ERROR(getLogger("tryAssign"), "Invalid number: {}", 
it->second);
+            throw;
+        }
+    }
+}
+
 template <char... chars>
 void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf)
 {
@@ -121,6 +139,11 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const 
String & advance)
     tryAssign(kvs, "buildHashTableId", info.storage_join_key);
     tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join);
     tryAssign(kvs, "isExistenceJoin", info.is_existence_join);
+    tryAssign(kvs, "leftRowCount", info.left_table_rows);
+    tryAssign(kvs, "leftSizeInBytes", info.left_table_bytes);
+    tryAssign(kvs, "rightRowCount", info.right_table_rows);
+    tryAssign(kvs, "rightSizeInBytes", info.right_table_bytes);
+    tryAssign(kvs, "numPartitions", info.partitions_num);
     return info;
 }
 }
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h 
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
index 5a15a3ea8..5f6fe6d25 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
@@ -29,6 +29,11 @@ struct JoinOptimizationInfo
     bool is_smj = false;
     bool is_null_aware_anti_join = false;
     bool is_existence_join = false;
+    Int64 left_table_rows = -1;
+    Int64 left_table_bytes = -1;
+    Int64 right_table_rows = -1;
+    Int64 right_table_bytes = -1;
+    Int64 partitions_num = -1;
     String storage_join_key;
 
     static JoinOptimizationInfo parse(const String & advance);
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp 
b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
index b217a9bd9..ef19e007d 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
@@ -16,6 +16,8 @@
  */
 #include "JoinRelParser.h"
 
+#include <Core/Block.h>
+#include <Functions/FunctionFactory.h>
 #include <IO/ReadBufferFromString.h>
 #include <IO/ReadHelpers.h>
 #include <Interpreters/CollectJoinOnKeysVisitor.h>
@@ -25,15 +27,15 @@
 #include <Interpreters/TableJoin.h>
 #include <Join/BroadCastJoinBuilder.h>
 #include <Join/StorageJoinFromReadBuffer.h>
-#include <Parser/SerializedPlanParser.h>
 #include <Parser/AdvancedParametersParseUtil.h>
+#include <Parser/SerializedPlanParser.h>
 #include <Parsers/ASTIdentifier.h>
 #include <Processors/QueryPlan/ExpressionStep.h>
 #include <Processors/QueryPlan/FilterStep.h>
 #include <Processors/QueryPlan/JoinStep.h>
 #include <google/protobuf/wrappers.pb.h>
 #include <Common/CHUtil.h>
-#include <Functions/FunctionFactory.h>
+#include <Common/GlutenConfig.h>
 
 #include <Common/logger_useful.h>
 
@@ -42,9 +44,9 @@ namespace DB
 {
 namespace ErrorCodes
 {
-    extern const int LOGICAL_ERROR;
-    extern const int UNKNOWN_TYPE;
-    extern const int BAD_ARGUMENTS;
+extern const int LOGICAL_ERROR;
+extern const int UNKNOWN_TYPE;
+extern const int BAD_ARGUMENTS;
 }
 }
 using namespace DB;
@@ -98,7 +100,8 @@ DB::QueryPlanPtr JoinRelParser::parseOp(const substrait::Rel 
& rel, std::list<co
     return parseJoin(join, std::move(left_plan), std::move(right_plan));
 }
 
-std::unordered_set<DB::JoinTableSide> 
JoinRelParser::extractTableSidesFromExpression(const substrait::Expression & 
expr, const DB::Block & left_header, const DB::Block & right_header)
+std::unordered_set<DB::JoinTableSide> 
JoinRelParser::extractTableSidesFromExpression(
+    const substrait::Expression & expr, const DB::Block & left_header, const 
DB::Block & right_header)
 {
     std::unordered_set<DB::JoinTableSide> table_sides;
     if (expr.has_scalar_function())
@@ -169,8 +172,7 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, 
DB::QueryPlan & righ
         storage_join.getRightSampleBlock().getColumnsWithTypeAndName(),
         ActionsDAG::MatchColumnsMode::Position);
 
-    QueryPlanStepPtr right_project_step =
-        std::make_unique<ExpressionStep>(right.getCurrentDataStream(), 
std::move(right_project));
+    QueryPlanStepPtr right_project_step = 
std::make_unique<ExpressionStep>(right.getCurrentDataStream(), 
std::move(right_project));
     right_project_step->setStepDescription("Rename Broadcast Table Name");
     steps.emplace_back(right_project_step.get());
     right.addStep(std::move(right_project_step));
@@ -193,12 +195,9 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & 
left, DB::QueryPlan & righ
         }
     }
     ActionsDAG left_project = ActionsDAG::makeConvertingActions(
-        left.getCurrentDataStream().header.getColumnsWithTypeAndName(),
-        new_left_cols,
-        ActionsDAG::MatchColumnsMode::Position);
+        left.getCurrentDataStream().header.getColumnsWithTypeAndName(), 
new_left_cols, ActionsDAG::MatchColumnsMode::Position);
 
-    QueryPlanStepPtr left_project_step =
-        std::make_unique<ExpressionStep>(left.getCurrentDataStream(), 
std::move(left_project));
+    QueryPlanStepPtr left_project_step = 
std::make_unique<ExpressionStep>(left.getCurrentDataStream(), 
std::move(left_project));
     left_project_step->setStepDescription("Rename Left Table Name for 
broadcast join");
     steps.emplace_back(left_project_step.get());
     left.addStep(std::move(left_project_step));
@@ -206,9 +205,11 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & 
left, DB::QueryPlan & righ
 
 DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, 
DB::QueryPlanPtr left, DB::QueryPlanPtr right)
 {
+    auto join_config = JoinConfig::loadFromContext(getContext());
     google::protobuf::StringValue optimization_info;
     
optimization_info.ParseFromString(join.advanced_extension().optimization().value());
     auto join_opt_info = 
JoinOptimizationInfo::parse(optimization_info.value());
+    LOG_ERROR(getLogger("JoinRelParser"), "optimizaiton info:{}", 
optimization_info.value());
     auto storage_join = join_opt_info.is_broadcast ? 
BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr;
     if (storage_join)
     {
@@ -239,7 +240,9 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
         }
         if (is_col_names_changed)
         {
-            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast 
join, we must not change the columns name in the right table.\nleft 
header:{},\nright header: {} -> {}",
+            throw DB::Exception(
+                DB::ErrorCodes::LOGICAL_ERROR,
+                "For broadcast join, we must not change the columns name in 
the right table.\nleft header:{},\nright header: {} -> {}",
                 left->getCurrentDataStream().header.dumpStructure(),
                 right_header_before_convert_step.dumpStructure(),
                 right->getCurrentDataStream().header.dumpStructure());
@@ -266,7 +269,6 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
 
     if (storage_join)
     {
-
         applyJoinFilter(*table_join, join, *left, *right, true);
         auto broadcast_hash_join = storage_join->getJoinLocked(table_join, 
context);
 
@@ -288,15 +290,13 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
         /// TODO: make smj support mixed conditions
         if (need_post_filter && table_join->kind() != DB::JoinKind::Inner)
         {
-            throw DB::Exception(
-                DB::ErrorCodes::LOGICAL_ERROR,
-                "Sort merge join doesn't support mixed join conditions, except 
inner join.");
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Sort merge 
join doesn't support mixed join conditions, except inner join.");
         }
 
         JoinPtr smj_join = std::make_shared<FullSortingMergeJoin>(table_join, 
right->getCurrentDataStream().header.cloneEmpty(), -1);
         MultiEnum<DB::JoinAlgorithm> join_algorithm = 
context->getSettingsRef().join_algorithm;
         QueryPlanStepPtr join_step
-                 = 
std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), 
right->getCurrentDataStream(), smj_join, 8192, 1, false);
+            = std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), 
right->getCurrentDataStream(), smj_join, 8192, 1, false);
 
         join_step->setStepDescription("SORT_MERGE_JOIN");
         steps.emplace_back(join_step.get());
@@ -311,41 +311,22 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
     }
     else
     {
-        applyJoinFilter(*table_join, join, *left, *right, true);
-
-        /// Following is some configurations for grace hash join.
-        /// - 
spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash.
 This will
-        ///   enable grace hash join.
-        /// - 
spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728.
 This setup
-        ///   the memory limitation fro grace hash join. If the memory 
consumption exceeds the limitation,
-        ///   data will be spilled to disk. Don't set the limitation too 
small, otherwise the buckets number
-        ///   will be too large and the performance will be bad.
-        JoinPtr hash_join = nullptr;
-        MultiEnum<DB::JoinAlgorithm> join_algorithm = 
context->getSettingsRef().join_algorithm;
-        if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH))
+        std::vector<DB::TableJoin::JoinOnClause> join_on_clauses;
+        if (table_join->getClauses().empty())
+            table_join->addDisjunct();
+        bool is_multi_join_on_clauses
+            = couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), 
join_on_clauses, join, left_header, right_header);
+        if (is_multi_join_on_clauses && 
join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0
+            && join_opt_info.partitions_num > 0
+            && join_opt_info.right_table_rows / join_opt_info.partitions_num
+                < join_config.multi_join_on_clauses_build_side_rows_limit)
         {
-            hash_join = std::make_shared<GraceHashJoin>(
-                context,
-                table_join,
-                left->getCurrentDataStream().header,
-                right->getCurrentDataStream().header,
-                context->getTempDataOnDisk());
+            query_plan = buildMultiOnClauseHashJoin(table_join, 
std::move(left), std::move(right), join_on_clauses);
         }
         else
         {
-            hash_join = std::make_shared<HashJoin>(table_join, 
right->getCurrentDataStream().header.cloneEmpty());
+            query_plan = buildSingleOnClauseHashJoin(join, table_join, 
std::move(left), std::move(right));
         }
-        QueryPlanStepPtr join_step
-            = std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), 
right->getCurrentDataStream(), hash_join, 8192, 1, false);
-
-        join_step->setStepDescription("HASH_JOIN");
-        steps.emplace_back(join_step.get());
-        std::vector<QueryPlanPtr> plans;
-        plans.emplace_back(std::move(left));
-        plans.emplace_back(std::move(right));
-
-        query_plan = std::make_unique<QueryPlan>();
-        query_plan->unitePlans(std::move(join_step), {std::move(plans)});
     }
 
     JoinUtil::reorderJoinOutput(*query_plan, after_join_names);
@@ -508,7 +489,11 @@ void JoinRelParser::collectJoinKeys(
 }
 
 bool JoinRelParser::applyJoinFilter(
-    DB::TableJoin & table_join, const substrait::JoinRel & join_rel, 
DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition)
+    DB::TableJoin & table_join,
+    const substrait::JoinRel & join_rel,
+    DB::QueryPlan & left,
+    DB::QueryPlan & right,
+    bool allow_mixed_condition)
 {
     if (!join_rel.has_post_join_filter())
         return true;
@@ -594,12 +579,13 @@ bool JoinRelParser::applyJoinFilter(
             return false;
         auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, 
mixed_header);
         mixed_join_expressions_actions.removeUnusedActions();
-        table_join.getMixedJoinExpression()
-            = 
std::make_shared<DB::ExpressionActions>(std::move(mixed_join_expressions_actions),
 ExpressionActionsSettings::fromContext(context));
+        table_join.getMixedJoinExpression() = 
std::make_shared<DB::ExpressionActions>(
+            std::move(mixed_join_expressions_actions), 
ExpressionActionsSettings::fromContext(context));
     }
     else
     {
-        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table 
column is used in the join condition.\n{}", join_rel.DebugString());
+        throw DB::Exception(
+            DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in 
the join condition.\n{}", join_rel.DebugString());
     }
     return true;
 }
@@ -610,7 +596,7 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & 
query_plan, const substrait::J
     ActionsDAG 
actions_dag{query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName()};
     if (!join.post_join_filter().has_scalar_function())
     {
-       // It may be singular_or_list
+        // It may be singular_or_list
         auto * in_node = getPlanParser()->parseExpression(actions_dag, 
join.post_join_filter());
         filter_name = in_node->result_name;
     }
@@ -624,6 +610,214 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & 
query_plan, const substrait::J
     query_plan.addStep(std::move(filter_step));
 }
 
+/// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4)
+bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
+    const DB::TableJoin::JoinOnClause & prefix_clause,
+    std::vector<DB::TableJoin::JoinOnClause> & clauses,
+    const substrait::JoinRel & join_rel,
+    const DB::Block & left_header,
+    const DB::Block & right_header)
+{
+    /// There is only one join clause
+    if (!join_rel.has_post_join_filter())
+        return false;
+
+    const auto & filter_expr = join_rel.post_join_filter();
+    std::list<const substrait::Expression *> expression_stack;
+    expression_stack.push_back(&filter_expr);
+
+    auto check_function = [&](const String function_name_, const 
substrait::Expression & e)
+    {
+        if (!e.has_scalar_function())
+        {
+            return false;
+        }
+        auto function_name = 
parseFunctionName(e.scalar_function().function_reference(), 
e.scalar_function());
+        return function_name.has_value() && *function_name == function_name_;
+    };
+
+    auto get_field_ref = [](const substrait::Expression & e) -> 
std::optional<Int32>
+    {
+        if (e.has_selection() && e.selection().has_direct_reference() && 
e.selection().direct_reference().has_struct_field())
+        {
+            return 
std::optional<Int32>(e.selection().direct_reference().struct_field().field());
+        }
+        return {};
+    };
+
+    auto parse_join_keys = [&](const substrait::Expression & e) -> 
std::optional<std::pair<String, String>>
+    {
+        const auto & args = e.scalar_function().arguments();
+        auto l_field_ref = get_field_ref(args[0].value());
+        auto r_field_ref = get_field_ref(args[1].value());
+        if (!l_field_ref.has_value() || !r_field_ref.has_value())
+            return {};
+        size_t l_pos = static_cast<size_t>(*l_field_ref);
+        size_t r_pos = static_cast<size_t>(*r_field_ref);
+        size_t l_cols = left_header.columns();
+        size_t total_cols = l_cols + right_header.columns();
+
+        if (l_pos < l_cols && r_pos >= l_cols && r_pos < total_cols)
+            return std::make_pair(left_header.getByPosition(l_pos).name, 
right_header.getByPosition(r_pos - l_cols).name);
+        else if (r_pos < l_cols && l_pos >= l_cols && l_pos < total_cols)
+            return std::make_pair(left_header.getByPosition(r_pos).name, 
right_header.getByPosition(l_pos - l_cols).name);
+        return {};
+    };
+
+    auto parse_and_expression = [&](const substrait::Expression & e, 
DB::TableJoin::JoinOnClause & join_on_clause)
+    {
+        std::vector<const substrait::Expression *> and_expression_stack;
+        and_expression_stack.push_back(&e);
+        while (!and_expression_stack.empty())
+        {
+            const auto & current_expr = *(and_expression_stack.back());
+            and_expression_stack.pop_back();
+            if (check_function("and", current_expr))
+            {
+                for (const auto & arg : e.scalar_function().arguments())
+                    and_expression_stack.push_back(&arg.value());
+            }
+            else if (check_function("equals", current_expr))
+            {
+                auto optional_keys = parse_join_keys(current_expr);
+                if (!optional_keys)
+                {
+                    LOG_ERROR(getLogger("JoinRelParser"), "Not equal 
comparison for keys from both tables");
+                    return false;
+                }
+                join_on_clause.addKey(optional_keys->first, 
optional_keys->second, false);
+            }
+            else
+            {
+                LOG_ERROR(getLogger("JoinRelParser"), "And or equals function 
is expected");
+                return false;
+            }
+        }
+        return true;
+    };
+
+    while (!expression_stack.empty())
+    {
+        const auto & current_expr = *(expression_stack.back());
+        expression_stack.pop_back();
+        if (!check_function("or", current_expr))
+        {
+            LOG_ERROR(getLogger("JoinRelParser"), "Not an or expression");
+        }
+
+        auto get_current_join_on_clause = [&]()
+        {
+            DB::TableJoin::JoinOnClause new_clause = prefix_clause;
+            clauses.push_back(new_clause);
+            return &clauses.back();
+        };
+
+        const auto & args = current_expr.scalar_function().arguments();
+        for (const auto & arg : args)
+        {
+            if (check_function("equals", arg.value()))
+            {
+                auto optional_keys = parse_join_keys(arg.value());
+                if (!optional_keys)
+                {
+                    LOG_ERROR(getLogger("JoinRelParser"), "Not equal 
comparison for keys from both tables");
+                    return false;
+                }
+                get_current_join_on_clause()->addKey(optional_keys->first, 
optional_keys->second, false);
+            }
+            else if (check_function("and", arg.value()))
+            {
+                if (!parse_and_expression(arg.value(), 
*get_current_join_on_clause()))
+                {
+                    LOG_ERROR(getLogger("JoinRelParser"), "Parse and 
expression failed");
+                    return false;
+                }
+            }
+            else if (check_function("or", arg.value()))
+            {
+                expression_stack.push_back(&arg.value());
+            }
+            else
+            {
+                LOG_ERROR(getLogger("JoinRelParser"), "Unknow function");
+                return false;
+            }
+        }
+    }
+    return true;
+}
+
+
+DB::QueryPlanPtr JoinRelParser::buildMultiOnClauseHashJoin(
+    std::shared_ptr<DB::TableJoin> table_join,
+    DB::QueryPlanPtr left_plan,
+    DB::QueryPlanPtr right_plan,
+    const std::vector<DB::TableJoin::JoinOnClause> & join_on_clauses)
+{
+    DB::TableJoin::JoinOnClause & base_join_on_clause = 
table_join->getOnlyClause();
+    base_join_on_clause = join_on_clauses[0];
+    for (size_t i = 1; i < join_on_clauses.size(); ++i)
+    {
+        table_join->addDisjunct();
+        auto & join_on_clause = table_join->getClauses().back();
+        join_on_clause = join_on_clauses[i];
+    }
+
+    LOG_INFO(getLogger("JoinRelParser"), "multi join on clauses:\n{}", 
DB::TableJoin::formatClauses(table_join->getClauses()));
+
+    JoinPtr hash_join = std::make_shared<HashJoin>(table_join, 
right_plan->getCurrentDataStream().header);
+    QueryPlanStepPtr join_step
+        = std::make_unique<DB::JoinStep>(left_plan->getCurrentDataStream(), 
right_plan->getCurrentDataStream(), hash_join, 8192, 1, false);
+    join_step->setStepDescription("Multi join on clause hash join");
+    steps.emplace_back(join_step.get());
+    std::vector<QueryPlanPtr> plans;
+    plans.emplace_back(std::move(left_plan));
+    plans.emplace_back(std::move(right_plan));
+    auto query_plan = std::make_unique<QueryPlan>();
+    query_plan->unitePlans(std::move(join_step), {std::move(plans)});
+    return query_plan;
+}
+
+DB::QueryPlanPtr JoinRelParser::buildSingleOnClauseHashJoin(
+    const substrait::JoinRel & join_rel, std::shared_ptr<DB::TableJoin> 
table_join, DB::QueryPlanPtr left_plan, DB::QueryPlanPtr right_plan)
+{
+    applyJoinFilter(*table_join, join_rel, *left_plan, *right_plan, true);
+    /// Following is some configurations for grace hash join.
+    /// - 
spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash.
 This will
+    ///   enable grace hash join.
+    /// - 
spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728.
 This setup
+    ///   the memory limitation fro grace hash join. If the memory consumption 
exceeds the limitation,
+    ///   data will be spilled to disk. Don't set the limitation too small, 
otherwise the buckets number
+    ///   will be too large and the performance will be bad.
+    JoinPtr hash_join = nullptr;
+    MultiEnum<DB::JoinAlgorithm> join_algorithm = 
context->getSettingsRef().join_algorithm;
+    if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH))
+    {
+        hash_join = std::make_shared<GraceHashJoin>(
+            context,
+            table_join,
+            left_plan->getCurrentDataStream().header,
+            right_plan->getCurrentDataStream().header,
+            context->getTempDataOnDisk());
+    }
+    else
+    {
+        hash_join = std::make_shared<HashJoin>(table_join, 
right_plan->getCurrentDataStream().header.cloneEmpty());
+    }
+    QueryPlanStepPtr join_step
+        = std::make_unique<DB::JoinStep>(left_plan->getCurrentDataStream(), 
right_plan->getCurrentDataStream(), hash_join, 8192, 1, false);
+
+    join_step->setStepDescription("HASH_JOIN");
+    steps.emplace_back(join_step.get());
+    std::vector<QueryPlanPtr> plans;
+    plans.emplace_back(std::move(left_plan));
+    plans.emplace_back(std::move(right_plan));
+
+    auto query_plan = std::make_unique<QueryPlan>();
+    query_plan->unitePlans(std::move(join_step), {std::move(plans)});
+    return query_plan;
+}
+
 void registerJoinRelParser(RelParserFactory & factory)
 {
     auto builder = [](SerializedPlanParser * plan_paser) { return 
std::make_shared<JoinRelParser>(plan_paser); };
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h 
b/cpp-ch/local-engine/Parser/JoinRelParser.h
index ee1155cb4..7e43187be 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.h
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.h
@@ -19,6 +19,7 @@
 #include <memory>
 #include <unordered_set>
 #include <Core/Joins.h>
+#include <Interpreters/TableJoin.h>
 #include <Parser/RelParser.h>
 #include <substrait/algebra.pb.h>
 
@@ -70,6 +71,24 @@ private:
 
     static std::unordered_set<DB::JoinTableSide> 
extractTableSidesFromExpression(
         const substrait::Expression & expr, const DB::Block & left_header, 
const DB::Block & right_header);
+
+    bool couldRewriteToMultiJoinOnClauses(
+        const DB::TableJoin::JoinOnClause & prefix_clause,
+        std::vector<DB::TableJoin::JoinOnClause> & clauses,
+        const substrait::JoinRel & join_rel,
+        const DB::Block & left_header,
+        const DB::Block & right_header);
+
+    DB::QueryPlanPtr buildMultiOnClauseHashJoin(
+        std::shared_ptr<DB::TableJoin> table_join,
+        DB::QueryPlanPtr left_plan,
+        DB::QueryPlanPtr right_plan,
+        const std::vector<DB::TableJoin::JoinOnClause> & join_on_clauses);
+    DB::QueryPlanPtr buildSingleOnClauseHashJoin(
+        const substrait::JoinRel & join_rel,
+        std::shared_ptr<DB::TableJoin> table_join,
+        DB::QueryPlanPtr left_plan,
+        DB::QueryPlanPtr right_plan);
 };
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to