This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new cf03901ba8 ARROW-16523: [C++] Part 1 of ExecPlan cleanup: Centralized 
Task Group (#13143)
cf03901ba8 is described below

commit cf03901ba8fa6364d68dc6b9a31942a329ac6c89
Author: Sasha Krassovsky <[email protected]>
AuthorDate: Thu Jul 14 15:11:45 2022 -0800

    ARROW-16523: [C++] Part 1 of ExecPlan cleanup: Centralized Task Group 
(#13143)
    
    Authored-by: Sasha Krassovsky <[email protected]>
    Signed-off-by: Weston Pace <[email protected]>
---
 cpp/src/arrow/compute/exec/aggregate_node.cc      |  80 ++++++--------
 cpp/src/arrow/compute/exec/exec_plan.cc           | 125 ++++++++++++++++++----
 cpp/src/arrow/compute/exec/exec_plan.h            |  80 +++++++++++---
 cpp/src/arrow/compute/exec/hash_join.cc           |  26 +++--
 cpp/src/arrow/compute/exec/hash_join.h            |   8 +-
 cpp/src/arrow/compute/exec/hash_join_benchmark.cc |  25 +++--
 cpp/src/arrow/compute/exec/hash_join_node.cc      | 113 +++++++++----------
 cpp/src/arrow/compute/exec/hash_join_node_test.cc |   3 +
 cpp/src/arrow/compute/exec/sink_node.cc           |  17 +--
 cpp/src/arrow/compute/exec/source_node.cc         | 115 +++++++++-----------
 cpp/src/arrow/compute/exec/swiss_join.cc          |  39 ++++---
 cpp/src/arrow/compute/exec/test_util.cc           |   3 +-
 cpp/src/arrow/compute/exec/tpch_node.cc           |  77 +++++--------
 cpp/src/arrow/compute/exec/union_node.cc          |   3 +-
 cpp/src/arrow/dataset/scanner_test.cc             |  31 ++++--
 cpp/src/arrow/engine/substrait/util.cc            |   9 +-
 cpp/src/arrow/util/future.cc                      |  18 +++-
 cpp/src/arrow/util/future.h                       |  13 +++
 cpp/src/arrow/util/tracing_internal.h             |  15 +--
 python/pyarrow/tests/test_dataset.py              |   2 +-
 20 files changed, 465 insertions(+), 337 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc 
b/cpp/src/arrow/compute/exec/aggregate_node.cc
index 0131319be3..96aa56b80c 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -113,7 +113,7 @@ class ScalarAggregateNode : public ExecNode {
       }
 
       KernelContext kernel_ctx{exec_ctx};
-      states[i].resize(ThreadIndexer::Capacity());
+      states[i].resize(plan->max_concurrency());
       RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx,
                                     KernelInitArgs{kernels[i],
                                                    {
@@ -168,7 +168,7 @@ class ScalarAggregateNode : public ExecNode {
                                     {"batch.length", batch.length}});
     DCHECK_EQ(input, inputs_[0]);
 
-    auto thread_index = get_thread_index_();
+    auto thread_index = plan_->GetThreadIndex();
 
     if (ErrorIfNotOk(DoConsume(std::move(batch), thread_index))) return;
 
@@ -196,8 +196,6 @@ class ScalarAggregateNode : public ExecNode {
                        {{"node.label", label()},
                         {"node.detail", ToString()},
                         {"node.kind", kind_name()}});
-    finished_ = Future<>::Make();
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
     // Scalar aggregates will only output a single batch
     outputs_[0]->InputFinished(this, 1);
     return Status::OK();
@@ -224,8 +222,6 @@ class ScalarAggregateNode : public ExecNode {
     inputs_[0]->StopProducing(this);
   }
 
-  Future<> finished() override { return finished_; }
-
  protected:
   std::string ToStringExtra(int indent = 0) const override {
     std::stringstream ss;
@@ -266,7 +262,6 @@ class ScalarAggregateNode : public ExecNode {
 
   std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
 
-  ThreadIndexer get_thread_index_;
   AtomicCounter input_counter_;
 };
 
@@ -284,6 +279,19 @@ class GroupByNode : public ExecNode {
         aggs_(std::move(aggs)),
         agg_kernels_(std::move(agg_kernels)) {}
 
+  Status Init() override {
+    output_task_group_id_ = plan_->RegisterTaskGroup(
+        [this](size_t, int64_t task_id) {
+          OutputNthBatch(task_id);
+          return Status::OK();
+        },
+        [this](size_t) {
+          finished_.MarkFinished();
+          return Status::OK();
+        });
+    return Status::OK();
+  }
+
   static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
                                 const ExecNodeOptions& options) {
     RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
@@ -358,7 +366,7 @@ class GroupByNode : public ExecNode {
                        {{"group_by", ToStringExtra()},
                         {"node.label", label()},
                         {"batch.length", batch.length}});
-    size_t thread_index = get_thread_index_();
+    size_t thread_index = plan_->GetThreadIndex();
     if (thread_index >= local_states_.size()) {
       return Status::IndexError("thread index ", thread_index, " is out of 
range [0, ",
                                 local_states_.size(), ")");
@@ -465,47 +473,32 @@ class GroupByNode : public ExecNode {
     std::move(out_keys.values.begin(), out_keys.values.end(),
               out_data.values.begin() + agg_kernels_.size());
     state->grouper.reset();
-
-    if (output_counter_.SetTotal(
-            static_cast<int>(bit_util::CeilDiv(out_data.length, 
output_batch_size())))) {
-      // this will be hit if out_data.length == 0
-      finished_.MarkFinished();
-    }
     return out_data;
   }
 
-  void OutputNthBatch(int n) {
+  void OutputNthBatch(int64_t n) {
     // bail if StopProducing was called
     if (finished_.is_finished()) return;
 
     int64_t batch_size = output_batch_size();
     outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, 
batch_size));
-
-    if (output_counter_.Increment()) {
-      finished_.MarkFinished();
-    }
   }
 
   Status OutputResult() {
-    RETURN_NOT_OK(Merge());
-    ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());
-
-    int num_output_batches = *output_counter_.total();
-    outputs_[0]->InputFinished(this, num_output_batches);
-
-    auto executor = ctx_->executor();
-    for (int i = 0; i < num_output_batches; ++i) {
-      if (executor) {
-        // bail if StopProducing was called
-        if (finished_.is_finished()) break;
-
-        auto plan = this->plan()->shared_from_this();
-        RETURN_NOT_OK(executor->Spawn([plan, this, i] { OutputNthBatch(i); }));
-      } else {
-        OutputNthBatch(i);
+    // To simplify merging, ensure that the first grouper is nonempty
+    for (size_t i = 0; i < local_states_.size(); i++) {
+      if (local_states_[i].grouper) {
+        std::swap(local_states_[i], local_states_[0]);
+        break;
       }
     }
 
+    RETURN_NOT_OK(Merge());
+    ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());
+
+    int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, 
output_batch_size());
+    outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
+    RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_, 
num_output_batches));
     return Status::OK();
   }
 
@@ -555,10 +548,8 @@ class GroupByNode : public ExecNode {
                        {{"node.label", label()},
                         {"node.detail", ToString()},
                         {"node.kind", kind_name()}});
-    finished_ = Future<>::Make();
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
 
-    local_states_.resize(ThreadIndexer::Capacity());
+    local_states_.resize(plan_->max_concurrency());
     return Status::OK();
   }
 
@@ -576,17 +567,12 @@ class GroupByNode : public ExecNode {
     EVENT(span_, "StopProducing");
     DCHECK_EQ(output, outputs_[0]);
 
-    ARROW_UNUSED(input_counter_.Cancel());
-    if (output_counter_.Cancel()) {
-      finished_.MarkFinished();
-    }
+    if (input_counter_.Cancel()) finished_.MarkFinished();
     inputs_[0]->StopProducing(this);
   }
 
   void StopProducing() override { StopProducing(outputs_[0]); }
 
-  Future<> finished() override { return finished_; }
-
  protected:
   std::string ToStringExtra(int indent = 0) const override {
     std::stringstream ss;
@@ -608,7 +594,7 @@ class GroupByNode : public ExecNode {
   };
 
   ThreadLocalState* GetLocalState() {
-    size_t thread_index = get_thread_index_();
+    size_t thread_index = plan_->GetThreadIndex();
     return &local_states_[thread_index];
   }
 
@@ -650,14 +636,14 @@ class GroupByNode : public ExecNode {
   }
 
   ExecContext* ctx_;
+  int output_task_group_id_;
 
   const std::vector<int> key_field_ids_;
   const std::vector<int> agg_src_field_ids_;
   const std::vector<Aggregate> aggs_;
   const std::vector<const HashAggregateKernel*> agg_kernels_;
 
-  ThreadIndexer get_thread_index_;
-  AtomicCounter input_counter_, output_counter_;
+  AtomicCounter input_counter_;
 
   std::vector<ThreadLocalState> local_states_;
   ExecBatch out_data_;
diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc 
b/cpp/src/arrow/compute/exec/exec_plan.cc
index 468d0accea..e248782ded 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -24,6 +24,7 @@
 #include "arrow/compute/exec.h"
 #include "arrow/compute/exec/expression.h"
 #include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/task_util.h"
 #include "arrow/compute/exec_internal.h"
 #include "arrow/compute/registry.h"
 #include "arrow/datum.h"
@@ -56,6 +57,9 @@ struct ExecPlanImpl : public ExecPlan {
     }
   }
 
+  size_t GetThreadIndex() { return thread_indexer_(); }
+  size_t max_concurrency() const { return thread_indexer_.Capacity(); }
+
   ExecNode* AddNode(std::unique_ptr<ExecNode> node) {
     if (node->label().empty()) {
       node->SetLabel(std::to_string(auto_label_counter_++));
@@ -70,6 +74,34 @@ struct ExecPlanImpl : public ExecPlan {
     return nodes_.back().get();
   }
 
+  Status AddFuture(Future<> fut) { return 
task_group_.AddTaskIfNotEnded(std::move(fut)); }
+
+  Status ScheduleTask(std::function<Status()> fn) {
+    auto executor = exec_context_->executor();
+    if (!executor) return fn();
+    // Atomically submit fn to the executor, and if successful
+    // add it to the task group.
+    return task_group_.AddTaskIfNotEnded(
+        [executor, fn]() { return executor->Submit(std::move(fn)); });
+  }
+
+  Status ScheduleTask(std::function<Status(size_t)> fn) {
+    std::function<Status()> indexed_fn = [this, fn]() {
+      size_t thread_index = GetThreadIndex();
+      return fn(thread_index);
+    };
+    return ScheduleTask(std::move(indexed_fn));
+  }
+
+  int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+                        std::function<Status(size_t)> on_finished) {
+    return task_scheduler_->RegisterTaskGroup(std::move(task), 
std::move(on_finished));
+  }
+
+  Status StartTaskGroup(int task_group_id, int64_t num_tasks) {
+    return task_scheduler_->StartTaskGroup(GetThreadIndex(), task_group_id, 
num_tasks);
+  }
+
   Status Validate() const {
     if (nodes_.empty()) {
       return Status::Invalid("ExecPlan has no node");
@@ -96,15 +128,35 @@ struct ExecPlanImpl : public ExecPlan {
     if (started_) {
       return Status::Invalid("restarted ExecPlan");
     }
-    started_ = true;
 
-    // producers precede consumers
-    sorted_nodes_ = TopoSort();
-    for (ExecNode* node : sorted_nodes_) {
-      RETURN_NOT_OK(node->PrepareToProduce());
+    std::vector<Future<>> futures;
+    for (auto& n : nodes_) {
+      RETURN_NOT_OK(n->Init());
+      futures.push_back(n->finished());
     }
 
-    std::vector<Future<>> futures;
+    AllFinished(futures).AddCallback([this](const Status& st) {
+      error_st_ = st;
+      EndTaskGroup();
+    });
+
+    task_scheduler_->RegisterEnd();
+    int num_threads = 1;
+    bool sync_execution = true;
+    if (auto executor = exec_context()->executor()) {
+      num_threads = executor->GetCapacity();
+      sync_execution = false;
+    }
+    RETURN_NOT_OK(task_scheduler_->StartScheduling(
+        0 /* thread_index */,
+        [this](std::function<Status(size_t)> fn) -> Status {
+          return this->ScheduleTask(std::move(fn));
+        },
+        /*concurrent_tasks=*/2 * num_threads, sync_execution));
+
+    started_ = true;
+    // producers precede consumers
+    sorted_nodes_ = TopoSort();
 
     Status st = Status::OK();
 
@@ -120,23 +172,34 @@ struct ExecPlanImpl : public ExecPlan {
         // Stop nodes that successfully started, in reverse order
         stopped_ = true;
         StopProducingImpl(it.base(), sorted_nodes_.end());
-        break;
+        for (NodeVector::iterator fw_it = sorted_nodes_.begin(); fw_it != 
it.base();
+             ++fw_it) {
+          Future<> fut = (*fw_it)->finished();
+          if (!fut.is_finished()) fut.MarkFinished();
+        }
+        return st;
       }
-
-      futures.push_back(node->finished());
     }
-
-    finished_ = AllFinished(futures);
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
     return st;
   }
 
+  void EndTaskGroup() {
+    bool expected = false;
+    if (group_ended_.compare_exchange_strong(expected, true)) {
+      task_group_.End().AddCallback([this](const Status& st) {
+        MARK_SPAN(span_, error_st_ & st);
+        END_SPAN(span_);
+        finished_.MarkFinished(error_st_ & st);
+      });
+    }
+  }
+
   void StopProducing() {
     DCHECK(started_) << "stopped an ExecPlan which never started";
     EVENT(span_, "StopProducing");
     stopped_ = true;
-
-    StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end());
+    task_scheduler_->Abort(
+        [this]() { StopProducingImpl(sorted_nodes_.begin(), 
sorted_nodes_.end()); });
   }
 
   template <typename It>
@@ -242,7 +305,8 @@ struct ExecPlanImpl : public ExecPlan {
     return ss.str();
   }
 
-  Future<> finished_ = Future<>::MakeFinished();
+  Status error_st_;
+  Future<> finished_ = Future<>::Make();
   bool started_ = false, stopped_ = false;
   std::vector<std::unique_ptr<ExecNode>> nodes_;
   NodeVector sources_, sinks_;
@@ -250,6 +314,11 @@ struct ExecPlanImpl : public ExecPlan {
   uint32_t auto_label_counter_ = 0;
   util::tracing::Span span_;
   std::shared_ptr<const KeyValueMetadata> metadata_;
+
+  ThreadIndexer thread_indexer_;
+  std::atomic<bool> group_ended_{false};
+  util::AsyncTaskGroup task_group_;
+  std::unique_ptr<TaskScheduler> task_scheduler_ = TaskScheduler::Make();
 };
 
 ExecPlanImpl* ToDerived(ExecPlan* ptr) { return 
checked_cast<ExecPlanImpl*>(ptr); }
@@ -283,6 +352,26 @@ const ExecPlan::NodeVector& ExecPlan::sources() const {
 
 const ExecPlan::NodeVector& ExecPlan::sinks() const { return 
ToDerived(this)->sinks_; }
 
+size_t ExecPlan::GetThreadIndex() { return ToDerived(this)->GetThreadIndex(); }
+size_t ExecPlan::max_concurrency() const { return 
ToDerived(this)->max_concurrency(); }
+
+Status ExecPlan::AddFuture(Future<> fut) {
+  return ToDerived(this)->AddFuture(std::move(fut));
+}
+Status ExecPlan::ScheduleTask(std::function<Status()> fn) {
+  return ToDerived(this)->ScheduleTask(std::move(fn));
+}
+Status ExecPlan::ScheduleTask(std::function<Status(size_t)> fn) {
+  return ToDerived(this)->ScheduleTask(std::move(fn));
+}
+int ExecPlan::RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+                                std::function<Status(size_t)> on_finished) {
+  return ToDerived(this)->RegisterTaskGroup(std::move(task), 
std::move(on_finished));
+}
+Status ExecPlan::StartTaskGroup(int task_group_id, int64_t num_tasks) {
+  return ToDerived(this)->StartTaskGroup(task_group_id, num_tasks);
+}
+
 Status ExecPlan::Validate() { return ToDerived(this)->Validate(); }
 
 Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); }
@@ -312,6 +401,8 @@ ExecNode::ExecNode(ExecPlan* plan, NodeVector inputs,
   }
 }
 
+Status ExecNode::Init() { return Status::OK(); }
+
 Status ExecNode::Validate() const {
   if (inputs_.size() != input_labels_.size()) {
     return Status::Invalid("Invalid number of inputs for '", label(), "' 
(expected ",
@@ -395,8 +486,6 @@ Status MapNode::StartProducing() {
   START_COMPUTE_SPAN(
       span_, std::string(kind_name()) + ":" + label(),
       {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", 
kind_name()}});
-  finished_ = Future<>::Make();
-  END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
   return Status::OK();
 }
 
@@ -424,8 +513,6 @@ void MapNode::StopProducing() {
   inputs_[0]->StopProducing(this);
 }
 
-Future<> MapNode::finished() { return finished_; }
-
 void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
                          ExecBatch batch) {
   Status status;
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h 
b/cpp/src/arrow/compute/exec/exec_plan.h
index dcf271bd36..c8599748de 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -61,6 +61,60 @@ class ARROW_EXPORT ExecPlan : public 
std::enable_shared_from_this<ExecPlan> {
     return out;
   }
 
+  /// \brief Returns the index of the current thread.
+  size_t GetThreadIndex();
+  /// \brief Returns the maximum number of threads that the plan could use.
+  ///
+  /// GetThreadIndex will always return something less than this, so it is 
safe to
+  /// e.g. make an array of thread-locals off this.
+  size_t max_concurrency() const;
+
+  /// \brief Add a future to the plan's task group.
+  ///
+  /// \param fut The future to add
+  ///
+  /// Use this when interfacing with anything that returns a future (such as 
IO), but
+  /// prefer ScheduleTask/StartTaskGroup inside of ExecNodes.
+  /// The below API interfaces with the scheduler to add tasks to the task 
group. Tasks
+  /// should be added sparingly! Prefer just doing the work immediately rather 
than adding
+  /// a task for it. Tasks are used in pipeline breakers that may output many 
more rows
+  /// than they received (such as a full outer join).
+  Status AddFuture(Future<> fut);
+
+  /// \brief Add a single function as a task to the plan's task group.
+  ///
+  /// \param fn The task to run. Takes no arguments and returns a Status.
+  Status ScheduleTask(std::function<Status()> fn);
+
+  /// \brief Add a single function as a task to the plan's task group.
+  ///
+  /// \param fn The task to run. Takes the thread index and returns a Status.
+  Status ScheduleTask(std::function<Status(size_t)> fn);
+  // Register/Start TaskGroup is a way of performing a "Parallel For" pattern:
+  // - The task function takes the thread index and the index of the task
+  // - The on_finished function takes the thread index
+  // Returns an integer ID that will be used to reference the task group in
+  // StartTaskGroup. At runtime, call StartTaskGroup with the ID and the 
number of times
+  // you'd like the task to be executed. The need to register a task group 
before use will
+  // be removed after we rewrite the scheduler.
+  /// \brief Register a "parallel for" task group with the scheduler
+  ///
+  /// \param task The function implementing the task. Takes the thread_index 
and
+  ///             the task index.
+  /// \param on_finished The function that gets run once all tasks have been 
completed.
+  /// Takes the thread_index.
+  ///
+  /// Must be called inside of ExecNode::Init.
+  int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+                        std::function<Status(size_t)> on_finished);
+
+  /// \brief Start the task group with the specified ID. This can only
+  ///        be called once per task_group_id.
+  ///
+  /// \param task_group_id The ID  of the task group to run
+  /// \param num_tasks The number of times to run the task
+  Status StartTaskGroup(int task_group_id, int64_t num_tasks);
+
   /// The initial inputs
   const NodeVector& sources() const;
 
@@ -157,6 +211,16 @@ class ARROW_EXPORT ExecNode {
   /// knows when it has received all input, regardless of order.
   virtual void InputFinished(ExecNode* input, int total_batches) = 0;
 
+  /// \brief Perform any needed initialization
+  ///
+  /// This hook performs any actions in between creation of ExecPlan and the 
call to
+  /// StartProducing. An example could be Bloom filter pushdown. The order of 
ExecNodes
+  /// that executes this method is undefined, but the calls are made 
synchronously.
+  ///
+  /// At this point a node can rely on all inputs & outputs (and the input 
schemas)
+  /// being well defined.
+  virtual Status Init();
+
   /// Lifecycle API:
   /// - start / stop to initiate and terminate production
   /// - pause / resume to apply backpressure
@@ -212,16 +276,6 @@ class ARROW_EXPORT ExecNode {
   // A node with multiple outputs will also need to ensure it is applying 
backpressure if
   // any of its outputs is asking to pause
 
-  /// \brief Perform any needed initialization
-  ///
-  /// This hook performs any actions in between creation of ExecPlan and the 
call to
-  /// StartProducing. An example could be Bloom filter pushdown. The order of 
ExecNodes
-  /// that executes this method is undefined, but the calls are made 
synchronously.
-  ///
-  /// At this point a node can rely on all inputs & outputs (and the input 
schemas)
-  /// being well defined.
-  virtual Status PrepareToProduce() { return Status::OK(); }
-
   /// \brief Start producing
   ///
   /// This must only be called once.  If this fails, then other lifecycle
@@ -263,7 +317,7 @@ class ARROW_EXPORT ExecNode {
   virtual void StopProducing() = 0;
 
   /// \brief A future which will be marked finished when this node has stopped 
producing.
-  virtual Future<> finished() = 0;
+  virtual Future<> finished() { return finished_; }
 
   std::string ToString(int indent = 0) const;
 
@@ -289,7 +343,7 @@ class ARROW_EXPORT ExecNode {
   NodeVector outputs_;
 
   // Future to sync finished
-  Future<> finished_ = Future<>::MakeFinished();
+  Future<> finished_ = Future<>::Make();
 
   util::tracing::Span span_;
 };
@@ -321,8 +375,6 @@ class ARROW_EXPORT MapNode : public ExecNode {
 
   void StopProducing() override;
 
-  Future<> finished() override;
-
  protected:
   void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn, 
ExecBatch batch);
 
diff --git a/cpp/src/arrow/compute/exec/hash_join.cc 
b/cpp/src/arrow/compute/exec/hash_join.cc
index e821979ae1..07a3083fb9 100644
--- a/cpp/src/arrow/compute/exec/hash_join.cc
+++ b/cpp/src/arrow/compute/exec/hash_join.cc
@@ -44,8 +44,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
               const HashJoinProjectionMaps* proj_map_left,
               const HashJoinProjectionMaps* proj_map_right,
               std::vector<JoinKeyCmp> key_cmp, Expression filter,
+              RegisterTaskGroupCallback register_task_group_callback,
+              StartTaskGroupCallback start_task_group_callback,
               OutputBatchCallback output_batch_callback,
-              FinishedCallback finished_callback, TaskScheduler* scheduler) 
override {
+              FinishedCallback finished_callback) override {
     START_COMPUTE_SPAN(span_, "HashJoinBasicImpl",
                        {{"detail", filter.ToString()},
                         {"join.kind", arrow::compute::ToString(join_type)},
@@ -58,10 +60,12 @@ class HashJoinBasicImpl : public HashJoinImpl {
     schema_[1] = proj_map_right;
     key_cmp_ = std::move(key_cmp);
     filter_ = std::move(filter);
+    register_task_group_callback_ = std::move(register_task_group_callback);
+    start_task_group_callback_ = std::move(start_task_group_callback);
     output_batch_callback_ = std::move(output_batch_callback);
     finished_callback_ = std::move(finished_callback);
-    scheduler_ = scheduler;
     local_states_.resize(num_threads_);
+
     for (size_t i = 0; i < local_states_.size(); ++i) {
       local_states_[i].is_initialized = false;
       local_states_[i].is_has_match_initialized = false;
@@ -78,11 +82,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
     return Status::OK();
   }
 
-  void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override 
{
+  void Abort(AbortContinuationImpl pos_abort_callback) override {
     EVENT(span_, "Abort");
     END_SPAN(span_);
     cancelled_ = true;
-    scheduler_->Abort(std::move(pos_abort_callback));
+    pos_abort_callback();
   }
 
   std::string ToString() const override { return "HashJoinBasicImpl"; }
@@ -546,7 +550,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
   }
 
   void RegisterBuildHashTable() {
-    task_group_build_ = scheduler_->RegisterTaskGroup(
+    task_group_build_ = register_task_group_callback_(
         [this](size_t thread_index, int64_t task_id) -> Status {
           return BuildHashTable_exec_task(thread_index, task_id);
         },
@@ -605,16 +609,16 @@ class HashJoinBasicImpl : public HashJoinImpl {
     return build_finished_callback_(thread_index);
   }
 
-  Status BuildHashTable(size_t thread_index, AccumulationQueue batches,
+  Status BuildHashTable(size_t /*thread_index*/, AccumulationQueue batches,
                         BuildFinishedCallback on_finished) override {
     build_finished_callback_ = std::move(on_finished);
     build_batches_ = std::move(batches);
-    return scheduler_->StartTaskGroup(thread_index, task_group_build_,
+    return start_task_group_callback_(task_group_build_,
                                       /*num_tasks=*/1);
   }
 
   void RegisterScanHashTable() {
-    task_group_scan_ = scheduler_->RegisterTaskGroup(
+    task_group_scan_ = register_task_group_callback_(
         [this](size_t thread_index, int64_t task_id) -> Status {
           return ScanHashTable_exec_task(thread_index, task_id);
         },
@@ -689,8 +693,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
 
   Status ScanHashTable(size_t thread_index) {
     MergeHasMatch();
-    return scheduler_->StartTaskGroup(thread_index, task_group_scan_,
-                                      ScanHashTable_num_tasks());
+    return start_task_group_callback_(task_group_scan_, 
ScanHashTable_num_tasks());
   }
 
   Status ProbingFinished(size_t thread_index) override {
@@ -741,12 +744,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
   const HashJoinProjectionMaps* schema_[2];
   std::vector<JoinKeyCmp> key_cmp_;
   Expression filter_;
-  TaskScheduler* scheduler_;
   int task_group_build_;
   int task_group_scan_;
 
   // Callbacks
   //
+  RegisterTaskGroupCallback register_task_group_callback_;
+  StartTaskGroupCallback start_task_group_callback_;
   OutputBatchCallback output_batch_callback_;
   BuildFinishedCallback build_finished_callback_;
   FinishedCallback finished_callback_;
diff --git a/cpp/src/arrow/compute/exec/hash_join.h 
b/cpp/src/arrow/compute/exec/hash_join.h
index 19add7d440..0c5e43467e 100644
--- a/cpp/src/arrow/compute/exec/hash_join.h
+++ b/cpp/src/arrow/compute/exec/hash_join.h
@@ -41,14 +41,20 @@ class HashJoinImpl {
   using OutputBatchCallback = std::function<void(int64_t, ExecBatch)>;
   using BuildFinishedCallback = std::function<Status(size_t)>;
   using FinishedCallback = std::function<void(int64_t)>;
+  using RegisterTaskGroupCallback = std::function<int(
+      std::function<Status(size_t, int64_t)>, std::function<Status(size_t)>)>;
+  using StartTaskGroupCallback = std::function<Status(int, int64_t)>;
+  using AbortContinuationImpl = std::function<void()>;
 
   virtual ~HashJoinImpl() = default;
   virtual Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads,
                       const HashJoinProjectionMaps* proj_map_left,
                       const HashJoinProjectionMaps* proj_map_right,
                       std::vector<JoinKeyCmp> key_cmp, Expression filter,
+                      RegisterTaskGroupCallback register_task_group_callback,
+                      StartTaskGroupCallback start_task_group_callback,
                       OutputBatchCallback output_batch_callback,
-                      FinishedCallback finished_callback, TaskScheduler* 
scheduler) = 0;
+                      FinishedCallback finished_callback) = 0;
 
   virtual Status BuildHashTable(size_t thread_index, AccumulationQueue batches,
                                 BuildFinishedCallback on_finished) = 0;
diff --git a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc 
b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
index 97badb8423..94201a849f 100644
--- a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
@@ -19,6 +19,7 @@
 
 #include "arrow/api.h"
 #include "arrow/compute/exec/hash_join.h"
+#include "arrow/compute/exec/hash_join_node.h"
 #include "arrow/compute/exec/options.h"
 #include "arrow/compute/exec/test_util.h"
 #include "arrow/compute/exec/util.h"
@@ -56,8 +57,6 @@ struct BenchmarkSettings {
 class JoinBenchmark {
  public:
   explicit JoinBenchmark(BenchmarkSettings& settings) {
-    bool is_parallel = settings.num_threads != 1;
-
     SchemaBuilder l_schema_builder, r_schema_builder;
     std::vector<FieldRef> left_keys, right_keys;
     std::vector<JoinKeyCmp> key_cmp;
@@ -127,9 +126,8 @@ class JoinBenchmark {
 
     stats_.num_probe_rows = settings.num_probe_batches * settings.batch_size;
 
-    ctx_ = arrow::internal::make_unique<ExecContext>(
-        default_memory_pool(),
-        is_parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+    ctx_ = arrow::internal::make_unique<ExecContext>(default_memory_pool(),
+                                                     
arrow::internal::GetCpuThreadPool());
 
     schema_mgr_ = arrow::internal::make_unique<HashJoinSchema>();
     Expression filter = literal(true);
@@ -151,10 +149,22 @@ class JoinBenchmark {
     };
 
     scheduler_ = TaskScheduler::Make();
+
+    auto register_task_group_callback = [&](std::function<Status(size_t, 
int64_t)> task,
+                                            std::function<Status(size_t)> 
cont) {
+      return scheduler_->RegisterTaskGroup(std::move(task), std::move(cont));
+    };
+
+    auto start_task_group_callback = [&](int task_group_id, int64_t num_tasks) 
{
+      return scheduler_->StartTaskGroup(omp_get_thread_num(), task_group_id, 
num_tasks);
+    };
+
     DCHECK_OK(join_->Init(
         ctx_.get(), settings.join_type, settings.num_threads,
         &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), 
std::move(key_cmp),
-        std::move(filter), [](ExecBatch) {}, [](int64_t x) {}, 
scheduler_.get()));
+        std::move(filter), std::move(register_task_group_callback),
+        std::move(start_task_group_callback), [](int64_t, ExecBatch) {},
+        [](int64_t x) {}));
 
     task_group_probe_ = scheduler_->RegisterTaskGroup(
         [this](size_t thread_index, int64_t task_id) -> Status {
@@ -168,7 +178,8 @@ class JoinBenchmark {
 
     DCHECK_OK(scheduler_->StartScheduling(
         0 /*thread index*/, std::move(schedule_callback),
-        static_cast<int>(2 * settings.num_threads) /*concurrent tasks*/, 
!is_parallel));
+        static_cast<int>(2 * settings.num_threads) /*concurrent tasks*/,
+        settings.num_threads == 1));
   }
 
   void RunJoin() {
diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc 
b/cpp/src/arrow/compute/exec/hash_join_node.cc
index 73df78b46e..1785df8784 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -483,10 +483,15 @@ class HashJoinNode;
 // on every batch that has been queued so far as well as any new probe-side 
batch that
 // comes in.
 struct BloomFilterPushdownContext {
+  using RegisterTaskGroupCallback = std::function<int(
+      std::function<Status(size_t, int64_t)>, std::function<Status(size_t)>)>;
+  using StartTaskGroupCallback = std::function<Status(int, int64_t)>;
   using BuildFinishedCallback = std::function<Status(size_t, 
AccumulationQueue)>;
   using FiltersReceivedCallback = std::function<Status()>;
   using FilterFinishedCallback = std::function<Status(size_t, 
AccumulationQueue)>;
-  void Init(HashJoinNode* owner, size_t num_threads, TaskScheduler* scheduler,
+  void Init(HashJoinNode* owner, size_t num_threads,
+            RegisterTaskGroupCallback register_task_group_callback,
+            StartTaskGroupCallback start_task_group_callback,
             FiltersReceivedCallback on_bloom_filters_received, bool 
disable_bloom_filter,
             bool use_sync_execution);
 
@@ -531,7 +536,7 @@ struct BloomFilterPushdownContext {
     if (eval_.num_expected_bloom_filters_ == 0)
       return eval_.on_finished_(thread_index, std::move(eval_.batches_));
 
-    return scheduler_->StartTaskGroup(thread_index, eval_.task_id_,
+    return start_task_group_callback_(eval_.task_id_,
                                       
/*num_tasks=*/eval_.batches_.batch_count());
   }
 
@@ -621,10 +626,10 @@ struct BloomFilterPushdownContext {
     return &tld_[thread_index].stack;
   }
 
+  StartTaskGroupCallback start_task_group_callback_;
   bool disable_bloom_filter_;
   HashJoinSchema* schema_mgr_;
   ExecContext* ctx_;
-  TaskScheduler* scheduler_;
 
   struct ThreadLocalData {
     bool is_init = false;
@@ -834,7 +839,7 @@ class HashJoinNode : public ExecNode {
   Status OnFiltersReceived() {
     std::unique_lock<std::mutex> guard(probe_side_mutex_);
     bloom_filters_ready_ = true;
-    size_t thread_index = thread_indexer_();
+    size_t thread_index = plan_->GetThreadIndex();
     AccumulationQueue batches = std::move(probe_accumulator_);
     guard.unlock();
     return pushdown_context_.FilterBatches(
@@ -863,8 +868,8 @@ class HashJoinNode : public ExecNode {
       std::lock_guard<std::mutex> guard(probe_side_mutex_);
       queued_batches_to_probe_ = std::move(probe_accumulator_);
     }
-    return scheduler_->StartTaskGroup(thread_index, task_group_probe_,
-                                      queued_batches_to_probe_.batch_count());
+    return plan_->StartTaskGroup(task_group_probe_,
+                                 queued_batches_to_probe_.batch_count());
   }
 
   Status OnQueuedBatchesProbed(size_t thread_index) {
@@ -885,7 +890,7 @@ class HashJoinNode : public ExecNode {
       return;
     }
 
-    size_t thread_index = thread_indexer_();
+    size_t thread_index = plan_->GetThreadIndex();
     int side = (input == inputs_[0]) ? 0 : 1;
 
     EVENT(span_, "InputReceived", {{"batch.length", batch.length}, {"side", 
side}});
@@ -923,8 +928,7 @@ class HashJoinNode : public ExecNode {
 
   void InputFinished(ExecNode* input, int total_batches) override {
     ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != 
inputs_.end());
-
-    size_t thread_index = thread_indexer_();
+    size_t thread_index = plan_->GetThreadIndex();
     int side = (input == inputs_[0]) ? 0 : 1;
 
     EVENT(span_, "InputFinished", {{"side", side}, {"batches.length", 
total_batches}});
@@ -941,28 +945,42 @@ class HashJoinNode : public ExecNode {
     }
   }
 
-  Status PrepareToProduce() override {
+  Status Init() override {
+    RETURN_NOT_OK(ExecNode::Init());
     bool use_sync_execution = !(plan_->exec_context()->executor());
     // TODO(ARROW-15732)
     // Each side of join might have an IO thread being called from. Once this 
is fixed
     // we will change it back to just the CPU's thread pool capacity.
     size_t num_threads = (GetCpuThreadPoolCapacity() + 
io::GetIOThreadPoolCapacity() + 1);
 
-    scheduler_ = TaskScheduler::Make();
     pushdown_context_.Init(
-        this, num_threads, scheduler_.get(), [this]() { return 
OnFiltersReceived(); },
-        disable_bloom_filter_, use_sync_execution);
+        this, num_threads,
+        [this](std::function<Status(size_t, int64_t)> fn,
+               std::function<Status(size_t)> on_finished) {
+          return plan_->RegisterTaskGroup(std::move(fn), 
std::move(on_finished));
+        },
+        [this](int task_group_id, int64_t num_tasks) {
+          return plan_->StartTaskGroup(task_group_id, num_tasks);
+        },
+        [this]() { return OnFiltersReceived(); }, disable_bloom_filter_,
+        use_sync_execution);
 
     RETURN_NOT_OK(impl_->Init(
         plan_->exec_context(), join_type_, num_threads, 
&(schema_mgr_->proj_maps[0]),
         &(schema_mgr_->proj_maps[1]), key_cmp_, filter_,
-        [this](int64_t /*ignored*/, ExecBatch batch) {
-          this->OutputBatchCallback(batch);
+        [this](std::function<Status(size_t, int64_t)> fn,
+               std::function<Status(size_t)> on_finished) {
+          return plan_->RegisterTaskGroup(std::move(fn), 
std::move(on_finished));
+        },
+        [this](int task_group_id, int64_t num_tasks) {
+          return plan_->StartTaskGroup(task_group_id, num_tasks);
         },
-        [this](int64_t total_num_batches) { 
this->FinishedCallback(total_num_batches); },
-        scheduler_.get()));
+        [this](int64_t, ExecBatch batch) { this->OutputBatchCallback(batch); },
+        [this](int64_t total_num_batches) {
+          this->FinishedCallback(total_num_batches);
+        }));
 
-    task_group_probe_ = scheduler_->RegisterTaskGroup(
+    task_group_probe_ = plan_->RegisterTaskGroup(
         [this](size_t thread_index, int64_t task_id) -> Status {
           return impl_->ProbeSingleBatch(thread_index,
                                          
std::move(queued_batches_to_probe_[task_id]));
@@ -971,14 +989,6 @@ class HashJoinNode : public ExecNode {
           return OnQueuedBatchesProbed(thread_index);
         });
 
-    scheduler_->RegisterEnd();
-
-    RETURN_NOT_OK(scheduler_->StartScheduling(
-        0 /*thread index*/,
-        [this](std::function<Status(size_t)> func) -> Status {
-          return this->ScheduleTaskCallback(std::move(func));
-        },
-        static_cast<int>(2 * num_threads) /*concurrent tasks*/, 
use_sync_execution));
     return Status::OK();
   }
 
@@ -987,7 +997,7 @@ class HashJoinNode : public ExecNode {
                        {{"node.label", label()},
                         {"node.detail", ToString()},
                         {"node.kind", kind_name()}});
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished(), this);
+    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
     RETURN_NOT_OK(pushdown_context_.StartProducing());
     return Status::OK();
   }
@@ -1012,12 +1022,10 @@ class HashJoinNode : public ExecNode {
       for (auto&& input : inputs_) {
         input->StopProducing(this);
       }
-      impl_->Abort([this]() { ARROW_UNUSED(task_group_.End()); });
+      impl_->Abort([this]() { finished_.MarkFinished(); });
     }
   }
 
-  Future<> finished() override { return task_group_.OnFinished(); }
-
  protected:
   std::string ToStringExtra(int indent = 0) const override {
     return "implementation=" + impl_->ToString();
@@ -1032,42 +1040,18 @@ class HashJoinNode : public ExecNode {
     bool expected = false;
     if (complete_.compare_exchange_strong(expected, true)) {
       outputs_[0]->InputFinished(this, static_cast<int>(total_num_batches));
-      ARROW_UNUSED(task_group_.End());
+      finished_.MarkFinished();
     }
   }
 
-  Status ScheduleTaskCallback(std::function<Status(size_t)> func) {
-    auto executor = plan_->exec_context()->executor();
-    if (executor) {
-      return task_group_.AddTask([this, executor, func] {
-        return DeferNotOk(executor->Submit([this, func] {
-          size_t thread_index = thread_indexer_();
-          Status status = func(thread_index);
-          if (!status.ok()) {
-            StopProducing();
-            ErrorIfNotOk(status);
-            return;
-          }
-        }));
-      });
-    } else {
-      // We should not get here in serial execution mode
-      ARROW_DCHECK(false);
-    }
-    return Status::OK();
-  }
-
  private:
   AtomicCounter batch_count_[2];
   std::atomic<bool> complete_;
   JoinType join_type_;
   std::vector<JoinKeyCmp> key_cmp_;
   Expression filter_;
-  ThreadIndexer thread_indexer_;
   std::unique_ptr<HashJoinSchema> schema_mgr_;
   std::unique_ptr<HashJoinImpl> impl_;
-  util::AsyncTaskGroup task_group_;
-  std::unique_ptr<TaskScheduler> scheduler_;
   util::AccumulationQueue build_accumulator_;
   util::AccumulationQueue probe_accumulator_;
   util::AccumulationQueue queued_batches_to_probe_;
@@ -1087,14 +1071,14 @@ class HashJoinNode : public ExecNode {
   BloomFilterPushdownContext pushdown_context_;
 };
 
-void BloomFilterPushdownContext::Init(HashJoinNode* owner, size_t num_threads,
-                                      TaskScheduler* scheduler,
-                                      FiltersReceivedCallback 
on_bloom_filters_received,
-                                      bool disable_bloom_filter,
-                                      bool use_sync_execution) {
+void BloomFilterPushdownContext::Init(
+    HashJoinNode* owner, size_t num_threads,
+    RegisterTaskGroupCallback register_task_group_callback,
+    StartTaskGroupCallback start_task_group_callback,
+    FiltersReceivedCallback on_bloom_filters_received, bool 
disable_bloom_filter,
+    bool use_sync_execution) {
   schema_mgr_ = owner->schema_mgr_.get();
   ctx_ = owner->plan_->exec_context();
-  scheduler_ = scheduler;
   tld_.resize(num_threads);
   disable_bloom_filter_ = disable_bloom_filter;
   std::tie(push_.pushdown_target_, push_.column_map_) = 
GetPushdownTarget(owner);
@@ -1108,7 +1092,7 @@ void BloomFilterPushdownContext::Init(HashJoinNode* 
owner, size_t num_threads,
         use_sync_execution ? BloomFilterBuildStrategy::SINGLE_THREADED
                            : BloomFilterBuildStrategy::PARALLEL);
 
-    build_.task_id_ = scheduler_->RegisterTaskGroup(
+    build_.task_id_ = register_task_group_callback(
         [this](size_t thread_index, int64_t task_id) {
           return BuildBloomFilter_exec_task(thread_index, task_id);
         },
@@ -1117,13 +1101,14 @@ void BloomFilterPushdownContext::Init(HashJoinNode* 
owner, size_t num_threads,
         });
   }
 
-  eval_.task_id_ = scheduler_->RegisterTaskGroup(
+  eval_.task_id_ = register_task_group_callback(
       [this](size_t thread_index, int64_t task_id) {
         return FilterSingleBatch(thread_index, &eval_.batches_[task_id]);
       },
       [this](size_t thread_index) {
         return eval_.on_finished_(thread_index, std::move(eval_.batches_));
       });
+  start_task_group_callback_ = std::move(start_task_group_callback);
 }
 
 Status BloomFilterPushdownContext::StartProducing() {
@@ -1145,7 +1130,7 @@ Status 
BloomFilterPushdownContext::BuildBloomFilter(size_t thread_index,
       ctx_->memory_pool(), build_.batches_.row_count(), 
build_.batches_.batch_count(),
       push_.bloom_filter_.get()));
 
-  return scheduler_->StartTaskGroup(thread_index, build_.task_id_,
+  return start_task_group_callback_(build_.task_id_,
                                     
/*num_tasks=*/build_.batches_.batch_count());
 }
 
diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc 
b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
index 9a3c734278..b4fd7ee643 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
@@ -1747,6 +1747,9 @@ TEST(HashJoin, DictNegative) {
     EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
         NotImplemented, ::testing::HasSubstr("Unifying differing 
dictionaries"),
         StartAndCollect(plan.get(), sink_gen));
+    // Since we returned an error, the StartAndCollect future may return before
+    // the plan is done finishing.
+    plan->finished().Wait();
   }
 }
 
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc 
b/cpp/src/arrow/compute/exec/sink_node.cc
index eae12bf729..9118d5a50e 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -136,9 +136,7 @@ class SinkNode : public ExecNode {
                        {{"node.label", label()},
                         {"node.detail", ToString()},
                         {"node.kind", kind_name()}});
-    finished_ = Future<>::Make();
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
-
+    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
     return Status::OK();
   }
 
@@ -161,8 +159,6 @@ class SinkNode : public ExecNode {
     inputs_[0]->StopProducing(this);
   }
 
-  Future<> finished() override { return finished_; }
-
   void RecordBackpressureBytesUsed(const ExecBatch& batch) {
     if (backpressure_queue_.enabled()) {
       uint64_t bytes_used = static_cast<uint64_t>(batch.TotalBufferSize());
@@ -287,6 +283,7 @@ class ConsumingSinkNode : public ExecNode, public 
BackpressureControl {
                        {{"node.label", label()},
                         {"node.detail", ToString()},
                         {"node.kind", kind_name()}});
+    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
     DCHECK_GT(inputs_.size(), 0);
     auto output_schema = inputs_[0]->output_schema();
     if (names_.size() > 0) {
@@ -303,8 +300,6 @@ class ConsumingSinkNode : public ExecNode, public 
BackpressureControl {
       output_schema = schema(std::move(fields));
     }
     RETURN_NOT_OK(consumer_->Init(output_schema, this));
-    finished_ = Future<>::Make();
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
     return Status::OK();
   }
 
@@ -326,12 +321,10 @@ class ConsumingSinkNode : public ExecNode, public 
BackpressureControl {
 
   void StopProducing() override {
     EVENT(span_, "StopProducing");
-    Finish(Status::Invalid("ExecPlan was stopped early"));
+    Finish(Status::OK());
     inputs_[0]->StopProducing(this);
   }
 
-  Future<> finished() override { return finished_; }
-
   void InputReceived(ExecNode* input, ExecBatch batch) override {
     EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
     util::tracing::Span span;
@@ -365,9 +358,7 @@ class ConsumingSinkNode : public ExecNode, public 
BackpressureControl {
     EVENT(span_, "ErrorReceived", {{"error", error.message()}});
     DCHECK_EQ(input, inputs_[0]);
 
-    if (input_counter_.Cancel()) {
-      Finish(std::move(error));
-    }
+    if (input_counter_.Cancel()) Finish(error);
 
     inputs_[0]->StopProducing(this);
   }
diff --git a/cpp/src/arrow/compute/exec/source_node.cc 
b/cpp/src/arrow/compute/exec/source_node.cc
index ec2b91050d..33072f0026 100644
--- a/cpp/src/arrow/compute/exec/source_node.cc
+++ b/cpp/src/arrow/compute/exec/source_node.cc
@@ -75,6 +75,7 @@ struct SourceNode : ExecNode {
                         {"node.label", label()},
                         {"node.output_schema", output_schema()->ToString()},
                         {"node.detail", ToString()}});
+    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
     {
       // If another exec node encountered an error during its StartProducing 
call
       // it might have already called StopProducing on all of its inputs 
(including this
@@ -96,66 +97,55 @@ struct SourceNode : ExecNode {
       options.executor = executor;
       options.should_schedule = ShouldSchedule::IfDifferentExecutor;
     }
-    finished_ = Loop([this, executor, options] {
-                  std::unique_lock<std::mutex> lock(mutex_);
-                  int total_batches = batch_count_++;
-                  if (stop_requested_) {
-                    return 
Future<ControlFlow<int>>::MakeFinished(Break(total_batches));
-                  }
-                  lock.unlock();
-
-                  return generator_().Then(
-                      [=](const util::optional<ExecBatch>& maybe_batch)
-                          -> Future<ControlFlow<int>> {
-                        std::unique_lock<std::mutex> lock(mutex_);
-                        if (IsIterationEnd(maybe_batch) || stop_requested_) {
-                          stop_requested_ = true;
-                          return Break(total_batches);
-                        }
-                        lock.unlock();
-                        ExecBatch batch = std::move(*maybe_batch);
-
-                        if (executor) {
-                          auto status = task_group_.AddTask(
-                              [this, executor, batch]() -> Result<Future<>> {
-                                return executor->Submit([=]() {
-                                  outputs_[0]->InputReceived(this, 
std::move(batch));
-                                  return Status::OK();
-                                });
-                              });
-                          if (!status.ok()) {
-                            outputs_[0]->ErrorReceived(this, 
std::move(status));
-                            return Break(total_batches);
-                          }
-                        } else {
-                          outputs_[0]->InputReceived(this, std::move(batch));
-                        }
-                        lock.lock();
-                        if (!backpressure_future_.is_finished()) {
-                          EVENT(span_, "Source paused due to backpressure");
-                          return backpressure_future_.Then(
-                              []() -> ControlFlow<int> { return Continue(); });
-                        }
-                        return 
Future<ControlFlow<int>>::MakeFinished(Continue());
-                      },
-                      [=](const Status& error) -> ControlFlow<int> {
-                        // NB: ErrorReceived is independent of InputFinished, 
but
-                        // ErrorReceived will usually prompt StopProducing 
which will
-                        // prompt InputFinished. ErrorReceived may still be 
called from a
-                        // node which was requested to stop (indeed, the 
request to stop
-                        // may prompt an error).
-                        std::unique_lock<std::mutex> lock(mutex_);
-                        stop_requested_ = true;
-                        lock.unlock();
-                        outputs_[0]->ErrorReceived(this, error);
-                        return Break(total_batches);
-                      },
-                      options);
-                }).Then([&](int total_batches) {
-      outputs_[0]->InputFinished(this, total_batches);
-      return task_group_.End();
-    });
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
+    started_ = true;
+    auto fut = Loop([this, options] {
+                 std::unique_lock<std::mutex> lock(mutex_);
+                 int total_batches = batch_count_++;
+                 if (stop_requested_) {
+                   return 
Future<ControlFlow<int>>::MakeFinished(Break(total_batches));
+                 }
+                 lock.unlock();
+
+                 return generator_().Then(
+                     [=](const util::optional<ExecBatch>& maybe_batch)
+                         -> Future<ControlFlow<int>> {
+                       std::unique_lock<std::mutex> lock(mutex_);
+                       if (IsIterationEnd(maybe_batch) || stop_requested_) {
+                         stop_requested_ = true;
+                         return Break(total_batches);
+                       }
+                       lock.unlock();
+                       ExecBatch batch = std::move(*maybe_batch);
+                       RETURN_NOT_OK(plan_->ScheduleTask([=]() {
+                         outputs_[0]->InputReceived(this, std::move(batch));
+                         return Status::OK();
+                       }));
+                       lock.lock();
+                       if (!backpressure_future_.is_finished()) {
+                         EVENT(span_, "Source paused due to backpressure");
+                         return backpressure_future_.Then(
+                             []() -> ControlFlow<int> { return Continue(); });
+                       }
+                       return 
Future<ControlFlow<int>>::MakeFinished(Continue());
+                     },
+                     [=](const Status& error) -> ControlFlow<int> {
+                       std::unique_lock<std::mutex> lock(mutex_);
+                       stop_requested_ = true;
+                       lock.unlock();
+                       outputs_[0]->ErrorReceived(this, error);
+                       finished_.MarkFinished(error);
+                       return Break(total_batches);
+                     },
+                     options);
+               })
+                   .Then(
+                       [=](int total_batches) {
+                         outputs_[0]->InputFinished(this, total_batches);
+                         if (!finished_.is_finished()) 
finished_.MarkFinished();
+                       },
+                       {}, options);
+    if (!executor && finished_.is_finished()) return finished_.status();
+    RETURN_NOT_OK(plan_->AddFuture(fut));
     return Status::OK();
   }
 
@@ -196,17 +186,16 @@ struct SourceNode : ExecNode {
   void StopProducing() override {
     std::unique_lock<std::mutex> lock(mutex_);
     stop_requested_ = true;
+    if (!started_) finished_.MarkFinished();
   }
 
-  Future<> finished() override { return finished_; }
-
  private:
   std::mutex mutex_;
   int32_t backpressure_counter_{0};
   Future<> backpressure_future_ = Future<>::MakeFinished();
   bool stop_requested_{false};
+  bool started_ = false;
   int batch_count_{0};
-  util::AsyncTaskGroup task_group_;
   AsyncGenerator<util::optional<ExecBatch>> generator_;
 };
 
diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc 
b/cpp/src/arrow/compute/exec/swiss_join.cc
index 5d70e01b1e..5b01edb119 100644
--- a/cpp/src/arrow/compute/exec/swiss_join.cc
+++ b/cpp/src/arrow/compute/exec/swiss_join.cc
@@ -2026,8 +2026,10 @@ class SwissJoin : public HashJoinImpl {
               const HashJoinProjectionMaps* proj_map_left,
               const HashJoinProjectionMaps* proj_map_right,
               std::vector<JoinKeyCmp> key_cmp, Expression filter,
+              RegisterTaskGroupCallback register_task_group_callback,
+              StartTaskGroupCallback start_task_group_callback,
               OutputBatchCallback output_batch_callback,
-              FinishedCallback finished_callback, TaskScheduler* scheduler) 
override {
+              FinishedCallback finished_callback) override {
     START_COMPUTE_SPAN(span_, "SwissJoinImpl",
                        {{"detail", filter.ToString()},
                         {"join.kind", arrow::compute::ToString(join_type)},
@@ -2043,11 +2045,15 @@ class SwissJoin : public HashJoinImpl {
     for (size_t i = 0; i < key_cmp.size(); ++i) {
       key_cmp_[i] = key_cmp[i];
     }
+
     schema_[0] = proj_map_left;
     schema_[1] = proj_map_right;
-    output_batch_callback_ = output_batch_callback;
-    finished_callback_ = finished_callback;
-    scheduler_ = scheduler;
+
+    register_task_group_callback_ = std::move(register_task_group_callback);
+    start_task_group_callback_ = std::move(start_task_group_callback);
+    output_batch_callback_ = std::move(output_batch_callback);
+    finished_callback_ = std::move(finished_callback);
+
     hash_table_ready_.store(false);
     cancelled_.store(false);
     {
@@ -2081,17 +2087,17 @@ class SwissJoin : public HashJoinImpl {
   }
 
   void InitTaskGroups() {
-    task_group_build_ = scheduler_->RegisterTaskGroup(
+    task_group_build_ = register_task_group_callback_(
         [this](size_t thread_index, int64_t task_id) -> Status {
           return BuildTask(thread_index, task_id);
         },
         [this](size_t thread_index) -> Status { return 
BuildFinished(thread_index); });
-    task_group_merge_ = scheduler_->RegisterTaskGroup(
+    task_group_merge_ = register_task_group_callback_(
         [this](size_t thread_index, int64_t task_id) -> Status {
           return MergeTask(thread_index, task_id);
         },
         [this](size_t thread_index) -> Status { return 
MergeFinished(thread_index); });
-    task_group_scan_ = scheduler_->RegisterTaskGroup(
+    task_group_scan_ = register_task_group_callback_(
         [this](size_t thread_index, int64_t task_id) -> Status {
           return ScanTask(thread_index, task_id);
         },
@@ -2136,11 +2142,11 @@ class SwissJoin : public HashJoinImpl {
     return CancelIfNotOK(StartBuildHashTable(static_cast<int64_t>(thread_id)));
   }
 
-  void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override 
{
+  void Abort(AbortContinuationImpl pos_abort_callback) override {
     EVENT(span_, "Abort");
     END_SPAN(span_);
     std::ignore = CancelIfNotOK(Status::Cancelled("Hash Join Cancelled"));
-    scheduler_->Abort(std::move(pos_abort_callback));
+    pos_abort_callback();
   }
 
   std::string ToString() const override { return "SwissJoin"; }
@@ -2176,9 +2182,8 @@ class SwissJoin : public HashJoinImpl {
 
     // Process all input batches
     //
-    return 
CancelIfNotOK(scheduler_->StartTaskGroup(static_cast<size_t>(thread_id),
-                                                    task_group_build_,
-                                                    
build_side_batches_.batch_count()));
+    return CancelIfNotOK(
+        start_task_group_callback_(task_group_build_, 
build_side_batches_.batch_count()));
   }
 
   Status BuildTask(size_t thread_id, int64_t batch_id) {
@@ -2240,8 +2245,8 @@ class SwissJoin : public HashJoinImpl {
     // table.
     //
     RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PreparePrtnMerge()));
-    return CancelIfNotOK(scheduler_->StartTaskGroup(thread_id, 
task_group_merge_,
-                                                    
hash_table_build_.num_prtns()));
+    return CancelIfNotOK(
+        start_task_group_callback_(task_group_merge_, 
hash_table_build_.num_prtns()));
   }
 
   Status MergeTask(size_t /*thread_id*/, int64_t prtn_id) {
@@ -2286,8 +2291,7 @@ class SwissJoin : public HashJoinImpl {
       hash_table_.MergeHasMatch();
       int64_t num_tasks = bit_util::CeilDiv(hash_table_.num_rows(), 
kNumRowsPerScanTask);
 
-      return 
CancelIfNotOK(scheduler_->StartTaskGroup(static_cast<size_t>(thread_id),
-                                                      task_group_scan_, 
num_tasks));
+      return CancelIfNotOK(start_task_group_callback_(task_group_scan_, 
num_tasks));
     } else {
       return CancelIfNotOK(OnScanHashTableFinished());
     }
@@ -2472,12 +2476,13 @@ class SwissJoin : public HashJoinImpl {
   const HashJoinProjectionMaps* schema_[2];
 
   // Task scheduling
-  TaskScheduler* scheduler_;
   int task_group_build_;
   int task_group_merge_;
   int task_group_scan_;
 
   // Callbacks
+  RegisterTaskGroupCallback register_task_group_callback_;
+  StartTaskGroupCallback start_task_group_callback_;
   OutputBatchCallback output_batch_callback_;
   BuildFinishedCallback build_finished_callback_;
   FinishedCallback finished_callback_;
diff --git a/cpp/src/arrow/compute/exec/test_util.cc 
b/cpp/src/arrow/compute/exec/test_util.cc
index 330ee47112..b3e5d85b53 100644
--- a/cpp/src/arrow/compute/exec/test_util.cc
+++ b/cpp/src/arrow/compute/exec/test_util.cc
@@ -67,6 +67,7 @@ struct DummyNode : ExecNode {
     for (size_t i = 0; i < input_labels_.size(); ++i) {
       input_labels_[i] = std::to_string(i);
     }
+    finished_.MarkFinished();
   }
 
   const char* kind_name() const override { return "Dummy"; }
@@ -111,8 +112,6 @@ struct DummyNode : ExecNode {
     }
   }
 
-  Future<> finished() override { return Future<>::MakeFinished(); }
-
  private:
   void AssertIsOutput(ExecNode* output) {
     auto it = std::find(outputs_.begin(), outputs_.end(), output);
diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc 
b/cpp/src/arrow/compute/exec/tpch_node.cc
index d8f2c60312..d19f20eea7 100644
--- a/cpp/src/arrow/compute/exec/tpch_node.cc
+++ b/cpp/src/arrow/compute/exec/tpch_node.cc
@@ -663,8 +663,7 @@ class PartAndPartSupplierGenerator {
     return SetOutputColumns(cols, kPartsuppTypes, kPartsuppNameMap, 
partsupp_cols_);
   }
 
-  Result<util::optional<ExecBatch>> NextPartBatch() {
-    size_t thread_index = thread_indexer_();
+  Result<util::optional<ExecBatch>> NextPartBatch(size_t thread_index) {
     ThreadLocalData& tld = thread_local_data_[thread_index];
     {
       std::lock_guard<std::mutex> lock(part_output_queue_mutex_);
@@ -719,8 +718,7 @@ class PartAndPartSupplierGenerator {
     return ExecBatch::Make(std::move(part_result));
   }
 
-  Result<util::optional<ExecBatch>> NextPartSuppBatch() {
-    size_t thread_index = thread_indexer_();
+  Result<util::optional<ExecBatch>> NextPartSuppBatch(size_t thread_index) {
     ThreadLocalData& tld = thread_local_data_[thread_index];
     {
       std::lock_guard<std::mutex> lock(partsupp_output_queue_mutex_);
@@ -1284,7 +1282,6 @@ class PartAndPartSupplierGenerator {
   int64_t part_rows_generated_{0};
   std::vector<int> part_cols_;
   std::vector<int> partsupp_cols_;
-  ThreadIndexer thread_indexer_;
 
   std::atomic<int64_t> part_batches_generated_ = {0};
   std::atomic<int64_t> partsupp_batches_generated_ = {0};
@@ -1326,8 +1323,7 @@ class OrdersAndLineItemGenerator {
     return SetOutputColumns(cols, kLineitemTypes, kLineitemNameMap, 
lineitem_cols_);
   }
 
-  Result<util::optional<ExecBatch>> NextOrdersBatch() {
-    size_t thread_index = thread_indexer_();
+  Result<util::optional<ExecBatch>> NextOrdersBatch(size_t thread_index) {
     ThreadLocalData& tld = thread_local_data_[thread_index];
     {
       std::lock_guard<std::mutex> lock(orders_output_queue_mutex_);
@@ -1382,8 +1378,7 @@ class OrdersAndLineItemGenerator {
     return ExecBatch::Make(std::move(orders_result));
   }
 
-  Result<util::optional<ExecBatch>> NextLineItemBatch() {
-    size_t thread_index = thread_indexer_();
+  Result<util::optional<ExecBatch>> NextLineItemBatch(size_t thread_index) {
     ThreadLocalData& tld = thread_local_data_[thread_index];
     ExecBatch queued;
     bool from_queue = false;
@@ -2397,7 +2392,6 @@ class OrdersAndLineItemGenerator {
   int64_t orders_rows_generated_;
   std::vector<int> orders_cols_;
   std::vector<int> lineitem_cols_;
-  ThreadIndexer thread_indexer_;
 
   std::atomic<size_t> orders_batches_generated_ = {0};
   std::atomic<size_t> lineitem_batches_generated_ = {0};
@@ -2712,9 +2706,10 @@ class PartGenerator : public TpchTableGenerator {
   std::shared_ptr<Schema> schema() const override { return schema_; }
 
  private:
-  Status ProduceCallback(size_t) {
+  Status ProduceCallback(size_t thread_index) {
     if (done_.load()) return Status::OK();
-    ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch, 
gen_->NextPartBatch());
+    ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
+                          gen_->NextPartBatch(thread_index));
     if (!maybe_batch.has_value()) {
       int64_t batches_generated = gen_->part_batches_generated();
       if (batches_generated == batches_outputted_.load()) {
@@ -2773,10 +2768,10 @@ class PartSuppGenerator : public TpchTableGenerator {
   std::shared_ptr<Schema> schema() const override { return schema_; }
 
  private:
-  Status ProduceCallback(size_t) {
+  Status ProduceCallback(size_t thread_index) {
     if (done_.load()) return Status::OK();
     ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
-                          gen_->NextPartSuppBatch());
+                          gen_->NextPartSuppBatch(thread_index));
     if (!maybe_batch.has_value()) {
       int64_t batches_generated = gen_->partsupp_batches_generated();
       if (batches_generated == batches_outputted_.load()) {
@@ -3092,9 +3087,10 @@ class OrdersGenerator : public TpchTableGenerator {
   std::shared_ptr<Schema> schema() const override { return schema_; }
 
  private:
-  Status ProduceCallback(size_t) {
+  Status ProduceCallback(size_t thread_index) {
     if (done_.load()) return Status::OK();
-    ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch, 
gen_->NextOrdersBatch());
+    ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
+                          gen_->NextOrdersBatch(thread_index));
     if (!maybe_batch.has_value()) {
       int64_t batches_generated = gen_->orders_batches_generated();
       if (batches_generated == batches_outputted_.load()) {
@@ -3153,10 +3149,10 @@ class LineitemGenerator : public TpchTableGenerator {
   std::shared_ptr<Schema> schema() const override { return schema_; }
 
  private:
-  Status ProduceCallback(size_t) {
+  Status ProduceCallback(size_t thread_index) {
     if (done_.load()) return Status::OK();
     ARROW_ASSIGN_OR_RAISE(util::optional<ExecBatch> maybe_batch,
-                          gen_->NextLineItemBatch());
+                          gen_->NextLineItemBatch(thread_index));
     if (!maybe_batch.has_value()) {
       int64_t batches_generated = gen_->lineitem_batches_generated();
       if (batches_generated == batches_outputted_.load()) {
@@ -3379,7 +3375,7 @@ class TpchNode : public ExecNode {
 
   Status StartProducing() override {
     return generator_->StartProducing(
-        thread_indexer_.Capacity(),
+        plan_->max_concurrency(),
         [this](ExecBatch batch) { this->OutputBatchCallback(std::move(batch)); 
},
         [this](int64_t num_batches) { this->FinishedCallback(num_batches); },
         [this](std::function<Status(size_t)> func) -> Status {
@@ -3400,10 +3396,10 @@ class TpchNode : public ExecNode {
   }
 
   void StopProducing() override {
-    if (generator_->Abort()) std::ignore = task_group_.End();
+    if (generator_->Abort()) finished_.MarkFinished();
   }
 
-  Future<> finished() override { return task_group_.OnFinished(); }
+  Future<> finished() override { return finished_; }
 
  private:
   void OutputBatchCallback(ExecBatch batch) {
@@ -3412,42 +3408,23 @@ class TpchNode : public ExecNode {
 
   void FinishedCallback(int64_t total_num_batches) {
     outputs_[0]->InputFinished(this, static_cast<int>(total_num_batches));
-    std::ignore = task_group_.End();
+    finished_.MarkFinished();
   }
 
   Status ScheduleTaskCallback(std::function<Status(size_t)> func) {
-    auto executor = plan_->exec_context()->executor();
-
-    // Due to the way that the generators schedule tasks, there may be more 
tasks
-    // than output batches. After outputting the last batch, the generator will
-    // end the task group, but there may still be other threads that try to 
schedule
-    // tasks while the task group is being ended. This can result in adding 
tasks after
-    // the task group is ended. If those tasks were to be executed, 
correctness would
-    // not be affected as they'd see the generator is done and exit 
immediately. As such,
-    // if the task group is ended we can just skip scheduling these tasks in 
general.
-    if (executor) {
-      RETURN_NOT_OK(task_group_.AddTaskIfNotEnded([&] {
-        return executor->Submit([this, func] {
-          size_t thread_index = thread_indexer_();
-          Status status = func(thread_index);
-          if (!status.ok()) {
-            StopProducing();
-            ErrorIfNotOk(status);
-            return;
-          }
-        });
-      }));
-    } else {
-      return func(0);
-    }
-    return Status::OK();
+    if (finished_.is_finished()) return Status::OK();
+    return plan_->ScheduleTask([this, func](size_t thread_index) {
+      Status status = func(thread_index);
+      if (!status.ok()) {
+        StopProducing();
+        ErrorIfNotOk(status);
+      }
+      return status;
+    });
   }
 
   const char* name_;
   std::unique_ptr<TpchTableGenerator> generator_;
-
-  util::AsyncTaskGroup task_group_;
-  ThreadIndexer thread_indexer_;
 };
 
 class TpchGenImpl : public TpchGen {
diff --git a/cpp/src/arrow/compute/exec/union_node.cc 
b/cpp/src/arrow/compute/exec/union_node.cc
index 22df39aac6..e5170c2bc9 100644
--- a/cpp/src/arrow/compute/exec/union_node.cc
+++ b/cpp/src/arrow/compute/exec/union_node.cc
@@ -116,8 +116,7 @@ class UnionNode : public ExecNode {
                        {{"node.label", label()},
                         {"node.detail", ToString()},
                         {"node.kind", kind_name()}});
-    finished_ = Future<>::Make();
-    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
+    END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
     return Status::OK();
   }
 
diff --git a/cpp/src/arrow/dataset/scanner_test.cc 
b/cpp/src/arrow/dataset/scanner_test.cc
index 5316f63d08..f5db16f694 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -662,11 +662,18 @@ TEST_P(TestScanner, ScanBatchesFailure) {
           [](const EnumeratedRecordBatch& batch) { return 
batch.record_batch.value; }))
           << "ScanBatchesUnordered() did not raise an error";
     }
-    ASSERT_OK_AND_ASSIGN(auto tagged_batch_it, scanner->ScanBatches());
-    EXPECT_TRUE(CheckIteratorRaises(
-        batch, std::move(tagged_batch_it),
-        [](const TaggedRecordBatch& batch) { return batch.record_batch; }))
-        << "ScanBatches() did not raise an error";
+
+    auto maybe_tagged_batch_it = scanner->ScanBatches();
+    if (!maybe_tagged_batch_it.ok()) {
+      EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we 
failed!"),
+                                      std::move(maybe_tagged_batch_it));
+    } else {
+      ASSERT_OK_AND_ASSIGN(auto tagged_batch_it, 
std::move(maybe_tagged_batch_it));
+      EXPECT_TRUE(CheckIteratorRaises(
+          batch, std::move(tagged_batch_it),
+          [](const TaggedRecordBatch& batch) { return batch.record_batch; }))
+          << "ScanBatches() did not raise an error";
+    }
   };
 
   // Case 1: failure when getting next scan task
@@ -748,10 +755,16 @@ TEST_P(TestScanner, FromReader) {
   AssertScannerEquals(target_reader.get(), scanner.get());
 
   // Such datasets can only be scanned once (but you can get fragments 
multiple times)
-  ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
-  EXPECT_RAISES_WITH_MESSAGE_THAT(
-      Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
-      batch_it.Next());
+  auto maybe_batch_it = scanner->ScanBatches();
+  if (maybe_batch_it.ok()) {
+    EXPECT_RAISES_WITH_MESSAGE_THAT(
+        Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+        (*maybe_batch_it).Next());
+  } else {
+    EXPECT_RAISES_WITH_MESSAGE_THAT(
+        Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+        std::move(maybe_batch_it));
+  }
 }
 
 INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner,
diff --git a/cpp/src/arrow/engine/substrait/util.cc 
b/cpp/src/arrow/engine/substrait/util.cc
index 27b61f0b34..36240d4682 100644
--- a/cpp/src/arrow/engine/substrait/util.cc
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -66,7 +66,7 @@ class SubstraitExecutor {
  public:
   explicit SubstraitExecutor(std::shared_ptr<compute::ExecPlan> plan,
                              compute::ExecContext exec_context)
-      : plan_(std::move(plan)), exec_context_(exec_context) {}
+      : plan_(std::move(plan)), plan_started_(false), 
exec_context_(exec_context) {}
 
   ~SubstraitExecutor() { ARROW_UNUSED(this->Close()); }
 
@@ -75,6 +75,7 @@ class SubstraitExecutor {
       RETURN_NOT_OK(decl.AddToPlan(plan_.get()).status());
     }
     RETURN_NOT_OK(plan_->Validate());
+    plan_started_ = true;
     RETURN_NOT_OK(plan_->StartProducing());
     auto schema = sink_consumer_->schema();
     std::shared_ptr<RecordBatchReader> sink_reader = 
compute::MakeGeneratorReader(
@@ -82,7 +83,10 @@ class SubstraitExecutor {
     return sink_reader;
   }
 
-  Status Close() { return plan_->finished().status(); }
+  Status Close() {
+    if (plan_started_) return plan_->finished().status();
+    return Status::OK();
+  }
 
   Status Init(const Buffer& substrait_buffer, const ExtensionIdRegistry* 
registry) {
     if (substrait_buffer.size() == 0) {
@@ -102,6 +106,7 @@ class SubstraitExecutor {
   arrow::PushGenerator<util::optional<compute::ExecBatch>> generator_;
   std::vector<compute::Declaration> declarations_;
   std::shared_ptr<compute::ExecPlan> plan_;
+  bool plan_started_;
   compute::ExecContext exec_context_;
   std::shared_ptr<SubstraitSinkConsumer> sink_consumer_;
 };
diff --git a/cpp/src/arrow/util/future.cc b/cpp/src/arrow/util/future.cc
index ca4290c5b0..ab59234dea 100644
--- a/cpp/src/arrow/util/future.cc
+++ b/cpp/src/arrow/util/future.cc
@@ -314,22 +314,34 @@ class ConcreteFutureImpl : public FutureImpl {
   }
 
   void DoMarkFinishedOrFailed(FutureState state) {
+    std::vector<CallbackRecord> callbacks;
+    std::shared_ptr<FutureImpl> self;
     {
       // Lock the hypothetical waiter first, and the future after.
       // This matches the locking order done in FutureWaiter constructor.
       std::unique_lock<std::mutex> waiter_lock(global_waiter_mutex);
       std::unique_lock<std::mutex> lock(mutex_);
+#ifdef ARROW_WITH_OPENTELEMETRY
+      if (this->span_) {
+        util::tracing::Span& span = *span_;
+        END_SPAN(span);
+      }
+#endif
 
       DCHECK(!IsFutureFinished(state_)) << "Future already marked finished";
+      if (!callbacks_.empty()) {
+        callbacks = std::move(callbacks_);
+        auto self_inner = shared_from_this();
+        self = std::move(self_inner);
+      }
+
       state_ = state;
       if (waiter_ != nullptr) {
         waiter_->MarkFutureFinishedUnlocked(waiter_arg_, state);
       }
     }
     cv_.notify_all();
-
-    auto callbacks = std::move(callbacks_);
-    auto self = shared_from_this();
+    if (callbacks.empty()) return;
 
     // run callbacks, lock not needed since the future is finished by this
     // point so nothing else can modify the callbacks list and it is safe
diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h
index b374c77c81..2ac26b7f20 100644
--- a/cpp/src/arrow/util/future.h
+++ b/cpp/src/arrow/util/future.h
@@ -29,9 +29,11 @@
 #include "arrow/status.h"
 #include "arrow/type_fwd.h"
 #include "arrow/type_traits.h"
+#include "arrow/util/config.h"
 #include "arrow/util/functional.h"
 #include "arrow/util/macros.h"
 #include "arrow/util/optional.h"
+#include "arrow/util/tracing.h"
 #include "arrow/util/type_fwd.h"
 #include "arrow/util/visibility.h"
 
@@ -263,6 +265,10 @@ class ARROW_EXPORT FutureImpl : public 
std::enable_shared_from_this<FutureImpl>
   static std::unique_ptr<FutureImpl> Make();
   static std::unique_ptr<FutureImpl> MakeFinished(FutureState state);
 
+#ifdef ARROW_WITH_OPENTELEMETRY
+  void SetSpan(util::tracing::Span* span) { span_ = span; }
+#endif
+
   // Future API
   void MarkFinished();
   void MarkFailed();
@@ -294,6 +300,9 @@ class ARROW_EXPORT FutureImpl : public 
std::enable_shared_from_this<FutureImpl>
     CallbackOptions options;
   };
   std::vector<CallbackRecord> callbacks_;
+#ifdef ARROW_WITH_OPENTELEMETRY
+  util::tracing::Span* span_ = NULLPTR;
+#endif
 };
 
 // An object that waits on multiple futures at once.  Only one waiter
@@ -378,6 +387,10 @@ class ARROW_MUST_USE_TYPE Future {
   // of being able to presize a vector of Futures.
   Future() = default;
 
+#ifdef ARROW_WITH_OPENTELEMETRY
+  void SetSpan(util::tracing::Span* span) { impl_->SetSpan(span); }
+#endif
+
   // Consumer API
 
   bool is_valid() const { return impl_ != NULLPTR; }
diff --git a/cpp/src/arrow/util/tracing_internal.h 
b/cpp/src/arrow/util/tracing_internal.h
index 2898fd245f..d1da05671a 100644
--- a/cpp/src/arrow/util/tracing_internal.h
+++ b/cpp/src/arrow/util/tracing_internal.h
@@ -164,17 +164,8 @@ opentelemetry::trace::StartSpanOptions 
SpanOptionsWithParent(
 #define END_SPAN(target_span) \
   ::arrow::internal::tracing::UnwrapSpan(target_span.details.get())->End()
 
-#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future, 
target_capture) \
-  target_future = target_future.Then(                                          
   \
-      [target_capture]() {                                                     
   \
-        MARK_SPAN(target_span, Status::OK());                                  
   \
-        END_SPAN(target_span);                                                 
   \
-      },                                                                       
   \
-      [target_capture](const Status& st) {                                     
   \
-        MARK_SPAN(target_span, st);                                            
   \
-        END_SPAN(target_span);                                                 
   \
-        return st;                                                             
   \
-      })
+#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future) \
+  target_future.SetSpan(&target_span)
 
 #define PROPAGATE_SPAN_TO_GENERATOR(generator)                                \
   generator = ::arrow::internal::tracing::PropagateSpanThroughAsyncGenerator( \
@@ -207,7 +198,7 @@ class SpanImpl {};
 #define MARK_SPAN(target_span, status)
 #define EVENT(target_span, ...)
 #define END_SPAN(target_span)
-#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future, 
target_capture)
+#define END_SPAN_ON_FUTURE_COMPLETION(target_span, target_future)
 #define PROPAGATE_SPAN_TO_GENERATOR(generator)
 #define WRAP_ASYNC_GENERATOR(generator)
 #define WRAP_ASYNC_GENERATOR_WITH_CHILD_SPAN(generator, name)
diff --git a/python/pyarrow/tests/test_dataset.py 
b/python/pyarrow/tests/test_dataset.py
index b35d8f3178..277c6866f1 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -2928,9 +2928,9 @@ def test_incompatible_schema_hang(tempdir, 
dataset_reader):
     dataset = ds.dataset([str(fn)] * 100, schema=schema)
     assert dataset.schema.equals(schema)
     scanner = dataset_reader.scanner(dataset)
-    reader = scanner.to_reader()
     with pytest.raises(NotImplementedError,
                        match='Unsupported cast from int64 to null'):
+        reader = scanner.to_reader()
         reader.read_all()
 
 

Reply via email to