This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new af4db7731b ARROW-16807: [C++][R] count distinct incorrectly merges 
state (#13583)
af4db7731b is described below

commit af4db7731b1f857e78221c53c2d8221849b1eeec
Author: octalene <octalene....@pm.me>
AuthorDate: Sat Jul 16 14:45:27 2022 -0700

    ARROW-16807: [C++][R] count distinct incorrectly merges state (#13583)
    
    This addresses a bug where the `count_distinct` function simply added 
counts when merging state. The correct logic would be to return the number of 
distinct elements after both states have been merged.
    
    State for count_distinct is backed by a MemoTable, which is then backed by 
a HashTable. To properly merge state, this PR adds 2 functions to each 
MemoTable: `MaybeInsert` and `MergeTable`. The MaybeInsert function handles 
simplified logic for inserting an element into the MemoTable. The MergeTable 
function handles iteration over elements in the MemoTable _to be merged_.
    
    This PR also adds an R test and a C++ test. The R test mirrors what was 
provided in ARROW-16807. The C++ test, `AllChunkedArrayTypesWithNulls`, mirrors 
another C++ test, `AllArrayTypesWithNulls`, but uses chunked arrays for test 
data.
    
    Lead-authored-by: Aldrin Montana <octalene....@pm.me>
    Co-authored-by: Aldrin M <octalene....@pm.me>
    Co-authored-by: Wes McKinney <w...@apache.org>
    Signed-off-by: Wes McKinney <w...@apache.org>
---
 cpp/src/arrow/compute/kernels/aggregate_basic.cc | 17 ++++--
 cpp/src/arrow/compute/kernels/aggregate_test.cc  | 72 ++++++++++++++++++++++++
 cpp/src/arrow/compute/kernels/codegen_internal.h |  2 +-
 cpp/src/arrow/util/hashing.h                     | 32 +++++++++++
 r/tests/testthat/test-dplyr-summarize.R          |  9 +++
 5 files changed, 126 insertions(+), 6 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc 
b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index 57cee87f00..fec483318e 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -136,27 +136,34 @@ struct CountDistinctImpl : public ScalarAggregator {
   Status Consume(KernelContext*, const ExecBatch& batch) override {
     if (batch[0].is_array()) {
       const ArrayData& arr = *batch[0].array();
+      this->has_nulls = arr.GetNullCount() > 0;
+
       auto visit_null = []() { return Status::OK(); };
       auto visit_value = [&](VisitorArgType arg) {
-        int y;
+        int32_t y;
         return memo_table_->GetOrInsert(arg, &y);
       };
       RETURN_NOT_OK(VisitArraySpanInline<Type>(arr, visit_value, visit_null));
-      this->non_nulls += memo_table_->size();
-      this->has_nulls = arr.GetNullCount() > 0;
+
     } else {
       const Scalar& input = *batch[0].scalar();
       this->has_nulls = !input.is_valid;
+
       if (input.is_valid) {
-        this->non_nulls += batch.length;
+        int32_t unused;
+        
RETURN_NOT_OK(memo_table_->GetOrInsert(UnboxScalar<Type>::Unbox(input), 
&unused));
       }
     }
+
+    this->non_nulls = memo_table_->size();
+
     return Status::OK();
   }
 
   Status MergeFrom(KernelContext*, KernelState&& src) override {
     const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
-    this->non_nulls += other_state.non_nulls;
+    RETURN_NOT_OK(this->memo_table_->MergeTable(*(other_state.memo_table_)));
+    this->non_nulls = this->memo_table_->size();
     this->has_nulls = this->has_nulls || other_state.has_nulls;
     return Status::OK();
   }
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc 
b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index aa54fe5f3e..abd5b5210a 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -962,11 +962,83 @@ class TestCountDistinctKernel : public ::testing::Test {
     EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
   }
 
+  void CheckChunkedArr(const std::shared_ptr<DataType>& type,
+                       const std::vector<std::string>& json, int64_t 
expected_all,
+                       bool has_nulls = true) {
+    Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls);
+  }
+
   CountOptions only_valid{CountOptions::ONLY_VALID};
   CountOptions only_null{CountOptions::ONLY_NULL};
   CountOptions all{CountOptions::ALL};
 };
 
+TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) {
+  // Boolean
+  CheckChunkedArr(boolean(), {"[]", "[]"}, 0, /*has_nulls=*/false);
+  CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]", 
"[true]"}, 3);
+
+  // Number
+  for (auto ty : NumericTypes()) {
+    CheckChunkedArr(ty, {"[1, 1, null, 2]", "[5, 8, 9, 9, null, 10]", "[6, 6, 
8, 9, 10]"},
+                    8);
+    CheckChunkedArr(ty, {"[1, 1, 8, 2]", "[5, 8, 9, 9, 10]", "[10, 6, 6]"}, 7,
+                    /*has_nulls=*/false);
+  }
+
+  // Date
+  CheckChunkedArr(date32(), {"[0, 11016]", "[0, null, 14241, 14241, null]"}, 
4);
+  CheckChunkedArr(date64(), {"[0, null]", "[0, null, 0, 0, 1262217600000]"}, 
3);
+
+  // Time
+  CheckChunkedArr(time32(TimeUnit::SECOND), {"[ 0, 11, 0, null]", "[14, 14, 
null]"}, 4);
+  CheckChunkedArr(time32(TimeUnit::MILLI), {"[ 0, 11000, 0]", "[null, 11000, 
11000]"}, 3);
+
+  CheckChunkedArr(time64(TimeUnit::MICRO), {"[84203999999, 0, null, 
84203999999]", "[0]"},
+                  3);
+  CheckChunkedArr(time64(TimeUnit::NANO),
+                  {"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"}, 3);
+
+  // Timestamp & Duration
+  for (auto u : TimeUnit::values()) {
+    CheckChunkedArr(duration(u), {"[123456789, null, 987654321]", "[123456789, 
null]"},
+                    3);
+
+    CheckChunkedArr(duration(u),
+                    {"[123456789, 987654321, 123456789, 123456789]", 
"[123456789]"}, 2,
+                    /*has_nulls=*/false);
+
+    auto ts =
+        std::vector<std::string>{R"(["2009-12-31T04:20:20", 
"2009-12-31T04:20:20"])",
+                                 R"(["2020-01-01", null])", R"(["2020-01-01", 
null])"};
+    CheckChunkedArr(timestamp(u), ts, 3);
+    CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3);
+  }
+
+  // Interval
+  CheckChunkedArr(month_interval(), {"[9012, 5678, null, 9012]", "[5678, null, 
9012]"},
+                  3);
+  CheckChunkedArr(day_time_interval(),
+                  {"[[0, 1], [0, 1]]", "[null, [0, 1], [1234, 5678]]"}, 3);
+  CheckChunkedArr(month_day_nano_interval(),
+                  {"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"}, 2);
+
+  // Binary & String & Fixed binary
+  auto samples = std::vector<std::string>{
+      R"([null, "abc", null])", R"(["abc", "abc", "cba"])", R"(["bca", "cba", 
null])"};
+
+  CheckChunkedArr(binary(), samples, 4);
+  CheckChunkedArr(large_binary(), samples, 4);
+  CheckChunkedArr(utf8(), samples, 4);
+  CheckChunkedArr(large_utf8(), samples, 4);
+  CheckChunkedArr(fixed_size_binary(3), samples, 4);
+
+  // Decimal
+  samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", 
"98765.421"])"};
+  CheckChunkedArr(decimal128(21, 3), samples, 3);
+  CheckChunkedArr(decimal256(13, 3), samples, 3);
+}
+
 TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) {
   // Boolean
   Check(boolean(), "[]", 0, /*has_nulls=*/false);
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h 
b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 1d5f5dd9bd..f008314e8b 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -343,7 +343,7 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
   using T = util::string_view;
   static T Unbox(const Scalar& val) {
     if (!val.is_valid) return util::string_view();
-    return util::string_view(*checked_cast<const 
BaseBinaryScalar&>(val).value);
+    return checked_cast<const 
::arrow::internal::PrimitiveScalarBase&>(val).view();
   }
 };
 
diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h
index d2c0178b00..ca5a6c766b 100644
--- a/cpp/src/arrow/util/hashing.h
+++ b/cpp/src/arrow/util/hashing.h
@@ -485,6 +485,20 @@ class ScalarMemoTable : public MemoTable {
   hash_t ComputeHash(const Scalar& value) const {
     return ScalarHelper<Scalar, 0>::ComputeHash(value);
   }
+
+ public:
+  // defined here so that `HashTableType` is visible
+  // Merge entries from `other_table` into `this->hash_table_`.
+  Status MergeTable(const ScalarMemoTable& other_table) {
+    const HashTableType& other_hashtable = other_table.hash_table_;
+
+    other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) {
+      int32_t unused;
+      DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused));
+    });
+    // TODO: ARROW-17074 - implement proper error handling
+    return Status::OK();
+  }
 };
 
 // ----------------------------------------------------------------------
@@ -568,6 +582,15 @@ class SmallScalarMemoTable : public MemoTable {
   // (which is also 1 + the largest memo index)
   int32_t size() const override { return 
static_cast<int32_t>(index_to_value_.size()); }
 
+  // Merge entries from `other_table` into `this`.
+  Status MergeTable(const SmallScalarMemoTable& other_table) {
+    for (const Scalar& other_val : other_table.index_to_value_) {
+      int32_t unused;
+      RETURN_NOT_OK(this->GetOrInsert(other_val, &unused));
+    }
+    return Status::OK();
+  }
+
   // Copy values starting from index `start` into `out_data`
   void CopyValues(int32_t start, Scalar* out_data) const {
     DCHECK_GE(start, 0);
@@ -824,6 +847,15 @@ class BinaryMemoTable : public MemoTable {
     };
     return hash_table_.Lookup(h, cmp_func);
   }
+
+ public:
+  Status MergeTable(const BinaryMemoTable& other_table) {
+    other_table.VisitValues(0, [this](const util::string_view& other_value) {
+      int32_t unused;
+      DCHECK_OK(this->GetOrInsert(other_value, &unused));
+    });
+    return Status::OK();
+  }
 };
 
 template <typename T, typename Enable = void>
diff --git a/r/tests/testthat/test-dplyr-summarize.R 
b/r/tests/testthat/test-dplyr-summarize.R
index c2207a1f27..3711b49975 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -236,6 +236,15 @@ test_that("Group by any/all", {
   )
 })
 
+test_that("n_distinct() with many batches", {
+  tf <- tempfile()
+  write_parquet(dplyr::starwars, tf, chunk_size = 20)
+
+  ds <- open_dataset(tf)
+  expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(),
+               ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)))
+})
+
 test_that("n_distinct() on dataset", {
   # With group_by
   compare_dplyr_binding(

Reply via email to