pitrou commented on code in PR #15083:
URL: https://github.com/apache/arrow/pull/15083#discussion_r1072494388


##########
cpp/src/arrow/compute/exec.h:
##########
@@ -174,7 +175,12 @@ struct ARROW_EXPORT ExecBatch {
 
   explicit ExecBatch(const RecordBatch& batch);
 
-  static Result<ExecBatch> Make(std::vector<Datum> values);
+  /// \brief Infers the ExecBatch length from values.

Review Comment:
   ```suggestion
     /// \brief Infer the ExecBatch length from values.
   ```



##########
cpp/src/arrow/compute/exec/aggregate.cc:
##########
@@ -121,27 +127,48 @@ Result<Datum> GroupBy(const std::vector<Datum>& 
arguments, const std::vector<Dat
   ExecSpanIterator argument_iterator;
 
   ExecBatch args_batch;
-  if (!arguments.empty()) {
-    ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments));
+  std::optional<int64_t> inferred_length = ExecBatch::InferLength(arguments);
+  if (!inferred_length.has_value()) {
+    inferred_length = ExecBatch::InferLength(keys);
+  }
+  DCHECK(inferred_length.has_value());
+  const int64_t length = inferred_length.value();
+  if (!aggregates.empty()) {
+    ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments, length));
 
     // Construct and initialize HashAggregateKernels
-    auto argument_types = args_batch.GetTypes();
+    std::vector<std::vector<TypeHolder>> 
aggs_argument_types(aggregates.size());
+    {
+      // Contains the flattened list of aggregate arguments. We use the size of
+      // each Aggregate::target to re-group the aggregate argument types.
+      auto argument_types = args_batch.GetTypes();
+      size_t i = 0;
+      for (size_t j = 0; j < aggregates.size(); j++) {
+        const size_t num_agg_args = aggregates[j].target.size();
+        for (size_t k = 0; k < num_agg_args && i < argument_types.size(); k++, 
i++) {

Review Comment:
   `i < argument_types.size()` isn't necessary, is it?



##########
cpp/src/arrow/compute/exec/aggregate.cc:
##########
@@ -121,27 +127,48 @@ Result<Datum> GroupBy(const std::vector<Datum>& 
arguments, const std::vector<Dat
   ExecSpanIterator argument_iterator;
 
   ExecBatch args_batch;
-  if (!arguments.empty()) {
-    ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments));
+  std::optional<int64_t> inferred_length = ExecBatch::InferLength(arguments);
+  if (!inferred_length.has_value()) {
+    inferred_length = ExecBatch::InferLength(keys);
+  }
+  DCHECK(inferred_length.has_value());

Review Comment:
   Are we sure this always succeeds? Otherwise we should return an error to the 
caller instead of crashing out.



##########
cpp/src/arrow/compute/api_aggregate.h:
##########
@@ -56,6 +58,13 @@ class ARROW_EXPORT ScalarAggregateOptions : public 
FunctionOptions {
   uint32_t min_count;
 };
 
+class ARROW_EXPORT CountAllOptions : public FunctionOptions {

Review Comment:
   Hmm, what is this for? If a given function doesn't support any specific 
options, then a `nullptr` FunctionOptions should simply be passed.



##########
python/pyarrow/table.pxi:
##########
@@ -5334,7 +5334,7 @@ class TableGroupBy:
         Parameters
         ----------
         aggregations : list[tuple(str, str)] or \
-list[tuple(str, str, FunctionOptions)]
+list[tuple(str|list[str]|tuple(str*), str, FunctionOptions)]

Review Comment:
   This is getting rather unwieldy (both the API and the implementation below).
   I wonder if we can find another way?
   cc @jorisvandenbossche for opinions



##########
cpp/src/arrow/compute/exec/aggregate.cc:
##########
@@ -121,27 +127,48 @@ Result<Datum> GroupBy(const std::vector<Datum>& 
arguments, const std::vector<Dat
   ExecSpanIterator argument_iterator;
 
   ExecBatch args_batch;
-  if (!arguments.empty()) {
-    ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments));
+  std::optional<int64_t> inferred_length = ExecBatch::InferLength(arguments);
+  if (!inferred_length.has_value()) {
+    inferred_length = ExecBatch::InferLength(keys);
+  }
+  DCHECK(inferred_length.has_value());
+  const int64_t length = inferred_length.value();
+  if (!aggregates.empty()) {
+    ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments, length));
 
     // Construct and initialize HashAggregateKernels
-    auto argument_types = args_batch.GetTypes();
+    std::vector<std::vector<TypeHolder>> 
aggs_argument_types(aggregates.size());
+    {
+      // Contains the flattened list of aggregate arguments. We use the size of
+      // each Aggregate::target to re-group the aggregate argument types.
+      auto argument_types = args_batch.GetTypes();
+      size_t i = 0;
+      for (size_t j = 0; j < aggregates.size(); j++) {
+        const size_t num_agg_args = aggregates[j].target.size();
+        for (size_t k = 0; k < num_agg_args && i < argument_types.size(); k++, 
i++) {
+          aggs_argument_types[j].push_back(std::move(argument_types[i]));
+        }
+      }
+      DCHECK_EQ(i, argument_types.size())
+          << "argument_types should contain input types for all the 
aggregates.";
+    }
 
-    ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, 
argument_types));
+    ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, 
aggs_argument_types));

