This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 342b74e419 GH-34786: [C++] Fix output schema calculated by Substrait 
consumer for AggregateRel (#34904)
342b74e419 is described below

commit 342b74e4192a69e849b67e36b054078c208fac91
Author: rtpsw <[email protected]>
AuthorDate: Wed Apr 12 19:43:38 2023 +0300

    GH-34786: [C++] Fix output schema calculated by Substrait consumer for 
AggregateRel (#34904)
    
    See #34786
    * Closes: #34786
    
    Can replace #34885
    
    Lead-authored-by: Yaron Gvili <[email protected]>
    Co-authored-by: rtpsw <[email protected]>
    Co-authored-by: Weston Pace <[email protected]>
    Signed-off-by: Weston Pace <[email protected]>
---
 cpp/src/arrow/acero/aggregate_node.cc              | 187 ++++++++++++++-------
 cpp/src/arrow/acero/aggregate_node.h               |  57 +++++++
 cpp/src/arrow/acero/hash_aggregate_test.cc         |  73 ++++++++
 cpp/src/arrow/engine/substrait/options.cc          |  37 ++--
 .../arrow/engine/substrait/relation_internal.cc    |  73 ++------
 cpp/src/arrow/engine/substrait/relation_internal.h |  24 +--
 6 files changed, 293 insertions(+), 158 deletions(-)

diff --git a/cpp/src/arrow/acero/aggregate_node.cc 
b/cpp/src/arrow/acero/aggregate_node.cc
index bd97235df6..c5b4442544 100644
--- a/cpp/src/arrow/acero/aggregate_node.cc
+++ b/cpp/src/arrow/acero/aggregate_node.cc
@@ -21,6 +21,7 @@
 #include <unordered_map>
 #include <unordered_set>
 
+#include "arrow/acero/aggregate_node.h"
 #include "arrow/acero/exec_plan.h"
 #include "arrow/acero/options.h"
 #include "arrow/acero/query_context.h"
@@ -77,6 +78,19 @@ namespace acero {
 
 namespace {
 
+template <typename KernelType>
+struct AggregateNodeArgs {
+  std::shared_ptr<Schema> output_schema;
+  std::vector<int> grouping_key_field_ids;
+  std::vector<int> segment_key_field_ids;
+  std::unique_ptr<RowSegmenter> segmenter;
+  std::vector<std::vector<int>> target_fieldsets;
+  std::vector<Aggregate> aggregates;
+  std::vector<const KernelType*> kernels;
+  std::vector<std::vector<TypeHolder>> kernel_intypes;
+  std::vector<std::vector<std::unique_ptr<KernelState>>> states;
+};
+
 std::vector<TypeHolder> ExtendWithGroupIdType(const std::vector<TypeHolder>& 
in_types) {
   std::vector<TypeHolder> aggr_in_types;
   aggr_in_types.reserve(in_types.size() + 1);
@@ -274,36 +288,22 @@ class ScalarAggregateNode : public ExecNode, public 
TracedNode {
         kernel_intypes_(std::move(kernel_intypes)),
         states_(std::move(states)) {}
 
-  static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
-                                const ExecNodeOptions& options) {
-    RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, 
"ScalarAggregateNode"));
-
-    const auto& aggregate_options = checked_cast<const 
AggregateNodeOptions&>(options);
-    auto aggregates = aggregate_options.aggregates;
-    const auto& keys = aggregate_options.keys;
-    const auto& segment_keys = aggregate_options.segment_keys;
-
-    if (keys.size() > 0) {
-      return Status::Invalid("Scalar aggregation with some key");
-    }
-    if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
-        segment_keys.size() > 0) {
-      return Status::NotImplemented("Segmented aggregation in a multi-threaded 
plan");
-    }
-
-    const auto& input_schema = *inputs[0]->output_schema();
-    auto exec_ctx = plan->query_context()->exec_context();
-
+  static Result<AggregateNodeArgs<ScalarAggregateKernel>> 
MakeAggregateNodeArgs(
+      const std::shared_ptr<Schema>& input_schema, const 
std::vector<FieldRef>& keys,
+      const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& 
aggs,
+      ExecContext* exec_ctx, size_t concurrency) {
+    // Copy (need to modify options pointer below)
+    std::vector<Aggregate> aggregates(aggs);
     std::vector<int> segment_field_ids(segment_keys.size());
     std::vector<TypeHolder> segment_key_types(segment_keys.size());
     for (size_t i = 0; i < segment_keys.size(); i++) {
-      ARROW_ASSIGN_OR_RAISE(FieldPath match, 
segment_keys[i].FindOne(input_schema));
+      ARROW_ASSIGN_OR_RAISE(FieldPath match, 
segment_keys[i].FindOne(*input_schema));
       if (match.indices().size() > 1) {
         // ARROW-18369: Support nested references as segment ids
         return Status::Invalid("Nested references cannot be used as segment 
ids");
       }
       segment_field_ids[i] = match[0];
-      segment_key_types[i] = input_schema.field(match[0])->type().get();
+      segment_key_types[i] = input_schema->field(match[0])->type().get();
     }
 
     ARROW_ASSIGN_OR_RAISE(auto segmenter,
@@ -317,16 +317,15 @@ class ScalarAggregateNode : public ExecNode, public 
TracedNode {
 
     // Output the segment keys first, followed by the aggregates
     for (size_t i = 0; i < segment_keys.size(); ++i) {
-      ARROW_ASSIGN_OR_RAISE(fields[i],
-                            
segment_keys[i].GetOne(*inputs[0]->output_schema()));
+      ARROW_ASSIGN_OR_RAISE(fields[i], segment_keys[i].GetOne(*input_schema));
     }
 
     std::vector<std::vector<int>> target_fieldsets(kernels.size());
     std::size_t base = segment_keys.size();
     for (size_t i = 0; i < kernels.size(); ++i) {
-      const auto& target_fieldset = aggregate_options.aggregates[i].target;
+      const auto& target_fieldset = aggregates[i].target;
       for (const auto& target : target_fieldset) {
-        ARROW_ASSIGN_OR_RAISE(auto match, 
FieldRef(target).FindOne(input_schema));
+        ARROW_ASSIGN_OR_RAISE(auto match, 
FieldRef(target).FindOne(*input_schema));
         target_fieldsets[i].push_back(match[0]);
       }
 
@@ -346,7 +345,7 @@ class ScalarAggregateNode : public ExecNode, public 
TracedNode {
 
       std::vector<TypeHolder> in_types;
       for (const auto& target : target_fieldsets[i]) {
-        in_types.emplace_back(input_schema.field(target)->type().get());
+        in_types.emplace_back(input_schema->field(target)->type().get());
       }
       kernel_intypes[i] = in_types;
       ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
@@ -362,7 +361,7 @@ class ScalarAggregateNode : public ExecNode, public 
TracedNode {
       }
 
       KernelContext kernel_ctx{exec_ctx};
-      states[i].resize(plan->query_context()->max_concurrency());
+      states[i].resize(concurrency);
       RETURN_NOT_OK(Kernel::InitAll(
           &kernel_ctx,
           KernelInitArgs{kernels[i], kernel_intypes[i], 
aggregates[i].options.get()},
@@ -373,14 +372,47 @@ class ScalarAggregateNode : public ExecNode, public 
TracedNode {
       ARROW_ASSIGN_OR_RAISE(auto out_type, 
kernels[i]->signature->out_type().Resolve(
                                                &kernel_ctx, 
kernel_intypes[i]));
 
-      fields[base + i] =
-          field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr());
+      fields[base + i] = field(aggregates[i].name, out_type.GetSharedPtr());
     }
 
+    return AggregateNodeArgs<ScalarAggregateKernel>{
+        schema(std::move(fields)),
+        /*grouping_key_field_ids=*/{}, std::move(segment_field_ids),
+        std::move(segmenter),          std::move(target_fieldsets),
+        std::move(aggregates),         std::move(kernels),
+        std::move(kernel_intypes),     std::move(states)};
+  }
+
+  static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+                                const ExecNodeOptions& options) {
+    RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, 
"ScalarAggregateNode"));
+
+    const auto& aggregate_options = checked_cast<const 
AggregateNodeOptions&>(options);
+    auto aggregates = aggregate_options.aggregates;
+    const auto& keys = aggregate_options.keys;
+    const auto& segment_keys = aggregate_options.segment_keys;
+
+    if (keys.size() > 0) {
+      return Status::Invalid("Scalar aggregation with some key");
+    }
+    if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
+        segment_keys.size() > 0) {
+      return Status::NotImplemented("Segmented aggregation in a multi-threaded 
plan");
+    }
+
+    const auto& input_schema = inputs[0]->output_schema();
+    auto exec_ctx = plan->query_context()->exec_context();
+
+    ARROW_ASSIGN_OR_RAISE(
+        auto args,
+        MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggregates, 
exec_ctx,
+                              
/*concurrency=*/plan->query_context()->max_concurrency()));
+
     return plan->EmplaceNode<ScalarAggregateNode>(
-        plan, std::move(inputs), schema(std::move(fields)), 
std::move(segmenter),
-        std::move(segment_field_ids), std::move(target_fieldsets), 
std::move(aggregates),
-        std::move(kernels), std::move(kernel_intypes), std::move(states));
+        plan, std::move(inputs), std::move(args.output_schema), 
std::move(args.segmenter),
+        std::move(args.segment_key_field_ids), 
std::move(args.target_fieldsets),
+        std::move(args.aggregates), std::move(args.kernels),
+        std::move(args.kernel_intypes), std::move(args.states));
   }
 
   const char* kind_name() const override { return "ScalarAggregateNode"; }
@@ -564,25 +596,10 @@ class GroupByNode : public ExecNode, public TracedNode {
     return Status::OK();
   }
 
-  static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
-                                const ExecNodeOptions& options) {
-    RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
-
-    auto input = inputs[0];
-    const auto& aggregate_options = checked_cast<const 
AggregateNodeOptions&>(options);
-    const auto& keys = aggregate_options.keys;
-    const auto& segment_keys = aggregate_options.segment_keys;
-    // Copy (need to modify options pointer below)
-    auto aggs = aggregate_options.aggregates;
-
-    if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
-        segment_keys.size() > 0) {
-      return Status::NotImplemented("Segmented aggregation in a multi-threaded 
plan");
-    }
-
-    // Get input schema
-    auto input_schema = input->output_schema();
-
+  static Result<AggregateNodeArgs<HashAggregateKernel>> MakeAggregateNodeArgs(
+      const std::shared_ptr<Schema>& input_schema, const 
std::vector<FieldRef>& keys,
+      const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& 
aggs,
+      ExecContext* ctx) {
     // Find input field indices for key fields
     std::vector<int> key_field_ids(keys.size());
     for (size_t i = 0; i < keys.size(); ++i) {
@@ -632,8 +649,6 @@ class GroupByNode : public ExecNode, public TracedNode {
       segment_key_types[i] = 
input_schema->field(segment_key_field_id)->type().get();
     }
 
-    auto ctx = plan->query_context()->exec_context();
-
     ARROW_ASSIGN_OR_RAISE(auto segmenter,
                           RowSegmenter::Make(std::move(segment_key_types),
                                              /*nullable_keys=*/false, ctx));
@@ -665,14 +680,47 @@ class GroupByNode : public ExecNode, public TracedNode {
     }
     base += segment_keys.size();
     for (size_t i = 0; i < aggs.size(); ++i) {
-      output_fields[base + i] =
-          agg_result_fields[i]->WithName(aggregate_options.aggregates[i].name);
+      output_fields[base + i] = agg_result_fields[i]->WithName(aggs[i].name);
+    }
+
+    return 
AggregateNodeArgs<HashAggregateKernel>{schema(std::move(output_fields)),
+                                                  std::move(key_field_ids),
+                                                  
std::move(segment_key_field_ids),
+                                                  std::move(segmenter),
+                                                  std::move(agg_src_fieldsets),
+                                                  std::move(aggs),
+                                                  std::move(agg_kernels),
+                                                  std::move(agg_src_types),
+                                                  /*states=*/{}};
+  }
+
+  static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+                                const ExecNodeOptions& options) {
+    RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
+
+    auto input = inputs[0];
+    const auto& aggregate_options = checked_cast<const 
AggregateNodeOptions&>(options);
+    const auto& keys = aggregate_options.keys;
+    const auto& segment_keys = aggregate_options.segment_keys;
+    auto aggs = aggregate_options.aggregates;
+
+    if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
+        segment_keys.size() > 0) {
+      return Status::NotImplemented(
+          "Segmented aggregation in a multi-threaded execution context");
     }
 
+    const auto& input_schema = input->output_schema();
+    auto exec_ctx = plan->query_context()->exec_context();
+
+    ARROW_ASSIGN_OR_RAISE(auto args, MakeAggregateNodeArgs(input_schema, keys,
+                                                           segment_keys, aggs, 
exec_ctx));
+
     return input->plan()->EmplaceNode<GroupByNode>(
-        input, schema(std::move(output_fields)), std::move(key_field_ids),
-        std::move(segment_key_field_ids), std::move(segmenter), 
std::move(agg_src_types),
-        std::move(agg_src_fieldsets), std::move(aggs), std::move(agg_kernels));
+        input, std::move(args.output_schema), 
std::move(args.grouping_key_field_ids),
+        std::move(args.segment_key_field_ids), std::move(args.segmenter),
+        std::move(args.kernel_intypes), std::move(args.target_fieldsets),
+        std::move(args.aggregates), std::move(args.kernels));
   }
 
   Status ResetKernelStates() {
@@ -981,6 +1029,27 @@ class GroupByNode : public ExecNode, public TracedNode {
 
 }  // namespace
 
+namespace aggregate {
+
+Result<std::shared_ptr<Schema>> MakeOutputSchema(
+    const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& 
keys,
+    const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& 
aggregates,
+    ExecContext* exec_ctx) {
+  if (keys.empty()) {
+    ARROW_ASSIGN_OR_RAISE(auto args, 
ScalarAggregateNode::MakeAggregateNodeArgs(
+                                         input_schema, keys, segment_keys, 
aggregates,
+                                         exec_ctx, /*concurrency=*/1));
+    return std::move(args.output_schema);
+  } else {
+    ARROW_ASSIGN_OR_RAISE(
+        auto args, GroupByNode::MakeAggregateNodeArgs(input_schema, keys, 
segment_keys,
+                                                      aggregates, exec_ctx));
+    return std::move(args.output_schema);
+  }
+}
+
+}  // namespace aggregate
+
 namespace internal {
 
 void RegisterAggregateNode(ExecFactoryRegistry* registry) {
diff --git a/cpp/src/arrow/acero/aggregate_node.h 
b/cpp/src/arrow/acero/aggregate_node.h
new file mode 100644
index 0000000000..790264b208
--- /dev/null
+++ b/cpp/src/arrow/acero/aggregate_node.h
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "arrow/acero/visibility.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace acero {
+namespace aggregate {
+
+using compute::Aggregate;
+using compute::default_exec_context;
+using compute::ExecContext;
+
+/// \brief Make the output schema of an aggregate node
+///
+/// The output schema is determined by the aggregation kernels, which may 
depend on the
+/// ExecContext argument. To guarantee correct results, the same ExecContext 
argument
+/// should be used in execution.
+///
+/// \param[in] input_schema the schema of the input to the node
+/// \param[in] keys the grouping keys for the aggregation
+/// \param[in] segment_keys the segmenting keys for the aggregation
+/// \param[in] aggregates the aggregates for the aggregation
+/// \param[in] exec_ctx the execution context for the aggregation
+ARROW_ACERO_EXPORT Result<std::shared_ptr<Schema>> MakeOutputSchema(
+    const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& 
keys,
+    const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& 
aggregates,
+    ExecContext* exec_ctx = default_exec_context());
+
+}  // namespace aggregate
+}  // namespace acero
+}  // namespace arrow
diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc 
b/cpp/src/arrow/acero/hash_aggregate_test.cc
index 0ae06d0572..ba8b6f4653 100644
--- a/cpp/src/arrow/acero/hash_aggregate_test.cc
+++ b/cpp/src/arrow/acero/hash_aggregate_test.cc
@@ -25,6 +25,7 @@
 #include <utility>
 #include <vector>
 
+#include "arrow/acero/aggregate_node.h"
 #include "arrow/acero/exec_plan.h"
 #include "arrow/acero/options.h"
 #include "arrow/acero/test_util_internal.h"
@@ -86,6 +87,78 @@ using compute::TDigestOptions;
 using compute::VarianceOptions;
 
 namespace acero {
+
+TEST(AggregateSchema, NoKeys) {
+  auto input_schema = schema({field("x", int32())});
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, HasSubstr("is a hash aggregate function"),
+      aggregate::MakeOutputSchema(input_schema, {}, {},
+                                  {{"hash_count", nullptr, "x", 
"hash_count"}}));
+  ASSERT_OK_AND_ASSIGN(auto output_schema,
+                       aggregate::MakeOutputSchema(input_schema, {}, {},
+                                                   {{"count", nullptr, "x", 
"count"}}));
+  AssertSchemaEqual(schema({field("count", int64())}), output_schema);
+}
+
+TEST(AggregateSchema, SingleKey) {
+  auto input_schema = schema({field("x", int32()), field("y", int32())});
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, HasSubstr("is a scalar aggregate function"),
+      aggregate::MakeOutputSchema(input_schema, {FieldRef("y")}, {},
+                                  {{"count", nullptr, "x", "count"}}));
+  ASSERT_OK_AND_ASSIGN(
+      auto output_schema,
+      aggregate::MakeOutputSchema(input_schema, {FieldRef("y")}, {},
+                                  {{"hash_count", nullptr, "x", 
"hash_count"}}));
+  AssertSchemaEqual(schema({field("y", int32()), field("hash_count", 
int64())}),
+                    output_schema);
+}
+
+TEST(AggregateSchema, DoubleKey) {
+  auto input_schema =
+      schema({field("x", int32()), field("y", int32()), field("z", int32())});
+  ASSERT_OK_AND_ASSIGN(
+      auto output_schema,
+      aggregate::MakeOutputSchema(input_schema, {FieldRef("z"), 
FieldRef("y")}, {},
+                                  {{"hash_count", nullptr, "x", 
"hash_count"}}));
+  AssertSchemaEqual(
+      schema({field("z", int32()), field("y", int32()), field("hash_count", 
int64())}),
+      output_schema);
+}
+
+TEST(AggregateSchema, SingleSegmentKey) {
+  auto input_schema = schema({field("x", int32()), field("y", int32())});
+  ASSERT_OK_AND_ASSIGN(auto output_schema,
+                       aggregate::MakeOutputSchema(input_schema, {}, 
{FieldRef("y")},
+                                                   {{"count", nullptr, "x", 
"count"}}));
+  AssertSchemaEqual(schema({field("y", int32()), field("count", int64())}),
+                    output_schema);
+}
+
+TEST(AggregateSchema, DoubleSegmentKey) {
+  auto input_schema =
+      schema({field("x", int32()), field("y", int32()), field("z", int32())});
+  ASSERT_OK_AND_ASSIGN(
+      auto output_schema,
+      aggregate::MakeOutputSchema(input_schema, {}, {FieldRef("z"), 
FieldRef("y")},
+                                  {{"count", nullptr, "x", "count"}}));
+  AssertSchemaEqual(
+      schema({field("z", int32()), field("y", int32()), field("count", 
int64())}),
+      output_schema);
+}
+
+TEST(AggregateSchema, SingleKeyAndSegmentKey) {
+  auto input_schema =
+      schema({field("x", int32()), field("y", int32()), field("z", int32())});
+  ASSERT_OK_AND_ASSIGN(
+      auto output_schema,
+      aggregate::MakeOutputSchema(input_schema, {FieldRef("y")}, 
{FieldRef("z")},
+                                  {{"hash_count", nullptr, "x", 
"hash_count"}}));
+  AssertSchemaEqual(
+      schema({field("y", int32()), field("z", int32()), field("hash_count", 
int64())}),
+      output_schema);
+}
+
 namespace {
 
 using GroupByFunction = std::function<Result<Datum>(
diff --git a/cpp/src/arrow/engine/substrait/options.cc 
b/cpp/src/arrow/engine/substrait/options.cc
index 979db875df..0a1af6fce1 100644
--- a/cpp/src/arrow/engine/substrait/options.cc
+++ b/cpp/src/arrow/engine/substrait/options.cc
@@ -20,6 +20,7 @@
 #include <google/protobuf/util/json_util.h>
 #include <mutex>
 
+#include "arrow/acero/aggregate_node.h"
 #include "arrow/acero/asof_join_node.h"
 #include "arrow/acero/options.h"
 #include "arrow/engine/substrait/expression_internal.h"
@@ -187,50 +188,38 @@ class DefaultExtensionProvider : public 
BaseExtensionProvider {
 
     auto input_schema = inputs[0].output_schema;
 
-    // store key fields to be used when output schema is created
-    std::vector<int> key_field_ids;
     std::vector<FieldRef> keys;
     for (auto& ref : seg_agg_rel.grouping_keys()) {
       ARROW_ASSIGN_OR_RAISE(auto field_ref,
                             DirectReferenceFromProto(&ref, ext_set, 
conv_opts));
-      ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
-      key_field_ids.emplace_back(std::move(match[0]));
       keys.emplace_back(std::move(field_ref));
     }
 
-    // store segment key fields to be used when output schema is created
-    std::vector<int> segment_key_field_ids;
     std::vector<FieldRef> segment_keys;
     for (auto& ref : seg_agg_rel.segment_keys()) {
       ARROW_ASSIGN_OR_RAISE(auto field_ref,
                             DirectReferenceFromProto(&ref, ext_set, 
conv_opts));
-      ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
-      segment_key_field_ids.emplace_back(std::move(match[0]));
       segment_keys.emplace_back(std::move(field_ref));
     }
 
     std::vector<compute::Aggregate> aggregates;
     aggregates.reserve(seg_agg_rel.measures_size());
-    std::vector<std::vector<int>> agg_src_fieldsets;
-    agg_src_fieldsets.reserve(seg_agg_rel.measures_size());
     for (auto agg_measure : seg_agg_rel.measures()) {
-      ARROW_ASSIGN_OR_RAISE(
-          auto parsed_measure,
-          internal::ParseAggregateMeasure(agg_measure, ext_set, conv_opts,
-                                          /*is_hash=*/!keys.empty(), 
input_schema));
-      aggregates.push_back(std::move(parsed_measure.aggregate));
-      agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset));
+      ARROW_ASSIGN_OR_RAISE(auto aggregate, internal::ParseAggregateMeasure(
+                                                agg_measure, ext_set, 
conv_opts,
+                                                /*is_hash=*/!keys.empty(), 
input_schema));
+      aggregates.push_back(std::move(aggregate));
     }
 
-    ARROW_ASSIGN_OR_RAISE(auto decl_info,
-                          internal::MakeAggregateDeclaration(
-                              std::move(inputs[0].declaration), 
std::move(input_schema),
-                              seg_agg_rel.measures_size(), 
std::move(aggregates),
-                              std::move(agg_src_fieldsets), std::move(keys),
-                              std::move(key_field_ids), 
std::move(segment_keys),
-                              std::move(segment_key_field_ids), ext_set, 
conv_opts));
+    ARROW_ASSIGN_OR_RAISE(
+        auto output_schema,
+        acero::aggregate::MakeOutputSchema(input_schema, keys, segment_keys, 
aggregates));
+
+    ARROW_ASSIGN_OR_RAISE(auto decl_info, internal::MakeAggregateDeclaration(
+                                              std::move(inputs[0].declaration),
+                                              output_schema, 
std::move(aggregates),
+                                              std::move(keys), 
std::move(segment_keys)));
 
-    const auto& output_schema = decl_info.output_schema;
     size_t out_size = output_schema->num_fields();
     std::vector<int> field_output_indices(out_size);
     for (int i = 0; i < static_cast<int>(out_size); i++) {
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc 
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index e7d704554c..0336bb3dd1 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -28,6 +28,7 @@
 #include <variant>
 #include <vector>
 
+#include "arrow/acero/aggregate_node.h"
 #include "arrow/acero/exec_plan.h"
 #include "arrow/acero/options.h"
 #include "arrow/compute/api_aggregate.h"
@@ -294,7 +295,7 @@ Status DiscoverFilesFromDir(const 
std::shared_ptr<fs::LocalFileSystem>& local_fs
 
 namespace internal {
 
-Result<ParsedMeasure> ParseAggregateMeasure(
+Result<compute::Aggregate> ParseAggregateMeasure(
     const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& 
ext_set,
     const ConversionOptions& conversion_options, bool is_hash,
     const std::shared_ptr<Schema> input_schema) {
@@ -314,50 +315,16 @@ Result<ParsedMeasure> ParseAggregateMeasure(
       ARROW_ASSIGN_OR_RAISE(converter, 
ext_set.registry()->GetSubstraitAggregateToArrow(
                                            aggregate_call.id()));
     }
-    ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, 
converter(aggregate_call));
-
-    // find aggregate field ids from schema
-    const auto& target = arrow_agg.target;
-    std::vector<int> fieldset;
-    fieldset.reserve(target.size());
-    for (const auto& field_ref : target) {
-      ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
-      fieldset.push_back(match[0]);
-    }
-
-    return ParsedMeasure{std::move(arrow_agg), std::move(fieldset)};
+    return converter(aggregate_call);
   } else {
     return Status::Invalid("substrait::AggregateFunction not provided");
   }
 }
 
 ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
-    acero::Declaration input_decl, std::shared_ptr<Schema> input_schema,
-    const int measure_size, std::vector<compute::Aggregate> aggregates,
-    std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef> 
keys,
-    std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
-    std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
-    const ConversionOptions& conversion_options) {
-  FieldVector output_fields;
-  output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
-                        measure_size);
-  // extract aggregate fields to output schema
-  for (const auto& agg_src_fieldset : agg_src_fieldsets) {
-    for (int field : agg_src_fieldset) {
-      output_fields.emplace_back(input_schema->field(field));
-    }
-  }
-  // extract key fields to output schema
-  for (int key_field_id : key_field_ids) {
-    output_fields.emplace_back(input_schema->field(key_field_id));
-  }
-  // extract segment key fields to output schema
-  for (int segment_key_field_id : segment_key_field_ids) {
-    output_fields.emplace_back(input_schema->field(segment_key_field_id));
-  }
-
-  std::shared_ptr<Schema> aggregate_schema = schema(std::move(output_fields));
-
+    acero::Declaration input_decl, std::shared_ptr<Schema> aggregate_schema,
+    std::vector<compute::Aggregate> aggregates, std::vector<FieldRef> keys,
+    std::vector<FieldRef> segment_keys) {
   return DeclarationInfo{
       acero::Declaration::Sequence(
           {std::move(input_decl),
@@ -771,22 +738,17 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
 
       // prepare output schema from aggregates
       auto input_schema = input.output_schema;
-      // store key fields to be used when output schema is created
-      std::vector<int> key_field_ids;
       std::vector<FieldRef> keys;
       if (aggregate.groupings_size() > 0) {
         const substrait::AggregateRel::Grouping& group = 
aggregate.groupings(0);
         int grouping_expr_size = group.grouping_expressions_size();
         keys.reserve(grouping_expr_size);
-        key_field_ids.reserve(grouping_expr_size);
         for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) {
           ARROW_ASSIGN_OR_RAISE(
               compute::Expression expr,
               FromProto(group.grouping_expressions(exp_id), ext_set, 
conversion_options));
           const FieldRef* field_ref = expr.field_ref();
           if (field_ref) {
-            ARROW_ASSIGN_OR_RAISE(auto match, 
field_ref->FindOne(*input_schema));
-            key_field_ids.emplace_back(std::move(match[0]));
             keys.emplace_back(std::move(*field_ref));
           } else {
             return Status::Invalid(
@@ -798,28 +760,27 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
       const int measure_size = aggregate.measures_size();
       std::vector<compute::Aggregate> aggregates;
       aggregates.reserve(measure_size);
-      // store aggregate fields to be used when output schema is created
-      std::vector<std::vector<int>> agg_src_fieldsets;
-      agg_src_fieldsets.reserve(measure_size);
       for (int measure_id = 0; measure_id < measure_size; measure_id++) {
         const auto& agg_measure = aggregate.measures(measure_id);
         ARROW_ASSIGN_OR_RAISE(
-            auto parsed_measure,
+            auto aggregate,
             internal::ParseAggregateMeasure(agg_measure, ext_set, 
conversion_options,
                                             /*is_hash=*/!keys.empty(), 
input_schema));
-        aggregates.push_back(std::move(parsed_measure.aggregate));
-        agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset));
+        aggregates.push_back(std::move(aggregate));
       }
 
+      ARROW_ASSIGN_OR_RAISE(auto aggregate_schema,
+                            acero::aggregate::MakeOutputSchema(
+                                input_schema, keys, /*segment_keys=*/{}, 
aggregates));
+
       ARROW_ASSIGN_OR_RAISE(
           auto aggregate_declaration,
-          internal::MakeAggregateDeclaration(
-              std::move(input.declaration), std::move(input_schema), 
measure_size,
-              std::move(aggregates), std::move(agg_src_fieldsets), 
std::move(keys),
-              std::move(key_field_ids), {}, {}, ext_set, conversion_options));
+          internal::MakeAggregateDeclaration(std::move(input.declaration),
+                                             aggregate_schema, 
std::move(aggregates),
+                                             std::move(keys), 
/*segment_keys=*/{}));
 
-      return ProcessEmit(aggregate, aggregate_declaration,
-                         aggregate_declaration.output_schema);
+      return ProcessEmit(std::move(aggregate), 
std::move(aggregate_declaration),
+                         std::move(aggregate_schema));
     }
 
     case substrait::Rel::RelTypeCase::kExtensionLeaf:
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h 
b/cpp/src/arrow/engine/substrait/relation_internal.h
index 72a0c3f98a..a436f1770d 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.h
+++ b/cpp/src/arrow/engine/substrait/relation_internal.h
@@ -50,11 +50,6 @@ ARROW_ENGINE_EXPORT Result<std::unique_ptr<substrait::Rel>> 
ToProto(
 
 namespace internal {
 
-struct ParsedMeasure {
-  compute::Aggregate aggregate;
-  std::vector<int> fieldset;
-};
-
 /// \brief Parse an aggregate relation's measure
 ///
 /// \param[in] agg_measure the measure
@@ -63,7 +58,7 @@ struct ParsedMeasure {
 /// \param[in] input_schema the schema to which field refs apply
 /// \param[in] is_hash whether the measure is a hash one (i.e., aggregation 
keys exist)
 ARROW_ENGINE_EXPORT
-Result<ParsedMeasure> ParseAggregateMeasure(
+Result<compute::Aggregate> ParseAggregateMeasure(
     const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& 
ext_set,
     const ConversionOptions& conversion_options, bool is_hash,
     const std::shared_ptr<Schema> input_schema);
@@ -71,23 +66,14 @@ Result<ParsedMeasure> ParseAggregateMeasure(
 /// \brief Make an aggregate declaration info
 ///
 /// \param[in] input_decl the input declaration to use
-/// \param[in] input_schema the schema to which field refs apply
-/// \param[in] measure_size the number of measures to use
+/// \param[in] output_schema the schema to which field refs apply
 /// \param[in] aggregates the aggregates to use
-/// \param[in] agg_src_fieldsets the field-sets per aggregate to use
 /// \param[in] keys the field-refs for grouping keys to use
-/// \param[in] key_field_ids the field-ids for grouping keys to use
 /// \param[in] segment_keys the field-refs for segment keys to use
-/// \param[in] segment_key_field_ids the field-ids for segment keys to use
-/// \param[in] ext_set an extension mapping to use
-/// \param[in] conversion_options options to control how the conversion is done
 ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
-    acero::Declaration input_decl, std::shared_ptr<Schema> input_schema,
-    const int measure_size, std::vector<compute::Aggregate> aggregates,
-    std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef> 
keys,
-    std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
-    std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
-    const ConversionOptions& conversion_options);
+    acero::Declaration input_decl, std::shared_ptr<Schema> output_schema,
+    std::vector<compute::Aggregate> aggregates, std::vector<FieldRef> keys,
+    std::vector<FieldRef> segment_keys);
 
 }  // namespace internal
 

Reply via email to