This is an automated email from the ASF dual-hosted git repository.

zhanglistar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a76c92e82 [GLUTEN-5668][CH] Support mixed conditions in shuffle hash 
join (#5735)
a76c92e82 is described below

commit a76c92e82f75250cd834f51ee88f82a7664c6562
Author: lgbo <[email protected]>
AuthorDate: Mon Jun 3 16:09:41 2024 +0800

    [GLUTEN-5668][CH] Support mixed conditions in shuffle hash join (#5735)
    
    * support inequal join
    
    * fixed bugs in CI
    
    * fixed performance issue in broadcast join
    
    * broadcast join build changed
---
 .../clickhouse/CHSparkPlanExecApi.scala            |  14 +-
 .../execution/CHHashJoinExecTransformer.scala      |  14 +-
 .../execution/CHSortMergeJoinExecTransformer.scala |   8 +-
 .../apache/gluten/utils/CHJoinValidateUtil.scala   |  94 ++---
 ...nClickHouseTPCDSParquetGraceHashJoinSuite.scala | 153 +-------
 .../GlutenClickHouseTPCDSParquetSuite.scala        |  69 +---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala |  18 +
 cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp  |   3 +-
 .../Join/StorageJoinFromReadBuffer.cpp             | 129 ++++--
 .../local-engine/Join/StorageJoinFromReadBuffer.h  |  29 +-
 cpp-ch/local-engine/Parser/JoinRelParser.cpp       | 434 +++++++++++++++------
 cpp-ch/local-engine/Parser/JoinRelParser.h         |  24 +-
 .../local-engine/Parser/SerializedPlanParser.cpp   |   2 +-
 13 files changed, 522 insertions(+), 469 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 1403c8261..bdbdfed0d 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -27,7 +27,7 @@ import 
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverri
 import org.apache.gluten.extension.columnar.transition.Convention
 import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.gluten.substrait.expression.{ExpressionBuilder, 
ExpressionNode, WindowFunctionNode}
-import org.apache.gluten.utils.CHJoinValidateUtil
+import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
 import org.apache.gluten.vectorized.CHColumnarBatchSerializer
 
 import org.apache.spark.{ShuffleDependency, SparkException}
@@ -694,15 +694,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
   }
 
   /**
-   * Define whether the join operator is fallback because of the join operator 
is not supported by
-   * backend
+   * This is only used to control whether transform smj into shj or not at 
present. We always prefer
+   * shj.
    */
   override def joinFallback(
-      JoinType: JoinType,
+      joinType: JoinType,
       leftOutputSet: AttributeSet,
       rightOutputSet: AttributeSet,
       condition: Option[Expression]): Boolean = {
-    CHJoinValidateUtil.shouldFallback(JoinType, leftOutputSet, rightOutputSet, 
condition)
+    CHJoinValidateUtil.shouldFallback(
+      UnknownJoinStrategy(joinType),
+      leftOutputSet,
+      rightOutputSet,
+      condition)
   }
 
   /** Generate window function node */
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index c3ab89df5..6004f7f86 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -18,7 +18,7 @@ package org.apache.gluten.execution
 
 import org.apache.gluten.backendsapi.clickhouse.CHIteratorApi
 import org.apache.gluten.extension.ValidationResult
-import org.apache.gluten.utils.CHJoinValidateUtil
+import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, 
ShuffleHashJoinStrategy}
 
 import org.apache.spark.{broadcast, SparkContext}
 import org.apache.spark.rdd.RDD
