lidavidm commented on a change in pull request #10845: URL: https://github.com/apache/arrow/pull/10845#discussion_r695719110
########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; Review comment: Do we still need these log statements? ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector<ValueDescr> key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast<int>(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast<int32_t>(state->grouper->num_groups())) { + curr_max = static_cast<int32_t>(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector<Datum> keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:"; + outputs_[0]->InputReceived(this, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared<BooleanArray>(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" + << " " << out_batch.length; + outputs_[0]->InputReceived(this, std::move(out_batch)); + } else { // all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" + << " " << batch.length; + outputs_[0]->InputReceived(this, std::move(batch)); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[build_result_index] which have been merged with all others. + Status ConsumeProbeBatch(ExecBatch batch) { + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:"; + + auto& final_grouper = *local_states_[build_result_index].grouper; + + // Create a batch with key columns + std::vector<Datum> keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + RETURN_NOT_OK(GenerateOutput(group_ids_data, std::move(batch))); + + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } + return Status::OK(); + } + + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(ExecBatch* batch) { + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " len:" << batch->length; + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.push_back(std::move(*batch)); + return true; + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } Review comment: nit, but as far as I understand, this inline doesn't do what it seems like it does (inline is for linkage not optimization) ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { Review comment: for consistency with all the other nodes ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { Review comment: Should we also check that number of keys is nonzero? ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector<ValueDescr> key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast<int>(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast<int32_t>(state->grouper->num_groups())) { + curr_max = static_cast<int32_t>(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector<Datum> keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:"; + outputs_[0]->InputReceived(this, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared<BooleanArray>(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" + << " " << out_batch.length; + outputs_[0]->InputReceived(this, std::move(out_batch)); + } else { // all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" + << " " << batch.length; + outputs_[0]->InputReceived(this, std::move(batch)); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[build_result_index] which have been merged with all others. + Status ConsumeProbeBatch(ExecBatch batch) { + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:"; + + auto& final_grouper = *local_states_[build_result_index].grouper; + + // Create a batch with key columns + std::vector<Datum> keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + RETURN_NOT_OK(GenerateOutput(group_ids_data, std::move(batch))); + + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } + return Status::OK(); + } + + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(ExecBatch* batch) { + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " len:" << batch->length; + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.push_back(std::move(*batch)); + return true; + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + + // If all build side batches received? continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, ExecBatch batch) override { + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << 0 << " len:" << batch.length; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + if (finished_.is_finished()) { + return; + } + + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!hash_table_built_); + + ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_) { + // build side done, continue with probing. when hash_table_built_ is set, it is + // guaranteed that some thread has already called the ConsumeCachedProbeBatches + + // consume this probe batch + ErrorIfNotOk(ConsumeProbeBatch(std::move(batch))); + } else { // build side not completed. Cache this batch! + if (!AttemptToCacheProbeBatch(&batch)) { + // if the cache attempt fails, consume the batch + ErrorIfNotOk(ConsumeProbeBatch(std::move(batch))); + } + } + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + ARROW_LOG(DEBUG) << "error received " << error.ToString(); + DCHECK_EQ(input, inputs_[0]); + + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + + void InputFinished(ExecNode* input, int num_total) override { + ARROW_LOG(DEBUG) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") + << " tot:" << num_total; + + // bail if StopProducing was called + if (finished_.is_finished()) return; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + // set total for build input + if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + ErrorIfNotOk(BuildSideCompleted()); + return; + } + + // set total for probe input. If it returns that probe side has completed, nothing to + // do, because probing inputs will be streamed to the output + // probe_counter_.SetTotal(num_total); + + // output will be streamed from the probe side. So, they will have the same total. + if (out_counter_.SetTotal(num_total)) { + // if out_counter has completed, the future is finished! + ErrorIfNotOk(ConsumeCachedProbeBatches()); + outputs_[0]->InputFinished(this, num_total); + finished_.MarkFinished(); + } else { + outputs_[0]->InputFinished(this, num_total); + } + } + + Status StartProducing() override { + ARROW_LOG(DEBUG) << "start prod"; + finished_ = Future<>::Make(); + + local_states_.resize(ThreadIndexer::Capacity()); + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + ARROW_LOG(DEBUG) << "stop prod from node"; + + DCHECK_EQ(output, outputs_[0]); + + if (build_counter_.Cancel()) { + finished_.MarkFinished(); + } else if (out_counter_.Cancel()) { + finished_.MarkFinished(); + } + + for (auto&& input : inputs_) { + input->StopProducing(this); + } + } + + // TODO(niranda) couldn't there be multiple outputs for a Node? Review comment: I would think only if the node supports it explicitly, i.e. we shouldn't worry about that here. ########## File path: cpp/src/arrow/compute/exec/hash_join_node_test.cc ########## @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gmock/gmock-matchers.h> + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +void GenerateBatchesFromString(const std::shared_ptr<Schema>& schema, Review comment: nit: [prefer to put out parameters last](https://google.github.io/styleguide/cppguide.html#Inputs_and_Outputs), or better, why can't this just return the batches? ########## File path: cpp/src/arrow/compute/exec/hash_join_node_test.cc ########## @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gmock/gmock-matchers.h> + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +void GenerateBatchesFromString(const std::shared_ptr<Schema>& schema, + const std::vector<util::string_view>& json_strings, + BatchesWithSchema* out_batches, int multiplicity = 1) { + std::vector<ValueDescr> descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches->batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + size_t batch_count = out_batches->batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches->batches.push_back(out_batches->batches[i]); + } + } + + out_batches->schema = schema; +} + +void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const std::vector<FieldRef>& left_keys, + const std::vector<FieldRef>& right_keys, + const BatchesWithSchema& exp_batches, bool parallel = false) { + auto exec_ctx = arrow::internal::make_unique<ExecContext>( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator<util::optional<ExecBatch>> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); +} + +void RunNonEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + BatchesWithSchema l_batches, r_batches, exp_batches; + + int multiplicity = parallel ? 100 : 1; + + GenerateBatchesFromString(l_schema, + {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", + R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &l_batches, multiplicity); + + GenerateBatchesFromString( + r_schema, + {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, + &r_batches, multiplicity); + + switch (type) { + case LEFT_SEMI: + GenerateBatchesFromString( + l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &exp_batches, multiplicity); + break; + case RIGHT_SEMI: + GenerateBatchesFromString( + r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, + &exp_batches, multiplicity); + break; + case LEFT_ANTI: + GenerateBatchesFromString( + l_schema, {R"([[0,"d"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", R"([])"}, + &exp_batches, multiplicity); + break; + case RIGHT_ANTI: + GenerateBatchesFromString(r_schema, {R"([["f", 0]])", R"([["g", 4]])", R"([])"}, + &exp_batches, multiplicity); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } + + CheckRunOutput(type, l_batches, r_batches, + /*left_keys=*/{{"l_str"}}, /*right_keys=*/{{"r_str"}}, exp_batches, + parallel); +} + +void RunEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + + int multiplicity = parallel ? 100 : 1; + + BatchesWithSchema l_empty, r_empty, l_n_empty, r_n_empty; + + GenerateBatchesFromString(l_schema, {R"([])"}, &l_empty, multiplicity); + GenerateBatchesFromString(r_schema, {R"([])"}, &r_empty, multiplicity); + + GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, &l_n_empty, + multiplicity); + GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, &r_n_empty, + multiplicity); + + std::vector<FieldRef> l_keys{{"l_str"}}; + std::vector<FieldRef> r_keys{{"r_str"}}; + + switch (type) { + case LEFT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_empty, parallel); + break; + case LEFT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_n_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_n_empty, parallel); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } +} + +class HashJoinTest : public testing::TestWithParam<std::tuple<JoinType, bool>> {}; + +INSTANTIATE_TEST_SUITE_P( + HashJoinTest, HashJoinTest, + ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI), + ::testing::Values(false, true))); + +TEST_P(HashJoinTest, TestSemiJoins) { + RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +TEST_P(HashJoinTest, TestSemiJoinstEmpty) { + RunEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +void TestJoinRandom(const std::shared_ptr<DataType>& data_type, JoinType type, + bool parallel, int num_batches, int batch_size) { + auto l_schema = schema({field("l0", data_type), field("l1", data_type)}); + auto r_schema = schema({field("r0", data_type), field("r1", data_type)}); + + // generate data + auto l_batches = MakeRandomBatches(l_schema, num_batches, batch_size); + auto r_batches = MakeRandomBatches(r_schema, num_batches, batch_size); + + std::vector<FieldRef> left_keys{{"l0"}}; + std::vector<FieldRef> right_keys{{"r1"}}; + + auto exec_ctx = arrow::internal::make_unique<ExecContext>( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator<util::optional<ExecBatch>> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + // TODO(niranda) add a verification step for res Review comment: Do we plan to tackle this still (presumably with a naive implementation)? ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector<ValueDescr> key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast<int>(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast<int32_t>(state->grouper->num_groups())) { + curr_max = static_cast<int32_t>(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector<Datum> keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:"; + outputs_[0]->InputReceived(this, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared<BooleanArray>(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" + << " " << out_batch.length; + outputs_[0]->InputReceived(this, std::move(out_batch)); + } else { // all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" + << " " << batch.length; + outputs_[0]->InputReceived(this, std::move(batch)); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[build_result_index] which have been merged with all others. + Status ConsumeProbeBatch(ExecBatch batch) { + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:"; + + auto& final_grouper = *local_states_[build_result_index].grouper; + + // Create a batch with key columns + std::vector<Datum> keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + RETURN_NOT_OK(GenerateOutput(group_ids_data, std::move(batch))); + + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } + return Status::OK(); + } + + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(ExecBatch* batch) { + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " len:" << batch->length; + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.push_back(std::move(*batch)); + return true; + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + + // If all build side batches received? continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, ExecBatch batch) override { + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << 0 << " len:" << batch.length; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + if (finished_.is_finished()) { + return; + } + + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!hash_table_built_); + + ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_) { Review comment: This non-atomic bool makes me nervous since there is nothing synchronizing it, though I suppose things still work because AttemptToCacheProbeBatch actually takes a lock and checks. ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector<ValueDescr> key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast<int>(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast<int32_t>(state->grouper->num_groups())) { + curr_max = static_cast<int32_t>(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector<Datum> keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:"; + outputs_[0]->InputReceived(this, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared<BooleanArray>(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" + << " " << out_batch.length; + outputs_[0]->InputReceived(this, std::move(out_batch)); + } else { // all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" + << " " << batch.length; + outputs_[0]->InputReceived(this, std::move(batch)); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[build_result_index] which have been merged with all others. + Status ConsumeProbeBatch(ExecBatch batch) { + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:"; + + auto& final_grouper = *local_states_[build_result_index].grouper; + + // Create a batch with key columns + std::vector<Datum> keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + RETURN_NOT_OK(GenerateOutput(group_ids_data, std::move(batch))); + + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } + return Status::OK(); + } + + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(ExecBatch* batch) { + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " len:" << batch->length; + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.push_back(std::move(*batch)); + return true; + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + + // If all build side batches received? continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, ExecBatch batch) override { + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << 0 << " len:" << batch.length; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + if (finished_.is_finished()) { + return; + } + + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!hash_table_built_); + + ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_) { + // build side done, continue with probing. when hash_table_built_ is set, it is + // guaranteed that some thread has already called the ConsumeCachedProbeBatches + + // consume this probe batch + ErrorIfNotOk(ConsumeProbeBatch(std::move(batch))); + } else { // build side not completed. Cache this batch! + if (!AttemptToCacheProbeBatch(&batch)) { + // if the cache attempt fails, consume the batch + ErrorIfNotOk(ConsumeProbeBatch(std::move(batch))); + } + } + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + ARROW_LOG(DEBUG) << "error received " << error.ToString(); + DCHECK_EQ(input, inputs_[0]); + + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + + void InputFinished(ExecNode* input, int num_total) override { + ARROW_LOG(DEBUG) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") + << " tot:" << num_total; + + // bail if StopProducing was called + if (finished_.is_finished()) return; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + // set total for build input + if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + ErrorIfNotOk(BuildSideCompleted()); + return; + } + + // set total for probe input. If it returns that probe side has completed, nothing to + // do, because probing inputs will be streamed to the output + // probe_counter_.SetTotal(num_total); Review comment: nit: comment/code is irrelevant now? ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { Review comment: ```suggestion class HashSemiJoinNode : public ExecNode { public: ``` ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector<ValueDescr> key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast<int>(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast<int32_t>(state->grouper->num_groups())) { + curr_max = static_cast<int32_t>(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector<Datum> keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:"; + outputs_[0]->InputReceived(this, batch.Slice(0, 0)); Review comment: Not sure if it's a big deal/can be dealt with later, but just slicing may inadvertently retain a reference to a large batch that is otherwise empty. ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector<ValueDescr> key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast<int>(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast<int32_t>(state->grouper->num_groups())) { + curr_max = static_cast<int32_t>(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector<Datum> keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:"; + outputs_[0]->InputReceived(this, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared<BooleanArray>(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" + << " " << out_batch.length; + outputs_[0]->InputReceived(this, std::move(out_batch)); Review comment: Similarly, this can be dealt with later, but it's a little unfortunate that here, we are making some heap allocations and potentially expanding scalars into arrays to do the filtering, especially when the underlying kernel can operate directly on an ExecBatch. ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields Review comment: nit: this comment is a little misleading since it's used below for both sides. ########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema, + const std::shared_ptr<Schema>& right_schema, + const std::vector<int>& left_keys, + const std::vector<int>& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result<std::vector<int>> PopulateKeys(const Schema& schema, + const std::vector<FieldRef>& keys) { + std::vector<int> key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template <bool anti_join = false> +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector<int>&& build_index_field_ids, + const std::vector<int>&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector<ValueDescr> key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast<int>(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast<int32_t>(state->grouper->num_groups())) { + curr_max = static_cast<int32_t>(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector<Datum> keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:"; + outputs_[0]->InputReceived(this, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared<BooleanArray>(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" + << " " << out_batch.length; + outputs_[0]->InputReceived(this, std::move(out_batch)); + } else { // all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" + << " " << batch.length; + outputs_[0]->InputReceived(this, std::move(batch)); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[build_result_index] which have been merged with all others. + Status ConsumeProbeBatch(ExecBatch batch) { + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:"; + + auto& final_grouper = *local_states_[build_result_index].grouper; + + // Create a batch with key columns + std::vector<Datum> keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + RETURN_NOT_OK(GenerateOutput(group_ids_data, std::move(batch))); + + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } + return Status::OK(); + } + + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(ExecBatch* batch) { + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " len:" << batch->length; + std::lock_guard<std::mutex> lck(cached_probe_batches_mutex); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.push_back(std::move(*batch)); + return true; + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + + // If all build side batches received? continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, ExecBatch batch) override { + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << 0 << " len:" << batch.length; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + if (finished_.is_finished()) { + return; + } + + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!hash_table_built_); + + ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_) { Review comment: I would be curious how this fares on ARM in CI/if this branch is really getting taken. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
