aucahuasi commented on a change in pull request #11257:
URL: https://github.com/apache/arrow/pull/11257#discussion_r721446677
##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,137 @@ Result<std::unique_ptr<KernelState>>
CountInit(KernelContext*,
static_cast<const CountOptions&>(*args.options));
}
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type, typename VisitorArgType>
+struct CountDistinctImpl : public ScalarAggregator {
+ using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+ explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+ : options(std::move(options)), memo_table_(new MemoTable(memory_pool,
0)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const ArrayData& arr = *batch[0].array();
+ auto visit_null = []() { return Status::OK(); };
+ auto visit_value = [&](VisitorArgType arg) {
+ int y;
+ RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+ return Status::OK();
+ };
+ RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+ this->non_nulls += memo_table_->size();
+ this->has_nulls = arr.GetNullCount() > 0;
+ } else {
+ const Scalar& input = *batch[0].scalar();
+ this->has_nulls = !input.is_valid;
+ this->non_nulls += batch.length;
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+ this->non_nulls += other_state.non_nulls;
+ this->has_nulls = this->has_nulls || other_state.has_nulls;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+ const int64_t nulls = state.has_nulls ? 1 : 0;
+ switch (state.options.mode) {
+ case CountOptions::ONLY_VALID:
+ *out = Datum(state.non_nulls);
+ break;
+ case CountOptions::ALL:
+ *out = Datum(state.non_nulls + nulls);
+ break;
+ case CountOptions::ONLY_NULL:
+ *out = Datum(nulls);
+ break;
+ default:
+ DCHECK(false) << "unreachable";
+ }
+ return Status::OK();
+ }
+
+ const CountOptions options;
+ int64_t non_nulls = 0;
+ bool has_nulls = false;
+ std::unique_ptr<MemoTable> memo_table_;
+};
+
+template <typename Type, typename VisitorArgType>
+Result<std::unique_ptr<KernelState>> CountDistinctInit(KernelContext* ctx,
+ const KernelInitArgs&
args) {
+ return ::arrow::internal::make_unique<CountDistinctImpl<Type,
VisitorArgType>>(
+ ctx->memory_pool(), static_cast<const CountOptions&>(*args.options));
+}
+
+template <typename Type, typename VisitorArgType>
+void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) {
+ AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())),
+ aggregate::CountDistinctInit<Type, VisitorArgType>, func);
+}
+
+template <typename Type, typename VisitorArgType = typename Type::c_type>
+struct CountDistinctKernel {
+ static void Add(InputType type, ScalarAggregateFunction* func) {
+ using PhysicalType = typename Type::PhysicalType;
+ AddCountDistinctKernel<PhysicalType, VisitorArgType>(type, func);
+ }
+};
+
+template <>
+struct CountDistinctKernel<FixedSizeBinaryType, util::string_view> {
+ static void Add(InputType type, ScalarAggregateFunction* func) {
+ AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(type, func);
+ }
+};
+
+void AddCountDistinctKernels(ScalarAggregateFunction* func) {
+ // Boolean
+ aggregate::CountDistinctKernel<BooleanType>::Add(boolean(), func);
+ // Number
+ aggregate::CountDistinctKernel<Int8Type>::Add(int8(), func);
+ aggregate::CountDistinctKernel<Int16Type>::Add(int16(), func);
+ aggregate::CountDistinctKernel<Int32Type>::Add(int32(), func);
+ aggregate::CountDistinctKernel<Int64Type>::Add(int64(), func);
+ aggregate::CountDistinctKernel<UInt8Type>::Add(uint8(), func);
+ aggregate::CountDistinctKernel<UInt16Type>::Add(uint16(), func);
+ aggregate::CountDistinctKernel<UInt32Type>::Add(uint32(), func);
+ aggregate::CountDistinctKernel<UInt64Type>::Add(uint64(), func);
+ aggregate::CountDistinctKernel<HalfFloatType>::Add(float16(), func);
+ aggregate::CountDistinctKernel<FloatType>::Add(float32(), func);
+ aggregate::CountDistinctKernel<DoubleType>::Add(float64(), func);
+ // Date
+ aggregate::CountDistinctKernel<Date32Type>::Add(date32(), func);
+ aggregate::CountDistinctKernel<Date64Type>::Add(date64(), func);
+ // Time
+
aggregate::CountDistinctKernel<Time32Type>::Add(match::SameTypeId(Type::TIME32),
func);
+
aggregate::CountDistinctKernel<Time64Type>::Add(match::SameTypeId(Type::TIME64),
func);
+ // Timestamp & Duration
+
aggregate::CountDistinctKernel<TimestampType>::Add(match::SameTypeId(Type::TIMESTAMP),
+ func);
+
aggregate::CountDistinctKernel<DurationType>::Add(match::SameTypeId(Type::DURATION),
+ func);
+ // Interval
+ aggregate::CountDistinctKernel<MonthIntervalType>::Add(month_interval(),
func);
+
aggregate::CountDistinctKernel<DayTimeIntervalType>::Add(day_time_interval(),
func);
+
aggregate::CountDistinctKernel<MonthDayNanoIntervalType>::Add(month_day_nano_interval(),
+ func);
+ // Binary & String
+ aggregate::CountDistinctKernel<BinaryType,
util::string_view>::Add(match::BinaryLike(),
+ func);
+ aggregate::CountDistinctKernel<LargeBinaryType, util::string_view>::Add(
+ match::LargeBinaryLike(), func);
+ // Fixed binary & Decimal
+ aggregate::CountDistinctKernel<FixedSizeBinaryType, util::string_view>::Add(
+ match::FixedSizeBinaryLike(), func);
Review comment:
Thanks, but I as mentioned before: I already look at those helpers and
it seems those are for generating the exec part of the kernel. Do you have
anything specific in mind? If so, maybe this can be a follow-up for a nice to
have/code style.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]