@@ -55,7 +55,11 @@ case class CHShuffledHashJoinExecTransformer(
 
   override protected def doValidateInternal(): ValidationResult = {
     val shouldFallback =
-      CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, 
right.outputSet, condition)
+      CHJoinValidateUtil.shouldFallback(
+        ShuffleHashJoinStrategy(joinType),
+        left.outputSet,
+        right.outputSet,
+        condition)
     if (shouldFallback) {
       return ValidationResult.notOk("ch join validate fail")
     }
@@ -107,7 +111,11 @@ case class CHBroadcastHashJoinExecTransformer(
 
   override protected def doValidateInternal(): ValidationResult = {
     val shouldFallback =
-      CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, 
right.outputSet, condition)
+      CHJoinValidateUtil.shouldFallback(
+        BroadcastHashJoinStrategy(joinType),
+        left.outputSet,
+        right.outputSet,
+        condition)
 
     if (shouldFallback) {
       return ValidationResult.notOk("ch join validate fail")
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
index a5ac5f658..e2b586551 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten.execution
 
 import org.apache.gluten.extension.ValidationResult
-import org.apache.gluten.utils.CHJoinValidateUtil
+import org.apache.gluten.utils.{CHJoinValidateUtil, SortMergeJoinStrategy}
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
@@ -44,7 +44,11 @@ case class CHSortMergeJoinExecTransformer(
 
   override protected def doValidateInternal(): ValidationResult = {
     val shouldFallback =
-      CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, 
right.outputSet, condition, true)
+      CHJoinValidateUtil.shouldFallback(
+        SortMergeJoinStrategy(joinType),
+        left.outputSet,
+        right.outputSet,
+        condition)
     if (shouldFallback) {
       return ValidationResult.notOk("ch join validate fail")
     }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
index 06b2445af..dae8e6e07 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
@@ -17,9 +17,17 @@
 package org.apache.gluten.utils
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.expressions.{AttributeSet, EqualTo, 
Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, 
Not, Or}
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
 import org.apache.spark.sql.catalyst.plans.JoinType
 
+trait JoinStrategy {
+  val joinType: JoinType
+}
+case class UnknownJoinStrategy(joinType: JoinType) extends JoinStrategy {}
+case class ShuffleHashJoinStrategy(joinType: JoinType) extends JoinStrategy {}
+case class BroadcastHashJoinStrategy(joinType: JoinType) extends JoinStrategy 
{}
+case class SortMergeJoinStrategy(joinType: JoinType) extends JoinStrategy {}
+
 /**
  * The logic here is that if it is not an equi-join spark will create BNLJ, 
which will fallback, if
  * it is an equi-join, spark will create BroadcastHashJoin or ShuffleHashJoin, 
for these join types,
@@ -34,78 +42,40 @@ object CHJoinValidateUtil extends Logging {
   def hasTwoTableColumn(
       leftOutputSet: AttributeSet,
       rightOutputSet: AttributeSet,
-      l: Expression,
-      r: Expression): Boolean = {
-    val allReferences = l.references ++ r.references
+      expr: Expression): Boolean = {
+    val allReferences = expr.references
     !(allReferences.subsetOf(leftOutputSet) || 
allReferences.subsetOf(rightOutputSet))
   }
 
   def shouldFallback(
-      joinType: JoinType,
+      joinStrategy: JoinStrategy,
       leftOutputSet: AttributeSet,
       rightOutputSet: AttributeSet,
-      condition: Option[Expression],
-      isSMJ: Boolean = false): Boolean = {
+      condition: Option[Expression]): Boolean = {
     var shouldFallback = false
+    val joinType = joinStrategy.joinType
     if (joinType.toString.contains("ExistenceJoin")) {
       return true
     }
-    if (joinType.sql.equals("INNER")) {
-      return shouldFallback
-    }
-    if (isSMJ) {
-      if (
-        joinType.sql.contains("SEMI")
-        || joinType.sql.contains("ANTI")
-      ) {
-        return true
+    if (joinType.sql.contains("INNER")) {
+      shouldFallback = false;
+    } else if (
+      condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet, 
condition.get)
+    ) {
+      shouldFallback = joinStrategy match {
+        case BroadcastHashJoinStrategy(joinTy) =>
+          joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+        case SortMergeJoinStrategy(_) => true
+        case ShuffleHashJoinStrategy(joinTy) =>
+          joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+        case UnknownJoinStrategy(joinTy) =>
+          joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
       }
-    }
-    if (condition.isDefined) {
-      condition.get.transform {
-        case Or(l, r) =>
-          if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
-            shouldFallback = true
-          }
-          Or(l, r)
-        case Not(EqualTo(l, r)) =>
-          if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
-            shouldFallback = true
-          }
-          Not(EqualTo(l, r))
-        case LessThan(l, r) =>
-          if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
-            shouldFallback = true
-          }
-          LessThan(l, r)
-        case LessThanOrEqual(l, r) =>
-          if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
-            shouldFallback = true
-          }
-          LessThanOrEqual(l, r)
-        case GreaterThan(l, r) =>
-          if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
-            shouldFallback = true
-          }
-          GreaterThan(l, r)
-        case GreaterThanOrEqual(l, r) =>
-          if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
-            shouldFallback = true
-          }
-          GreaterThanOrEqual(l, r)
-        case In(l, r) =>
-          r.foreach(
-            e => {
-              if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, e)) {
-                shouldFallback = true
-              }
-            })
-          In(l, r)
-        case EqualTo(l, r) =>
-          if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
-            shouldFallback = true
-          }
-          EqualTo(l, r)
+    } else {
+      shouldFallback = joinStrategy match {
+        case SortMergeJoinStrategy(joinTy) =>
+          joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+        case _ => false
       }
     }
     shouldFallback
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
index 0b7ad9a6d..04ccda29b 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
@@ -16,13 +16,9 @@
  */
 package org.apache.gluten.execution
 
-import org.apache.gluten.test.FallbackUtil
-
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, 
Not}
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
 
 class GlutenClickHouseTPCDSParquetGraceHashJoinSuite extends 
GlutenClickHouseTPCDSAbstractSuite {
 
@@ -39,105 +35,11 @@ class GlutenClickHouseTPCDSParquetGraceHashJoinSuite 
extends GlutenClickHouseTPC
       .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
       .set("spark.memory.offHeap.size", "8g")
       
.set("spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm", 
"grace_hash")
-      
.set("spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join", 
"3145728")
+      
.set("spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join", 
"314572800")
   }
 
   executeTPCDSTest(false);
 
-  test(
-    "test fallback operations not supported by ch backend " +
-      "in CHHashJoinExecTransformer && CHBroadcastHashJoinExecTransformer") {
-    val testSql =
-      """
-        | SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, 
i_manufact,
-        |     sum(ss_ext_sales_price) AS ext_price
-        | FROM date_dim
-        | LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
-        | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
-        | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
-        | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
-        | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5) 
<> substr(s_zip,1,5)
-        | WHERE d_moy = 11
-        |   AND d_year = 1999
-        | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
-        | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, 
i_manufact
-        | LIMIT 100;
-        |""".stripMargin
-
-    val df = spark.sql(testSql)
-    val operateWithCondition = df.queryExecution.executedPlan.collect {
-      case f: BroadcastHashJoinExec if f.condition.get.isInstanceOf[Not] => f
-    }
-    assert(
-      operateWithCondition(0).left
-        .asInstanceOf[InputAdapter]
-        .child
-        .isInstanceOf[CHColumnarToRowExec])
-  }
-
-  test("test fallbackutils") {
-    val testSql =
-      """
-        | SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, 
i_manufact,
-        |    sum(ss_ext_sales_price) AS ext_price
-        | FROM date_dim
-        | LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
-        | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
-        | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
-        | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
-        | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5) 
<> substr(s_zip,1,5)
-        | WHERE d_moy = 11
-        |   AND d_year = 1999
-        | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
-        | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, 
i_manufact
-        | LIMIT 100;
-        |""".stripMargin
-
-    val df = spark.sql(testSql)
-    assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
-  }
-
-  test("Gluten-4458: test clickhouse not support join with IN condition") {
-    val testSql =
-      """
-        | SELECT *
-        | FROM date_dim t1
-        | LEFT JOIN date_dim t2 ON t1.d_date_sk = t2.d_date_sk
-        |   AND datediff(t1.d_day_name, t2.d_day_name) IN (1, 3)
-        | LIMIT 100;
-        |""".stripMargin
-
-    val df = spark.sql(testSql)
-    assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
-  }
-
-  test("Gluten-4458: test join with Equal computing two table in one side") {
-    val testSql =
-      """
-        | SELECT *
-        | FROM date_dim t1
-        | LEFT JOIN date_dim t2 ON t1.d_date_sk = t2.d_date_sk AND t1.d_year - 
t2.d_year = 1
-        | LIMIT 100;
-        |""".stripMargin
-
-    val df = spark.sql(testSql)
-    assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
-  }
-
-  test("Gluten-4458: test inner join can support join with IN condition") {
-    val testSql =
-      """
-        | SELECT *
-        | FROM date_dim t1
-        | INNER JOIN date_dim t2 ON t1.d_date_sk = t2.d_date_sk
-        |   AND datediff(t1.d_day_name, t2.d_day_name) IN (1, 3)
-        | LIMIT 100;
-        |""".stripMargin
-
-    val df = spark.sql(testSql)
-    assert(!FallbackUtil.hasFallback(df.queryExecution.executedPlan))
-  }
-
   test("Gluten-1235: Fix missing reading from the broadcasted value when 
executing DPP") {
     val testSql =
       """
@@ -198,55 +100,4 @@ class GlutenClickHouseTPCDSParquetGraceHashJoinSuite 
extends GlutenClickHouseTPC
       }
     }
   }
-
-  test("TPCDS Q21 with non-separated scan rdd") {
-    withSQLConf(("spark.gluten.sql.columnar.separate.scan.rdd.for.ch", 
"false")) {
-      runTPCDSQuery("q21") {
-        df =>
-          val foundDynamicPruningExpr = df.queryExecution.executedPlan.find {
-            case f: FileSourceScanExecTransformer =>
-              f.partitionFilters.exists {
-                case _: DynamicPruningExpression => true
-                case _ => false
-              }
-            case _ => false
-          }
-          assert(foundDynamicPruningExpr.nonEmpty == true)
-
-          val reuseExchange = df.queryExecution.executedPlan.find {
-            case r: ReusedExchangeExec => true
-            case _ => false
-          }
-          assert(reuseExchange.nonEmpty == true)
-      }
-    }
-  }
-
-  test("Gluten-4452: Fix get wrong hash table when multi joins in a task") {
-    val testSql =
-      """
-        | SELECT ws_item_sk, ws_sold_date_sk, ws_ship_date_sk,
-        |        t3.d_date_id as sold_date_id, t2.d_date_id as ship_date_id
-        | FROM (
-        | SELECT ws_item_sk, ws_sold_date_sk, ws_ship_date_sk, t1.d_date_id
-        | FROM web_sales
-        | LEFT JOIN
-        |   (SELECT d_date_id, d_date_sk from date_dim GROUP BY d_date_id, 
d_date_sk) t1
-        | ON ws_sold_date_sk == t1.d_date_sk) t3
-        | INNER JOIN
-        |   (SELECT d_date_id, d_date_sk from date_dim GROUP BY d_date_id, 
d_date_sk) t2
-        | ON ws_ship_date_sk == t2.d_date_sk
-        | LIMIT 100;
-        |""".stripMargin
-    compareResultsAgainstVanillaSpark(
-      testSql,
-      true,
-      df => {
-        val foundBroadcastHashJoinExpr = 
df.queryExecution.executedPlan.collect {
-          case f: CHBroadcastHashJoinExecTransformer => f
-        }
-        assert(foundBroadcastHashJoinExpr.size == 2)
-      }
-    )
-  }
 }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
index a63e47888..e9c27437b 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
@@ -16,13 +16,10 @@
  */
 package org.apache.gluten.execution
 
-import org.apache.gluten.test.FallbackUtil
-
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, 
Not}
+import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
SortMergeJoinExec}
 
 // Some sqls' line length exceeds 100
 // scalastyle:off line.size.limit
@@ -121,38 +118,7 @@ class GlutenClickHouseTPCDSParquetSuite extends 
GlutenClickHouseTPCDSAbstractSui
     assert(result(0).getLong(0) == 73049)
   }
 
-  test(
-    "test fallback operations not supported by ch backend " +
-      "in CHHashJoinExecTransformer && CHBroadcastHashJoinExecTransformer") {
-    val testSql =
-      """
-        |SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, 
i_manufact,
-        |    sum(ss_ext_sales_price) AS ext_price
-        | FROM date_dim
-        | LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
-        | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
-        | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
-        | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
-        | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5) 
<> substr(s_zip,1,5)
-        | WHERE d_moy = 11
-        |   AND d_year = 1999
-        | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
-        | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, 
i_manufact
-        | LIMIT 100;
-        |""".stripMargin
-
-    val df = spark.sql(testSql)
-    val operateWithCondition = df.queryExecution.executedPlan.collect {
-      case f: BroadcastHashJoinExec if f.condition.get.isInstanceOf[Not] => f
-    }
-    assert(
-      operateWithCondition(0).left
-        .asInstanceOf[InputAdapter]
-        .child
-        .isInstanceOf[CHColumnarToRowExec])
-  }
-
-  test("test fallbackutils") {
+  test("Test join with mixed condition 1") {
     val testSql =
       """
         |SELECT  i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, 
i_manufact,
@@ -169,36 +135,7 @@ class GlutenClickHouseTPCDSParquetSuite extends 
GlutenClickHouseTPCDSAbstractSui
         | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, 
i_manufact
         | LIMIT 100;
         |""".stripMargin
-
-    val df = spark.sql(testSql)
-    assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
-  }
-
-  test(
-    "Test avoid forceShuffledHashJoin when the join condition" +
-      " does not supported by the backend") {
-    val testSql =
-      """
-        |SELECT  /*+  merge(date_dim)*/ i_brand_id AS brand_id, i_brand AS 
brand, i_manufact_id, i_manufact,
-        |    sum(ss_ext_sales_price) AS ext_price
-        | FROM date_dim
-        | LEFT JOIN store_sales ON d_date_sk == ss_sold_date_sk AND (d_date_sk 
= 213232  OR ss_sold_date_sk = 3232)
-        | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
-        | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
-        | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
-        | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5) 
<> substr(s_zip,1,5)
-        | WHERE d_moy = 11
-        |   AND d_year = 1999
-        | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
-        | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, 
i_manufact
-        | LIMIT 100;
-        |""".stripMargin
-
-    val df = spark.sql(testSql)
-    val sortMergeJoinExec = df.queryExecution.executedPlan.collect {
-      case s: SortMergeJoinExec => s
-    }
-    assert(sortMergeJoinExec.nonEmpty)
+    compareResultsAgainstVanillaSpark(testSql, true, _ => {})
   }
 
   test("Gluten-1235: Fix missing reading from the broadcasted value when 
executing DPP") {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 748bd5a7f..038b170df 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2563,5 +2563,23 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
     spark.sql("drop table test_tbl_5896")
   }
+
+  test("Inequal join support") {
+    withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) {
+      spark.sql("create table ineq_join_t1 (key bigint, value bigint) using 
parquet");
+      spark.sql("create table ineq_join_t2 (key bigint, value bigint) using 
parquet");
+      spark.sql("insert into ineq_join_t1 values(1, 1), (2, 2), (3, 3), (4, 
4), (5, 5)");
+      spark.sql("insert into ineq_join_t2 values(2, 2), (2, 1), (3, 3), (4, 
6), (5, 3)");
+      val sql =
+        """
+          | select t1.key, t1.value, t2.key, t2.value from ineq_join_t1 as t1
+          | left join ineq_join_t2 as t2
+          | on t1.key = t2.key and t1.value > t2.value
+          |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql, true, { _ => })
+      spark.sql("drop table ineq_join_t1")
+      spark.sql("drop table ineq_join_t2")
+    }
+  }
 }
 // scalastyle:on line.size.limit
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp 
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
index d90951241..f1b3ac2fb 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
@@ -103,7 +103,8 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
         row_count,
         key_names,
         true,
