westonpace commented on code in PR #34627:
URL: https://github.com/apache/arrow/pull/34627#discussion_r1151186698


##########
cpp/src/arrow/engine/substrait/options.cc:
##########
@@ -165,6 +170,69 @@ class DefaultExtensionProvider : public 
BaseExtensionProvider {
                                                 named_tap_rel.name(), 
renamed_schema));
     return RelationInfo{{std::move(decl), std::move(renamed_schema)}, 
std::nullopt};
   }
+
+  Result<RelationInfo> MakeSegmentedAggregateRel(
+      const ConversionOptions& conv_opts, const std::vector<DeclarationInfo>& 
inputs,
+      const substrait_ext::SegmentedAggregateRel& seg_agg_rel,
+      const ExtensionSet& ext_set) {
+    if (inputs.size() != 1) {
+      return Status::Invalid(
+          "substrait_ext::SegmentedAggregateRel requires a single input but 
got: ",
+          inputs.size());
+    }
+    if (seg_agg_rel.segment_keys_size() == 0) {
+      return Status::Invalid(
+          "substrait_ext::SegmentedAggregateRel requires at least one segment 
key");
+    }

Review Comment:
   This is more of a question for @icexelloss I think (or whomever is going to 
be producing these plans).  We could fallback to a regular aggregation if there 
are no segment keys right?  If we leave this check in place then the user will 
have to make sure to only use this message if there are segmented keys present 
correct?



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -293,6 +293,86 @@ Status DiscoverFilesFromDir(const 
std::shared_ptr<fs::LocalFileSystem>& local_fs
   return Status::OK();
 }
 
+namespace internal {
+
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
+    const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& 
ext_set,
+    const ConversionOptions& conversion_options, bool is_hash,
+    const std::shared_ptr<Schema> input_schema,
+    std::vector<compute::Aggregate>* aggregates_ptr,
+    std::vector<std::vector<int>>* agg_src_fieldsets_ptr) {
+  std::vector<compute::Aggregate>& aggregates = *aggregates_ptr;
+  std::vector<std::vector<int>>& agg_src_fieldsets = *agg_src_fieldsets_ptr;
+  if (agg_measure.has_measure()) {
+    if (agg_measure.has_filter()) {
+      return Status::NotImplemented("Aggregate filters are not supported.");
+    }
+    const auto& agg_func = agg_measure.measure();
+    ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+                          FromProto(agg_func, is_hash, ext_set, 
conversion_options));
+    ExtensionIdRegistry::SubstraitAggregateToArrow converter;
+    if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
+      ARROW_ASSIGN_OR_RAISE(converter,
+                            
ext_set.registry()->GetSubstraitAggregateToArrowFallback(
+                                aggregate_call.id().name));
+    } else {
+      ARROW_ASSIGN_OR_RAISE(converter, 
ext_set.registry()->GetSubstraitAggregateToArrow(
+                                           aggregate_call.id()));
+    }
+    ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, 
converter(aggregate_call));
+
+    // find aggregate field ids from schema
+    const auto& target = arrow_agg.target;
+    size_t measure_id = agg_src_fieldsets.size();
+    agg_src_fieldsets.push_back({});
+    for (const auto& field_ref : target) {
+      ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+      agg_src_fieldsets[measure_id].push_back(match[0]);
+    }
+
+    aggregates.push_back(std::move(arrow_agg));
+    return Status::OK();
+  } else {
+    return Status::Invalid("substrait::AggregateFunction not provided");
+  }
+}
+
+ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
+    compute::Declaration input_decl, std::shared_ptr<Schema> input_schema,
+    const int measure_size, std::vector<compute::Aggregate> aggregates,
+    std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef> 
keys,
+    std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
+    std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
+    const ConversionOptions& conversion_options) {
+  FieldVector output_fields;
+  output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
+                        measure_size);
+  // extract aggregate fields to output schema
+  for (const auto& agg_src_fieldset : agg_src_fieldsets) {
+    for (int field : agg_src_fieldset) {
+      output_fields.emplace_back(input_schema->field(field));
+    }
+  }