Review Comment:
   Isn't it slightly silly to bother merging all aggregates together, and then 
call `GetKernels` which just iterates over them? Or am I missing something?



##########
cpp/src/arrow/compute/kernels/hash_aggregate.cc:
##########
@@ -223,6 +235,53 @@ void VisitGroupedValuesNonNull(const ExecSpan& batch, 
ConsumeValue&& valid_func)
 // ----------------------------------------------------------------------
 // Count implementation
 
+// Nullary-count implementation -- COUNT(*).
+struct GroupedCountAllImpl : public GroupedAggregator {
+  Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
+    counts_ = BufferBuilder(ctx->memory_pool());
+    return Status::OK();
+  }
+
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    return counts_.Append(added_groups * sizeof(int64_t), 0);
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    auto other = checked_cast<GroupedCountAllImpl*>(&raw_other);
+
+    auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    auto other_counts = reinterpret_cast<const 
int64_t*>(other->counts_.mutable_data());

Review Comment:
   ```suggestion
       auto other_counts = reinterpret_cast<const 
int64_t*>(other->counts_.data());
   ```



##########
cpp/src/arrow/compute/exec.cc:
##########
@@ -163,6 +163,34 @@ Result<ExecBatch> ExecBatch::Make(std::vector<Datum> 
values) {
       continue;
     }
 
+    if (length != value.length()) {
+      // all the arrays should have the same length
+      return -1;
+    }
+  }
+
+  return length == -1 ? 1 : length;
+}
+
+Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values, int64_t length) {
+  if (length < 0) {

Review Comment:
   I'd rather converge on a single function as well. Besides, `InferLength` 
should return a `Result<...>` if some error may occur, IMHO.



##########
cpp/src/arrow/engine/substrait/test_plan_builder.h:
##########
@@ -64,11 +64,12 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> 
CreateScanProjectSubstrait(
 /// \brief Create a scan->aggregate->sink plan for tests
 ///
 /// The plan will create an aggregate with one grouping set (defined by
-/// key_idxs) and one measure.  The measure will be a unary function
-/// defined by `function_id` and a direct reference to `arg_idx`.
+/// key_idxs) and one measure.  The measure will be a function
+/// defined by `function_id` and direct references to `arg_idxs`.
 ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> CreateScanAggSubstrait(
     Id function_id, const std::shared_ptr<Table>& input_table,
-    const std::vector<int>& key_idxs, int arg_idx, const DataType& 
output_type);
+    const std::vector<int>& key_idxs, const std::vector<int>& arg_idx,

Review Comment:
   ```suggestion
       const std::vector<int>& key_idxs, const std::vector<int>& arg_idxs,
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to