This is an automated email from the ASF dual-hosted git repository. kszucs pushed a commit to branch maint-9.0.0 in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 74a4a0244ed7ea5acc3512e0a9a93036844abc0e Author: Vibhatha Lakmal Abeykoon <[email protected]> AuthorDate: Tue Jul 26 05:29:15 2022 +0530 ARROW-15591: [C++] Add support for aggregation to the Substrait consumer (#13130) This PR includes the Substrait-Arrow Aggregate integration where a Substrait plan can be consumed in ACERO. Lead-authored-by: Vibhatha Abeykoon <[email protected]> Co-authored-by: Vibhatha Lakmal Abeykoon <[email protected]> Signed-off-by: Weston Pace <[email protected]> --- cpp/src/arrow/engine/substrait/extension_set.cc | 1 + .../arrow/engine/substrait/relation_internal.cc | 72 ++++- cpp/src/arrow/engine/substrait/serde_test.cc | 317 +++++++++++++++++++++ 3 files changed, 389 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index f60f6ac1cb..08eb6acc9c 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -445,6 +445,7 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { "add", "equal", "is_not_distinct_from", + "hash_count", }) { DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 09ecb2f069..8f6cb0ce36 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -307,7 +307,7 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, callptr->function_name); } - // TODO: ARROW-166241 Add Suffix support for Substrait + // TODO: ARROW-16624 Add Suffix support for Substrait const auto* left_keys = callptr->arguments[0].field_ref(); const auto* right_keys = callptr->arguments[1].field_ref(); if (!left_keys || !right_keys) { @@ -323,6 +323,76 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, join_dec.inputs.emplace_back(std::move(right.declaration)); return DeclarationInfo{std::move(join_dec), num_columns}; } + case substrait::Rel::RelTypeCase::kAggregate: { + const auto& aggregate = rel.aggregate(); + RETURN_NOT_OK(CheckRelCommon(aggregate)); + + if (!aggregate.has_input()) { + return Status::Invalid("substrait::AggregateRel with no input relation"); + } + + ARROW_ASSIGN_OR_RAISE(auto input, FromProto(aggregate.input(), ext_set)); + + if (aggregate.groupings_size() > 1) { + return Status::NotImplemented( + "Grouping sets not supported. AggregateRel::groupings may not have more " + "than one item"); + } + std::vector<FieldRef> keys; + auto group = aggregate.groupings(0); + keys.reserve(group.grouping_expressions_size()); + for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) { + ARROW_ASSIGN_OR_RAISE(auto expr, + FromProto(group.grouping_expressions(exp_id), ext_set)); + const auto* field_ref = expr.field_ref(); + if (field_ref) { + keys.emplace_back(std::move(*field_ref)); + } else { + return Status::Invalid( + "The grouping expression for an aggregate must be a direct reference."); + } + } + + int measure_size = aggregate.measures_size(); + std::vector<compute::Aggregate> aggregates; + aggregates.reserve(measure_size); + for (int measure_id = 0; measure_id < measure_size; measure_id++) { + const auto& agg_measure = aggregate.measures(measure_id); + if (agg_measure.has_measure()) { + if (agg_measure.has_filter()) { + return Status::NotImplemented("Aggregate filters are not supported."); + } + const auto& agg_func = agg_measure.measure(); + if (agg_func.arguments_size() != 1) { + return Status::NotImplemented("Aggregate function must be a unary function."); + } + int func_reference = agg_func.function_reference(); + ARROW_ASSIGN_OR_RAISE(auto func_record, ext_set.DecodeFunction(func_reference)); + // aggreagte function name + auto func_name = std::string(func_record.id.name); + // aggregate target + auto subs_func_args = agg_func.arguments(0); + ARROW_ASSIGN_OR_RAISE(auto field_expr, + FromProto(subs_func_args.value(), ext_set)); + auto target = field_expr.field_ref(); + if (!target) { + return Status::Invalid( + "The input expression to an aggregate function must be a direct " + "reference."); + } + aggregates.emplace_back(compute::Aggregate{std::move(func_name), NULLPTR, + std::move(*target), std::move("")}); + } else { + return Status::Invalid("substrait::AggregateFunction not provided"); + } + } + + return DeclarationInfo{ + compute::Declaration::Sequence( + {std::move(input.declaration), + {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), + static_cast<int>(aggregates.size())}; + } default: break; diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index e10082392d..8e5745d6df 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1383,5 +1383,322 @@ TEST(Substrait, JoinPlanInvalidKeys) { } } +TEST(Substrait, AggregateBasic) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "relations": [{ + "rel": { + "aggregate": { + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": "file:///tmp/dat.parquet", + "parquet": {} + } + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "hash_count" + } + }], + })")); + + auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry(); + ASSERT_OK_AND_ASSIGN(auto sink_decls, + DeserializePlans(*buf, [] { return kNullConsumer; })); + auto agg_decl = sink_decls[0].inputs[0]; + + const auto& agg_rel = agg_decl.get<compute::Declaration>(); + + const auto& agg_options = + checked_cast<const compute::AggregateNodeOptions&>(*agg_rel->options); + + EXPECT_EQ(agg_rel->factory_name, "aggregate"); + EXPECT_EQ(agg_options.aggregates[0].name, ""); + EXPECT_EQ(agg_options.aggregates[0].function, "hash_count"); +} + +TEST(Substrait, AggregateInvalidRel) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "relations": [{ + "rel": { + "aggregate": { + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "hash_count" + } + }], + })")); + + ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; })); +} + +TEST(Substrait, AggregateInvalidFunction) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "relations": [{ + "rel": { + "aggregate": { + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": "file:///tmp/dat.parquet", + "parquet": {} + } + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "hash_count" + } + }], + })")); + + ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; })); +} + +TEST(Substrait, AggregateInvalidAggFuncArgs) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "relations": [{ + "rel": { + "aggregate": { + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": "file:///tmp/dat.parquet", + "parquet": {} + } + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "hash_count" + } + }], + })")); + + ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); +} + +TEST(Substrait, AggregateWithFilter) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "relations": [{ + "rel": { + "aggregate": { + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": "file:///tmp/dat.parquet", + "parquet": {} + } + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + } + }], + })")); + + ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); +} + } // namespace engine } // namespace arrow
