This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 54eedb95ec GH-33960: [C++] Add DeclarationToSchema and
DeclarationToString helper methods. (#34013)
54eedb95ec is described below
commit 54eedb95ec504a715d557e71139ae4df9657fde6
Author: Weston Pace <[email protected]>
AuthorDate: Fri Feb 3 12:25:19 2023 -0800
GH-33960: [C++] Add DeclarationToSchema and DeclarationToString helper
methods. (#34013)
Also cleans up the Acero server example to use the DeclarationToXyz methods
* Closes: #33960
Authored-by: Weston Pace <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/compute/exec/exec_plan.cc | 54 ++++++++-
cpp/src/arrow/compute/exec/exec_plan.h | 30 +++++
cpp/src/arrow/compute/exec/plan_test.cc | 116 +++++++++++--------
cpp/src/arrow/flight/sql/example/acero_server.cc | 137 ++---------------------
4 files changed, 160 insertions(+), 177 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc
b/cpp/src/arrow/compute/exec/exec_plan.cc
index 896eafb58c..a187d4346f 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -575,6 +575,52 @@ bool Declaration::IsValid(ExecFactoryRegistry* registry)
const {
return !this->factory_name.empty() && this->options != nullptr;
}
+namespace {
+
+Result<ExecNode*> EnsureSink(ExecNode* last_node, ExecPlan* plan) {
+ if (!last_node->is_sink()) {
+ Declaration null_sink =
+ Declaration("consuming_sink", {last_node},
+ ConsumingSinkNodeOptions(NullSinkNodeConsumer::Make()));
+ return null_sink.AddToPlan(plan);
+ }
+ return last_node;
+}
+
+} // namespace
+
+Result<std::shared_ptr<Schema>> DeclarationToSchema(const Declaration&
declaration,
+ FunctionRegistry*
function_registry) {
+ // We pass in the default memory pool and the CPU executor but nothing we
are doing
+ // should be starting new thread tasks or making large allocations.
+ ExecContext exec_context(default_memory_pool(),
::arrow::internal::GetCpuThreadPool(),
+ function_registry);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ExecPlan> exec_plan,
+ ExecPlan::Make(exec_context));
+ ARROW_ASSIGN_OR_RAISE(ExecNode * last_node,
declaration.AddToPlan(exec_plan.get()));
+ ARROW_ASSIGN_OR_RAISE(last_node, EnsureSink(last_node, exec_plan.get()));
+ ARROW_RETURN_NOT_OK(exec_plan->Validate());
+ if (last_node->inputs().size() != 1) {
+ // Every sink node today has exactly one input
+ return Status::Invalid("Unexpected sink node with more than one input");
+ }
+ return last_node->inputs()[0]->output_schema();
+}
+
+Result<std::string> DeclarationToString(const Declaration& declaration,
+ FunctionRegistry* function_registry) {
+ // We pass in the default memory pool and the CPU executor but nothing we
are doing
+ // should be starting new thread tasks or making large allocations.
+ ExecContext exec_context(default_memory_pool(),
::arrow::internal::GetCpuThreadPool(),
+ function_registry);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ExecPlan> exec_plan,
+ ExecPlan::Make(exec_context));
+ ARROW_ASSIGN_OR_RAISE(ExecNode * last_node,
declaration.AddToPlan(exec_plan.get()));
+ ARROW_ASSIGN_OR_RAISE(last_node, EnsureSink(last_node, exec_plan.get()));
+ ARROW_RETURN_NOT_OK(exec_plan->Validate());
+ return exec_plan->ToString();
+}
+
Future<std::shared_ptr<Table>> DeclarationToTableAsync(Declaration declaration,
ExecContext
exec_context) {
std::shared_ptr<std::shared_ptr<Table>> output_table =
@@ -817,11 +863,17 @@ Result<std::unique_ptr<RecordBatchReader>>
DeclarationToReader(
std::shared_ptr<Schema> schema() const override { return schema_; }
Status ReadNext(std::shared_ptr<RecordBatch>* record_batch) override {
- DCHECK(!!iterator_) << "call to ReadNext on already closed reader";
+ if (!iterator_) {
+ return Status::Invalid("call to ReadNext on already closed reader");
+ }
return iterator_->Next().Value(record_batch);
}
Status Close() override {
+ if (!iterator_) {
+ // Already closed
+ return Status::OK();
+ }
// End plan and read from generator until finished
std::shared_ptr<RecordBatch> batch;
do {
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h
b/cpp/src/arrow/compute/exec/exec_plan.h
index 0fcfb36754..dc875ef479 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -408,6 +408,36 @@ struct ARROW_EXPORT Declaration {
std::string label;
};
+/// \brief Calculate the output schema of a declaration
+///
+/// This does not actually execute the plan. This operation may fail if the
+/// declaration represents an invalid plan (e.g. a project node with multiple
inputs)
+///
+/// \param declaration A declaration describing an execution plan
+/// \param function_registry The function registry to use for function
execution. If null
+/// then the default function registry will be used.
+///
+/// \return the schema that batches would have after going through the
execution plan
+ARROW_EXPORT Result<std::shared_ptr<Schema>> DeclarationToSchema(
+ const Declaration& declaration, FunctionRegistry* function_registry =
NULLPTR);
+
+/// \brief Create a string representation of a plan
+///
+/// This representation is for debug purposes only.
+///
+/// Conversion to a string may fail if the declaration represents an
+/// invalid plan.
+///
+/// Use Substrait for complete serialization of plans
+///
+/// \param declaration A declaration describing an execution plan
+/// \param function_registry The function registry to use for function
execution. If null
+/// then the default function registry will be used.
+///
+/// \return a string representation of the plan suitable for debugging output
+ARROW_EXPORT Result<std::string> DeclarationToString(
+ const Declaration& declaration, FunctionRegistry* function_registry =
NULLPTR);
+
/// \brief Utility method to run a declaration and collect the results into a
table
///
/// \param declaration A declaration describing the plan to run
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc
b/cpp/src/arrow/compute/exec/plan_test.cc
index 497b719625..5b2af718df 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -440,6 +440,8 @@ TEST(ExecPlan, ToString) {
auto basic_data = MakeBasicBatches();
AsyncGenerator<std::optional<ExecBatch>> sink_gen;
+ // Cannot test the following mini-plans with DeclarationToString since
validation
+ // would fail (no sink)
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
ASSERT_OK(Declaration::Sequence(
{
@@ -456,40 +458,36 @@ TEST(ExecPlan, ToString) {
:SourceNode{}
)");
- ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make());
std::shared_ptr<CountOptions> options =
std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
- ASSERT_OK(
- Declaration::Sequence(
- {
- {"source",
- SourceNodeOptions{basic_data.schema,
- basic_data.gen(/*parallel=*/false,
/*slow=*/false)},
- "custom_source_label"},
- {"filter", FilterNodeOptions{greater_equal(field_ref("i32"),
literal(0))}},
- {"project", ProjectNodeOptions{{
- field_ref("bool"),
- call("multiply", {field_ref("i32"), literal(2)}),
- }}},
- {"aggregate",
- AggregateNodeOptions{
- /*aggregates=*/{
- {"hash_sum", nullptr, "multiply(i32, 2)",
"sum(multiply(i32, 2))"},
- {"hash_count", options, "multiply(i32, 2)",
- "count(multiply(i32, 2))"},
- {"hash_count_all", "count(*)"},
- },
- /*keys=*/{"bool"}}},
- {"filter",
FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
- literal(10))}},
- {"order_by_sink",
- OrderBySinkNodeOptions{
- SortOptions({SortKey{"sum(multiply(i32, 2))",
SortOrder::Ascending}}),
- &sink_gen},
- "custom_sink_label"},
- })
- .AddToPlan(plan.get()));
- EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 6 nodes:
+ Declaration declaration = Declaration::Sequence({
+ {"source",
+ SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false, /*slow=*/false)},
+ "custom_source_label"},
+ {"filter", FilterNodeOptions{greater_equal(field_ref("i32"),
literal(0))}},
+ {"project", ProjectNodeOptions{{
+ field_ref("bool"),
+ call("multiply", {field_ref("i32"), literal(2)}),
+ }}},
+ {"aggregate",
+ AggregateNodeOptions{
+ /*aggregates=*/{
+ {"hash_sum", nullptr, "multiply(i32, 2)", "sum(multiply(i32,
2))"},
+ {"hash_count", options, "multiply(i32, 2)",
"count(multiply(i32, 2))"},
+ {"hash_count_all", "count(*)"},
+ },
+ /*keys=*/{"bool"}}},
+ {"filter",
+ FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
literal(10))}},
+ {"order_by_sink",
+ OrderBySinkNodeOptions{
+ SortOptions({SortKey{"sum(multiply(i32, 2))",
SortOrder::Ascending}}),
+ &sink_gen},
+ "custom_sink_label"},
+ });
+ ASSERT_OK_AND_ASSIGN(std::string plan_str, DeclarationToString(declaration));
+ EXPECT_EQ(plan_str, R"a(ExecPlan with 6 nodes:
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
2))) ASC], null_placement=AtEnd}}
:FilterNode{filter=(sum(multiply(i32, 2)) > 10)}
:GroupByNode{keys=["bool"], aggregates=[
@@ -502,8 +500,6 @@
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
custom_source_label:SourceNode{}
)a");
- ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make());
-
Declaration union_node{"union", ExecNodeOptions{}};
Declaration lhs{"source",
SourceNodeOptions{basic_data.schema,
@@ -515,19 +511,17 @@
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
rhs.label = "rhs";
union_node.inputs.emplace_back(lhs);
union_node.inputs.emplace_back(rhs);
- ASSERT_OK(Declaration::Sequence(
- {
- union_node,
- {"aggregate",
- AggregateNodeOptions{/*aggregates=*/{
- {"count", options, "i32",
"count(i32)"},
- {"count_all", "count(*)"},
- },
- /*keys=*/{}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
- EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 5 nodes:
+ declaration = Declaration::Sequence({
+ union_node,
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+ {"count", options, "i32",
"count(i32)"},
+ {"count_all", "count(*)"},
+ },
+ /*keys=*/{}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ });
+ ASSERT_OK_AND_ASSIGN(plan_str, DeclarationToString(declaration));
+ EXPECT_EQ(plan_str, R"a(ExecPlan with 5 nodes:
:SinkNode{}
:ScalarAggregateNode{aggregates=[
count(i32, {mode=NON_NULL}),
@@ -674,6 +668,34 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) {
}
}
+TEST(ExecPlanExecution, DeclarationToSchema) {
+ auto basic_data = MakeBasicBatches();
+ auto plan = Declaration::Sequence(
+ {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false,
false))},
+ {"aggregate", AggregateNodeOptions({{"hash_sum", "i32", "int32_sum"}},
{"bool"})},
+ {"project",
+ ProjectNodeOptions({field_ref("int32_sum"),
+ call("multiply", {field_ref("int32_sum"),
literal(2)})})}});
+ auto expected_out_schema =
+ schema({field("int32_sum", int64()), field("multiply(int32_sum, 2)",
int64())});
+ ASSERT_OK_AND_ASSIGN(auto actual_out_schema,
DeclarationToSchema(std::move(plan)));
+ AssertSchemaEqual(expected_out_schema, actual_out_schema);
+}
+
+TEST(ExecPlanExecution, DeclarationToReader) {
+ auto basic_data = MakeBasicBatches();
+ auto plan = Declaration::Sequence(
+ {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false,
false))}});
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<RecordBatchReader> reader,
+ DeclarationToReader(plan));
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> out, reader->ToTable());
+ ASSERT_EQ(5, out->num_rows());
+ ASSERT_OK(reader->Close());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("already closed reader"),
+ reader->Next());
+}
+
TEST(ExecPlanExecution, ConsumingSinkNames) {
struct SchemaKeepingConsumer : public SinkNodeConsumer {
std::shared_ptr<Schema> schema_;
diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc
b/cpp/src/arrow/flight/sql/example/acero_server.cc
index 43b69d669f..ed5422e81f 100644
--- a/cpp/src/arrow/flight/sql/example/acero_server.cc
+++ b/cpp/src/arrow/flight/sql/example/acero_server.cc
@@ -35,108 +35,6 @@ namespace sql {
namespace acero_example {
namespace {
-/// \brief A SinkNodeConsumer that saves the schema as given to it by
-/// the ExecPlan. Used to retrieve the schema of a Substrait plan to
-/// fulfill the Flight SQL API contract.
-class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer {
- public:
- Status Init(const std::shared_ptr<Schema>& schema,
compute::BackpressureControl*,
- compute::ExecPlan* plan) override {
- schema_ = schema;
- return Status::OK();
- }
- Status Consume(compute::ExecBatch exec_batch) override { return
Status::OK(); }
- Future<> Finish() override { return Status::OK(); }
-
- const std::shared_ptr<Schema>& schema() const { return schema_; }
-
- private:
- std::shared_ptr<Schema> schema_;
-};
-
-/// \brief A SinkNodeConsumer that internally saves batches into a
-/// queue, so that it can be read from a RecordBatchReader. In other
-/// words, this bridges a push-based interface (ExecPlan) to a
-/// pull-based interface (RecordBatchReader).
-class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer {
- public:
- QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {}
-
- Status Init(const std::shared_ptr<Schema>& schema,
compute::BackpressureControl*,
- compute::ExecPlan* plan) override {
- schema_ = schema;
- return Status::OK();
- }
-
- Status Consume(compute::ExecBatch exec_batch) override {
- {
- std::lock_guard<std::mutex> guard(mutex_);
- batches_.push_back(std::move(exec_batch));
- batches_added_.notify_all();
- }
-
- return Status::OK();
- }
-
- Future<> Finish() override {
- {
- std::lock_guard<std::mutex> guard(mutex_);
- finished_ = true;
- batches_added_.notify_all();
- }
-
- return Status::OK();
- }
-
- const std::shared_ptr<Schema>& schema() const { return schema_; }
-
- arrow::Result<std::shared_ptr<RecordBatch>> Next() {
- compute::ExecBatch batch;
- {
- std::unique_lock<std::mutex> guard(mutex_);
- batches_added_.wait(guard, [this] { return !batches_.empty() ||
finished_; });
-
- if (finished_ && batches_.empty()) {
- return nullptr;
- }
- batch = std::move(batches_.front());
- batches_.pop_front();
- }
-
- return batch.ToRecordBatch(schema_);
- }
-
- private:
- std::mutex mutex_;
- std::condition_variable batches_added_;
- std::deque<compute::ExecBatch> batches_;
- std::shared_ptr<Schema> schema_;
- bool finished_;
-};
-
-/// \brief A RecordBatchReader that pulls from the
-/// QueuingSinkNodeConsumer above, blocking until results are
-/// available as necessary.
-class ConsumerBasedRecordBatchReader : public RecordBatchReader {
- public:
- explicit ConsumerBasedRecordBatchReader(
- std::shared_ptr<compute::ExecPlan> plan,
- std::shared_ptr<QueuingSinkNodeConsumer> consumer)
- : plan_(std::move(plan)), consumer_(std::move(consumer)) {}
-
- std::shared_ptr<Schema> schema() const override { return
consumer_->schema(); }
-
- Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
- return consumer_->Next().Value(batch);
- }
-
- // TODO(ARROW-17242): FlightDataStream needs to call Close()
- Status Close() override { return plan_->finished().status(); }
-
- private:
- std::shared_ptr<compute::ExecPlan> plan_;
- std::shared_ptr<QueuingSinkNodeConsumer> consumer_;
-};
/// \brief An implementation of a Flight SQL service backed by Acero.
class AceroFlightSqlServer : public FlightSqlServerBase {
@@ -193,18 +91,14 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
// GetFlightInfoSubstraitPlan encodes the plan into the ticket
std::shared_ptr<Buffer> serialized_plan =
Buffer::FromString(command.statement_handle);
- std::shared_ptr<QueuingSinkNodeConsumer> consumer =
- std::make_shared<QueuingSinkNodeConsumer>();
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<compute::ExecPlan> plan,
- engine::DeserializePlan(*serialized_plan, consumer));
-
- ARROW_LOG(INFO) << "DoGetStatement: executing plan " << plan->ToString();
+ ARROW_ASSIGN_OR_RAISE(compute::Declaration plan,
+ engine::DeserializePlan(*serialized_plan));
- plan->StartProducing();
+ ARROW_LOG(INFO) << "DoGetStatement: executing plan "
+ << compute::DeclarationToString(plan).ValueOr("Invalid
plan");
- auto reader =
std::make_shared<ConsumerBasedRecordBatchReader>(std::move(plan),
-
std::move(consumer));
- return std::make_unique<RecordBatchStream>(reader);
+ ARROW_ASSIGN_OR_RAISE(auto reader, compute::DeclarationToReader(plan));
+ return std::make_unique<RecordBatchStream>(std::move(reader));
}
arrow::Result<int64_t> DoPutCommandSubstraitPlan(
@@ -263,23 +157,8 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
arrow::Result<std::shared_ptr<arrow::Schema>> GetPlanSchema(
const std::string& serialized_plan) {
std::shared_ptr<Buffer> plan_buf = Buffer::FromString(serialized_plan);
- std::shared_ptr<GetSchemaSinkNodeConsumer> consumer =
- std::make_shared<GetSchemaSinkNodeConsumer>();
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<compute::ExecPlan> plan,
- engine::DeserializePlan(*plan_buf, consumer));
- std::shared_ptr<Schema> output_schema;
- for (compute::ExecNode* possible_sink : plan->nodes()) {
- if (possible_sink->is_sink()) {
- // Force SinkNodeConsumer::Init to be called
- ARROW_RETURN_NOT_OK(possible_sink->StartProducing());
- output_schema = consumer->schema();
- break;
- }
- }
- if (!output_schema) {
- return Status::Invalid("Could not infer output schema");
- }
- return output_schema;
+ ARROW_ASSIGN_OR_RAISE(compute::Declaration plan,
engine::DeserializePlan(*plan_buf));
+ return compute::DeclarationToSchema(plan);
}
arrow::Result<std::unique_ptr<FlightInfo>> MakeFlightInfo(