This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new a2453bd50f GH-41190: [C++] support for single threaded joins (#41125)
a2453bd50f is described below

commit a2453bd50fa7bf90b92850d4fddc779289499e96
Author: Joe Marshall <joe.marsh...@nottingham.ac.uk>
AuthorDate: Wed May 29 20:51:39 2024 +0100

    GH-41190: [C++] support for single threaded joins (#41125)
    
    When I initially added single threading support, I didn't do asof joins and 
sorted merge joins, because the code for these operations uses threads 
internally. This is a small check-in to add support for them. Tests run okay in 
single-threaded, I'm pushing it here to run full tests and check I didn't break 
the threaded case.
    
    I'm pushing this now because making this work saves adding a load of 
threading checks in python (this currently breaks single-threaded python i.e. 
emscripten).
    * GitHub Issue: #41190
    
    Lead-authored-by: Joe Marshall <joe.marsh...@nottingham.ac.uk>
    Co-authored-by: Rossi Sun <zanmato1...@gmail.com>
    Signed-off-by: Weston Pace <weston.p...@gmail.com>
---
 cpp/src/arrow/acero/CMakeLists.txt       | 21 ++------
 cpp/src/arrow/acero/asof_join_node.cc    | 85 +++++++++++++++++++++++++++++---
 cpp/src/arrow/acero/sorted_merge_node.cc | 52 +++++++++++++++----
 3 files changed, 125 insertions(+), 33 deletions(-)

diff --git a/cpp/src/arrow/acero/CMakeLists.txt 
b/cpp/src/arrow/acero/CMakeLists.txt
index 31ed4a6a69..73079059f1 100644
--- a/cpp/src/arrow/acero/CMakeLists.txt
+++ b/cpp/src/arrow/acero/CMakeLists.txt
@@ -173,13 +173,8 @@ add_arrow_acero_test(hash_join_node_test SOURCES 
hash_join_node_test.cc
                      bloom_filter_test.cc)
 add_arrow_acero_test(pivot_longer_node_test SOURCES pivot_longer_node_test.cc)
 
-# asof_join_node and sorted_merge_node use std::thread internally
-# and doesn't use ThreadPool so it will
-# be broken if threading is turned off
-if(ARROW_ENABLE_THREADING)
-  add_arrow_acero_test(asof_join_node_test SOURCES asof_join_node_test.cc)
-  add_arrow_acero_test(sorted_merge_node_test SOURCES 
sorted_merge_node_test.cc)
-endif()
+add_arrow_acero_test(asof_join_node_test SOURCES asof_join_node_test.cc)
+add_arrow_acero_test(sorted_merge_node_test SOURCES sorted_merge_node_test.cc)
 
 add_arrow_acero_test(tpch_node_test SOURCES tpch_node_test.cc)
 add_arrow_acero_test(union_node_test SOURCES union_node_test.cc)
@@ -228,9 +223,7 @@ if(ARROW_BUILD_BENCHMARKS)
   add_arrow_acero_benchmark(project_benchmark SOURCES benchmark_util.cc
                             project_benchmark.cc)
 
-  if(ARROW_ENABLE_THREADING)
-    add_arrow_acero_benchmark(asof_join_benchmark SOURCES 
asof_join_benchmark.cc)
-  endif()
+  add_arrow_acero_benchmark(asof_join_benchmark SOURCES asof_join_benchmark.cc)
 
   add_arrow_acero_benchmark(tpch_benchmark SOURCES tpch_benchmark.cc)
 
@@ -253,9 +246,7 @@ if(ARROW_BUILD_BENCHMARKS)
     target_link_libraries(arrow-acero-expression-benchmark PUBLIC 
arrow_acero_static)
     target_link_libraries(arrow-acero-filter-benchmark PUBLIC 
arrow_acero_static)
     target_link_libraries(arrow-acero-project-benchmark PUBLIC 
arrow_acero_static)
-    if(ARROW_ENABLE_THREADING)
-      target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC 
arrow_acero_static)
-    endif()
+    target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC 
arrow_acero_static)
     target_link_libraries(arrow-acero-tpch-benchmark PUBLIC arrow_acero_static)
     if(ARROW_BUILD_OPENMP_BENCHMARKS)
       target_link_libraries(arrow-acero-hash-join-benchmark PUBLIC 
arrow_acero_static)
@@ -264,9 +255,7 @@ if(ARROW_BUILD_BENCHMARKS)
     target_link_libraries(arrow-acero-expression-benchmark PUBLIC 
arrow_acero_shared)
     target_link_libraries(arrow-acero-filter-benchmark PUBLIC 
arrow_acero_shared)
     target_link_libraries(arrow-acero-project-benchmark PUBLIC 
arrow_acero_shared)
-    if(ARROW_ENABLE_THREADING)
-      target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC 
arrow_acero_shared)
-    endif()
+    target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC 
arrow_acero_shared)
     target_link_libraries(arrow-acero-tpch-benchmark PUBLIC arrow_acero_shared)
     if(ARROW_BUILD_OPENMP_BENCHMARKS)
       target_link_libraries(arrow-acero-hash-join-benchmark PUBLIC 
arrow_acero_shared)
diff --git a/cpp/src/arrow/acero/asof_join_node.cc 
b/cpp/src/arrow/acero/asof_join_node.cc
index 1d94467df9..848cbdf750 100644
--- a/cpp/src/arrow/acero/asof_join_node.cc
+++ b/cpp/src/arrow/acero/asof_join_node.cc
@@ -1014,6 +1014,8 @@ class AsofJoinNode : public ExecNode {
     }
   }
 
