github-actions[bot] commented on code in PR #26730:
URL: https://github.com/apache/doris/pull/26730#discussion_r1398225742


##########
be/src/vec/common/hash_table/hash_map.h:
##########
@@ -193,10 +200,346 @@ class HashMapTable : public HashTable<Key, Cell, Hash, 
Grower, Allocator> {
     bool has_null_key_data() const { return false; }
 };
 
+template <typename Key, typename Cell, typename Hash = DefaultHash<Key>,
+          typename Grower = HashTableGrower<>, typename Allocator = 
HashTableAllocator>
+class JoinHashMapTable : public HashMapTable<Key, Cell, Hash, Grower, 
Allocator> {
+public:
+    using Self = JoinHashMapTable;
+    using Base = HashMapTable<Key, Cell, Hash, Grower, Allocator>;
+
+    using key_type = Key;
+    using value_type = typename Cell::value_type;
+    using mapped_type = typename Cell::Mapped;
+
+    using LookupResult = typename Base::LookupResult;
+
+    using HashMapTable<Key, Cell, Hash, Grower, Allocator>::HashMapTable;
+
+    static uint32_t calc_bucket_size(size_t num_elem) {
+        size_t expect_bucket_size = num_elem + (num_elem - 1) / 7;
+        return phmap::priv::NormalizeCapacity(expect_bucket_size) + 1;
+    }
+
+    size_t get_byte_size() const {
+        auto cal_vector_mem = [](const auto& vec) { return vec.capacity() * 
sizeof(vec[0]); };
+        return cal_vector_mem(visited) + cal_vector_mem(first) + 
cal_vector_mem(next);
+    }
+
+    template <int JoinOpType>
+    void prepare_build(size_t num_elem, int batch_size) {
+        max_batch_size = batch_size;
+        bucket_size = calc_bucket_size(num_elem + 1);
+        first.resize(bucket_size + 1);
+        next.resize(num_elem);
+
+        if constexpr (JoinOpType == doris::TJoinOp::FULL_OUTER_JOIN ||
+                      JoinOpType == doris::TJoinOp::RIGHT_OUTER_JOIN ||
+                      JoinOpType == doris::TJoinOp::RIGHT_ANTI_JOIN ||
+                      JoinOpType == doris::TJoinOp::RIGHT_SEMI_JOIN) {
+            visited.resize(num_elem);
+        }
+    }
+
+    uint32_t get_bucket_size() const { return bucket_size; }
+
+    size_t size() const { return Base::size() == 0 ? next.size() : 
Base::size(); }
+
+    std::vector<uint8_t>& get_visited() { return visited; }
+
+    void build(const Key* __restrict keys, const uint32_t* __restrict 
bucket_nums,
+               size_t num_elem) {
+        build_keys = keys;
+        for (size_t i = 1; i < num_elem; i++) {
+            uint32_t bucket_num = bucket_nums[i];
+            next[i] = first[bucket_num];
+            first[bucket_num] = i;
+        }
+        first[bucket_size] = 0; // index = bucket_num means null
+    }
+
+    template <int JoinOpType, bool with_other_conjuncts, bool is_mark_join, 
bool need_judge_null>
+    auto find_batch(const Key* __restrict keys, const uint32_t* __restrict 
bucket_nums,
+                    int probe_idx, uint32_t build_idx, int probe_rows,
+                    uint32_t* __restrict probe_idxs, uint32_t* __restrict 
build_idxs,
+                    doris::vectorized::ColumnFilterHelper* mark_column) {
+        if constexpr (is_mark_join) {
+            return _find_batch_mark<JoinOpType>(keys, bucket_nums, probe_idx, 
probe_rows,
+                                                probe_idxs, build_idxs, 
mark_column);
+        }
+
+        if constexpr (with_other_conjuncts) {
+            return _find_batch_conjunct<JoinOpType>(keys, bucket_nums, 
probe_idx, build_idx,
+                                                    probe_rows, probe_idxs, 
build_idxs);
+        }
+
+        if constexpr (JoinOpType == doris::TJoinOp::INNER_JOIN ||
+                      JoinOpType == doris::TJoinOp::FULL_OUTER_JOIN ||
+                      JoinOpType == doris::TJoinOp::LEFT_OUTER_JOIN ||
+                      JoinOpType == doris::TJoinOp::RIGHT_OUTER_JOIN) {
+            return _find_batch_inner_outer_join<JoinOpType>(keys, bucket_nums, 
probe_idx, build_idx,
+                                                            probe_rows, 
probe_idxs, build_idxs);
+        }
+        if constexpr (JoinOpType == doris::TJoinOp::LEFT_ANTI_JOIN ||
+                      JoinOpType == doris::TJoinOp::LEFT_SEMI_JOIN ||
+                      JoinOpType == doris::TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) 
{
+            return _find_batch_left_semi_anti<JoinOpType, need_judge_null>(
+                    keys, bucket_nums, probe_idx, probe_rows, probe_idxs);
+        }
+        if constexpr (JoinOpType == doris::TJoinOp::RIGHT_ANTI_JOIN ||
+                      JoinOpType == doris::TJoinOp::RIGHT_SEMI_JOIN) {
+            return _find_batch_right_semi_anti(keys, bucket_nums, probe_idx, 
probe_rows);
+        }
+        return std::tuple {0, 0U, 0};
+    }
+
+    template <int JoinOpType>
+    bool iterate_map(std::vector<uint32_t>& build_idxs) const {
+        const auto batch_size = max_batch_size;
+        const auto elem_num = visited.size();
+        int count = 0;
+        build_idxs.resize(batch_size);
+
+        while (count < batch_size && iter_idx < elem_num) {
+            const auto matched = visited[iter_idx];
+            build_idxs[count] = iter_idx;
+            if constexpr (JoinOpType != doris::TJoinOp::RIGHT_SEMI_JOIN) {
+                count += !matched;
+            } else {
+                count += matched;
+            }
+            iter_idx++;
+        }
+
+        build_idxs.resize(count);
+        return iter_idx >= elem_num;
+    }
+
+private:
+    // only LEFT_ANTI_JOIN/LEFT_SEMI_JOIN/NULL_AWARE_LEFT_ANTI_JOIN/CROSS_JOIN 
support mark join
+    template <int JoinOpType>
+    auto _find_batch_mark(const Key* __restrict keys, const uint32_t* 
__restrict bucket_nums,
+                          int probe_idx, int probe_rows, uint32_t* __restrict 
probe_idxs,
+                          uint32_t* __restrict build_idxs,
+                          doris::vectorized::ColumnFilterHelper* mark_column) {
+        auto matched_cnt = 0;
+        const auto batch_size = max_batch_size;
+
+        while (probe_idx < probe_rows && matched_cnt < batch_size) {
+            auto build_idx = first[bucket_nums[probe_idx]];
+
+            while (build_idx && keys[probe_idx] != build_keys[build_idx]) {
+                build_idx = next[build_idx];
+            }
+
+            if (bucket_nums[probe_idx] == bucket_size) {
+                // mark result as null when probe row is null
+                mark_column->insert_null();
+            } else {
+                bool matched = JoinOpType == doris::TJoinOp::LEFT_SEMI_JOIN ? 
build_idx != 0
+                                                                            : 
build_idx == 0;
+                mark_column->insert_value(matched);
+            }
+
+            probe_idxs[matched_cnt] = probe_idx++;
+            build_idxs[matched_cnt] = build_idx;
+            matched_cnt++;
+        }
+        return std::tuple {probe_idx, 0U, matched_cnt};
+    }
+
+    auto _find_batch_right_semi_anti(const Key* __restrict keys,
+                                     const uint32_t* __restrict bucket_nums, 
int probe_idx,
+                                     int probe_rows) {
+        while (probe_idx < probe_rows) {
+            auto build_idx = first[bucket_nums[probe_idx]];
+
+            while (build_idx) {
+                if (!visited[build_idx] && keys[probe_idx] == 
build_keys[build_idx]) {
+                    visited[build_idx] = 1;
+                }
+                build_idx = next[build_idx];
+            }
+            probe_idx++;
+        }
+        return std::tuple {probe_idx, 0U, 0};
+    }
+
+    template <int JoinOpType, bool need_judge_null>
+    auto _find_batch_left_semi_anti(const Key* __restrict keys,
+                                    const uint32_t* __restrict bucket_nums, 
int probe_idx,
+                                    int probe_rows, uint32_t* __restrict 
probe_idxs) {
+        auto matched_cnt = 0;
+        const auto batch_size = max_batch_size;
+
+        while (probe_idx < probe_rows && matched_cnt < batch_size) {
+            if constexpr (need_judge_null) {
+                if (bucket_nums[probe_idx] == bucket_size) {
+                    probe_idx++;
+                    continue;
+                }
+            }
+
+            auto build_idx = first[bucket_nums[probe_idx]];
+
+            while (build_idx && keys[probe_idx] != build_keys[build_idx]) {
+                build_idx = next[build_idx];
+            }
+            bool matched =
+                    JoinOpType == doris::TJoinOp::LEFT_SEMI_JOIN ? build_idx 
!= 0 : build_idx == 0;
+            probe_idxs[matched_cnt] = probe_idx++;
+            matched_cnt += matched;
+        }
+        return std::tuple {probe_idx, 0U, matched_cnt};
+    }
+
+    auto _find_batch_left_semi_anti_conjunct(const Key* __restrict keys,
+                                             const uint32_t* __restrict 
bucket_nums, int probe_idx,
+                                             int probe_rows, uint32_t* 
__restrict probe_idxs,
+                                             uint32_t* __restrict build_idxs) {
+        auto matched_cnt = 0;
+        const auto batch_size = max_batch_size;
+
+        while (probe_idx < probe_rows && matched_cnt < batch_size) {
+            auto build_idx = first[bucket_nums[probe_idx]];
+
+            while (build_idx) {
+                if (keys[probe_idx] == build_keys[build_idx]) {
+                    probe_idxs[matched_cnt] = probe_idx;
+                    build_idxs[matched_cnt] = build_idx;
+                    matched_cnt++;
+                }
+                build_idx = next[build_idx];
+            }
+            probe_idx++;
+        }
+        return std::tuple {probe_idx, 0U, matched_cnt};
+    }
+
+    template <int JoinOpType>
+    auto _find_batch_conjunct(const Key* __restrict keys, const uint32_t* 
__restrict bucket_nums,
+                              int probe_idx, uint32_t build_idx, int 
probe_rows,
+                              uint32_t* __restrict probe_idxs, uint32_t* 
__restrict build_idxs) {
+        auto matched_cnt = 0;
+        const auto batch_size = max_batch_size;
+
+        auto do_the_probe = [&]() {
+            auto matched_cnt_old = matched_cnt;
+            while (build_idx && matched_cnt < batch_size) {
+                if constexpr (JoinOpType == doris::TJoinOp::RIGHT_ANTI_JOIN ||
+                              JoinOpType == doris::TJoinOp::RIGHT_SEMI_JOIN) {
+                    if (!visited[build_idx] && keys[probe_idx] == 
build_keys[build_idx]) {
+                        build_idxs[matched_cnt++] = build_idx;
+                    }
+                } else {
+                    build_idxs[matched_cnt++] = build_idx;
+                    matched_cnt += keys[probe_idx] == build_keys[build_idx];
+                }
+                build_idx = next[build_idx];
+            }
+
+            for (auto i = matched_cnt_old; i < matched_cnt; i++) {
+                probe_idxs[i] = probe_idx;
+            }
+
+            if constexpr (JoinOpType == doris::TJoinOp::LEFT_OUTER_JOIN ||
+                          JoinOpType == doris::TJoinOp::FULL_OUTER_JOIN) {
+                if (!build_idx) {
+                    probe_idxs[matched_cnt] = probe_idx;
+                    build_idxs[matched_cnt] = 0;
+                    matched_cnt++;
+                }
+            }
+
+            probe_idx++;
+        };
+
+        if (build_idx) {
+            do_the_probe();
+        }
+
+        while (probe_idx < probe_rows && matched_cnt < batch_size) {
+            build_idx = first[bucket_nums[probe_idx]];
+            do_the_probe();
+        }
+
+        probe_idx -=
+                (matched_cnt >= batch_size &&
+                 build_idx); // FULL_OUTER_JOIN may over batch_size when 
emplace 0 into build_idxs
+        return std::tuple {probe_idx, build_idx, matched_cnt};
+    }
+
+    template <int JoinOpType>
+    auto _find_batch_inner_outer_join(const Key* __restrict keys,
+                                      const uint32_t* __restrict bucket_nums, 
int probe_idx,
+                                      uint32_t build_idx, int probe_rows,
+                                      uint32_t* __restrict probe_idxs,
+                                      uint32_t* __restrict build_idxs) {
+        auto matched_cnt = 0;
+        const auto batch_size = max_batch_size;
+
+        auto do_the_probe = [&]() {
+            while (build_idx && matched_cnt < batch_size) {
+                if (keys[probe_idx] == build_keys[build_idx]) {
+                    probe_idxs[matched_cnt] = probe_idx;
+                    build_idxs[matched_cnt] = build_idx;
+                    matched_cnt++;
+                    if constexpr (JoinOpType == 
doris::TJoinOp::RIGHT_OUTER_JOIN ||
+                                  JoinOpType == 
doris::TJoinOp::FULL_OUTER_JOIN) {
+                        if (!visited[build_idx]) {
+                            visited[build_idx] = 1;
+                        }
+                    }
+                }
+                build_idx = next[build_idx];
+            }
+
+            if constexpr (JoinOpType == doris::TJoinOp::LEFT_OUTER_JOIN ||
+                          JoinOpType == doris::TJoinOp::FULL_OUTER_JOIN) {
+                // `(!matched_cnt || probe_idxs[matched_cnt - 1] != 
probe_idx)` means not match one build side
+                if (!matched_cnt || probe_idxs[matched_cnt - 1] != probe_idx) {
+                    probe_idxs[matched_cnt] = probe_idx;
+                    build_idxs[matched_cnt] = 0;
+                    matched_cnt++;
+                }
+            }
+            probe_idx++;
+        };
+
+        if (build_idx) {
+            do_the_probe();
+        }
+
+        while (probe_idx < probe_rows && matched_cnt < batch_size) {
+            build_idx = first[bucket_nums[probe_idx]];
+            do_the_probe();
+        }
+
+        probe_idx -= (matched_cnt == batch_size && build_idx);
+        return std::tuple {probe_idx, build_idx, matched_cnt};
+    }
+
+    const Key* __restrict build_keys;
+    std::vector<uint8_t> visited;
+
+    uint32_t bucket_size = 1;
+    int max_batch_size = 4064;

Review Comment:
   warning: 4064 is a magic number; consider replacing it with a named constant 
[readability-magic-numbers]
   ```cpp
       int max_batch_size = 4064;
                            ^
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to