lidavidm commented on a change in pull request #10845:
URL: https://github.com/apache/arrow/pull/10845#discussion_r691393145



##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+                          const std::shared_ptr<Schema>& right_schema,
+                          const std::vector<int>& left_keys,
+                          const std::vector<int>& right_keys) {
+  if (left_keys.size() != right_keys.size()) {
+    return Status::Invalid("left and right key sizes do not match");
+  }
+
+  for (size_t i = 0; i < left_keys.size(); i++) {
+    auto l_type = left_schema->field(left_keys[i])->type();
+    auto r_type = right_schema->field(right_keys[i])->type();
+
+    if (!l_type->Equals(r_type)) {
+      return Status::Invalid("build and probe types do not match: " + 
l_type->ToString() +
+                             "!=" + r_type->ToString());
+    }
+  }
+
+  return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+                                      const std::vector<FieldRef>& keys) {
+  std::vector<int> key_field_ids(keys.size());
+  // Find input field indices for left key fields
+  for (size_t i = 0; i < keys.size(); ++i) {
+    ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+    key_field_ids[i] = match[0];
+  }
+  return key_field_ids;
+}
+}  // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+  HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* 
ctx,
+                   const std::vector<int>&& build_index_field_ids,
+                   const std::vector<int>&& probe_index_field_ids)
+      : ExecNode(build_input->plan(), {build_input, probe_input},
+                 {"hash_join_build", "hash_join_probe"}, 
probe_input->output_schema(),
+                 /*num_outputs=*/1),
+        ctx_(ctx),
+        build_index_field_ids_(build_index_field_ids),
+        probe_index_field_ids_(probe_index_field_ids),
+        build_result_index(-1),
+        hash_table_built_(false),
+        cached_probe_batches_consumed(false) {}
+
+ private:
+  struct ThreadLocalState;
+
+ public:
+  const char* kind_name() override { return "HashSemiJoinNode"; }
+
+  Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+    ARROW_LOG(DEBUG) << "init state";
+
+    // Get input schema
+    auto build_schema = inputs_[0]->output_schema();
+
+    if (state->grouper != nullptr) return Status::OK();
+
+    // Build vector of key field data types
+    std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+    for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+      auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+      key_descrs[i] = ValueDescr(build_type);
+    }
+
+    // Construct grouper
+    ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, 
ctx_));
+
+    return Status::OK();
+  }
+
+  // Finds an appropriate index which could accumulate all build indices (i.e. 
the grouper
+  // which has the highest # of groups)
+  void CalculateBuildResultIndex() {
+    int32_t curr_max = -1;
+    for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+      auto* state = &local_states_[i];
+      ARROW_DCHECK(state);
+      if (state->grouper &&
+          curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+        curr_max = static_cast<int32_t>(state->grouper->num_groups());
+        build_result_index = i;
+      }
+    }
+    ARROW_DCHECK(build_result_index > -1);
+    ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+  }
+
+  // Performs the housekeeping work after the build-side is completed.
+  // Note: this method is not thread safe, and hence should be guaranteed that 
it is
+  // not accessed concurrently!
+  Status BuildSideCompleted() {
+    ARROW_LOG(DEBUG) << "build side merge";
+
+    // if the hash table has already been built, return
+    if (hash_table_built_) return Status::OK();
+
+    CalculateBuildResultIndex();
+
+    // merge every group into the build_result_index grouper
+    ThreadLocalState* result_state = &local_states_[build_result_index];
+    for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+      ThreadLocalState* state = &local_states_[i];
+      ARROW_DCHECK(state);
+      if (i == build_result_index || !state->grouper) {
+        continue;
+      }
+      ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, 
state->grouper->GetUniques());
+
+      // TODO(niranda) replace with void consume method
+      ARROW_ASSIGN_OR_RAISE(Datum _, 
result_state->grouper->Consume(other_keys));
+      state->grouper.reset();
+    }
+
+    // enable flag that build side is completed
+    hash_table_built_ = true;
+
+    // since the build side is completed, consume cached probe batches
+    RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+    return Status::OK();
+  }
+
+  // consumes a build batch and increments the build_batches count. if the 
build batches
+  // total reached at the end of consumption, all the local states will be 
merged, before
+  // incrementing the total batches
+  Status ConsumeBuildBatch(ExecBatch batch) {
+    size_t thread_index = get_thread_index_();
+    ARROW_DCHECK(thread_index < local_states_.size());
+
+    ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+                     << " len:" << batch.length;
+
+    auto state = &local_states_[thread_index];
+    RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+    // Create a batch with key columns
+    std::vector<Datum> keys(build_index_field_ids_.size());
+    for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+      keys[i] = batch.values[build_index_field_ids_[i]];
+    }
+    ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+    // Create a batch with group ids
+    // TODO(niranda) replace with void consume method
+    ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+    if (build_counter_.Increment()) {
+      // only one thread would get inside this block!
+      // while incrementing, if the total is reached, call BuildSideCompleted.
+      RETURN_NOT_OK(BuildSideCompleted());
+    }
+
+    return Status::OK();
+  }
+
+  // consumes cached probe batches by invoking executor::Spawn.
+  Status ConsumeCachedProbeBatches() {
+    ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+                     << " len:" << cached_probe_batches.size();
+
+    // acquire the mutex to access cached_probe_batches, because while 
consuming, other
+    // batches should not be cached!
+    std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+    if (!cached_probe_batches_consumed) {
+      auto executor = ctx_->executor();
+      for (auto&& cached : cached_probe_batches) {
+        if (executor) {
+          Status lambda_status;
+          RETURN_NOT_OK(executor->Spawn([&] {
+            lambda_status = ConsumeProbeBatch(cached.first, 
std::move(cached.second));
+          }));
+
+          // if the lambda execution failed internally, return status
+          RETURN_NOT_OK(lambda_status);
+        } else {
+          RETURN_NOT_OK(ConsumeProbeBatch(cached.first, 
std::move(cached.second)));
+        }
+      }
+      // cached vector will be cleared. exec batches are expected to be moved 
to the
+      // lambdas
+      cached_probe_batches.clear();
+    }
+
+    // set flag
+    cached_probe_batches_consumed = true;
+    return Status::OK();
+  }
+
+  Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch 
batch) {
+    if (group_ids_data.GetNullCount() == batch.length) {
+      // All NULLS! hence, there are no valid outputs!
+      ARROW_LOG(DEBUG) << "output seq:" << seq << " 0";
+      outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0));
+    } else if (group_ids_data.MayHaveNulls()) {  // values need to be filtered
+      auto filter_arr =
+          std::make_shared<BooleanArray>(group_ids_data.length, 
group_ids_data.buffers[0],
+                                         /*null_bitmap=*/nullptr, 
/*null_count=*/0,
+                                         /*offset=*/group_ids_data.offset);
+      ARROW_ASSIGN_OR_RAISE(auto rec_batch,
+                            batch.ToRecordBatch(output_schema_, 
ctx_->memory_pool()));
+      ARROW_ASSIGN_OR_RAISE(
+          auto filtered,
+          Filter(rec_batch, filter_arr,
+                 /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_));
+      auto out_batch = ExecBatch(*filtered.record_batch());
+      ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length;
+      outputs_[0]->InputReceived(this, seq, std::move(out_batch));
+    } else {  // all values are valid for output
+      ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length;
+      outputs_[0]->InputReceived(this, seq, std::move(batch));
+    }
+
+    return Status::OK();
+  }
+
+  // consumes a probe batch and increment probe batches count. Probing would 
query the
+  // grouper[build_result_index] which have been merged with all others.
+  Status ConsumeProbeBatch(int seq, ExecBatch batch) {
+    ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq;
+
+    auto& final_grouper = *local_states_[build_result_index].grouper;
+
+    // Create a batch with key columns
+    std::vector<Datum> keys(probe_index_field_ids_.size());
+    for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) {
+      keys[i] = batch.values[probe_index_field_ids_[i]];
+    }
+    ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+    // Query the grouper with key_batch. If no match was found, returning 
group_ids would
+    // have null.
+    ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch));
+    auto group_ids_data = *group_ids.array();
+
+    RETURN_NOT_OK(GenerateOutput(seq, group_ids_data, std::move(batch)));
+
+    if (out_counter_.Increment()) {
+      finished_.MarkFinished();
+    }
+    return Status::OK();
+  }
+
+  // Attempt to cache a probe batch. If it is not cached, return false.
+  // if cached_probe_batches_consumed is true, by the time a thread acquires
+  // cached_probe_batches_mutex, it should no longer be cached! instead, it 
can be
+  //  directly consumed!
+  bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) {

Review comment:
       In ARROW-13660 we'll get rid of seq_num fortunately




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to