This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 751b22aba44 branch-4.0: [Improvement](join) add direct mapping opt for
join #57960 (#58309)
751b22aba44 is described below
commit 751b22aba44488ea8fd6e56d11c1fce5ca0eb9b5
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Nov 25 09:22:58 2025 +0800
branch-4.0: [Improvement](join) add direct mapping opt for join #57960
(#58309)
Cherry-picked from #57960
Co-authored-by: Pxl <[email protected]>
---
be/src/pipeline/common/join_utils.h | 104 ++++++++++++++++++++-
be/src/pipeline/exec/hashjoin_build_sink.cpp | 44 +++++----
be/src/pipeline/exec/hashjoin_build_sink.h | 6 +-
.../exec/join/process_hash_table_probe_impl.h | 33 ++++---
.../exec/partitioned_hash_join_probe_operator.cpp | 2 +-
be/src/vec/common/hash_table/hash_map_context.h | 65 +++++++++++++
be/src/vec/common/hash_table/join_hash_table.h | 59 ++++++++----
7 files changed, 253 insertions(+), 60 deletions(-)
diff --git a/be/src/pipeline/common/join_utils.h
b/be/src/pipeline/common/join_utils.h
index aac9932c058..c10b748f82f 100644
--- a/be/src/pipeline/common/join_utils.h
+++ b/be/src/pipeline/common/join_utils.h
@@ -38,13 +38,21 @@ using JoinOpVariants =
std::integral_constant<TJoinOp::type,
TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN>>;
template <class T>
-using PrimaryTypeHashTableContext = vectorized::MethodOneNumber<T,
JoinHashMap<T, HashCRC32<T>>>;
+using PrimaryTypeHashTableContext =
+ vectorized::MethodOneNumber<T, JoinHashMap<T, HashCRC32<T>, false>>;
+
+template <class T>
+using DirectPrimaryTypeHashTableContext =
+ vectorized::MethodOneNumberDirect<T, JoinHashMap<T, HashCRC32<T>,
true>>;
template <class Key>
-using FixedKeyHashTableContext = vectorized::MethodKeysFixed<JoinHashMap<Key,
HashCRC32<Key>>>;
+using FixedKeyHashTableContext =
+ vectorized::MethodKeysFixed<JoinHashMap<Key, HashCRC32<Key>, false>>;
-using SerializedHashTableContext =
vectorized::MethodSerialized<JoinHashMap<StringRef>>;
-using MethodOneString =
vectorized::MethodStringNoCache<JoinHashMap<StringRef>>;
+using SerializedHashTableContext =
+ vectorized::MethodSerialized<JoinHashMap<StringRef,
DefaultHash<StringRef>, false>>;
+using MethodOneString =
+ vectorized::MethodStringNoCache<JoinHashMap<StringRef,
DefaultHash<StringRef>, false>>;
using HashTableVariants = std::variant<
std::monostate, SerializedHashTableContext,
PrimaryTypeHashTableContext<vectorized::UInt8>,
@@ -53,6 +61,11 @@ using HashTableVariants = std::variant<
PrimaryTypeHashTableContext<vectorized::UInt64>,
PrimaryTypeHashTableContext<vectorized::UInt128>,
PrimaryTypeHashTableContext<vectorized::UInt256>,
+ DirectPrimaryTypeHashTableContext<vectorized::UInt8>,
+ DirectPrimaryTypeHashTableContext<vectorized::UInt16>,
+ DirectPrimaryTypeHashTableContext<vectorized::UInt32>,
+ DirectPrimaryTypeHashTableContext<vectorized::UInt64>,
+ DirectPrimaryTypeHashTableContext<vectorized::UInt128>,
FixedKeyHashTableContext<vectorized::UInt64>,
FixedKeyHashTableContext<vectorized::UInt128>,
FixedKeyHashTableContext<vectorized::UInt136>,
FixedKeyHashTableContext<vectorized::UInt256>, MethodOneString>;
@@ -109,4 +122,87 @@ struct JoinDataVariants {
}
};
+template <typename Method>
+void primary_to_direct_mapping(Method* context, const
vectorized::ColumnRawPtrs& key_columns,
+ const
std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+ using FieldType = typename Method::Base::Key;
+ FieldType max_key = std::numeric_limits<FieldType>::min();
+ FieldType min_key = std::numeric_limits<FieldType>::max();
+
+ size_t num_rows = key_columns[0]->size();
+ if (key_columns[0]->is_nullable()) {
+ const FieldType* input_keys =
+ (FieldType*)assert_cast<const
vectorized::ColumnNullable*>(key_columns[0])
+ ->get_nested_column_ptr()
+ ->get_raw_data()
+ .data;
+ const vectorized::NullMap& null_map =
+ assert_cast<const
vectorized::ColumnNullable*>(key_columns[0])->get_null_map_data();
+ // skip first mocked row
+ for (size_t i = 1; i < num_rows; i++) {
+ if (null_map[i]) {
+ continue;
+ }
+ max_key = std::max(max_key, input_keys[i]);
+ min_key = std::min(min_key, input_keys[i]);
+ }
+ } else {
+ const FieldType* input_keys =
(FieldType*)key_columns[0]->get_raw_data().data;
+ // skip first mocked row
+ for (size_t i = 1; i < num_rows; i++) {
+ max_key = std::max(max_key, input_keys[i]);
+ min_key = std::min(min_key, input_keys[i]);
+ }
+ }
+
+ constexpr auto MAX_MAPPING_RANGE = 1 << 23;
+ bool allow_direct_mapping = (max_key >= min_key && max_key - min_key <
MAX_MAPPING_RANGE - 1);
+ if (allow_direct_mapping) {
+ for (const auto& variant_ptr : variant_ptrs) {
+
variant_ptr->method_variant.emplace<DirectPrimaryTypeHashTableContext<FieldType>>(
+ max_key, min_key);
+ }
+ }
+}
+
+template <typename Method>
+void try_convert_to_direct_mapping(
+ Method* method, const vectorized::ColumnRawPtrs& key_columns,
+ const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {}
+
+inline void try_convert_to_direct_mapping(
+ PrimaryTypeHashTableContext<vectorized::UInt8>* context,
+ const vectorized::ColumnRawPtrs& key_columns,
+ const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+ primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+ PrimaryTypeHashTableContext<vectorized::UInt16>* context,
+ const vectorized::ColumnRawPtrs& key_columns,
+ const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+ primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+ PrimaryTypeHashTableContext<vectorized::UInt32>* context,
+ const vectorized::ColumnRawPtrs& key_columns,
+ const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+ primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+ PrimaryTypeHashTableContext<vectorized::UInt64>* context,
+ const vectorized::ColumnRawPtrs& key_columns,
+ const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+ primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+ PrimaryTypeHashTableContext<vectorized::UInt128>* context,
+ const vectorized::ColumnRawPtrs& key_columns,
+ const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+ primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
} // namespace doris
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index b7e9757b515..cbc22f7168d 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -19,6 +19,7 @@
#include <cstdlib>
#include <string>
+#include <variant>
#include "pipeline/exec/hashjoin_probe_operator.h"
#include "pipeline/exec/operator.h"
@@ -87,8 +88,6 @@ Status HashJoinBuildSinkLocalState::init(RuntimeState* state,
LocalSinkStateInfo
_build_table_insert_timer = ADD_TIMER(record_profile,
"BuildTableInsertTime");
_build_expr_call_timer = ADD_TIMER(record_profile, "BuildExprCallTime");
- // Hash Table Init
- RETURN_IF_ERROR(_hash_table_init(state));
_runtime_filter_producer_helper =
std::make_shared<RuntimeFilterProducerHelper>(
_should_build_hash_table, p._is_broadcast_join);
RETURN_IF_ERROR(_runtime_filter_producer_helper->init(state,
_build_expr_ctxs,
@@ -137,7 +136,7 @@ size_t
HashJoinBuildSinkLocalState::get_reserve_mem_size(RuntimeState* state, bo
if (eos) {
const size_t rows = build_block_rows + state->batch_size();
- const auto bucket_size =
JoinHashTable<StringRef>::calc_bucket_size(rows);
+ const auto bucket_size = hash_join_table_calc_bucket_size(rows);
size_to_reserve += bucket_size * sizeof(uint32_t); //
JoinHashTable::first
size_to_reserve += rows * sizeof(uint32_t); //
JoinHashTable::next
@@ -184,10 +183,7 @@ size_t
HashJoinBuildSinkLocalState::get_reserve_mem_size(RuntimeState* state, bo
throw Exception(st);
}
- std::visit(vectorized::Overload {[&](std::monostate& arg) {
- LOG(FATAL) << "FATAL:
uninited hash table";
- __builtin_unreachable();
- },
+ std::visit(vectorized::Overload {[&](std::monostate& arg) {},
[&](auto&& hash_map_context) {
size_to_reserve +=
hash_map_context.estimated_size(
raw_ptrs,
(uint32_t)block.rows(), true,
@@ -361,13 +357,6 @@ Status
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
auto& p = _parent->cast<HashJoinBuildSinkOperatorX>();
SCOPED_TIMER(_build_table_timer);
auto rows = (uint32_t)block.rows();
- if (UNLIKELY(rows == 0)) {
- return Status::OK();
- }
-
- LOG(INFO) << "build block rows: " << block.rows() << ", columns count: "
<< block.columns()
- << ", bytes/allocated_bytes: " <<
PrettyPrinter::print_bytes(block.bytes()) << "/"
- << PrettyPrinter::print_bytes(block.allocated_bytes());
// 1. Dispose the overflow of ColumnString
// 2. Finalize the ColumnVariant to speed up
for (auto& data : block) {
@@ -400,6 +389,8 @@ Status
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
// Get the key column that needs to be built
RETURN_IF_ERROR(_extract_join_column(block, null_map_val, raw_ptrs,
_build_col_ids));
+ RETURN_IF_ERROR(_hash_table_init(state, raw_ptrs));
+
Status st = std::visit(
vectorized::Overload {
[&](std::monostate& arg, auto join_op,
@@ -448,7 +439,8 @@ void
HashJoinBuildSinkLocalState::_set_build_side_has_external_nullmap(
}
}
-Status HashJoinBuildSinkLocalState::_hash_table_init(RuntimeState* state) {
+Status HashJoinBuildSinkLocalState::_hash_table_init(RuntimeState* state,
+ const
vectorized::ColumnRawPtrs& raw_ptrs) {
auto& p = _parent->cast<HashJoinBuildSinkOperatorX>();
std::vector<vectorized::DataTypePtr> data_types;
for (size_t i = 0; i < _build_expr_ctxs.size(); ++i) {
@@ -465,10 +457,21 @@ Status
HashJoinBuildSinkLocalState::_hash_table_init(RuntimeState* state) {
if (_build_expr_ctxs.size() == 1) {
p._should_keep_hash_key_column = true;
}
- return init_hash_method<JoinDataVariants>(
- _shared_state->hash_table_variant_vector[p._use_shared_hash_table
? _task_idx : 0]
- .get(),
- data_types, true);
+
+ std::vector<std::shared_ptr<JoinDataVariants>> variant_ptrs;
+ if (p._is_broadcast_join && p._use_shared_hash_table) {
+ variant_ptrs = _shared_state->hash_table_variant_vector;
+ } else {
+ variant_ptrs.emplace_back(
+
_shared_state->hash_table_variant_vector[p._use_shared_hash_table ? _task_idx :
0]);
+ }
+
+ for (auto& variant_ptr : variant_ptrs) {
+ RETURN_IF_ERROR(init_hash_method<JoinDataVariants>(variant_ptr.get(),
data_types, true));
+ }
+ std::visit([&](auto&& arg) { try_convert_to_direct_mapping(&arg, raw_ptrs,
variant_ptrs); },
+ variant_ptrs[0]->method_variant);
+ return Status::OK();
}
HashJoinBuildSinkOperatorX::HashJoinBuildSinkOperatorX(ObjectPool* pool, int
operator_id,
@@ -636,6 +639,9 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState*
state, vectorized::Block*
std::is_same_v<std::decay_t<decltype(src)>,
std::decay_t<decltype(dst)>>)
{
dst.hash_table = src.hash_table;
+ } else {
+ throw Exception(Status::InternalError(
+ "Hash table type mismatch when share hash
table"));
}
},
local_state._shared_state->hash_table_variant_vector[local_state._task_idx]
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h
b/be/src/pipeline/exec/hashjoin_build_sink.h
index aec6adf084c..f1c94f24c38 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -50,7 +50,7 @@ public:
[[nodiscard]] MOCK_FUNCTION size_t get_reserve_mem_size(RuntimeState*
state, bool eos);
protected:
- Status _hash_table_init(RuntimeState* state);
+ Status _hash_table_init(RuntimeState* state, const
vectorized::ColumnRawPtrs& raw_ptrs);
void _set_build_side_has_external_nullmap(vectorized::Block& block,
const std::vector<int>&
res_col_ids);
Status _do_evaluate(vectorized::Block& block,
vectorized::VExprContextSPtrs& exprs,
@@ -204,8 +204,8 @@ struct ProcessHashTableBuild {
}
SCOPED_TIMER(_parent->_build_table_insert_timer);
- hash_table_ctx.hash_table->template prepare_build<JoinOpType>(_rows,
_batch_size,
-
*has_null_key);
+ hash_table_ctx.hash_table->template prepare_build<JoinOpType>(
+ _rows, _batch_size, *has_null_key,
hash_table_ctx.direct_mapping_range());
// In order to make the null keys equal when using single null eq, all
null keys need to be set to default value.
if (_build_raw_ptrs.size() == 1 && null_map) {
diff --git a/be/src/pipeline/exec/join/process_hash_table_probe_impl.h
b/be/src/pipeline/exec/join/process_hash_table_probe_impl.h
index 0424660db26..1f1edec4335 100644
--- a/be/src/pipeline/exec/join/process_hash_table_probe_impl.h
+++ b/be/src/pipeline/exec/join/process_hash_table_probe_impl.h
@@ -787,20 +787,25 @@ struct ExtractType<T(U)> {
ExtractType<void(T)>::Type & hash_table_ctx,
vectorized::MutableBlock & mutable_block, \
vectorized::Block * output_block, bool* eos, bool is_mark_join);
-#define INSTANTIATION_FOR(JoinOpType)
\
- template struct ProcessHashTableProbe<JoinOpType>;
\
-
\
- INSTANTIATION(JoinOpType, (SerializedHashTableContext));
\
- INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt8>)); \
- INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt16>)); \
- INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt32>)); \
- INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt64>)); \
- INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt128>)); \
- INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt256>)); \
- INSTANTIATION(JoinOpType, (FixedKeyHashTableContext<vectorized::UInt64>));
\
- INSTANTIATION(JoinOpType,
(FixedKeyHashTableContext<vectorized::UInt128>)); \
- INSTANTIATION(JoinOpType,
(FixedKeyHashTableContext<vectorized::UInt136>)); \
- INSTANTIATION(JoinOpType,
(FixedKeyHashTableContext<vectorized::UInt256>)); \
+#define INSTANTIATION_FOR(JoinOpType)
\
+ template struct ProcessHashTableProbe<JoinOpType>;
\
+
\
+ INSTANTIATION(JoinOpType, (SerializedHashTableContext));
\
+ INSTANTIATION(JoinOpType,
(DirectPrimaryTypeHashTableContext<vectorized::UInt8>)); \
+ INSTANTIATION(JoinOpType,
(DirectPrimaryTypeHashTableContext<vectorized::UInt16>)); \
+ INSTANTIATION(JoinOpType,
(DirectPrimaryTypeHashTableContext<vectorized::UInt32>)); \
+ INSTANTIATION(JoinOpType,
(DirectPrimaryTypeHashTableContext<vectorized::UInt64>)); \
+ INSTANTIATION(JoinOpType,
(DirectPrimaryTypeHashTableContext<vectorized::UInt128>)); \
+ INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt8>)); \
+ INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt16>)); \
+ INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt32>)); \
+ INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt64>)); \
+ INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt128>)); \
+ INSTANTIATION(JoinOpType,
(PrimaryTypeHashTableContext<vectorized::UInt256>)); \
+ INSTANTIATION(JoinOpType, (FixedKeyHashTableContext<vectorized::UInt64>));
\
+ INSTANTIATION(JoinOpType,
(FixedKeyHashTableContext<vectorized::UInt128>)); \
+ INSTANTIATION(JoinOpType,
(FixedKeyHashTableContext<vectorized::UInt136>)); \
+ INSTANTIATION(JoinOpType,
(FixedKeyHashTableContext<vectorized::UInt256>)); \
INSTANTIATION(JoinOpType, (MethodOneString));
#include "common/compile_check_end.h"
} // namespace doris::pipeline
diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
index 223a7f24013..dd244fedf57 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
@@ -804,7 +804,7 @@ size_t
PartitionedHashJoinProbeOperatorX::get_reserve_mem_size(RuntimeState* sta
(local_state._recovered_build_block ?
local_state._recovered_build_block->rows()
: 0) +
state->batch_size();
- size_t bucket_size = JoinHashTable<StringRef>::calc_bucket_size(rows);
+ size_t bucket_size = hash_join_table_calc_bucket_size(rows);
size_to_reserve += bucket_size * sizeof(uint32_t); //
JoinHashTable::first
size_to_reserve += rows * sizeof(uint32_t); //
JoinHashTable::next
diff --git a/be/src/vec/common/hash_table/hash_map_context.h
b/be/src/vec/common/hash_table/hash_map_context.h
index 519e326268d..8419cea4341 100644
--- a/be/src/vec/common/hash_table/hash_map_context.h
+++ b/be/src/vec/common/hash_table/hash_map_context.h
@@ -17,6 +17,7 @@
#pragma once
+#include <cstdint>
#include <type_traits>
#include <utility>
@@ -157,6 +158,8 @@ struct MethodBaseInner {
virtual void insert_keys_into_columns(std::vector<Key>& keys,
MutableColumns& key_columns,
uint32_t num_rows) = 0;
+
+ virtual uint32_t direct_mapping_range() { return 0; }
};
template <typename T>
@@ -411,6 +414,68 @@ struct MethodOneNumber : public MethodBase<TData> {
}
};
+template <typename FieldType, typename TData>
+struct MethodOneNumberDirect : public MethodOneNumber<FieldType, TData> {
+ using Base = MethodOneNumber<FieldType, TData>;
+ using Base::init_iterator;
+ using Base::hash_table;
+ using State = ColumnsHashing::HashMethodOneNumber<typename Base::Value,
typename Base::Mapped,
+ FieldType>;
+ FieldType _max_key;
+ FieldType _min_key;
+
+ MethodOneNumberDirect(FieldType max_key, FieldType min_key)
+ : _max_key(max_key), _min_key(min_key) {}
+
+ void init_serialized_keys(const ColumnRawPtrs& key_columns, uint32_t
num_rows,
+ const uint8_t* null_map = nullptr, bool is_join
= false,
+ bool is_build = false, uint32_t bucket_size = 0)
override {
+ Base::keys = (FieldType*)(key_columns[0]->is_nullable()
+ ? assert_cast<const
ColumnNullable*>(key_columns[0])
+ ->get_nested_column_ptr()
+ ->get_raw_data()
+ .data
+ :
key_columns[0]->get_raw_data().data);
+ CHECK(is_join);
+ CHECK_EQ(bucket_size, direct_mapping_range());
+ Base::bucket_nums.resize(num_rows);
+
+ if (null_map == nullptr) {
+ if (is_build) {
+ for (uint32_t k = 1; k < num_rows; ++k) {
+ Base::bucket_nums[k] = uint32_t(Base::keys[k] - _min_key +
1);
+ }
+ } else {
+ for (uint32_t k = 0; k < num_rows; ++k) {
+ Base::bucket_nums[k] = (Base::keys[k] >= _min_key &&
Base::keys[k] <= _max_key)
+ ? uint32_t(Base::keys[k] -
_min_key + 1)
+ : 0;
+ }
+ }
+ } else {
+ if (is_build) {
+ for (uint32_t k = 1; k < num_rows; ++k) {
+ Base::bucket_nums[k] =
+ null_map[k] ? bucket_size : uint32_t(Base::keys[k]
- _min_key + 1);
+ }
+ } else {
+ for (uint32_t k = 0; k < num_rows; ++k) {
+ Base::bucket_nums[k] =
+ null_map[k] ? bucket_size
+ : (Base::keys[k] >= _min_key && Base::keys[k] <=
_max_key)
+ ? uint32_t(Base::keys[k] - _min_key + 1)
+ : 0;
+ }
+ }
+ }
+ }
+
+ uint32_t direct_mapping_range() override {
+ // +2 to include max_key and one slot for out of range value
+ return static_cast<uint32_t>(_max_key - _min_key + 2);
+ }
+};
+
template <typename TData>
struct MethodKeysFixed : public MethodBase<TData> {
using Base = MethodBase<TData>;
diff --git a/be/src/vec/common/hash_table/join_hash_table.h
b/be/src/vec/common/hash_table/join_hash_table.h
index c6227591545..9426829e056 100644
--- a/be/src/vec/common/hash_table/join_hash_table.h
+++ b/be/src/vec/common/hash_table/join_hash_table.h
@@ -25,11 +25,17 @@
#include "common/status.h"
#include "vec/columns/column_filter_helper.h"
#include "vec/common/custom_allocator.h"
-#include "vec/common/hash_table/hash.h"
namespace doris {
#include "common/compile_check_begin.h"
-template <typename Key, typename Hash = DefaultHash<Key>>
+
+inline uint32_t hash_join_table_calc_bucket_size(size_t num_elem) {
+ size_t expect_bucket_size = num_elem + (num_elem - 1) / 7;
+ return
(uint32_t)std::min(phmap::priv::NormalizeCapacity(expect_bucket_size) + 1,
+
static_cast<size_t>(std::numeric_limits<int32_t>::max()) + 1);
+}
+
+template <typename Key, typename Hash, bool DirectMapping>
class JoinHashTable {
public:
using key_type = Key;
@@ -37,25 +43,24 @@ public:
using value_type = void*;
size_t hash(const Key& x) const { return Hash()(x); }
- static uint32_t calc_bucket_size(size_t num_elem) {
- size_t expect_bucket_size = num_elem + (num_elem - 1) / 7;
- return
(uint32_t)std::min(phmap::priv::NormalizeCapacity(expect_bucket_size) + 1,
-
static_cast<size_t>(std::numeric_limits<int32_t>::max()) + 1);
- }
-
size_t get_byte_size() const {
auto cal_vector_mem = [](const auto& vec) { return vec.capacity() *
sizeof(vec[0]); };
return cal_vector_mem(visited) + cal_vector_mem(first) +
cal_vector_mem(next);
}
template <int JoinOpType>
- void prepare_build(size_t num_elem, int batch_size, bool has_null_key) {
+ void prepare_build(size_t num_elem, int batch_size, bool has_null_key,
+ uint32_t force_bucket_size) {
_has_null_key = has_null_key;
// the first row in build side is not really from build side table
_empty_build_side = num_elem <= 1;
max_batch_size = batch_size;
- bucket_size = calc_bucket_size(num_elem + 1);
+ if constexpr (DirectMapping) {
+ bucket_size = force_bucket_size;
+ } else {
+ bucket_size = hash_join_table_calc_bucket_size(num_elem + 1);
+ }
first.resize(bucket_size + 1);
next.resize(num_elem);
@@ -217,6 +222,13 @@ public:
}
private:
+ bool _eq(const Key& lhs, const Key& rhs) const {
+ if (DirectMapping) {
+ return true;
+ }
+ return lhs == rhs;
+ }
+
template <int JoinOpType>
auto _process_null_aware_left_half_join_for_empty_build_side(int
probe_idx, int probe_rows,
uint32_t*
__restrict probe_idxs,
@@ -245,11 +257,20 @@ private:
while (probe_idx < probe_rows) {
auto build_idx = build_idx_map[probe_idx];
- while (build_idx) {
- if (!visited[build_idx] && keys[probe_idx] ==
build_keys[build_idx]) {
- visited[build_idx] = 1;
+ if constexpr (DirectMapping) {
+ if (!visited[build_idx]) {
+ while (build_idx) {
+ visited[build_idx] = 1;
+ build_idx = next[build_idx];
+ }
+ }
+ } else {
+ while (build_idx) {
+ if (!visited[build_idx] && _eq(keys[probe_idx],
build_keys[build_idx])) {
+ visited[build_idx] = 1;
+ }
+ build_idx = next[build_idx];
}
- build_idx = next[build_idx];
}
probe_idx++;
}
@@ -293,7 +314,7 @@ private:
auto do_the_probe = [&]() {
while (build_idx && matched_cnt < batch_size) {
- if (keys[probe_idx] == build_keys[build_idx]) {
+ if (_eq(keys[probe_idx], build_keys[build_idx])) {
build_idxs[matched_cnt] = build_idx;
probe_idxs[matched_cnt] = probe_idx;
matched_cnt++;
@@ -347,7 +368,7 @@ private:
auto do_the_probe = [&]() {
while (build_idx && matched_cnt < batch_size) {
- if (keys[probe_idx] == build_keys[build_idx]) {
+ if (_eq(keys[probe_idx], build_keys[build_idx])) {
probe_idxs[matched_cnt] = probe_idx;
build_idxs[matched_cnt] = build_idx;
matched_cnt++;
@@ -408,7 +429,7 @@ private:
}
while (build_idx && matched_cnt < batch_size) {
- if (picking_null_keys || keys[probe_idx] ==
build_keys[build_idx]) {
+ if (picking_null_keys || _eq(keys[probe_idx],
build_keys[build_idx])) {
build_idxs[matched_cnt] = build_idx;
probe_idxs[matched_cnt] = probe_idx;
null_flags[matched_cnt] = picking_null_keys;
@@ -476,7 +497,7 @@ private:
bool _empty_build_side = true;
};
-template <typename Key, typename Hash = DefaultHash<Key>>
-using JoinHashMap = JoinHashTable<Key, Hash>;
+template <typename Key, typename Hash, bool DirectMapping>
+using JoinHashMap = JoinHashTable<Key, Hash, DirectMapping>;
#include "common/compile_check_end.h"
} // namespace doris
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]