aucahuasi commented on a change in pull request #11257:
URL: https://github.com/apache/arrow/pull/11257#discussion_r718535676
##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>>
CountInit(KernelContext*,
static_cast<const CountOptions&>(*args.options));
}
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+ using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+ explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+ : options(std::move(options)), memo_table_(new MemoTable(memory_pool,
0)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const ArrayData& arr = *batch[0].array();
+ auto visit_null = [&]() {
+ if (this->nulls > 0) return Status::OK();
+ ++this->nulls;
+ return Status::OK();
+ };
+ auto visit_value = [&](typename Type::c_type arg) {
+ int y;
+ RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+ return Status::OK();
+ };
+ RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+ this->non_nulls += this->memo_table_->size();
+ } else {
+ const Scalar& input = *batch[0].scalar();
+ this->nulls += !input.is_valid * batch.length;
+ this->non_nulls += input.is_valid * batch.length;
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+ this->non_nulls += other_state.non_nulls;
+ this->nulls += other_state.nulls;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+ switch (state.options.mode) {
+ case CountOptions::ONLY_VALID:
+ *out = Datum(state.non_nulls);
+ break;
+ case CountOptions::ALL:
+ *out = Datum(state.non_nulls + state.nulls);
+ break;
+ case CountOptions::ONLY_NULL:
+ *out = Datum(state.nulls);
+ break;
+ default:
+ DCHECK(false) << "unreachable";
+ }
+ return Status::OK();
+ }
+
+ CountOptions options;
+ int64_t non_nulls = 0;
Review comment:
Thanks, I did that for my first local version. I changed it to have more
similar implementation to count, but I think it wasn't the best choice.
Let me improve this part!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]