This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new be12888997 GH-35059: [C++] Fix "hash_count" for run-end encoded inputs
(#35129)
be12888997 is described below
commit be12888997c81b1fb7947f6284be1256edd4d3e4
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Mon May 1 12:19:17 2023 -0300
GH-35059: [C++] Fix "hash_count" for run-end encoded inputs (#35129)
### Rationale for this change
Fixing a bug.
### What changes are included in this PR?
Changes to the `"hash_count"` kernel implementation to handle REE and union
arrays correctly.
- [x] Generic (potentially slow) implementation
- [x] REE-specialized implementation
### Are these changes tested?
Yes, by modifying the existing unit tests.
* Closes: #35059
Authored-by: Felipe Oliveira Carvalho <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/acero/hash_aggregate_test.cc | 117 ++++++++++++++++--------
cpp/src/arrow/array/data.cc | 2 +-
cpp/src/arrow/compute/kernels/hash_aggregate.cc | 102 ++++++++++++++++++---
cpp/src/arrow/testing/gtest_util.cc | 19 ++++
cpp/src/arrow/testing/gtest_util.h | 4 +
5 files changed, 192 insertions(+), 52 deletions(-)
diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc
b/cpp/src/arrow/acero/hash_aggregate_test.cc
index 144098e169..02e67927cc 100644
--- a/cpp/src/arrow/acero/hash_aggregate_test.cc
+++ b/cpp/src/arrow/acero/hash_aggregate_test.cc
@@ -1361,46 +1361,87 @@ void SortBy(std::vector<std::string> names, Datum*
aggregated_and_grouped) {
} // namespace
TEST_P(GroupBy, CountOnly) {
- for (bool use_threads : {true, false}) {
- SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
-
- auto table =
- TableFromJSON(schema({field("argument", float64()), field("key",
int64())}), {R"([
- [1.0, 1],
- [null, 1]
- ])",
-
R"([
- [0.0, 2],
- [null, 3],
- [4.0, null],
- [3.25, 1],
- [0.125, 2]
- ])",
-
R"([
- [-0.25, 2],
- [0.75, null],
- [null, 3]
- ])"});
+ const std::vector<std::string> json = {
+ // Test inputs ("argument", "key")
+ R"([[1.0, 1],
+ [null, 1]])",
+ R"([[0.0, 2],
+ [null, 3],
+ [null, 2],
+ [4.0, null],
+ [3.25, 1],
+ [3.25, 1],
+ [0.125, 2]])",
+ R"([[-0.25, 2],
+ [0.75, null],
+ [null, 3]])",
+ };
+ const auto skip_nulls =
std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
+ const auto only_nulls =
std::make_shared<CountOptions>(CountOptions::ONLY_NULL);
+ const auto count_all = std::make_shared<CountOptions>(CountOptions::ALL);
+ const auto possible_count_options =
std::vector<std::shared_ptr<CountOptions>>{
+ nullptr, // default = skip_nulls
+ skip_nulls,
+ only_nulls,
+ count_all,
+ };
+ const auto expected_results = std::vector<std::string>{
+ // Results ("key_0", "hash_count")
+ // nullptr = skip_nulls
+ R"([[1, 3],
+ [2, 3],
+ [3, 0],
+ [null, 2]])",
+ // skip_nulls
+ R"([[1, 3],
+ [2, 3],
+ [3, 0],
+ [null, 2]])",
+ // only_nulls
+ R"([[1, 1],
+ [2, 1],
+ [3, 2],
+ [null, 0]])",
+ // count_all
+ R"([[1, 4],
+ [2, 4],
+ [3, 2],
+ [null, 2]])",
+ };
+ // NOTE: the "key" column (1) does not appear in the possible run-end
+ // encoding transformations because GroupBy kernels do not support run-end
+ // encoded key arrays.
+ for (const auto& re_encode_cols : std::vector<std::vector<int>>{{}, {0}}) {
+ for (bool use_threads : {/*true, */ false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ for (size_t i = 0; i < possible_count_options.size(); i++) {
+ SCOPED_TRACE(possible_count_options[i] ?
possible_count_options[i]->ToString()
+ : "default");
+ auto table = TableFromJSON(
+ schema({field("argument", float64()), field("key", int64())}),
json);
+
+ auto transformed_table = table;
+ if (!re_encode_cols.empty()) {
+ ASSERT_OK_AND_ASSIGN(transformed_table,
+ RunEndEncodeTableColumns(*table,
re_encode_cols));
+ }
- ASSERT_OK_AND_ASSIGN(
- Datum aggregated_and_grouped,
- GroupByTest({table->GetColumnByName("argument")},
{table->GetColumnByName("key")},
- {
- {"hash_count", nullptr},
- },
- use_threads));
- SortBy({"key_0"}, &aggregated_and_grouped);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+
GroupByTest({transformed_table->GetColumnByName("argument")},
+
{transformed_table->GetColumnByName("key")},
+ {
+ {"hash_count",
possible_count_options[i]},
+ },
+ use_threads));
+ SortBy({"key_0"}, &aggregated_and_grouped);
- AssertDatumsEqual(
- ArrayFromJSON(struct_({field("key_0", int64()), field("hash_count",
int64())}),
- R"([
- [1, 2],
- [2, 3],
- [3, 0],
- [null, 2]
- ])"),
- aggregated_and_grouped,
- /*verbose=*/true);
+ AssertDatumsEqual(aggregated_and_grouped,
+ ArrayFromJSON(struct_({field("key_0", int64()),
+ field("hash_count",
int64())}),
+ expected_results[i]),
+ /*verbose=*/true);
+ }
+ }
}
}
diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc
index 18c9eb720a..8764e9c354 100644
--- a/cpp/src/arrow/array/data.cc
+++ b/cpp/src/arrow/array/data.cc
@@ -203,7 +203,7 @@ void ArraySpan::SetMembers(const ArrayData& data) {
type_id = ext_type->storage_type()->id();
}
- if (data.buffers[0] == nullptr && type_id != Type::NA &&
+ if ((data.buffers.size() == 0 || data.buffers[0] == nullptr) && type_id !=
Type::NA &&
type_id != Type::SPARSE_UNION && type_id != Type::DENSE_UNION) {
// This should already be zero but we make for sure
this->null_count = 0;
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index 4242680adb..6ab95bedc6 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -45,6 +45,7 @@
#include "arrow/util/cpu_info.h"
#include "arrow/util/int128_internal.h"
#include "arrow/util/int_util_overflow.h"
+#include "arrow/util/ree_util.h"
#include "arrow/util/task_group.h"
#include "arrow/util/tdigest.h"
#include "arrow/util/thread_pool.h"
@@ -302,6 +303,46 @@ struct GroupedCountImpl : public GroupedAggregator {
return Status::OK();
}
+ template <bool count_valid>
+ struct RunEndEncodedCountImpl {
+ /// Count the number of valid or invalid values in a run-end-encoded array.
+ ///
+ /// \param[in] input the run-end-encoded array
+ /// \param[out] counts the counts being accumulated
+ /// \param[in] g the group ids of the values in the array
+ template <typename RunEndCType>
+ void DoCount(const ArraySpan& input, int64_t* counts, const uint32_t* g) {
+ ree_util::RunEndEncodedArraySpan<RunEndCType> ree_span(input);
+ const auto* physical_validity =
ree_util::ValuesArray(input).GetValues<uint8_t>(0);
+ auto end = ree_span.end();
+ for (auto it = ree_span.begin(); it != end; ++it) {
+ const bool is_valid = bit_util::GetBit(physical_validity,
it.index_into_array());
+ if (is_valid == count_valid) {
+ for (int64_t i = 0; i < it.run_length(); ++i, ++g) {
+ counts[*g] += 1;
+ }
+ } else {
+ g += it.run_length();
+ }
+ }
+ }
+
+ void operator()(const ArraySpan& input, int64_t* counts, const uint32_t*
g) {
+ auto ree_type = checked_cast<const RunEndEncodedType*>(input.type);
+ switch (ree_type->run_end_type()->id()) {
+ case Type::INT16:
+ DoCount<int16_t>(input, counts, g);
+ break;
+ case Type::INT32:
+ DoCount<int32_t>(input, counts, g);
+ break;
+ default:
+ DoCount<int64_t>(input, counts, g);
+ break;
+ }
+ }
+ };
+
Status Consume(const ExecSpan& batch) override {
auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
auto g_begin = batch[1].array.GetValues<uint32_t>(1);
@@ -312,26 +353,61 @@ struct GroupedCountImpl : public GroupedAggregator {
}
} else if (batch[0].is_array()) {
const ArraySpan& input = batch[0].array;
- if (options_.mode == CountOptions::ONLY_VALID) {
+ if (options_.mode == CountOptions::ONLY_VALID) { // ONLY_VALID
if (input.type->id() != arrow::Type::NA) {
- arrow::internal::VisitSetBitRunsVoid(
- input.buffers[0].data, input.offset, input.length,
- [&](int64_t offset, int64_t length) {
- auto g = g_begin + offset;
- for (int64_t i = 0; i < length; ++i, ++g) {
- counts[*g] += 1;
- }
- });
+ const uint8_t* bitmap = input.buffers[0].data;
+ if (bitmap) {
+ arrow::internal::VisitSetBitRunsVoid(
+ bitmap, input.offset, input.length, [&](int64_t offset,
int64_t length) {
+ auto g = g_begin + offset;
+ for (int64_t i = 0; i < length; ++i, ++g) {
+ counts[*g] += 1;
+ }
+ });
+ } else {
+ // Array without validity bitmaps require special handling of
nulls.
+ const bool all_valid = !input.MayHaveLogicalNulls();
+ if (all_valid) {
+ for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
+ counts[*g_begin] += 1;
+ }
+ } else {
+ switch (input.type->id()) {
+ case Type::RUN_END_ENCODED:
+ RunEndEncodedCountImpl<true>{}(input, counts, g_begin);
+ break;
+ default: // Generic and forward-compatible version.
+ for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
+ counts[*g_begin] += input.IsValid(i);
+ }
+ break;
+ }
+ }
+ }
}
} else { // ONLY_NULL
if (input.type->id() == arrow::Type::NA) {
for (int64_t i = 0; i < batch.length; ++i, ++g_begin) {
counts[*g_begin] += 1;
}
- } else if (input.MayHaveNulls()) {
- auto end = input.offset + input.length;
- for (int64_t i = input.offset; i < end; ++i, ++g_begin) {
- counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i);
+ } else if (input.MayHaveLogicalNulls()) {
+ if (input.HasValidityBitmap()) {
+ auto end = input.offset + input.length;
+ for (int64_t i = input.offset; i < end; ++i, ++g_begin) {
+ counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i);
+ }
+ } else {
+ // Arrays without validity bitmaps require special handling of
nulls.
+ switch (input.type->id()) {
+ case Type::RUN_END_ENCODED:
+ RunEndEncodedCountImpl<false>{}(input, counts, g_begin);
+ break;
+ default: // Generic and forward-compatible version.
+ for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
+ counts[*g_begin] += input.IsNull(i);
+ }
+ break;
+ }
}
}
}
diff --git a/cpp/src/arrow/testing/gtest_util.cc
b/cpp/src/arrow/testing/gtest_util.cc
index 37c430892d..9569375bda 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -47,6 +47,7 @@
#include "arrow/array.h"
#include "arrow/buffer.h"
+#include "arrow/compute/api_vector.h"
#include "arrow/datum.h"
#include "arrow/ipc/json_simple.h"
#include "arrow/pretty_print.h"
@@ -427,6 +428,24 @@ std::shared_ptr<Table> TableFromJSON(const
std::shared_ptr<Schema>& schema,
return *Table::FromRecordBatches(schema, std::move(batches));
}
+Result<std::shared_ptr<Table>> RunEndEncodeTableColumns(
+ const Table& table, const std::vector<int>& column_indices) {
+ const int num_columns = table.num_columns();
+ std::vector<std::shared_ptr<ChunkedArray>> encoded_columns;
+ encoded_columns.reserve(num_columns);
+ for (int i = 0; i < num_columns; i++) {
+ if (std::find(column_indices.begin(), column_indices.end(), i) !=
+ column_indices.end()) {
+ ARROW_ASSIGN_OR_RAISE(auto run_end_encoded,
compute::RunEndEncode(table.column(i)));
+ DCHECK_EQ(run_end_encoded.kind(), Datum::CHUNKED_ARRAY);
+ encoded_columns.push_back(run_end_encoded.chunked_array());
+ } else {
+ encoded_columns.push_back(table.column(i));
+ }
+ }
+ return Table::Make(table.schema(), std::move(encoded_columns));
+}
+
Result<std::optional<std::string>> PrintArrayDiff(const ChunkedArray& expected,
const ChunkedArray& actual) {
if (actual.Equals(expected)) {
diff --git a/cpp/src/arrow/testing/gtest_util.h
b/cpp/src/arrow/testing/gtest_util.h
index 2708056295..55bd307b12 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -341,6 +341,10 @@ ARROW_TESTING_EXPORT
std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>&,
const std::vector<std::string>& json);
+ARROW_TESTING_EXPORT
+Result<std::shared_ptr<Table>> RunEndEncodeTableColumns(
+ const Table& table, const std::vector<int>& column_indices);
+
// Given an array, return a new identical array except for one validity bit
// set to a new value.
// This is useful to force the underlying "value" of null entries to otherwise