This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 751b22aba44 branch-4.0: [Improvement](join) add direct mapping opt for 
join #57960 (#58309)
751b22aba44 is described below

commit 751b22aba44488ea8fd6e56d11c1fce5ca0eb9b5
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Nov 25 09:22:58 2025 +0800

    branch-4.0: [Improvement](join) add direct mapping opt for join #57960 
(#58309)
    
    Cherry-picked from #57960
    
    Co-authored-by: Pxl <[email protected]>
---
 be/src/pipeline/common/join_utils.h                | 104 ++++++++++++++++++++-
 be/src/pipeline/exec/hashjoin_build_sink.cpp       |  44 +++++----
 be/src/pipeline/exec/hashjoin_build_sink.h         |   6 +-
 .../exec/join/process_hash_table_probe_impl.h      |  33 ++++---
 .../exec/partitioned_hash_join_probe_operator.cpp  |   2 +-
 be/src/vec/common/hash_table/hash_map_context.h    |  65 +++++++++++++
 be/src/vec/common/hash_table/join_hash_table.h     |  59 ++++++++----
 7 files changed, 253 insertions(+), 60 deletions(-)

diff --git a/be/src/pipeline/common/join_utils.h 
b/be/src/pipeline/common/join_utils.h
index aac9932c058..c10b748f82f 100644
--- a/be/src/pipeline/common/join_utils.h
+++ b/be/src/pipeline/common/join_utils.h
@@ -38,13 +38,21 @@ using JoinOpVariants =
                      std::integral_constant<TJoinOp::type, 
TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN>>;
 
 template <class T>
-using PrimaryTypeHashTableContext = vectorized::MethodOneNumber<T, 
JoinHashMap<T, HashCRC32<T>>>;
+using PrimaryTypeHashTableContext =
+        vectorized::MethodOneNumber<T, JoinHashMap<T, HashCRC32<T>, false>>;
+
+template <class T>
+using DirectPrimaryTypeHashTableContext =
+        vectorized::MethodOneNumberDirect<T, JoinHashMap<T, HashCRC32<T>, 
true>>;
 
 template <class Key>
-using FixedKeyHashTableContext = vectorized::MethodKeysFixed<JoinHashMap<Key, 
HashCRC32<Key>>>;
+using FixedKeyHashTableContext =
+        vectorized::MethodKeysFixed<JoinHashMap<Key, HashCRC32<Key>, false>>;
 
-using SerializedHashTableContext = 
vectorized::MethodSerialized<JoinHashMap<StringRef>>;
-using MethodOneString = 
vectorized::MethodStringNoCache<JoinHashMap<StringRef>>;
+using SerializedHashTableContext =
+        vectorized::MethodSerialized<JoinHashMap<StringRef, 
DefaultHash<StringRef>, false>>;
+using MethodOneString =
+        vectorized::MethodStringNoCache<JoinHashMap<StringRef, 
DefaultHash<StringRef>, false>>;
 
 using HashTableVariants = std::variant<
         std::monostate, SerializedHashTableContext, 
PrimaryTypeHashTableContext<vectorized::UInt8>,
@@ -53,6 +61,11 @@ using HashTableVariants = std::variant<
         PrimaryTypeHashTableContext<vectorized::UInt64>,
         PrimaryTypeHashTableContext<vectorized::UInt128>,
         PrimaryTypeHashTableContext<vectorized::UInt256>,
+        DirectPrimaryTypeHashTableContext<vectorized::UInt8>,
+        DirectPrimaryTypeHashTableContext<vectorized::UInt16>,
+        DirectPrimaryTypeHashTableContext<vectorized::UInt32>,
+        DirectPrimaryTypeHashTableContext<vectorized::UInt64>,
+        DirectPrimaryTypeHashTableContext<vectorized::UInt128>,
         FixedKeyHashTableContext<vectorized::UInt64>, 
FixedKeyHashTableContext<vectorized::UInt128>,
         FixedKeyHashTableContext<vectorized::UInt136>,
         FixedKeyHashTableContext<vectorized::UInt256>, MethodOneString>;
@@ -109,4 +122,87 @@ struct JoinDataVariants {
     }
 };
 
