This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a784e4566 [GLUTEN-6544][CH] Support existence join (#6548)
a784e4566 is described below

commit a784e45662117b2192a166d24d463257df7311c4
Author: lgbo <[email protected]>
AuthorDate: Mon Jul 29 10:00:01 2024 +0800

    [GLUTEN-6544][CH] Support existence join (#6548)
    
    * support existence join
    
    * fixed tests
---
 .../gluten/vectorized/StorageJoinBuilder.java      |  2 ++
 .../CHBroadcastNestedLoopJoinExecTransformer.scala | 16 +++++++++-
 .../execution/CHHashJoinExecTransformer.scala      | 37 ++++++++++++++++++++--
 .../apache/gluten/utils/CHJoinValidateUtil.scala   |  4 +++
 .../GlutenClickHouseTPCDSAbstractSuite.scala       | 11 ++++---
 ...nClickHouseTPCDSParquetSortMergeJoinSuite.scala |  5 ++-
 .../execution/GlutenClickHouseTPCHSuite.scala      | 31 ++++++++++++++++++
 .../benchmarks/CHHashBuildBenchmark.scala          |  2 +-
 cpp-ch/local-engine/Common/CHUtil.cpp              |  8 +++--
 cpp-ch/local-engine/Common/CHUtil.h                |  2 +-
 cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp  |  3 +-
 cpp-ch/local-engine/Join/BroadCastJoinBuilder.h    |  1 +
 cpp-ch/local-engine/Parser/JoinRelParser.cpp       | 35 ++++++++++++++++++--
 cpp-ch/local-engine/Parser/JoinRelParser.h         |  2 ++
 cpp-ch/local-engine/local_engine_jni.cpp           |  3 +-
 15 files changed, 144 insertions(+), 18 deletions(-)

diff --git 
a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
 
b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
index 27725998f..ae7b89120 100644
--- 
a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
+++ 
b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
@@ -45,6 +45,7 @@ public class StorageJoinBuilder {
       String joinKeys,
       int joinType,
       boolean hasMixedFiltCondition,
+      boolean isExistenceJoin,
       byte[] namedStruct);
 
   private StorageJoinBuilder() {}
@@ -89,6 +90,7 @@ public class StorageJoinBuilder {
         joinKey,
         joinType,
         broadCastContext.hasMixedFiltCondition(),
+        broadCastContext.isExistenceJoin(),
         toNameStruct(output).toByteArray());
   }
 
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala
index d1dc76045..3aab5a6eb 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.rpc.GlutenDriverEndpoint
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.optimizer.BuildSide
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftSemi}
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.joins.BuildSideRelation
@@ -44,6 +45,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
     condition
   ) {
 
+  private val finalJoinType = joinType match {
+    case ExistenceJoin(_) =>
+      LeftSemi
+    case _ =>
+      joinType
+  }
+
   override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
     val streamedRDD = getColumnarInputRDDs(streamedPlan)
     val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
@@ -57,7 +65,13 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
     }
     val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
     val context =
-      BroadCastHashJoinContext(Seq.empty, joinType, false, buildPlan.output, 
buildBroadcastTableId)
+      BroadCastHashJoinContext(
+        Seq.empty,
+        finalJoinType,
+        false,
+        joinType.isInstanceOf[ExistenceJoin],
+        buildPlan.output,
+        buildBroadcastTableId)
     val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, 