-        std::make_shared<DB::TableJoin>(SizeLimits(), true, kind, strictness, 
key_names),
+        kind,
+        strictness,
         columns_description,
         ConstraintsDescription(),
         key,
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp 
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
index 6d0021adb..f0aec6af6 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
@@ -15,7 +15,9 @@
  * limitations under the License.
  */
 #include "StorageJoinFromReadBuffer.h"
+#include <algorithm>
 
+#include <DataTypes/DataTypeNullable.h>
 #include <Interpreters/Context.h>
 #include <Interpreters/HashJoin.h>
 #include <Interpreters/TableJoin.h>
@@ -23,6 +25,9 @@
 #include <Storages/IO/NativeReader.h>
 #include <Common/Exception.h>
 
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+
 namespace DB
 {
 class HashJoin;
@@ -40,25 +45,23 @@ extern const int DEADLOCK_AVOIDED;
 
 using namespace DB;
 
-void restore(DB::ReadBuffer & in, IJoin & join, const Block & sample_block)
-{
-    local_engine::NativeReader block_stream(in);
-    ProfileInfo info;
-    while (Block block = block_stream.read())
-    {
-        auto final_block = 
sample_block.cloneWithColumns(block.mutateColumns());
-        info.update(final_block);
-        join.addBlockToJoin(final_block, true);
-    }
-}
+constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_";
 
 DB::Block rightSampleBlock(bool use_nulls, const StorageInMemoryMetadata & 
storage_metadata_, JoinKind kind)
 {
+    DB::ColumnsWithTypeAndName new_cols;
     DB::Block block = storage_metadata_.getSampleBlock();
-    if (use_nulls && isLeftOrFull(kind))
-        for (auto & col : block)
-            DB::JoinCommon::convertColumnToNullable(col);
-    return block;
+    for (const auto & col : block)
+    {
+        // Add a prefix to avoid column name conflicts with left table.
+        new_cols.emplace_back(col.column, col.type, RIHGT_COLUMN_PREFIX + 
col.name);
+        if (use_nulls && isLeftOrFull(kind))
+        {
+            auto & new_col = new_cols.back();
+            DB::JoinCommon::convertColumnToNullable(new_col);
+        }
+    }
+    return DB::Block(new_cols);
 }
 
 namespace local_engine
@@ -67,46 +70,88 @@ namespace local_engine
 StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
     DB::ReadBuffer & in,
     size_t row_count_,
-    const Names & key_names,
-    bool use_nulls,
-    std::shared_ptr<DB::TableJoin> table_join,
+    const Names & key_names_,
+    bool use_nulls_,
+    DB::JoinKind kind,
+    DB::JoinStrictness strictness,
     const ColumnsDescription & columns,
     const ConstraintsDescription & constraints,
     const String & comment,
-    const bool overwrite)
-    : key_names_(key_names), use_nulls_(use_nulls)
+    const bool overwrite_)
+    : key_names({}), use_nulls(use_nulls_), row_count(row_count_), 
overwrite(overwrite_)
 {
-    storage_metadata_.setColumns(columns);
-    storage_metadata_.setConstraints(constraints);
-    storage_metadata_.setComment(comment);
+    storage_metadata.setColumns(columns);
+    storage_metadata.setConstraints(constraints);
+    storage_metadata.setComment(comment);
 
-    for (const auto & key : key_names)
-        if (!storage_metadata_.getColumns().hasPhysical(key))
+    for (const auto & key : key_names_)
+        if (!storage_metadata.getColumns().hasPhysical(key))
             throw Exception(ErrorCodes::NO_SUCH_COLUMN_IN_TABLE, "Key column 
({}) does not exist in table declaration.", key);
-    right_sample_block_ = rightSampleBlock(use_nulls, storage_metadata_, 
table_join->kind());
-    join_ = std::make_shared<HashJoin>(table_join, right_sample_block_, 
overwrite, row_count_);
-    restore(in, *join_, storage_metadata_.getSampleBlock());
+    for (const auto & name : key_names_)
+        key_names.push_back(RIHGT_COLUMN_PREFIX + name);
+    auto table_join = std::make_shared<DB::TableJoin>(SizeLimits(), true, 
kind, strictness, key_names);
+    right_sample_block = rightSampleBlock(use_nulls, storage_metadata, 
table_join->kind());
+    buildJoin(in, right_sample_block, table_join);
+}
+
+/// The column names may be different in two blocks.
+/// and the nullability also could be different, with TPCDS-Q1 as an example.
+static DB::ColumnWithTypeAndName convertColumnAsNecessary(const 
DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName & 
sample_column)
+{
+    if (sample_column.type->equals(*column.type))
+        return {column.column, column.type, sample_column.name};
+    else if (
+        sample_column.type->isNullable() && !column.type->isNullable()
+        && DB::removeNullable(sample_column.type)->equals(*column.type))
+    {
+        auto nullable_column = column;
+        DB::JoinCommon::convertColumnToNullable(nullable_column);
+        return {nullable_column.column, sample_column.type, 
sample_column.name};
+    }
+    else
+        throw DB::Exception(
+            DB::ErrorCodes::LOGICAL_ERROR,
+            "Columns have different types. original:{} expected:{}",
+            column.dumpStructure(),
+            sample_column.dumpStructure());
+}
+
+void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer & in, const Block 
header, std::shared_ptr<DB::TableJoin> analyzed_join)
+{
+    local_engine::NativeReader block_stream(in);
+    ProfileInfo info;
+    join = std::make_shared<HashJoin>(analyzed_join, header, overwrite, 
row_count);
+    while (Block block = block_stream.read())
+    {
+        DB::ColumnsWithTypeAndName columns;
+        for (size_t i = 0; i < block.columns(); ++i)
+        {
+            const auto & column = block.getByPosition(i);
+            columns.emplace_back(convertColumnAsNecessary(column, 
header.getByPosition(i)));
+        }
+        DB::Block final_block(columns);
+        info.update(final_block);
+        join->addBlockToJoin(final_block, true);
+    }
 }
 