+template <typename Method>
+void primary_to_direct_mapping(Method* context, const 
vectorized::ColumnRawPtrs& key_columns,
+                               const 
std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+    using FieldType = typename Method::Base::Key;
+    FieldType max_key = std::numeric_limits<FieldType>::min();
+    FieldType min_key = std::numeric_limits<FieldType>::max();
+
+    size_t num_rows = key_columns[0]->size();
+    if (key_columns[0]->is_nullable()) {
+        const FieldType* input_keys =
+                (FieldType*)assert_cast<const 
vectorized::ColumnNullable*>(key_columns[0])
+                        ->get_nested_column_ptr()
+                        ->get_raw_data()
+                        .data;
+        const vectorized::NullMap& null_map =
+                assert_cast<const 
vectorized::ColumnNullable*>(key_columns[0])->get_null_map_data();
+        // skip first mocked row
+        for (size_t i = 1; i < num_rows; i++) {
+            if (null_map[i]) {
+                continue;
+            }
+            max_key = std::max(max_key, input_keys[i]);
+            min_key = std::min(min_key, input_keys[i]);
+        }
+    } else {
+        const FieldType* input_keys = 
(FieldType*)key_columns[0]->get_raw_data().data;
+        // skip first mocked row
+        for (size_t i = 1; i < num_rows; i++) {
+            max_key = std::max(max_key, input_keys[i]);
+            min_key = std::min(min_key, input_keys[i]);
+        }
+    }
+
+    constexpr auto MAX_MAPPING_RANGE = 1 << 23;
+    bool allow_direct_mapping = (max_key >= min_key && max_key - min_key < 
MAX_MAPPING_RANGE - 1);
+    if (allow_direct_mapping) {
+        for (const auto& variant_ptr : variant_ptrs) {
+            
variant_ptr->method_variant.emplace<DirectPrimaryTypeHashTableContext<FieldType>>(
+                    max_key, min_key);
+        }
+    }
+}
+
+template <typename Method>
+void try_convert_to_direct_mapping(
+        Method* method, const vectorized::ColumnRawPtrs& key_columns,
+        const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {}
+
+inline void try_convert_to_direct_mapping(
+        PrimaryTypeHashTableContext<vectorized::UInt8>* context,
+        const vectorized::ColumnRawPtrs& key_columns,
+        const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+    primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+        PrimaryTypeHashTableContext<vectorized::UInt16>* context,
+        const vectorized::ColumnRawPtrs& key_columns,
+        const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+    primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+        PrimaryTypeHashTableContext<vectorized::UInt32>* context,
+        const vectorized::ColumnRawPtrs& key_columns,
+        const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+    primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+        PrimaryTypeHashTableContext<vectorized::UInt64>* context,
+        const vectorized::ColumnRawPtrs& key_columns,
+        const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+    primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
+inline void try_convert_to_direct_mapping(
+        PrimaryTypeHashTableContext<vectorized::UInt128>* context,
+        const vectorized::ColumnRawPtrs& key_columns,
+        const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
+    primary_to_direct_mapping(context, key_columns, variant_ptrs);
+}
+
 } // namespace doris
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index b7e9757b515..cbc22f7168d 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -19,6 +19,7 @@
 
 #include <cstdlib>
 #include <string>
+#include <variant>
 
 #include "pipeline/exec/hashjoin_probe_operator.h"
 #include "pipeline/exec/operator.h"
@@ -87,8 +88,6 @@ Status HashJoinBuildSinkLocalState::init(RuntimeState* state, 
LocalSinkStateInfo
     _build_table_insert_timer = ADD_TIMER(record_profile, 
"BuildTableInsertTime");
     _build_expr_call_timer = ADD_TIMER(record_profile, "BuildExprCallTime");
 
-    // Hash Table Init
-    RETURN_IF_ERROR(_hash_table_init(state));
     _runtime_filter_producer_helper = 
std::make_shared<RuntimeFilterProducerHelper>(
             _should_build_hash_table, p._is_broadcast_join);
     RETURN_IF_ERROR(_runtime_filter_producer_helper->init(state, 
_build_expr_ctxs,
@@ -137,7 +136,7 @@ size_t 
HashJoinBuildSinkLocalState::get_reserve_mem_size(RuntimeState* state, bo
 
     if (eos) {
         const size_t rows = build_block_rows + state->batch_size();
-        const auto bucket_size = 
JoinHashTable<StringRef>::calc_bucket_size(rows);
+        const auto bucket_size = hash_join_table_calc_bucket_size(rows);
 
         size_to_reserve += bucket_size * sizeof(uint32_t); // 
JoinHashTable::first
         size_to_reserve += rows * sizeof(uint32_t);        // 
JoinHashTable::next
@@ -184,10 +183,7 @@ size_t 
HashJoinBuildSinkLocalState::get_reserve_mem_size(RuntimeState* state, bo
                 throw Exception(st);
             }
 
-            std::visit(vectorized::Overload {[&](std::monostate& arg) {
-                                                 LOG(FATAL) << "FATAL: 
uninited hash table";
-                                                 __builtin_unreachable();
-                                             },
+            std::visit(vectorized::Overload {[&](std::monostate& arg) {},
                                              [&](auto&& hash_map_context) {
                                                  size_to_reserve += 
hash_map_context.estimated_size(
                                                          raw_ptrs, 
(uint32_t)block.rows(), true,
@@ -361,13 +357,6 @@ Status 
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
     auto& p = _parent->cast<HashJoinBuildSinkOperatorX>();
     SCOPED_TIMER(_build_table_timer);
     auto rows = (uint32_t)block.rows();
-    if (UNLIKELY(rows == 0)) {
-        return Status::OK();
-    }
-
-    LOG(INFO) << "build block rows: " << block.rows() << ", columns count: " 
<< block.columns()
-              << ", bytes/allocated_bytes: " << 
PrettyPrinter::print_bytes(block.bytes()) << "/"
-              << PrettyPrinter::print_bytes(block.allocated_bytes());
     // 1. Dispose the overflow of ColumnString
     // 2. Finalize the ColumnVariant to speed up
     for (auto& data : block) {
@@ -400,6 +389,8 @@ Status 
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
     // Get the key column that needs to be built
     RETURN_IF_ERROR(_extract_join_column(block, null_map_val, raw_ptrs, 
_build_col_ids));
 
+    RETURN_IF_ERROR(_hash_table_init(state, raw_ptrs));
+
     Status st = std::visit(
             vectorized::Overload {
                     [&](std::monostate& arg, auto join_op,
@@ -448,7 +439,8 @@ void 
HashJoinBuildSinkLocalState::_set_build_side_has_external_nullmap(
     }
 }
 
-Status HashJoinBuildSinkLocalState::_hash_table_init(RuntimeState* state) {
+Status HashJoinBuildSinkLocalState::_hash_table_init(RuntimeState* state,
+                                                     const 
vectorized::ColumnRawPtrs& raw_ptrs) {
     auto& p = _parent->cast<HashJoinBuildSinkOperatorX>();
     std::vector<vectorized::DataTypePtr> data_types;
     for (size_t i = 0; i < _build_expr_ctxs.size(); ++i) {
@@ -465,10 +457,21 @@ Status 
HashJoinBuildSinkLocalState::_hash_table_init(RuntimeState* state) {
     if (_build_expr_ctxs.size() == 1) {
         p._should_keep_hash_key_column = true;
     }
-    return init_hash_method<JoinDataVariants>(
-            _shared_state->hash_table_variant_vector[p._use_shared_hash_table 
? _task_idx : 0]
-                    .get(),
-            data_types, true);
+
+    std::vector<std::shared_ptr<JoinDataVariants>> variant_ptrs;
+    if (p._is_broadcast_join && p._use_shared_hash_table) {
+        variant_ptrs = _shared_state->hash_table_variant_vector;
+    } else {
+        variant_ptrs.emplace_back(
+                
_shared_state->hash_table_variant_vector[p._use_shared_hash_table ? _task_idx : 
0]);
+    }
+
+    for (auto& variant_ptr : variant_ptrs) {
+        RETURN_IF_ERROR(init_hash_method<JoinDataVariants>(variant_ptr.get(), 
data_types, true));
+    }
+    std::visit([&](auto&& arg) { try_convert_to_direct_mapping(&arg, raw_ptrs, 
variant_ptrs); },
+               variant_ptrs[0]->method_variant);
+    return Status::OK();
 }
 
 HashJoinBuildSinkOperatorX::HashJoinBuildSinkOperatorX(ObjectPool* pool, int 
operator_id,
@@ -636,6 +639,9 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* 
state, vectorized::Block*
                                   std::is_same_v<std::decay_t<decltype(src)>,
                                                  std::decay_t<decltype(dst)>>) 
{
                         dst.hash_table = src.hash_table;
+                    } else {
+                        throw Exception(Status::InternalError(
+                                "Hash table type mismatch when share hash 
table"));
                     }
                 },
                 
local_state._shared_state->hash_table_variant_vector[local_state._task_idx]
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h 
b/be/src/pipeline/exec/hashjoin_build_sink.h
index aec6adf084c..f1c94f24c38 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -50,7 +50,7 @@ public:
     [[nodiscard]] MOCK_FUNCTION size_t get_reserve_mem_size(RuntimeState* 
state, bool eos);
 
 protected:
-    Status _hash_table_init(RuntimeState* state);
+    Status _hash_table_init(RuntimeState* state, const 
vectorized::ColumnRawPtrs& raw_ptrs);
     void _set_build_side_has_external_nullmap(vectorized::Block& block,
                                               const std::vector<int>& 
res_col_ids);
     Status _do_evaluate(vectorized::Block& block, 
vectorized::VExprContextSPtrs& exprs,
@@ -204,8 +204,8 @@ struct ProcessHashTableBuild {
         }
 
         SCOPED_TIMER(_parent->_build_table_insert_timer);
-        hash_table_ctx.hash_table->template prepare_build<JoinOpType>(_rows, 
_batch_size,
-                                                                      
*has_null_key);
+        hash_table_ctx.hash_table->template prepare_build<JoinOpType>(
+                _rows, _batch_size, *has_null_key, 
hash_table_ctx.direct_mapping_range());
 
         // In order to make the null keys equal when using single null eq, all 
null keys need to be set to default value.
         if (_build_raw_ptrs.size() == 1 && null_map) {
diff --git a/be/src/pipeline/exec/join/process_hash_table_probe_impl.h 
b/be/src/pipeline/exec/join/process_hash_table_probe_impl.h
index 0424660db26..1f1edec4335 100644
--- a/be/src/pipeline/exec/join/process_hash_table_probe_impl.h
+++ b/be/src/pipeline/exec/join/process_hash_table_probe_impl.h
@@ -787,20 +787,25 @@ struct ExtractType<T(U)> {
             ExtractType<void(T)>::Type & hash_table_ctx, 
vectorized::MutableBlock & mutable_block, \
             vectorized::Block * output_block, bool* eos, bool is_mark_join);
 
-#define INSTANTIATION_FOR(JoinOpType)                                          
    \
-    template struct ProcessHashTableProbe<JoinOpType>;                         
    \
-                                                                               
    \
-    INSTANTIATION(JoinOpType, (SerializedHashTableContext));                   
    \
-    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt8>));   \
-    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt16>));  \
-    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt32>));  \
-    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt64>));  \
-    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt128>)); \
-    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt256>)); \
-    INSTANTIATION(JoinOpType, (FixedKeyHashTableContext<vectorized::UInt64>)); 
    \
-    INSTANTIATION(JoinOpType, 
(FixedKeyHashTableContext<vectorized::UInt128>));    \
-    INSTANTIATION(JoinOpType, 
(FixedKeyHashTableContext<vectorized::UInt136>));    \
-    INSTANTIATION(JoinOpType, 
(FixedKeyHashTableContext<vectorized::UInt256>));    \
+#define INSTANTIATION_FOR(JoinOpType)                                          
          \
+    template struct ProcessHashTableProbe<JoinOpType>;                         
          \
+                                                                               
          \
+    INSTANTIATION(JoinOpType, (SerializedHashTableContext));                   
          \
+    INSTANTIATION(JoinOpType, 
(DirectPrimaryTypeHashTableContext<vectorized::UInt8>));   \
+    INSTANTIATION(JoinOpType, 
(DirectPrimaryTypeHashTableContext<vectorized::UInt16>));  \
+    INSTANTIATION(JoinOpType, 
(DirectPrimaryTypeHashTableContext<vectorized::UInt32>));  \
+    INSTANTIATION(JoinOpType, 
(DirectPrimaryTypeHashTableContext<vectorized::UInt64>));  \
+    INSTANTIATION(JoinOpType, 
(DirectPrimaryTypeHashTableContext<vectorized::UInt128>)); \
+    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt8>));         \
+    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt16>));        \
+    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt32>));        \
+    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt64>));        \
+    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt128>));       \
+    INSTANTIATION(JoinOpType, 
(PrimaryTypeHashTableContext<vectorized::UInt256>));       \
+    INSTANTIATION(JoinOpType, (FixedKeyHashTableContext<vectorized::UInt64>)); 
          \