+#ifdef ARROW_ENABLE_THREADING
+
   template <typename Callable>
   struct Defer {
     Callable callable;
@@ -1100,6 +1102,7 @@ class AsofJoinNode : public ExecNode {
   }
 
   static void ProcessThreadWrapper(AsofJoinNode* node) { 
node->ProcessThread(); }
+#endif
 
  public:
   AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> 
input_labels,
@@ -1131,8 +1134,10 @@ class AsofJoinNode : public ExecNode {
   }
 
   virtual ~AsofJoinNode() {
-    process_.Push(false);  // poison pill
+#ifdef ARROW_ENABLE_THREADING
+    PushProcess(false);
     process_thread_.join();
+#endif
   }
 
   const std::vector<col_index_t>& indices_of_on_key() { return 
indices_of_on_key_; }
@@ -1410,7 +1415,8 @@ class AsofJoinNode : public ExecNode {
                rb->ToString(), DEBUG_MANIP(std::endl));
 
     ARROW_RETURN_NOT_OK(state_.at(k)->Push(rb));
-    process_.Push(true);
+    PushProcess(true);
+
     return Status::OK();
   }
 
@@ -1425,22 +1431,77 @@ class AsofJoinNode : public ExecNode {
     // The reason for this is that there are cases at the end of a table where 
we don't
     // know whether the RHS of the join is up-to-date until we know that the 
table is
     // finished.
-    process_.Push(true);
+    PushProcess(true);
+
     return Status::OK();
   }
+  void PushProcess(bool value) {
+#ifdef ARROW_ENABLE_THREADING
+    process_.Push(value);
+#else
+    if (value) {
+      ProcessNonThreaded();
+    } else if (!process_task_.is_finished()) {
+      EndFromSingleThread();
+    }
+#endif
+  }
 