-DB::JoinPtr 
StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr<DB::TableJoin> 
analyzed_join, DB::ContextPtr /*context*/) const
+/// The column names of 'rgiht_header' could be different from the ones in 
`input_blocks`, and we must
+/// use 'right_header' to build the HashJoin. Otherwise, it will cause 
exceptions with name mismatches.
+///
+/// In most cases, 'getJoinLocked' is called only once, and the input_blocks 
should not be too large.
+/// This is will be OK.
+DB::JoinPtr 
StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr<DB::TableJoin> 
analyzed_join, DB::ContextPtr /*context*/)
 {
-    if ((analyzed_join->forceNullableRight() && !use_nulls_)
-        || (!analyzed_join->forceNullableRight() && 
isLeftOrFull(analyzed_join->kind()) && use_nulls_))
+    if ((analyzed_join->forceNullableRight() && !use_nulls)
+        || (!analyzed_join->forceNullableRight() && 
isLeftOrFull(analyzed_join->kind()) && use_nulls))
         throw Exception(
             ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN,
             "Table {} needs the same join_use_nulls setting as present in LEFT 
or FULL JOIN",
-            storage_metadata_.comment);
-
-    /// TODO: check key columns
-
-    /// Set names qualifiers: table.column -> column
-    /// It's required because storage join stores non-qualified names
-    /// Qualifies will be added by join implementation (HashJoin)
-    analyzed_join->setRightKeys(key_names_);
-
-    HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join, 
right_sample_block_);
-    join_clone->reuseJoinedData(static_cast<const HashJoin &>(*join_));
+            storage_metadata.comment);
 
+    HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join, 
right_sample_block);
+    /// reuseJoinedData will set the flag `HashJoin::from_storage_join` which 
is required by `FilledStep`
+    join_clone->reuseJoinedData(static_cast<const HashJoin &>(*join));
     return join_clone;
 }
 }
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h 
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
index 2e949fa87..af623c0cd 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
@@ -23,6 +23,8 @@ namespace DB
 class TableJoin;
 class IJoin;
 using JoinPtr = std::shared_ptr<IJoin>;
+class HashJoin;
+class ReadBuffer;
 }
 
 namespace local_engine
