westonpace commented on code in PR #34627:
URL: https://github.com/apache/arrow/pull/34627#discussion_r1152716216


##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -293,6 +293,86 @@ Status DiscoverFilesFromDir(const 
std::shared_ptr<fs::LocalFileSystem>& local_fs
   return Status::OK();
 }
 
+namespace internal {
+
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
+    const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& 
ext_set,
+    const ConversionOptions& conversion_options, bool is_hash,
+    const std::shared_ptr<Schema> input_schema,
+    std::vector<compute::Aggregate>* aggregates_ptr,
+    std::vector<std::vector<int>>* agg_src_fieldsets_ptr) {
+  std::vector<compute::Aggregate>& aggregates = *aggregates_ptr;
+  std::vector<std::vector<int>>& agg_src_fieldsets = *agg_src_fieldsets_ptr;
+  if (agg_measure.has_measure()) {
+    if (agg_measure.has_filter()) {
+      return Status::NotImplemented("Aggregate filters are not supported.");
+    }
+    const auto& agg_func = agg_measure.measure();
+    ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+                          FromProto(agg_func, is_hash, ext_set, 
conversion_options));
+    ExtensionIdRegistry::SubstraitAggregateToArrow converter;
+    if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
+      ARROW_ASSIGN_OR_RAISE(converter,
+                            
ext_set.registry()->GetSubstraitAggregateToArrowFallback(
+                                aggregate_call.id().name));
+    } else {
+      ARROW_ASSIGN_OR_RAISE(converter, 
ext_set.registry()->GetSubstraitAggregateToArrow(
+                                           aggregate_call.id()));
+    }
+    ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, 
converter(aggregate_call));
+
+    // find aggregate field ids from schema
+    const auto& target = arrow_agg.target;
+    size_t measure_id = agg_src_fieldsets.size();
+    agg_src_fieldsets.push_back({});
+    for (const auto& field_ref : target) {
+      ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+      agg_src_fieldsets[measure_id].push_back(match[0]);
+    }
+
+    aggregates.push_back(std::move(arrow_agg));
+    return Status::OK();
+  } else {
+    return Status::Invalid("substrait::AggregateFunction not provided");
+  }
+}
+
+ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
+    compute::Declaration input_decl, std::shared_ptr<Schema> input_schema,
+    const int measure_size, std::vector<compute::Aggregate> aggregates,
+    std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef> 
keys,
+    std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
+    std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
+    const ConversionOptions& conversion_options) {
+  FieldVector output_fields;
+  output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
+                        measure_size);
+  // extract aggregate fields to output schema
+  for (const auto& agg_src_fieldset : agg_src_fieldsets) {
+    for (int field : agg_src_fieldset) {
+      output_fields.emplace_back(input_schema->field(field));
+    }
+  }

Review Comment:
   We can leave for a follow-up.  Yes, I still think it's incorrect.  Following 
up an aggregate with a projection had been broken for a long time due to the 
ordering issue.  So it wouldn't surprise me if this has always been broken.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to