westonpace commented on code in PR #13130:
URL: https://github.com/apache/arrow/pull/13130#discussion_r923297309
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1383,5 +1383,350 @@ TEST(Substrait, JoinPlanInvalidKeys) {
}
}
+TEST(Substrait, AggregateBasic) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "arguments": [{
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ }
+ }],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
+ *buf, [] { return kNullConsumer; },
+ ext_id_reg, &ext_set_invalid));
Review Comment:
```suggestion
ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
*buf, [] { return kNullConsumer;
});
```
The third argument (`ext_id_reg`) is optional if you want to just use the
default extension registry (which should be fine for this test). The fourth
argument `&ext_set_invalid` is an out-parameter, and also optional, so not
needed for this test.
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1383,5 +1383,350 @@ TEST(Substrait, JoinPlanInvalidKeys) {
}
}
+TEST(Substrait, AggregateBasic) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "arguments": [{
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ }
+ }],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
+ *buf, [] { return kNullConsumer; },
+ ext_id_reg, &ext_set_invalid));
+ auto agg_decl = sink_decls[0].inputs[0];
+
+ const auto& agg_rel = agg_decl.get<compute::Declaration>();
+
+ const auto& agg_options =
+ checked_cast<const compute::AggregateNodeOptions&>(*agg_rel->options);
+
+ EXPECT_EQ(agg_rel->factory_name, "aggregate");
+ EXPECT_EQ(agg_options.aggregates[0].name, "");
+ EXPECT_EQ(agg_options.aggregates[0].function, "hash_count");
+}
+
+TEST(Substrait, AggregateInvalidRel) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(Invalid,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
Review Comment:
```suggestion
ASSERT_RAISES(Invalid,
DeserializePlans(
*buf, [] { return kNullConsumer; }));
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1383,5 +1383,350 @@ TEST(Substrait, JoinPlanInvalidKeys) {
}
}
+TEST(Substrait, AggregateBasic) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "arguments": [{
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ }
+ }],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
+ *buf, [] { return kNullConsumer; },
+ ext_id_reg, &ext_set_invalid));
+ auto agg_decl = sink_decls[0].inputs[0];
+
+ const auto& agg_rel = agg_decl.get<compute::Declaration>();
+
+ const auto& agg_options =
+ checked_cast<const compute::AggregateNodeOptions&>(*agg_rel->options);
+
+ EXPECT_EQ(agg_rel->factory_name, "aggregate");
+ EXPECT_EQ(agg_options.aggregates[0].name, "");
+ EXPECT_EQ(agg_options.aggregates[0].function, "hash_count");
+}
+
+TEST(Substrait, AggregateInvalidRel) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(Invalid,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
+}
+
+TEST(Substrait, AggregateInvalidFunction) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(Invalid,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
Review Comment:
```suggestion
ASSERT_RAISES(Invalid,
DeserializePlans(
*buf, [] { return kNullConsumer; }));
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1383,5 +1383,350 @@ TEST(Substrait, JoinPlanInvalidKeys) {
}
}
+TEST(Substrait, AggregateBasic) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "arguments": [{
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ }
+ }],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
+ *buf, [] { return kNullConsumer; },
+ ext_id_reg, &ext_set_invalid));
+ auto agg_decl = sink_decls[0].inputs[0];
+
+ const auto& agg_rel = agg_decl.get<compute::Declaration>();
+
+ const auto& agg_options =
+ checked_cast<const compute::AggregateNodeOptions&>(*agg_rel->options);
+
+ EXPECT_EQ(agg_rel->factory_name, "aggregate");
+ EXPECT_EQ(agg_options.aggregates[0].name, "");
+ EXPECT_EQ(agg_options.aggregates[0].function, "hash_count");
+}
+
+TEST(Substrait, AggregateInvalidRel) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(Invalid,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
+}
+
+TEST(Substrait, AggregateInvalidFunction) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(Invalid,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
+}
+
+TEST(Substrait, AggregateInvalidAggFuncArgs) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "args": [],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(NotImplemented,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
+}
+
+TEST(Substrait, AggregateWithFilter) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "args": [],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "equal"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(NotImplemented,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
Review Comment:
```suggestion
ASSERT_RAISES(NotImplemented,
DeserializePlans(
*buf, [] { return kNullConsumer; }));
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1383,5 +1383,350 @@ TEST(Substrait, JoinPlanInvalidKeys) {
}
}
+TEST(Substrait, AggregateBasic) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "arguments": [{
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ }
+ }],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
+ *buf, [] { return kNullConsumer; },
+ ext_id_reg, &ext_set_invalid));
+ auto agg_decl = sink_decls[0].inputs[0];
+
+ const auto& agg_rel = agg_decl.get<compute::Declaration>();
+
+ const auto& agg_options =
+ checked_cast<const compute::AggregateNodeOptions&>(*agg_rel->options);
+
+ EXPECT_EQ(agg_rel->factory_name, "aggregate");
+ EXPECT_EQ(agg_options.aggregates[0].name, "");
+ EXPECT_EQ(agg_options.aggregates[0].function, "hash_count");
+}
+
+TEST(Substrait, AggregateInvalidRel) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(Invalid,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
+}
+
+TEST(Substrait, AggregateInvalidFunction) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(Invalid,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
+}
+
+TEST(Substrait, AggregateInvalidAggFuncArgs) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "args": [],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "hash_count"
+ }
+ }],
+ })"));
+
+ auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ // invalid before registration
+ ExtensionSet ext_set_invalid(ext_id_reg);
+ ASSERT_RAISES(NotImplemented,
+ DeserializePlans(
+ *buf, [] { return kNullConsumer; }, ext_id_reg,
&ext_set_invalid));
Review Comment:
```suggestion
ASSERT_RAISES(NotImplemented,
DeserializePlans(
*buf, [] { return kNullConsumer; }));
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]