lidavidm commented on a change in pull request #10845:
URL: https://github.com/apache/arrow/pull/10845#discussion_r691238077
##########
File path: cpp/src/arrow/compute/exec/options.h
##########
@@ -111,5 +111,32 @@ class ARROW_EXPORT SinkNodeOptions : public
ExecNodeOptions {
std::function<Future<util::optional<ExecBatch>>()>* generator;
};
+enum JoinType {
Review comment:
nit: enum class? (Though the compute functions are already somewhat
inconsistent about this)
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -368,30 +370,43 @@ struct GrouperImpl : Grouper {
return std::move(impl);
}
- Result<Datum> Consume(const ExecBatch& batch) override {
- std::vector<int32_t> offsets_batch(batch.length + 1);
+ Status PopulateKeyData(const ExecBatch& batch, std::vector<int32_t>*
offsets_batch,
+ std::vector<uint8_t>* key_bytes_batch,
+ std::vector<uint8_t*>* key_buf_ptrs) const {
+ offsets_batch->resize(batch.length + 1);
for (int i = 0; i < batch.num_values(); ++i) {
- encoders_[i]->AddLength(*batch[i].array(), offsets_batch.data());
+ encoders_[i]->AddLength(*batch[i].array(), offsets_batch->data());
}
int32_t total_length = 0;
for (int64_t i = 0; i < batch.length; ++i) {
auto total_length_before = total_length;
- total_length += offsets_batch[i];
- offsets_batch[i] = total_length_before;
+ total_length += offsets_batch->at(i);
Review comment:
note that `at` forces bounds checking, you might prefer
`(*offsets_batch)[i]` instead
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+ const std::shared_ptr<Schema>& right_schema,
+ const std::vector<int>& left_keys,
+ const std::vector<int>& right_keys) {
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("left and right key sizes do not match");
+ }
+
+ for (size_t i = 0; i < left_keys.size(); i++) {
+ auto l_type = left_schema->field(left_keys[i])->type();
+ auto r_type = right_schema->field(right_keys[i])->type();
+
+ if (!l_type->Equals(r_type)) {
+ return Status::Invalid("build and probe types do not match: " +
l_type->ToString() +
+ "!=" + r_type->ToString());
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+ const std::vector<FieldRef>& keys) {
+ std::vector<int> key_field_ids(keys.size());
+ // Find input field indices for left key fields
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+ key_field_ids[i] = match[0];
+ }
+ return key_field_ids;
+}
+} // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+ HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext*
ctx,
+ const std::vector<int>&& build_index_field_ids,
+ const std::vector<int>&& probe_index_field_ids)
+ : ExecNode(build_input->plan(), {build_input, probe_input},
+ {"hash_join_build", "hash_join_probe"},
probe_input->output_schema(),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ build_index_field_ids_(build_index_field_ids),
+ probe_index_field_ids_(probe_index_field_ids),
+ build_result_index(-1),
+ hash_table_built_(false),
+ cached_probe_batches_consumed(false) {}
+
+ private:
+ struct ThreadLocalState;
+
+ public:
+ const char* kind_name() override { return "HashSemiJoinNode"; }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ ARROW_LOG(DEBUG) << "init state";
+
+ // Get input schema
+ auto build_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+ key_descrs[i] = ValueDescr(build_type);
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs,
ctx_));
+
+ return Status::OK();
+ }
+
+ // Finds an appropriate index which could accumulate all build indices (i.e.
the grouper
+ // which has the highest # of groups)
+ void CalculateBuildResultIndex() {
+ int32_t curr_max = -1;
+ for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+ auto* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (state->grouper &&
+ curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+ curr_max = static_cast<int32_t>(state->grouper->num_groups());
+ build_result_index = i;
+ }
+ }
+ ARROW_DCHECK(build_result_index > -1);
+ ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+ }
+
+ // Performs the housekeeping work after the build-side is completed.
+ // Note: this method is not thread safe, and hence should be guaranteed that
it is
+ // not accessed concurrently!
+ Status BuildSideCompleted() {
+ ARROW_LOG(DEBUG) << "build side merge";
+
+ // if the hash table has already been built, return
+ if (hash_table_built_) return Status::OK();
+
+ CalculateBuildResultIndex();
+
+ // merge every group into the build_result_index grouper
+ ThreadLocalState* result_state = &local_states_[build_result_index];
+ for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+ ThreadLocalState* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (i == build_result_index || !state->grouper) {
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys,
state->grouper->GetUniques());
+
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _,
result_state->grouper->Consume(other_keys));
+ state->grouper.reset();
+ }
+
+ // enable flag that build side is completed
+ hash_table_built_ = true;
+
+ // since the build side is completed, consume cached probe batches
+ RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+ return Status::OK();
+ }
+
+ // consumes a build batch and increments the build_batches count. if the
build batches
+ // total reached at the end of consumption, all the local states will be
merged, before
+ // incrementing the total batches
+ Status ConsumeBuildBatch(ExecBatch batch) {
+ size_t thread_index = get_thread_index_();
+ ARROW_DCHECK(thread_index < local_states_.size());
+
+ ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+ << " len:" << batch.length;
+
+ auto state = &local_states_[thread_index];
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[build_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Create a batch with group ids
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+ if (build_counter_.Increment()) {
+ // only one thread would get inside this block!
+ // while incrementing, if the total is reached, call BuildSideCompleted.
+ RETURN_NOT_OK(BuildSideCompleted());
+ }
+
+ return Status::OK();
+ }
+
+ // consumes cached probe batches by invoking executor::Spawn.
+ Status ConsumeCachedProbeBatches() {
+ ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+ << " len:" << cached_probe_batches.size();
+
+ // acquire the mutex to access cached_probe_batches, because while
consuming, other
+ // batches should not be cached!
+ std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+ if (!cached_probe_batches_consumed) {
+ auto executor = ctx_->executor();
+ for (auto&& cached : cached_probe_batches) {
+ if (executor) {
+ Status lambda_status;
+ RETURN_NOT_OK(executor->Spawn([&] {
+ lambda_status = ConsumeProbeBatch(cached.first,
std::move(cached.second));
Review comment:
Won't lambda_status be invalid?
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+ const std::shared_ptr<Schema>& right_schema,
+ const std::vector<int>& left_keys,
+ const std::vector<int>& right_keys) {
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("left and right key sizes do not match");
+ }
+
+ for (size_t i = 0; i < left_keys.size(); i++) {
+ auto l_type = left_schema->field(left_keys[i])->type();
+ auto r_type = right_schema->field(right_keys[i])->type();
+
+ if (!l_type->Equals(r_type)) {
+ return Status::Invalid("build and probe types do not match: " +
l_type->ToString() +
+ "!=" + r_type->ToString());
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+ const std::vector<FieldRef>& keys) {
+ std::vector<int> key_field_ids(keys.size());
+ // Find input field indices for left key fields
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+ key_field_ids[i] = match[0];
+ }
+ return key_field_ids;
+}
+} // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+ HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext*
ctx,
+ const std::vector<int>&& build_index_field_ids,
+ const std::vector<int>&& probe_index_field_ids)
+ : ExecNode(build_input->plan(), {build_input, probe_input},
+ {"hash_join_build", "hash_join_probe"},
probe_input->output_schema(),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ build_index_field_ids_(build_index_field_ids),
+ probe_index_field_ids_(probe_index_field_ids),
+ build_result_index(-1),
+ hash_table_built_(false),
+ cached_probe_batches_consumed(false) {}
+
+ private:
+ struct ThreadLocalState;
+
+ public:
+ const char* kind_name() override { return "HashSemiJoinNode"; }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ ARROW_LOG(DEBUG) << "init state";
+
+ // Get input schema
+ auto build_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+ key_descrs[i] = ValueDescr(build_type);
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs,
ctx_));
+
+ return Status::OK();
+ }
+
+ // Finds an appropriate index which could accumulate all build indices (i.e.
the grouper
+ // which has the highest # of groups)
+ void CalculateBuildResultIndex() {
+ int32_t curr_max = -1;
+ for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+ auto* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (state->grouper &&
+ curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+ curr_max = static_cast<int32_t>(state->grouper->num_groups());
+ build_result_index = i;
+ }
+ }
+ ARROW_DCHECK(build_result_index > -1);
+ ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+ }
+
+ // Performs the housekeeping work after the build-side is completed.
+ // Note: this method is not thread safe, and hence should be guaranteed that
it is
+ // not accessed concurrently!
+ Status BuildSideCompleted() {
+ ARROW_LOG(DEBUG) << "build side merge";
+
+ // if the hash table has already been built, return
+ if (hash_table_built_) return Status::OK();
Review comment:
I think use of AtomicCounter means that this check is unnecessary/can be
a DCHECK.
Though I also think that this might need to be an atomic<bool> to be safe
anyways, since I don't think there is any happens-before relationship between
the plain bool here and any atomic/mutex.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+ const std::shared_ptr<Schema>& right_schema,
+ const std::vector<int>& left_keys,
+ const std::vector<int>& right_keys) {
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("left and right key sizes do not match");
+ }
+
+ for (size_t i = 0; i < left_keys.size(); i++) {
+ auto l_type = left_schema->field(left_keys[i])->type();
+ auto r_type = right_schema->field(right_keys[i])->type();
+
+ if (!l_type->Equals(r_type)) {
+ return Status::Invalid("build and probe types do not match: " +
l_type->ToString() +
+ "!=" + r_type->ToString());
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+ const std::vector<FieldRef>& keys) {
+ std::vector<int> key_field_ids(keys.size());
+ // Find input field indices for left key fields
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+ key_field_ids[i] = match[0];
+ }
+ return key_field_ids;
+}
+} // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+ HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext*
ctx,
+ const std::vector<int>&& build_index_field_ids,
+ const std::vector<int>&& probe_index_field_ids)
+ : ExecNode(build_input->plan(), {build_input, probe_input},
+ {"hash_join_build", "hash_join_probe"},
probe_input->output_schema(),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ build_index_field_ids_(build_index_field_ids),
+ probe_index_field_ids_(probe_index_field_ids),
+ build_result_index(-1),
+ hash_table_built_(false),
+ cached_probe_batches_consumed(false) {}
+
+ private:
+ struct ThreadLocalState;
+
+ public:
+ const char* kind_name() override { return "HashSemiJoinNode"; }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ ARROW_LOG(DEBUG) << "init state";
+
+ // Get input schema
+ auto build_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+ key_descrs[i] = ValueDescr(build_type);
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs,
ctx_));
+
+ return Status::OK();
+ }
+
+ // Finds an appropriate index which could accumulate all build indices (i.e.
the grouper
+ // which has the highest # of groups)
+ void CalculateBuildResultIndex() {
+ int32_t curr_max = -1;
+ for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+ auto* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (state->grouper &&
+ curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+ curr_max = static_cast<int32_t>(state->grouper->num_groups());
+ build_result_index = i;
+ }
+ }
+ ARROW_DCHECK(build_result_index > -1);
+ ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+ }
+
+ // Performs the housekeeping work after the build-side is completed.
+ // Note: this method is not thread safe, and hence should be guaranteed that
it is
+ // not accessed concurrently!
+ Status BuildSideCompleted() {
+ ARROW_LOG(DEBUG) << "build side merge";
+
+ // if the hash table has already been built, return
+ if (hash_table_built_) return Status::OK();
+
+ CalculateBuildResultIndex();
+
+ // merge every group into the build_result_index grouper
+ ThreadLocalState* result_state = &local_states_[build_result_index];
+ for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+ ThreadLocalState* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (i == build_result_index || !state->grouper) {
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys,
state->grouper->GetUniques());
+
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _,
result_state->grouper->Consume(other_keys));
+ state->grouper.reset();
+ }
+
+ // enable flag that build side is completed
+ hash_table_built_ = true;
+
+ // since the build side is completed, consume cached probe batches
+ RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+ return Status::OK();
+ }
+
+ // consumes a build batch and increments the build_batches count. if the
build batches
+ // total reached at the end of consumption, all the local states will be
merged, before
+ // incrementing the total batches
+ Status ConsumeBuildBatch(ExecBatch batch) {
+ size_t thread_index = get_thread_index_();
+ ARROW_DCHECK(thread_index < local_states_.size());
+
+ ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+ << " len:" << batch.length;
+
+ auto state = &local_states_[thread_index];
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[build_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Create a batch with group ids
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+ if (build_counter_.Increment()) {
+ // only one thread would get inside this block!
+ // while incrementing, if the total is reached, call BuildSideCompleted.
+ RETURN_NOT_OK(BuildSideCompleted());
+ }
+
+ return Status::OK();
+ }
+
+ // consumes cached probe batches by invoking executor::Spawn.
+ Status ConsumeCachedProbeBatches() {
+ ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+ << " len:" << cached_probe_batches.size();
+
+ // acquire the mutex to access cached_probe_batches, because while
consuming, other
+ // batches should not be cached!
+ std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+ if (!cached_probe_batches_consumed) {
+ auto executor = ctx_->executor();
+ for (auto&& cached : cached_probe_batches) {
+ if (executor) {
+ Status lambda_status;
+ RETURN_NOT_OK(executor->Spawn([&] {
+ lambda_status = ConsumeProbeBatch(cached.first,
std::move(cached.second));
+ }));
+
+ // if the lambda execution failed internally, return status
+ RETURN_NOT_OK(lambda_status);
+ } else {
+ RETURN_NOT_OK(ConsumeProbeBatch(cached.first,
std::move(cached.second)));
+ }
+ }
+ // cached vector will be cleared. exec batches are expected to be moved
to the
+ // lambdas
+ cached_probe_batches.clear();
+ }
+
+ // set flag
+ cached_probe_batches_consumed = true;
+ return Status::OK();
+ }
+
+ Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch
batch) {
+ if (group_ids_data.GetNullCount() == batch.length) {
+ // All NULLS! hence, there are no valid outputs!
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " 0";
+ outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0));
+ } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered
+ auto filter_arr =
+ std::make_shared<BooleanArray>(group_ids_data.length,
group_ids_data.buffers[0],
+ /*null_bitmap=*/nullptr,
/*null_count=*/0,
+ /*offset=*/group_ids_data.offset);
+ ARROW_ASSIGN_OR_RAISE(auto rec_batch,
+ batch.ToRecordBatch(output_schema_,
ctx_->memory_pool()));
+ ARROW_ASSIGN_OR_RAISE(
+ auto filtered,
+ Filter(rec_batch, filter_arr,
+ /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_));
+ auto out_batch = ExecBatch(*filtered.record_batch());
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length;
+ outputs_[0]->InputReceived(this, seq, std::move(out_batch));
+ } else { // all values are valid for output
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length;
+ outputs_[0]->InputReceived(this, seq, std::move(batch));
+ }
+
+ return Status::OK();
+ }
+
+ // consumes a probe batch and increment probe batches count. Probing would
query the
+ // grouper[build_result_index] which have been merged with all others.
+ Status ConsumeProbeBatch(int seq, ExecBatch batch) {
+ ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq;
+
+ auto& final_grouper = *local_states_[build_result_index].grouper;
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(probe_index_field_ids_.size());
+ for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[probe_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Query the grouper with key_batch. If no match was found, returning
group_ids would
+ // have null.
+ ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch));
+ auto group_ids_data = *group_ids.array();
+
+ RETURN_NOT_OK(GenerateOutput(seq, group_ids_data, std::move(batch)));
+
+ if (out_counter_.Increment()) {
+ finished_.MarkFinished();
+ }
+ return Status::OK();
+ }
+
+ // Attempt to cache a probe batch. If it is not cached, return false.
+ // if cached_probe_batches_consumed is true, by the time a thread acquires
+ // cached_probe_batches_mutex, it should no longer be cached! instead, it
can be
+ // directly consumed!
+ bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) {
Review comment:
just a minor nit, but the current UnionNode doesn't renumber batches so
there can be duplicate seq_nums. It sounds like we want to get rid of them
anyways so it might be worth getting ahead of that and generating our own
indices here.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+ const std::shared_ptr<Schema>& right_schema,
+ const std::vector<int>& left_keys,
+ const std::vector<int>& right_keys) {
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("left and right key sizes do not match");
+ }
+
+ for (size_t i = 0; i < left_keys.size(); i++) {
+ auto l_type = left_schema->field(left_keys[i])->type();
+ auto r_type = right_schema->field(right_keys[i])->type();
+
+ if (!l_type->Equals(r_type)) {
+ return Status::Invalid("build and probe types do not match: " +
l_type->ToString() +
+ "!=" + r_type->ToString());
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+ const std::vector<FieldRef>& keys) {
+ std::vector<int> key_field_ids(keys.size());
+ // Find input field indices for left key fields
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+ key_field_ids[i] = match[0];
+ }
+ return key_field_ids;
+}
+} // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+ HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext*
ctx,
+ const std::vector<int>&& build_index_field_ids,
+ const std::vector<int>&& probe_index_field_ids)
+ : ExecNode(build_input->plan(), {build_input, probe_input},
+ {"hash_join_build", "hash_join_probe"},
probe_input->output_schema(),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ build_index_field_ids_(build_index_field_ids),
+ probe_index_field_ids_(probe_index_field_ids),
+ build_result_index(-1),
+ hash_table_built_(false),
+ cached_probe_batches_consumed(false) {}
+
+ private:
+ struct ThreadLocalState;
+
+ public:
+ const char* kind_name() override { return "HashSemiJoinNode"; }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ ARROW_LOG(DEBUG) << "init state";
+
+ // Get input schema
+ auto build_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+ key_descrs[i] = ValueDescr(build_type);
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs,
ctx_));
+
+ return Status::OK();
+ }
+
+ // Finds an appropriate index which could accumulate all build indices (i.e.
the grouper
+ // which has the highest # of groups)
+ void CalculateBuildResultIndex() {
+ int32_t curr_max = -1;
+ for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+ auto* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (state->grouper &&
+ curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+ curr_max = static_cast<int32_t>(state->grouper->num_groups());
+ build_result_index = i;
+ }
+ }
+ ARROW_DCHECK(build_result_index > -1);
+ ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+ }
+
+ // Performs the housekeeping work after the build-side is completed.
+ // Note: this method is not thread safe, and hence should be guaranteed that
it is
+ // not accessed concurrently!
+ Status BuildSideCompleted() {
+ ARROW_LOG(DEBUG) << "build side merge";
+
+ // if the hash table has already been built, return
+ if (hash_table_built_) return Status::OK();
+
+ CalculateBuildResultIndex();
+
+ // merge every group into the build_result_index grouper
+ ThreadLocalState* result_state = &local_states_[build_result_index];
+ for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+ ThreadLocalState* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (i == build_result_index || !state->grouper) {
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys,
state->grouper->GetUniques());
+
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _,
result_state->grouper->Consume(other_keys));
+ state->grouper.reset();
+ }
+
+ // enable flag that build side is completed
+ hash_table_built_ = true;
+
+ // since the build side is completed, consume cached probe batches
+ RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+ return Status::OK();
+ }
+
+ // consumes a build batch and increments the build_batches count. if the
build batches
+ // total reached at the end of consumption, all the local states will be
merged, before
+ // incrementing the total batches
+ Status ConsumeBuildBatch(ExecBatch batch) {
+ size_t thread_index = get_thread_index_();
+ ARROW_DCHECK(thread_index < local_states_.size());
+
+ ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+ << " len:" << batch.length;
+
+ auto state = &local_states_[thread_index];
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[build_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Create a batch with group ids
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+ if (build_counter_.Increment()) {
+ // only one thread would get inside this block!
+ // while incrementing, if the total is reached, call BuildSideCompleted.
+ RETURN_NOT_OK(BuildSideCompleted());
+ }
+
+ return Status::OK();
+ }
+
+ // consumes cached probe batches by invoking executor::Spawn.
+ Status ConsumeCachedProbeBatches() {
+ ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+ << " len:" << cached_probe_batches.size();
+
+ // acquire the mutex to access cached_probe_batches, because while
consuming, other
+ // batches should not be cached!
+ std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+ if (!cached_probe_batches_consumed) {
+ auto executor = ctx_->executor();
+ for (auto&& cached : cached_probe_batches) {
+ if (executor) {
+ Status lambda_status;
+ RETURN_NOT_OK(executor->Spawn([&] {
+ lambda_status = ConsumeProbeBatch(cached.first,
std::move(cached.second));
+ }));
+
+ // if the lambda execution failed internally, return status
+ RETURN_NOT_OK(lambda_status);
+ } else {
+ RETURN_NOT_OK(ConsumeProbeBatch(cached.first,
std::move(cached.second)));
+ }
+ }
+ // cached vector will be cleared. exec batches are expected to be moved
to the
+ // lambdas
+ cached_probe_batches.clear();
+ }
+
+ // set flag
+ cached_probe_batches_consumed = true;
+ return Status::OK();
+ }
+
+ Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch
batch) {
+ if (group_ids_data.GetNullCount() == batch.length) {
+ // All NULLS! hence, there are no valid outputs!
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " 0";
+ outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0));
+ } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered
+ auto filter_arr =
+ std::make_shared<BooleanArray>(group_ids_data.length,
group_ids_data.buffers[0],
+ /*null_bitmap=*/nullptr,
/*null_count=*/0,
+ /*offset=*/group_ids_data.offset);
+ ARROW_ASSIGN_OR_RAISE(auto rec_batch,
+ batch.ToRecordBatch(output_schema_,
ctx_->memory_pool()));
+ ARROW_ASSIGN_OR_RAISE(
+ auto filtered,
+ Filter(rec_batch, filter_arr,
+ /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_));
+ auto out_batch = ExecBatch(*filtered.record_batch());
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length;
+ outputs_[0]->InputReceived(this, seq, std::move(out_batch));
+ } else { // all values are valid for output
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length;
+ outputs_[0]->InputReceived(this, seq, std::move(batch));
+ }
+
+ return Status::OK();
+ }
+
+ // consumes a probe batch and increment probe batches count. Probing would
query the
+ // grouper[build_result_index] which have been merged with all others.
+ Status ConsumeProbeBatch(int seq, ExecBatch batch) {
+ ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq;
+
+ auto& final_grouper = *local_states_[build_result_index].grouper;
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(probe_index_field_ids_.size());
+ for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[probe_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Query the grouper with key_batch. If no match was found, returning
group_ids would
+ // have null.
+ ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch));
+ auto group_ids_data = *group_ids.array();
+
+ RETURN_NOT_OK(GenerateOutput(seq, group_ids_data, std::move(batch)));
+
+ if (out_counter_.Increment()) {
+ finished_.MarkFinished();
+ }
+ return Status::OK();
+ }
+
+ // Attempt to cache a probe batch. If it is not cached, return false.
+ // if cached_probe_batches_consumed is true, by the time a thread acquires
+ // cached_probe_batches_mutex, it should no longer be cached! instead, it
can be
+ // directly consumed!
+ bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) {
Review comment:
Or really, since we're guarding with a mutex already, why not just use a
vector?
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+ const std::shared_ptr<Schema>& right_schema,
+ const std::vector<int>& left_keys,
+ const std::vector<int>& right_keys) {
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("left and right key sizes do not match");
+ }
+
+ for (size_t i = 0; i < left_keys.size(); i++) {
+ auto l_type = left_schema->field(left_keys[i])->type();
+ auto r_type = right_schema->field(right_keys[i])->type();
+
+ if (!l_type->Equals(r_type)) {
+ return Status::Invalid("build and probe types do not match: " +
l_type->ToString() +
+ "!=" + r_type->ToString());
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+ const std::vector<FieldRef>& keys) {
+ std::vector<int> key_field_ids(keys.size());
+ // Find input field indices for left key fields
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+ key_field_ids[i] = match[0];
+ }
+ return key_field_ids;
+}
+} // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+ HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext*
ctx,
+ const std::vector<int>&& build_index_field_ids,
+ const std::vector<int>&& probe_index_field_ids)
+ : ExecNode(build_input->plan(), {build_input, probe_input},
+ {"hash_join_build", "hash_join_probe"},
probe_input->output_schema(),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ build_index_field_ids_(build_index_field_ids),
+ probe_index_field_ids_(probe_index_field_ids),
+ build_result_index(-1),
+ hash_table_built_(false),
+ cached_probe_batches_consumed(false) {}
+
+ private:
+ struct ThreadLocalState;
+
+ public:
+ const char* kind_name() override { return "HashSemiJoinNode"; }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ ARROW_LOG(DEBUG) << "init state";
+
+ // Get input schema
+ auto build_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+ key_descrs[i] = ValueDescr(build_type);
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs,
ctx_));
+
+ return Status::OK();
+ }
+
+ // Finds an appropriate index which could accumulate all build indices (i.e.
the grouper
+ // which has the highest # of groups)
+ void CalculateBuildResultIndex() {
+ int32_t curr_max = -1;
+ for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+ auto* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (state->grouper &&
+ curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+ curr_max = static_cast<int32_t>(state->grouper->num_groups());
+ build_result_index = i;
+ }
+ }
+ ARROW_DCHECK(build_result_index > -1);
+ ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+ }
+
+ // Performs the housekeeping work after the build-side is completed.
+ // Note: this method is not thread safe, and hence should be guaranteed that
it is
+ // not accessed concurrently!
+ Status BuildSideCompleted() {
+ ARROW_LOG(DEBUG) << "build side merge";
+
+ // if the hash table has already been built, return
+ if (hash_table_built_) return Status::OK();
+
+ CalculateBuildResultIndex();
+
+ // merge every group into the build_result_index grouper
+ ThreadLocalState* result_state = &local_states_[build_result_index];
+ for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+ ThreadLocalState* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (i == build_result_index || !state->grouper) {
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys,
state->grouper->GetUniques());
+
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _,
result_state->grouper->Consume(other_keys));
+ state->grouper.reset();
+ }
+
+ // enable flag that build side is completed
+ hash_table_built_ = true;
+
+ // since the build side is completed, consume cached probe batches
+ RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+ return Status::OK();
+ }
+
+ // consumes a build batch and increments the build_batches count. if the
build batches
+ // total reached at the end of consumption, all the local states will be
merged, before
+ // incrementing the total batches
+ Status ConsumeBuildBatch(ExecBatch batch) {
+ size_t thread_index = get_thread_index_();
+ ARROW_DCHECK(thread_index < local_states_.size());
+
+ ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+ << " len:" << batch.length;
+
+ auto state = &local_states_[thread_index];
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[build_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Create a batch with group ids
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+ if (build_counter_.Increment()) {
+ // only one thread would get inside this block!
+ // while incrementing, if the total is reached, call BuildSideCompleted.
+ RETURN_NOT_OK(BuildSideCompleted());
+ }
+
+ return Status::OK();
+ }
+
+ // consumes cached probe batches by invoking executor::Spawn.
+ Status ConsumeCachedProbeBatches() {
+ ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+ << " len:" << cached_probe_batches.size();
+
+ // acquire the mutex to access cached_probe_batches, because while
consuming, other
+ // batches should not be cached!
+ std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+ if (!cached_probe_batches_consumed) {
+ auto executor = ctx_->executor();
+ for (auto&& cached : cached_probe_batches) {
+ if (executor) {
+ Status lambda_status;
+ RETURN_NOT_OK(executor->Spawn([&] {
+ lambda_status = ConsumeProbeBatch(cached.first,
std::move(cached.second));
Review comment:
I think this should be ErrorIfNotOk.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+ const std::shared_ptr<Schema>& right_schema,
+ const std::vector<int>& left_keys,
+ const std::vector<int>& right_keys) {
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("left and right key sizes do not match");
+ }
+
+ for (size_t i = 0; i < left_keys.size(); i++) {
+ auto l_type = left_schema->field(left_keys[i])->type();
+ auto r_type = right_schema->field(right_keys[i])->type();
+
+ if (!l_type->Equals(r_type)) {
+ return Status::Invalid("build and probe types do not match: " +
l_type->ToString() +
+ "!=" + r_type->ToString());
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+ const std::vector<FieldRef>& keys) {
+ std::vector<int> key_field_ids(keys.size());
+ // Find input field indices for left key fields
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+ key_field_ids[i] = match[0];
+ }
+ return key_field_ids;
+}
+} // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+ HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext*
ctx,
+ const std::vector<int>&& build_index_field_ids,
+ const std::vector<int>&& probe_index_field_ids)
+ : ExecNode(build_input->plan(), {build_input, probe_input},
+ {"hash_join_build", "hash_join_probe"},
probe_input->output_schema(),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ build_index_field_ids_(build_index_field_ids),
+ probe_index_field_ids_(probe_index_field_ids),
+ build_result_index(-1),
+ hash_table_built_(false),
+ cached_probe_batches_consumed(false) {}
+
+ private:
+ struct ThreadLocalState;
+
+ public:
+ const char* kind_name() override { return "HashSemiJoinNode"; }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ ARROW_LOG(DEBUG) << "init state";
+
+ // Get input schema
+ auto build_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+ key_descrs[i] = ValueDescr(build_type);
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs,
ctx_));
+
+ return Status::OK();
+ }
+
+ // Finds an appropriate index which could accumulate all build indices (i.e.
the grouper
+ // which has the highest # of groups)
+ void CalculateBuildResultIndex() {
+ int32_t curr_max = -1;
+ for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+ auto* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (state->grouper &&
+ curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+ curr_max = static_cast<int32_t>(state->grouper->num_groups());
+ build_result_index = i;
+ }
+ }
+ ARROW_DCHECK(build_result_index > -1);
+ ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+ }
+
+ // Performs the housekeeping work after the build-side is completed.
+ // Note: this method is not thread safe, and hence should be guaranteed that
it is
+ // not accessed concurrently!
+ Status BuildSideCompleted() {
+ ARROW_LOG(DEBUG) << "build side merge";
+
+ // if the hash table has already been built, return
+ if (hash_table_built_) return Status::OK();
+
+ CalculateBuildResultIndex();
+
+ // merge every group into the build_result_index grouper
+ ThreadLocalState* result_state = &local_states_[build_result_index];
+ for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+ ThreadLocalState* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (i == build_result_index || !state->grouper) {
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys,
state->grouper->GetUniques());
+
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _,
result_state->grouper->Consume(other_keys));
+ state->grouper.reset();
+ }
+
+ // enable flag that build side is completed
+ hash_table_built_ = true;
+
+ // since the build side is completed, consume cached probe batches
+ RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+ return Status::OK();
+ }
+
+ // consumes a build batch and increments the build_batches count. if the
build batches
+ // total reached at the end of consumption, all the local states will be
merged, before
+ // incrementing the total batches
+ Status ConsumeBuildBatch(ExecBatch batch) {
+ size_t thread_index = get_thread_index_();
+ ARROW_DCHECK(thread_index < local_states_.size());
+
+ ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+ << " len:" << batch.length;
+
+ auto state = &local_states_[thread_index];
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[build_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Create a batch with group ids
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+ if (build_counter_.Increment()) {
+ // only one thread would get inside this block!
+ // while incrementing, if the total is reached, call BuildSideCompleted.
+ RETURN_NOT_OK(BuildSideCompleted());
+ }
+
+ return Status::OK();
+ }
+
+ // consumes cached probe batches by invoking executor::Spawn.
+ Status ConsumeCachedProbeBatches() {
+ ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+ << " len:" << cached_probe_batches.size();
+
+ // acquire the mutex to access cached_probe_batches, because while
consuming, other
+ // batches should not be cached!
+ std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+ if (!cached_probe_batches_consumed) {
+ auto executor = ctx_->executor();
+ for (auto&& cached : cached_probe_batches) {
+ if (executor) {
+ Status lambda_status;
+ RETURN_NOT_OK(executor->Spawn([&] {
+ lambda_status = ConsumeProbeBatch(cached.first,
std::move(cached.second));
Review comment:
This captures a reference to a stack local, but I would think that
reference gets invalidated immediately - I think this is unsafe.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,552 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+ const std::shared_ptr<Schema>& right_schema,
+ const std::vector<int>& left_keys,
+ const std::vector<int>& right_keys) {
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("left and right key sizes do not match");
+ }
+
+ for (size_t i = 0; i < left_keys.size(); i++) {
+ auto l_type = left_schema->field(left_keys[i])->type();
+ auto r_type = right_schema->field(right_keys[i])->type();
+
+ if (!l_type->Equals(r_type)) {
+ return Status::Invalid("build and probe types do not match: " +
l_type->ToString() +
+ "!=" + r_type->ToString());
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+ const std::vector<FieldRef>& keys) {
+ std::vector<int> key_field_ids(keys.size());
+ // Find input field indices for left key fields
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+ key_field_ids[i] = match[0];
+ }
+ return key_field_ids;
+}
+} // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+ HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext*
ctx,
+ const std::vector<int>&& build_index_field_ids,
+ const std::vector<int>&& probe_index_field_ids)
+ : ExecNode(build_input->plan(), {build_input, probe_input},
+ {"hash_join_build", "hash_join_probe"},
probe_input->output_schema(),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ build_index_field_ids_(build_index_field_ids),
+ probe_index_field_ids_(probe_index_field_ids),
+ build_result_index(-1),
+ hash_table_built_(false),
+ cached_probe_batches_consumed(false) {}
+
+ private:
+ struct ThreadLocalState;
+
+ public:
+ const char* kind_name() override { return "HashSemiJoinNode"; }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ ARROW_LOG(DEBUG) << "init state";
+
+ // Get input schema
+ auto build_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+ key_descrs[i] = ValueDescr(build_type);
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs,
ctx_));
+
+ return Status::OK();
+ }
+
+ // Finds an appropriate index which could accumulate all build indices (i.e.
the grouper
+ // which has the highest # of groups)
+ void CalculateBuildResultIndex() {
+ int32_t curr_max = -1;
+ for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+ auto* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (state->grouper &&
+ curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+ curr_max = static_cast<int32_t>(state->grouper->num_groups());
+ build_result_index = i;
+ }
+ }
+ ARROW_DCHECK(build_result_index > -1);
+ ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+ }
+
+ // Performs the housekeeping work after the build-side is completed.
+ // Note: this method is not thread safe, and hence should be guaranteed that
it is
+ // not accessed concurrently!
+ Status BuildSideCompleted() {
+ ARROW_LOG(DEBUG) << "build side merge";
+
+ // if the hash table has already been built, return
+ if (hash_table_built_) return Status::OK();
+
+ CalculateBuildResultIndex();
+
+ // merge every group into the build_result_index grouper
+ ThreadLocalState* result_state = &local_states_[build_result_index];
+ for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+ ThreadLocalState* state = &local_states_[i];
+ ARROW_DCHECK(state);
+ if (i == build_result_index || !state->grouper) {
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys,
state->grouper->GetUniques());
+
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _,
result_state->grouper->Consume(other_keys));
+ state->grouper.reset();
+ }
+
+ // enable flag that build side is completed
+ hash_table_built_ = true;
+
+ // since the build side is completed, consume cached probe batches
+ RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+ return Status::OK();
+ }
+
+ // consumes a build batch and increments the build_batches count. if the
build batches
+ // total reached at the end of consumption, all the local states will be
merged, before
+ // incrementing the total batches
+ Status ConsumeBuildBatch(ExecBatch batch) {
+ size_t thread_index = get_thread_index_();
+ ARROW_DCHECK(thread_index < local_states_.size());
+
+ ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+ << " len:" << batch.length;
+
+ auto state = &local_states_[thread_index];
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(build_index_field_ids_.size());
+ for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[build_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Create a batch with group ids
+ // TODO(niranda) replace with void consume method
+ ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+ if (build_counter_.Increment()) {
+ // only one thread would get inside this block!
+ // while incrementing, if the total is reached, call BuildSideCompleted.
+ RETURN_NOT_OK(BuildSideCompleted());
+ }
+
+ return Status::OK();
+ }
+
+ // consumes cached probe batches by invoking executor::Spawn.
+ Status ConsumeCachedProbeBatches() {
+ ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+ << " len:" << cached_probe_batches.size();
+
+ // acquire the mutex to access cached_probe_batches, because while
consuming, other
+ // batches should not be cached!
+ std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+ if (!cached_probe_batches_consumed) {
+ auto executor = ctx_->executor();
+ for (auto&& cached : cached_probe_batches) {
+ if (executor) {
+ Status lambda_status;
+ RETURN_NOT_OK(executor->Spawn([&] {
+ lambda_status = ConsumeProbeBatch(cached.first,
std::move(cached.second));
+ }));
+
+ // if the lambda execution failed internally, return status
+ RETURN_NOT_OK(lambda_status);
+ } else {
+ RETURN_NOT_OK(ConsumeProbeBatch(cached.first,
std::move(cached.second)));
+ }
+ }
+ // cached vector will be cleared. exec batches are expected to be moved
to the
+ // lambdas
+ cached_probe_batches.clear();
+ }
+
+ // set flag
+ cached_probe_batches_consumed = true;
+ return Status::OK();
+ }
+
+ Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch
batch) {
+ if (group_ids_data.GetNullCount() == batch.length) {
+ // All NULLS! hence, there are no valid outputs!
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " 0";
+ outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0));
+ } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered
+ auto filter_arr =
+ std::make_shared<BooleanArray>(group_ids_data.length,
group_ids_data.buffers[0],
+ /*null_bitmap=*/nullptr,
/*null_count=*/0,
+ /*offset=*/group_ids_data.offset);
+ ARROW_ASSIGN_OR_RAISE(auto rec_batch,
+ batch.ToRecordBatch(output_schema_,
ctx_->memory_pool()));
+ ARROW_ASSIGN_OR_RAISE(
+ auto filtered,
+ Filter(rec_batch, filter_arr,
+ /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_));
+ auto out_batch = ExecBatch(*filtered.record_batch());
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length;
+ outputs_[0]->InputReceived(this, seq, std::move(out_batch));
+ } else { // all values are valid for output
+ ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length;
+ outputs_[0]->InputReceived(this, seq, std::move(batch));
+ }
+
+ return Status::OK();
+ }
+
+ // consumes a probe batch and increment probe batches count. Probing would
query the
+ // grouper[build_result_index] which have been merged with all others.
+ Status ConsumeProbeBatch(int seq, ExecBatch batch) {
+ ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq;
+
+ auto& final_grouper = *local_states_[build_result_index].grouper;
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(probe_index_field_ids_.size());
+ for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) {
+ keys[i] = batch.values[probe_index_field_ids_[i]];
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+ // Query the grouper with key_batch. If no match was found, returning
group_ids would
+ // have null.
+ ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch));
+ auto group_ids_data = *group_ids.array();
+
+ RETURN_NOT_OK(GenerateOutput(seq, group_ids_data, std::move(batch)));
+
+ if (out_counter_.Increment()) {
+ finished_.MarkFinished();
+ }
+ return Status::OK();
+ }
+
+ // Attempt to cache a probe batch. If it is not cached, return false.
+ // if cached_probe_batches_consumed is true, by the time a thread acquires
+ // cached_probe_batches_mutex, it should no longer be cached! instead, it
can be
+ // directly consumed!
+ bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) {
Review comment:
In ARROW-13660 we'll get rid of seq_num fortunately
##########
File path: cpp/src/arrow/compute/exec/options.h
##########
@@ -111,5 +111,32 @@ class ARROW_EXPORT SinkNodeOptions : public
ExecNodeOptions {
std::function<Future<util::optional<ExecBatch>>()>* generator;
};
+enum JoinType {
Review comment:
Hmm. The C++ style guide doesn't say anything. I would say enum class is
generally preferred in modern C++. We semi-intentionally use plain enum in a
few places (when nested inside a struct) but I'm not sure why we have that
pattern.
This isn't the only place where this happens so it's not a big deal, but I
would prefer enum class so that it's not implicitly converted and so that
namespacing is explicit.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]