This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new bd8005151c GH-34405: [C++] Add support for custom names in
QueryOptions. Wire this up to Substrait (#34406)
bd8005151c is described below
commit bd8005151cac0470474c0c65b6a9299f2c0bde83
Author: Weston Pace <[email protected]>
AuthorDate: Tue Mar 7 14:19:01 2023 -0800
GH-34405: [C++] Add support for custom names in QueryOptions. Wire this up
to Substrait (#34406)
### Rationale for this change
Users want to be able to specify custom column names / aliases instead of
using the ones generated by Acero
### What changes are included in this PR?
It is now possible to specify custom column names in QueryOptions. In
addition, the python Substrait bindings now use this feature so that the
Substrait plan's names will be respsected.
### Are these changes tested?
Yes. These are tested directly. In addition, I added a python test for
the Substrait bindings as this is actually a regression there and this should
close https://github.com/apache/arrow/issues/33434.
### Are there any user-facing changes?
There is new API surface but nothing breaking.
* Closes: #34405
* Closes: gh-33434
Authored-by: Weston Pace <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/compute/exec/exec_plan.cc | 70 +++++++++++++++++++++---
cpp/src/arrow/compute/exec/exec_plan.h | 12 +++-
cpp/src/arrow/compute/exec/options.h | 12 +++-
cpp/src/arrow/compute/exec/plan_test.cc | 30 ++++++++++
cpp/src/arrow/compute/exec/sink_node.cc | 19 +++----
cpp/src/arrow/engine/substrait/relation.h | 16 ++++++
cpp/src/arrow/engine/substrait/serde.cc | 51 +++++++++++------
cpp/src/arrow/engine/substrait/serde.h | 3 +-
cpp/src/arrow/engine/substrait/util.cc | 12 +++-
cpp/src/arrow/flight/sql/example/acero_server.cc | 14 +++--
python/pyarrow/tests/test_substrait.py | 46 +++++++++++++++-
11 files changed, 234 insertions(+), 51 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc
b/cpp/src/arrow/compute/exec/exec_plan.cc
index d119dc271a..c7e745d7a8 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -624,6 +624,7 @@ Future<std::shared_ptr<Table>> DeclarationToTableImpl(
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ExecPlan> exec_plan,
ExecPlan::Make(exec_ctx));
TableSinkNodeOptions sink_options(output_table.get());
sink_options.sequence_output = query_options.sequence_output;
+ sink_options.names = std::move(query_options.field_names);
Declaration with_sink =
Declaration::Sequence({declaration, {"table_sink", sink_options}});
ARROW_RETURN_NOT_OK(with_sink.AddToPlan(exec_plan.get()));
@@ -652,6 +653,10 @@ Future<BatchesWithCommonSchema>
DeclarationToExecBatchesImpl(
sink_options.sequence_output = options.sequence_output;
Declaration with_sink = Declaration::Sequence({declaration, {"sink",
sink_options}});
ARROW_RETURN_NOT_OK(with_sink.AddToPlan(exec_plan.get()));
+ if (!options.field_names.empty()) {
+ ARROW_ASSIGN_OR_RAISE(out_schema,
+
out_schema->WithNames(std::move(options.field_names)));
+ }
ARROW_RETURN_NOT_OK(exec_plan->Validate());
exec_plan->StartProducing();
auto collected_fut = CollectAsyncGenerator(sink_gen);
@@ -800,6 +805,19 @@ Result<std::vector<std::shared_ptr<RecordBatch>>>
DeclarationToBatches(
use_threads);
}
+Result<std::vector<std::shared_ptr<RecordBatch>>> DeclarationToBatches(
+ Declaration declaration, QueryOptions query_options) {
+ if (query_options.custom_cpu_executor != nullptr) {
+ return Status::Invalid("Cannot use synchronous methods with a custom CPU
executor");
+ }
+ return ::arrow::internal::RunSynchronously<
+ Future<std::vector<std::shared_ptr<RecordBatch>>>>(
+ [=, declaration = std::move(declaration)](::arrow::internal::Executor*
executor) {
+ return DeclarationToBatchesImpl(std::move(declaration), query_options,
executor);
+ },
+ query_options.use_threads);
+}
+
Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(Declaration
declaration,
ExecContext
exec_context) {
return DeclarationToExecBatchesImpl(std::move(declaration),
@@ -925,14 +943,35 @@ struct BatchConverter {
});
}
+ Result<std::shared_ptr<Schema>> InitializeSchema(
+ const std::vector<std::string>& names) {
+ // By this point this->schema will have been set by the SinkNode. We
potentially
+ // rename it with the names provided by the user and then return this in
case the user
+ // wants to know the output schema.
+ if (!names.empty()) {
+ if (static_cast<int>(names.size()) != schema->num_fields()) {
+ return Status::Invalid(
+ "A plan was created with custom field names but the number of
names (",
+ names.size(),
+ ") did not "
+ "match the number of output columns (",
+ schema->num_fields(), ")");
+ }
+ ARROW_ASSIGN_OR_RAISE(schema, schema->WithNames(names));
+ }
+ return schema;
+ }
+
AsyncGenerator<std::optional<ExecBatch>> exec_batch_gen;
std::shared_ptr<Schema> schema;
std::shared_ptr<ExecPlan> exec_plan;
};
Result<AsyncGenerator<std::shared_ptr<RecordBatch>>>
DeclarationToRecordBatchGenerator(
- Declaration declaration, ExecContext exec_ctx, std::shared_ptr<Schema>*
out_schema) {
+ Declaration declaration, QueryOptions options,
+ ::arrow::internal::Executor* cpu_executor, std::shared_ptr<Schema>*
out_schema) {
auto converter = std::make_shared<BatchConverter>();
+ ExecContext exec_ctx(options.memory_pool, cpu_executor,
options.function_registry);
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ExecPlan> plan,
ExecPlan::Make(exec_ctx));
Declaration with_sink = Declaration::Sequence(
{declaration,
@@ -941,23 +980,28 @@ Result<AsyncGenerator<std::shared_ptr<RecordBatch>>>
DeclarationToRecordBatchGen
ARROW_RETURN_NOT_OK(plan->Validate());
plan->StartProducing();
converter->exec_plan = std::move(plan);
- *out_schema = converter->schema;
+ ARROW_ASSIGN_OR_RAISE(*out_schema,
converter->InitializeSchema(options.field_names));
return [conv = std::move(converter)] { return (*conv)(); };
}
+
} // namespace
-Result<std::unique_ptr<RecordBatchReader>> DeclarationToReader(
- Declaration declaration, bool use_threads, MemoryPool* memory_pool,
- FunctionRegistry* function_registry) {
+Result<std::unique_ptr<RecordBatchReader>> DeclarationToReader(Declaration
declaration,
+ QueryOptions
options) {
+ if (options.custom_cpu_executor != nullptr) {
+ return Status::Invalid("Cannot use synchronous methods with a custom CPU
executor");
+ }
std::shared_ptr<Schema> schema;
auto batch_iterator =
std::make_unique<Iterator<std::shared_ptr<RecordBatch>>>(
::arrow::internal::IterateSynchronously<std::shared_ptr<RecordBatch>>(
[&](::arrow::internal::Executor* executor)
-> Result<AsyncGenerator<std::shared_ptr<RecordBatch>>> {
- ExecContext exec_ctx(memory_pool, executor, function_registry);
- return DeclarationToRecordBatchGenerator(declaration, exec_ctx,
&schema);
+ ExecContext exec_ctx(options.memory_pool, executor,
+ options.function_registry);
+ return DeclarationToRecordBatchGenerator(declaration,
std::move(options),
+ executor, &schema);
},
- use_threads));
+ options.use_threads));
struct PlanReader : RecordBatchReader {
PlanReader(std::shared_ptr<Schema> schema,
@@ -994,6 +1038,16 @@ Result<std::unique_ptr<RecordBatchReader>>
DeclarationToReader(
return std::make_unique<PlanReader>(std::move(schema),
std::move(batch_iterator));
}
+Result<std::unique_ptr<RecordBatchReader>> DeclarationToReader(
+ Declaration declaration, bool use_threads, MemoryPool* memory_pool,
+ FunctionRegistry* function_registry) {
+ QueryOptions options;
+ options.memory_pool = memory_pool;
+ options.function_registry = function_registry;
+ options.use_threads = use_threads;
+ return DeclarationToReader(std::move(declaration), std::move(options));
+}
+
namespace internal {
void RegisterSourceNode(ExecFactoryRegistry*);
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h
b/cpp/src/arrow/compute/exec/exec_plan.h
index 1f47515e3f..83b9248eb6 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -505,6 +505,12 @@ struct ARROW_EXPORT QueryOptions {
///
/// Must remain valid for the duration of the plan.
FunctionRegistry* function_registry = GetFunctionRegistry();
+ /// \brief the names of the output columns
+ ///
+ /// If this is empty then names will be generated based on the input columns
+ ///
+ /// If set then the number of names must equal the number of output columns
+ std::vector<std::string> field_names;
};
/// \brief Calculate the output schema of a declaration
@@ -622,6 +628,9 @@ ARROW_EXPORT
Result<std::vector<std::shared_ptr<RecordBatch>>> DeclarationToBatc
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
+ARROW_EXPORT Result<std::vector<std::shared_ptr<RecordBatch>>>
DeclarationToBatches(
+ Declaration declaration, QueryOptions query_options);
+
/// \brief Asynchronous version of \see DeclarationToBatches
///
/// \see DeclarationToTableAsync for details on threading & execution
@@ -656,9 +665,8 @@ ARROW_EXPORT Result<std::unique_ptr<RecordBatchReader>>
DeclarationToReader(
MemoryPool* memory_pool = default_memory_pool(),
FunctionRegistry* function_registry = NULLPTR);
-/// \brief Overload of \see DeclarationToReader accepting a custom exec context
ARROW_EXPORT Result<std::unique_ptr<RecordBatchReader>> DeclarationToReader(
- Declaration declaration, ExecContext exec_context);
+ Declaration declaration, QueryOptions query_options);
/// \brief Utility method to run a declaration and ignore results
///
diff --git a/cpp/src/arrow/compute/exec/options.h
b/cpp/src/arrow/compute/exec/options.h
index 628c547dfc..f532dd1c09 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -352,7 +352,6 @@ class ARROW_EXPORT SinkNodeConsumer {
/// This will be run once the schema is finalized as the plan is starting and
/// before any calls to Consume. A common use is to save off the schema so
that
/// batches can be interpreted.
- /// TODO(ARROW-17837) Move ExecPlan* plan to query context
virtual Status Init(const std::shared_ptr<Schema>& schema,
BackpressureControl* backpressure_control, ExecPlan*
plan) = 0;
/// \brief Consume a batch of data
@@ -380,7 +379,9 @@ class ARROW_EXPORT ConsumingSinkNodeOptions : public
ExecNodeOptions {
/// \brief Names to rename the sink's schema fields to
///
/// If specified then names must be provided for all fields. Currently, only
a flat
- /// schema is supported (see ARROW-15901).
+ /// schema is supported (see GH-31875).
+ ///
+ /// If not specified then names will be generated based on the source data.
std::vector<std::string> names;
/// \brief Controls whether batches should be emitted immediately or
sequenced in order
///
@@ -614,6 +615,13 @@ class ARROW_EXPORT TableSinkNodeOptions : public
ExecNodeOptions {
///
/// \see QueryOptions for more details
std::optional<bool> sequence_output;
+ /// \brief Custom names to use for the columns.
+ ///
+ /// If specified then names must be provided for all fields. Currently, only
a flat
+ /// schema is supported (see GH-31875).
+ ///
+ /// If not specified then names will be generated based on the source data.
+ std::vector<std::string> names;
};
struct ARROW_EXPORT PivotLongerRowTemplate {
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc
b/cpp/src/arrow/compute/exec/plan_test.cc
index e0868ebf12..eac4d12a06 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -549,6 +549,36 @@
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
)a");
}
+TEST(ExecPlanExecution, CustomFieldNames) {
+ Declaration source = gen::Gen({{"x", gen::Step()}})
+ ->FailOnError()
+ ->SourceNode(/*rows_per_batch=*/1,
/*num_batches=*/1);
+
+ QueryOptions opts;
+ opts.field_names = {"y"};
+
+ ASSERT_OK_AND_ASSIGN(std::vector<std::shared_ptr<RecordBatch>> batches,
+ DeclarationToBatches(source, opts));
+
+ std::shared_ptr<Schema> expected_schema = schema({field("y", uint32())});
+
+ for (const auto& batch : batches) {
+ AssertSchemaEqual(*expected_schema, *batch->schema());
+ }
+
+ ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema batches_with_schema,
+ DeclarationToExecBatches(source, opts));
+
+ AssertSchemaEqual(*expected_schema, *batches_with_schema.schema);
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> table,
DeclarationToTable(source, opts));
+ AssertSchemaEqual(*expected_schema, *table->schema());
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<RecordBatchReader> reader,
+ DeclarationToReader(source, opts));
+ AssertSchemaEqual(*expected_schema, *reader->schema());
+}
+
TEST(ExecPlanExecution, SourceOrderBy) {
std::vector<ExecBatch> expected = {
ExecBatchFromJSON({int32(), boolean()},
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc
b/cpp/src/arrow/compute/exec/sink_node.cc
index 34341a241e..79c2472822 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -104,7 +104,7 @@ class SinkNode : public ExecNode,
public:
SinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
AsyncGenerator<std::optional<ExecBatch>>* generator,
- std::shared_ptr<Schema>* schema, BackpressureOptions backpressure,
+ std::shared_ptr<Schema>* schema_out, BackpressureOptions
backpressure,
BackpressureMonitor** backpressure_monitor_out,
std::optional<bool> sequence_output)
: ExecNode(plan, std::move(inputs), {"collected"}, {}),
@@ -117,8 +117,8 @@ class SinkNode : public ExecNode,
*backpressure_monitor_out = &backpressure_queue_;
}
auto node_destroyed_capture = node_destroyed_;
- if (schema) {
- *schema = inputs_[0]->output_schema();
+ if (schema_out) {
+ *schema_out = inputs_[0]->output_schema();
}
*generator = [this, node_destroyed_capture]() ->
Future<std::optional<ExecBatch>> {
if (*node_destroyed_capture) {
@@ -334,15 +334,11 @@ class ConsumingSinkNode : public ExecNode,
if (names_.size() > 0) {
int num_fields = output_schema->num_fields();
if (names_.size() != static_cast<size_t>(num_fields)) {
- return Status::Invalid("ConsumingSinkNode with mismatched number of
names");
- }
- FieldVector fields(num_fields);
- int i = 0;
- for (const auto& output_field : output_schema->fields()) {
- fields[i] = field(names_[i], output_field->type());
- ++i;
+ return Status::Invalid(
+ "A plan was created with custom field names but the number of
names did not "
+ "match the number of output columns");
}
- output_schema = schema(std::move(fields));
+ ARROW_ASSIGN_OR_RAISE(output_schema, output_schema->WithNames(names_));
}
RETURN_NOT_OK(consumer_->Init(output_schema, this, plan_));
return Status::OK();
@@ -437,6 +433,7 @@ static Result<ExecNode*> MakeTableConsumingSinkNode(
std::make_shared<TableSinkNodeConsumer>(sink_options.output_table, pool);
auto consuming_sink_node_options = ConsumingSinkNodeOptions{tb_consumer};
consuming_sink_node_options.sequence_output = sink_options.sequence_output;
+ consuming_sink_node_options.names = sink_options.names;
return MakeExecNode("consuming_sink", plan, inputs,
consuming_sink_node_options);
}
diff --git a/cpp/src/arrow/engine/substrait/relation.h
b/cpp/src/arrow/engine/substrait/relation.h
index 39f7b65069..d750abd4e3 100644
--- a/cpp/src/arrow/engine/substrait/relation.h
+++ b/cpp/src/arrow/engine/substrait/relation.h
@@ -50,5 +50,21 @@ struct ARROW_ENGINE_EXPORT RelationInfo {
std::optional<std::vector<int>> field_output_indices;
};
+/// Information resulting from converting a Substrait plan
+struct ARROW_ENGINE_EXPORT PlanInfo {
+ /// The root declaration.
+ ///
+ /// Only plans containing a single top-level relation are supported and so
this will
+ /// represent that relation.
+ ///
+ /// This should technically be a RelRoot but some producers use a simple Rel
here and so
+ /// Acero currently supports that case.
+ DeclarationInfo root;
+ /// The names of the output fields
+ ///
+ /// If `root` was created from a simple Rel then this will be empty
+ std::vector<std::string> names;
+};
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/serde.cc
b/cpp/src/arrow/engine/substrait/serde.cc
index 0c4fe99317..a6116af959 100644
--- a/cpp/src/arrow/engine/substrait/serde.cc
+++ b/cpp/src/arrow/engine/substrait/serde.cc
@@ -127,13 +127,6 @@ DeclarationFactory MakeWriteDeclarationFactory(
};
}
-DeclarationFactory MakeNoSinkDeclarationFactory() {
- return [](compute::Declaration input,
- std::vector<std::string> names) -> Result<compute::Declaration> {
- return input;
- };
-}
-
constexpr uint32_t kMinimumMajorVersion = 0;
constexpr uint32_t kMinimumMinorVersion = 20;
@@ -194,19 +187,45 @@ Result<std::vector<compute::Declaration>>
DeserializePlans(
registry, ext_set_out, conversion_options);
}
-ARROW_ENGINE_EXPORT Result<compute::Declaration> DeserializePlan(
+ARROW_ENGINE_EXPORT Result<PlanInfo> DeserializePlan(
const Buffer& buf, const ExtensionIdRegistry* registry, ExtensionSet*
ext_set_out,
const ConversionOptions& conversion_options) {
- ARROW_ASSIGN_OR_RAISE(std::vector<compute::Declaration> top_level_decls,
- DeserializePlans(buf, MakeNoSinkDeclarationFactory(),
registry,
- ext_set_out, conversion_options));
- if (top_level_decls.empty()) {
- return Status::Invalid("No RelRoot in plan");
+ ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<substrait::Plan>(buf));
+
+ if (plan.version().major_number() < kMinimumMajorVersion &&
+ plan.version().minor_number() < kMinimumMinorVersion) {
+ return Status::Invalid("Can only parse plans with a version >= ",
+ kMinimumMajorVersion, ".", kMinimumMinorVersion);
}
- if (top_level_decls.size() != 1) {
- return Status::Invalid("Multiple top level declarations found in Substrait
plan");
+
+ ARROW_ASSIGN_OR_RAISE(auto ext_set,
+ GetExtensionSetFromPlan(plan, conversion_options,
registry));
+
+ if (plan.relations_size() == 0) {
+ return Status::Invalid("Plan has no relations");
+ }
+ if (plan.relations_size() > 1) {
+ return Status::NotImplemented("Common sub-plans");
+ }
+ const substrait::PlanRel& root_rel = plan.relations(0);
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto decl_info,
+ FromProto(root_rel.has_root() ? root_rel.root().input() :
root_rel.rel(), ext_set,
+ conversion_options));
+
+ std::vector<std::string> names;
+ if (root_rel.has_root()) {
+ names.assign(root_rel.root().names().begin(),
root_rel.root().names().end());
+ ARROW_ASSIGN_OR_RAISE(decl_info.output_schema,
+ decl_info.output_schema->WithNames(names));
}
- return top_level_decls[0];
+
+ if (ext_set_out) {
+ *ext_set_out = std::move(ext_set);
+ }
+
+ return PlanInfo{std::move(decl_info), std::move(names)};
}
namespace {
diff --git a/cpp/src/arrow/engine/substrait/serde.h
b/cpp/src/arrow/engine/substrait/serde.h
index a4e3b3df14..c13e0b90f3 100644
--- a/cpp/src/arrow/engine/substrait/serde.h
+++ b/cpp/src/arrow/engine/substrait/serde.h
@@ -28,6 +28,7 @@
#include "arrow/compute/type_fwd.h"
#include "arrow/dataset/type_fwd.h"
#include "arrow/engine/substrait/options.h"
+#include "arrow/engine/substrait/relation.h"
#include "arrow/engine/substrait/type_fwd.h"
#include "arrow/engine/substrait/visibility.h"
#include "arrow/result.h"
@@ -151,7 +152,7 @@ ARROW_ENGINE_EXPORT
Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
/// Plan is returned here.
/// \param[in] conversion_options options to control how the conversion is to
be done.
/// \return A declaration representing the Substrait plan
-ARROW_ENGINE_EXPORT Result<compute::Declaration> DeserializePlan(
+ARROW_ENGINE_EXPORT Result<PlanInfo> DeserializePlan(
const Buffer& buf, const ExtensionIdRegistry* registry = NULLPTR,
ExtensionSet* ext_set_out = NULLPTR,
const ConversionOptions& conversion_options = {});
diff --git a/cpp/src/arrow/engine/substrait/util.cc
b/cpp/src/arrow/engine/substrait/util.cc
index e0c876d21d..a00a7470fc 100644
--- a/cpp/src/arrow/engine/substrait/util.cc
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -28,6 +28,7 @@
#include "arrow/compute/exec/options.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/substrait/relation.h"
#include "arrow/engine/substrait/serde.h"
#include "arrow/engine/substrait/type_fwd.h"
#include "arrow/status.h"
@@ -44,11 +45,16 @@ Result<std::shared_ptr<RecordBatchReader>>
ExecuteSerializedPlan(
const Buffer& substrait_buffer, const ExtensionIdRegistry* registry,
compute::FunctionRegistry* func_registry, const ConversionOptions&
conversion_options,
bool use_threads, MemoryPool* memory_pool) {
- ARROW_ASSIGN_OR_RAISE(compute::Declaration plan,
+ ARROW_ASSIGN_OR_RAISE(PlanInfo plan_info,
DeserializePlan(substrait_buffer, registry,
/*ext_set_out=*/nullptr,
conversion_options));
- return compute::DeclarationToReader(std::move(plan), use_threads,
memory_pool,
- func_registry);
+ compute::QueryOptions query_options;
+ query_options.memory_pool = memory_pool;
+ query_options.function_registry = func_registry;
+ query_options.use_threads = use_threads;
+ query_options.field_names = plan_info.names;
+ return compute::DeclarationToReader(std::move(plan_info.root.declaration),
+ std::move(query_options));
}
Result<std::shared_ptr<Buffer>> SerializeJsonPlan(const std::string&
substrait_json) {
diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc
b/cpp/src/arrow/flight/sql/example/acero_server.cc
index ed5422e81f..c65ad186f8 100644
--- a/cpp/src/arrow/flight/sql/example/acero_server.cc
+++ b/cpp/src/arrow/flight/sql/example/acero_server.cc
@@ -91,13 +91,15 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
// GetFlightInfoSubstraitPlan encodes the plan into the ticket
std::shared_ptr<Buffer> serialized_plan =
Buffer::FromString(command.statement_handle);
- ARROW_ASSIGN_OR_RAISE(compute::Declaration plan,
+ ARROW_ASSIGN_OR_RAISE(engine::PlanInfo plan,
engine::DeserializePlan(*serialized_plan));
- ARROW_LOG(INFO) << "DoGetStatement: executing plan "
- << compute::DeclarationToString(plan).ValueOr("Invalid
plan");
+ ARROW_LOG(INFO)
+ << "DoGetStatement: executing plan "
+ <<
compute::DeclarationToString(plan.root.declaration).ValueOr("Invalid plan");
- ARROW_ASSIGN_OR_RAISE(auto reader, compute::DeclarationToReader(plan));
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+ compute::DeclarationToReader(plan.root.declaration));
return std::make_unique<RecordBatchStream>(std::move(reader));
}
@@ -157,8 +159,8 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
arrow::Result<std::shared_ptr<arrow::Schema>> GetPlanSchema(
const std::string& serialized_plan) {
std::shared_ptr<Buffer> plan_buf = Buffer::FromString(serialized_plan);
- ARROW_ASSIGN_OR_RAISE(compute::Declaration plan,
engine::DeserializePlan(*plan_buf));
- return compute::DeclarationToSchema(plan);
+ ARROW_ASSIGN_OR_RAISE(engine::PlanInfo plan,
engine::DeserializePlan(*plan_buf));
+ return compute::DeclarationToSchema(plan.root.declaration);
}
arrow::Result<std::unique_ptr<FlightInfo>> MakeFlightInfo(
diff --git a/python/pyarrow/tests/test_substrait.py
b/python/pyarrow/tests/test_substrait.py
index 87d3bfc444..d0da517ea7 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -116,7 +116,7 @@ def test_invalid_plan():
}
"""
buf = pa._substrait._parse_json_plan(tobytes(query))
- exec_message = "No RelRoot in plan"
+ exec_message = "Plan has no relations"
with pytest.raises(ArrowInvalid, match=exec_message):
substrait.run_query(buf)
@@ -443,7 +443,6 @@ def test_udf_via_substrait(unary_func_fixture, use_threads):
function, name = unary_func_fixture
expected_tb = test_table.add_column(1, 'y', function(
mock_scalar_udf_context(10), test_table['x']))
- res_tb = res_tb.rename_columns(['x', 'y'])
assert res_tb == expected_tb
@@ -563,3 +562,46 @@ def test_udf_via_substrait_wrong_udf_name():
with pytest.raises(pa.ArrowKeyError) as excinfo:
pa.substrait.run_query(buf, table_provider=table_provider)
assert "No function registered" in str(excinfo.value)
+
+
[email protected]("use_threads", [True, False])
+def test_output_field_names(use_threads):
+ in_table = pa.Table.from_pydict({"x": [1, 2, 3]})
+
+ def table_provider(names, schema):
+ return in_table
+
+ substrait_query = """
+ {
+ "version": { "major": 9999 },
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [{"i64": {}}]
+ },
+ "names": ["x"]
+ },
+ "namedTable": {
+ "names": ["t1"]
+ }
+ }
+ },
+ "names": ["out"]
+ }
+ }
+ ]
+ }
+ """
+
+ buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+ reader = pa.substrait.run_query(
+ buf, table_provider=table_provider, use_threads=use_threads)
+ res_tb = reader.read_all()
+
+ expected = pa.Table.from_pydict({"out": [1, 2, 3]})
+
+ assert res_tb == expected