This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit c9ab243153b78e4a5df305acb4d8dcb4efd98d1c
Author: Jerry Hu <[email protected]>
AuthorDate: Sun Feb 4 10:11:06 2024 +0800

    [feat-wip](join) support mark join for right semi join(without mark join 
conjunct) (#30767)
---
 be/src/pipeline/exec/hashjoin_probe_operator.cpp   |  2 +-
 be/src/pipeline/exec/join_build_sink_operator.cpp  |  6 ++--
 be/src/vec/common/hash_table/join_hash_table.h     | 34 +++++++++++++---------
 be/src/vec/exec/join/process_hash_table_probe.h    |  2 +-
 .../vec/exec/join/process_hash_table_probe_impl.h  | 22 ++++++++++----
 be/src/vec/exec/join/vhash_join_node.cpp           |  2 +-
 be/src/vec/exec/join/vjoin_node_base.cpp           |  6 ++--
 7 files changed, 48 insertions(+), 26 deletions(-)

diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.cpp 
b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
index a2fb0012ffa..1cf3e54572e 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.cpp
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
@@ -343,7 +343,7 @@ Status HashJoinProbeOperatorX::pull(doris::RuntimeState* 
state, vectorized::Bloc
                             if constexpr (!std::is_same_v<HashTableCtxType, 
std::monostate>) {
                                 bool eos = false;
                                 st = 
process_hashtable_ctx.process_data_in_hashtable(
-                                        arg, mutable_join_block, &temp_block, 
&eos);
+                                        arg, mutable_join_block, &temp_block, 
&eos, _is_mark_join);
                                 source_state = eos ? SourceState::FINISHED : 
source_state;
                             } else {
                                 st = Status::InternalError("uninited hash 
table");
diff --git a/be/src/pipeline/exec/join_build_sink_operator.cpp 
b/be/src/pipeline/exec/join_build_sink_operator.cpp
index 73ebfb947a3..6b930ff8a5e 100644
--- a/be/src/pipeline/exec/join_build_sink_operator.cpp
+++ b/be/src/pipeline/exec/join_build_sink_operator.cpp
@@ -83,8 +83,10 @@ 
JoinBuildSinkOperatorX<LocalStateType>::JoinBuildSinkOperatorX(ObjectPool* pool,
     if (_is_mark_join) {
         DCHECK(_join_op == TJoinOp::LEFT_ANTI_JOIN || _join_op == 
TJoinOp::LEFT_SEMI_JOIN ||
                _join_op == TJoinOp::CROSS_JOIN || _join_op == 
TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
-               _join_op == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN)
-                << "Mark join is only supported for null aware left semi/anti 
join and cross join "
+               _join_op == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN ||
+               _join_op == TJoinOp::RIGHT_SEMI_JOIN)
+                << "Mark join is only supported for null aware left semi/anti 
join and right semi "
+                   "join and cross join "
                    "but this is "
                 << _join_op;
     }
diff --git a/be/src/vec/common/hash_table/join_hash_table.h 
b/be/src/vec/common/hash_table/join_hash_table.h
index 85665e76853..5ae0a13acef 100644
--- a/be/src/vec/common/hash_table/join_hash_table.h
+++ b/be/src/vec/common/hash_table/join_hash_table.h
@@ -98,7 +98,8 @@ public:
             }
         }
 
-        if constexpr (with_other_conjuncts || is_mark_join) {
+        if constexpr (with_other_conjuncts ||
+                      (is_mark_join && JoinOpType != 
TJoinOp::RIGHT_SEMI_JOIN)) {
             return _find_batch_conjunct<JoinOpType, need_judge_null>(
                     keys, build_idx_map, probe_idx, build_idx, probe_rows, 
probe_idxs, build_idxs);
         }
@@ -203,8 +204,9 @@ public:
         return std::tuple {probe_idx, build_idx, matched_cnt, 
picking_null_keys};
     }
 
-    template <int JoinOpType>
-    bool iterate_map(std::vector<uint32_t>& build_idxs) const {
+    template <int JoinOpType, bool is_mark_join>
+    bool iterate_map(std::vector<uint32_t>& build_idxs,
+                     vectorized::ColumnFilterHelper* mark_column_helper) const 
{
         const auto batch_size = max_batch_size;
         const auto elem_num = visited.size();
         int count = 0;
@@ -213,10 +215,15 @@ public:
         while (count < batch_size && iter_idx < elem_num) {
             const auto matched = visited[iter_idx];
             build_idxs[count] = iter_idx;
-            if constexpr (JoinOpType != TJoinOp::RIGHT_SEMI_JOIN) {
-                count += !matched;
+            if constexpr (JoinOpType == TJoinOp::RIGHT_SEMI_JOIN) {
+                if constexpr (is_mark_join) {
+                    mark_column_helper->insert_value(matched);
+                    ++count;
+                } else {
+                    count += matched;
+                }
             } else {
-                count += matched;
+                count += !matched;
             }
             iter_idx++;
         }
@@ -228,15 +235,16 @@ public:
     bool has_null_key() { return _has_null_key; }
 
     void pre_build_idxs(std::vector<uint32>& buckets, const uint8_t* null_map) 
{
+        const auto first_at_bucket_size = first[bucket_size];
         if (null_map) {
-            for (unsigned int& bucket : buckets) {
-                bucket = bucket == bucket_size ? bucket_size : first[bucket];
-            }
-        } else {
-            for (unsigned int& bucket : buckets) {
-                bucket = first[bucket];
-            }
+            first[bucket_size] = bucket_size; // distinguish between not 
matched and null
+        }
+
+        for (unsigned int& bucket : buckets) {
+            bucket = first[bucket];
         }
+
+        first[bucket_size] = first_at_bucket_size;
     }
 
 private:
diff --git a/be/src/vec/exec/join/process_hash_table_probe.h 
b/be/src/vec/exec/join/process_hash_table_probe.h
index 9f4ddbabdcb..924974ca915 100644
--- a/be/src/vec/exec/join/process_hash_table_probe.h
+++ b/be/src/vec/exec/join/process_hash_table_probe.h
@@ -82,7 +82,7 @@ struct ProcessHashTableProbe {
     // in hash table
     template <typename HashTableType>
     Status process_data_in_hashtable(HashTableType& hash_table_ctx, 
MutableBlock& mutable_block,
-                                     Block* output_block, bool* eos);
+                                     Block* output_block, bool* eos, bool 
is_mark_join);
 
     /// For null aware join with other conjuncts, if the probe key of one row 
on left side is null,
     /// we should make this row match with all rows in build side.
diff --git a/be/src/vec/exec/join/process_hash_table_probe_impl.h 
b/be/src/vec/exec/join/process_hash_table_probe_impl.h
index 06dfdac9074..d7b47bed9c0 100644
--- a/be/src/vec/exec/join/process_hash_table_probe_impl.h
+++ b/be/src/vec/exec/join/process_hash_table_probe_impl.h
@@ -174,7 +174,8 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_process(HashTableType& hash
                 need_null_map_for_probe && ignore_null &&
                         (JoinOpType == doris::TJoinOp::LEFT_ANTI_JOIN ||
                          JoinOpType == doris::TJoinOp::LEFT_SEMI_JOIN ||
-                         JoinOpType == 
doris::TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN || is_mark_join));
+                         JoinOpType == 
doris::TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
+                         (is_mark_join && JoinOpType != 
doris::TJoinOp::RIGHT_SEMI_JOIN)));
     }
 
     auto& mcol = mutable_block.mutable_columns();
@@ -253,7 +254,7 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_process(HashTableType& hash
 
     output_block->swap(mutable_block.to_block());
 
-    if constexpr (is_mark_join) {
+    if constexpr (is_mark_join && JoinOpType != TJoinOp::RIGHT_SEMI_JOIN) {
         return do_mark_join_conjuncts<with_other_conjuncts>(
                 output_block, hash_table_ctx.hash_table->get_bucket_size());
     } else if constexpr (with_other_conjuncts) {
@@ -572,11 +573,20 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_other_join_conjuncts(
 template <int JoinOpType, typename Parent>
 template <typename HashTableType>
 Status ProcessHashTableProbe<JoinOpType, Parent>::process_data_in_hashtable(
-        HashTableType& hash_table_ctx, MutableBlock& mutable_block, Block* 
output_block,
-        bool* eos) {
+        HashTableType& hash_table_ctx, MutableBlock& mutable_block, Block* 
output_block, bool* eos,
+        bool is_mark_join) {
     SCOPED_TIMER(_probe_process_hashtable_timer);
     auto& mcol = mutable_block.mutable_columns();
-    *eos = hash_table_ctx.hash_table->template 
iterate_map<JoinOpType>(_build_indexs);
+    if (is_mark_join) {
+        std::unique_ptr<ColumnFilterHelper> mark_column =
+                std::make_unique<ColumnFilterHelper>(*mcol[mcol.size() - 1]);
+        *eos = hash_table_ctx.hash_table->template iterate_map<JoinOpType, 
true>(_build_indexs,
+                                                                               
  mark_column.get());
+    } else {
+        *eos = hash_table_ctx.hash_table->template iterate_map<JoinOpType, 
false>(_build_indexs,
+                                                                               
   nullptr);
+    }
+
     auto block_size = _build_indexs.size();
 
     if (block_size) {
@@ -661,7 +671,7 @@ struct ExtractType<T(U)> {
     template Status ProcessHashTableProbe<JoinOpType, 
Parent>::process_data_in_hashtable<         \
             ExtractType<void(T)>::Type>(ExtractType<void(T)>::Type & 
hash_table_ctx,              \
                                         MutableBlock & mutable_block, Block * 
output_block,       \
-                                        bool* eos)
+                                        bool* eos, bool is_mark_join);
 
 #define INSTANTIATION_FOR1(JoinOpType, Parent)                                \
     template struct ProcessHashTableProbe<JoinOpType, Parent>;                \
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp 
b/be/src/vec/exec/join/vhash_join_node.cpp
index ec630f3fe32..ffb01aed552 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -431,7 +431,7 @@ Status HashJoinNode::pull(doris::RuntimeState* state, 
vectorized::Block* output_
                             using HashTableCtxType = 
std::decay_t<decltype(arg)>;
                             if constexpr (!std::is_same_v<HashTableCtxType, 
std::monostate>) {
                                 st = 
process_hashtable_ctx.process_data_in_hashtable(
-                                        arg, mutable_join_block, &temp_block, 
eos);
+                                        arg, mutable_join_block, &temp_block, 
eos, _is_mark_join);
                             } else {
                                 st = Status::InternalError("uninited hash 
table");
                             }
diff --git a/be/src/vec/exec/join/vjoin_node_base.cpp 
b/be/src/vec/exec/join/vjoin_node_base.cpp
index e81851b7d93..3436209c4cd 100644
--- a/be/src/vec/exec/join/vjoin_node_base.cpp
+++ b/be/src/vec/exec/join/vjoin_node_base.cpp
@@ -84,8 +84,10 @@ VJoinNodeBase::VJoinNodeBase(ObjectPool* pool, const 
TPlanNode& tnode, const Des
     if (_is_mark_join) {
         DCHECK(_join_op == TJoinOp::LEFT_ANTI_JOIN || _join_op == 
TJoinOp::LEFT_SEMI_JOIN ||
                _join_op == TJoinOp::CROSS_JOIN || _join_op == 
TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
-               _join_op == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN)
-                << "Mark join is only supported for null aware left semi/anti 
join and cross join "
+               _join_op == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN ||
+               _join_op == TJoinOp::RIGHT_SEMI_JOIN)
+                << "Mark join is only supported for null aware left semi/anti 
join and right semi "
+                   "join "
                    "but this is "
                 << _join_op;
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to