This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new be12888997 GH-35059: [C++] Fix "hash_count" for run-end encoded inputs 
(#35129)
be12888997 is described below

commit be12888997c81b1fb7947f6284be1256edd4d3e4
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Mon May 1 12:19:17 2023 -0300

    GH-35059: [C++] Fix "hash_count" for run-end encoded inputs (#35129)
    
    ### Rationale for this change
    
    Fixing a bug.
    
    ### What changes are included in this PR?
    
    Changes to the `"hash_count"` kernel implementation to handle REE and union 
arrays correctly.
    
    - [x] Generic (potentially slow) implementation
    - [x] REE-specialized implementation
    
    ### Are these changes tested?
    
    Yes, by modifying the existing unit tests.
    
    * Closes: #35059
    
    Authored-by: Felipe Oliveira Carvalho <[email protected]>
    Signed-off-by: Weston Pace <[email protected]>
---
 cpp/src/arrow/acero/hash_aggregate_test.cc      | 117 ++++++++++++++++--------
 cpp/src/arrow/array/data.cc                     |   2 +-
 cpp/src/arrow/compute/kernels/hash_aggregate.cc | 102 ++++++++++++++++++---
 cpp/src/arrow/testing/gtest_util.cc             |  19 ++++
 cpp/src/arrow/testing/gtest_util.h              |   4 +
 5 files changed, 192 insertions(+), 52 deletions(-)

diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc 
b/cpp/src/arrow/acero/hash_aggregate_test.cc
index 144098e169..02e67927cc 100644
--- a/cpp/src/arrow/acero/hash_aggregate_test.cc
+++ b/cpp/src/arrow/acero/hash_aggregate_test.cc
@@ -1361,46 +1361,87 @@ void SortBy(std::vector<std::string> names, Datum* 
aggregated_and_grouped) {
 }  // namespace
 
 TEST_P(GroupBy, CountOnly) {
-  for (bool use_threads : {true, false}) {
-    SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
-
-    auto table =
-        TableFromJSON(schema({field("argument", float64()), field("key", 
int64())}), {R"([
-    [1.0,   1],
-    [null,  1]
-                        ])",
-                                                                               
       R"([
-    [0.0,   2],
-    [null,  3],
-    [4.0,   null],
-    [3.25,  1],
-    [0.125, 2]
-                        ])",
-                                                                               
       R"([
-    [-0.25, 2],
-    [0.75,  null],
-    [null,  3]
-                        ])"});
+  const std::vector<std::string> json = {
+      // Test inputs ("argument", "key")
+      R"([[1.0,   1],
+          [null,  1]])",
+      R"([[0.0,   2],
+          [null,  3],
+          [null,  2],
+          [4.0,   null],
+          [3.25,  1],
+          [3.25,  1],
+          [0.125, 2]])",
+      R"([[-0.25, 2],
+          [0.75,  null],
+          [null,  3]])",
+  };
+  const auto skip_nulls = 
std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
+  const auto only_nulls = 
std::make_shared<CountOptions>(CountOptions::ONLY_NULL);
+  const auto count_all = std::make_shared<CountOptions>(CountOptions::ALL);
+  const auto possible_count_options = 
std::vector<std::shared_ptr<CountOptions>>{
+      nullptr,  // default = skip_nulls
+      skip_nulls,
+      only_nulls,
+      count_all,
+  };
+  const auto expected_results = std::vector<std::string>{
+      // Results ("key_0", "hash_count")
+      // nullptr = skip_nulls
+      R"([[1, 3],
+          [2, 3],
+          [3, 0],
+          [null, 2]])",
+      // skip_nulls
+      R"([[1, 3],
+          [2, 3],
+          [3, 0],
+          [null, 2]])",
+      // only_nulls
+      R"([[1, 1],
+          [2, 1],
+          [3, 2],
+          [null, 0]])",
+      // count_all
+      R"([[1, 4],
+          [2, 4],
+          [3, 2],
+          [null, 2]])",
+  };
+  // NOTE: the "key" column (1) does not appear in the possible run-end
+  // encoding transformations because GroupBy kernels do not support run-end
+  // encoded key arrays.
+  for (const auto& re_encode_cols : std::vector<std::vector<int>>{{}, {0}}) {
+    for (bool use_threads : {/*true, */ false}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+      for (size_t i = 0; i < possible_count_options.size(); i++) {
+        SCOPED_TRACE(possible_count_options[i] ? 
possible_count_options[i]->ToString()
+                                               : "default");
+        auto table = TableFromJSON(
+            schema({field("argument", float64()), field("key", int64())}), 
json);
+
+        auto transformed_table = table;
+        if (!re_encode_cols.empty()) {
+          ASSERT_OK_AND_ASSIGN(transformed_table,
+                               RunEndEncodeTableColumns(*table, 
re_encode_cols));
+        }
 
-    ASSERT_OK_AND_ASSIGN(
-        Datum aggregated_and_grouped,
-        GroupByTest({table->GetColumnByName("argument")}, 
{table->GetColumnByName("key")},
-                    {
-                        {"hash_count", nullptr},
-                    },
-                    use_threads));
-    SortBy({"key_0"}, &aggregated_and_grouped);
+        ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                             
GroupByTest({transformed_table->GetColumnByName("argument")},
+                                         
{transformed_table->GetColumnByName("key")},
+                                         {
+                                             {"hash_count", 
possible_count_options[i]},
+                                         },
+                                         use_threads));
+        SortBy({"key_0"}, &aggregated_and_grouped);
 
-    AssertDatumsEqual(
-        ArrayFromJSON(struct_({field("key_0", int64()), field("hash_count", 
int64())}),
-                      R"([
-    [1, 2],
-    [2, 3],
-    [3, 0],
-    [null, 2]
-  ])"),
-        aggregated_and_grouped,
-        /*verbose=*/true);
+        AssertDatumsEqual(aggregated_and_grouped,
+                          ArrayFromJSON(struct_({field("key_0", int64()),
+                                                 field("hash_count", 
int64())}),
+                                        expected_results[i]),
+                          /*verbose=*/true);
+      }
+    }
   }
 }
 
diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc
index 18c9eb720a..8764e9c354 100644
--- a/cpp/src/arrow/array/data.cc
+++ b/cpp/src/arrow/array/data.cc
@@ -203,7 +203,7 @@ void ArraySpan::SetMembers(const ArrayData& data) {
     type_id = ext_type->storage_type()->id();
   }
 
-  if (data.buffers[0] == nullptr && type_id != Type::NA &&
+  if ((data.buffers.size() == 0 || data.buffers[0] == nullptr) && type_id != 
Type::NA &&
       type_id != Type::SPARSE_UNION && type_id != Type::DENSE_UNION) {
     // This should already be zero but we make for sure
     this->null_count = 0;
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc 
b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index 4242680adb..6ab95bedc6 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -45,6 +45,7 @@
 #include "arrow/util/cpu_info.h"
 #include "arrow/util/int128_internal.h"
 #include "arrow/util/int_util_overflow.h"
+#include "arrow/util/ree_util.h"
 #include "arrow/util/task_group.h"
 #include "arrow/util/tdigest.h"
 #include "arrow/util/thread_pool.h"
@@ -302,6 +303,46 @@ struct GroupedCountImpl : public GroupedAggregator {
     return Status::OK();
   }
 
+  template <bool count_valid>
+  struct RunEndEncodedCountImpl {
+    /// Count the number of valid or invalid values in a run-end-encoded array.
+    ///
+    /// \param[in] input the run-end-encoded array
+    /// \param[out] counts the counts being accumulated
+    /// \param[in] g the group ids of the values in the array
+    template <typename RunEndCType>
+    void DoCount(const ArraySpan& input, int64_t* counts, const uint32_t* g) {
+      ree_util::RunEndEncodedArraySpan<RunEndCType> ree_span(input);
+      const auto* physical_validity = 
ree_util::ValuesArray(input).GetValues<uint8_t>(0);
+      auto end = ree_span.end();
+      for (auto it = ree_span.begin(); it != end; ++it) {
+        const bool is_valid = bit_util::GetBit(physical_validity, 
it.index_into_array());
+        if (is_valid == count_valid) {
+          for (int64_t i = 0; i < it.run_length(); ++i, ++g) {
+            counts[*g] += 1;
+          }
+        } else {
+          g += it.run_length();
+        }
+      }
+    }
+
+    void operator()(const ArraySpan& input, int64_t* counts, const uint32_t* 
g) {
+      auto ree_type = checked_cast<const RunEndEncodedType*>(input.type);
+      switch (ree_type->run_end_type()->id()) {
+        case Type::INT16:
+          DoCount<int16_t>(input, counts, g);
+          break;
+        case Type::INT32:
+          DoCount<int32_t>(input, counts, g);
+          break;
+        default:
+          DoCount<int64_t>(input, counts, g);
+          break;
+      }
+    }
+  };
+
   Status Consume(const ExecSpan& batch) override {
     auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
     auto g_begin = batch[1].array.GetValues<uint32_t>(1);
@@ -312,26 +353,61 @@ struct GroupedCountImpl : public GroupedAggregator {
       }
     } else if (batch[0].is_array()) {
       const ArraySpan& input = batch[0].array;
-      if (options_.mode == CountOptions::ONLY_VALID) {
+      if (options_.mode == CountOptions::ONLY_VALID) {  // ONLY_VALID
         if (input.type->id() != arrow::Type::NA) {
-          arrow::internal::VisitSetBitRunsVoid(
-              input.buffers[0].data, input.offset, input.length,
-              [&](int64_t offset, int64_t length) {
-                auto g = g_begin + offset;
-                for (int64_t i = 0; i < length; ++i, ++g) {
-                  counts[*g] += 1;
-                }
-              });
+          const uint8_t* bitmap = input.buffers[0].data;
+          if (bitmap) {
+            arrow::internal::VisitSetBitRunsVoid(
+                bitmap, input.offset, input.length, [&](int64_t offset, 
int64_t length) {
+                  auto g = g_begin + offset;
+                  for (int64_t i = 0; i < length; ++i, ++g) {
+                    counts[*g] += 1;
+                  }
+                });
+          } else {
+            // Array without validity bitmaps require special handling of 
nulls.
+            const bool all_valid = !input.MayHaveLogicalNulls();
+            if (all_valid) {
+              for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
+                counts[*g_begin] += 1;
+              }
+            } else {
+              switch (input.type->id()) {
+                case Type::RUN_END_ENCODED:
+                  RunEndEncodedCountImpl<true>{}(input, counts, g_begin);
+                  break;
+                default:  // Generic and forward-compatible version.
+                  for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
+                    counts[*g_begin] += input.IsValid(i);
+                  }
+                  break;
+              }
+            }
+          }
         }
       } else {  // ONLY_NULL
         if (input.type->id() == arrow::Type::NA) {
           for (int64_t i = 0; i < batch.length; ++i, ++g_begin) {
             counts[*g_begin] += 1;
           }
-        } else if (input.MayHaveNulls()) {
-          auto end = input.offset + input.length;
-          for (int64_t i = input.offset; i < end; ++i, ++g_begin) {
-            counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i);
+        } else if (input.MayHaveLogicalNulls()) {
+          if (input.HasValidityBitmap()) {
+            auto end = input.offset + input.length;
+            for (int64_t i = input.offset; i < end; ++i, ++g_begin) {
+              counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i);
+            }
+          } else {
+            // Arrays without validity bitmaps require special handling of 
nulls.
+            switch (input.type->id()) {
+              case Type::RUN_END_ENCODED:
+                RunEndEncodedCountImpl<false>{}(input, counts, g_begin);
+                break;
+              default:  // Generic and forward-compatible version.
+                for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
+                  counts[*g_begin] += input.IsNull(i);
+                }
+                break;
+            }
           }
         }
       }
