bkietz commented on a change in pull request #9621:
URL: https://github.com/apache/arrow/pull/9621#discussion_r592454275



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -91,6 +95,377 @@ struct CountImpl : public ScalarAggregator {
   int64_t nulls = 0;
 };
 
+struct GroupedAggregator {
+  virtual ~GroupedAggregator() = default;
+
+  virtual void Consume(KernelContext*, const Datum& aggregand,
+                       const uint32_t* group_ids) = 0;
+
+  virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
+
+  virtual void Resize(KernelContext* ctx, int64_t new_num_groups) = 0;
+
+  virtual int64_t num_groups() const = 0;
+
+  void MaybeResize(KernelContext* ctx, int64_t length, const uint32_t* 
group_ids) {
+    if (length == 0) return;
+
+    // maybe a batch of group_ids should include the min/max group id
+    int64_t max_group = *std::max_element(group_ids, group_ids + length);
+    auto old_size = num_groups();
+
+    if (max_group >= old_size) {
+      auto new_size = BufferBuilder::GrowByFactor(old_size, max_group + 1);
+      Resize(ctx, new_size);
+    }
+  }
+
+  virtual std::shared_ptr<DataType> out_type() const = 0;
+};
+
+struct GroupedCountImpl : public GroupedAggregator {
+  static std::unique_ptr<GroupedCountImpl> Make(KernelContext* ctx,
+                                                const 
std::shared_ptr<DataType>&,
+                                                const FunctionOptions* 
options) {
+    auto out = ::arrow::internal::make_unique<GroupedCountImpl>();
+    out->options_ = checked_cast<const CountOptions&>(*options);
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->counts_));
+    return out;
+  }
+
+  void Resize(KernelContext* ctx, int64_t new_num_groups) override {
+    auto old_size = num_groups();
+    KERNEL_RETURN_IF_ERROR(ctx, counts_->TypedResize<int64_t>(new_num_groups));
+    auto new_size = num_groups();
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+    for (auto i = old_size; i < new_size; ++i) {
+      raw_counts[i] = 0;
+    }
+  }
+
+  void Consume(KernelContext* ctx, const Datum& aggregand,
+               const uint32_t* group_ids) override {
+    MaybeResize(ctx, aggregand.length(), group_ids);
+    if (ctx->HasError()) return;
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+
+    const auto& input = aggregand.array();
+
+    if (options_.count_mode == CountOptions::COUNT_NULL) {
+      for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, 
++input_i) {
+        auto g = group_ids[i];
+        raw_counts[g] += !BitUtil::GetBit(input->buffers[0]->data(), input_i);
+      }
+      return;
+    }
+
+    arrow::internal::VisitSetBitRunsVoid(
+        input->buffers[0], input->offset, input->length,
+        [&](int64_t begin, int64_t length) {
+          for (int64_t input_i = begin, i = begin - input->offset;
+               input_i < begin + length; ++input_i, ++i) {
+            auto g = group_ids[i];
+            raw_counts[g] += 1;
+          }
+        });
+  }
+
+  void Finalize(KernelContext* ctx, Datum* out) override {
+    auto length = num_groups();
+    *out = std::make_shared<Int64Array>(length, std::move(counts_));
+  }
+
+  int64_t num_groups() const override { return counts_->size() / 
sizeof(int64_t); }
+
+  std::shared_ptr<DataType> out_type() const override { return int64(); }
+
+  CountOptions options_;
+  std::shared_ptr<ResizableBuffer> counts_;
+};
+
+struct GroupedSumImpl : public GroupedAggregator {
+  // NB: whether we are accumulating into double, int64_t, or uint64_t
+  // we always have 64 bits per group in the sums buffer.
+  static constexpr size_t kSumSize = sizeof(int64_t);
+
+  using ConsumeImpl = std::function<void(const std::shared_ptr<ArrayData>&,
+                                         const uint32_t*, Buffer*, Buffer*)>;
+
+  struct GetConsumeImpl {
+    template <typename T,
+              typename AccumulatorType = typename FindAccumulatorType<T>::Type>
+    Status Visit(const T&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, Buffer* sums, Buffer* 
counts) {
+        auto raw_input = reinterpret_cast<const typename 
TypeTraits<T>::CType*>(
+            input->buffers[1]->data());
+        auto raw_sums = reinterpret_cast<typename 
TypeTraits<AccumulatorType>::CType*>(
+            sums->mutable_data());
+        auto raw_counts = reinterpret_cast<int64_t*>(counts->mutable_data());
+
+        arrow::internal::VisitSetBitRunsVoid(
+            input->buffers[0], input->offset, input->length,
+            [&](int64_t begin, int64_t length) {
+              for (int64_t input_i = begin, i = begin - input->offset;
+                   input_i < begin + length; ++input_i, ++i) {
+                auto g = group_ids[i];
+                raw_sums[g] += raw_input[input_i];
+                raw_counts[g] += 1;
+              }
+            });
+      };
+      out_type = TypeTraits<AccumulatorType>::type_singleton();
+      return Status::OK();
+    }
+
+    Status Visit(const BooleanType&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, Buffer* sums, Buffer* 
counts) {
+        auto raw_input = input->buffers[1]->data();
+        auto raw_sums = reinterpret_cast<uint64_t*>(sums->mutable_data());
+        auto raw_counts = reinterpret_cast<int64_t*>(counts->mutable_data());
+
+        arrow::internal::VisitSetBitRunsVoid(
+            input->buffers[0], input->offset, input->length,
+            [&](int64_t begin, int64_t length) {
+              for (int64_t input_i = begin, i = begin - input->offset;
+                   input_i < begin + length; ++input_i) {
+                auto g = group_ids[i];
+                raw_sums[g] += BitUtil::GetBit(raw_input, input_i);
+                raw_counts[g] += 1;
+              }
+            });
+      };
+      out_type = boolean();

Review comment:
       This is indeed a typo. Will correct to `uint64()`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to