+    INSTANTIATION(JoinOpType, 
(FixedKeyHashTableContext<vectorized::UInt128>));          \
+    INSTANTIATION(JoinOpType, 
(FixedKeyHashTableContext<vectorized::UInt136>));          \
+    INSTANTIATION(JoinOpType, 
(FixedKeyHashTableContext<vectorized::UInt256>));          \
     INSTANTIATION(JoinOpType, (MethodOneString));
 #include "common/compile_check_end.h"
 } // namespace doris::pipeline
diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
index 223a7f24013..dd244fedf57 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
@@ -804,7 +804,7 @@ size_t 
PartitionedHashJoinProbeOperatorX::get_reserve_mem_size(RuntimeState* sta
                 (local_state._recovered_build_block ? 
local_state._recovered_build_block->rows()
                                                     : 0) +
                 state->batch_size();
-        size_t bucket_size = JoinHashTable<StringRef>::calc_bucket_size(rows);
+        size_t bucket_size = hash_join_table_calc_bucket_size(rows);
 
         size_to_reserve += bucket_size * sizeof(uint32_t); // 
JoinHashTable::first
         size_to_reserve += rows * sizeof(uint32_t);        // 
JoinHashTable::next
diff --git a/be/src/vec/common/hash_table/hash_map_context.h 
b/be/src/vec/common/hash_table/hash_map_context.h
index 519e326268d..8419cea4341 100644
--- a/be/src/vec/common/hash_table/hash_map_context.h
+++ b/be/src/vec/common/hash_table/hash_map_context.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <cstdint>
 #include <type_traits>
 #include <utility>
 
@@ -157,6 +158,8 @@ struct MethodBaseInner {
 
     virtual void insert_keys_into_columns(std::vector<Key>& keys, 
MutableColumns& key_columns,
                                           uint32_t num_rows) = 0;
+
+    virtual uint32_t direct_mapping_range() { return 0; }
 };
 
 template <typename T>
@@ -411,6 +414,68 @@ struct MethodOneNumber : public MethodBase<TData> {
     }
 };
 
