westonpace commented on code in PR #13130:
URL: https://github.com/apache/arrow/pull/13130#discussion_r918293275


##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -308,6 +308,71 @@ Result<compute::Declaration> FromProto(const 
substrait::Rel& rel,
       join_dec.inputs.emplace_back(std::move(right));
       return std::move(join_dec);
     }
+    case substrait::Rel::RelTypeCase::kAggregate: {
+      const auto& aggregate = rel.aggregate();
+      RETURN_NOT_OK(CheckRelCommon(aggregate));
+
+      if (!aggregate.has_input()) {
+        return Status::Invalid("substrait::AggregateRel with no input 
relation");
+      }
+
+      ARROW_ASSIGN_OR_RAISE(auto input, FromProto(aggregate.input(), ext_set));
+
+      if (aggregate.groupings_size() > 1) {
+        return Status::NotImplemented("Grouping sets not supported.");
+      }
+      std::vector<FieldRef> keys;
+      auto group = aggregate.groupings(0);
+      keys.reserve(group.grouping_expressions_size());
+      for (int exp_id = 0; exp_id < group.grouping_expressions_size(); 
exp_id++) {
+        ARROW_ASSIGN_OR_RAISE(auto expr,
+                              FromProto(group.grouping_expressions(exp_id), 
ext_set));
+        const auto* field_ref = expr.field_ref();
+        if (field_ref) {
+          keys.emplace_back(std::move(*field_ref));
+        } else {
+          return Status::Invalid(
+              "The grouping expression for an aggregate must be a direct 
reference.");
+        }
+      }
+
+      int measure_size = aggregate.measures_size();
+      std::vector<compute::Aggregate> aggregates;
+      aggregates.reserve(measure_size);
+      for (int measure_id = 0; measure_id < measure_size; measure_id++) {
+        const auto& agg_measure = aggregate.measures(measure_id);
+        if (agg_measure.has_measure()) {
+          if (agg_measure.has_filter()) {
+            return Status::NotImplemented("Aggregate filters are not 
supported.");
+          }
+          const auto& agg_func = agg_measure.measure();
+          if (agg_func.arguments_size() != 1) {
+            return Status::NotImplemented("Aggregate function must be a unary 
function.");
+          }
+          int func_reference = agg_func.function_reference();
+          ARROW_ASSIGN_OR_RAISE(auto func_record, 
ext_set.DecodeFunction(func_reference));
+          // aggreagte function name
+          auto func_name = std::string(func_record.id.name);
+          // aggregate target
+          auto subs_func_args = agg_func.arguments(0);
+          ARROW_ASSIGN_OR_RAISE(auto field_expr,
+                                FromProto(subs_func_args.value(), ext_set));
+          auto target = field_expr.field_ref();
+          if (!target) {
+            return Status::Invalid(
+                "The input expression to an aggregate function must be a 
direct "
+                "reference.");
+          }
+          aggregates.emplace_back(compute::Aggregate{std::move(func_name), 
NULLPTR,
+                                                     std::move(*target), 
std::move("")});

Review Comment:
   ```suggestion
             aggregates.emplace_back(compute::Aggregate{std::move(func_name), 
NULLPTR,
                                                        std::move(*target), 
std::move("")});
   ```
   
   Minor nit.  It might be better to use:
   
   ```
             // If you are going to create the instance yourself you can just 
use push_back
             aggregates.push_back(compute::Aggregate{std::move(func_name), 
NULLPTR,
                                                        std::move(*target), 
std::move("")});
   ```
   
   or...
   
   ```
             // If you are using emplace_back you do not need to create the 
instance
             // yourself.
             aggregates.emplace_back(std::move(func_name), NULLPTR,
                                                        std::move(*target), 
std::move(""));
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -308,6 +308,71 @@ Result<compute::Declaration> FromProto(const 
substrait::Rel& rel,
       join_dec.inputs.emplace_back(std::move(right));
       return std::move(join_dec);
     }
+    case substrait::Rel::RelTypeCase::kAggregate: {
+      const auto& aggregate = rel.aggregate();
+      RETURN_NOT_OK(CheckRelCommon(aggregate));
+
+      if (!aggregate.has_input()) {
+        return Status::Invalid("substrait::AggregateRel with no input 
relation");
+      }
+
+      ARROW_ASSIGN_OR_RAISE(auto input, FromProto(aggregate.input(), ext_set));
+
+      if (aggregate.groupings_size() > 1) {
+        return Status::NotImplemented("Grouping sets not supported.");

Review Comment:
   ```suggestion
           return Status::NotImplemented("Grouping sets not supported.  
AggregateRel::groupings may not have more than one item");
   ```
   
   Minor nit: If someone gets this error they might not immediately realize how 
to modify the substrait plan to support Acero.



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1383,5 +1383,350 @@ TEST(Substrait, JoinPlanInvalidKeys) {
   }
 }
 
+TEST(Substrait, AggregateBase) {

Review Comment:
   ```suggestion
   TEST(Substrait, AggregateBasic) {
   ```
   
   Minor nit: `Base` might lead one to think that this is meant to be extended 
somehow.



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -1383,5 +1383,350 @@ TEST(Substrait, JoinPlanInvalidKeys) {
   }
 }
 
+TEST(Substrait, AggregateBase) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+    "relations": [{
+      "rel": {
+        "aggregate": {
+          "input": {
+            "read": {
+              "base_schema": {
+                "names": ["A", "B", "C"],
+                "struct": {
+                  "types": [{
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }, {
+                    "i32": {}
+                  }]
+                }
+              },
+              "local_files": { 
+                "items": [
+                  {
+                    "uri_file": "file:///tmp/dat.parquet",
+                    "parquet": {}
+                  }
+                ]
+              }
+            }
+          },
+          "groupings": [{
+            "groupingExpressions": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                }
+              }
+            }]
+          }],
+          "measures": [{
+            "measure": {
+              "functionReference": 0,
+              "arguments": [{
+                "value": {
+                  "selection": {
+                    "directReference": {
+                      "structField": {
+                        "field": 1
+                      }
+                    }
+                  }
+                }
+            }],
+              "sorts": [],
+              "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+              "outputType": {
+                "i64": {}
+              }
+            }
+          }]
+        }
+      }
+    }],
+    "extensionUris": [{
+      "extension_uri_anchor": 0,
+      "uri": 
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml";
+    }],
+    "extensions": [{
+      "extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "count"

Review Comment:
   This plan wouldn't actually work I think.  Since there is a grouping it will 
be a hash aggregate and will need to use `hash_count` instead of `count`.  
However, we are not running the test end-to-end so I think we get away with it. 
 Still, might be nice to update it to be accurate.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to