context)
     streamedRDD :+ broadcastRDD
   }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 48870892d..c44156373 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -29,6 +29,8 @@ import org.apache.spark.sql.execution.{SparkPlan, 
SQLExecution}
 import org.apache.spark.sql.execution.joins.BuildSideRelation
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
+import io.substrait.proto.JoinRel
+
 case class CHShuffledHashJoinExecTransformer(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
@@ -82,6 +84,7 @@ case class BroadCastHashJoinContext(
     buildSideJoinKeys: Seq[Expression],
     joinType: JoinType,
     hasMixedFiltCondition: Boolean,
+    isExistenceJoin: Boolean,
     buildSideStructure: Seq[Attribute],
     buildHashTableId: String)
 
@@ -112,7 +115,7 @@ case class CHBroadcastHashJoinExecTransformer(
   override protected def doValidateInternal(): ValidationResult = {
     val shouldFallback =
       CHJoinValidateUtil.shouldFallback(
-        BroadcastHashJoinStrategy(joinType),
+        BroadcastHashJoinStrategy(finalJoinType),
         left.outputSet,
         right.outputSet,
         condition)
@@ -141,8 +144,9 @@ case class CHBroadcastHashJoinExecTransformer(
     val context =
       BroadCastHashJoinContext(
         buildKeyExprs,
-        joinType,
+        finalJoinType,
         isMixedCondition(condition),
+        joinType.isInstanceOf[ExistenceJoin],
         buildPlan.output,
         buildHashTableId)
     val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, 
context)
@@ -161,4 +165,33 @@ case class CHBroadcastHashJoinExecTransformer(
     }
     res
   }
+
+  // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the 
left table with
+  // a new column to indecate whether the row is matched in the right table.
+  // Indeed, the ExistenceJoin is transformed into left any join in CH.
+  // We don't have left any join in substrait, so use left semi join instead.
+  // and isExistenceJoin is set to true to indicate that it is an existence 
join.
+  private val finalJoinType = joinType match {
+    case ExistenceJoin(_) =>
+      LeftSemi
+    case _ =>
+      joinType
+  }
+  override protected lazy val substraitJoinType: JoinRel.JoinType = {
+    joinType match {
+      case _: InnerLike =>
+        JoinRel.JoinType.JOIN_TYPE_INNER
+      case FullOuter =>
+        JoinRel.JoinType.JOIN_TYPE_OUTER
+      case LeftOuter | RightOuter =>
+        JoinRel.JoinType.JOIN_TYPE_LEFT
+      case LeftSemi | ExistenceJoin(_) =>
+        JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
+      case LeftAnti =>
+        JoinRel.JoinType.JOIN_TYPE_ANTI
+      case _ =>
+        // TODO: Support cross join with Cross Rel
+        JoinRel.JoinType.UNRECOGNIZED
+    }
+  }
 }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
index dae8e6e07..08b5ef5b2 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
@@ -55,6 +55,7 @@ object CHJoinValidateUtil extends Logging {
     var shouldFallback = false
     val joinType = joinStrategy.joinType
     if (joinType.toString.contains("ExistenceJoin")) {
+      logError("Fallback for join type ExistenceJoin")
       return true
     }
     if (joinType.sql.contains("INNER")) {
@@ -78,6 +79,9 @@ object CHJoinValidateUtil extends Logging {
         case _ => false
       }
     }
+    if (shouldFallback) {
+      logError(s"Fallback for join type $joinType")
+    }
     shouldFallback
   }
 }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
index f0712bf5a..9787182ed 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
@@ -58,13 +58,14 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
             Seq("q" + "%d".format(queryNum))
           }
           val noFallBack = queryNum match {
-            case i if i == 10 || i == 16 || i == 35 || i == 45 || i == 94 =>
-              // Q10 BroadcastHashJoin, ExistenceJoin
-              // Q16 ShuffledHashJoin, NOT condition
-              // Q35 BroadcastHashJoin, ExistenceJoin
-              // Q45 BroadcastHashJoin, ExistenceJoin
+            case i if !isAqe && (i == 10 || i == 16 || i == 35 || i == 94) =>
+              // q10 smj + existence join
+              // q16 smj + left semi + not condition
+              // q35 smj + existence join
               // Q94 BroadcastHashJoin, LeftSemi, NOT condition
               (false, false)
+            case i if isAqe && (i == 16 || i == 94) =>
+              (false, false)
             case other => (true, false)
           }
           sqlNums.map((_, noFallBack._1, noFallBack._2))
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
index 3f7816cb8..3ec4e31a4 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
@@ -23,12 +23,15 @@ import org.apache.spark.SparkConf
 class GlutenClickHouseTPCDSParquetSortMergeJoinSuite extends 
GlutenClickHouseTPCDSAbstractSuite {
 
   override protected def excludedTpcdsQueries: Set[String] = Set(
-    // fallback due to left semi/anti
+    // fallback due to left semi/anti/existence join
     "q8",
+    "q10",
     "q14a",
     "q14b",
+    "116",
     "q23a",
     "q23b",
+    "q35",
     "q38",
     "q51",
     "q69",
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
index d26891ddb..1c09449c8 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
@@ -500,5 +500,36 @@ class GlutenClickHouseTPCHSuite extends 
GlutenClickHouseTPCHAbstractSuite {
     compareResultsAgainstVanillaSpark(sql2, true, { _ => })
 
   }
+
+  test("existence join") {
+    spark.sql("create table t1(a int, b int) using parquet")
+    spark.sql("create table t2(a int, b int) using parquet")
+    spark.sql("insert into t1 values(0, 0), (1, 2), (2, 3), (3, 4), (null, 5), 
(6, null)")
+    spark.sql("insert into t2 values(0, 0), (1, 2), (2, 3), (2,4), (null, 5), 
(6, null)")
+
+    val sql1 = """
+                 |select * from t1 where exists (select 1 from t2 where t1.a = 
t2.a) or t1.a > 1
+                 |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql1, true, { _ => })
+
+    val sql2 = """
+                 |select * from t1 where exists (select 1 from t2 where t1.a = 
t2.a) or t1.a > 3
+                 |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql2, true, { _ => })
+
+    val sql3 = """
+                 |select * from t1 where exists (select 1 from t2 where t1.a = 
t2.a) or t1.b > 0
+                 |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql3, true, { _ => })
+
+    val sql4 = """
+                 |select * from t1 where exists (select 1 from t2
+                 |where t1.a = t2.a and t1.b = t2.b) or t1.a > 0
+                 |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql4, true, { _ => })
+
+    spark.sql("drop table t1")
+    spark.sql("drop table t2")
+  }
 }
 // scalastyle:off line.size.limit
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
index 8d4bee554..141bf5eea 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
@@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with 
CHSqlBasedBenchmark w
     (
       countsAndBytes.flatMap(_._2),
       countsAndBytes.map(_._1).sum,
-      BroadCastHashJoinContext(Seq(child.output.head), Inner, false, 
child.output, "")
+      BroadCastHashJoinContext(Seq(child.output.head), Inner, false, false, 
child.output, "")
     )
   }
 }
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp 
b/cpp-ch/local-engine/Common/CHUtil.cpp
index 787277dbe..3a699b50e 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -1089,14 +1089,18 @@ void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, 
DB::Names cols)
     plan.addStep(std::move(project_step));
 }
 
