This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new d1b3e9918 fixed missing columns when there is mixed join conditions 
(#5997)
d1b3e9918 is described below

commit d1b3e9918fcd087f348d0faf787e42246650b502
Author: lgbo <[email protected]>
AuthorDate: Thu Jun 6 13:43:20 2024 +0800

    fixed missing columns when there is mixed join conditions (#5997)
---
 .../gluten/vectorized/StorageJoinBuilder.java      |  2 +
 .../execution/CHHashJoinExecTransformer.scala      | 20 +++++++-
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 12 ++++-
 .../benchmarks/CHHashBuildBenchmark.scala          |  2 +-
 cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp  |  2 +
 cpp-ch/local-engine/Join/BroadCastJoinBuilder.h    |  1 +
 .../Join/StorageJoinFromReadBuffer.cpp             | 54 +++++++++++++++++++++-
 .../local-engine/Join/StorageJoinFromReadBuffer.h  |  6 +++
 cpp-ch/local-engine/local_engine_jni.cpp           | 14 ++++--
 9 files changed, 104 insertions(+), 9 deletions(-)

diff --git 
a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
 
b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
index 065be9de2..9cb49b6a2 100644
--- 
a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
+++ 
b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
@@ -44,6 +44,7 @@ public class StorageJoinBuilder {
       long rowCount,
       String joinKeys,
       int joinType,
+      boolean hasMixedFiltCondition,
       byte[] namedStruct);
 
   private StorageJoinBuilder() {}
@@ -79,6 +80,7 @@ public class StorageJoinBuilder {
         rowCount,
         joinKey,
         SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(),
+        broadCastContext.hasMixedFiltCondition(),
         toNameStruct(output).toByteArray());
   }
 
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 6004f7f86..a7e7769e7 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -82,6 +82,7 @@ case class CHBroadcastBuildSideRDD(
 case class BroadCastHashJoinContext(
     buildSideJoinKeys: Seq[Expression],
     joinType: JoinType,
+    hasMixedFiltCondition: Boolean,
     buildSideStructure: Seq[Attribute],
     buildHashTableId: String)
 
@@ -139,9 +140,26 @@ case class CHBroadcastHashJoinExecTransformer(
     }
     val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
     val context =
-      BroadCastHashJoinContext(buildKeyExprs, joinType, buildPlan.output, 
buildHashTableId)
+      BroadCastHashJoinContext(
+        buildKeyExprs,
+        joinType,
+        isMixedCondition(condition),
+        buildPlan.output,
+        buildHashTableId)
     val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, 
context)
     // FIXME: Do we have to make build side a RDD?
     streamedRDD :+ broadcastRDD
   }