@@ -33,23 +35,32 @@ class StorageJoinFromReadBuffer
 public:
     StorageJoinFromReadBuffer(
         DB::ReadBuffer & in_,
-        size_t row_count_,
+        size_t row_count,
         const DB::Names & key_names_,
         bool use_nulls_,
-        std::shared_ptr<DB::TableJoin> table_join_,
+        DB::JoinKind kind,
+        DB::JoinStrictness strictness,
         const DB::ColumnsDescription & columns_,
         const DB::ConstraintsDescription & constraints_,
         const String & comment,
         bool overwrite_);
 
-    DB::JoinPtr getJoinLocked(std::shared_ptr<DB::TableJoin> analyzed_join, 
DB::ContextPtr context) const;
-    const DB::Block & getRightSampleBlock() const { return 
right_sample_block_; }
+    /// The columns' names in right_header may be different from the names in 
the ColumnsDescription
+    /// in the constructor.
+    /// This should be called once.
+    DB::JoinPtr getJoinLocked(std::shared_ptr<DB::TableJoin> analyzed_join, 
DB::ContextPtr context);
+    const DB::Block & getRightSampleBlock() const { return right_sample_block; 
}
 
 private:
-    DB::StorageInMemoryMetadata storage_metadata_;
-    const DB::Names key_names_;
-    bool use_nulls_;
-    DB::JoinPtr join_;
-    DB::Block right_sample_block_;
+    DB::StorageInMemoryMetadata storage_metadata;
+    DB::Names key_names;
+    bool use_nulls;
+    size_t row_count;
+    bool overwrite;
+    DB::Block right_sample_block;
+    std::shared_ptr<DB::HashJoin> join = nullptr;
+
+    void readAllBlocksFromInput(DB::ReadBuffer & in);
+    void buildJoin(DB::ReadBuffer & in, const DB::Block header, 
std::shared_ptr<DB::TableJoin> analyzed_join);
 };
 }
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp 
b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
index 8f7f35d5e..937e449b0 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
@@ -31,6 +31,10 @@
 #include <Processors/QueryPlan/JoinStep.h>
 #include <google/protobuf/wrappers.pb.h>
 
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+
+
 namespace DB
 {
 namespace ErrorCodes
@@ -179,40 +183,167 @@ DB::QueryPlanPtr JoinRelParser::parseOp(const 
substrait::Rel & rel, std::list<co
     return parseJoin(join, std::move(left_plan), std::move(right_plan));
 }
 
+std::unordered_set<DB::JoinTableSide> 
JoinRelParser::extractTableSidesFromExpression(const substrait::Expression & 
expr, const DB::Block & left_header, const DB::Block & right_header)
+{
+    std::unordered_set<DB::JoinTableSide> table_sides;
+    if (expr.has_scalar_function())
+    {
+        for (const auto & arg : expr.scalar_function().arguments())
+        {
+            auto table_sides_from_arg = 
extractTableSidesFromExpression(arg.value(), left_header, right_header);
+            table_sides.insert(table_sides_from_arg.begin(), 
table_sides_from_arg.end());
+        }
+    }
+    else if (expr.has_selection() && expr.selection().has_direct_reference() 
&& expr.selection().direct_reference().has_struct_field())
+    {
+        auto pos = expr.selection().direct_reference().struct_field().field();
+        if (pos < left_header.columns())
+        {
+            table_sides.insert(DB::JoinTableSide::Left);
+        }
+        else
+        {
+            table_sides.insert(DB::JoinTableSide::Right);
+        }
+    }
+    else if (expr.has_singular_or_list())
+    {
+        auto child_table_sides = 
extractTableSidesFromExpression(expr.singular_or_list().value(), left_header, 
right_header);
+        table_sides.insert(child_table_sides.begin(), child_table_sides.end());
+        for (const auto & option : expr.singular_or_list().options())
+        {
+            child_table_sides = extractTableSidesFromExpression(option, 
left_header, right_header);
+            table_sides.insert(child_table_sides.begin(), 
child_table_sides.end());
+        }
+    }
+    else if (expr.has_cast())
+    {
+        auto child_table_sides = 
extractTableSidesFromExpression(expr.cast().input(), left_header, right_header);
+        table_sides.insert(child_table_sides.begin(), child_table_sides.end());
+    }
+    else if (expr.has_if_then())
+    {
+        for (const auto & if_child : expr.if_then().ifs())
+        {
+            auto child_table_sides = 
extractTableSidesFromExpression(if_child.if_(), left_header, right_header);
+            table_sides.insert(child_table_sides.begin(), 
child_table_sides.end());
+            child_table_sides = 
extractTableSidesFromExpression(if_child.then(), left_header, right_header);
+            table_sides.insert(child_table_sides.begin(), 
child_table_sides.end());
+        }
+        auto child_table_sides = 
extractTableSidesFromExpression(expr.if_then().else_(), left_header, 
right_header);
+        table_sides.insert(child_table_sides.begin(), child_table_sides.end());
+    }
+    else if (expr.has_literal())
+    {
+        // nothing
+    }
+    else
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Illegal expression 
'{}'", expr.DebugString());
+    }
+    return table_sides;
+}
+
+
+void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & 
right, const StorageJoinFromReadBuffer & storage_join)
+{
+    /// To support mixed join conditions, we must make sure that the column 
names in the right be the same as
+    /// storage_join's right sample block.
+    ActionsDAGPtr project = ActionsDAG::makeConvertingActions(
+        right.getCurrentDataStream().header.getColumnsWithTypeAndName(),
+        storage_join.getRightSampleBlock().getColumnsWithTypeAndName(),
+        ActionsDAG::MatchColumnsMode::Position);
+
+    if (project)
+    {
+        QueryPlanStepPtr project_step = 
std::make_unique<ExpressionStep>(right.getCurrentDataStream(), project);
+        project_step->setStepDescription("Rename Broadcast Table Name");
+        steps.emplace_back(project_step.get());
+        right.addStep(std::move(project_step));
+    }
+
+    /// If the columns name in right table is duplicated with left table, we 
need to rename the left table's columns,
+    /// avoid the columns name in the right table be changed in 
`addConvertStep`.
+    /// This could happen in tpc-ds q44.
+    DB::ColumnsWithTypeAndName new_left_cols;
+    const auto & right_header = right.getCurrentDataStream().header;
+    auto left_prefix = getUniqueName("left");
+    for (const auto & col : left.getCurrentDataStream().header)
+    {
+        if (right_header.has(col.name))
+        {
+            new_left_cols.emplace_back(col.column, col.type, left_prefix + 
col.name);
+        }
+        else
+        {
+            new_left_cols.emplace_back(col.column, col.type, col.name);
+        }
+    }
+    project = ActionsDAG::makeConvertingActions(
+        left.getCurrentDataStream().header.getColumnsWithTypeAndName(),
+        new_left_cols,
+        ActionsDAG::MatchColumnsMode::Position);
+
+    if (project)
+    {
+        QueryPlanStepPtr project_step = 
std::make_unique<ExpressionStep>(left.getCurrentDataStream(), project);
+        project_step->setStepDescription("Rename Left Table Name for broadcast 
join");
+        steps.emplace_back(project_step.get());
+        left.addStep(std::move(project_step));
+    }
+}
+
 DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, 
DB::QueryPlanPtr left, DB::QueryPlanPtr right)
 {
     auto join_opt_info = parseJoinOptimizationInfo(join);
     auto storage_join = join_opt_info.is_broadcast ? 
BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr;
-
     if (storage_join)
     {
-        ActionsDAGPtr project = ActionsDAG::makeConvertingActions(
-            right->getCurrentDataStream().header.getColumnsWithTypeAndName(),
-            storage_join->getRightSampleBlock().getColumnsWithTypeAndName(),
-            ActionsDAG::MatchColumnsMode::Position);
+        renamePlanColumns(*left, *right, *storage_join);
+    }
+
+    auto table_join = createDefaultTableJoin(join.type());
+    DB::Block right_header_before_convert_step = 
right->getCurrentDataStream().header;
+    addConvertStep(*table_join, *left, *right);
 
-        if (project)
+    // Add a check to find error easily.
+    if (storage_join)
+    {
+        if(!blocksHaveEqualStructure(right_header_before_convert_step, 
right->getCurrentDataStream().header))
         {
-            QueryPlanStepPtr project_step = 
std::make_unique<ExpressionStep>(right->getCurrentDataStream(), project);
-            project_step->setStepDescription("Rename Broadcast Table Name");
-            steps.emplace_back(project_step.get());
-            right->addStep(std::move(project_step));
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast 
join, we must not change the columns name in the right table.\nleft 
header:{},\nright header: {} -> {}",
+                left->getCurrentDataStream().header.dumpNames(),
+                right_header_before_convert_step.dumpNames(),
+                right->getCurrentDataStream().header.dumpNames());
         }
     }
 
