westonpace commented on code in PR #34627:
URL: https://github.com/apache/arrow/pull/34627#discussion_r1151186698
##########
cpp/src/arrow/engine/substrait/options.cc:
##########
@@ -165,6 +170,69 @@ class DefaultExtensionProvider : public
BaseExtensionProvider {
named_tap_rel.name(),
renamed_schema));
return RelationInfo{{std::move(decl), std::move(renamed_schema)},
std::nullopt};
}
+
+ Result<RelationInfo> MakeSegmentedAggregateRel(
+ const ConversionOptions& conv_opts, const std::vector<DeclarationInfo>&
inputs,
+ const substrait_ext::SegmentedAggregateRel& seg_agg_rel,
+ const ExtensionSet& ext_set) {
+ if (inputs.size() != 1) {
+ return Status::Invalid(
+ "substrait_ext::SegmentedAggregateRel requires a single input but
got: ",
+ inputs.size());
+ }
+ if (seg_agg_rel.segment_keys_size() == 0) {
+ return Status::Invalid(
+ "substrait_ext::SegmentedAggregateRel requires at least one segment
key");
+ }
Review Comment:
This is more of a question for @icexelloss I think (or whomever is going to
be producing these plans). We could fallback to a regular aggregation if there
are no segment keys right? If we leave this check in place then the user will
have to make sure to only use this message if there are segmented keys present
correct?
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -293,6 +293,86 @@ Status DiscoverFilesFromDir(const
std::shared_ptr<fs::LocalFileSystem>& local_fs
return Status::OK();
}
+namespace internal {
+
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
+ const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet&
ext_set,
+ const ConversionOptions& conversion_options, bool is_hash,
+ const std::shared_ptr<Schema> input_schema,
+ std::vector<compute::Aggregate>* aggregates_ptr,
+ std::vector<std::vector<int>>* agg_src_fieldsets_ptr) {
+ std::vector<compute::Aggregate>& aggregates = *aggregates_ptr;
+ std::vector<std::vector<int>>& agg_src_fieldsets = *agg_src_fieldsets_ptr;
+ if (agg_measure.has_measure()) {
+ if (agg_measure.has_filter()) {
+ return Status::NotImplemented("Aggregate filters are not supported.");
+ }
+ const auto& agg_func = agg_measure.measure();
+ ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+ FromProto(agg_func, is_hash, ext_set,
conversion_options));
+ ExtensionIdRegistry::SubstraitAggregateToArrow converter;
+ if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
+ ARROW_ASSIGN_OR_RAISE(converter,
+
ext_set.registry()->GetSubstraitAggregateToArrowFallback(
+ aggregate_call.id().name));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(converter,
ext_set.registry()->GetSubstraitAggregateToArrow(
+ aggregate_call.id()));
+ }
+ ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg,
converter(aggregate_call));
+
+ // find aggregate field ids from schema
+ const auto& target = arrow_agg.target;
+ size_t measure_id = agg_src_fieldsets.size();
+ agg_src_fieldsets.push_back({});
+ for (const auto& field_ref : target) {
+ ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+ agg_src_fieldsets[measure_id].push_back(match[0]);
+ }
+
+ aggregates.push_back(std::move(arrow_agg));
+ return Status::OK();
+ } else {
+ return Status::Invalid("substrait::AggregateFunction not provided");
+ }
+}
+
+ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
+ compute::Declaration input_decl, std::shared_ptr<Schema> input_schema,
+ const int measure_size, std::vector<compute::Aggregate> aggregates,
+ std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef>
keys,
+ std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
+ std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
+ const ConversionOptions& conversion_options) {
+ FieldVector output_fields;
+ output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
+ measure_size);
+ // extract aggregate fields to output schema
+ for (const auto& agg_src_fieldset : agg_src_fieldsets) {
+ for (int field : agg_src_fieldset) {
+ output_fields.emplace_back(input_schema->field(field));
+ }
+ }
Review Comment:
I don't think this is correct.
1. Values come after keys (note, this changed recently in
https://github.com/apache/arrow/issues/32897)
2. You're inserting fields here for the inputs to the aggregates. However,
the fields should be based on the outputs.
For example, consider the query is `SELECT SUM(x), COVARIANCE(x,y) GROUP BY
key` and the input schema is `{key: int32, x: float32, y: float32}`
I believe this method would return `{x: float32, x: float32, y: float32,
key: int32}`. Instead it should return `{ key: int32, sum(x): float64,
covariance(x, y): float64}`.
Note, the names aren't really important (and probably difficult to
recreate). So it'd be fine to do `{key0: int32, measure0: float64, measure1:
float64}`. However, we do need to get the types and # of output fields correct
or else it will mess up any project relations added after this.
##########
cpp/src/arrow/engine/substrait/relation_internal.h:
##########
@@ -46,5 +52,46 @@ Result<DeclarationInfo> FromProto(const substrait::Rel&,
const ExtensionSet&,
ARROW_ENGINE_EXPORT Result<std::unique_ptr<substrait::Rel>> ToProto(
const compute::Declaration&, ExtensionSet*, const ConversionOptions&);
+namespace internal {
+
+/// \brief Parse an aggregate relation's measure
+///
+/// \param[in] agg_measure the measure
+/// \param[in] ext_set an extension mapping to use in parsing
+/// \param[in] conversion_options options to control how the conversion is done
+/// \param[in] input_schema the schema to which field refs apply
+/// \param[in] is_hash whether the measure is a hash one (i.e., aggregation
keys exist)
+/// \param[out] aggregates points to vector to push the parsed measure into
+/// \param[out] agg_src_fieldsets points to vector to push the parsed field
set into
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
Review Comment:
I don't love these methods. It's not clear why they are public and they
have a lot of arguments. However, I can understand why they are here.
ParseAggregateMeasure is probably inevitable. Although, I wonder if it
might be easier to understand if you return:
```
struct ParsedMeasure {
compute::Aggregate aggregate;
std::vector<int> fieldset;
std::shared_ptr<DataType> output_type;
};
Result<ParsedMeasure> ParseAggregateMeasure(...);
```
I think we could do away with `MakeAggregateDeclaration` if we had a better
way of computing the output schema given an input schema and a
`compute::AggregateNodeOptions`. I started trying to make a
`compute::AggregateNodeOptions::CalculateOutputSchema(const Schema&
input_schema)` but it turns out it's rather tricky to determine the output
types of measures once we have left Substrait. So, for the moment, I think we
can leave this as-is.
##########
cpp/src/arrow/engine/substrait/expression_internal.h:
##########
@@ -34,6 +35,15 @@ namespace engine {
class SubstraitCall;
+ARROW_ENGINE_EXPORT
+Result<FieldRef> DirectReferenceFromProto(const
substrait::Expression::ReferenceSegment*,
+ const ExtensionSet&, const
ConversionOptions&);
+
+ARROW_ENGINE_EXPORT
+Result<compute::Expression> FromProto(const
substrait::Expression::ReferenceSegment*,
+ const ExtensionSet&, const
ConversionOptions&,
+ std::optional<compute::Expression>);
Review Comment:
It's the "current expression". This method is currently called int he
middle of a deserializing an expression tree. So, for example:
```mermaid
flowchart TD
A[Call] -->|args| FieldRef
A -->|args| C
C[Call*] -->|args| D
C -->|Call| E
D[Literal]
E[FieldRef*]
```
So, when de-referencing `FieldRef` this will be `Call` and when
dereferencing `FieldRef*` this will be `Call*`.
However, can we remove this prototype from the header file and put it in an
anonymous namespace? I think it should be an internal method and not exposed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]