-std::pair<DB::JoinKind, DB::JoinStrictness> 
JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type)
+std::pair<DB::JoinKind, DB::JoinStrictness>
+JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool 
is_existence_join)
 {
     switch (join_type)
     {
         case substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
             return {DB::JoinKind::Inner, DB::JoinStrictness::All};
-        case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
+        case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: {
+            if (is_existence_join)
+                return {DB::JoinKind::Left, DB::JoinStrictness::Any};
             return {DB::JoinKind::Left, DB::JoinStrictness::Semi};
+        }
         case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI:
             return {DB::JoinKind::Left, DB::JoinStrictness::Anti};
         case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
diff --git a/cpp-ch/local-engine/Common/CHUtil.h 
b/cpp-ch/local-engine/Common/CHUtil.h
index 98139fb49..b45c6ab3c 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -313,7 +313,7 @@ class JoinUtil
 {
 public:
     static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols);
-    static std::pair<DB::JoinKind, DB::JoinStrictness> 
getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type);
+    static std::pair<DB::JoinKind, DB::JoinStrictness> 
getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type, bool 
is_existence_join);
     static std::pair<DB::JoinKind, DB::JoinStrictness> 
getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type);
 };
 
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp 
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
index 4d5eae6dc..c21cc8ba3 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
@@ -99,6 +99,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
     const std::string & join_keys,
     jint join_type,
     bool has_mixed_join_condition,
+    bool is_existence_join,
     const std::string & named_struct)
 {
     auto join_key_list = Poco::StringTokenizer(join_keys, ",");
@@ -112,7 +113,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
     if (key.starts_with("BuiltBNLJBroadcastTable-"))
         std::tie(kind, strictness) = 
JoinUtil::getCrossJoinKindAndStrictness(static_cast<substrait::CrossRel_JoinType>(join_type));
     else
-        std::tie(kind, strictness) = 
JoinUtil::getJoinKindAndStrictness(static_cast<substrait::JoinRel_JoinType>(join_type));
+        std::tie(kind, strictness) = 
JoinUtil::getJoinKindAndStrictness(static_cast<substrait::JoinRel_JoinType>(join_type),
 is_existence_join);
 
 
     substrait::NamedStruct substrait_struct;
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h 
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
index 3d2e67f9d..a97bd77a8 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
@@ -37,6 +37,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
     const std::string & join_keys,
     jint join_type,
     bool has_mixed_join_condition,
+    bool is_existence_join,
     const std::string & named_struct);
 void cleanBuildHashTable(const std::string & hash_table_id, jlong instance);
 std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & 
hash_table_id);
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp 
b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
index 460311e28..24ba7acdb 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
@@ -33,6 +33,7 @@
 #include <Processors/QueryPlan/JoinStep.h>
 #include <google/protobuf/wrappers.pb.h>
 #include <Common/CHUtil.h>
+#include <Functions/FunctionFactory.h>
 
 #include <Common/logger_useful.h>
 
