JerAguilon commented on code in PR #38380: URL: https://github.com/apache/arrow/pull/38380#discussion_r1373792311
########## cpp/src/arrow/acero/sorted_merge_node.cc: ########## @@ -0,0 +1,606 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/api.h> +#include <atomic> +#include <mutex> +#include <sstream> +#include <thread> +#include <tuple> +#include <unordered_map> +#include <vector> +#include "arrow/acero/concurrent_queue.h" +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/options.h" +#include "arrow/acero/query_context.h" +#include "arrow/acero/time_series_util.h" +#include "arrow/acero/unmaterialized_table.h" +#include "arrow/acero/util.h" +#include "arrow/array/builder_base.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/logging.h" + +namespace { +template <typename Callable> +struct Defer { + Callable callable; + explicit Defer(Callable callable_) : callable(std::move(callable_)) {} + ~Defer() noexcept { callable(); } +}; + +std::vector<std::string> GetInputLabels( + const arrow::acero::ExecNode::NodeVector& inputs) { + std::vector<std::string> labels(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + labels[i] = "input_" + std::to_string(i) + "_label"; + } + return labels; +} + +template <typename T, typename V = typename T::value_type> +inline typename T::const_iterator std_find(const T& container, const V& val) { + return std::find(container.begin(), container.end(), val); +} + +template <typename T, typename V = typename T::value_type> +inline bool std_has(const T& container, const V& val) { + return container.end() != std_find(container, val); +} + +} // namespace + +namespace arrow::acero { + +namespace sorted_merge { + +// Each slice is associated with a single input source, so we only need 1 record +// batch per slice +using UnmaterializedSlice = arrow::acero::UnmaterializedSlice<1>; +using UnmaterializedCompositeTable = arrow::acero::UnmaterializedCompositeTable<1>; + +using row_index_t = uint64_t; +using time_unit_t = uint64_t; +using col_index_t = int; + +#define NEW_TASK true +#define POISON_PILL false + +class BackpressureController : public BackpressureControl { + public: + BackpressureController(ExecNode* node, ExecNode* output, + std::atomic<int32_t>& backpressure_counter) + : node_(node), output_(output), backpressure_counter_(backpressure_counter) {} + + void Pause() override { node_->PauseProducing(output_, ++backpressure_counter_); } + void Resume() override { node_->ResumeProducing(output_, ++backpressure_counter_); } + + private: + ExecNode* node_; + ExecNode* output_; + std::atomic<int32_t>& backpressure_counter_; +}; + +/// InputState correponds to an input. Input record batches are queued up in InputState +/// until processed and turned into output record batches. +class InputState { + public: + InputState(size_t index, BackpressureHandler handler, + const std::shared_ptr<arrow::Schema>& schema, const int time_col_index) + : index_(index), + queue_(std::move(handler)), + schema_(schema), + time_col_index_(time_col_index), + time_type_id_(schema_->fields()[time_col_index_]->type()->id()) {} + + template <typename PtrType> + static arrow::Result<PtrType> Make(size_t index, arrow::acero::ExecNode* input, + arrow::acero::ExecNode* output, + std::atomic<int32_t>& backpressure_counter, + const std::shared_ptr<arrow::Schema>& schema, + const col_index_t time_col_index) { + constexpr size_t low_threshold = 4, high_threshold = 8; + std::unique_ptr<arrow::acero::BackpressureControl> backpressure_control = + std::make_unique<BackpressureController>(input, output, backpressure_counter); + ARROW_ASSIGN_OR_RAISE(auto handler, + BackpressureHandler::Make(input, low_threshold, high_threshold, + std::move(backpressure_control))); + return PtrType(new InputState(index, std::move(handler), schema, time_col_index)); + } + + bool IsTimeColumn(col_index_t i) const { + DCHECK_LT(i, schema_->num_fields()); + return (i == time_col_index_); + } + + // Gets the latest row index, assuming the queue isn't empty + row_index_t GetLatestRow() const { return latest_ref_row_; } + + bool Empty() const { + // cannot be empty if ref row is >0 -- can avoid slow queue lock + // below + if (latest_ref_row_ > 0) { + return false; + } + return queue_.Empty(); + } + + size_t index() const { return index_; } + + int total_batches() const { return total_batches_; } + + // Gets latest batch (precondition: must not be empty) + const std::shared_ptr<arrow::RecordBatch>& GetLatestBatch() const { + return queue_.UnsyncFront(); + } + +#define LATEST_VAL_CASE(id, val) \ + case arrow::Type::id: { \ + using T = typename arrow::TypeIdTraits<arrow::Type::id>::Type; \ + using CType = typename arrow::TypeTraits<T>::CType; \ + return val(data->GetValues<CType>(1)[row]); \ + } + + inline time_unit_t GetLatestTime() const { + return GetTime(GetLatestBatch().get(), time_type_id_, time_col_index_, + latest_ref_row_); + } + +#undef LATEST_VAL_CASE + + bool Finished() const { return batches_processed_ == total_batches_; } + + arrow::Result<std::pair<UnmaterializedSlice, std::shared_ptr<arrow::RecordBatch>>> + Advance() { + // Advance the row until a new time is encountered or the record batch + // ends. This will return a range of {-1, -1} and a nullptr if there is + // no input + + bool active = + (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); + + if (!active) { + return std::make_pair(UnmaterializedSlice(), nullptr); + } + + row_index_t start = latest_ref_row_; + row_index_t end = latest_ref_row_; + time_unit_t startTime = GetLatestTime(); + std::shared_ptr<arrow::RecordBatch> batch = queue_.UnsyncFront(); + auto rows_in_batch = (row_index_t)batch->num_rows(); + + while (GetLatestTime() == startTime) { + end = ++latest_ref_row_; + if (latest_ref_row_ >= rows_in_batch) { + // hit the end of the batch, need to get the next batch if + // possible. + ++batches_processed_; + latest_ref_row_ = 0; + active &= !queue_.TryPop(); + if (active) { + DCHECK_GT(queue_.UnsyncFront()->num_rows(), + 0); // empty batches disallowed, sanity check + } + break; + } + } + + UnmaterializedSlice slice; + slice.num_components = 1; + slice.components[0] = CompositeEntry{batch.get(), start, end}; + return std::make_pair(slice, batch); + } + + arrow::Status Push(const std::shared_ptr<arrow::RecordBatch>& rb) { + if (rb->num_rows() > 0) { + queue_.Push(rb); + } else { + ++batches_processed_; // don't enqueue empty batches, just record + // as processed + } + return arrow::Status::OK(); + } + + const std::shared_ptr<arrow::Schema>& get_schema() const { return schema_; } + + void set_total_batches(int n) { + DCHECK_GE(n, 0); + DCHECK_EQ(total_batches_, -1) << "Set total batch more than once"; + total_batches_ = n; + } + + private: + size_t index_; + // Pending record batches. The latest is the front. Batches cannot be empty. + BackpressureConcurrentQueue<std::shared_ptr<arrow::RecordBatch>> queue_; + // Schema associated with the input + std::shared_ptr<arrow::Schema> schema_; + // Total number of batches (only int because InputFinished uses int) + std::atomic<int> total_batches_{-1}; + // Number of batches processed so far (only int because InputFinished uses + // int) + std::atomic<int> batches_processed_{0}; + // Index of the time col + col_index_t time_col_index_; + // Type id of the time column + arrow::Type::type time_type_id_; + // Index of the latest row reference within; if >0 then queue_ cannot be + // empty Must be < queue_.front()->num_rows() if queue_ is non-empty + row_index_t latest_ref_row_ = 0; + // Time of latest row + time_unit_t latest_time_ = std::numeric_limits<time_unit_t>::lowest(); +}; + +struct InputStateComparator { + bool operator()(const std::shared_ptr<InputState>& lhs, + const std::shared_ptr<InputState>& rhs) const { + // True if lhs is ahead of time of rhs + if (lhs->Finished()) { + return false; + } + if (rhs->Finished()) { + return false; + } + time_unit_t lFirst = lhs->GetLatestTime(); + time_unit_t rFirst = rhs->GetLatestTime(); + return lFirst > rFirst; + } +}; + +class SortedMergeNode : public ExecNode { + static constexpr int64_t kTargetOutputBatchSize = 1024 * 1024; + + public: + SortedMergeNode(arrow::acero::ExecPlan* plan, + std::vector<arrow::acero::ExecNode*> inputs, + std::shared_ptr<arrow::Schema> output_schema, + arrow::Ordering new_ordering) + : ExecNode(plan, inputs, GetInputLabels(inputs), std::move(output_schema)), + ordering_(std::move(new_ordering)), + input_counter(inputs_.size()), + output_counter(inputs_.size()), + process_thread() { + SetLabel("sorted_merge"); + } + + ~SortedMergeNode() override { + process_queue.Push( + POISON_PILL); // poison pill + // We might create a temporary (such as to inspect the output + // schema), in which case there isn't anything to join + if (process_thread.joinable()) { + process_thread.join(); + } + } + + static arrow::Result<arrow::acero::ExecNode*> Make( + arrow::acero::ExecPlan* plan, std::vector<arrow::acero::ExecNode*> inputs, + const arrow::acero::ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, static_cast<int>(inputs.size()), + "SortedMergeNode")); + + if (inputs.size() < 1) { + return Status::Invalid("Constructing a `SortedMergeNode` with < 1 inputs"); + } + + const auto schema = inputs.at(0)->output_schema(); + for (const auto& input : inputs) { + if (!input->output_schema()->Equals(schema)) { + return Status::Invalid( + "SortedMergeNode input schemas must all " + "match, first schema " + "was: ", + schema->ToString(), " got schema: ", input->output_schema()->ToString()); + } + } + + const auto& order_options = + arrow::internal::checked_cast<const OrderByNodeOptions&>(options); + + if (order_options.ordering.is_implicit() || order_options.ordering.is_unordered()) { + return Status::Invalid("`ordering` must be an explicit non-empty ordering"); + } + + std::shared_ptr<Schema> output_schema = inputs[0]->output_schema(); + return plan->EmplaceNode<SortedMergeNode>( + plan, std::move(inputs), std::move(output_schema), order_options.ordering); + } + + const char* kind_name() const override { return "SortedMergeNode"; } + + const arrow::Ordering& ordering() const override { return ordering_; } + + arrow::Status Init() override { + auto inputs = this->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + ExecNode* input = inputs[i]; + const auto& schema = input->output_schema(); + const auto& sort_key = ordering_.sort_keys()[0]; + if (sort_key.order != arrow::compute::SortOrder::Ascending) { + return Status::Invalid("Only ascending sort order is supported"); + } + + const auto& ref = sort_key.target; + if (!ref.IsName()) { + return Status::Invalid("Ordering must be a name. ", ref.ToString(), + " is not a name"); + } Review Comment: Below I do this: ``` ARROW_ASSIGN_OR_RAISE(auto input_state, InputState::Make<std::shared_ptr<InputState>>( i, input, this, backpressure_counter, schema, schema->GetFieldIndex(*ref.name()))); ``` Not sure if there's an API that I could be using to avoid this check, would be nice to remove the constraint. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
