This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new af4db7731b ARROW-16807: [C++][R] count distinct incorrectly merges
state (#13583)
af4db7731b is described below
commit af4db7731b1f857e78221c53c2d8221849b1eeec
Author: octalene <[email protected]>
AuthorDate: Sat Jul 16 14:45:27 2022 -0700
ARROW-16807: [C++][R] count distinct incorrectly merges state (#13583)
This addresses a bug where the `count_distinct` function simply added
counts when merging state. The correct logic would be to return the number of
distinct elements after both states have been merged.
State for count_distinct is backed by a MemoTable, which is then backed by
a HashTable. To properly merge state, this PR adds 2 functions to each
MemoTable: `MaybeInsert` and `MergeTable`. The MaybeInsert function handles
simplified logic for inserting an element into the MemoTable. The MergeTable
function handles iteration over elements in the MemoTable _to be merged_.
This PR also adds an R test and a C++ test. The R test mirrors what was
provided in ARROW-16807. The C++ test, `AllChunkedArrayTypesWithNulls`, mirrors
another C++ test, `AllArrayTypesWithNulls`, but uses chunked arrays for test
data.
Lead-authored-by: Aldrin Montana <[email protected]>
Co-authored-by: Aldrin M <[email protected]>
Co-authored-by: Wes McKinney <[email protected]>
Signed-off-by: Wes McKinney <[email protected]>
---
cpp/src/arrow/compute/kernels/aggregate_basic.cc | 17 ++++--
cpp/src/arrow/compute/kernels/aggregate_test.cc | 72 ++++++++++++++++++++++++
cpp/src/arrow/compute/kernels/codegen_internal.h | 2 +-
cpp/src/arrow/util/hashing.h | 32 +++++++++++
r/tests/testthat/test-dplyr-summarize.R | 9 +++
5 files changed, 126 insertions(+), 6 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index 57cee87f00..fec483318e 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -136,27 +136,34 @@ struct CountDistinctImpl : public ScalarAggregator {
Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
const ArrayData& arr = *batch[0].array();
+ this->has_nulls = arr.GetNullCount() > 0;
+
auto visit_null = []() { return Status::OK(); };
auto visit_value = [&](VisitorArgType arg) {
- int y;
+ int32_t y;
return memo_table_->GetOrInsert(arg, &y);
};
RETURN_NOT_OK(VisitArraySpanInline<Type>(arr, visit_value, visit_null));
- this->non_nulls += memo_table_->size();
- this->has_nulls = arr.GetNullCount() > 0;
+
} else {
const Scalar& input = *batch[0].scalar();
this->has_nulls = !input.is_valid;
+
if (input.is_valid) {
- this->non_nulls += batch.length;
+ int32_t unused;
+
RETURN_NOT_OK(memo_table_->GetOrInsert(UnboxScalar<Type>::Unbox(input),
&unused));
}
}
+
+ this->non_nulls = memo_table_->size();
+
return Status::OK();
}
Status MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
- this->non_nulls += other_state.non_nulls;
+ RETURN_NOT_OK(this->memo_table_->MergeTable(*(other_state.memo_table_)));
+ this->non_nulls = this->memo_table_->size();
this->has_nulls = this->has_nulls || other_state.has_nulls;
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc
b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index aa54fe5f3e..abd5b5210a 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -962,11 +962,83 @@ class TestCountDistinctKernel : public ::testing::Test {
EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
}
+ void CheckChunkedArr(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& json, int64_t
expected_all,
+ bool has_nulls = true) {
+ Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls);
+ }
+
CountOptions only_valid{CountOptions::ONLY_VALID};
CountOptions only_null{CountOptions::ONLY_NULL};
CountOptions all{CountOptions::ALL};
};
+TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) {
+ // Boolean
+ CheckChunkedArr(boolean(), {"[]", "[]"}, 0, /*has_nulls=*/false);
+ CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]",
"[true]"}, 3);
+
+ // Number
+ for (auto ty : NumericTypes()) {
+ CheckChunkedArr(ty, {"[1, 1, null, 2]", "[5, 8, 9, 9, null, 10]", "[6, 6,
8, 9, 10]"},
+ 8);
+ CheckChunkedArr(ty, {"[1, 1, 8, 2]", "[5, 8, 9, 9, 10]", "[10, 6, 6]"}, 7,
+ /*has_nulls=*/false);
+ }
+
+ // Date
+ CheckChunkedArr(date32(), {"[0, 11016]", "[0, null, 14241, 14241, null]"},
4);
+ CheckChunkedArr(date64(), {"[0, null]", "[0, null, 0, 0, 1262217600000]"},
3);
+
+ // Time
+ CheckChunkedArr(time32(TimeUnit::SECOND), {"[ 0, 11, 0, null]", "[14, 14,
null]"}, 4);
+ CheckChunkedArr(time32(TimeUnit::MILLI), {"[ 0, 11000, 0]", "[null, 11000,
11000]"}, 3);
+
+ CheckChunkedArr(time64(TimeUnit::MICRO), {"[84203999999, 0, null,
84203999999]", "[0]"},
+ 3);
+ CheckChunkedArr(time64(TimeUnit::NANO),
+ {"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"}, 3);
+
+ // Timestamp & Duration
+ for (auto u : TimeUnit::values()) {
+ CheckChunkedArr(duration(u), {"[123456789, null, 987654321]", "[123456789,
null]"},
+ 3);
+
+ CheckChunkedArr(duration(u),
+ {"[123456789, 987654321, 123456789, 123456789]",
"[123456789]"}, 2,
+ /*has_nulls=*/false);
+
+ auto ts =
+ std::vector<std::string>{R"(["2009-12-31T04:20:20",
"2009-12-31T04:20:20"])",
+ R"(["2020-01-01", null])", R"(["2020-01-01",
null])"};
+ CheckChunkedArr(timestamp(u), ts, 3);
+ CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3);
+ }
+
+ // Interval
+ CheckChunkedArr(month_interval(), {"[9012, 5678, null, 9012]", "[5678, null,
9012]"},
+ 3);
+ CheckChunkedArr(day_time_interval(),
+ {"[[0, 1], [0, 1]]", "[null, [0, 1], [1234, 5678]]"}, 3);
+ CheckChunkedArr(month_day_nano_interval(),
+ {"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"}, 2);
+
+ // Binary & String & Fixed binary
+ auto samples = std::vector<std::string>{
+ R"([null, "abc", null])", R"(["abc", "abc", "cba"])", R"(["bca", "cba",
null])"};
+
+ CheckChunkedArr(binary(), samples, 4);
+ CheckChunkedArr(large_binary(), samples, 4);
+ CheckChunkedArr(utf8(), samples, 4);
+ CheckChunkedArr(large_utf8(), samples, 4);
+ CheckChunkedArr(fixed_size_binary(3), samples, 4);
+
+ // Decimal
+ samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679",
"98765.421"])"};
+ CheckChunkedArr(decimal128(21, 3), samples, 3);
+ CheckChunkedArr(decimal256(13, 3), samples, 3);
+}
+
TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) {
// Boolean
Check(boolean(), "[]", 0, /*has_nulls=*/false);
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h
b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 1d5f5dd9bd..f008314e8b 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -343,7 +343,7 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
using T = util::string_view;
static T Unbox(const Scalar& val) {
if (!val.is_valid) return util::string_view();
- return util::string_view(*checked_cast<const
BaseBinaryScalar&>(val).value);
+ return checked_cast<const
::arrow::internal::PrimitiveScalarBase&>(val).view();
}
};
diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h
index d2c0178b00..ca5a6c766b 100644
--- a/cpp/src/arrow/util/hashing.h
+++ b/cpp/src/arrow/util/hashing.h
@@ -485,6 +485,20 @@ class ScalarMemoTable : public MemoTable {
hash_t ComputeHash(const Scalar& value) const {
return ScalarHelper<Scalar, 0>::ComputeHash(value);
}
+
+ public:
+ // defined here so that `HashTableType` is visible
+ // Merge entries from `other_table` into `this->hash_table_`.
+ Status MergeTable(const ScalarMemoTable& other_table) {
+ const HashTableType& other_hashtable = other_table.hash_table_;
+
+ other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) {
+ int32_t unused;
+ DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused));
+ });
+ // TODO: ARROW-17074 - implement proper error handling
+ return Status::OK();
+ }
};
// ----------------------------------------------------------------------
@@ -568,6 +582,15 @@ class SmallScalarMemoTable : public MemoTable {
// (which is also 1 + the largest memo index)
int32_t size() const override { return
static_cast<int32_t>(index_to_value_.size()); }
+ // Merge entries from `other_table` into `this`.
+ Status MergeTable(const SmallScalarMemoTable& other_table) {
+ for (const Scalar& other_val : other_table.index_to_value_) {
+ int32_t unused;
+ RETURN_NOT_OK(this->GetOrInsert(other_val, &unused));
+ }
+ return Status::OK();
+ }
+
// Copy values starting from index `start` into `out_data`
void CopyValues(int32_t start, Scalar* out_data) const {
DCHECK_GE(start, 0);
@@ -824,6 +847,15 @@ class BinaryMemoTable : public MemoTable {
};
return hash_table_.Lookup(h, cmp_func);
}
+
+ public:
+ Status MergeTable(const BinaryMemoTable& other_table) {
+ other_table.VisitValues(0, [this](const util::string_view& other_value) {
+ int32_t unused;
+ DCHECK_OK(this->GetOrInsert(other_value, &unused));
+ });
+ return Status::OK();
+ }
};
template <typename T, typename Enable = void>
diff --git a/r/tests/testthat/test-dplyr-summarize.R
b/r/tests/testthat/test-dplyr-summarize.R
index c2207a1f27..3711b49975 100644
--- a/r/tests/testthat/test-dplyr-summarize.R
+++ b/r/tests/testthat/test-dplyr-summarize.R
@@ -236,6 +236,15 @@ test_that("Group by any/all", {
)
})
+test_that("n_distinct() with many batches", {
+ tf <- tempfile()
+ write_parquet(dplyr::starwars, tf, chunk_size = 20)
+
+ ds <- open_dataset(tf)
+ expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(),
+ ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)))
+})
+
test_that("n_distinct() on dataset", {
# With group_by
compare_dplyr_binding(