rtpsw commented on code in PR #14352:
URL: https://github.com/apache/arrow/pull/14352#discussion_r1021812019


##########
cpp/src/arrow/compute/exec/aggregate.cc:
##########
@@ -101,136 +106,401 @@ Result<FieldVector> ResolveKernels(
   return fields;
 }
 
-Result<Datum> GroupBy(const std::vector<Datum>& arguments, const 
std::vector<Datum>& keys,
-                      const std::vector<Aggregate>& aggregates, bool 
use_threads,
-                      ExecContext* ctx) {
-  auto task_group =
-      use_threads
-          ? 
arrow::internal::TaskGroup::MakeThreaded(arrow::internal::GetCpuThreadPool())
-          : arrow::internal::TaskGroup::MakeSerial();
+namespace {
 
-  std::vector<const HashAggregateKernel*> kernels;
-  std::vector<std::vector<std::unique_ptr<KernelState>>> states;
-  FieldVector out_fields;
+template <typename T>
+inline std::string ToString(const std::vector<T>& v) {
+  std::stringstream s;
+  s << '[';
+  for (size_t i = 0; i < v.size(); i++) {
+    if (i != 0) s << ',';
+    s << v[i];
+  }
+  s << ']';
+  return s.str();
+}
 
-  using arrow::compute::detail::ExecSpanIterator;
-  ExecSpanIterator argument_iterator;
+int64_t FindLength(const std::vector<Datum>& arguments, const 
std::vector<Datum>& keys,
+                   const std::vector<Datum>& segment_keys) {
+  int64_t length = -1;
+  for (const auto& datums : {arguments, keys, segment_keys}) {
+    for (const auto& datum : datums) {
+      if (datum.is_scalar()) {
+        // do nothing
+      } else if (datum.is_array() || datum.is_chunked_array()) {
+        int64_t datum_length =
+            datum.is_array() ? datum.array()->length : 
datum.chunked_array()->length();
+        if (length == -1) {
+          length = datum_length;
+        } else if (length != datum_length) {
+          return -1;
+        }
+      } else {
+        ARROW_DCHECK(false);
+      }
+    }
+  }
+  return length;
+}
 
-  ExecBatch args_batch;
-  if (!arguments.empty()) {
-    ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments));
+class GroupByProcess {
+ public:
+  struct BatchInfo {
+    ExecBatch args_batch;
+    std::vector<TypeHolder> argument_types;
+    ExecBatch keys_batch;
+    std::vector<TypeHolder> key_types;
+    ExecBatch segment_keys_batch;
+    std::vector<TypeHolder> segment_key_types;
+
+    static Result<BatchInfo> Make(const std::vector<Datum>& arguments,
+                                  const std::vector<Datum>& keys,
+                                  const std::vector<Datum>& segment_keys) {
+      int64_t batch_length = FindLength(arguments, keys, segment_keys);
+
+      ARROW_ASSIGN_OR_RAISE(auto args_batch, ExecBatch::Make(arguments, 
batch_length));
+      auto argument_types = args_batch.GetTypes();
+
+      ARROW_ASSIGN_OR_RAISE(auto keys_batch, ExecBatch::Make(keys, 
batch_length));
+      auto key_types = keys_batch.GetTypes();
+
+      ARROW_ASSIGN_OR_RAISE(auto segment_keys_batch,
+                            ExecBatch::Make(segment_keys, batch_length));
+      auto segment_key_types = segment_keys_batch.GetTypes();
+
+      return BatchInfo{std::move(args_batch),         
std::move(argument_types),
+                       std::move(keys_batch),         std::move(key_types),
+                       std::move(segment_keys_batch), 
std::move(segment_key_types)};
+    }
 
-    // Construct and initialize HashAggregateKernels
-    auto argument_types = args_batch.GetTypes();
+    BatchInfo Slice(int64_t offset, int64_t length) const {
+      return BatchInfo{args_batch.Slice(offset, length),         
argument_types,
+                       keys_batch.Slice(offset, length),         key_types,
+                       segment_keys_batch.Slice(offset, length), 
segment_key_types};
+    }
+  };
+
+  struct StateInfo {
+    GroupByProcess& process;
+    std::shared_ptr<arrow::internal::TaskGroup> task_group;
+    std::vector<std::unique_ptr<Grouper>> groupers;
+    std::vector<const HashAggregateKernel*> kernels;
+    std::vector<std::vector<std::unique_ptr<KernelState>>> states;
+    FieldVector out_fields;
+    ExecSpanIterator argument_iterator;
+    ExecSpanIterator key_iterator;
+    ScalarVector segment_keys;
+
+    explicit StateInfo(GroupByProcess& process) : process(process) {}
+
+    Status Init() {
+      const std::vector<TypeHolder>& argument_types = process.argument_types;
+      const std::vector<TypeHolder>& key_types = process.key_types;
+      const std::vector<Aggregate>& aggregates = process.aggregates;
+      ExecContext* ctx = process.ctx;
+      const FieldVector& key_fields = process.key_fields;
+
+      task_group = process.use_threads ? 
arrow::internal::TaskGroup::MakeThreaded(
+                                             
arrow::internal::GetCpuThreadPool())
+                                       : 
arrow::internal::TaskGroup::MakeSerial();
+
+      groupers.resize(task_group->parallelism());
+      for (auto& grouper : groupers) {
+        ARROW_ASSIGN_OR_RAISE(grouper, Grouper::Make(key_types, ctx));
+      }
 
-    ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, 
argument_types));
+      if (!argument_types.empty()) {
+        // Construct and initialize HashAggregateKernels
+        ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, 
argument_types));
 
-    states.resize(task_group->parallelism());
-    for (auto& state : states) {
-      ARROW_ASSIGN_OR_RAISE(state, InitKernels(kernels, ctx, aggregates, 
argument_types));
-    }
+        states.resize(task_group->parallelism());
+        for (auto& state : states) {
+          ARROW_ASSIGN_OR_RAISE(state,
+                                InitKernels(kernels, ctx, aggregates, 
argument_types));
+        }
 
-    ARROW_ASSIGN_OR_RAISE(
-        out_fields, ResolveKernels(aggregates, kernels, states[0], ctx, 
argument_types));
+        ARROW_ASSIGN_OR_RAISE(out_fields, ResolveKernels(aggregates, kernels, 
states[0],
+                                                         ctx, argument_types));
+      } else {
+        out_fields = {};
+      }
+      out_fields.insert(out_fields.end(), key_fields.begin(), 
key_fields.end());
 
