westonpace commented on code in PR #34627:
URL: https://github.com/apache/arrow/pull/34627#discussion_r1149877822
##########
cpp/src/arrow/engine/substrait/options.cc:
##########
@@ -166,6 +171,57 @@ class DefaultExtensionProvider : public
BaseExtensionProvider {
named_tap_rel.name(),
std::move(renamed_schema)));
return RelationInfo{{std::move(decl), std::move(renamed_schema)},
std::nullopt};
}
+
+ Result<RelationInfo> MakeSegmentedAggregateRel(
+ const ConversionOptions& conv_opts, const std::vector<DeclarationInfo>&
inputs,
+ const substrait_ext::SegmentedAggregateRel& seg_agg_rel,
+ const ExtensionSet& ext_set) {
+ if (inputs.size() != 1) {
Review Comment:
I think this is no longer a concern right? Since we are only using the
`Measure` part of the original message (and not the part that contains its own
inputs). So we should be able to avoid this `inputs` check.
##########
cpp/src/arrow/engine/substrait/expression_internal.cc:
##########
@@ -138,6 +138,77 @@ std::string EnumToString(int value, const
google::protobuf::EnumDescriptor* desc
return value_desc->name();
}
+Result<compute::Expression> FromProto(const
substrait::Expression::ReferenceSegment* ref,
+ const ExtensionSet& ext_set,
+ const ConversionOptions&
conversion_options,
+ std::optional<compute::Expression>
in_expr) {
+ auto in_ref = ref;
+ auto& out = in_expr;
+ while (ref != nullptr) {
+ switch (ref->reference_type_case()) {
+ case substrait::Expression::ReferenceSegment::kStructField: {
+ auto index = ref->struct_field().field();
+ if (!out) {
+ // Root StructField (column selection)
+ out = compute::field_ref(FieldRef(index));
+ } else if (auto out_ref = out->field_ref()) {
+ // Nested StructFields on the root (selection of struct-typed column
+ // combined with selecting struct fields)
+ out = compute::field_ref(FieldRef(*out_ref, index));
+ } else if (out->call() && out->call()->function_name ==
"struct_field") {
+ // Nested StructFields on top of an arbitrary expression
+ auto* field_options =
+
checked_cast<compute::StructFieldOptions*>(out->call()->options.get());
+ field_options->field_ref =
FieldRef(std::move(field_options->field_ref), index);
+ } else {
+ // First StructField on top of an arbitrary expression
+ out = compute::call("struct_field", {std::move(*out)},
+ arrow::compute::StructFieldOptions({index}));
+ }
Review Comment:
The code here is perhaps more complex than the caller needs in many cases.
Ideally it would return a simple field reference. The challenge is that
Arrow's expressions don't support field references into intermediate outputs.
For example, if you have a function `min_max` that returns a struct like `{
"minimum": 0, "maximum": 10} }` then there is no way to say `min_max(x).minimum
> 0` using a field reference. So you end up needing to use the `struct_field`
function to create `call("gt", { call("struct_field", { field_ref("x"),
"minimum" }), 0 })`.
In many cases (e.g. aggregate keys, join keys, sort keys) Acero will only
accept "direct references". These will always end up being expressible as
`FieldRef`.
So, there may be some value in creating a second `DirectReferenceFromProto`
that accepts `ReferenceSegment` and returns `FieldRef`. The implementation
could call the existing `FromProto(Expression)`:
```
Result<FieldRef> DirectReferenceFromProto(const substrait::Expression&
substrait_expr,
const ExtensionSet& ext_set,
const ConversionOptions&
conversion_options) {
ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
FromProto(substrait_expr, ext_set,
conversion_options));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
return *field_ref;
} else {
return Status::Invalid(
"A direct reference was expected but a more complex expression was
given "
"instead");
}
}
```
You would still need the change here so I agree with breaking it out into
its own method.
##########
cpp/src/arrow/compute/exec/source_node.cc:
##########
@@ -102,6 +104,19 @@ struct SourceNode : ExecNode, public TracedNode {
batch_size = morsel_length;
}
ExecBatch batch = morsel.Slice(offset, batch_size);
+ for (auto& value : batch.values) {
+ if (value.is_array()) {
+ ARROW_ASSIGN_OR_RAISE(
+ value, util::EnsureAlignment(value.make_array(),
ipc::kArrowAlignment,
+ default_memory_pool()));
+ }
+ if (value.is_chunked_array()) {
Review Comment:
`value` can't be a chunked array here.
##########
cpp/proto/substrait/extension_rels.proto:
##########
@@ -58,3 +58,16 @@ message NamedTapRel {
// If empty, field names will be automatically generated.
repeated string columns = 3;
}
+
+message SegmentedAggregateRel {
+ substrait.RelCommon common = 1;
+
+ // Grouping keys of the aggregation
+ repeated substrait.Expression.ReferenceSegment grouping_keys = 2;
+
+ // Segment keys of the aggregation
+ repeated substrait.Expression.ReferenceSegment segment_keys = 3;
+
+ // A list of one or more aggregate expressions along with an optional filter.
+ repeated substrait.AggregateRel.Measure measures = 4;
Review Comment:
I think this is ok. If this were in the Substrait spec we might separate
the two so we could allow the two implementations to diverge in the future.
However, for something internal I think it is fine to keep it simpler.
##########
cpp/proto/substrait/extension_rels.proto:
##########
@@ -58,3 +58,16 @@ message NamedTapRel {
// If empty, field names will be automatically generated.
repeated string columns = 3;
}
+
+message SegmentedAggregateRel {
+ substrait.RelCommon common = 1;
Review Comment:
Agreed.
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -293,6 +303,90 @@ Status DiscoverFilesFromDir(const
std::shared_ptr<fs::LocalFileSystem>& local_fs
return Status::OK();
}
+namespace internal {
+
+ARROW_ENGINE_EXPORT Status ParseAggregateMeasure(
+ const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet&
ext_set,
+ const ConversionOptions& conversion_options, bool is_hash,
+ const std::shared_ptr<Schema> input_schema,
+ std::vector<compute::Aggregate>* aggregates_ptr,
+ std::vector<std::vector<int>>* agg_src_fieldsets_ptr) {
+ std::vector<compute::Aggregate>& aggregates = *aggregates_ptr;
+ std::vector<std::vector<int>>& agg_src_fieldsets = *agg_src_fieldsets_ptr;
+ if (agg_measure.has_measure()) {
+ if (agg_measure.has_filter()) {
+ return Status::NotImplemented("Aggregate filters are not supported.");
+ }
+ const auto& agg_func = agg_measure.measure();
+ ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call,
+ FromProto(agg_func, is_hash, ext_set,
conversion_options));
+ ExtensionIdRegistry::SubstraitAggregateToArrow converter;
+ if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') {
+ ARROW_ASSIGN_OR_RAISE(converter,
+
ext_set.registry()->GetSubstraitAggregateToArrowFallback(
+ aggregate_call.id().name));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(converter,
ext_set.registry()->GetSubstraitAggregateToArrow(
+ aggregate_call.id()));
+ }
+ ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg,
converter(aggregate_call));
+
+ // find aggregate field ids from schema
+ const auto& target = arrow_agg.target;
+ size_t measure_id = agg_src_fieldsets.size();
+ agg_src_fieldsets.push_back({});
+ for (const auto& field_ref : target) {
+ ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+ agg_src_fieldsets[measure_id].push_back(match[0]);
+ }
+
+ aggregates.push_back(std::move(arrow_agg));
+ return Status::OK();
+ } else {
+ return Status::Invalid("substrait::AggregateFunction not provided");
+ }
+}
+
+ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(
+ std::optional<substrait::RelCommon> agg_common_opt, compute::Declaration
input_decl,
+ std::shared_ptr<Schema> input_schema, const int measure_size,
+ std::vector<compute::Aggregate> aggregates,
+ std::vector<std::vector<int>> agg_src_fieldsets, std::vector<FieldRef>
keys,
+ std::vector<int> key_field_ids, std::vector<FieldRef> segment_keys,
+ std::vector<int> segment_key_field_ids, const ExtensionSet& ext_set,
+ const ConversionOptions& conversion_options) {
+ FieldVector output_fields;
+ output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() +
+ measure_size);
+ // extract aggregate fields to output schema
+ for (const auto& agg_src_fieldset : agg_src_fieldsets) {
+ for (int field : agg_src_fieldset) {
+ output_fields.emplace_back(input_schema->field(field));
+ }
+ }
+ // extract key fields to output schema
+ for (int key_field_id : key_field_ids) {
+ output_fields.emplace_back(input_schema->field(key_field_id));
+ }
+ // extract segment key fields to output schema
+ for (int segment_key_field_id : segment_key_field_ids) {
+ output_fields.emplace_back(input_schema->field(segment_key_field_id));
+ }
+
+ std::shared_ptr<Schema> aggregate_schema = schema(std::move(output_fields));
+
+ DeclarationInfo aggregate_declaration{
+ compute::Declaration::Sequence(
+ {std::move(input_decl),
+ {"aggregate", compute::AggregateNodeOptions{aggregates, keys,
segment_keys}}}),
+ aggregate_schema};
+
+ return ProcessEmit(std::move(agg_common_opt),
std::move(aggregate_declaration),
+ std::move(aggregate_schema));
Review Comment:
`ProcessEmit` should not be part of this method. You don't need it when
working on an extension method. This should be left back where it was in the
`kAggregate` case of the top-level relation switch statement. Then you don't
need to pass in `RelCommon` and you don't need the changes to `ProcessEmit`.
##########
cpp/src/arrow/engine/substrait/expression_internal.cc:
##########
@@ -138,6 +138,77 @@ std::string EnumToString(int value, const
google::protobuf::EnumDescriptor* desc
return value_desc->name();
}
+Result<compute::Expression> FromProto(const
substrait::Expression::ReferenceSegment* ref,
+ const ExtensionSet& ext_set,
+ const ConversionOptions&
conversion_options,
+ std::optional<compute::Expression>
in_expr) {
+ auto in_ref = ref;
+ auto& out = in_expr;
Review Comment:
We could maybe name it `current` or something. The logic is (very roughly):
```
FieldRef current = in_expr;
while (ref != null) {
current = current.join(ref);
ref = ref->next;
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]