icexelloss commented on code in PR #34627:
URL: https://github.com/apache/arrow/pull/34627#discussion_r1150630902


##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -293,6 +303,90 @@ Status DiscoverFilesFromDir(const 
std::shared_ptr<fs::LocalFileSystem>& local_fs
   return Status::OK();
 }
 
+namespace internal {
+
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
+    const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& 
ext_set,
+    const ConversionOptions& conversion_options, bool is_hash,
+    const std::shared_ptr<Schema> input_schema,
+    std::vector<compute::Aggregate>* aggregates_ptr,
+    std::vector<std::vector<int>>* agg_src_fieldsets_ptr) {
+  std::vector<compute::Aggregate>& aggregates = *aggregates_ptr;
+  std::vector<std::vector<int>>& agg_src_fieldsets = *agg_src_fieldsets_ptr;
+  if (agg_measure.has_measure()) {
+    if (agg_measure.has_filter()) {
+      return Status::NotImplemented("Aggregate filters are not supported.");
+    }
+    const auto& agg_func = agg_measure.measure();
+    ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+                          FromProto(agg_func, is_hash, ext_set, 
conversion_options));
+    ExtensionIdRegistry::SubstraitAggregateToArrow converter;
+    if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
+      ARROW_ASSIGN_OR_RAISE(converter,
+                            
ext_set.registry()->GetSubstraitAggregateToArrowFallback(
+                                aggregate_call.id().name));
+    } else {
+      ARROW_ASSIGN_OR_RAISE(converter, 
ext_set.registry()->GetSubstraitAggregateToArrow(
+                                           aggregate_call.id()));
+    }
+    ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, 
converter(aggregate_call));
+
+    // find aggregate field ids from schema
+    const auto& target = arrow_agg.target;
+    size_t measure_id = agg_src_fieldsets.size();
+    agg_src_fieldsets.push_back({});
+    for (const auto& field_ref : target) {
+      ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+      agg_src_fieldsets[measure_id].push_back(match[0]);
+    }
+
+    aggregates.push_back(std::move(arrow_agg));
+    return Status::OK();
+  } else {
+    return Status::Invalid("substrait::AggregateFunction not provided");
+  }
+}
+
+ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(

Review Comment:
   Nvm I found it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to