edponce commented on a change in pull request #11257:
URL: https://github.com/apache/arrow/pull/11257#discussion_r718153840



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+      : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 
0)) {}
+
+  Status Consume(KernelContext*, const ExecBatch& batch) override {
+    if (batch[0].is_array()) {
+      const ArrayData& arr = *batch[0].array();
+      auto visit_null = [&]() {
+        if (this->nulls > 0) return Status::OK();
+        ++this->nulls;
+        return Status::OK();
+      };
+      auto visit_value = [&](typename Type::c_type arg) {
+        int y;

Review comment:
       Nit: Instead of using `typename Type::c_type`, we can define an alias at 
top of class
   ```c++
   using c_type = typename Type::c_type
   ```
   but since it is only used here then it is ok. It is just a common pattern in 
Arrow codebase.

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+      : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 
0)) {}
+
+  Status Consume(KernelContext*, const ExecBatch& batch) override {
+    if (batch[0].is_array()) {
+      const ArrayData& arr = *batch[0].array();
+      auto visit_null = [&]() {
+        if (this->nulls > 0) return Status::OK();
+        ++this->nulls;
+        return Status::OK();
+      };
+      auto visit_value = [&](typename Type::c_type arg) {
+        int y;
+        RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+        return Status::OK();
+      };
+      RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+      this->non_nulls += this->memo_table_->size();
+    } else {
+      const Scalar& input = *batch[0].scalar();
+      this->nulls += !input.is_valid * batch.length;
+      this->non_nulls += input.is_valid * batch.length;
+    }

Review comment:
       If `batch.length > 1`, then `this->nulls` can be greater than 1 for 
which will result in an incorrect value in `Finalize()`. We can limit the value 
to 1 in `Finalize()`.

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -754,6 +839,30 @@ void RegisterScalarAggregateBasic(FunctionRegistry* 
registry) {
                aggregate::CountInit, func.get());
   DCHECK_OK(registry->AddFunction(std::move(func)));
 
+  func = std::make_shared<ScalarAggregateFunction>(
+      "count_distinct", Arity::Unary(), &count_distinct_doc, 
&default_count_options);
+
+  // Takes any input, outputs int64 scalar
+  aggregate::AddCountDistinctKernel<Int8Type>(int8(), func.get());
+  aggregate::AddCountDistinctKernel<Int16Type>(int16(), func.get());
+  aggregate::AddCountDistinctKernel<Int32Type>(int32(), func.get());
+  aggregate::AddCountDistinctKernel<Date32Type>(date32(), func.get());
+  aggregate::AddCountDistinctKernel<Int64Type>(int64(), func.get());

Review comment:
       You are providing explicit input types because you need them for the 
[`ArrayDataInlineVisitor`](https://github.com/apache/arrow/blob/master/cpp/src/arrow/visitor_inline.h#L141)
 invoked via 
[`VisitorArrayDataInline`](https://github.com/apache/arrow/blob/master/cpp/src/arrow/visitor_inline.h#L326).
 Note that `ArrayDataInlineVisitor` only checks for the base type, so we can 
possibly register less number of kernels by leveraging [implicit 
casting](https://arrow.apache.org/docs/cpp/compute.html#implicit-casts).

##########
File path: python/pyarrow/tests/test_compute.py
##########
@@ -2217,3 +2217,12 @@ def test_list_element():
     result = pa.compute.list_element(lists, index)
     expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type)
     assert result.equals(expected)
+
+
+def test_count_distinct():
+    seed = datetime.now()
+    samples = [seed.replace(year=y) for y in range(1992, 2092)]
+    arr = pa.array(samples, pa.timestamp("ns"))
+    result = pa.compute.count_distinct(arr)
+    expected = pa.scalar(len(samples), type=pa.int64())

Review comment:
       Add tests using the `CountOptions`, see 
[`count`](https://github.com/apache/arrow/blob/master/python/pyarrow/tests/test_compute.py#L1839).

##########
File path: cpp/src/arrow/compute/kernels/aggregate_test.cc
##########
@@ -873,6 +873,65 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) 
{
   }
 }
 
+//
+// Count Distinct
+//
+
+class TestCountDistinctKernel : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    only_valid = CountOptions(CountOptions::ONLY_VALID);
+    only_null = CountOptions(CountOptions::ONLY_NULL);
+    all = CountOptions(CountOptions::ALL);
+  }
+
+  const Datum& expected(int64_t value) {
+    expected_values[value] = Datum(static_cast<int64_t>(value));
+    return expected_values.at(value);
+  }
+
+  CountOptions only_valid;
+  CountOptions only_null;
+  CountOptions all;
+
+ private:
+  std::map<int64_t, Datum> expected_values;
+};
+
+TEST_F(TestCountDistinctKernel, NumericArrowTypesWithNulls) {
+  auto sample = "[1, 1, 2, 2, 5, 8, 9, 9, 9, 10, 6, 6]";
+  auto sample_nulls = "[null, 8, null, null, 6, null, 8]";
+  for (auto ty : NumericTypes()) {
+    auto input = ArrayFromJSON(ty, sample);
+    CheckScalar("count_distinct", {input}, expected(7), &only_valid);
+    CheckScalar("count_distinct", {input}, expected(0), &only_null);
+    CheckScalar("count_distinct", {input}, expected(7), &all);
+    auto input_nulls = ArrayFromJSON(ty, sample_nulls);
+    CheckScalar("count_distinct", {input_nulls}, expected(2), &only_valid);
+    CheckScalar("count_distinct", {input_nulls}, expected(1), &only_null);
+    CheckScalar("count_distinct", {input_nulls}, expected(3), &all);
+  }
+}
+
+TEST_F(TestCountDistinctKernel, RandomValidsStdMap) {
+  UInt32Builder builder;
+  std::map<uint32_t, int64_t> hashmap;
+  auto visit_null = [&]() { return Status::OK(); };
+  auto visit_value = [&](uint32_t arg) {

Review comment:
       `visit_null` does not needs the ampersand, there are no variables being 
binded.

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+      : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 
0)) {}
+
+  Status Consume(KernelContext*, const ExecBatch& batch) override {
+    if (batch[0].is_array()) {
+      const ArrayData& arr = *batch[0].array();
+      auto visit_null = [&]() {
+        if (this->nulls > 0) return Status::OK();
+        ++this->nulls;
+        return Status::OK();
+      };

Review comment:
       Since we need to traverse the array for building the `memo_table` and if 
we accumulate the `nulls`, then we could potentially update the null count of 
`arr` as a side-effect, but ... allow such side-effect? cc @lidavidm 

##########
File path: cpp/src/arrow/compute/kernels/aggregate_test.cc
##########
@@ -873,6 +873,65 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) 
{
   }
 }
 
+//
+// Count Distinct
+//
+
+class TestCountDistinctKernel : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    only_valid = CountOptions(CountOptions::ONLY_VALID);
+    only_null = CountOptions(CountOptions::ONLY_NULL);
+    all = CountOptions(CountOptions::ALL);
+  }
+
+  const Datum& expected(int64_t value) {
+    expected_values[value] = Datum(static_cast<int64_t>(value));
+    return expected_values.at(value);
+  }
+
+  CountOptions only_valid;
+  CountOptions only_null;
+  CountOptions all;
+
+ private:
+  std::map<int64_t, Datum> expected_values;
+};
+
+TEST_F(TestCountDistinctKernel, NumericArrowTypesWithNulls) {
+  auto sample = "[1, 1, 2, 2, 5, 8, 9, 9, 9, 10, 6, 6]";

Review comment:
       Also add tests for non-numeric types.

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+      : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 
0)) {}
+
+  Status Consume(KernelContext*, const ExecBatch& batch) override {
+    if (batch[0].is_array()) {
+      const ArrayData& arr = *batch[0].array();
+      auto visit_null = [&]() {
+        if (this->nulls > 0) return Status::OK();
+        ++this->nulls;
+        return Status::OK();
+      };
+      auto visit_value = [&](typename Type::c_type arg) {
+        int y;
+        RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+        return Status::OK();
+      };
+      RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+      this->non_nulls += this->memo_table_->size();
+    } else {
+      const Scalar& input = *batch[0].scalar();
+      this->nulls += !input.is_valid * batch.length;
+      this->non_nulls += input.is_valid * batch.length;
+    }
+    return Status::OK();
+  }
+
+  Status MergeFrom(KernelContext*, KernelState&& src) override {
+    const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+    this->non_nulls += other_state.non_nulls;
+    this->nulls += other_state.nulls;
+    return Status::OK();
+  }
+
+  Status Finalize(KernelContext* ctx, Datum* out) override {
+    const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+    switch (state.options.mode) {
+      case CountOptions::ONLY_VALID:
+        *out = Datum(state.non_nulls);
+        break;
+      case CountOptions::ALL:
+        *out = Datum(state.non_nulls + state.nulls);
+        break;

Review comment:
       If `this->nulls` are being accumulated then use `std::min(state.nulls, 
1)` for its value.

##########
File path: cpp/src/arrow/compute/kernels/aggregate_test.cc
##########
@@ -873,6 +873,65 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) 
{
   }
 }
 
+//
+// Count Distinct
+//
+
+class TestCountDistinctKernel : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    only_valid = CountOptions(CountOptions::ONLY_VALID);
+    only_null = CountOptions(CountOptions::ONLY_NULL);
+    all = CountOptions(CountOptions::ALL);
+  }
+
+  const Datum& expected(int64_t value) {
+    expected_values[value] = Datum(static_cast<int64_t>(value));
+    return expected_values.at(value);

Review comment:
       Not sure I understand the rationale here. You are basically storing and 
returning `Datum(value)`?

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;

Review comment:
       There could be a base class for aggregate `CountImpl` that includes 
`options, non_nulls, nulls, and MergeFrom()` with virtual `Consume() and 
Finalize()`, but maybe it is overkill.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to