This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new cf03901ba8 ARROW-16523: [C++] Part 1 of ExecPlan cleanup: Centralized
Task Group (#13143)
cf03901ba8 is described below
commit cf03901ba8fa6364d68dc6b9a31942a329ac6c89
Author: Sasha Krassovsky <[email protected]>
AuthorDate: Thu Jul 14 15:11:45 2022 -0800
ARROW-16523: [C++] Part 1 of ExecPlan cleanup: Centralized Task Group
(#13143)
Authored-by: Sasha Krassovsky <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/compute/exec/aggregate_node.cc | 80 ++++++--------
cpp/src/arrow/compute/exec/exec_plan.cc | 125 ++++++++++++++++++----
cpp/src/arrow/compute/exec/exec_plan.h | 80 +++++++++++---
cpp/src/arrow/compute/exec/hash_join.cc | 26 +++--
cpp/src/arrow/compute/exec/hash_join.h | 8 +-
cpp/src/arrow/compute/exec/hash_join_benchmark.cc | 25 +++--
cpp/src/arrow/compute/exec/hash_join_node.cc | 113 +++++++++----------
cpp/src/arrow/compute/exec/hash_join_node_test.cc | 3 +
cpp/src/arrow/compute/exec/sink_node.cc | 17 +--
cpp/src/arrow/compute/exec/source_node.cc | 115 +++++++++-----------
cpp/src/arrow/compute/exec/swiss_join.cc | 39 ++++---
cpp/src/arrow/compute/exec/test_util.cc | 3 +-
cpp/src/arrow/compute/exec/tpch_node.cc | 77 +++++--------
cpp/src/arrow/compute/exec/union_node.cc | 3 +-
cpp/src/arrow/dataset/scanner_test.cc | 31 ++++--
cpp/src/arrow/engine/substrait/util.cc | 9 +-
cpp/src/arrow/util/future.cc | 18 +++-
cpp/src/arrow/util/future.h | 13 +++
cpp/src/arrow/util/tracing_internal.h | 15 +--
python/pyarrow/tests/test_dataset.py | 2 +-
20 files changed, 465 insertions(+), 337 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc
b/cpp/src/arrow/compute/exec/aggregate_node.cc
index 0131319be3..96aa56b80c 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -113,7 +113,7 @@ class ScalarAggregateNode : public ExecNode {
}
KernelContext kernel_ctx{exec_ctx};
- states[i].resize(ThreadIndexer::Capacity());
+ states[i].resize(plan->max_concurrency());
RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx,
KernelInitArgs{kernels[i],
{
@@ -168,7 +168,7 @@ class ScalarAggregateNode : public ExecNode {
{"batch.length", batch.length}});
DCHECK_EQ(input, inputs_[0]);
- auto thread_index = get_thread_index_();
+ auto thread_index = plan_->GetThreadIndex();
if (ErrorIfNotOk(DoConsume(std::move(batch), thread_index))) return;
@@ -196,8 +196,6 @@ class ScalarAggregateNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
- finished_ = Future<>::Make();
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
// Scalar aggregates will only output a single batch
outputs_[0]->InputFinished(this, 1);
return Status::OK();
@@ -224,8 +222,6 @@ class ScalarAggregateNode : public ExecNode {
inputs_[0]->StopProducing(this);
}
- Future<> finished() override { return finished_; }
-
protected:
std::string ToStringExtra(int indent = 0) const override {
std::stringstream ss;
@@ -266,7 +262,6 @@ class ScalarAggregateNode : public ExecNode {
std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
- ThreadIndexer get_thread_index_;
AtomicCounter input_counter_;
};
@@ -284,6 +279,19 @@ class GroupByNode : public ExecNode {
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}
+ Status Init() override {
+ output_task_group_id_ = plan_->RegisterTaskGroup(
+ [this](size_t, int64_t task_id) {
+ OutputNthBatch(task_id);
+ return Status::OK();
+ },
+ [this](size_t) {
+ finished_.MarkFinished();
+ return Status::OK();
+ });
+ return Status::OK();
+ }
+
static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
@@ -358,7 +366,7 @@ class GroupByNode : public ExecNode {
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});
- size_t thread_index = get_thread_index_();
+ size_t thread_index = plan_->GetThreadIndex();
if (thread_index >= local_states_.size()) {
return Status::IndexError("thread index ", thread_index, " is out of
range [0, ",
local_states_.size(), ")");
@@ -465,47 +473,32 @@ class GroupByNode : public ExecNode {
std::move(out_keys.values.begin(), out_keys.values.end(),
out_data.values.begin() + agg_kernels_.size());
state->grouper.reset();
-
- if (output_counter_.SetTotal(
- static_cast<int>(bit_util::CeilDiv(out_data.length,
output_batch_size())))) {
- // this will be hit if out_data.length == 0
- finished_.MarkFinished();
- }
return out_data;
}
- void OutputNthBatch(int n) {
+ void OutputNthBatch(int64_t n) {
// bail if StopProducing was called
if (finished_.is_finished()) return;
int64_t batch_size = output_batch_size();
outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n,
batch_size));
-
- if (output_counter_.Increment()) {
- finished_.MarkFinished();
- }
}
Status OutputResult() {
- RETURN_NOT_OK(Merge());
- ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());
-
- int num_output_batches = *output_counter_.total();
- outputs_[0]->InputFinished(this, num_output_batches);
-
- auto executor = ctx_->executor();
- for (int i = 0; i < num_output_batches; ++i) {
- if (executor) {
- // bail if StopProducing was called
- if (finished_.is_finished()) break;
-
- auto plan = this->plan()->shared_from_this();
- RETURN_NOT_OK(executor->Spawn([plan, this, i] { OutputNthBatch(i); }));
- } else {
- OutputNthBatch(i);
+ // To simplify merging, ensure that the first grouper is nonempty
+ for (size_t i = 0; i < local_states_.size(); i++) {
+ if (local_states_[i].grouper) {
+ std::swap(local_states_[i], local_states_[0]);
+ break;
}
}
+ RETURN_NOT_OK(Merge());
+ ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());
+
+ int64_t num_output_batches = bit_util::CeilDiv(out_data_.length,
output_batch_size());
+ outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
+ RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_,
num_output_batches));
return Status::OK();
}
@@ -555,10 +548,8 @@ class GroupByNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
- finished_ = Future<>::Make();
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
- local_states_.resize(ThreadIndexer::Capacity());
+ local_states_.resize(plan_->max_concurrency());
return Status::OK();
}
@@ -576,17 +567,12 @@ class GroupByNode : public ExecNode {
EVENT(span_, "StopProducing");
DCHECK_EQ(output, outputs_[0]);
- ARROW_UNUSED(input_counter_.Cancel());
- if (output_counter_.Cancel()) {
- finished_.MarkFinished();
- }
+ if (input_counter_.Cancel()) finished_.MarkFinished();
inputs_[0]->StopProducing(this);
}
void StopProducing() override { StopProducing(outputs_[0]); }
- Future<> finished() override { return finished_; }
-
protected:
std::string ToStringExtra(int indent = 0) const override {
std::stringstream ss;
@@ -608,7 +594,7 @@ class GroupByNode : public ExecNode {
};
ThreadLocalState* GetLocalState() {
- size_t thread_index = get_thread_index_();
+ size_t thread_index = plan_->GetThreadIndex();
return &local_states_[thread_index];
}
@@ -650,14 +636,14 @@ class GroupByNode : public ExecNode {
}
ExecContext* ctx_;
+ int output_task_group_id_;
const std::vector<int> key_field_ids_;
const std::vector<int> agg_src_field_ids_;
const std::vector<Aggregate> aggs_;
const std::vector<const HashAggregateKernel*> agg_kernels_;
- ThreadIndexer get_thread_index_;
- AtomicCounter input_counter_, output_counter_;
+ AtomicCounter input_counter_;
std::vector<ThreadLocalState> local_states_;
ExecBatch out_data_;
diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc
b/cpp/src/arrow/compute/exec/exec_plan.cc
index 468d0accea..e248782ded 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -24,6 +24,7 @@
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/task_util.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/datum.h"
@@ -56,6 +57,9 @@ struct ExecPlanImpl : public ExecPlan {
}
}
+ size_t GetThreadIndex() { return thread_indexer_(); }
+ size_t max_concurrency() const { return thread_indexer_.Capacity(); }
+
ExecNode* AddNode(std::unique_ptr<ExecNode> node) {
if (node->label().empty()) {
node->SetLabel(std::to_string(auto_label_counter_++));
@@ -70,6 +74,34 @@ struct ExecPlanImpl : public ExecPlan {
return nodes_.back().get();
}
+ Status AddFuture(Future<> fut) { return
task_group_.AddTaskIfNotEnded(std::move(fut)); }
+
+ Status ScheduleTask(std::function<Status()> fn) {
+ auto executor = exec_context_->executor();
+ if (!executor) return fn();
+ // Atomically submit fn to the executor, and if successful
+ // add it to the task group.
+ return task_group_.AddTaskIfNotEnded(
+ [executor, fn]() { return executor->Submit(std::move(fn)); });
+ }
+
+ Status ScheduleTask(std::function<Status(size_t)> fn) {
+ std::function<Status()> indexed_fn = [this, fn]() {
+ size_t thread_index = GetThreadIndex();
+ return fn(thread_index);
+ };
+ return ScheduleTask(std::move(indexed_fn));
+ }
+
+ int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+ std::function<Status(size_t)> on_finished) {
+ return task_scheduler_->RegisterTaskGroup(std::move(task),
std::move(on_finished));
+ }
+
+ Status StartTaskGroup(int task_group_id, int64_t num_tasks) {
+ return task_scheduler_->StartTaskGroup(GetThreadIndex(), task_group_id,
num_tasks);
+ }
+
Status Validate() const {
if (nodes_.empty()) {
return Status::Invalid("ExecPlan has no node");
@@ -96,15 +128,35 @@ struct ExecPlanImpl : public ExecPlan {
if (started_) {
return Status::Invalid("restarted ExecPlan");
}
- started_ = true;
- // producers precede consumers
- sorted_nodes_ = TopoSort();
- for (ExecNode* node : sorted_nodes_) {
- RETURN_NOT_OK(node->PrepareToProduce());
+ std::vector<Future<>> futures;
+ for (auto& n : nodes_) {
+ RETURN_NOT_OK(n->Init());
+ futures.push_back(n->finished());
}
- std::vector<Future<>> futures;
+ AllFinished(futures).AddCallback([this](const Status& st) {
+ error_st_ = st;
+ EndTaskGroup();
+ });
+
+ task_scheduler_->RegisterEnd();
+ int num_threads = 1;
+ bool sync_execution = true;
+ if (auto executor = exec_context()->executor()) {
+ num_threads = executor->GetCapacity();
+ sync_execution = false;
+ }
+ RETURN_NOT_OK(task_scheduler_->StartScheduling(
+ 0 /* thread_index */,
+ [this](std::function<Status(size_t)> fn) -> Status {
+ return this->ScheduleTask(std::move(fn));
+ },
+ /*concurrent_tasks=*/2 * num_threads, sync_execution));
+
+ started_ = true;
+ // producers precede consumers
+ sorted_nodes_ = TopoSort();
Status st = Status::OK();
@@ -120,23 +172,34 @@ struct ExecPlanImpl : public ExecPlan {
// Stop nodes that successfully started, in reverse order
stopped_ = true;
StopProducingImpl(it.base(), sorted_nodes_.end());
- break;
+ for (NodeVector::iterator fw_it = sorted_nodes_.begin(); fw_it !=
it.base();
+ ++fw_it) {
+ Future<> fut = (*fw_it)->finished();
+ if (!fut.is_finished()) fut.MarkFinished();
+ }
+ return st;
}
-
- futures.push_back(node->finished());
}
-
- finished_ = AllFinished(futures);
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
return st;
}
+ void EndTaskGroup() {
+ bool expected = false;
+ if (group_ended_.compare_exchange_strong(expected, true)) {
+ task_group_.End().AddCallback([this](const Status& st) {
+ MARK_SPAN(span_, error_st_ & st);
+ END_SPAN(span_);
+ finished_.MarkFinished(error_st_ & st);
+ });
+ }
+ }
+
void StopProducing() {
DCHECK(started_) << "stopped an ExecPlan which never started";
EVENT(span_, "StopProducing");
stopped_ = true;
-
- StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end());
+ task_scheduler_->Abort(
+ [this]() { StopProducingImpl(sorted_nodes_.begin(),
sorted_nodes_.end()); });
}
template <typename It>
@@ -242,7 +305,8 @@ struct ExecPlanImpl : public ExecPlan {
return ss.str();
}
- Future<> finished_ = Future<>::MakeFinished();
+ Status error_st_;
+ Future<> finished_ = Future<>::Make();
bool started_ = false, stopped_ = false;
std::vector<std::unique_ptr<ExecNode>> nodes_;
NodeVector sources_, sinks_;
@@ -250,6 +314,11 @@ struct ExecPlanImpl : public ExecPlan {
uint32_t auto_label_counter_ = 0;
util::tracing::Span span_;
std::shared_ptr<const KeyValueMetadata> metadata_;
+
+ ThreadIndexer thread_indexer_;
+ std::atomic<bool> group_ended_{false};
+ util::AsyncTaskGroup task_group_;
+ std::unique_ptr<TaskScheduler> task_scheduler_ = TaskScheduler::Make();
};
ExecPlanImpl* ToDerived(ExecPlan* ptr) { return
checked_cast<ExecPlanImpl*>(ptr); }
@@ -283,6 +352,26 @@ const ExecPlan::NodeVector& ExecPlan::sources() const {
const ExecPlan::NodeVector& ExecPlan::sinks() const { return
ToDerived(this)->sinks_; }
+size_t ExecPlan::GetThreadIndex() { return ToDerived(this)->GetThreadIndex(); }
+size_t ExecPlan::max_concurrency() const { return
ToDerived(this)->max_concurrency(); }
+
+Status ExecPlan::AddFuture(Future<> fut) {
+ return ToDerived(this)->AddFuture(std::move(fut));
+}
+Status ExecPlan::ScheduleTask(std::function<Status()> fn) {
+ return ToDerived(this)->ScheduleTask(std::move(fn));
+}
+Status ExecPlan::ScheduleTask(std::function<Status(size_t)> fn) {
+ return ToDerived(this)->ScheduleTask(std::move(fn));
+}
+int ExecPlan::RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+ std::function<Status(size_t)> on_finished) {
+ return ToDerived(this)->RegisterTaskGroup(std::move(task),
std::move(on_finished));
+}
+Status ExecPlan::StartTaskGroup(int task_group_id, int64_t num_tasks) {
+ return ToDerived(this)->StartTaskGroup(task_group_id, num_tasks);
+}
+
Status ExecPlan::Validate() { return ToDerived(this)->Validate(); }
Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); }
@@ -312,6 +401,8 @@ ExecNode::ExecNode(ExecPlan* plan, NodeVector inputs,
}
}
+Status ExecNode::Init() { return Status::OK(); }
+
Status ExecNode::Validate() const {
if (inputs_.size() != input_labels_.size()) {
return Status::Invalid("Invalid number of inputs for '", label(), "'
(expected ",
@@ -395,8 +486,6 @@ Status MapNode::StartProducing() {
START_COMPUTE_SPAN(
span_, std::string(kind_name()) + ":" + label(),
{{"node.label", label()}, {"node.detail", ToString()}, {"node.kind",
kind_name()}});
- finished_ = Future<>::Make();
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
return Status::OK();
}
@@ -424,8 +513,6 @@ void MapNode::StopProducing() {
inputs_[0]->StopProducing(this);
}
-Future<> MapNode::finished() { return finished_; }
-
void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
ExecBatch batch) {
Status status;
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h
b/cpp/src/arrow/compute/exec/exec_plan.h
index dcf271bd36..c8599748de 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -61,6 +61,60 @@ class ARROW_EXPORT ExecPlan : public
std::enable_shared_from_this<ExecPlan> {
return out;
}
+ /// \brief Returns the index of the current thread.
+ size_t GetThreadIndex();
+ /// \brief Returns the maximum number of threads that the plan could use.
+ ///
+ /// GetThreadIndex will always return something less than this, so it is
safe to
+ /// e.g. make an array of thread-locals off this.
+ size_t max_concurrency() const;
+
+ /// \brief Add a future to the plan's task group.
+ ///
+ /// \param fut The future to add
+ ///
+ /// Use this when interfacing with anything that returns a future (such as
IO), but
+ /// prefer ScheduleTask/StartTaskGroup inside of ExecNodes.
+ /// The below API interfaces with the scheduler to add tasks to the task
group. Tasks
+ /// should be added sparingly! Prefer just doing the work immediately rather
than adding
+ /// a task for it. Tasks are used in pipeline breakers that may output many
more rows
+ /// than they received (such as a full outer join).
+ Status AddFuture(Future<> fut);
+
+ /// \brief Add a single function as a task to the plan's task group.
+ ///
+ /// \param fn The task to run. Takes no arguments and returns a Status.
+ Status ScheduleTask(std::function<Status()> fn);
+
+ /// \brief Add a single function as a task to the plan's task group.
+ ///
+ /// \param fn The task to run. Takes the thread index and returns a Status.
+ Status ScheduleTask(std::function<Status(size_t)> fn);
+ // Register/Start TaskGroup is a way of performing a "Parallel For" pattern:
+ // - The task function takes the thread index and the index of the task
+ // - The on_finished function takes the thread index
+ // Returns an integer ID that will be used to reference the task group in
+ // StartTaskGroup. At runtime, call StartTaskGroup with the ID and the
number of times
+ // you'd like the task to be executed. The need to register a task group
before use will
+ // be removed after we rewrite the scheduler.
+ /// \brief Register a "parallel for" task group with the scheduler
+ ///
+ /// \param task The function implementing the task. Takes the thread_index
and
+ /// the task index.
+ /// \param on_finished The function that gets run once all tasks have been
completed.
+ /// Takes the thread_index.
+ ///
+ /// Must be called inside of ExecNode::Init.
+ int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+ std::function<Status(size_t)> on_finished);
+
+ /// \brief Start the task group with the specified ID. This can only
+ /// be called once per task_group_id.
+ ///
+ /// \param task_group_id The ID of the task group to run
+ /// \param num_tasks The number of times to run the task
+ Status StartTaskGroup(int task_group_id, int64_t num_tasks);
+
/// The initial inputs
const NodeVector& sources() const;
@@ -157,6 +211,16 @@ class ARROW_EXPORT ExecNode {
/// knows when it has received all input, regardless of order.
virtual void InputFinished(ExecNode* input, int total_batches) = 0;
+ /// \brief Perform any needed initialization
+ ///
+ /// This hook performs any actions in between creation of ExecPlan and the
call to
+ /// StartProducing. An example could be Bloom filter pushdown. The order of
ExecNodes
+ /// that executes this method is undefined, but the calls are made
synchronously.
+ ///
+ /// At this point a node can rely on all inputs & outputs (and the input
schemas)
+ /// being well defined.
+ virtual Status Init();
+
/// Lifecycle API:
/// - start / stop to initiate and terminate production
/// - pause / resume to apply backpressure
@@ -212,16 +276,6 @@ class ARROW_EXPORT ExecNode {
// A node with multiple outputs will also need to ensure it is applying
backpressure if
// any of its outputs is asking to pause
- /// \brief Perform any needed initialization
- ///
- /// This hook performs any actions in between creation of ExecPlan and the
call to
- /// StartProducing. An example could be Bloom filter pushdown. The order of
ExecNodes
- /// that executes this method is undefined, but the calls are made
synchronously.
- ///
- /// At this point a node can rely on all inputs & outputs (and the input
schemas)
- /// being well defined.
- virtual Status PrepareToProduce() { return Status::OK(); }
-
/// \brief Start producing
///
/// This must only be called once. If this fails, then other lifecycle
@@ -263,7 +317,7 @@ class ARROW_EXPORT ExecNode {
virtual void StopProducing() = 0;
/// \brief A future which will be marked finished when this node has stopped
producing.
- virtual Future<> finished() = 0;
+ virtual Future<> finished() { return finished_; }
std::string ToString(int indent = 0) const;
@@ -289,7 +343,7 @@ class ARROW_EXPORT ExecNode {
NodeVector outputs_;
// Future to sync finished
- Future<> finished_ = Future<>::MakeFinished();
+ Future<> finished_ = Future<>::Make();
util::tracing::Span span_;
};
@@ -321,8 +375,6 @@ class ARROW_EXPORT MapNode : public ExecNode {
void StopProducing() override;
- Future<> finished() override;
-
protected:
void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
ExecBatch batch);
diff --git a/cpp/src/arrow/compute/exec/hash_join.cc
b/cpp/src/arrow/compute/exec/hash_join.cc
index e821979ae1..07a3083fb9 100644
--- a/cpp/src/arrow/compute/exec/hash_join.cc
+++ b/cpp/src/arrow/compute/exec/hash_join.cc
@@ -44,8 +44,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
+ RegisterTaskGroupCallback register_task_group_callback,
+ StartTaskGroupCallback start_task_group_callback,
OutputBatchCallback output_batch_callback,
- FinishedCallback finished_callback, TaskScheduler* scheduler)
override {
+ FinishedCallback finished_callback) override {
START_COMPUTE_SPAN(span_, "HashJoinBasicImpl",
{{"detail", filter.ToString()},
{"join.kind", arrow::compute::ToString(join_type)},
@@ -58,10 +60,12 @@ class HashJoinBasicImpl : public HashJoinImpl {
schema_[1] = proj_map_right;
key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
+ register_task_group_callback_ = std::move(register_task_group_callback);
+ start_task_group_callback_ = std::move(start_task_group_callback);
output_batch_callback_ = std::move(output_batch_callback);
finished_callback_ = std::move(finished_callback);
- scheduler_ = scheduler;
local_states_.resize(num_threads_);
+
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
local_states_[i].is_has_match_initialized = false;
@@ -78,11 +82,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
return Status::OK();
}
- void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override
{
+ void Abort(AbortContinuationImpl pos_abort_callback) override {
EVENT(span_, "Abort");
END_SPAN(span_);
cancelled_ = true;
- scheduler_->Abort(std::move(pos_abort_callback));
+ pos_abort_callback();
}
std::string ToString() const override { return "HashJoinBasicImpl"; }
@@ -546,7 +550,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
void RegisterBuildHashTable() {
- task_group_build_ = scheduler_->RegisterTaskGroup(
+ task_group_build_ = register_task_group_callback_(
[this](size_t thread_index, int64_t task_id) -> Status {
return BuildHashTable_exec_task(thread_index, task_id);
},
@@ -605,16 +609,16 @@ class HashJoinBasicImpl : public HashJoinImpl {
return build_finished_callback_(thread_index);
}
- Status BuildHashTable(size_t thread_index, AccumulationQueue batches,
+ Status BuildHashTable(size_t /*thread_index*/, AccumulationQueue batches,
BuildFinishedCallback on_finished) override {
build_finished_callback_ = std::move(on_finished);
build_batches_ = std::move(batches);
- return scheduler_->StartTaskGroup(thread_index, task_group_build_,
+ return start_task_group_callback_(task_group_build_,
/*num_tasks=*/1);
}
void RegisterScanHashTable() {
- task_group_scan_ = scheduler_->RegisterTaskGroup(
+ task_group_scan_ = register_task_group_callback_(
[this](size_t thread_index, int64_t task_id) -> Status {
return ScanHashTable_exec_task(thread_index, task_id);
},
@@ -689,8 +693,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
Status ScanHashTable(size_t thread_index) {
MergeHasMatch();
- return scheduler_->StartTaskGroup(thread_index, task_group_scan_,
- ScanHashTable_num_tasks());
+ return start_task_group_callback_(task_group_scan_,
ScanHashTable_num_tasks());
}
Status ProbingFinished(size_t thread_index) override {
@@ -741,12 +744,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
const HashJoinProjectionMaps* schema_[2];
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
- TaskScheduler* scheduler_;
int task_group_build_;
int task_group_scan_;
// Callbacks
//
+ RegisterTaskGroupCallback register_task_group_callback_;
+ StartTaskGroupCallback start_task_group_callback_;
OutputBatchCallback output_batch_callback_;
BuildFinishedCallback build_finished_callback_;
FinishedCallback finished_callback_;
diff --git a/cpp/src/arrow/compute/exec/hash_join.h
b/cpp/src/arrow/compute/exec/hash_join.h
index 19add7d440..0c5e43467e 100644
--- a/cpp/src/arrow/compute/exec/hash_join.h
+++ b/cpp/src/arrow/compute/exec/hash_join.h
@@ -41,14 +41,20 @@ class HashJoinImpl {
using OutputBatchCallback = std::function<void(int64_t, ExecBatch)>;
using BuildFinishedCallback = std::function<Status(size_t)>;
using FinishedCallback = std::function<void(int64_t)>;
+ using RegisterTaskGroupCallback = std::function<int(
+ std::function<Status(size_t, int64_t)>, std::function<Status(size_t)>)>;
+ using StartTaskGroupCallback = std::function<Status(int, int64_t)>;
+ using AbortContinuationImpl = std::function<void()>;
virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads,
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
+ RegisterTaskGroupCallback register_task_group_callback,
+ StartTaskGroupCallback start_task_group_callback,
OutputBatchCallback output_batch_callback,
- FinishedCallback finished_callback, TaskScheduler*
scheduler) = 0;
+ FinishedCallback finished_callback) = 0;
virtual Status BuildHashTable(size_t thread_index, AccumulationQueue batches,
BuildFinishedCallback on_finished) = 0;
diff --git a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
index 97badb8423..94201a849f 100644
--- a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
@@ -19,6 +19,7 @@
#include "arrow/api.h"
#include "arrow/compute/exec/hash_join.h"
+#include "arrow/compute/exec/hash_join_node.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/compute/exec/util.h"
@@ -56,8 +57,6 @@ struct BenchmarkSettings {
class JoinBenchmark {
public:
explicit JoinBenchmark(BenchmarkSettings& settings) {
- bool is_parallel = settings.num_threads != 1;
-
SchemaBuilder l_schema_builder, r_schema_builder;
std::vector<FieldRef> left_keys, right_keys;
std::vector<JoinKeyCmp> key_cmp;
@@ -127,9 +126,8 @@ class JoinBenchmark {
stats_.num_probe_rows = settings.num_probe_batches * settings.batch_size;
- ctx_ = arrow::internal::make_unique<ExecContext>(
- default_memory_pool(),
- is_parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+ ctx_ = arrow::internal::make_unique<ExecContext>(default_memory_pool(),
+
arrow::internal::GetCpuThreadPool());
schema_mgr_ = arrow::internal::make_unique<HashJoinSchema>();
Expression filter = literal(true);
@@ -151,10 +149,22 @@ class JoinBenchmark {
};
scheduler_ = TaskScheduler::Make();
+
+ auto register_task_group_callback = [&](std::function<Status(size_t,
int64_t)> task,
+ std::function<Status(size_t)>
cont) {
+ return scheduler_->RegisterTaskGroup(std::move(task), std::move(cont));
+ };
+
+ auto start_task_group_callback = [&](int task_group_id, int64_t num_tasks)
{
+ return scheduler_->StartTaskGroup(omp_get_thread_num(), task_group_id,
num_tasks);
+ };
+
DCHECK_OK(join_->Init(
ctx_.get(), settings.join_type, settings.num_threads,
&(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]),
std::move(key_cmp),
- std::move(filter), [](ExecBatch) {}, [](int64_t x) {},
scheduler_.get()));
+ std::move(filter), std::move(register_task_group_callback),
+ std::move(start_task_group_callback), [](int64_t, ExecBatch) {},
+ [](int64_t x) {}));
task_group_probe_ = scheduler_->RegisterTaskGroup(
[this](size_t thread_index, int64_t task_id) -> Status {
@@ -168,7 +178,8 @@ class JoinBenchmark {
DCHECK_OK(scheduler_->StartScheduling(
0 /*thread index*/, std::move(schedule_callback),
- static_cast<int>(2 * settings.num_threads) /*concurrent tasks*/,
!is_parallel));
+ static_cast<int>(2 * settings.num_threads) /*concurrent tasks*/,
+ settings.num_threads == 1));
}
void RunJoin() {
diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc
b/cpp/src/arrow/compute/exec/hash_join_node.cc
index 73df78b46e..1785df8784 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -483,10 +483,15 @@ class HashJoinNode;
// on every batch that has been queued so far as well as any new probe-side
batch that
// comes in.
struct BloomFilterPushdownContext {
+ using RegisterTaskGroupCallback = std::function<int(
+ std::function<Status(size_t, int64_t)>, std::function<Status(size_t)>)>;
+ using StartTaskGroupCallback = std::function<Status(int, int64_t)>;
using BuildFinishedCallback = std::function<Status(size_t,
AccumulationQueue)>;
using FiltersReceivedCallback = std::function<Status()>;
using FilterFinishedCallback = std::function<Status(size_t,
AccumulationQueue)>;
- void Init(HashJoinNode* owner, size_t num_threads, TaskScheduler* scheduler,
+ void Init(HashJoinNode* owner, size_t num_threads,
+ RegisterTaskGroupCallback register_task_group_callback,
+ StartTaskGroupCallback start_task_group_callback,
FiltersReceivedCallback on_bloom_filters_received, bool
disable_bloom_filter,
bool use_sync_execution);
@@ -531,7 +536,7 @@ struct BloomFilterPushdownContext {
if (eval_.num_expected_bloom_filters_ == 0)
return eval_.on_finished_(thread_index, std::move(eval_.batches_));
- return scheduler_->StartTaskGroup(thread_index, eval_.task_id_,
+ return start_task_group_callback_(eval_.task_id_,
/*num_tasks=*/eval_.batches_.batch_count());
}
@@ -621,10 +626,10 @@ struct BloomFilterPushdownContext {
return &tld_[thread_index].stack;
}
+ StartTaskGroupCallback start_task_group_callback_;
bool disable_bloom_filter_;
HashJoinSchema* schema_mgr_;
ExecContext* ctx_;
- TaskScheduler* scheduler_;
struct ThreadLocalData {
bool is_init = false;
@@ -834,7 +839,7 @@ class HashJoinNode : public ExecNode {
Status OnFiltersReceived() {
std::unique_lock<std::mutex> guard(probe_side_mutex_);
bloom_filters_ready_ = true;
- size_t thread_index = thread_indexer_();
+ size_t thread_index = plan_->GetThreadIndex();
AccumulationQueue batches = std::move(probe_accumulator_);
guard.unlock();
return pushdown_context_.FilterBatches(
@@ -863,8 +868,8 @@ class HashJoinNode : public ExecNode {
std::lock_guard<std::mutex> guard(probe_side_mutex_);
queued_batches_to_probe_ = std::move(probe_accumulator_);
}
- return scheduler_->StartTaskGroup(thread_index, task_group_probe_,
- queued_batches_to_probe_.batch_count());
+ return plan_->StartTaskGroup(task_group_probe_,
+ queued_batches_to_probe_.batch_count());
}
Status OnQueuedBatchesProbed(size_t thread_index) {
@@ -885,7 +890,7 @@ class HashJoinNode : public ExecNode {
return;
}
- size_t thread_index = thread_indexer_();
+ size_t thread_index = plan_->GetThreadIndex();
int side = (input == inputs_[0]) ? 0 : 1;
EVENT(span_, "InputReceived", {{"batch.length", batch.length}, {"side",
side}});
@@ -923,8 +928,7 @@ class HashJoinNode : public ExecNode {
void InputFinished(ExecNode* input, int total_batches) override {
ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) !=
inputs_.end());
-
- size_t thread_index = thread_indexer_();
+ size_t thread_index = plan_->GetThreadIndex();
int side = (input == inputs_[0]) ? 0 : 1;
EVENT(span_, "InputFinished", {{"side", side}, {"batches.length",
total_batches}});
@@ -941,28 +945,42 @@ class HashJoinNode : public ExecNode {
}
}
- Status PrepareToProduce() override {
+ Status Init() override {
+ RETURN_NOT_OK(ExecNode::Init());
bool use_sync_execution = !(plan_->exec_context()->executor());
// TODO(ARROW-15732)
// Each side of join might have an IO thread being called from. Once this
is fixed
// we will change it back to just the CPU's thread pool capacity.
size_t num_threads = (GetCpuThreadPoolCapacity() +
io::GetIOThreadPoolCapacity() + 1);
- scheduler_ = TaskScheduler::Make();
pushdown_context_.Init(
- this, num_threads, scheduler_.get(), [this]() { return
OnFiltersReceived(); },
- disable_bloom_filter_, use_sync_execution);
+ this, num_threads,
+ [this](std::function<Status(size_t, int64_t)> fn,
+ std::function<Status(size_t)> on_finished) {
+ return plan_->RegisterTaskGroup(std::move(fn),
std::move(on_finished));
+ },
+ [this](int task_group_id, int64_t num_tasks) {
+ return plan_->StartTaskGroup(task_group_id, num_tasks);
+ },
+ [this]() { return OnFiltersReceived(); }, disable_bloom_filter_,
+ use_sync_execution);
RETURN_NOT_OK(impl_->Init(
plan_->exec_context(), join_type_, num_threads,
&(schema_mgr_->proj_maps[0]),
&(schema_mgr_->proj_maps[1]), key_cmp_, filter_,
- [this](int64_t /*ignored*/, ExecBatch batch) {
- this->OutputBatchCallback(batch);
+ [this](std::function<Status(size_t, int64_t)> fn,
+ std::function<Status(size_t)> on_finished) {
+ return plan_->RegisterTaskGroup(std::move(fn),
std::move(on_finished));
+ },
+ [this](int task_group_id, int64_t num_tasks) {
+ return plan_->StartTaskGroup(task_group_id, num_tasks);
},
- [this](int64_t total_num_batches) {
this->FinishedCallback(total_num_batches); },
- scheduler_.get()));
+ [this](int64_t, ExecBatch batch) { this->OutputBatchCallback(batch); },
+ [this](int64_t total_num_batches) {
+ this->FinishedCallback(total_num_batches);
+ }));
- task_group_probe_ = scheduler_->RegisterTaskGroup(
+ task_group_probe_ = plan_->RegisterTaskGroup(
[this](size_t thread_index, int64_t task_id) -> Status {
return impl_->ProbeSingleBatch(thread_index,
std::move(queued_batches_to_probe_[task_id]));
@@ -971,14 +989,6 @@ class HashJoinNode : public ExecNode {
return OnQueuedBatchesProbed(thread_index);
});
- scheduler_->RegisterEnd();
-
- RETURN_NOT_OK(scheduler_->StartScheduling(
- 0 /*thread index*/,
- [this](std::function<Status(size_t)> func) -> Status {
- return this->ScheduleTaskCallback(std::move(func));
- },
- static_cast<int>(2 * num_threads) /*concurrent tasks*/,
use_sync_execution));
return Status::OK();
}
@@ -987,7 +997,7 @@ class HashJoinNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished(), this);
+ END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
RETURN_NOT_OK(pushdown_context_.StartProducing());
return Status::OK();
}
@@ -1012,12 +1022,10 @@ class HashJoinNode : public ExecNode {
for (auto&& input : inputs_) {
input->StopProducing(this);
}
- impl_->Abort([this]() { ARROW_UNUSED(task_group_.End()); });
+ impl_->Abort([this]() { finished_.MarkFinished(); });
}
}
- Future<> finished() override { return task_group_.OnFinished(); }
-
protected:
std::string ToStringExtra(int indent = 0) const override {
return "implementation=" + impl_->ToString();
@@ -1032,42 +1040,18 @@ class HashJoinNode : public ExecNode {
bool expected = false;
if (complete_.compare_exchange_strong(expected, true)) {
outputs_[0]->InputFinished(this, static_cast<int>(total_num_batches));
- ARROW_UNUSED(task_group_.End());
+ finished_.MarkFinished();
}
}
- Status ScheduleTaskCallback(std::function<Status(size_t)> func) {
- auto executor = plan_->exec_context()->executor();
- if (executor) {
- return task_group_.AddTask([this, executor, func] {
- return DeferNotOk(executor->Submit([this, func] {
- size_t thread_index = thread_indexer_();
- Status status = func(thread_index);
- if (!status.ok()) {
- StopProducing();
- ErrorIfNotOk(status);
- return;
- }
- }));
- });
- } else {
- // We should not get here in serial execution mode
- ARROW_DCHECK(false);
- }
- return Status::OK();
- }
-
private:
AtomicCounter batch_count_[2];
std::atomic<bool> complete_;
JoinType join_type_;
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
- ThreadIndexer thread_indexer_;
std::unique_ptr<HashJoinSchema> schema_mgr_;
std::unique_ptr<HashJoinImpl> impl_;
- util::AsyncTaskGroup task_group_;
- std::unique_ptr<TaskScheduler> scheduler_;
util::AccumulationQueue build_accumulator_;
util::AccumulationQueue probe_accumulator_;
util::AccumulationQueue queued_batches_to_probe_;
@@ -1087,14 +1071,14 @@ class HashJoinNode : public ExecNode {
BloomFilterPushdownContext pushdown_context_;
};
-void BloomFilterPushdownContext::Init(HashJoinNode* owner, size_t num_threads,
- TaskScheduler* scheduler,
- FiltersReceivedCallback
on_bloom_filters_received,
- bool disable_bloom_filter,
- bool use_sync_execution) {
+void BloomFilterPushdownContext::Init(
+ HashJoinNode* owner, size_t num_threads,
+ RegisterTaskGroupCallback register_task_group_callback,
+ StartTaskGroupCallback start_task_group_callback,
+ FiltersReceivedCallback on_bloom_filters_received, bool
disable_bloom_filter,
+ bool use_sync_execution) {
schema_mgr_ = owner->schema_mgr_.get();
ctx_ = owner->plan_->exec_context();
- scheduler_ = scheduler;
tld_.resize(num_threads);
disable_bloom_filter_ = disable_bloom_filter;
std::tie(push_.pushdown_target_, push_.column_map_) =
GetPushdownTarget(owner);
@@ -1108,7 +1092,7 @@ void BloomFilterPushdownContext::Init(HashJoinNode*
owner, size_t num_threads,
use_sync_execution ? BloomFilterBuildStrategy::SINGLE_THREADED
: BloomFilterBuildStrategy::PARALLEL);
- build_.task_id_ = scheduler_->RegisterTaskGroup(
+ build_.task_id_ = register_task_group_callback(
[this](size_t thread_index, int64_t task_id) {
return BuildBloomFilter_exec_task(thread_index, task_id);
},
@@ -1117,13 +1101,14 @@ void BloomFilterPushdownContext::Init(HashJoinNode*
owner, size_t num_threads,
});
}
- eval_.task_id_ = scheduler_->RegisterTaskGroup(
+ eval_.task_id_ = register_task_group_callback(
[this](size_t thread_index, int64_t task_id) {
return FilterSingleBatch(thread_index, &eval_.batches_[task_id]);
},
[this](size_t thread_index) {
return eval_.on_finished_(thread_index, std::move(eval_.batches_));
});
+ start_task_group_callback_ = std::move(start_task_group_callback);
}
Status BloomFilterPushdownContext::StartProducing() {
@@ -1145,7 +1130,7 @@ Status
BloomFilterPushdownContext::BuildBloomFilter(size_t thread_index,
ctx_->memory_pool(), build_.batches_.row_count(),
build_.batches_.batch_count(),
push_.bloom_filter_.get()));
- return scheduler_->StartTaskGroup(thread_index, build_.task_id_,
+ return start_task_group_callback_(build_.task_id_,
/*num_tasks=*/build_.batches_.batch_count());
}
diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc
b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
index 9a3c734278..b4fd7ee643 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
@@ -1747,6 +1747,9 @@ TEST(HashJoin, DictNegative) {
EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Unifying differing
dictionaries"),
StartAndCollect(plan.get(), sink_gen));
+ // Since we returned an error, the StartAndCollect future may return before
+ // the plan is done finishing.
+ plan->finished().Wait();
}
}
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc
b/cpp/src/arrow/compute/exec/sink_node.cc
index eae12bf729..9118d5a50e 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -136,9 +136,7 @@ class SinkNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
- finished_ = Future<>::Make();
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
-
+ END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
return Status::OK();
}
@@ -161,8 +159,6 @@ class SinkNode : public ExecNode {
inputs_[0]->StopProducing(this);
}
- Future<> finished() override { return finished_; }
-
void RecordBackpressureBytesUsed(const ExecBatch& batch) {
if (backpressure_queue_.enabled()) {
uint64_t bytes_used = static_cast<uint64_t>(batch.TotalBufferSize());
@@ -287,6 +283,7 @@ class ConsumingSinkNode : public ExecNode, public
BackpressureControl {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
+ END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
DCHECK_GT(inputs_.size(), 0);
auto output_schema = inputs_[0]->output_schema();
if (names_.size() > 0) {
@@ -303,8 +300,6 @@ class ConsumingSinkNode : public ExecNode, public
BackpressureControl {
output_schema = schema(std::move(fields));
}
RETURN_NOT_OK(consumer_->Init(output_schema, this));
- finished_ = Future<>::Make();
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
return Status::OK();
}
@@ -326,12 +321,10 @@ class ConsumingSinkNode : public ExecNode, public
BackpressureControl {
void StopProducing() override {
EVENT(span_, "StopProducing");
- Finish(Status::Invalid("ExecPlan was stopped early"));
+ Finish(Status::OK());
inputs_[0]->StopProducing(this);
}
- Future<> finished() override { return finished_; }
-
void InputReceived(ExecNode* input, ExecBatch batch) override {
EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
util::tracing::Span span;
@@ -365,9 +358,7 @@ class ConsumingSinkNode : public ExecNode, public
BackpressureControl {
EVENT(span_, "ErrorReceived", {{"error", error.message()}});
DCHECK_EQ(input, inputs_[0]);
- if (input_counter_.Cancel()) {
- Finish(std::move(error));
- }
+ if (input_counter_.Cancel()) Finish(error);
inputs_[0]->StopProducing(this);
}
diff --git a/cpp/src/arrow/compute/exec/source_node.cc
b/cpp/src/arrow/compute/exec/source_node.cc
index ec2b91050d..33072f0026 100644
--- a/cpp/src/arrow/compute/exec/source_node.cc
+++ b/cpp/src/arrow/compute/exec/source_node.cc
@@ -75,6 +75,7 @@ struct SourceNode : ExecNode {
{"node.label", label()},
{"node.output_schema", output_schema()->ToString()},
{"node.detail", ToString()}});
+ END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
{
// If another exec node encountered an error during its StartProducing
call
// it might have already called StopProducing on all of its inputs
(including this
@@ -96,66 +97,55 @@ struct SourceNode : ExecNode {
options.executor = executor;
options.should_schedule = ShouldSchedule::IfDifferentExecutor;
}
- finished_ = Loop([this, executor, options] {
- std::unique_lock<std::mutex> lock(mutex_);
- int total_batches = batch_count_++;
- if (stop_requested_) {
- return
Future<ControlFlow<int>>::MakeFinished(Break(total_batches));
- }
- lock.unlock();
-
- return generator_().Then(
- [=](const util::optional<ExecBatch>& maybe_batch)
- -> Future<ControlFlow<int>> {
- std::unique_lock<std::mutex> lock(mutex_);
- if (IsIterationEnd(maybe_batch) || stop_requested_) {
- stop_requested_ = true;
- return Break(total_batches);
- }
- lock.unlock();
- ExecBatch batch = std::move(*maybe_batch);
-
- if (executor) {
- auto status = task_group_.AddTask(
- [this, executor, batch]() -> Result<Future<>> {
- return executor->Submit([=]() {
- outputs_[0]->InputReceived(this,
std::move(batch));
- return Status::OK();
- });
- });
- if (!status.ok()) {
- outputs_[0]->ErrorReceived(this,
std::move(status));
- return Break(total_batches);
- }
- } else {
- outputs_[0]->InputReceived(this, std::move(batch));
- }
- lock.lock();
- if (!backpressure_future_.is_finished()) {
- EVENT(span_, "Source paused due to backpressure");
- return backpressure_future_.Then(
- []() -> ControlFlow<int> { return Continue(); });
- }
- return
Future<ControlFlow<int>>::MakeFinished(Continue());
- },
- [=](const Status& error) -> ControlFlow<int> {
- // NB: ErrorReceived is independent of InputFinished,
but
- // ErrorReceived will usually prompt StopProducing
which will
- // prompt InputFinished. ErrorReceived may still be
called from a
- // node which was requested to stop (indeed, the
request to stop
- // may prompt an error).
- std::unique_lock<std::mutex> lock(mutex_);
- stop_requested_ = true;
- lock.unlock();
- outputs_[0]->ErrorReceived(this, error);
- return Break(total_batches);
- },
- options);
- }).Then([&](int total_batches) {
- outputs_[0]->InputFinished(this, total_batches);
- return task_group_.End();
- });
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
+ started_ = true;
+ auto fut = Loop([this, options] {
+ std::unique_lock<std::mutex> lock(mutex_);
+ int total_batches = batch_count_++;
+ if (stop_requested_) {
+ return
Future<ControlFlow<int>>::MakeFinished(Break(total_batches));
+ }
+ lock.unlock();
+
+ return generator_().Then(
+ [=](const util::optional<ExecBatch>& maybe_batch)
+ -> Future<ControlFlow<int>> {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (IsIterationEnd(maybe_batch) || stop_requested_) {
+ stop_requested_ = true;
+ return Break(total_batches);
+ }
+ lock.unlock();
+ ExecBatch batch = std::move(*maybe_batch);
+ RETURN_NOT_OK(plan_->ScheduleTask([=]() {
+ outputs_[0]->InputReceived(this, std::move(batch));
+ return Status::OK();
+ }));
+ lock.lock();
+ if (!backpressure_future_.is_finished()) {
+ EVENT(span_, "Source paused due to backpressure");
+ return backpressure_future_.Then(
+ []() -> ControlFlow<int> { return Continue(); });
+ }
+ return
Future<ControlFlow<int>>::MakeFinished(Continue());
+ },
+ [=](const Status& error) -> ControlFlow<int> {
+ std::unique_lock<std::mutex> lock(mutex_);
+ stop_requested_ = true;
+ lock.unlock();
+ outputs_[0]->ErrorReceived(this, error);
+ finished_.MarkFinished(error);
+ return Break(total_batches);
+ },
+ options);
+ })
+ .Then(
+ [=](int total_batches) {
+ outputs_[0]->InputFinished(this, total_batches);
+ if (!finished_.is_finished())
finished_.MarkFinished();
+ },
+ {}, options);
+ if (!executor && finished_.is_finished()) return finished_.status();
+ RETURN_NOT_OK(plan_->AddFuture(fut));
return Status::OK();
}
@@ -196,17 +186,16 @@ struct SourceNode : ExecNode {
void StopProducing() override {
std::unique_lock<std::mutex> lock(mutex_);
stop_requested_ = true;
+ if (!started_) finished_.MarkFinished();
}
- Future<> finished() override { return finished_; }
-
private:
std::mutex mutex_;
int32_t backpressure_counter_{0};
Future<> backpressure_future_ = Future<>::MakeFinished();
bool stop_requested_{false};
+ bool started_ = false;
int batch_count_{0};
- util::AsyncTaskGroup task_group_;
AsyncGenerator<util::optional<ExecBatch>> generator_;
};
diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc
b/cpp/src/arrow/compute/exec/swiss_join.cc
index 5d70e01b1e..5b01edb119 100644
--- a/cpp/src/arrow/compute/exec/swiss_join.cc
+++ b/cpp/src/arrow/compute/exec/swiss_join.cc
@@ -2026,8 +2026,10 @@ class SwissJoin : public HashJoinImpl {
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
+ RegisterTaskGroupCallback register_task_group_callback,
+ StartTaskGroupCallback start_task_group_callback,
OutputBatchCallback output_batch_callback,
- FinishedCallback finished_callback, TaskScheduler* scheduler)
override {
+ FinishedCallback finished_callback) override {
START_COMPUTE_SPAN(span_, "SwissJoinImpl",
{{"detail", filter.ToString()},
{"join.kind", arrow::compute::ToString(join_type)},
@@ -2043,11 +2045,15 @@ class SwissJoin : public HashJoinImpl {
for (size_t i = 0; i < key_cmp.size(); ++i) {
key_cmp_[i] = key_cmp[i];
}
+
schema_[0] = proj_map_left;
schema_[1] = proj_map_right;
- output_batch_callback_ = output_batch_callback;
- finished_callback_ = finished_callback;
- scheduler_ = scheduler;
+
+ register_task_group_callback_ = std::move(register_task_group_callback);
+ start_task_group_callback_ = std::move(start_task_group_callback);
+ output_batch_callback_ = std::move(output_batch_callback);
+ finished_callback_ = std::move(finished_callback);
+
hash_table_ready_.store(false);
cancelled_.store(false);
{
@@ -2081,17 +2087,17 @@ class SwissJoin : public HashJoinImpl {
}
void InitTaskGroups() {
- task_group_build_ = scheduler_->RegisterTaskGroup(
+ task_group_build_ = register_task_group_callback_(
[this](size_t thread_index, int64_t task_id) -> Status {
return BuildTask(thread_index, task_id);
},
[this](size_t thread_index) -> Status { return
BuildFinished(thread_index); });
- task_group_merge_ = scheduler_->RegisterTaskGroup(
+ task_group_merge_ = register_task_group_callback_(
[this](size_t thread_index, int64_t task_id) -> Status {
return MergeTask(thread_index, task_id);
},
[this](size_t thread_index) -> Status { return
MergeFinished(thread_index); });
- task_group_scan_ = scheduler_->RegisterTaskGroup(
+ task_group_scan_ = register_task_group_callback_(
[this](size_t thread_index, int64_t task_id) -> Status {
return ScanTask(thread_index, task_id);
},
@@ -2136,11 +2142,11 @@ class SwissJoin : public HashJoinImpl {
return CancelIfNotOK(StartBuildHashTable(static_cast<int64_t>(thread_id)));
}
- void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override
{
+ void Abort(AbortContinuationImpl pos_abort_callback) override {
EVENT(span_, "Abort");
END_SPAN(span_);
std::ignore = CancelIfNotOK(Status::Cancelled("Hash Join Cancelled"));
- scheduler_->Abort(std::move(pos_abort_callback));
+ pos_abort_callback();
}
std::string ToString() const override { return "SwissJoin"; }
@@ -2176,9 +2182,8 @@ class SwissJoin : public HashJoinImpl {
// Process all input batches
//
- return
CancelIfNotOK(scheduler_->StartTaskGroup(static_cast<size_t>(thread_id),
- task_group_build_,
-
build_side_batches_.batch_count()));
+ return CancelIfNotOK(
+ start_task_group_callback_(task_group_build_,
build_side_batches_.batch_count()));
}
Status BuildTask(size_t thread_id, int64_t batch_id) {
@@ -2240,8 +2245,8 @@ class SwissJoin : public HashJoinImpl {
// table.
//
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PreparePrtnMerge()));
- return CancelIfNotOK(scheduler_->StartTaskGroup(thread_id,
task_group_merge_,
-
hash_table_build_.num_prtns()));
+ return CancelIfNotOK(
+ start_task_group_callback_(task_group_merge_,
hash_table_build_.num_prtns()));
}
Status MergeTask(size_t /*thread_id*/, int64_t prtn_id) {
@@ -2286,8 +2291,7 @@ class SwissJoin : public HashJoinImpl {
hash_table_.MergeHasMatch();
int64_t num_tasks = bit_util::CeilDiv(hash_table_.num_rows(),
kNumRowsPerScanTask);
- return
CancelIfNotOK(scheduler_->StartTaskGroup(static_cast<size_t>(thread_id),
- task_group_scan_,
num_tasks));
+ return CancelIfNotOK(start_task_group_callback_(task_group_scan_,
num_tasks));
} else {
return CancelIfNotOK(OnScanHashTableFinished());
}
@@ -2472,12 +2476,13 @@ class SwissJoin : public HashJoinImpl {
const HashJoinProjectionMaps* schema_[2];
// Task scheduling
- TaskScheduler* scheduler_;
int task_group_build_;
int task_group_merge_;
int task_group_scan_;
// Callbacks
+ RegisterTaskGroupCallback register_task_group_callback_;
+ StartTaskGroupCallback start_task_group_callback_;
OutputBatchCallback output_batch_callback_;
BuildFinishedCallback build_finished_callback_;
FinishedCallback finished_callback_;
diff --git a/cpp/src/arrow/compute/exec/test_util.cc
b/cpp/src/arrow/compute/exec/test_util.cc
index 330ee47112..b3e5d85b53 100644
--- a/cpp/src/arrow/compute/exec/test_util.cc
+++ b/cpp/src/arrow/compute/exec/test_util.cc
@@ -67,6 +67,7 @@ struct DummyNode : ExecNode {
for (size_t i = 0; i < input_labels_.size(); ++i) {
input_labels_[i] = std::to_string(i);
}
+ finished_.MarkFinished();
}
const char* kind_name() const override { return "Dummy"; }
@@ -111,8 +112,6 @@ struct DummyNode : ExecNode {
}
}
- Future<> finished() override { return Future<>::MakeFinished(); }
-
private:
void AssertIsOutput(ExecNode* output) {
auto it = std::find(outputs_.begin(), outputs_.end(), output);
diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc
b/cpp/src/arrow/compute/exec/tpch_node.cc
index d8f2c60312..d19f20eea7 100644
--- a/cpp/src/arrow/compute/exec/tpch_node.cc
+++ b/cpp/src/arrow/compute/exec/tpch_node.cc
@@ -663,8 +663,7 @@ class PartAndPartSupplierGenerator {
return SetOutputColumns(cols, kPartsuppTypes, kPartsuppNameMap,
partsupp_cols_);
}
- Result<util::optional<ExecBatch>> NextPartBatch() {
- size_t thread_index = thread_indexer_();
+ Result<util::optional<ExecBatch>> NextPartBatch(size_t thread_index) {
ThreadLocalData& tld = thread_local_data_[thread_index];
{
std::lock_guard<std::mutex> lock(part_output_queue_mutex_);
@@ -719,8 +718,7 @@ class PartAndPartSupplierGenerator {
return ExecBatch::Make(std::move(part_result));
}
- Result<util::optional<ExecBatch>> NextPartSuppBatch() {
- size_t thread_index = thread_indexer_();
+ Result<util::optional<ExecBatch>> NextPartSuppBatch(size_t thread_index) {
ThreadLocalData& tld = thread_local_data_[thread_index];
{
std::lock_guard<std::mutex> lock(partsupp_output_queue_mutex_);
@@ -1284,7 +1282,6 @@ class PartAndPartSupplierGenerator {
int64_t part_rows_generated_{0};
std::vector<int> part_cols_;
std::vector<int> partsupp_cols_;
- ThreadIndexer thread_indexer_;
std::atomic<int64_t> part_batches_generated_ = {0};
std::atomic<int64_t> partsupp_batches_generated_ = {0};
@@ -1326,8 +1323,7 @@ class OrdersAndLineItemGenerator {
return SetOutputColumns(cols, kLineitemTypes, kLineitemNameMap,
lineitem_cols_);
}
- Result<util::optional<ExecBatch>> NextOrdersBatch() {
- size_t thread_index = thread_indexer_();
+ Result<util::optional<ExecBatch>> NextOrdersBatch(size_t thread_index) {
ThreadLocalData& tld = thread_local_data_[thread_index];
{
std::lock_guard<std::mutex> lock(orders_output_queue_mutex_);
@@ -1382,8 +1378,7 @@ class OrdersAndLineItemGenerator {
return ExecBatch::Make(std::move(orders_result));
}
- Result<util::optional<ExecBatch>> NextLineItemBatch() {
- size_t thread_index = thread_indexer_();
+ Result<util::optional<ExecBatch>> NextLineItemBatch(size_t thread_index) {
ThreadLocalData& tld = thread_local_data_[thread_index];
ExecBatch queued;
bool from_queue = false;
@@ -2397,7 +2392,6 @@ class OrdersAndLineItemGenerator {
int64_t orders_rows_generated_;
std::vector<int> orders_cols_;
std::vector<int> lineitem_cols_;
- ThreadIndexer thread_indexer_;
std::atomic<size_t> orders_batches_generated_ = {0};
std::atomic<size_t> lineitem_batches_generated_ = {0};
@@ -2712,9 +2706,10 @@ class PartGenerator : public TpchTableGenerator {
std::shared_ptr<Schema> schema() const override { return schema_; }
private:
- Status ProduceCallback(size_t) {
+ Status ProduceCallback(size_t thread_index) {
if (done_.load()) return Status::OK();
- ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
gen_->NextPartBatch());
+ ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
+ gen_->NextPartBatch(thread_index));
if (!maybe_batch.has_value()) {
int64_t batches_generated = gen_->part_batches_generated();
if (batches_generated == batches_outputted_.load()) {
@@ -2773,10 +2768,10 @@ class PartSuppGenerator : public TpchTableGenerator {
std::shared_ptr<Schema> schema() const override { return schema_; }
private:
- Status ProduceCallback(size_t) {
+ Status ProduceCallback(size_t thread_index) {
if (done_.load()) return Status::OK();
ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
- gen_->NextPartSuppBatch());
+ gen_->NextPartSuppBatch(thread_index));
if (!maybe_batch.has_value()) {
int64_t batches_generated = gen_->partsupp_batches_generated();
if (batches_generated == batches_outputted_.load()) {
@@ -3092,9 +3087,10 @@ class OrdersGenerator : public TpchTableGenerator {
std::shared_ptr<Schema> schema() const override { return schema_; }
private:
- Status ProduceCallback(size_t) {
+ Status ProduceCallback(size_t thread_index) {
if (done_.load()) return Status::OK();
- ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
gen_->NextOrdersBatch());
+ ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
+ gen_->NextOrdersBatch(thread_index));
if (!maybe_batch.has_value()) {
int64_t batches_generated = gen_->orders_batches_generated();
if (batches_generated == batches_outputted_.load()) {
@@ -3153,10 +3149,10 @@ class LineitemGenerator : public TpchTableGenerator {
std::shared_ptr<Schema> schema() const override { return schema_; }
private:
- Status ProduceCallback(size_t) {
+ Status ProduceCallback(size_t thread_index) {
if (done_.load()) return Status::OK();
ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
- gen_->NextLineItemBatch());
+ gen_->NextLineItemBatch(thread_index));
if (!maybe_batch.has_value()) {
int64_t batches_generated = gen_->lineitem_batches_generated();
if (batches_generated == batches_outputted_.load()) {
@@ -3379,7 +3375,7 @@ class TpchNode : public ExecNode {
Status StartProducing() override {
return generator_->StartProducing(
- thread_indexer_.Capacity(),
+ plan_->max_concurrency(),
[this](ExecBatch batch) { this->OutputBatchCallback(std::move(batch));
},
[this](int64_t num_batches) { this->FinishedCallback(num_batches); },
[this](std::function<Status(size_t)> func) -> Status {
@@ -3400,10 +3396,10 @@ class TpchNode : public ExecNode {
}
void StopProducing() override {
- if (generator_->Abort()) std::ignore = task_group_.End();
+ if (generator_->Abort()) finished_.MarkFinished();
}
- Future<> finished() override { return task_group_.OnFinished(); }
+ Future<> finished() override { return finished_; }
private:
void OutputBatchCallback(ExecBatch batch) {
@@ -3412,42 +3408,23 @@ class TpchNode : public ExecNode {
void FinishedCallback(int64_t total_num_batches) {
outputs_[0]->InputFinished(this, static_cast<int>(total_num_batches));
- std::ignore = task_group_.End();
+ finished_.MarkFinished();
}
Status ScheduleTaskCallback(std::function<Status(size_t)> func) {
- auto executor = plan_->exec_context()->executor();
-
- // Due to the way that the generators schedule tasks, there may be more
tasks
- // than output batches. After outputting the last batch, the generator will
- // end the task group, but there may still be other threads that try to
schedule
- // tasks while the task group is being ended. This can result in adding
tasks after
- // the task group is ended. If those tasks were to be executed,
correctness would
- // not be affected as they'd see the generator is done and exit
immediately. As such,
- // if the task group is ended we can just skip scheduling these tasks in
general.
- if (executor) {
- RETURN_NOT_OK(task_group_.AddTaskIfNotEnded([&] {
- return executor->Submit([this, func] {
- size_t thread_index = thread_indexer_();
- Status status = func(thread_index);
- if (!status.ok()) {
- StopProducing();
- ErrorIfNotOk(status);
- return;
- }
- });
- }));
- } else {
- return func(0);
- }
- return Status::OK();
+ if (finished_.is_finished()) return Status::OK();
+ return plan_->ScheduleTask([this, func](size_t thread_index) {
+ Status status = func(thread_index);
+ if (!status.ok()) {
+ StopProducing();
+ ErrorIfNotOk(status);
+ }
+ return status;
+ });
}
const char* name_;
std::unique_ptr<TpchTableGenerator> generator_;
-
- util::AsyncTaskGroup task_group_;
- ThreadIndexer thread_indexer_;
};
class TpchGenImpl : public TpchGen {
diff --git a/cpp/src/arrow/compute/exec/union_node.cc
b/cpp/src/arrow/compute/exec/union_node.cc
index 22df39aac6..e5170c2bc9 100644
--- a/cpp/src/arrow/compute/exec/union_node.cc
+++ b/cpp/src/arrow/compute/exec/union_node.cc
@@ -116,8 +116,7 @@ class UnionNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
- finished_ = Future<>::Make();
- END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
+ END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
return Status::OK();
}
diff --git a/cpp/src/arrow/dataset/scanner_test.cc
b/cpp/src/arrow/dataset/scanner_test.cc
index 5316f63d08..f5db16f694 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -662,11 +662,18 @@ TEST_P(TestScanner, ScanBatchesFailure) {
[](const EnumeratedRecordBatch& batch) { return
batch.record_batch.value; }))
<< "ScanBatchesUnordered() did not raise an error";
}
- ASSERT_OK_AND_ASSIGN(auto tagged_batch_it, scanner->ScanBatches());
- EXPECT_TRUE(CheckIteratorRaises(
- batch, std::move(tagged_batch_it),
- [](const TaggedRecordBatch& batch) { return batch.record_batch; }))
- << "ScanBatches() did not raise an error";
+
+ auto maybe_tagged_batch_it = scanner->ScanBatches();
+ if (!maybe_tagged_batch_it.ok()) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we
failed!"),
+ std::move(maybe_tagged_batch_it));
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto tagged_batch_it,
std::move(maybe_tagged_batch_it));
+ EXPECT_TRUE(CheckIteratorRaises(
+ batch, std::move(tagged_batch_it),
+ [](const TaggedRecordBatch& batch) { return batch.record_batch; }))
+ << "ScanBatches() did not raise an error";
+ }
};
// Case 1: failure when getting next scan task
@@ -748,10 +755,16 @@ TEST_P(TestScanner, FromReader) {
AssertScannerEquals(target_reader.get(), scanner.get());
// Such datasets can only be scanned once (but you can get fragments
multiple times)
- ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
- batch_it.Next());
+ auto maybe_batch_it = scanner->ScanBatches();
+ if (maybe_batch_it.ok()) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+ (*maybe_batch_it).Next());
+ } else {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+ std::move(maybe_batch_it));
+ }
}
INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner,
diff --git a/cpp/src/arrow/engine/substrait/util.cc
b/cpp/src/arrow/engine/substrait/util.cc
index 27b61f0b34..36240d4682 100644
--- a/cpp/src/arrow/engine/substrait/util.cc
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -66,7 +66,7 @@ class SubstraitExecutor {
public:
explicit SubstraitExecutor(std::shared_ptr<compute::ExecPlan> plan,
compute::ExecContext exec_context)
- : plan_(std::move(plan)), exec_context_(exec_context) {}
+ : plan_(std::move(plan)), plan_started_(false),
exec_context_(exec_context) {}
~SubstraitExecutor() { ARROW_UNUSED(this->Close()); }
@@ -75,6 +75,7 @@ class SubstraitExecutor {
RETURN_NOT_OK(decl.AddToPlan(plan_.get()).status());
}
RETURN_NOT_OK(plan_->Validate());
+ plan_started_ = true;
RETURN_NOT_OK(plan_->StartProducing());
auto schema = sink_consumer_->schema();
std::shared_ptr<RecordBatchReader> sink_reader =
compute::MakeGeneratorReader(
@@ -82,7 +83,10 @@ class SubstraitExecutor {
return sink_reader;
}
- Status Close() { return plan_->finished().status(); }
+ Status Close() {
+ if (plan_started_) return plan_->finished().status();
+ return Status::OK();
+ }
Status Init(const Buffer& substrait_buffer, const ExtensionIdRegistry*
registry) {
if (substrait_buffer.size() == 0) {
@@ -102,6 +106,7 @@ class SubstraitExecutor {
arrow::PushGenerator<util::optional<compute::ExecBatch>> generator_;
std::vector<compute::Declaration> declarations_;
std::shared_ptr<compute::ExecPlan> plan_;
+ bool plan_started_;
compute::ExecContext exec_context_;
std::shared_ptr<SubstraitSinkConsumer> sink_consumer_;
};
diff --git a/cpp/src/arrow/util/future.cc b/cpp/src/arrow/util/future.cc
index ca4290c5b0..ab59234dea 100644
--- a/cpp/src/arrow/util/future.cc
+++ b/cpp/src/arrow/util/future.cc
@@ -314,22 +314,34 @@ class ConcreteFutureImpl : public FutureImpl {
}
void DoMarkFinishedOrFailed(FutureState state) {
+ std::vector<CallbackRecord> callbacks;
+ std::shared_ptr<FutureImpl> self;
{
// Lock the hypothetical waiter first, and the future after.
// This matches the locking order done in FutureWaiter constructor.
std::unique_lock<std::mutex> waiter_lock(global_waiter_mutex);
std::unique_lock<std::mutex> lock(mutex_);
+#ifdef ARROW_WITH_OPENTELEMETRY
+ if (this->span_) {
+ util::tracing::Span& span = *span_;
+ END_SPAN(span);
+ }
+#endif
DCHECK(!IsFutureFinished(state_)) << "Future already marked finished";
+ if (!callbacks_.empty()) {
+ callbacks = std::move(callbacks_);
+ auto self_inner = shared_from_this();
+ self = std::move(self_inner);
+ }
+
state_ = state;
if (waiter_ != nullptr) {
waiter_->MarkFutureFinishedUnlocked(waiter_arg_, state);
}
}
cv_.notify_all();
-
- auto callbacks = std::move(callbacks_);
- auto self = shared_from_this();
+ if (callbacks.empty()) return;
// run callbacks, lock not needed since the future is finished by this
// point so nothing else can modify the callbacks list and it is safe
diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h
index b374c77c81..2ac26b7f20 100644
--- a/cpp/src/arrow/util/future.h
+++ b/cpp/src/arrow/util/future.h
@@ -29,9 +29,11 @@
#include "arrow/status.h"
#include "arrow/type_fwd.h"
#include "arrow/type_traits.h"
+#include "arrow/util/config.h"
#include "arrow/util/functional.h"
#include "arrow/util/macros.h"
#include "arrow/util/optional.h"
+#include "arrow/util/tracing.h"
#include "arrow/util/type_fwd.h"
#include "arrow/util/visibility.h"
@@ -263,6 +265,10 @@ class ARROW_EXPORT FutureImpl : public
std::enable_shared_from_this<FutureImpl>
static std::unique_ptr<FutureImpl> Make();
static std::unique_ptr<FutureImpl> MakeFinished(FutureState state);
+#ifdef ARROW_WITH_OPENTELEMETRY
+ void SetSpan(util::tracing::Span* span) { span_ = span; }
+#endif
+
// Future API
void MarkFinished();
void MarkFailed();
@@ -294,6 +300,9 @@ class ARROW_EXPORT FutureImpl : public
std::enable_shared_from_this<FutureImpl>
CallbackOptions options;
};
std::vector<CallbackRecord> callbacks_;
+#ifdef ARROW_WITH_OPENTELEMETRY
+ util::tracing::Span* span_ = NULLPTR;
+#endif
};
// An object that waits on multiple futures at once. Only one waiter
@@ -378,6 +387,10 @@ class ARROW_MUST_USE_TYPE Future {
// of being able to presize a vector of Futures.
Future() = default;
+#ifdef ARROW_WITH_OPENTELEMETRY
+ void SetSpan(util::tracing::Span* span) { impl_->SetSpan(span); }
+#endif
+
// Consumer API
bool is_valid() const { return impl_ != NULLPTR; }
diff --git a/cpp/src/arrow/util/tracing_internal.h
b/cpp/src/arrow/util/tracing_internal.h
index 2898fd245f..d1da05671a 100644
--- a/cpp/src/arrow/util/tracing_internal.h
+++ b/cpp/src/arrow/util/tracing_internal.h
@@ -164,17 +164,8 @@ opentelemetry::trace::StartSpanOptions
SpanOptionsWithParent(
#define END_SPAN(target_span) \
::arrow::internal::tracing::UnwrapSpan(target_span.details.get())->End()
-#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future,
target_capture) \
- target_future = target_future.Then(
\
- [target_capture]() {
\
- MARK_SPAN(target_span, Status::OK());
\
- END_SPAN(target_span);
\
- },
\
- [target_capture](const Status& st) {
\
- MARK_SPAN(target_span, st);
\
- END_SPAN(target_span);
\
- return st;
\
- })
+#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future) \
+ target_future.SetSpan(&target_span)
#define PROPAGATE_SPAN_TO_GENERATOR(generator) \
generator = ::arrow::internal::tracing::PropagateSpanThroughAsyncGenerator( \
@@ -207,7 +198,7 @@ class SpanImpl {};
#define MARK_SPAN(target_span, status)
#define EVENT(target_span, ...)
#define END_SPAN(target_span)
-#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future,
target_capture)
+#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future)
#define PROPAGATE_SPAN_TO_GENERATOR(generator)
#define WRAP_ASYNC_GENERATOR(generator)
#define WRAP_ASYNC_GENERATOR_WITH_CHILD_SPAN(generator, name)
diff --git a/python/pyarrow/tests/test_dataset.py
b/python/pyarrow/tests/test_dataset.py
index b35d8f3178..277c6866f1 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -2928,9 +2928,9 @@ def test_incompatible_schema_hang(tempdir,
dataset_reader):
dataset = ds.dataset([str(fn)] * 100, schema=schema)
assert dataset.schema.equals(schema)
scanner = dataset_reader.scanner(dataset)
- reader = scanner.to_reader()
with pytest.raises(NotImplementedError,
match='Unsupported cast from int64 to null'):
+ reader = scanner.to_reader()
reader.read_all()