-    RETURN_NOT_OK(argument_iterator.Init(args_batch, ctx->exec_chunksize()));
-  }
+      return Status::OK();
+    }
 
-  // Construct Groupers
-  ARROW_ASSIGN_OR_RAISE(ExecBatch keys_batch, ExecBatch::Make(keys));
-  auto key_types = keys_batch.GetTypes();
+    Status Consume(const BatchInfo& batch_info) {
+      const std::vector<TypeHolder>& argument_types = process.argument_types;
+      ExecContext* ctx = process.ctx;
 
-  std::vector<std::unique_ptr<Grouper>> groupers(task_group->parallelism());
-  for (auto& grouper : groupers) {
-    ARROW_ASSIGN_OR_RAISE(grouper, Grouper::Make(key_types, ctx));
-  }
+      const ExecBatch& args_batch = batch_info.args_batch;
+      const ExecBatch& keys_batch = batch_info.keys_batch;
+      const ExecBatch& segment_keys_batch = batch_info.segment_keys_batch;
+
+      if (segment_keys_batch.length == 0) {
+        return Status::OK();
+      }
+      segment_keys = {};
+      for (auto value : segment_keys_batch.values) {
+        if (value.is_scalar()) {
+          segment_keys.push_back(value.scalar());
+        } else if (value.is_array()) {
+          ARROW_ASSIGN_OR_RAISE(auto scalar, value.make_array()->GetScalar(0));
+          segment_keys.push_back(scalar);
+        } else if (value.is_chunked_array()) {
+          ARROW_ASSIGN_OR_RAISE(auto scalar, 
value.chunked_array()->GetScalar(0));
+          segment_keys.push_back(scalar);
+        } else {
+          return Status::Invalid("consuming an invalid segment key type ", 
*value.type());
+        }
+      }
 
-  std::mutex mutex;
-  std::unordered_map<std::thread::id, size_t> thread_ids;
+      if (!argument_types.empty()) {
+        ARROW_RETURN_NOT_OK(argument_iterator.Init(args_batch, 
ctx->exec_chunksize()));
+      }
+      ARROW_RETURN_NOT_OK(key_iterator.Init(keys_batch, 
ctx->exec_chunksize()));
+
+      std::mutex mutex;
+      std::unordered_map<std::thread::id, size_t> thread_ids;

Review Comment:
   I'll go with `ThreadIndexer` and defer on #12871.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to