-  Status StartProducing() override {
 #ifndef ARROW_ENABLE_THREADING
-    return Status::NotImplemented("ASOF join requires threading enabled");
+  bool ProcessNonThreaded() {
+    while (!process_task_.is_finished()) {
+      Result<std::shared_ptr<RecordBatch>> result = ProcessInner();
+
+      if (result.ok()) {
+        auto out_rb = *result;
+        if (!out_rb) break;
+        ExecBatch out_b(*out_rb);
+        out_b.index = batches_produced_++;
+        DEBUG_SYNC(this, "produce batch ", out_b.index, ":", 
DEBUG_MANIP(std::endl),
+                   out_rb->ToString(), DEBUG_MANIP(std::endl));
+        Status st = output_->InputReceived(this, std::move(out_b));
+        if (!st.ok()) {
+          // this isn't really from a thread,
+          // but we call through to this for consistency
+          EndFromSingleThread(std::move(st));
+          return false;
+        }
+      } else {
+        // this isn't really from a thread,
+        // but we call through to this for consistency
+        EndFromSingleThread(result.status());
+        return false;
+      }
+    }
+    auto& lhs = *state_.at(0);
+    if (lhs.Finished() && !process_task_.is_finished()) {
+      EndFromSingleThread(Status::OK());
+    }
+    return true;
+  }
+
+  void EndFromSingleThread(Status st = Status::OK()) {
+    process_task_.MarkFinished(st);
+    if (st.ok()) {
+      st = output_->InputFinished(this, batches_produced_);
+    }
+    for (const auto& s : state_) {
+      st &= s->ForceShutdown();
+    }
+  }
+
 #endif
 
+  Status StartProducing() override {
     ARROW_ASSIGN_OR_RAISE(process_task_, 
plan_->query_context()->BeginExternalTask(
                                              "AsofJoinNode::ProcessThread"));
     if (!process_task_.is_valid()) {
       // Plan has already aborted.  Do not start process thread
       return Status::OK();
     }
+#ifdef ARROW_ENABLE_THREADING
     process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this);
+#endif
     return Status::OK();
   }
 