@@ -50,13 +51,13 @@ using namespace DB;
 
 namespace local_engine
 {
-std::shared_ptr<DB::TableJoin> 
createDefaultTableJoin(substrait::JoinRel_JoinType join_type)
+std::shared_ptr<DB::TableJoin> 
createDefaultTableJoin(substrait::JoinRel_JoinType join_type, bool 
is_existence_join)
 {
     auto & global_context = SerializedPlanParser::global_context;
     auto table_join = std::make_shared<TableJoin>(
         global_context->getSettings(), 
global_context->getGlobalTemporaryVolume(), 
global_context->getTempDataOnDisk());
 
-    std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = 
JoinUtil::getJoinKindAndStrictness(join_type);
+    std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = 
JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join);
     table_join->setKind(kind_and_strictness.first);
     table_join->setStrictness(kind_and_strictness.second);
     return table_join;
@@ -218,7 +219,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
         renamePlanColumns(*left, *right, *storage_join);
     }
 
-    auto table_join = createDefaultTableJoin(join.type());
+    auto table_join = createDefaultTableJoin(join.type(), 
join_opt_info.is_existence_join);
     DB::Block right_header_before_convert_step = 
right->getCurrentDataStream().header;
     addConvertStep(*table_join, *left, *right);
 
@@ -350,11 +351,39 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
         query_plan = std::make_unique<QueryPlan>();
         query_plan->unitePlans(std::move(join_step), {std::move(plans)});
     }
+
     JoinUtil::reorderJoinOutput(*query_plan, after_join_names);
+    /// Need to project the right table column into boolean type
+    if (join_opt_info.is_existence_join)
+    {
+        existenceJoinPostProject(*query_plan, left_names);
+    }
 
     return query_plan;
 }
 
+
+/// We use left any join to implement ExistenceJoin.
+/// The result columns of ExistenceJoin are left table columns + one flag 
column.
+/// The flag column indicates whether a left row is matched or not. We build 
the flag column here.
+/// The input plan's header is left table columns + right table columns. If 
one row in the right row is null,
+/// we mark the flag 0, otherwise mark it 1.
+void JoinRelParser::existenceJoinPostProject(DB::QueryPlan & plan, const 
DB::Names & left_input_cols)
+{
+    auto actions_dag = 
std::make_shared<DB::ActionsDAG>(plan.getCurrentDataStream().header.getColumnsWithTypeAndName());
+    const auto * right_col_node = actions_dag->getInputs().back();
+    auto function_builder = DB::FunctionFactory::instance().get("isNotNull", 
getContext());
+    const auto * not_null_node = &actions_dag->addFunction(function_builder, 
{right_col_node}, right_col_node->result_name);
+    actions_dag->addOrReplaceInOutputs(*not_null_node);
+    DB::Names required_cols = left_input_cols;
+    required_cols.emplace_back(not_null_node->result_name);
+    actions_dag->removeUnusedActions(required_cols);
+    auto project_step = 
std::make_unique<DB::ExpressionStep>(plan.getCurrentDataStream(), actions_dag);
+    project_step->setStepDescription("ExistenceJoin Post Project");
+    steps.emplace_back(project_step.get());
+    plan.addStep(std::move(project_step));
+}
+
 void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & 
left, DB::QueryPlan & right)
 {
     /// If the columns name in right table is duplicated with left table, we 
need to rename the right table's columns.
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h 
b/cpp-ch/local-engine/Parser/JoinRelParser.h
index e6d31e6d3..ee1155cb4 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.h
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.h
@@ -66,6 +66,8 @@ private:
 
     void addPostFilter(DB::QueryPlan & plan, const substrait::JoinRel & join);
 
+    void existenceJoinPostProject(DB::QueryPlan & plan, const DB::Names & 
left_input_cols);
+
     static std::unordered_set<DB::JoinTableSide> 
extractTableSidesFromExpression(
         const substrait::Expression & expr, const DB::Block & left_header, 
const DB::Block & right_header);
 };
diff --git a/cpp-ch/local-engine/local_engine_jni.cpp 
b/cpp-ch/local-engine/local_engine_jni.cpp
index 17d087bb8..a6ca55052 100644
--- a/cpp-ch/local-engine/local_engine_jni.cpp
+++ b/cpp-ch/local-engine/local_engine_jni.cpp
@@ -1094,6 +1094,7 @@ JNIEXPORT jlong 
Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild
     jstring join_key_,
     jint join_type_,
     jboolean has_mixed_join_condition,
+    jboolean is_existence_join,
     jbyteArray named_struct)
 {
     LOCAL_ENGINE_JNI_METHOD_START
@@ -1107,7 +1108,7 @@ JNIEXPORT jlong 
Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild
     DB::CompressedReadBuffer input(read_buffer_from_java_array);
     local_engine::configureCompressedReadBuffer(input);
     const auto * obj = 
make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(
-        hash_table_id, input, row_count_, join_key, join_type_, 
has_mixed_join_condition, struct_string));
+        hash_table_id, input, row_count_, join_key, join_type_, 
has_mixed_join_condition, is_existence_join, struct_string));
     return obj->instance();
     LOCAL_ENGINE_JNI_METHOD_END(env, 0)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to