-    auto table_join = createDefaultTableJoin(join.type());
-    addConvertStep(*table_join, *left, *right);
     Names after_join_names;
     auto left_names = left->getCurrentDataStream().header.getNames();
     after_join_names.insert(after_join_names.end(), left_names.begin(), 
left_names.end());
     auto right_name = table_join->columnsFromJoinedTable().getNames();
     after_join_names.insert(after_join_names.end(), right_name.begin(), 
right_name.end());
-    bool add_filter_step = tryAddPushDownFilter(*table_join, join, *left, 
*right, table_join->columnsFromJoinedTable(), after_join_names);
+
+    auto left_header = left->getCurrentDataStream().header;
+    auto right_header = right->getCurrentDataStream().header;
 
     QueryPlanPtr query_plan;
+
+    /// Support only one join clause.
+    table_join->addDisjunct();
+    /// some examples to explain when the post_join_filter is not empty
+    /// - on t1.key = t2.key and t1.v1 > 1 and t2.v1 > 1, 't1.v1> 1' is in the 
 post filter. but 't2.v1 > 1'
+    ///   will be pushed down into right table by spark and is not in the post 
filter. 't1.key = t2.key ' is
+    ///   in JoinRel::expression.
+    /// - on t1.key = t2. key and t1.v1 > t2.v2, 't1.v1 > t2.v2' is in the 
post filter.
+    collectJoinKeys(*table_join, join, left_header, right_header);
+
     if (storage_join)
     {
+
+        applyJoinFilter(*table_join, join, *left, *right, true);
         auto broadcast_hash_join = storage_join->getJoinLocked(table_join, 
context);
+
         QueryPlanStepPtr join_step = 
std::make_unique<FilledJoinStep>(left->getCurrentDataStream(), 
broadcast_hash_join, 8192);
 
         join_step->setStepDescription("STORAGE_JOIN");
@@ -224,6 +355,18 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
     }
     else if (join_opt_info.is_smj)
     {
+        bool need_post_filter = !applyJoinFilter(*table_join, join, *left, 
*right, false);
+
+        /// If applyJoinFilter returns false, it means there are mixed 
conditions in the post_join_filter.
+        /// It should be a inner join.
+        /// TODO: make smj support mixed conditions
+        if (need_post_filter && table_join->kind() != DB::JoinKind::Inner)
+        {
+            throw DB::Exception(
+                DB::ErrorCodes::LOGICAL_ERROR,
+                "Sort merge join doesn't support mixed join conditions, except 
inner join.");
+        }
+
         JoinPtr smj_join = std::make_shared<FullSortingMergeJoin>(table_join, 
right->getCurrentDataStream().header.cloneEmpty(), -1);
         MultiEnum<DB::JoinAlgorithm> join_algorithm = 
context->getSettingsRef().join_algorithm;
         QueryPlanStepPtr join_step
@@ -237,12 +380,14 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
 
         query_plan = std::make_unique<QueryPlan>();
         query_plan->unitePlans(std::move(join_step), {std::move(plans)});
+        if (need_post_filter)
+            addPostFilter(*query_plan, join);
     }
     else
     {
-        /// TODO: make grace hash join be the default hash join algorithm.
-        ///
-        /// Following is some configuration for grace hash join.
+        applyJoinFilter(*table_join, join, *left, *right, true);
+
+        /// Following is some configurations for grace hash join.
         /// - 
spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash.
 This will
         ///   enable grace hash join.
         /// - 
spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728.
 This setup
@@ -278,28 +423,15 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
     }
     reorderJoinOutput(*query_plan, after_join_names);
 
-    if (add_filter_step)
-    {
-        addPostFilter(*query_plan, join);
-    }
     return query_plan;
 }
 
 void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & 
left, DB::QueryPlan & right)
 {
-
-    /// After https://github.com/ClickHouse/ClickHouse/pull/61216, We will 
failed at tryPushDownFilter() in filterPushDown.cpp
-    /// Here is a workaround, refer to chooseJoinAlgorithm() in 
PlannerJoins.cpp, it always call TableJoin::setRename to
-    /// create aliases for columns in the right table
-    /// By using right table header name sets, so 
TableJoin::deduplicateAndQualifyColumnNames can do same thing as 
chooseJoinAlgorithm()
-    ///
-    /// Affected UT fixed bh this workaround:
-    ///    GlutenClickHouseTPCHParquetRFSuite:TPCH Q17, Q19, Q20, Q21
+    /// If the columns name in right table is duplicated with left table, we 
need to rename the right table's columns.
     NameSet left_columns_set;
-    for (const auto & col : right.getCurrentDataStream().header.getNames())
-    {
+    for (const auto & col : left.getCurrentDataStream().header.getNames())
         left_columns_set.emplace(col);
-    }
     table_join.setColumnsFromJoinedTable(
         right.getCurrentDataStream().header.getNamesAndTypesList(), 
left_columns_set, getUniqueName("right") + ".");
 
@@ -360,117 +492,179 @@ void JoinRelParser::addConvertStep(TableJoin & 
table_join, DB::QueryPlan & left,
     }
 }
 
