JerAguilon commented on code in PR #38380:
URL: https://github.com/apache/arrow/pull/38380#discussion_r1373784091


##########
cpp/src/arrow/acero/sorted_merge_node.cc:
##########
@@ -0,0 +1,606 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/api.h>
+#include <atomic>
+#include <mutex>
+#include <sstream>
+#include <thread>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
+#include "arrow/acero/concurrent_queue.h"
+#include "arrow/acero/exec_plan.h"
+#include "arrow/acero/options.h"
+#include "arrow/acero/query_context.h"
+#include "arrow/acero/time_series_util.h"
+#include "arrow/acero/unmaterialized_table.h"
+#include "arrow/acero/util.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/logging.h"
+
+namespace {
+template <typename Callable>
+struct Defer {
+  Callable callable;
+  explicit Defer(Callable callable_) : callable(std::move(callable_)) {}
+  ~Defer() noexcept { callable(); }
+};
+
+std::vector<std::string> GetInputLabels(
+    const arrow::acero::ExecNode::NodeVector& inputs) {
+  std::vector<std::string> labels(inputs.size());
+  for (size_t i = 0; i < inputs.size(); i++) {
+    labels[i] = "input_" + std::to_string(i) + "_label";
+  }
+  return labels;
+}
+
+template <typename T, typename V = typename T::value_type>
+inline typename T::const_iterator std_find(const T& container, const V& val) {
+  return std::find(container.begin(), container.end(), val);
+}
+
+template <typename T, typename V = typename T::value_type>
+inline bool std_has(const T& container, const V& val) {
+  return container.end() != std_find(container, val);
+}
+
+}  // namespace
+
+namespace arrow::acero {
+
+namespace sorted_merge {
+
+// Each slice is associated with a single input source, so we only need 1 
record
+// batch per slice
+using UnmaterializedSlice = arrow::acero::UnmaterializedSlice<1>;
+using UnmaterializedCompositeTable = 
arrow::acero::UnmaterializedCompositeTable<1>;
+
+using row_index_t = uint64_t;
+using time_unit_t = uint64_t;
+using col_index_t = int;
+
+#define NEW_TASK true
+#define POISON_PILL false
+
+class BackpressureController : public BackpressureControl {
+ public:
+  BackpressureController(ExecNode* node, ExecNode* output,
+                         std::atomic<int32_t>& backpressure_counter)
+      : node_(node), output_(output), 
backpressure_counter_(backpressure_counter) {}
+
+  void Pause() override { node_->PauseProducing(output_, 
++backpressure_counter_); }
+  void Resume() override { node_->ResumeProducing(output_, 
++backpressure_counter_); }
+
+ private:
+  ExecNode* node_;
+  ExecNode* output_;
+  std::atomic<int32_t>& backpressure_counter_;
+};
+
+/// InputState correponds to an input. Input record batches are queued up in 
InputState
+/// until processed and turned into output record batches.
+class InputState {
+ public:
+  InputState(size_t index, BackpressureHandler handler,
+             const std::shared_ptr<arrow::Schema>& schema, const int 
time_col_index)
+      : index_(index),
+        queue_(std::move(handler)),
+        schema_(schema),
+        time_col_index_(time_col_index),
+        time_type_id_(schema_->fields()[time_col_index_]->type()->id()) {}
+
+  template <typename PtrType>
+  static arrow::Result<PtrType> Make(size_t index, arrow::acero::ExecNode* 
input,
+                                     arrow::acero::ExecNode* output,
+                                     std::atomic<int32_t>& 
backpressure_counter,
+                                     const std::shared_ptr<arrow::Schema>& 
schema,
+                                     const col_index_t time_col_index) {
+    constexpr size_t low_threshold = 4, high_threshold = 8;
+    std::unique_ptr<arrow::acero::BackpressureControl> backpressure_control =
+        std::make_unique<BackpressureController>(input, output, 
backpressure_counter);
+    ARROW_ASSIGN_OR_RAISE(auto handler,
+                          BackpressureHandler::Make(input, low_threshold, 
high_threshold,
+                                                    
std::move(backpressure_control)));
+    return PtrType(new InputState(index, std::move(handler), schema, 
time_col_index));
+  }
+
+  bool IsTimeColumn(col_index_t i) const {
+    DCHECK_LT(i, schema_->num_fields());
+    return (i == time_col_index_);
+  }
+
+  // Gets the latest row index, assuming the queue isn't empty
+  row_index_t GetLatestRow() const { return latest_ref_row_; }
+
+  bool Empty() const {
+    // cannot be empty if ref row is >0 -- can avoid slow queue lock
+    // below
+    if (latest_ref_row_ > 0) {
+      return false;
+    }
+    return queue_.Empty();
+  }
+
+  size_t index() const { return index_; }
+
+  int total_batches() const { return total_batches_; }
+
+  // Gets latest batch (precondition: must not be empty)
+  const std::shared_ptr<arrow::RecordBatch>& GetLatestBatch() const {
+    return queue_.UnsyncFront();
+  }
+
+#define LATEST_VAL_CASE(id, val)                                   \
+  case arrow::Type::id: {                                          \
+    using T = typename arrow::TypeIdTraits<arrow::Type::id>::Type; \
+    using CType = typename arrow::TypeTraits<T>::CType;            \
+    return val(data->GetValues<CType>(1)[row]);                    \
+  }
+
+  inline time_unit_t GetLatestTime() const {
+    return GetTime(GetLatestBatch().get(), time_type_id_, time_col_index_,
+                   latest_ref_row_);
+  }
+
+#undef LATEST_VAL_CASE
+
+  bool Finished() const { return batches_processed_ == total_batches_; }
+
+  arrow::Result<std::pair<UnmaterializedSlice, 
std::shared_ptr<arrow::RecordBatch>>>
+  Advance() {
+    // Advance the row until a new time is encountered or the record batch
+    // ends. This will return a range of {-1, -1} and a nullptr if there is
+    // no input

Review Comment:
   The asof join node has a time-ish nomenclature (things like `time_unit_t`, 
for example), so I kind of followed that tradition. in my biased opinion this 
is easier in my head to picture/understand, but let me know if you want me to 
change the language.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to