+template <typename FieldType, typename TData>
+struct MethodOneNumberDirect : public MethodOneNumber<FieldType, TData> {
+    using Base = MethodOneNumber<FieldType, TData>;
+    using Base::init_iterator;
+    using Base::hash_table;
+    using State = ColumnsHashing::HashMethodOneNumber<typename Base::Value, 
typename Base::Mapped,
+                                                      FieldType>;
+    FieldType _max_key;
+    FieldType _min_key;
+
+    MethodOneNumberDirect(FieldType max_key, FieldType min_key)
+            : _max_key(max_key), _min_key(min_key) {}
+
+    void init_serialized_keys(const ColumnRawPtrs& key_columns, uint32_t 
num_rows,
+                              const uint8_t* null_map = nullptr, bool is_join 
= false,
+                              bool is_build = false, uint32_t bucket_size = 0) 
override {
+        Base::keys = (FieldType*)(key_columns[0]->is_nullable()
+                                          ? assert_cast<const 
ColumnNullable*>(key_columns[0])
+                                                    ->get_nested_column_ptr()
+                                                    ->get_raw_data()
+                                                    .data
+                                          : 
key_columns[0]->get_raw_data().data);
+        CHECK(is_join);
+        CHECK_EQ(bucket_size, direct_mapping_range());
+        Base::bucket_nums.resize(num_rows);
+
+        if (null_map == nullptr) {
+            if (is_build) {
+                for (uint32_t k = 1; k < num_rows; ++k) {
+                    Base::bucket_nums[k] = uint32_t(Base::keys[k] - _min_key + 
1);
+                }
+            } else {
+                for (uint32_t k = 0; k < num_rows; ++k) {
+                    Base::bucket_nums[k] = (Base::keys[k] >= _min_key && 
Base::keys[k] <= _max_key)
+                                                   ? uint32_t(Base::keys[k] - 
_min_key + 1)
+                                                   : 0;
+                }
+            }
+        } else {
+            if (is_build) {
+                for (uint32_t k = 1; k < num_rows; ++k) {
+                    Base::bucket_nums[k] =
+                            null_map[k] ? bucket_size : uint32_t(Base::keys[k] 
- _min_key + 1);
+                }
+            } else {
+                for (uint32_t k = 0; k < num_rows; ++k) {
+                    Base::bucket_nums[k] =
+                            null_map[k] ? bucket_size
+                            : (Base::keys[k] >= _min_key && Base::keys[k] <= 
_max_key)
+                                    ? uint32_t(Base::keys[k] - _min_key + 1)
+                                    : 0;
+                }
+            }
+        }
+    }
+
+    uint32_t direct_mapping_range() override {
+        // +2 to include max_key and one slot for out of range value
+        return static_cast<uint32_t>(_max_key - _min_key + 2);
+    }
+};
+
 template <typename TData>
 struct MethodKeysFixed : public MethodBase<TData> {
     using Base = MethodBase<TData>;
diff --git a/be/src/vec/common/hash_table/join_hash_table.h 
b/be/src/vec/common/hash_table/join_hash_table.h
index c6227591545..9426829e056 100644
--- a/be/src/vec/common/hash_table/join_hash_table.h
+++ b/be/src/vec/common/hash_table/join_hash_table.h
@@ -25,11 +25,17 @@
 #include "common/status.h"
 #include "vec/columns/column_filter_helper.h"
 #include "vec/common/custom_allocator.h"
-#include "vec/common/hash_table/hash.h"
 
 namespace doris {
 #include "common/compile_check_begin.h"
-template <typename Key, typename Hash = DefaultHash<Key>>
+
+inline uint32_t hash_join_table_calc_bucket_size(size_t num_elem) {
+    size_t expect_bucket_size = num_elem + (num_elem - 1) / 7;
+    return 
(uint32_t)std::min(phmap::priv::NormalizeCapacity(expect_bucket_size) + 1,
+                              
static_cast<size_t>(std::numeric_limits<int32_t>::max()) + 1);
+}
+
+template <typename Key, typename Hash, bool DirectMapping>
 class JoinHashTable {
 public:
     using key_type = Key;
@@ -37,25 +43,24 @@ public:
     using value_type = void*;
     size_t hash(const Key& x) const { return Hash()(x); }
 
-    static uint32_t calc_bucket_size(size_t num_elem) {
-        size_t expect_bucket_size = num_elem + (num_elem - 1) / 7;
-        return 
(uint32_t)std::min(phmap::priv::NormalizeCapacity(expect_bucket_size) + 1,
-                                  
static_cast<size_t>(std::numeric_limits<int32_t>::max()) + 1);
-    }
-
     size_t get_byte_size() const {
         auto cal_vector_mem = [](const auto& vec) { return vec.capacity() * 
sizeof(vec[0]); };
         return cal_vector_mem(visited) + cal_vector_mem(first) + 
cal_vector_mem(next);
     }
 
     template <int JoinOpType>
-    void prepare_build(size_t num_elem, int batch_size, bool has_null_key) {
+    void prepare_build(size_t num_elem, int batch_size, bool has_null_key,
+                       uint32_t force_bucket_size) {
         _has_null_key = has_null_key;
 
         // the first row in build side is not really from build side table
         _empty_build_side = num_elem <= 1;
         max_batch_size = batch_size;
-        bucket_size = calc_bucket_size(num_elem + 1);
+        if constexpr (DirectMapping) {
+            bucket_size = force_bucket_size;
+        } else {
+            bucket_size = hash_join_table_calc_bucket_size(num_elem + 1);
+        }
         first.resize(bucket_size + 1);
         next.resize(num_elem);
 
@@ -217,6 +222,13 @@ public:
     }
 
 private:
+    bool _eq(const Key& lhs, const Key& rhs) const {
+        if (DirectMapping) {
+            return true;
+        }
+        return lhs == rhs;
+    }
+
     template <int JoinOpType>
     auto _process_null_aware_left_half_join_for_empty_build_side(int 
probe_idx, int probe_rows,
                                                                  uint32_t* 
__restrict probe_idxs,
@@ -245,11 +257,20 @@ private:
         while (probe_idx < probe_rows) {
             auto build_idx = build_idx_map[probe_idx];
 
-            while (build_idx) {
-                if (!visited[build_idx] && keys[probe_idx] == 
build_keys[build_idx]) {
-                    visited[build_idx] = 1;
+            if constexpr (DirectMapping) {
+                if (!visited[build_idx]) {
+                    while (build_idx) {
+                        visited[build_idx] = 1;
+                        build_idx = next[build_idx];
+                    }
+                }
+            } else {
+                while (build_idx) {
+                    if (!visited[build_idx] && _eq(keys[probe_idx], 
build_keys[build_idx])) {
+                        visited[build_idx] = 1;
+                    }
+                    build_idx = next[build_idx];
                 }
-                build_idx = next[build_idx];
             }
             probe_idx++;
         }
@@ -293,7 +314,7 @@ private:
 
         auto do_the_probe = [&]() {
             while (build_idx && matched_cnt < batch_size) {
-                if (keys[probe_idx] == build_keys[build_idx]) {
+                if (_eq(keys[probe_idx], build_keys[build_idx])) {
                     build_idxs[matched_cnt] = build_idx;
                     probe_idxs[matched_cnt] = probe_idx;
                     matched_cnt++;
@@ -347,7 +368,7 @@ private:
 
         auto do_the_probe = [&]() {
             while (build_idx && matched_cnt < batch_size) {
-                if (keys[probe_idx] == build_keys[build_idx]) {
+                if (_eq(keys[probe_idx], build_keys[build_idx])) {
                     probe_idxs[matched_cnt] = probe_idx;
                     build_idxs[matched_cnt] = build_idx;
                     matched_cnt++;
@@ -408,7 +429,7 @@ private:
             }
 
             while (build_idx && matched_cnt < batch_size) {
-                if (picking_null_keys || keys[probe_idx] == 
build_keys[build_idx]) {
+                if (picking_null_keys || _eq(keys[probe_idx], 
build_keys[build_idx])) {
                     build_idxs[matched_cnt] = build_idx;
                     probe_idxs[matched_cnt] = probe_idx;
                     null_flags[matched_cnt] = picking_null_keys;
@@ -476,7 +497,7 @@ private:
     bool _empty_build_side = true;
 };
 
-template <typename Key, typename Hash = DefaultHash<Key>>
-using JoinHashMap = JoinHashTable<Key, Hash>;
+template <typename Key, typename Hash, bool DirectMapping>
+using JoinHashMap = JoinHashTable<Key, Hash, DirectMapping>;
 #include "common/compile_check_end.h"
 } // namespace doris
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to