This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 342b74e419 GH-34786: [C++] Fix output schema calculated by Substrait
consumer for AggregateRel (#34904)
342b74e419 is described below
commit 342b74e4192a69e849b67e36b054078c208fac91
Author: rtpsw <[email protected]>
AuthorDate: Wed Apr 12 19:43:38 2023 +0300
GH-34786: [C++] Fix output schema calculated by Substrait consumer for
AggregateRel (#34904)
See #34786
* Closes: #34786
Can replace #34885
Lead-authored-by: Yaron Gvili <[email protected]>
Co-authored-by: rtpsw <[email protected]>
Co-authored-by: Weston Pace <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/acero/aggregate_node.cc | 187 ++++++++++++++-------
cpp/src/arrow/acero/aggregate_node.h | 57 +++++++
cpp/src/arrow/acero/hash_aggregate_test.cc | 73 ++++++++
cpp/src/arrow/engine/substrait/options.cc | 37 ++--
.../arrow/engine/substrait/relation_internal.cc | 73 ++------
cpp/src/arrow/engine/substrait/relation_internal.h | 24 +--
6 files changed, 293 insertions(+), 158 deletions(-)
diff --git a/cpp/src/arrow/acero/aggregate_node.cc
b/cpp/src/arrow/acero/aggregate_node.cc
index bd97235df6..c5b4442544 100644
--- a/cpp/src/arrow/acero/aggregate_node.cc
+++ b/cpp/src/arrow/acero/aggregate_node.cc
@@ -21,6 +21,7 @@
#include <unordered_map>
#include <unordered_set>
+#include "arrow/acero/aggregate_node.h"
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/options.h"
#include "arrow/acero/query_context.h"
@@ -77,6 +78,19 @@ namespace acero {
namespace {
+template <typename KernelType>
+struct AggregateNodeArgs {
+ std::shared_ptr<Schema> output_schema;
+ std::vector<int> grouping_key_field_ids;
+ std::vector<int> segment_key_field_ids;
+ std::unique_ptr<RowSegmenter> segmenter;
+ std::vector<std::vector<int>> target_fieldsets;
+ std::vector<Aggregate> aggregates;
+ std::vector<const KernelType*> kernels;
+ std::vector<std::vector<TypeHolder>> kernel_intypes;
+ std::vector<std::vector<std::unique_ptr<KernelState>>> states;
+};
+
std::vector<TypeHolder> ExtendWithGroupIdType(const std::vector<TypeHolder>&
in_types) {
std::vector<TypeHolder> aggr_in_types;
aggr_in_types.reserve(in_types.size() + 1);
@@ -274,36 +288,22 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
kernel_intypes_(std::move(kernel_intypes)),
states_(std::move(states)) {}
- static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
- const ExecNodeOptions& options) {
- RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1,
"ScalarAggregateNode"));
-
- const auto& aggregate_options = checked_cast<const
AggregateNodeOptions&>(options);
- auto aggregates = aggregate_options.aggregates;
- const auto& keys = aggregate_options.keys;
- const auto& segment_keys = aggregate_options.segment_keys;
-
- if (keys.size() > 0) {
- return Status::Invalid("Scalar aggregation with some key");
- }
- if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
- segment_keys.size() > 0) {
- return Status::NotImplemented("Segmented aggregation in a multi-threaded
plan");
- }
-
- const auto& input_schema = *inputs[0]->output_schema();
- auto exec_ctx = plan->query_context()->exec_context();
-
+ static Result<AggregateNodeArgs<ScalarAggregateKernel>>
MakeAggregateNodeArgs(
+ const std::shared_ptr<Schema>& input_schema, const
std::vector<FieldRef>& keys,
+ const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>&
aggs,
+ ExecContext* exec_ctx, size_t concurrency) {
+ // Copy (need to modify options pointer below)
+ std::vector<Aggregate> aggregates(aggs);
std::vector<int> segment_field_ids(segment_keys.size());
std::vector<TypeHolder> segment_key_types(segment_keys.size());
for (size_t i = 0; i < segment_keys.size(); i++) {
- ARROW_ASSIGN_OR_RAISE(FieldPath match,
segment_keys[i].FindOne(input_schema));
+ ARROW_ASSIGN_OR_RAISE(FieldPath match,
segment_keys[i].FindOne(*input_schema));
if (match.indices().size() > 1) {
// ARROW-18369: Support nested references as segment ids
return Status::Invalid("Nested references cannot be used as segment
ids");
}
segment_field_ids[i] = match[0];
- segment_key_types[i] = input_schema.field(match[0])->type().get();
+ segment_key_types[i] = input_schema->field(match[0])->type().get();
}
ARROW_ASSIGN_OR_RAISE(auto segmenter,
@@ -317,16 +317,15 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
// Output the segment keys first, followed by the aggregates
for (size_t i = 0; i < segment_keys.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(fields[i],
-
segment_keys[i].GetOne(*inputs[0]->output_schema()));
+ ARROW_ASSIGN_OR_RAISE(fields[i], segment_keys[i].GetOne(*input_schema));
}
std::vector<std::vector<int>> target_fieldsets(kernels.size());
std::size_t base = segment_keys.size();
for (size_t i = 0; i < kernels.size(); ++i) {
- const auto& target_fieldset = aggregate_options.aggregates[i].target;
+ const auto& target_fieldset = aggregates[i].target;
for (const auto& target : target_fieldset) {
- ARROW_ASSIGN_OR_RAISE(auto match,
FieldRef(target).FindOne(input_schema));
+ ARROW_ASSIGN_OR_RAISE(auto match,
FieldRef(target).FindOne(*input_schema));
target_fieldsets[i].push_back(match[0]);
}
@@ -346,7 +345,7 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
std::vector<TypeHolder> in_types;
for (const auto& target : target_fieldsets[i]) {
- in_types.emplace_back(input_schema.field(target)->type().get());
+ in_types.emplace_back(input_schema->field(target)->type().get());
}
kernel_intypes[i] = in_types;
ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
@@ -362,7 +361,7 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
}
KernelContext kernel_ctx{exec_ctx};
- states[i].resize(plan->query_context()->max_concurrency());
+ states[i].resize(concurrency);
RETURN_NOT_OK(Kernel::InitAll(
&kernel_ctx,
KernelInitArgs{kernels[i], kernel_intypes[i],
aggregates[i].options.get()},
@@ -373,14 +372,47 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
ARROW_ASSIGN_OR_RAISE(auto out_type,
kernels[i]->signature->out_type().Resolve(
&kernel_ctx,
kernel_intypes[i]));
- fields[base + i] =
- field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr());
+ fields[base + i] = field(aggregates[i].name, out_type.GetSharedPtr());
}
+ return AggregateNodeArgs<ScalarAggregateKernel>{
+ schema(std::move(fields)),
+ /*grouping_key_field_ids=*/{}, std::move(segment_field_ids),
+ std::move(segmenter), std::move(target_fieldsets),
+ std::move(aggregates), std::move(kernels),
+ std::move(kernel_intypes), std::move(states)};
+ }
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1,
"ScalarAggregateNode"));
+
+ const auto& aggregate_options = checked_cast<const
AggregateNodeOptions&>(options);
+ auto aggregates = aggregate_options.aggregates;
+ const auto& keys = aggregate_options.keys;
+ const auto& segment_keys = aggregate_options.segment_keys;
+
+ if (keys.size() > 0) {
+ return Status::Invalid("Scalar aggregation with some key");
+ }
+ if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
+ segment_keys.size() > 0) {
+ return Status::NotImplemented("Segmented aggregation in a multi-threaded
plan");
+ }
+
+ const auto& input_schema = inputs[0]->output_schema();
+ auto exec_ctx = plan->query_context()->exec_context();
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto args,
+ MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggregates,
exec_ctx,
+
/*concurrency=*/plan->query_context()->max_concurrency()));
+
return plan->EmplaceNode<ScalarAggregateNode>(
- plan, std::move(inputs), schema(std::move(fields)),
std::move(segmenter),
- std::move(segment_field_ids), std::move(target_fieldsets),
std::move(aggregates),
- std::move(kernels), std::move(kernel_intypes), std::move(states));
+ plan, std::move(inputs), std::move(args.output_schema),
std::move(args.segmenter),
+ std::move(args.segment_key_field_ids),
std::move(args.target_fieldsets),
+ std::move(args.aggregates), std::move(args.kernels),
+ std::move(args.kernel_intypes), std::move(args.states));
}
const char* kind_name() const override { return "ScalarAggregateNode"; }
@@ -564,25 +596,10 @@ class GroupByNode : public ExecNode, public TracedNode {
return Status::OK();
}
- static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
- const ExecNodeOptions& options) {
- RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
-
- auto input = inputs[0];
- const auto& aggregate_options = checked_cast<const
AggregateNodeOptions&>(options);
- const auto& keys = aggregate_options.keys;
- const auto& segment_keys = aggregate_options.segment_keys;
- // Copy (need to modify options pointer below)
- auto aggs = aggregate_options.aggregates;
-
- if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
- segment_keys.size() > 0) {
- return Status::NotImplemented("Segmented aggregation in a multi-threaded
plan");
- }
-
- // Get input schema
- auto input_schema = input->output_schema();
-
+ static Result<AggregateNodeArgs<HashAggregateKernel>> MakeAggregateNodeArgs(
+ const std::shared_ptr<Schema>& input_schema, const
std::vector<FieldRef>& keys,
+ const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>&
aggs,
+ ExecContext* ctx) {
// Find input field indices for key fields
std::vector<int> key_field_ids(keys.size());
for (size_t i = 0; i < keys.size(); ++i) {
@@ -632,8 +649,6 @@ class GroupByNode : public ExecNode, public TracedNode {
segment_key_types[i] =
input_schema->field(segment_key_field_id)->type().get();
}
- auto ctx = plan->query_context()->exec_context();
-
ARROW_ASSIGN_OR_RAISE(auto segmenter,
RowSegmenter::Make(std::move(segment_key_types),
/*nullable_keys=*/false, ctx));
@@ -665,14 +680,47 @@ class GroupByNode : public ExecNode, public TracedNode {
}
base += segment_keys.size();
for (size_t i = 0; i < aggs.size(); ++i) {
- output_fields[base + i] =
- agg_result_fields[i]->WithName(aggregate_options.aggregates[i].name);
+ output_fields[base + i] = agg_result_fields[i]->WithName(aggs[i].name);
+ }
+
+ return
AggregateNodeArgs<HashAggregateKernel>{schema(std::move(output_fields)),
+ std::move(key_field_ids),
+
std::move(segment_key_field_ids),
+ std::move(segmenter),
+ std::move(agg_src_fieldsets),
+ std::move(aggs),
+ std::move(agg_kernels),
+ std::move(agg_src_types),
+ /*states=*/{}};
+ }
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
+
+ auto input = inputs[0];
+ const auto& aggregate_options = checked_cast<const
AggregateNodeOptions&>(options);
+ const auto& keys = aggregate_options.keys;
+ const auto& segment_keys = aggregate_options.segment_keys;
+ auto aggs = aggregate_options.aggregates;
+
+ if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
+ segment_keys.size() > 0) {
+ return Status::NotImplemented(
+ "Segmented aggregation in a multi-threaded execution context");
}
+ const auto& input_schema = input->output_schema();
+ auto exec_ctx = plan->query_context()->exec_context();
+
+ ARROW_ASSIGN_OR_RAISE(auto args, MakeAggregateNodeArgs(input_schema, keys,
+ segment_keys, aggs,
exec_ctx));
+
return input->plan()->EmplaceNode<GroupByNode>(
- input, schema(std::move(output_fields)), std::move(key_field_ids),
- std::move(segment_key_field_ids), std::move(segmenter),
std::move(agg_src_types),
- std::move(agg_src_fieldsets), std::move(aggs), std::move(agg_kernels));
+ input, std::move(args.output_schema),
std::move(args.grouping_key_field_ids),
+ std::move(args.segment_key_field_ids), std::move(args.segmenter),
+ std::move(args.kernel_intypes), std::move(args.target_fieldsets),
+ std::move(args.aggregates), std::move(args.kernels));
}
Status ResetKernelStates() {
@@ -981,6 +1029,27 @@ class GroupByNode : public ExecNode, public TracedNode {
} // namespace
+namespace aggregate {
+
+Result<std::shared_ptr<Schema>> MakeOutputSchema(
+ const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>&
keys,
+ const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>&
aggregates,
+ ExecContext* exec_ctx) {
+ if (keys.empty()) {
+ ARROW_ASSIGN_OR_RAISE(auto args,
ScalarAggregateNode::MakeAggregateNodeArgs(
+ input_schema, keys, segment_keys,
aggregates,
+ exec_ctx, /*concurrency=*/1));
+ return std::move(args.output_schema);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ auto args, GroupByNode::MakeAggregateNodeArgs(input_schema, keys,
segment_keys,
+ aggregates, exec_ctx));
+ return std::move(args.output_schema);
+ }
+}
+
+} // namespace aggregate
+
namespace internal {
void RegisterAggregateNode(ExecFactoryRegistry* registry) {
diff --git a/cpp/src/arrow/acero/aggregate_node.h
b/cpp/src/arrow/acero/aggregate_node.h
new file mode 100644
index 0000000000..790264b208
--- /dev/null
+++ b/cpp/src/arrow/acero/aggregate_node.h
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "arrow/acero/visibility.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace acero {
+namespace aggregate {
+
+using compute::Aggregate;
+using compute::default_exec_context;
+using compute::ExecContext;
+
+/// \brief Make the output schema of an aggregate node
+///
+/// The output schema is determined by the aggregation kernels, which may
depend on the
+/// ExecContext argument. To guarantee correct results, the same ExecContext
argument
+/// should be used in execution.
+///
+/// \param[in] input_schema the schema of the input to the node
+/// \param[in] keys the grouping keys for the aggregation
+/// \param[in] segment_keys the segmenting keys for the aggregation
+/// \param[in] aggregates the aggregates for the aggregation
+/// \param[in] exec_ctx the execution context for the aggregation
+ARROW_ACERO_EXPORT Result<std::shared_ptr<Schema>> MakeOutputSchema(
+ const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>&
keys,
+ const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>&
aggregates,
+ ExecContext* exec_ctx = default_exec_context());
+
+} // namespace aggregate
+} // namespace acero
+} // namespace arrow
diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc
b/cpp/src/arrow/acero/hash_aggregate_test.cc
index 0ae06d0572..ba8b6f4653 100644
--- a/cpp/src/arrow/acero/hash_aggregate_test.cc
+++ b/cpp/src/arrow/acero/hash_aggregate_test.cc
@@ -25,6 +25,7 @@
#include <utility>
#include <vector>
+#include "arrow/acero/aggregate_node.h"
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/options.h"
#include "arrow/acero/test_util_internal.h"
@@ -86,6 +87,78 @@ using compute::TDigestOptions;
using compute::VarianceOptions;
namespace acero {
+
+TEST(AggregateSchema, NoKeys) {
+ auto input_schema = schema({field("x", int32())});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, HasSubstr("is a hash aggregate function"),
+ aggregate::MakeOutputSchema(input_schema, {}, {},
+ {{"hash_count", nullptr, "x",
"hash_count"}}));
+ ASSERT_OK_AND_ASSIGN(auto output_schema,
+ aggregate::MakeOutputSchema(input_schema, {}, {},
+ {{"count", nullptr, "x",
"count"}}));
+ AssertSchemaEqual(schema({field("count", int64())}), output_schema);
+}
+
+TEST(AggregateSchema, SingleKey) {
+ auto input_schema = schema({field("x", int32()), field("y", int32())});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, HasSubstr("is a scalar aggregate function"),
+ aggregate::MakeOutputSchema(input_schema, {FieldRef("y")}, {},
+ {{"count", nullptr, "x", "count"}}));
+ ASSERT_OK_AND_ASSIGN(
+ auto output_schema,
+ aggregate::MakeOutputSchema(input_schema, {FieldRef("y")}, {},
+ {{"hash_count", nullptr, "x",
"hash_count"}}));
+ AssertSchemaEqual(schema({field("y", int32()), field("hash_count",
int64())}),
+ output_schema);
+}
+
+TEST(AggregateSchema, DoubleKey) {
+ auto input_schema =
+ schema({field("x", int32()), field("y", int32()), field("z", int32())});
+ ASSERT_OK_AND_ASSIGN(
+ auto output_schema,
+ aggregate::MakeOutputSchema(input_schema, {FieldRef("z"),
FieldRef("y")}, {},
+ {{"hash_count", nullptr, "x",
"hash_count"}}));
+ AssertSchemaEqual(
+ schema({field("z", int32()), field("y", int32()), field("hash_count",
int64())}),
+ output_schema);
+}
+
+TEST(AggregateSchema, SingleSegmentKey) {
+ auto input_schema = schema({field("x", int32()), field("y", int32())});
+ ASSERT_OK_AND_ASSIGN(auto output_schema,
+ aggregate::MakeOutputSchema(input_schema, {},
{FieldRef("y")},
+ {{"count", nullptr, "x",
"count"}}));
+ AssertSchemaEqual(schema({field("y", int32()), field("count", int64())}),
+ output_schema);
+}
+
+TEST(AggregateSchema, DoubleSegmentKey) {
+ auto input_schema =
+ schema({field("x", int32()), field("y", int32()), field("z", int32())});
+ ASSERT_OK_AND_ASSIGN(
+ auto output_schema,
+ aggregate::MakeOutputSchema(input_schema, {}, {FieldRef("z"),
FieldRef("y")},
+ {{"count", nullptr, "x", "count"}}));
+ AssertSchemaEqual(
+ schema({field("z", int32()), field("y", int32()), field("count",
int64())}),
+ output_schema);
+}
+
+TEST(AggregateSchema, SingleKeyAndSegmentKey) {
+ auto input_schema =
+ schema({field("x", int32()), field("y", int32()), field("z", int32())});
+ ASSERT_OK_AND_ASSIGN(
+ auto output_schema,
+ aggregate::MakeOutputSchema(input_schema, {FieldRef("y")},
{FieldRef("z")},
+ {{"hash_count", nullptr, "x",
"hash_count"}}));
+ AssertSchemaEqual(
+ schema({field("y", int32()), field("z", int32()), field("hash_count",
int64())}),
+ output_schema);
+}
+
namespace {
using GroupByFunction = std::function<Result<Datum>(
diff --git a/cpp/src/arrow/engine/substrait/options.cc
b/cpp/src/arrow/engine/substrait/options.cc
index 979db875df..0a1af6fce1 100644
--- a/cpp/src/arrow/engine/substrait/options.cc
+++ b/cpp/src/arrow/engine/substrait/options.cc
@@ -20,6 +20,7 @@
#include <google/protobuf/util/json_util.h>
#include <mutex>
+#include "arrow/acero/aggregate_node.h"
#include "arrow/acero/asof_join_node.h"
#include "arrow/acero/options.h"
#include "arrow/engine/substrait/expression_internal.h"
@@ -187,50 +188,38 @@ class DefaultExtensionProvider : public
BaseExtensionProvider {
auto input_schema = inputs[0].output_schema;
- // store key fields to be used when output schema is created
- std::vector<int> key_field_ids;
std::vector<FieldRef> keys;
for (auto& ref : seg_agg_rel.grouping_keys()) {
ARROW_ASSIGN_OR_RAISE(auto field_ref,
DirectReferenceFromProto(&ref, ext_set,
conv_opts));
- ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
- key_field_ids.emplace_back(std::move(match[0]));
keys.emplace_back(std::move(field_ref));
}
- // store segment key fields to be used when output schema is created
- std::vector<int> segment_key_field_ids;
std::vector<FieldRef> segment_keys;
for (auto& ref : seg_agg_rel.segment_keys()) {
ARROW_ASSIGN_OR_RAISE(auto field_ref,
DirectReferenceFromProto(&ref, ext_set,
conv_opts));
- ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
- segment_key_field_ids.emplace_back(std::move(match[0]));
segment_keys.emplace_back(std::move(field_ref));
}
std::vector<compute::Aggregate> aggregates;
aggregates.reserve(seg_agg_rel.measures_size());
- std::vector<std::vector<int>> agg_src_fieldsets;
- agg_src_fieldsets.reserve(seg_agg_rel.measures_size());
for (auto agg_measure : seg_agg_rel.measures()) {
- ARROW_ASSIGN_OR_RAISE(
- auto parsed_measure,
- internal::ParseAggregateMeasure(agg_measure, ext_set, conv_opts,
- /*is_hash=*/!keys.empty(),
input_schema));
- aggregates.push_back(std::move(parsed_measure.aggregate));
- agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset));
+ ARROW_ASSIGN_OR_RAISE(auto aggregate, internal::ParseAggregateMeasure(
+ agg_measure, ext_set,
conv_opts,
+ /*is_hash=*/!keys.empty(),
input_schema));
+ aggregates.push_back(std::move(aggregate));
}
- ARROW_ASSIGN_OR_RAISE(auto decl_info,
- internal::MakeAggregateDeclaration(
- std::move(inputs[0].declaration),
std::move(input_schema),
- seg_agg_rel.measures_size(),
std::move(aggregates),
- std::move(agg_src_fieldsets), std::move(keys),
- std::move(key_field_ids),
std::move(segment_keys),
- std::move(segment_key_field_ids), ext_set,
conv_opts));
+ ARROW_ASSIGN_OR_RAISE(
+ auto output_schema,
+ acero::aggregate::MakeOutputSchema(input_schema, keys, segment_keys,
aggregates));
+
+ ARROW_ASSIGN_OR_RAISE(auto decl_info, internal::MakeAggregateDeclaration(
+ std::move(inputs[0].declaration),
+ output_schema,
std::move(aggregates),
+ std::move(keys),
std::move(segment_keys)));
- const auto& output_schema = decl_info.output_schema;
size_t out_size = output_schema->num_fields();
std::vector<int> field_output_indices(out_size);
for (int i = 0; i < static_cast<int>(out_size); i++) {
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index e7d704554c..0336bb3dd1 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -28,6 +28,7 @@
#include <variant>
#include <vector>
+#include "arrow/acero/aggregate_node.h"
#include "arrow/acero/exec_plan.h"
#include "arrow/acero/options.h"
#include "arrow/compute/api_aggregate.h"
@@ -294,7 +295,7 @@ Status DiscoverFilesFromDir(const
std::shared_ptr<fs::LocalFileSystem>& local_fs
namespace internal {
-Result<ParsedMeasure> ParseAggregateMeasure(
+Result<compute::Aggregate> ParseAggregateMeasure(
const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet&
ext_set,
const ConversionOptions& conversion_options, bool is_hash,
const std::shared_ptr<Schema> input_schema) {
@@ -314,50 +315,16 @@ Result<ParsedMeasure> ParseAggregateMeasure(
ARROW_ASSIGN_OR_RAISE(converter,
ext_set.registry()->GetSubstraitAggregateToArrow(
aggregate_call.id()));
}
- ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg,
converter(aggregate_call));
-
- // find aggregate field ids from schema
- const auto& target = arrow_agg.target;
- std::vector<int> fieldset;
- fieldset.reserve(target.size());
- for (const auto& field_ref : target) {
- ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
- fieldset.push_back(match[0]);
- }
-
- return ParsedMeasure{std::move(arrow_agg), std::move(fieldset)};
+ return converter(aggregate_call);
} else {
return Status::Invalid("substrait::AggregateFunction not provided");
}
}
ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
- acero::Declaration input_decl, std::shared_ptr<Schema> input_schema,
- const int measure_size, std::vector<compute::Aggregate> aggregates,
- std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef>
keys,
- std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
- std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
- const ConversionOptions& conversion_options) {
- FieldVector output_fields;
- output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
- measure_size);
- // extract aggregate fields to output schema
- for (const auto& agg_src_fieldset : agg_src_fieldsets) {
- for (int field : agg_src_fieldset) {
- output_fields.emplace_back(input_schema->field(field));
- }
- }
- // extract key fields to output schema
- for (int key_field_id : key_field_ids) {
- output_fields.emplace_back(input_schema->field(key_field_id));
- }
- // extract segment key fields to output schema
- for (int segment_key_field_id : segment_key_field_ids) {
- output_fields.emplace_back(input_schema->field(segment_key_field_id));
- }
-
- std::shared_ptr<Schema> aggregate_schema = schema(std::move(output_fields));
-
+ acero::Declaration input_decl, std::shared_ptr<Schema> aggregate_schema,
+ std::vector<compute::Aggregate> aggregates, std::vector<FieldRef> keys,
+ std::vector<FieldRef> segment_keys) {
return DeclarationInfo{
acero::Declaration::Sequence(
{std::move(input_decl),
@@ -771,22 +738,17 @@ Result<DeclarationInfo> FromProto(const substrait::Rel&
rel, const ExtensionSet&
// prepare output schema from aggregates
auto input_schema = input.output_schema;
- // store key fields to be used when output schema is created
- std::vector<int> key_field_ids;
std::vector<FieldRef> keys;
if (aggregate.groupings_size() > 0) {
const substrait::AggregateRel::Grouping& group =
aggregate.groupings(0);
int grouping_expr_size = group.grouping_expressions_size();
keys.reserve(grouping_expr_size);
- key_field_ids.reserve(grouping_expr_size);
for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) {
ARROW_ASSIGN_OR_RAISE(
compute::Expression expr,
FromProto(group.grouping_expressions(exp_id), ext_set,
conversion_options));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
- ARROW_ASSIGN_OR_RAISE(auto match,
field_ref->FindOne(*input_schema));
- key_field_ids.emplace_back(std::move(match[0]));
keys.emplace_back(std::move(*field_ref));
} else {
return Status::Invalid(
@@ -798,28 +760,27 @@ Result<DeclarationInfo> FromProto(const substrait::Rel&
rel, const ExtensionSet&
const int measure_size = aggregate.measures_size();
std::vector<compute::Aggregate> aggregates;
aggregates.reserve(measure_size);
- // store aggregate fields to be used when output schema is created
- std::vector<std::vector<int>> agg_src_fieldsets;
- agg_src_fieldsets.reserve(measure_size);
for (int measure_id = 0; measure_id < measure_size; measure_id++) {
const auto& agg_measure = aggregate.measures(measure_id);
ARROW_ASSIGN_OR_RAISE(
- auto parsed_measure,
+ auto aggregate,
internal::ParseAggregateMeasure(agg_measure, ext_set,
conversion_options,
/*is_hash=*/!keys.empty(),
input_schema));
- aggregates.push_back(std::move(parsed_measure.aggregate));
- agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset));
+ aggregates.push_back(std::move(aggregate));
}
+ ARROW_ASSIGN_OR_RAISE(auto aggregate_schema,
+ acero::aggregate::MakeOutputSchema(
+ input_schema, keys, /*segment_keys=*/{},
aggregates));
+
ARROW_ASSIGN_OR_RAISE(
auto aggregate_declaration,
- internal::MakeAggregateDeclaration(
- std::move(input.declaration), std::move(input_schema),
measure_size,
- std::move(aggregates), std::move(agg_src_fieldsets),
std::move(keys),
- std::move(key_field_ids), {}, {}, ext_set, conversion_options));
+ internal::MakeAggregateDeclaration(std::move(input.declaration),
+ aggregate_schema,
std::move(aggregates),
+ std::move(keys),
/*segment_keys=*/{}));
- return ProcessEmit(aggregate, aggregate_declaration,
- aggregate_declaration.output_schema);
+ return ProcessEmit(std::move(aggregate),
std::move(aggregate_declaration),
+ std::move(aggregate_schema));
}
case substrait::Rel::RelTypeCase::kExtensionLeaf:
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h
b/cpp/src/arrow/engine/substrait/relation_internal.h
index 72a0c3f98a..a436f1770d 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.h
+++ b/cpp/src/arrow/engine/substrait/relation_internal.h
@@ -50,11 +50,6 @@ ARROW_ENGINE_EXPORT Result<std::unique_ptr<substrait::Rel>>
ToProto(
namespace internal {
-struct ParsedMeasure {
- compute::Aggregate aggregate;
- std::vector<int> fieldset;
-};
-
/// \brief Parse an aggregate relation's measure
///
/// \param[in] agg_measure the measure
@@ -63,7 +58,7 @@ struct ParsedMeasure {
/// \param[in] input_schema the schema to which field refs apply
/// \param[in] is_hash whether the measure is a hash one (i.e., aggregation
keys exist)
ARROW_ENGINE_EXPORT
-Result<ParsedMeasure> ParseAggregateMeasure(
+Result<compute::Aggregate> ParseAggregateMeasure(
const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet&
ext_set,
const ConversionOptions& conversion_options, bool is_hash,
const std::shared_ptr<Schema> input_schema);
@@ -71,23 +66,14 @@ Result<ParsedMeasure> ParseAggregateMeasure(
/// \brief Make an aggregate declaration info
///
/// \param[in] input_decl the input declaration to use
-/// \param[in] input_schema the schema to which field refs apply
-/// \param[in] measure_size the number of measures to use
+/// \param[in] output_schema the schema to which field refs apply
/// \param[in] aggregates the aggregates to use
-/// \param[in] agg_src_fieldsets the field-sets per aggregate to use
/// \param[in] keys the field-refs for grouping keys to use
-/// \param[in] key_field_ids the field-ids for grouping keys to use
/// \param[in] segment_keys the field-refs for segment keys to use
-/// \param[in] segment_key_field_ids the field-ids for segment keys to use
-/// \param[in] ext_set an extension mapping to use
-/// \param[in] conversion_options options to control how the conversion is done
ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
- acero::Declaration input_decl, std::shared_ptr<Schema> input_schema,
- const int measure_size, std::vector<compute::Aggregate> aggregates,
- std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef>
keys,
- std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
- std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
- const ConversionOptions& conversion_options);
+ acero::Declaration input_decl, std::shared_ptr<Schema> output_schema,
+ std::vector<compute::Aggregate> aggregates, std::vector<FieldRef> keys,
+ std::vector<FieldRef> segment_keys);
} // namespace internal