bkietz commented on a change in pull request #10660:
URL: https://github.com/apache/arrow/pull/10660#discussion_r668001720



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -748,67 +752,128 @@ struct GrouperFastImpl : Grouper {
 /// Implementations should be default constructible and perform initialization 
in
 /// Init().
 struct GroupedAggregator : KernelState {
-  virtual Status Init(ExecContext*, const FunctionOptions*,
-                      const std::shared_ptr<DataType>&) = 0;
+  virtual Status Init(ExecContext*, const FunctionOptions*) = 0;
+
+  virtual Status Resize(int64_t new_num_groups) = 0;
 
   virtual Status Consume(const ExecBatch& batch) = 0;
 
-  virtual Result<Datum> Finalize() = 0;
+  virtual Status Merge(GroupedAggregator&& other, const ArrayData& 
group_id_mapping) = 0;
 
-  template <typename Reserve>
-  Status MaybeReserve(int64_t old_num_groups, const ExecBatch& batch,
-                      const Reserve& reserve) {
-    int64_t new_num_groups = batch[2].scalar_as<UInt32Scalar>().value;
-    if (new_num_groups <= old_num_groups) {
-      return Status::OK();
-    }
-    return reserve(new_num_groups - old_num_groups);
-  }
+  virtual Result<Datum> Finalize() = 0;
 
   virtual std::shared_ptr<DataType> out_type() const = 0;
 };
 
+template <typename Impl>
+Result<std::unique_ptr<KernelState>> HashAggregateInit(KernelContext* ctx,
+                                                       const KernelInitArgs& 
args) {
+  auto impl = ::arrow::internal::make_unique<Impl>();
+  RETURN_NOT_OK(impl->Init(ctx->exec_context(), args.options));
+  return std::move(impl);
+}
+
+HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
+  HashAggregateKernel kernel;
+
+  kernel.init = std::move(init);
+
+  kernel.signature = KernelSignature::Make(
+      {std::move(argument_type), InputType::Array(Type::UINT32)},
+      OutputType(
+          [](KernelContext* ctx, const std::vector<ValueDescr>&) -> 
Result<ValueDescr> {
+            return checked_cast<GroupedAggregator*>(ctx->state())->out_type();
+          }));
+
+  kernel.resize = [](KernelContext* ctx, int64_t num_groups) {
+    return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups);
+  };
+
+  kernel.consume = [](KernelContext* ctx, const ExecBatch& batch) {
+    return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch);
+  };
+
+  kernel.merge = [](KernelContext* ctx, KernelState&& other,
+                    const ArrayData& group_id_mapping) {
+    return checked_cast<GroupedAggregator*>(ctx->state())
+        ->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping);
+  };
+
+  kernel.finalize = [](KernelContext* ctx, Datum* out) {
+    ARROW_ASSIGN_OR_RAISE(*out,
+                          
checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
+    return Status::OK();
+  };
+
+  return kernel;
+}
+
+Status AddHashAggKernels(
+    const std::vector<std::shared_ptr<DataType>>& types,
+    Result<HashAggregateKernel> make_kernel(const std::shared_ptr<DataType>&),
+    HashAggregateFunction* function) {
+  for (const auto& ty : types) {
+    ARROW_ASSIGN_OR_RAISE(auto kernel, make_kernel(ty));
+    RETURN_NOT_OK(function->AddKernel(std::move(kernel)));
+  }
+  return Status::OK();
+}
+
 // ----------------------------------------------------------------------
 // Count implementation
 
 struct GroupedCountImpl : public GroupedAggregator {
-  Status Init(ExecContext* ctx, const FunctionOptions* options,
-              const std::shared_ptr<DataType>&) override {
+  Status Init(ExecContext* ctx, const FunctionOptions* options) override {
     options_ = checked_cast<const ScalarAggregateOptions&>(*options);
     counts_ = BufferBuilder(ctx->memory_pool());
     return Status::OK();
   }
 
-  Status Consume(const ExecBatch& batch) override {
-    RETURN_NOT_OK(MaybeReserve(num_groups_, batch, [&](int64_t added_groups) {
-      num_groups_ += added_groups;
-      return counts_.Append(added_groups * sizeof(int64_t), 0);
-    }));
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    return counts_.Append(added_groups * sizeof(int64_t), 0);
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    auto other = checked_cast<GroupedCountImpl*>(&raw_other);
 
-    auto group_ids = batch[1].array()->GetValues<uint32_t>(1);
-    auto raw_counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    auto other_counts = reinterpret_cast<const 
int64_t*>(other->counts_.mutable_data());
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+    for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, 
++g) {
+      counts[*g] += other_counts[other_g];
+    }
+    return Status::OK();
+  }
+
+  Status Consume(const ExecBatch& batch) override {
+    auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
 
     const auto& input = batch[0].array();
 
-    if (!options_.skip_nulls) {
-      if (input->GetNullCount() != 0) {
-        for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, 
++input_i) {
-          auto g = group_ids[i];
-          raw_counts[g] += !BitUtil::GetBit(input->buffers[0]->data(), 
input_i);
-        }
+    if (options_.skip_nulls) {
+      auto g_begin =
+          reinterpret_cast<const 
uint32_t*>(batch[1].array()->buffers[1]->data());
+
+      arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset,
+                                           input->length,
+                                           [&](int64_t offset, int64_t length) 
{
+                                             auto g = g_begin + offset;
+                                             for (int64_t i = 0; i < length; 
++i, ++g) {
+                                               counts[*g] += 1;
+                                             }
+                                           });
+    } else if (input->MayHaveNulls()) {

Review comment:
       If `skip_nulls`, we want to count only valid slots. If `!skip_nulls`, we 
want to count only null slots- in that case if `!MayHaveNulls` then no counts 
will be incremented




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to