Review Comment:
   I don't think this is correct.
   
   1. Values come after keys (note, this changed recently in 
https://github.com/apache/arrow/issues/32897)
   2. You're inserting fields here for the inputs to the aggregates.  However, 
the fields should be based on the outputs.
   
   For example, consider the query is `SELECT SUM(x), COVARIANCE(x,y) GROUP BY 
key` and the input schema is `{key: int32, x: float32, y: float32}`
   
   I believe this method would return `{x: float32, x: float32, y: float32, 
key: int32}`.  Instead it should return `{ key: int32, sum(x): float64, 
covariance(x, y): float64}`.
   
   Note, the names aren't really important (and probably difficult to 
recreate).  So it'd be fine to do `{key0: int32, measure0: float64, measure1: 
float64}`.  However, we do need to get the types and # of output fields correct 
or else it will mess up any project relations added after this.



##########
cpp/src/arrow/engine/substrait/relation_internal.h:
##########
@@ -46,5 +52,46 @@ Result<DeclarationInfo> FromProto(const substrait::Rel&, 
const ExtensionSet&,
 ARROW_ENGINE_EXPORT Result<std::unique_ptr<substrait::Rel>> ToProto(
     const compute::Declaration&, ExtensionSet*, const ConversionOptions&);
 
+namespace internal {
+
+/// \brief Parse an aggregate relation's measure
+///
+/// \param[in] agg_measure the measure
+/// \param[in] ext_set an extension mapping to use in parsing
+/// \param[in] conversion_options options to control how the conversion is done
+/// \param[in] input_schema the schema to which field refs apply
+/// \param[in] is_hash whether the measure is a hash one (i.e., aggregation 
keys exist)
+/// \param[out] aggregates points to vector to push the parsed measure into
+/// \param[out] agg_src_fieldsets points to vector to push the parsed field 
set into
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(

Review Comment:
   I don't love these methods.  It's not clear why they are public and they 
have a lot of arguments.  However, I can understand why they are here.
   
   ParseAggregateMeasure is probably inevitable.  Although, I wonder if it 
might be easier to understand if you return:
   
   ```
   struct ParsedMeasure {
     compute::Aggregate aggregate;
     std::vector<int> fieldset;
     std::shared_ptr<DataType> output_type;
   };
   
   Result<ParsedMeasure> ParseAggregateMeasure(...);
   ```
   
   I think we could do away with `MakeAggregateDeclaration` if we had a better 
way of computing the output schema given an input schema and a 
`compute::AggregateNodeOptions`. I started trying to make a 
`compute::AggregateNodeOptions::CalculateOutputSchema(const Schema& 
input_schema)` but it turns out it's rather tricky to determine the output 
types of measures once we have left Substrait.  So, for the moment, I think we 
can leave this as-is.



##########
cpp/src/arrow/engine/substrait/expression_internal.h:
##########
@@ -34,6 +35,15 @@ namespace engine {
 
 class SubstraitCall;
 
+ARROW_ENGINE_EXPORT
+Result<FieldRef> DirectReferenceFromProto(const 
substrait::Expression::ReferenceSegment*,
+                                          const ExtensionSet&, const 
ConversionOptions&);
+
+ARROW_ENGINE_EXPORT
+Result<compute::Expression> FromProto(const 
substrait::Expression::ReferenceSegment*,
+                                      const ExtensionSet&, const 
ConversionOptions&,
+                                      std::optional<compute::Expression>);

Review Comment:
   It's the "current expression".  This method is currently called int he 
middle of a deserializing an expression tree.  So, for example:
   
   ```mermaid
   flowchart TD
       A[Call] -->|args| FieldRef
       A -->|args| C
       C[Call*] -->|args| D
       C -->|Call| E
       D[Literal]
       E[FieldRef*]
   ```
   
   So, when de-referencing `FieldRef` this will be `Call` and when 
dereferencing `FieldRef*` this will be `Call*`.
   
   However, can we remove this prototype from the header file and put it in an 
anonymous namespace?  I think it should be an internal method and not exposed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to