rtpsw commented on code in PR #34311:
URL: https://github.com/apache/arrow/pull/34311#discussion_r1117884067


##########
cpp/src/arrow/compute/row/grouper.h:
##########
@@ -30,6 +30,49 @@
 namespace arrow {
 namespace compute {
 
+/// \brief A segment of contiguous rows for grouping
+struct ARROW_EXPORT GroupingSegment {
+  int64_t offset;
+  int64_t length;
+  bool is_open;
+  bool extends;
+};
+
+inline bool operator==(const GroupingSegment& segment1, const GroupingSegment& 
segment2) {
+  return segment1.offset == segment2.offset && segment1.length == 
segment2.length &&
+         segment1.is_open == segment2.is_open && segment1.extends == 
segment2.extends;
+}
+inline bool operator!=(const GroupingSegment& segment1, const GroupingSegment& 
segment2) {
+  return !(segment1 == segment2);
+}
+
+/// \brief Computes grouping segments for a batch. Each segment covers rows 
with identical
+/// values in the batch. The values in the batch are often selected as keys 
from a larger
+/// batch.
+class ARROW_EXPORT GroupingSegmenter {
+ public:
+  virtual ~GroupingSegmenter() = default;
+
+  /// \brief Construct a GroupingSegmenter which receives the specified key 
types
+  static Result<std::unique_ptr<GroupingSegmenter>> Make(
+      const std::vector<TypeHolder>& key_types, bool nullable_keys = false,

Review Comment:
   Done.



##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -326,46 +446,86 @@ class ScalarAggregateNode : public ExecNode, public 
TracedNode {
   }
 
  private:
-  Status Finish() {
-    auto scope = TraceFinish();
+  Status ReconstructAggregates() {
+    const auto& input_schema = *inputs()[0]->output_schema();
+    auto exec_ctx = plan()->query_context()->exec_context();
+    for (size_t i = 0; i < kernels_.size(); ++i) {
+      std::vector<TypeHolder> in_types;
+      for (const auto& target : target_fieldsets_[i]) {
+        in_types.emplace_back(input_schema.field(target)->type().get());
+      }
+      states_[i].resize(plan()->query_context()->max_concurrency());
+      KernelContext kernel_ctx{exec_ctx};
+      RETURN_NOT_OK(Kernel::InitAll(
+          &kernel_ctx, KernelInitArgs{kernels_[i], in_types, 
aggs_[i].options.get()},
+          &states_[i]));
+    }
+    return Status::OK();
+  }
+
+  Status OutputResult(bool is_last = false, bool traced = false) {
+    if (is_last && !traced) {
+      auto scope = TraceFinish();
+      return OutputResult(is_last, /*traced=*/true);
+    }
+    GatedUniqueLock lock(gated_shared_mutex_);
     ExecBatch batch{{}, 1};
-    batch.values.resize(kernels_.size());
+    batch.values.resize(kernels_.size() + segment_field_ids_.size());
 
     for (size_t i = 0; i < kernels_.size(); ++i) {
       util::tracing::Span span;
       START_COMPUTE_SPAN(span, aggs_[i].function,
                          {{"function.name", aggs_[i].function},
                           {"function.options",
                            aggs_[i].options ? aggs_[i].options->ToString() : 
"<NULLPTR>"},
-                          {"function.kind", std::string(kind_name()) + 
"::Finalize"}});
+                          {"function.kind", std::string(kind_name()) + 
"::Output"}});
       KernelContext ctx{plan()->query_context()->exec_context()};
       ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll(
                                              kernels_[i], &ctx, 
std::move(states_[i])));
       RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
     }
+    PlaceFields(batch, kernels_.size(), segmenter_values_);
 
-    return output_->InputReceived(this, std::move(batch));
+    ARROW_RETURN_NOT_OK(output_->InputReceived(this, std::move(batch)));
+    total_output_batches_++;
+    if (is_last) {
+      ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_output_batches_));
+    } else {
+      ARROW_RETURN_NOT_OK(ReconstructAggregates());
+    }
+    return Status::OK();
   }
 
+  std::unique_ptr<GroupingSegmenter> segmenter_;
+  const std::vector<int> segment_field_ids_;
+  std::vector<Datum> segmenter_values_;
+
   const std::vector<std::vector<int>> target_fieldsets_;
   const std::vector<Aggregate> aggs_;
   const std::vector<const ScalarAggregateKernel*> kernels_;
 
   std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
 
   AtomicCounter input_counter_;
+  int64_t total_output_batches_ = 0;

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to