+
+  def isMixedCondition(cond: Option[Expression]): Boolean = {
+    val res = if (cond.isDefined) {
+      val leftOutputSet = left.outputSet
+      val rightOutputSet = right.outputSet
+      val allReferences = cond.get.references
+      !(allReferences.subsetOf(leftOutputSet) || 
allReferences.subsetOf(rightOutputSet))
+    } else {
+      false
+    }
+    res
+  }
 }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index ada980a20..ee495457e 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2593,13 +2593,21 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
       spark.sql("create table ineq_join_t2 (key bigint, value bigint) using 
parquet");
       spark.sql("insert into ineq_join_t1 values(1, 1), (2, 2), (3, 3), (4, 
4), (5, 5)");
       spark.sql("insert into ineq_join_t2 values(2, 2), (2, 1), (3, 3), (4, 
6), (5, 3)");
-      val sql =
+      val sql1 =
         """
           | select t1.key, t1.value, t2.key, t2.value from ineq_join_t1 as t1
           | left join ineq_join_t2 as t2
           | on t1.key = t2.key and t1.value > t2.value
           |""".stripMargin
-      compareResultsAgainstVanillaSpark(sql, true, { _ => })
+      compareResultsAgainstVanillaSpark(sql1, true, { _ => })
+
+      val sql2 =
+        """
+          | select t1.key, t1.value from ineq_join_t1 as t1
+          | left join ineq_join_t2 as t2
+          | on t1.key = t2.key and t1.value > t2.value and t1.value > t2.key
+          |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql2, true, { _ => })
       spark.sql("drop table ineq_join_t1")
       spark.sql("drop table ineq_join_t2")
     }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
index 487433c46..8d4bee554 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
@@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with 
CHSqlBasedBenchmark w
     (
       countsAndBytes.flatMap(_._2),
       countsAndBytes.map(_._1).sum,
-      BroadCastHashJoinContext(Seq(child.output.head), Inner, child.output, "")
+      BroadCastHashJoinContext(Seq(child.output.head), Inner, false, 
child.output, "")
     )
   }
 }
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp 
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
index f1b3ac2fb..1c79a00a7 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
@@ -82,6 +82,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
     jlong row_count,
     const std::string & join_keys,
     substrait::JoinRel_JoinType join_type,
+    bool has_mixed_join_condition,
     const std::string & named_struct)
 {
     auto join_key_list = Poco::StringTokenizer(join_keys, ",");
@@ -105,6 +106,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
         true,
         kind,
         strictness,
+        has_mixed_join_condition,
         columns_description,
         ConstraintsDescription(),
         key,
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h 
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
index 5aa1e0876..9a6837e35 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
@@ -36,6 +36,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
     jlong row_count,
     const std::string & join_keys,
     substrait::JoinRel_JoinType join_type,
+    bool has_mixed_join_condition,
     const std::string & named_struct);
 void cleanBuildHashTable(const std::string & hash_table_id, jlong instance);
 std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & 
hash_table_id);
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp 
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
index f0aec6af6..af306564a 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
@@ -74,6 +74,7 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
     bool use_nulls_,
     DB::JoinKind kind,
     DB::JoinStrictness strictness,
+    bool has_mixed_join_condition,
     const ColumnsDescription & columns,
     const ConstraintsDescription & constraints,
     const String & comment,
@@ -91,7 +92,11 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
         key_names.push_back(RIHGT_COLUMN_PREFIX + name);
     auto table_join = std::make_shared<DB::TableJoin>(SizeLimits(), true, 
kind, strictness, key_names);
     right_sample_block = rightSampleBlock(use_nulls, storage_metadata, 
table_join->kind());
-    buildJoin(in, right_sample_block, table_join);
+    /// If there is mixed join conditions, need to build the hash join lazily, 
which rely on the real table join.
+    if (!has_mixed_join_condition)
+        buildJoin(in, right_sample_block, table_join);
+    else
+        collectAllInputs(in, right_sample_block);
 }
 
 /// The column names may be different in two blocks.
@@ -135,6 +140,51 @@ void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer & 
in, const Block heade
     }
 }
 
+void StorageJoinFromReadBuffer::collectAllInputs(DB::ReadBuffer & in, const 
DB::Block header)
+{
+    local_engine::NativeReader block_stream(in);
+    ProfileInfo info;
+    while (Block block = block_stream.read())
+    {
+        DB::ColumnsWithTypeAndName columns;
+        for (size_t i = 0; i < block.columns(); ++i)
+        {
+            const auto & column = block.getByPosition(i);
+            columns.emplace_back(convertColumnAsNecessary(column, 
header.getByPosition(i)));
+        }
+        DB::Block final_block(columns);
+        info.update(final_block);
+        input_blocks.emplace_back(std::move(final_block));
+    }
+}
+
+void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header, 
std::shared_ptr<DB::TableJoin> analyzed_join)
+{
+    {
+        std::shared_lock lock(join_mutex);
+        if (join)
+            return;
+    }
+    std::unique_lock lock(join_mutex);
+    if (join)
+        return;
+    join = std::make_shared<HashJoin>(analyzed_join, header, overwrite, 
row_count);
+    while(!input_blocks.empty())
+    {
+        auto & block = *input_blocks.begin();
+        DB::ColumnsWithTypeAndName columns;
+        for (size_t i = 0; i < block.columns(); ++i)
+        {
+            const auto & column = block.getByPosition(i);
+            columns.emplace_back(convertColumnAsNecessary(column, 
header.getByPosition(i)));
+        }
+        DB::Block final_block(columns);
+        join->addBlockToJoin(final_block, true);
+        input_blocks.pop_front();
+    }
+}
+
+
 /// The column names of 'rgiht_header' could be different from the ones in 
`input_blocks`, and we must
 /// use 'right_header' to build the HashJoin. Otherwise, it will cause 
exceptions with name mismatches.
 ///
@@ -148,7 +198,7 @@ DB::JoinPtr 
StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr<DB::TableJo
             ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN,
             "Table {} needs the same join_use_nulls setting as present in LEFT 
or FULL JOIN",
             storage_metadata.comment);
-
+    buildJoinLazily(getRightSampleBlock(), analyzed_join);
     HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join, 
right_sample_block);
     /// reuseJoinedData will set the flag `HashJoin::from_storage_join` which 
is required by `FilledStep`
     join_clone->reuseJoinedData(static_cast<const HashJoin &>(*join));
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h 
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
index af623c0cd..ddefda69c 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
@@ -15,6 +15,7 @@
  * limitations under the License.
  */
 #pragma once
+#include <shared_mutex>
 #include <Interpreters/JoinUtils.h>
 #include <Storages/StorageInMemoryMetadata.h>
 
@@ -40,6 +41,7 @@ public:
         bool use_nulls_,
         DB::JoinKind kind,
         DB::JoinStrictness strictness,
+        bool has_mixed_join_condition,
         const DB::ColumnsDescription & columns_,
         const DB::ConstraintsDescription & constraints_,
         const String & comment,
@@ -58,9 +60,13 @@ private:
     size_t row_count;
     bool overwrite;
     DB::Block right_sample_block;
+    std::shared_mutex join_mutex;
+    std::list<DB::Block> input_blocks;
     std::shared_ptr<DB::HashJoin> join = nullptr;
 
     void readAllBlocksFromInput(DB::ReadBuffer & in);
     void buildJoin(DB::ReadBuffer & in, const DB::Block header, 
std::shared_ptr<DB::TableJoin> analyzed_join);
+    void collectAllInputs(DB::ReadBuffer & in, const DB::Block header);
+    void buildJoinLazily(DB::Block header, std::shared_ptr<DB::TableJoin> 
analyzed_join);
 };
 }
diff --git a/cpp-ch/local-engine/local_engine_jni.cpp 
b/cpp-ch/local-engine/local_engine_jni.cpp
index be28b9fab..38f188293 100644
--- a/cpp-ch/local-engine/local_engine_jni.cpp
+++ b/cpp-ch/local-engine/local_engine_jni.cpp
@@ -1172,7 +1172,15 @@ JNIEXPORT jobject 
Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn
 }
 
 JNIEXPORT jlong 
Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild(
-    JNIEnv * env, jclass, jstring key, jbyteArray in, jlong row_count_, 
jstring join_key_, jint join_type_, jbyteArray named_struct)
+    JNIEnv * env,
+    jclass,
+    jstring key,
+    jbyteArray in,
+    jlong row_count_,
+    jstring join_key_,
+    jint join_type_,
+    jboolean has_mixed_join_condition,
+    jbyteArray named_struct)
 {
     LOCAL_ENGINE_JNI_METHOD_START
     const auto hash_table_id = jstring2string(env, key);
@@ -1186,8 +1194,8 @@ JNIEXPORT jlong 
Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild
     local_engine::ReadBufferFromByteArray read_buffer_from_java_array(in, 
length);
     DB::CompressedReadBuffer input(read_buffer_from_java_array);
     local_engine::configureCompressedReadBuffer(input);
-    const auto * obj
-        = 
make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(hash_table_id, 
input, row_count_, join_key, join_type, struct_string));
+    const auto * obj = 
make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(
+        hash_table_id, input, row_count_, join_key, join_type, 
has_mixed_join_condition, struct_string));
     env->ReleaseByteArrayElements(named_struct, struct_address, JNI_ABORT);
     return obj->instance();
     LOCAL_ENGINE_JNI_METHOD_END(env, 0)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to