save-buffer commented on code in PR #13332:
URL: https://github.com/apache/arrow/pull/13332#discussion_r893823387
##########
cpp/src/arrow/compute/exec/hash_join_benchmark.cc:
##########
@@ -124,67 +128,60 @@ class JoinBenchmark {
schema_mgr_ = arrow::internal::make_unique<HashJoinSchema>();
Expression filter = literal(true);
- DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_.schema,
left_keys,
- *r_batches_.schema, right_keys, filter, "l_",
"r_"));
+ DCHECK_OK(schema_mgr_->Init(settings.join_type,
*l_batches_with_schema.schema,
+ left_keys, *r_batches_with_schema.schema,
right_keys,
+ filter, "l_", "r_"));
join_ = *HashJoinImpl::MakeBasic();
- HashJoinImpl* bloom_filter_pushdown_target = nullptr;
- std::vector<int> key_input_map;
-
- bool bloom_filter_does_not_apply_to_join =
- settings.join_type == JoinType::LEFT_ANTI ||
- settings.join_type == JoinType::LEFT_OUTER ||
- settings.join_type == JoinType::FULL_OUTER;
- if (settings.bloom_filter && !bloom_filter_does_not_apply_to_join) {
- bloom_filter_pushdown_target = join_.get();
- SchemaProjectionMap probe_key_to_input = schema_mgr_->proj_maps[0].map(
- HashJoinProjection::KEY, HashJoinProjection::INPUT);
- int num_keys = probe_key_to_input.num_cols;
- for (int i = 0; i < num_keys; i++)
- key_input_map.push_back(probe_key_to_input.get(i));
- }
-
omp_set_num_threads(settings.num_threads);
auto schedule_callback = [](std::function<Status(size_t)> func) -> Status {
#pragma omp task
{ DCHECK_OK(func(omp_get_thread_num())); }
return Status::OK();
};
+ scheduler_ = TaskScheduler::Make();
DCHECK_OK(join_->Init(
- ctx_.get(), settings.join_type, !is_parallel, settings.num_threads,
- schema_mgr_.get(), std::move(key_cmp), std::move(filter),
[](ExecBatch) {},
- [](int64_t x) {}, schedule_callback, bloom_filter_pushdown_target,
- std::move(key_input_map)));
+ ctx_.get(), settings.join_type, settings.num_threads,
schema_mgr_.get(),
+ std::move(key_cmp), std::move(filter), [](ExecBatch) {}, [](int64_t x)
{},
+ scheduler_.get()));
+
+ task_group_probe_ = scheduler_->RegisterTaskGroup(
+ [this](size_t thread_index, int64_t task_id) -> Status {
+ return join_->ProbeSingleBatch(thread_index,
std::move(l_batches_[task_id]));
+ },
+ [this](size_t thread_index) -> Status {
+ return join_->ProbingFinished(thread_index);
+ });
+
+ scheduler_->RegisterEnd();
+
+ DCHECK_OK(scheduler_->StartScheduling(
+ 0 /*thread index*/, std::move(schedule_callback),
+ static_cast<int>(2 * settings.num_threads) /*concurrent tasks*/,
!is_parallel));
}
void RunJoin() {
#pragma omp parallel
{
int tid = omp_get_thread_num();
-#pragma omp for nowait
- for (auto it = r_batches_.batches.begin(); it !=
r_batches_.batches.end(); ++it)
- DCHECK_OK(join_->InputReceived(tid, /*side=*/1, *it));
-#pragma omp for nowait
- for (auto it = l_batches_.batches.begin(); it !=
l_batches_.batches.end(); ++it)
- DCHECK_OK(join_->InputReceived(tid, /*side=*/0, *it));
-
-#pragma omp barrier
-
-#pragma omp single nowait
- { DCHECK_OK(join_->InputFinished(tid, /*side=*/1)); }
-
-#pragma omp single nowait
- { DCHECK_OK(join_->InputFinished(tid, /*side=*/0)); }
+#pragma omp single
Review Comment:
Yep, `#pragma omp parallel` creates a thread pool, `#pragma omp single` just
makes it execute on one thread. The magic of creating/scheduling tasks happens
inside of `schedule_callback` with `#pragma omp task`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]