diff --git a/cpp/src/arrow/testing/gtest_util.cc 
b/cpp/src/arrow/testing/gtest_util.cc
index 37c430892d..9569375bda 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -47,6 +47,7 @@
 
 #include "arrow/array.h"
 #include "arrow/buffer.h"
+#include "arrow/compute/api_vector.h"
 #include "arrow/datum.h"
 #include "arrow/ipc/json_simple.h"
 #include "arrow/pretty_print.h"
@@ -427,6 +428,24 @@ std::shared_ptr<Table> TableFromJSON(const 
std::shared_ptr<Schema>& schema,
   return *Table::FromRecordBatches(schema, std::move(batches));
 }
 
+Result<std::shared_ptr<Table>> RunEndEncodeTableColumns(
+    const Table& table, const std::vector<int>& column_indices) {
+  const int num_columns = table.num_columns();
+  std::vector<std::shared_ptr<ChunkedArray>> encoded_columns;
+  encoded_columns.reserve(num_columns);
+  for (int i = 0; i < num_columns; i++) {
+    if (std::find(column_indices.begin(), column_indices.end(), i) !=
+        column_indices.end()) {
+      ARROW_ASSIGN_OR_RAISE(auto run_end_encoded, 
compute::RunEndEncode(table.column(i)));
+      DCHECK_EQ(run_end_encoded.kind(), Datum::CHUNKED_ARRAY);
+      encoded_columns.push_back(run_end_encoded.chunked_array());
+    } else {
+      encoded_columns.push_back(table.column(i));
+    }
+  }
+  return Table::Make(table.schema(), std::move(encoded_columns));
+}
+
 Result<std::optional<std::string>> PrintArrayDiff(const ChunkedArray& expected,
                                                   const ChunkedArray& actual) {
   if (actual.Equals(expected)) {
diff --git a/cpp/src/arrow/testing/gtest_util.h 
b/cpp/src/arrow/testing/gtest_util.h
index 2708056295..55bd307b12 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -341,6 +341,10 @@ ARROW_TESTING_EXPORT
 std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>&,
                                      const std::vector<std::string>& json);
 
+ARROW_TESTING_EXPORT
+Result<std::shared_ptr<Table>> RunEndEncodeTableColumns(
+    const Table& table, const std::vector<int>& column_indices);
+
 // Given an array, return a new identical array except for one validity bit
 // set to a new value.
 // This is useful to force the underlying "value" of null entries to otherwise

Reply via email to