This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d1b3e9918 fixed missing columns when there is mixed join conditions
(#5997)
d1b3e9918 is described below
commit d1b3e9918fcd087f348d0faf787e42246650b502
Author: lgbo <[email protected]>
AuthorDate: Thu Jun 6 13:43:20 2024 +0800
fixed missing columns when there is mixed join conditions (#5997)
---
.../gluten/vectorized/StorageJoinBuilder.java | 2 +
.../execution/CHHashJoinExecTransformer.scala | 20 +++++++-
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 12 ++++-
.../benchmarks/CHHashBuildBenchmark.scala | 2 +-
cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp | 2 +
cpp-ch/local-engine/Join/BroadCastJoinBuilder.h | 1 +
.../Join/StorageJoinFromReadBuffer.cpp | 54 +++++++++++++++++++++-
.../local-engine/Join/StorageJoinFromReadBuffer.h | 6 +++
cpp-ch/local-engine/local_engine_jni.cpp | 14 ++++--
9 files changed, 104 insertions(+), 9 deletions(-)
diff --git
a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
index 065be9de2..9cb49b6a2 100644
---
a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
+++
b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java
@@ -44,6 +44,7 @@ public class StorageJoinBuilder {
long rowCount,
String joinKeys,
int joinType,
+ boolean hasMixedFiltCondition,
byte[] namedStruct);
private StorageJoinBuilder() {}
@@ -79,6 +80,7 @@ public class StorageJoinBuilder {
rowCount,
joinKey,
SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(),
+ broadCastContext.hasMixedFiltCondition(),
toNameStruct(output).toByteArray());
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 6004f7f86..a7e7769e7 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -82,6 +82,7 @@ case class CHBroadcastBuildSideRDD(
case class BroadCastHashJoinContext(
buildSideJoinKeys: Seq[Expression],
joinType: JoinType,
+ hasMixedFiltCondition: Boolean,
buildSideStructure: Seq[Attribute],
buildHashTableId: String)
@@ -139,9 +140,26 @@ case class CHBroadcastHashJoinExecTransformer(
}
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
val context =
- BroadCastHashJoinContext(buildKeyExprs, joinType, buildPlan.output,
buildHashTableId)
+ BroadCastHashJoinContext(
+ buildKeyExprs,
+ joinType,
+ isMixedCondition(condition),
+ buildPlan.output,
+ buildHashTableId)
val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast,
context)
// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
}
+
+ def isMixedCondition(cond: Option[Expression]): Boolean = {
+ val res = if (cond.isDefined) {
+ val leftOutputSet = left.outputSet
+ val rightOutputSet = right.outputSet
+ val allReferences = cond.get.references
+ !(allReferences.subsetOf(leftOutputSet) ||
allReferences.subsetOf(rightOutputSet))
+ } else {
+ false
+ }
+ res
+ }
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index ada980a20..ee495457e 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2593,13 +2593,21 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
spark.sql("create table ineq_join_t2 (key bigint, value bigint) using
parquet");
spark.sql("insert into ineq_join_t1 values(1, 1), (2, 2), (3, 3), (4,
4), (5, 5)");
spark.sql("insert into ineq_join_t2 values(2, 2), (2, 1), (3, 3), (4,
6), (5, 3)");
- val sql =
+ val sql1 =
"""
| select t1.key, t1.value, t2.key, t2.value from ineq_join_t1 as t1
| left join ineq_join_t2 as t2
| on t1.key = t2.key and t1.value > t2.value
|""".stripMargin
- compareResultsAgainstVanillaSpark(sql, true, { _ => })
+ compareResultsAgainstVanillaSpark(sql1, true, { _ => })
+
+ val sql2 =
+ """
+ | select t1.key, t1.value from ineq_join_t1 as t1
+ | left join ineq_join_t2 as t2
+ | on t1.key = t2.key and t1.value > t2.value and t1.value > t2.key
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql2, true, { _ => })
spark.sql("drop table ineq_join_t1")
spark.sql("drop table ineq_join_t2")
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
index 487433c46..8d4bee554 100644
---
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala
@@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with
CHSqlBasedBenchmark w
(
countsAndBytes.flatMap(_._2),
countsAndBytes.map(_._1).sum,
- BroadCastHashJoinContext(Seq(child.output.head), Inner, child.output, "")
+ BroadCastHashJoinContext(Seq(child.output.head), Inner, false,
child.output, "")
)
}
}
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
index f1b3ac2fb..1c79a00a7 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
@@ -82,6 +82,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
jlong row_count,
const std::string & join_keys,
substrait::JoinRel_JoinType join_type,
+ bool has_mixed_join_condition,
const std::string & named_struct)
{
auto join_key_list = Poco::StringTokenizer(join_keys, ",");
@@ -105,6 +106,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
true,
kind,
strictness,
+ has_mixed_join_condition,
columns_description,
ConstraintsDescription(),
key,
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
index 5aa1e0876..9a6837e35 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
@@ -36,6 +36,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
jlong row_count,
const std::string & join_keys,
substrait::JoinRel_JoinType join_type,
+ bool has_mixed_join_condition,
const std::string & named_struct);
void cleanBuildHashTable(const std::string & hash_table_id, jlong instance);
std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string &
hash_table_id);
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
index f0aec6af6..af306564a 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
@@ -74,6 +74,7 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
bool use_nulls_,
DB::JoinKind kind,
DB::JoinStrictness strictness,
+ bool has_mixed_join_condition,
const ColumnsDescription & columns,
const ConstraintsDescription & constraints,
const String & comment,
@@ -91,7 +92,11 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
key_names.push_back(RIHGT_COLUMN_PREFIX + name);
auto table_join = std::make_shared<DB::TableJoin>(SizeLimits(), true,
kind, strictness, key_names);
right_sample_block = rightSampleBlock(use_nulls, storage_metadata,
table_join->kind());
- buildJoin(in, right_sample_block, table_join);
+ /// If there is mixed join conditions, need to build the hash join lazily,
which rely on the real table join.
+ if (!has_mixed_join_condition)
+ buildJoin(in, right_sample_block, table_join);
+ else
+ collectAllInputs(in, right_sample_block);
}
/// The column names may be different in two blocks.
@@ -135,6 +140,51 @@ void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer &
in, const Block heade
}
}
+void StorageJoinFromReadBuffer::collectAllInputs(DB::ReadBuffer & in, const
DB::Block header)
+{
+ local_engine::NativeReader block_stream(in);
+ ProfileInfo info;
+ while (Block block = block_stream.read())
+ {
+ DB::ColumnsWithTypeAndName columns;
+ for (size_t i = 0; i < block.columns(); ++i)
+ {
+ const auto & column = block.getByPosition(i);
+ columns.emplace_back(convertColumnAsNecessary(column,
header.getByPosition(i)));
+ }
+ DB::Block final_block(columns);
+ info.update(final_block);
+ input_blocks.emplace_back(std::move(final_block));
+ }
+}
+
+void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header,
std::shared_ptr<DB::TableJoin> analyzed_join)
+{
+ {
+ std::shared_lock lock(join_mutex);
+ if (join)
+ return;
+ }
+ std::unique_lock lock(join_mutex);
+ if (join)
+ return;
+ join = std::make_shared<HashJoin>(analyzed_join, header, overwrite,
row_count);
+ while(!input_blocks.empty())
+ {
+ auto & block = *input_blocks.begin();
+ DB::ColumnsWithTypeAndName columns;
+ for (size_t i = 0; i < block.columns(); ++i)
+ {
+ const auto & column = block.getByPosition(i);
+ columns.emplace_back(convertColumnAsNecessary(column,
header.getByPosition(i)));
+ }
+ DB::Block final_block(columns);
+ join->addBlockToJoin(final_block, true);
+ input_blocks.pop_front();
+ }
+}
+
+
/// The column names of 'rgiht_header' could be different from the ones in
`input_blocks`, and we must
/// use 'right_header' to build the HashJoin. Otherwise, it will cause
exceptions with name mismatches.
///
@@ -148,7 +198,7 @@ DB::JoinPtr
StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr<DB::TableJo
ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN,
"Table {} needs the same join_use_nulls setting as present in LEFT
or FULL JOIN",
storage_metadata.comment);
-
+ buildJoinLazily(getRightSampleBlock(), analyzed_join);
HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join,
right_sample_block);
/// reuseJoinedData will set the flag `HashJoin::from_storage_join` which
is required by `FilledStep`
join_clone->reuseJoinedData(static_cast<const HashJoin &>(*join));
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
index af623c0cd..ddefda69c 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
@@ -15,6 +15,7 @@
* limitations under the License.
*/
#pragma once
+#include <shared_mutex>
#include <Interpreters/JoinUtils.h>
#include <Storages/StorageInMemoryMetadata.h>
@@ -40,6 +41,7 @@ public:
bool use_nulls_,
DB::JoinKind kind,
DB::JoinStrictness strictness,
+ bool has_mixed_join_condition,
const DB::ColumnsDescription & columns_,
const DB::ConstraintsDescription & constraints_,
const String & comment,
@@ -58,9 +60,13 @@ private:
size_t row_count;
bool overwrite;
DB::Block right_sample_block;
+ std::shared_mutex join_mutex;
+ std::list<DB::Block> input_blocks;
std::shared_ptr<DB::HashJoin> join = nullptr;
void readAllBlocksFromInput(DB::ReadBuffer & in);
void buildJoin(DB::ReadBuffer & in, const DB::Block header,
std::shared_ptr<DB::TableJoin> analyzed_join);
+ void collectAllInputs(DB::ReadBuffer & in, const DB::Block header);
+ void buildJoinLazily(DB::Block header, std::shared_ptr<DB::TableJoin>
analyzed_join);
};
}
diff --git a/cpp-ch/local-engine/local_engine_jni.cpp
b/cpp-ch/local-engine/local_engine_jni.cpp
index be28b9fab..38f188293 100644
--- a/cpp-ch/local-engine/local_engine_jni.cpp
+++ b/cpp-ch/local-engine/local_engine_jni.cpp
@@ -1172,7 +1172,15 @@ JNIEXPORT jobject
Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn
}
JNIEXPORT jlong
Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild(
- JNIEnv * env, jclass, jstring key, jbyteArray in, jlong row_count_,
jstring join_key_, jint join_type_, jbyteArray named_struct)
+ JNIEnv * env,
+ jclass,
+ jstring key,
+ jbyteArray in,
+ jlong row_count_,
+ jstring join_key_,
+ jint join_type_,
+ jboolean has_mixed_join_condition,
+ jbyteArray named_struct)
{
LOCAL_ENGINE_JNI_METHOD_START
const auto hash_table_id = jstring2string(env, key);
@@ -1186,8 +1194,8 @@ JNIEXPORT jlong
Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild
local_engine::ReadBufferFromByteArray read_buffer_from_java_array(in,
length);
DB::CompressedReadBuffer input(read_buffer_from_java_array);
local_engine::configureCompressedReadBuffer(input);
- const auto * obj
- =
make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(hash_table_id,
input, row_count_, join_key, join_type, struct_string));
+ const auto * obj =
make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(
+ hash_table_id, input, row_count_, join_key, join_type,
has_mixed_join_condition, struct_string));
env->ReleaseByteArrayElements(named_struct, struct_address, JNI_ABORT);
return obj->instance();
LOCAL_ENGINE_JNI_METHOD_END(env, 0)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]