-void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const 
substrait::JoinRel & join)
+/// Join keys are collected from substrait::JoinRel::expression() which only 
contains the equal join conditions.
+void JoinRelParser::collectJoinKeys(
+    TableJoin & table_join, const substrait::JoinRel & join_rel, const 
DB::Block & left_header, const DB::Block & right_header)
 {
-    std::string filter_name;
-    auto actions_dag = 
std::make_shared<ActionsDAG>(query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName());
-    if (!join.post_join_filter().has_scalar_function())
+    if (!join_rel.has_expression())
+        return;
+    const auto & expr = join_rel.expression();
+    auto & join_clause = table_join.getClauses().back();
+    std::list<const const substrait::Expression *> expressions_stack;
+    expressions_stack.push_back(&expr);
+    while (!expressions_stack.empty())
     {
-       // It may be singular_or_list
-        auto * in_node = getPlanParser()->parseExpression(actions_dag, 
join.post_join_filter());
-        filter_name = in_node->result_name;
-    }
-    else
-    {
-        
getPlanParser()->parseFunction(query_plan.getCurrentDataStream().header, 
join.post_join_filter(), filter_name, actions_dag, true);
+        /// Must handle the expressions in DF order. It matters in sort merge 
join.
+        const auto * current_expr = expressions_stack.back();
+        expressions_stack.pop_back();
+        if (!current_expr->has_scalar_function())
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Function 
expression is expected");
+        auto function_name = 
parseFunctionName(current_expr->scalar_function().function_reference(), 
current_expr->scalar_function());
+        if (!function_name)
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid 
function expression");
+        if (*function_name == "equals")
+        {
+            String left_key, right_key;
+            size_t left_pos = 0, right_pos = 0;
+            for (const auto & arg : 
current_expr->scalar_function().arguments())
+            {
+                if (!arg.value().has_selection() || 
!arg.value().selection().has_direct_reference()
+                    || 
!arg.value().selection().direct_reference().has_struct_field())
+                {
+                    throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A 
column reference is expected");
+                }
+                auto col_pos_ref = 
arg.value().selection().direct_reference().struct_field().field();
+                if (col_pos_ref < left_header.columns())
+                {
+                    left_pos = col_pos_ref;
+                    left_key = left_header.getByPosition(col_pos_ref).name;
+                }
+                else
+                {
+                    right_pos = col_pos_ref - left_header.columns();
+                    right_key = right_header.getByPosition(col_pos_ref - 
left_header.columns()).name;
+                }
+            }
+            if (left_key.empty() || right_key.empty())
+                throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid 
key equal join condition");
+            join_clause.addKey(left_key, right_key, false);
+        }
+        else if (*function_name == "and")
+        {
+            
expressions_stack.push_back(&current_expr->scalar_function().arguments().at(1).value());
+            
expressions_stack.push_back(&current_expr->scalar_function().arguments().at(0).value());
+        }
+        else
+        {
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow 
function: {}", *function_name);
+        }
     }
-    auto filter_step = 
std::make_unique<FilterStep>(query_plan.getCurrentDataStream(), actions_dag, 
filter_name, true);
-    filter_step->setStepDescription("Post Join Filter");
-    steps.emplace_back(filter_step.get());
-    query_plan.addStep(std::move(filter_step));
 }
 
-bool JoinRelParser::tryAddPushDownFilter(
-    TableJoin & table_join,
-    const substrait::JoinRel & join,
-    DB::QueryPlan & left,
-    DB::QueryPlan & right,
-    const NamesAndTypesList & alias_right,
-    const Names & names)
+bool JoinRelParser::applyJoinFilter(
+    DB::TableJoin & table_join, const substrait::JoinRel & join_rel, 
DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition)
 {
-    try
+    if (!join_rel.has_post_join_filter())
+        return true;
+    const auto & expr = join_rel.post_join_filter();
+
+    const auto & left_header = left.getCurrentDataStream().header;
+    const auto & right_header = right.getCurrentDataStream().header;
+    ColumnsWithTypeAndName mixed_columns;
+    std::unordered_set<String> added_column_name;
+    for (const auto & col : left_header.getColumnsWithTypeAndName())
+    {
+        mixed_columns.emplace_back(col);
+        added_column_name.insert(col.name);
+    }
+    for (const auto & col : right_header.getColumnsWithTypeAndName())
     {
-        ASTParser astParser(context, function_mapping, getPlanParser());
-        ASTs args;
+        const auto & renamed_col_name = 
table_join.renamedRightColumnNameWithAlias(col.name);
+        if (added_column_name.find(col.name) != added_column_name.end())
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Right column's 
name conflict with left column: {}", col.name);
+        mixed_columns.emplace_back(col);
+        added_column_name.insert(col.name);
+    }
+    DB::Block mixed_header(mixed_columns);
 
-        if (join.has_expression())
-        {
-            args.emplace_back(astParser.parseToAST(names, join.expression()));
-        }
+    auto table_sides = extractTableSidesFromExpression(expr, left_header, 
right_header);
 
-        if (join.has_post_join_filter())
+    auto get_input_expressions = [](const DB::Block & header)
+    {
+        std::vector<substrait::Expression> exprs;
+        for (size_t i = 0; i < header.columns(); ++i)
         {
-            args.emplace_back(astParser.parseToAST(names, 
join.post_join_filter()));
+            substrait::Expression expr;
+            
expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i);
+            exprs.emplace_back(expr);
         }
-
-        if (args.empty())
-            return false;
-
-        ASTPtr ast = args.size() == 1 ? args.back() : makeASTFunction("and", 
args);
-
-        bool is_asof = (table_join.strictness() == JoinStrictness::Asof);
-
-        Aliases aliases;
-        DatabaseAndTableWithAlias left_table_name;
-        DatabaseAndTableWithAlias right_table_name;
-        TableWithColumnNamesAndTypes left_table(left_table_name, 
left.getCurrentDataStream().header.getNamesAndTypesList());
-        TableWithColumnNamesAndTypes right_table(right_table_name, 
alias_right);
-
-        CollectJoinOnKeysVisitor::Data data{table_join, left_table, 
right_table, aliases, is_asof};
-        if (auto * or_func = ast->as<ASTFunction>(); or_func && or_func->name 
== "or")
+        return exprs;
+    };
+
+    /// If the columns in the expression are all from one table, use 
analyzer_left_filter_condition_column_name
+    /// and analyzer_left_filter_condition_column_name to filt the join result 
data. It requires to build the filter
+    /// column at first.
+    /// If the columns in the expression are from both tables, use 
mixed_join_expression to filt the join result data.
+    /// the filter columns will be built inner the join step.
+    if (table_sides.size() == 1)
+    {
+        auto table_side = *table_sides.begin();
+        if (table_side == DB::JoinTableSide::Left)
         {
-            for (auto & disjunct : or_func->arguments->children)
-            {
-                table_join.addDisjunct();
-                CollectJoinOnKeysVisitor(data).visit(disjunct);
-            }
-            assert(table_join.getClauses().size() == 
or_func->arguments->children.size());
+            auto input_exprs = get_input_expressions(left_header);
+            input_exprs.push_back(expr);
+            auto actions_dag = expressionsToActionsDAG(input_exprs, 
left_header);
+            
table_join.getClauses().back().analyzer_left_filter_condition_column_name = 
actions_dag->getOutputs().back()->result_name;
+            QueryPlanStepPtr before_join_step = 
std::make_unique<ExpressionStep>(left.getCurrentDataStream(), actions_dag);
+            before_join_step->setStepDescription("Before JOIN LEFT");
+            steps.emplace_back(before_join_step.get());
+            left.addStep(std::move(before_join_step));
         }
         else
         {
-            table_join.addDisjunct();
-            CollectJoinOnKeysVisitor(data).visit(ast);
-            assert(table_join.oneDisjunct());
-        }
-
-        if (join.has_post_join_filter())
-        {
-            auto left_keys = table_join.leftKeysList();
-            auto right_keys = table_join.rightKeysList();
-            if (!left_keys->children.empty())
+            /// since the field reference in expr is the index of left_header 
++ right_header, so we use
+            /// mixed_header to build the actions_dag
+            auto input_exprs = get_input_expressions(mixed_header);
+            input_exprs.push_back(expr);
+            auto actions_dag = expressionsToActionsDAG(input_exprs, 
mixed_header);
+
+            /// clear unused columns in actions_dag
+            for (const auto & col : left_header.getColumnsWithTypeAndName())
             {
-                auto actions = 
astParser.convertToActions(left.getCurrentDataStream().header.getNamesAndTypesList(),
 left_keys);
-                QueryPlanStepPtr before_join_step = 
std::make_unique<ExpressionStep>(left.getCurrentDataStream(), actions);
-                before_join_step->setStepDescription("Before JOIN LEFT");
-                steps.emplace_back(before_join_step.get());
-                left.addStep(std::move(before_join_step));
+                actions_dag->removeUnusedResult(col.name);
             }
+            actions_dag->removeUnusedActions();
 
-            if (!right_keys->children.empty())
-            {
-                auto actions = 
astParser.convertToActions(right.getCurrentDataStream().header.getNamesAndTypesList(),
 right_keys);
-                QueryPlanStepPtr before_join_step = 
std::make_unique<ExpressionStep>(right.getCurrentDataStream(), actions);
-                before_join_step->setStepDescription("Before JOIN RIGHT");
-                steps.emplace_back(before_join_step.get());
-                right.addStep(std::move(before_join_step));
-            }
+            
table_join.getClauses().back().analyzer_right_filter_condition_column_name = 
actions_dag->getOutputs().back()->result_name;
+            QueryPlanStepPtr before_join_step = 
std::make_unique<ExpressionStep>(right.getCurrentDataStream(), actions_dag);
+            before_join_step->setStepDescription("Before JOIN RIGHT");
+            steps.emplace_back(before_join_step.get());
+            right.addStep(std::move(before_join_step));
         }
     }
