rtpsw commented on code in PR #34627:
URL: https://github.com/apache/arrow/pull/34627#discussion_r1150575591
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -293,6 +303,90 @@ Status DiscoverFilesFromDir(const
std::shared_ptr<fs::LocalFileSystem>& local_fs
return Status::OK();
}
+namespace internal {
+
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
+ const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet&
ext_set,
+ const ConversionOptions& conversion_options, bool is_hash,
+ const std::shared_ptr<Schema> input_schema,
+ std::vector<compute::Aggregate>* aggregates_ptr,
+ std::vector<std::vector<int>>* agg_src_fieldsets_ptr) {
+ std::vector<compute::Aggregate>& aggregates = *aggregates_ptr;
+ std::vector<std::vector<int>>& agg_src_fieldsets = *agg_src_fieldsets_ptr;
+ if (agg_measure.has_measure()) {
+ if (agg_measure.has_filter()) {
+ return Status::NotImplemented("Aggregate filters are not supported.");
+ }
+ const auto& agg_func = agg_measure.measure();
+ ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+ FromProto(agg_func, is_hash, ext_set,
conversion_options));
+ ExtensionIdRegistry::SubstraitAggregateToArrow converter;
+ if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
+ ARROW_ASSIGN_OR_RAISE(converter,
+
ext_set.registry()->GetSubstraitAggregateToArrowFallback(
+ aggregate_call.id().name));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(converter,
ext_set.registry()->GetSubstraitAggregateToArrow(
+ aggregate_call.id()));
+ }
+ ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg,
converter(aggregate_call));
+
+ // find aggregate field ids from schema
+ const auto& target = arrow_agg.target;
+ size_t measure_id = agg_src_fieldsets.size();
+ agg_src_fieldsets.push_back({});
+ for (const auto& field_ref : target) {
+ ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+ agg_src_fieldsets[measure_id].push_back(match[0]);
+ }
+
+ aggregates.push_back(std::move(arrow_agg));
+ return Status::OK();
+ } else {
+ return Status::Invalid("substrait::AggregateFunction not provided");
+ }
+}
+
+ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
+ std::optional<substrait::RelCommon> agg_common_opt, compute::Declaration
input_decl,
+ std::shared_ptr<Schema> input_schema, const int measure_size,
+ std::vector<compute::Aggregate> aggregates,
+ std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef>
keys,
+ std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
+ std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
+ const ConversionOptions& conversion_options) {
+ FieldVector output_fields;
+ output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
+ measure_size);
+ // extract aggregate fields to output schema
+ for (const auto& agg_src_fieldset : agg_src_fieldsets) {
+ for (int field : agg_src_fieldset) {
+ output_fields.emplace_back(input_schema->field(field));
+ }
+ }
+ // extract key fields to output schema
+ for (int key_field_id : key_field_ids) {
+ output_fields.emplace_back(input_schema->field(key_field_id));
+ }
+ // extract segment key fields to output schema
+ for (int segment_key_field_id : segment_key_field_ids) {
+ output_fields.emplace_back(input_schema->field(segment_key_field_id));
+ }
+
+ std::shared_ptr<Schema> aggregate_schema = schema(std::move(output_fields));
+
+ DeclarationInfo aggregate_declaration{
+ compute::Declaration::Sequence(
+ {std::move(input_decl),
+ {"aggregate", compute::AggregateNodeOptions{aggregates, keys,
segment_keys}}}),
+ aggregate_schema};
+
+ return ProcessEmit(std::move(agg_common_opt),
std::move(aggregate_declaration),
+ std::move(aggregate_schema));
Review Comment:
Fixed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]