@@ -1448,8 +1509,10 @@ class AsofJoinNode : public ExecNode {
   void ResumeProducing(ExecNode* output, int32_t counter) override {}
 
   Status StopProducingImpl() override {
+#ifdef ARROW_ENABLE_THREADING
     process_.Clear();
-    process_.Push(false);
+#endif
+    PushProcess(false);
     return Status::OK();
   }
 
@@ -1479,11 +1542,13 @@ class AsofJoinNode : public ExecNode {
 
   // Backpressure counter common to all inputs
   std::atomic<int32_t> backpressure_counter_;
+#ifdef ARROW_ENABLE_THREADING
   // Queue for triggering processing of a given input
   // (a false value is a poison pill)
   ConcurrentQueue<bool> process_;
   // Worker thread
   std::thread process_thread_;
+#endif
   Future<> process_task_;
 
   // In-progress batches produced
@@ -1511,9 +1576,13 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector 
inputs,
       debug_os_(join_options.debug_opts ? join_options.debug_opts->os : 
nullptr),
       debug_mutex_(join_options.debug_opts ? join_options.debug_opts->mutex : 
nullptr),
 #endif
-      backpressure_counter_(1),
+      backpressure_counter_(1)
+#ifdef ARROW_ENABLE_THREADING
+      ,
       process_(),
-      process_thread_() {
+      process_thread_()
+#endif
+{
   for (auto& key_hasher : key_hashers_) {
     key_hasher->node_ = this;
   }
diff --git a/cpp/src/arrow/acero/sorted_merge_node.cc 
b/cpp/src/arrow/acero/sorted_merge_node.cc
index 4d4565a6bb..a71ac79efc 100644
--- a/cpp/src/arrow/acero/sorted_merge_node.cc
+++ b/cpp/src/arrow/acero/sorted_merge_node.cc
@@ -262,19 +262,22 @@ class SortedMergeNode : public ExecNode {
       : ExecNode(plan, inputs, GetInputLabels(inputs), 
std::move(output_schema)),
         ordering_(std::move(new_ordering)),
         input_counter(inputs_.size()),
-        output_counter(inputs_.size()),
-        process_thread() {
+        output_counter(inputs_.size())
+#ifdef ARROW_ENABLE_THREADING
+        ,
+        process_thread()
+#endif
+  {
     SetLabel("sorted_merge");
   }
 
   ~SortedMergeNode() override {
-    process_queue.Push(
-        kPoisonPill);  // poison pill
-                       // We might create a temporary (such as to inspect the 
output
-                       // schema), in which case there isn't anything  to join
+    PushTask(kPoisonPill);
+#ifdef ARROW_ENABLE_THREADING
     if (process_thread.joinable()) {
       process_thread.join();
     }
+#endif
   }
 
   static arrow::Result<arrow::acero::ExecNode*> Make(
@@ -355,10 +358,25 @@ class SortedMergeNode : public ExecNode {
     // InputState's ConcurrentQueue manages locking
     input_counter[index] += rb->num_rows();
     ARROW_RETURN_NOT_OK(state[index]->Push(rb));
-    process_queue.Push(kNewTask);
+    PushTask(kNewTask);
     return Status::OK();
   }
 
+  void PushTask(bool ok) {
+#ifdef ARROW_ENABLE_THREADING
+    process_queue.Push(ok);
+#else
+    if (process_task.is_finished()) {
+      return;
+    }
+    if (ok == kNewTask) {
+      PollOnce();
+    } else {
+      EndFromProcessThread();
+    }
+#endif
+  }
+
   arrow::Status InputFinished(arrow::acero::ExecNode* input, int 
total_batches) override {
     ARROW_DCHECK(std_has(inputs_, input));
     {
@@ -368,7 +386,8 @@ class SortedMergeNode : public ExecNode {
       state.at(k)->set_total_batches(total_batches);
     }
     // Trigger a final process call for stragglers
-    process_queue.Push(kNewTask);
+    PushTask(kNewTask);
+
     return Status::OK();
   }
 
@@ -379,13 +398,17 @@ class SortedMergeNode : public ExecNode {
       // Plan has already aborted.  Do not start process thread
       return Status::OK();
     }
+#ifdef ARROW_ENABLE_THREADING
     process_thread = std::thread(&SortedMergeNode::StartPoller, this);
+#endif
     return Status::OK();
   }
 
   arrow::Status StopProducingImpl() override {
+#ifdef ARROW_ENABLE_THREADING
     process_queue.Clear();
-    process_queue.Push(kPoisonPill);
+#endif
+    PushTask(kPoisonPill);
     return Status::OK();
   }
 
@@ -408,6 +431,7 @@ class SortedMergeNode : public ExecNode {
           << input_counter[i] << " != " << output_counter[i];
     }
 
+#ifdef ARROW_ENABLE_THREADING
     ARROW_UNUSED(
         plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() 
mutable {
           Defer cleanup([this, &st]() { process_task.MarkFinished(st); });
@@ -415,6 +439,12 @@ class SortedMergeNode : public ExecNode {
             st = output_->InputFinished(this, batches_produced);
           }
         }));
+#else
+    process_task.MarkFinished(st);
+    if (st.ok()) {
+      st = output_->InputFinished(this, batches_produced);
+    }
+#endif
   }
 
   bool CheckEnded() {
@@ -552,6 +582,7 @@ class SortedMergeNode : public ExecNode {
     return true;
   }
 
+#ifdef ARROW_ENABLE_THREADING
   void EmitBatches() {
     while (true) {
       // Implementation note: If the queue is empty, we will block here
@@ -567,6 +598,7 @@ class SortedMergeNode : public ExecNode {
 
   /// The entry point for processThread
   static void StartPoller(SortedMergeNode* node) { node->EmitBatches(); }
+#endif
 
   arrow::Ordering ordering_;
 
@@ -583,11 +615,13 @@ class SortedMergeNode : public ExecNode {
 
   std::atomic<int32_t> batches_produced{0};
 
+#ifdef ARROW_ENABLE_THREADING
   // Queue to trigger processing of a given input. False acts as a poison pill
   ConcurrentQueue<bool> process_queue;
   // Once StartProducing is called, we initialize this thread to poll the
   // input states and emit batches
   std::thread process_thread;
+#endif
   arrow::Future<> process_task;
 
   // Map arg index --> completion counter

Reply via email to