-    // if ch does not support the join type or join conditions, it will throw 
an exception like 'not support'.
-    catch (Poco::Exception & e)
+    else if (table_sides.size() == 2)
     {
-        // CH not support join condition has 'or' and has different table in 
each side.
-        // But in inner join, we could execute join condition after join. so 
we have add filter step
-        if (e.code() == ErrorCodes::INVALID_JOIN_ON_EXPRESSION && 
table_join.kind() == DB::JoinKind::Inner)
-        {
-            return true;
-        }
-        else
-        {
-            throw;
-        }
+        if (!allow_mixed_condition)
+            return false;
+        auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, 
mixed_header);
+        table_join.getMixedJoinExpression()
+            = 
std::make_shared<DB::ExpressionActions>(mixed_join_expressions_actions, 
ExpressionActionsSettings::fromContext(context));
     }
-    return false;
+    else
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table 
column is used in the join condition.\n{}", join_rel.DebugString());
+    }
+    return true;
+}
+
+void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const 
substrait::JoinRel & join)
+{
+    std::string filter_name;
+    auto actions_dag = 
std::make_shared<ActionsDAG>(query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName());
+    if (!join.post_join_filter().has_scalar_function())
+    {
+       // It may be singular_or_list
+        auto * in_node = getPlanParser()->parseExpression(actions_dag, 
join.post_join_filter());
+        filter_name = in_node->result_name;
+    }
+    else
+    {
+        
getPlanParser()->parseFunction(query_plan.getCurrentDataStream().header, 
join.post_join_filter(), filter_name, actions_dag, true);
+    }
+    auto filter_step = 
std::make_unique<FilterStep>(query_plan.getCurrentDataStream(), actions_dag, 
filter_name, true);
+    filter_step->setStepDescription("Post Join Filter");
+    steps.emplace_back(filter_step.get());
+    query_plan.addStep(std::move(filter_step));
 }
 
 void registerJoinRelParser(RelParserFactory & factory)
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h 
b/cpp-ch/local-engine/Parser/JoinRelParser.h
index 445b7e683..c423f4390 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.h
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.h
@@ -17,6 +17,7 @@
 #pragma once
 
 #include <memory>
+#include <unordered_set>
 #include <Parser/RelParser.h>
 #include <substrait/algebra.pb.h>
 
@@ -28,6 +29,8 @@ class TableJoin;
 namespace local_engine
 {
 
+class StorageJoinFromReadBuffer;
+
 std::pair<DB::JoinKind, DB::JoinStrictness> 
getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type);
 
 class JoinRelParser : public RelParser
@@ -50,15 +53,22 @@ private:
 
 
     DB::QueryPlanPtr parseJoin(const substrait::JoinRel & join, 
DB::QueryPlanPtr left, DB::QueryPlanPtr right);
+    void renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const 
StorageJoinFromReadBuffer & storage_join);
     void addConvertStep(TableJoin & table_join, DB::QueryPlan & left, 
DB::QueryPlan & right);
-    bool tryAddPushDownFilter(
-        TableJoin & table_join,
-        const substrait::JoinRel & join,
-        DB::QueryPlan & left,
-        DB::QueryPlan & right,
-        const NamesAndTypesList & alias_right,
-        const Names & names);
+    void collectJoinKeys(
+        TableJoin & table_join, const substrait::JoinRel & join_rel, const 
DB::Block & left_header, const DB::Block & right_header);
+
+    bool applyJoinFilter(
+        DB::TableJoin & table_join,
+        const substrait::JoinRel & join_rel,
+        DB::QueryPlan & left_plan,
+        DB::QueryPlan & right_plan,
+        bool allow_mixed_condition);
+
     void addPostFilter(DB::QueryPlan & plan, const substrait::JoinRel & join);
+
+    static std::unordered_set<DB::JoinTableSide> 
extractTableSidesFromExpression(
+        const substrait::Expression & expr, const DB::Block & left_header, 
const DB::Block & right_header);
 };
 
 }
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index a26f78699..b0d3bbeca 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -214,7 +214,7 @@ std::shared_ptr<ActionsDAG> 
SerializedPlanParser::expressionsToActionsDAG(
                 }
             }
         }
-        else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
+        else if (expr.has_cast() || expr.has_if_then() || expr.has_literal() 
|| expr.has_singular_or_list())
         {
             const auto * node = parseExpression(actions_dag, expr);
             actions_dag->addOrReplaceInOutputs(*node);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to