aucahuasi commented on a change in pull request #11257:
URL: https://github.com/apache/arrow/pull/11257#discussion_r720709247



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,168 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type, typename VisitorArgType>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+      : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 
0)) {}
+
+  Status Consume(KernelContext*, const ExecBatch& batch) override {
+    if (batch[0].is_array()) {
+      const ArrayData& arr = *batch[0].array();
+      auto visit_null = []() { return Status::OK(); };
+      auto visit_value = [&](VisitorArgType arg) {
+        int y;
+        RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+        return Status::OK();
+      };
+      RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+      this->non_nulls += memo_table_->size();
+      this->has_nulls = arr.GetNullCount() > 0;
+    } else {
+      const Scalar& input = *batch[0].scalar();
+      this->has_nulls = !input.is_valid;
+      this->non_nulls += batch.length;
+    }
+    return Status::OK();
+  }
+
+  Status MergeFrom(KernelContext*, KernelState&& src) override {
+    const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+    this->non_nulls += other_state.non_nulls;
+    this->has_nulls = this->has_nulls || other_state.has_nulls;
+    return Status::OK();
+  }
+
+  Status Finalize(KernelContext* ctx, Datum* out) override {
+    const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+    const int64_t nulls = state.has_nulls ? 1 : 0;
+    switch (state.options.mode) {
+      case CountOptions::ONLY_VALID:
+        *out = Datum(state.non_nulls);
+        break;
+      case CountOptions::ALL:
+        *out = Datum(state.non_nulls + nulls);
+        break;
+      case CountOptions::ONLY_NULL:
+        *out = Datum(nulls);
+        break;
+      default:
+        DCHECK(false) << "unreachable";
+    }
+    return Status::OK();
+  }
+
+  const CountOptions options;
+  int64_t non_nulls = 0;
+  bool has_nulls = false;
+  std::unique_ptr<MemoTable> memo_table_;
+};
+
+template <typename Type, typename VisitorArgType>
+Result<std::unique_ptr<KernelState>> CountDistinctInit(KernelContext* ctx,
+                                                       const KernelInitArgs& 
args) {
+  return ::arrow::internal::make_unique<CountDistinctImpl<Type, 
VisitorArgType>>(
+      ctx->memory_pool(), static_cast<const CountOptions&>(*args.options));
+}
+
+template <typename Type, typename VisitorArgType>
+void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) {
+  AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())),
+               aggregate::CountDistinctInit<Type, VisitorArgType>, func);
+}
+
+template <typename Type, typename VisitorArgType = typename Type::c_type>
+struct CountDistinctKernel {
+  static void Add(InputType type, ScalarAggregateFunction* func) {
+    using PhysicalType = typename Type::PhysicalType;
+    AddCountDistinctKernel<PhysicalType, VisitorArgType>(type, func);
+  }
+};
+
+template <>
+struct CountDistinctKernel<FixedSizeBinaryType, util::string_view> {
+  static void Add(InputType type, ScalarAggregateFunction* func) {
+    AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(type, func);
+  }
+};
+
+template <>
+struct CountDistinctKernel<DecimalType, util::string_view> {
+  static void Add(InputType type, ScalarAggregateFunction* func) {
+    AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(type, func);
+  }
+};
+
+template <>
+struct CountDistinctKernel<Decimal128Type, util::string_view> {
+  static void Add(InputType type, ScalarAggregateFunction* func) {
+    AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(type, func);
+  }
+};
+
+template <>
+struct CountDistinctKernel<Decimal256Type, util::string_view> {
+  static void Add(InputType type, ScalarAggregateFunction* func) {
+    AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(type, func);
+  }

Review comment:
       Done!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to