westonpace commented on a change in pull request #10397: URL: https://github.com/apache/arrow/pull/10397#discussion_r644175511
########## File path: cpp/src/arrow/compute/exec/exec_plan.cc ########## @@ -170,48 +165,409 @@ Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } -ExecNode::ExecNode(ExecPlan* plan, std::string label, - std::vector<BatchDescr> input_descrs, +ExecNode::ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, std::vector<std::string> input_labels, BatchDescr output_descr, int num_outputs) : plan_(plan), label_(std::move(label)), - input_descrs_(std::move(input_descrs)), + inputs_(std::move(inputs)), input_labels_(std::move(input_labels)), output_descr_(std::move(output_descr)), - num_outputs_(num_outputs) {} + num_outputs_(num_outputs) { + for (auto input : inputs_) { + input->outputs_.push_back(this); + } +} Status ExecNode::Validate() const { - if (inputs_.size() != input_descrs_.size()) { + if (inputs_.size() != input_labels_.size()) { return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ", - num_inputs(), ", actual ", inputs_.size(), ")"); + num_inputs(), ", actual ", input_labels_.size(), ")"); } if (static_cast<int>(outputs_.size()) != num_outputs_) { return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ", num_outputs(), ", actual ", outputs_.size(), ")"); } - DCHECK_EQ(input_descrs_.size(), input_labels_.size()); - for (auto out : outputs_) { auto input_index = GetNodeIndex(out->inputs(), this); if (!input_index) { return Status::Invalid("Node '", label(), "' outputs to node '", out->label(), "' but is not listed as an input."); } + } - const auto& in_descr = out->input_descrs_[*input_index]; - if (in_descr != output_descr_) { - return Status::Invalid( - "Node '", label(), "' (bound to input ", input_labels_[*input_index], - ") produces batches with type '", ValueDescr::ToString(output_descr_), - "' inconsistent with consumer '", out->label(), "' which accepts '", - ValueDescr::ToString(in_descr), "'"); + return Status::OK(); +} + +struct SourceNode : ExecNode { + SourceNode(ExecPlan* plan, std::string label, ExecNode::BatchDescr output_descr, + AsyncGenerator<util::optional<ExecBatch>> generator) + : ExecNode(plan, std::move(label), {}, {}, std::move(output_descr), + /*num_outputs=*/1), + generator_(std::move(generator)) {} + + const char* kind_name() override { return "SourceNode"; } + + static void NoInputs() { DCHECK(false) << "no inputs; this should never be called"; } + void InputReceived(ExecNode*, int, ExecBatch) override { NoInputs(); } + void ErrorReceived(ExecNode*, Status) override { NoInputs(); } + void InputFinished(ExecNode*, int) override { NoInputs(); } + + Status StartProducing() override { + if (finished_) { + return Status::Invalid("Restarted SourceNode '", label(), "'"); } + + auto gen = std::move(generator_); + + /// XXX should we wait on this future anywhere? In StopProducing() maybe? + auto done_fut = + Loop([gen, this] { + std::unique_lock<std::mutex> lock(mutex_); + int seq = next_batch_index_++; + if (finished_) { + return Future<ControlFlow<int>>::MakeFinished(Break(seq)); + } + lock.unlock(); + + return gen().Then( + [=](const util::optional<ExecBatch>& batch) -> ControlFlow<int> { + std::unique_lock<std::mutex> lock(mutex_); + if (!batch || finished_) { + finished_ = true; + return Break(seq); + } + lock.unlock(); + + outputs_[0]->InputReceived(this, seq, *batch); + return Continue(); + }, + [=](const Status& error) -> ControlFlow<int> { + std::unique_lock<std::mutex> lock(mutex_); + if (!finished_) { + finished_ = true; + lock.unlock(); + // unless we were already finished, push the error to our output + // XXX is this correct? Is it reasonable for a consumer to ignore errors + // from a finished producer? + outputs_[0]->ErrorReceived(this, error); + } + return Break(seq); + }); + }).Then([&](int seq) { + /// XXX this is probably redundant: do we always call InputFinished after + /// ErrorReceived or will ErrorRecieved be sufficient? + outputs_[0]->InputFinished(this, seq); + }); + + return Status::OK(); } - return Status::OK(); + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + std::unique_lock<std::mutex> lock(mutex_); + finished_ = true; + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + std::mutex mutex_; + bool finished_{false}; + int next_batch_index_{0}; + AsyncGenerator<util::optional<ExecBatch>> generator_; +}; + +ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, + ExecNode::BatchDescr output_descr, + AsyncGenerator<util::optional<ExecBatch>> generator) { + return plan->EmplaceNode<SourceNode>(plan, std::move(label), std::move(output_descr), + std::move(generator)); +} + +struct FilterNode : ExecNode { + FilterNode(ExecNode* input, std::string label, Expression filter) + : ExecNode(input->plan(), std::move(label), {input}, {"target"}, + /*output_descr=*/{input->output_descr()}, + /*num_outputs=*/1), + filter_(std::move(filter)) {} + + const char* kind_name() override { return "FilterNode"; } + + Result<ExecBatch> DoFilter(const ExecBatch& target) { + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, + SimplifyWithGuarantee(filter_, target.guarantee)); + + // XXX get a non-default exec context + ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target)); + + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as<BooleanScalar>(); + if (mask_scalar.is_valid && mask_scalar.value) { + return target; + } + + return target.Slice(0, 0); + } + + auto values = target.values; + for (auto& value : values) { + if (value.is_scalar()) continue; + ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults())); + } + return ExecBatch::Make(std::move(values)); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + auto maybe_filtered = DoFilter(std::move(batch)); + if (!maybe_filtered.ok()) { + outputs_[0]->ErrorReceived(this, maybe_filtered.status()); + inputs_[0]->StopProducing(this); + return; + } + + maybe_filtered->guarantee = batch.guarantee; + outputs_[0]->InputReceived(this, seq, maybe_filtered.MoveValueUnsafe()); + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->ErrorReceived(this, std::move(error)); + inputs_[0]->StopProducing(this); + } + + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->InputFinished(this, seq); + inputs_[0]->StopProducing(this); + } + + Status StartProducing() override { + // XXX validate inputs_[0]->output_descr() against filter_ + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + Expression filter_; +}; + +ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter) { + return input->plan()->EmplaceNode<FilterNode>(input, std::move(label), + std::move(filter)); +} + +struct ProjectNode : ExecNode { + ProjectNode(ExecNode* input, std::string label, std::vector<Expression> exprs) + : ExecNode(input->plan(), std::move(label), {input}, {"target"}, + /*output_descr=*/{input->output_descr()}, + /*num_outputs=*/1), + exprs_(std::move(exprs)) {} + + const char* kind_name() override { return "ProjectNode"; } + + Result<ExecBatch> DoProject(const ExecBatch& target) { + // XXX get a non-default exec context + std::vector<Datum> values{exprs_.size()}; + for (size_t i = 0; i < exprs_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, + SimplifyWithGuarantee(exprs_[i], target.guarantee)); + + ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target)); + } + return ExecBatch::Make(std::move(values)); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + auto maybe_projected = DoProject(std::move(batch)); + if (!maybe_projected.ok()) { + outputs_[0]->ErrorReceived(this, maybe_projected.status()); + inputs_[0]->StopProducing(this); + return; + } + + maybe_projected->guarantee = batch.guarantee; + outputs_[0]->InputReceived(this, seq, maybe_projected.MoveValueUnsafe()); + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->ErrorReceived(this, std::move(error)); + inputs_[0]->StopProducing(this); + } + + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->InputFinished(this, seq); + inputs_[0]->StopProducing(this); + } + + Status StartProducing() override { + // XXX validate inputs_[0]->output_descr() against filter_ + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + std::vector<Expression> exprs_; +}; + +ExecNode* MakeProjectNode(ExecNode* input, std::string label, + std::vector<Expression> exprs) { + return input->plan()->EmplaceNode<ProjectNode>(input, std::move(label), + std::move(exprs)); +} + +struct SinkNode : ExecNode { Review comment: I'm pondering how back pressure would be applied. I think there would be a new argument added to this `SinkNode` for `max_items_queued` or something like that. However, we could not naively apply that limit to `received_batches_` because of the resequencing. Since we are delivering to a pull-based model I think the appropriate way to apply back pressure would be to have the `PushGenerator` keep track of how many undelivered items it has. Then there would need to be a check in this code and, after pushing, if the `PushGenerator` is full, then apply back pressure to the inputs. The `PushGenerator` would also need some way of signalling back into the `SinkNode` that the pressure has been relieved and it is ready for more items. I don't think this has to be implemented now, but does that sound reasonable? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org