bkietz commented on a change in pull request #9621:
URL: https://github.com/apache/arrow/pull/9621#discussion_r592471600



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -91,6 +95,377 @@ struct CountImpl : public ScalarAggregator {
   int64_t nulls = 0;
 };
 
+struct GroupedAggregator {
+  virtual ~GroupedAggregator() = default;
+
+  virtual void Consume(KernelContext*, const Datum& aggregand,
+                       const uint32_t* group_ids) = 0;
+
+  virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
+
+  virtual void Resize(KernelContext* ctx, int64_t new_num_groups) = 0;
+
+  virtual int64_t num_groups() const = 0;
+
+  void MaybeResize(KernelContext* ctx, int64_t length, const uint32_t* 
group_ids) {
+    if (length == 0) return;
+
+    // maybe a batch of group_ids should include the min/max group id
+    int64_t max_group = *std::max_element(group_ids, group_ids + length);
+    auto old_size = num_groups();
+
+    if (max_group >= old_size) {
+      auto new_size = BufferBuilder::GrowByFactor(old_size, max_group + 1);
+      Resize(ctx, new_size);
+    }
+  }
+
+  virtual std::shared_ptr<DataType> out_type() const = 0;
+};
+
+struct GroupedCountImpl : public GroupedAggregator {
+  static std::unique_ptr<GroupedCountImpl> Make(KernelContext* ctx,
+                                                const 
std::shared_ptr<DataType>&,
+                                                const FunctionOptions* 
options) {
+    auto out = ::arrow::internal::make_unique<GroupedCountImpl>();
+    out->options_ = checked_cast<const CountOptions&>(*options);
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->counts_));
+    return out;
+  }
+
+  void Resize(KernelContext* ctx, int64_t new_num_groups) override {
+    auto old_size = num_groups();
+    KERNEL_RETURN_IF_ERROR(ctx, counts_->TypedResize<int64_t>(new_num_groups));
+    auto new_size = num_groups();
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+    for (auto i = old_size; i < new_size; ++i) {
+      raw_counts[i] = 0;
+    }
+  }
+
+  void Consume(KernelContext* ctx, const Datum& aggregand,
+               const uint32_t* group_ids) override {
+    MaybeResize(ctx, aggregand.length(), group_ids);
+    if (ctx->HasError()) return;
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+
+    const auto& input = aggregand.array();
+
+    if (options_.count_mode == CountOptions::COUNT_NULL) {
+      for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, 
++input_i) {
+        auto g = group_ids[i];
+        raw_counts[g] += !BitUtil::GetBit(input->buffers[0]->data(), input_i);
+      }
+      return;
+    }
+
+    arrow::internal::VisitSetBitRunsVoid(
+        input->buffers[0], input->offset, input->length,
+        [&](int64_t begin, int64_t length) {
+          for (int64_t input_i = begin, i = begin - input->offset;
+               input_i < begin + length; ++input_i, ++i) {
+            auto g = group_ids[i];
+            raw_counts[g] += 1;
+          }
+        });
+  }
+
+  void Finalize(KernelContext* ctx, Datum* out) override {
+    auto length = num_groups();
+    *out = std::make_shared<Int64Array>(length, std::move(counts_));
+  }
+
+  int64_t num_groups() const override { return counts_->size() / 
sizeof(int64_t); }

Review comment:
       I'll refactor to use BufferBuilder.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to