icexelloss commented on code in PR #34627:
URL: https://github.com/apache/arrow/pull/34627#discussion_r1152588087


##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -293,6 +293,86 @@ Status DiscoverFilesFromDir(const 
std::shared_ptr<fs::LocalFileSystem>& local_fs
   return Status::OK();
 }
 
+namespace internal {
+
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
+    const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& 
ext_set,
+    const ConversionOptions& conversion_options, bool is_hash,
+    const std::shared_ptr<Schema> input_schema,
+    std::vector<compute::Aggregate>* aggregates_ptr,
+    std::vector<std::vector<int>>* agg_src_fieldsets_ptr) {
+  std::vector<compute::Aggregate>& aggregates = *aggregates_ptr;
+  std::vector<std::vector<int>>& agg_src_fieldsets = *agg_src_fieldsets_ptr;
+  if (agg_measure.has_measure()) {
+    if (agg_measure.has_filter()) {
+      return Status::NotImplemented("Aggregate filters are not supported.");
+    }
+    const auto& agg_func = agg_measure.measure();
+    ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+                          FromProto(agg_func, is_hash, ext_set, 
conversion_options));
+    ExtensionIdRegistry::SubstraitAggregateToArrow converter;
+    if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
+      ARROW_ASSIGN_OR_RAISE(converter,
+                            
ext_set.registry()->GetSubstraitAggregateToArrowFallback(
+                                aggregate_call.id().name));
+    } else {
+      ARROW_ASSIGN_OR_RAISE(converter, 
ext_set.registry()->GetSubstraitAggregateToArrow(
+                                           aggregate_call.id()));
+    }
+    ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, 
converter(aggregate_call));
+
+    // find aggregate field ids from schema
+    const auto& target = arrow_agg.target;
+    size_t measure_id = agg_src_fieldsets.size();
+    agg_src_fieldsets.push_back({});
+    for (const auto& field_ref : target) {
+      ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+      agg_src_fieldsets[measure_id].push_back(match[0]);
+    }
+
+    aggregates.push_back(std::move(arrow_agg));
+    return Status::OK();
+  } else {
+    return Status::Invalid("substrait::AggregateFunction not provided");
+  }
+}
+
+ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
+    compute::Declaration input_decl, std::shared_ptr<Schema> input_schema,
+    const int measure_size, std::vector<compute::Aggregate> aggregates,
+    std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef> 
keys,
+    std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
+    std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
+    const ConversionOptions& conversion_options) {
+  FieldVector output_fields;
+  output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
+                        measure_size);
+  // extract aggregate fields to output schema
+  for (const auto& agg_src_fieldset : agg_src_fieldsets) {
+    for (int field : agg_src_fieldset) {
+      output_fields.emplace_back(input_schema->field(field));
+    }
+  }

Review Comment:
   @westonpace  I agree this looks strange/incorrect but it does seem to be 
same as existing code:
   
https://github.com/apache/arrow/blob/main/cpp/src/arrow/engine/substrait/relation_internal.cc#L771
   
   Do you prefer try to get to bottom of this in this PR or leave as follow up? 
I am fine either way



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to