This is an automated email from the ASF dual-hosted git repository.
zanmato1984 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 70113a62ff GH-32381: [C++] Improve error handling for hash table
merges (#49512)
70113a62ff is described below
commit 70113a62ff5fd5bc9196ffaa3df2a3c9d8e0551e
Author: Kristofer Gaudel <[email protected]>
AuthorDate: Thu May 14 07:45:15 2026 -0700
GH-32381: [C++] Improve error handling for hash table merges (#49512)
### Rationale for this change
Fixes #32381
### What changes are included in this PR?
Proper error handling in `util/hashing.h` and unit test for said error
handling in `util/hashing_test.cc`
### Are these changes tested?
Yes, via included unit test
### Are there any user-facing changes?
Yes, error handling in hashing utilities
* GitHub Issue: #32381
Authored-by: kris-gaudel <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/util/hashing.h | 34 ++++++++++++++---------
cpp/src/arrow/util/hashing_test.cc | 56 ++++++++++++++++++++++++++++++++++++++
2 files changed, 77 insertions(+), 13 deletions(-)
diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h
index 53f92c8f23..c29b6c211d 100644
--- a/cpp/src/arrow/util/hashing.h
+++ b/cpp/src/arrow/util/hashing.h
@@ -286,15 +286,16 @@ class HashTable {
uint64_t size() const { return size_; }
// Visit all non-empty entries in the table
- // The visit_func should have signature void(const Entry*)
+ // The visit_func should have signature Status(const Entry*)
template <typename VisitFunc>
- void VisitEntries(VisitFunc&& visit_func) const {
+ Status VisitEntries(VisitFunc&& visit_func) const {
for (uint64_t i = 0; i < capacity_; i++) {
const auto& entry = entries_[i];
if (entry) {
- visit_func(&entry);
+ RETURN_NOT_OK(visit_func(&entry));
}
}
+ return Status::OK();
}
protected:
@@ -494,12 +495,13 @@ class ScalarMemoTable : public MemoTable {
// So that both uint16_t and Float16 are allowed
static_assert(sizeof(Value) == sizeof(Scalar));
Scalar* out = reinterpret_cast<Scalar*>(out_data);
- hash_table_.VisitEntries([=](const HashTableEntry* entry) {
+ ARROW_DCHECK_OK(hash_table_.VisitEntries([=](const HashTableEntry* entry) {
int32_t index = entry->payload.memo_index - start;
if (index >= 0) {
out[index] = entry->payload.value;
}
- });
+ return Status::OK();
+ }));
// Zero-initialize the null entry
if (null_index_ != kKeyNotFound) {
int32_t index = null_index_ - start;
@@ -534,13 +536,10 @@ class ScalarMemoTable : public MemoTable {
// Merge entries from `other_table` into `this->hash_table_`.
Status MergeTable(const ScalarMemoTable& other_table) {
const HashTableType& other_hashtable = other_table.hash_table_;
-
- other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) {
+ return other_hashtable.VisitEntries([this](const HashTableEntry*
other_entry) {
int32_t unused;
- ARROW_DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused));
+ return this->GetOrInsert(other_entry->payload.value, &unused);
});
- // TODO: ARROW-17074 - implement proper error handling
- return Status::OK();
}
};
@@ -867,6 +866,16 @@ class BinaryMemoTable : public MemoTable {
}
}
+ // Like VisitValues, but allows the visitor to fail. The visitor should have
+ // signature `Status(std::string_view)` or `Status(const std::string_view&)`.
+ template <typename VisitFunc>
+ Status VisitValuesStatus(int32_t start, VisitFunc&& visit) const {
+ for (int32_t i = start; i < size(); ++i) {
+ RETURN_NOT_OK(visit(binary_builder_.GetView(i)));
+ }
+ return Status::OK();
+ }
+
// Visit the stored value at a specific index in insertion order.
// The visitor function should have the signature `void(std::string_view)`
// or `void(const std::string_view&)`.
@@ -899,11 +908,10 @@ class BinaryMemoTable : public MemoTable {
public:
Status MergeTable(const BinaryMemoTable& other_table) {
- other_table.VisitValues(0, [this](std::string_view other_value) {
+ return other_table.VisitValuesStatus(0, [this](std::string_view
other_value) {
int32_t unused;
- ARROW_DCHECK_OK(this->GetOrInsert(other_value, &unused));
+ return this->GetOrInsert(other_value, &unused);
});
- return Status::OK();
}
};
diff --git a/cpp/src/arrow/util/hashing_test.cc
b/cpp/src/arrow/util/hashing_test.cc
index f6ada0acd2..6e4c59a1eb 100644
--- a/cpp/src/arrow/util/hashing_test.cc
+++ b/cpp/src/arrow/util/hashing_test.cc
@@ -30,6 +30,7 @@
#include "arrow/array/builder_primitive.h"
#include "arrow/array/concatenate.h"
+#include "arrow/memory_pool.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/hashing.h"
@@ -376,6 +377,32 @@ TEST(ScalarMemoTable, StressInt64) {
ASSERT_EQ(table.size(), map.size());
}
+TEST(ScalarMemoTable, MergeTablePropagatesInsertError) {
+ int64_t bytes_allocated_limit = 0;
+ {
+ ProxyMemoryPool probe(default_memory_pool());
+ ScalarMemoTable<int64_t> target(&probe, 0);
+ for (int64_t value = 0; value < 15; ++value) {
+ AssertGetOrInsert(target, value, static_cast<int32_t>(value));
+ }
+ bytes_allocated_limit = probe.bytes_allocated();
+ }
+ ASSERT_GT(bytes_allocated_limit, 0);
+
+ ScalarMemoTable<int64_t> source(default_memory_pool(), 0);
+ AssertGetOrInsert(source, 15, 0);
+
+ ProxyMemoryPool proxy(default_memory_pool());
+ CappedMemoryPool pool(&proxy, bytes_allocated_limit);
+ ScalarMemoTable<int64_t> target(&pool, 0);
+ for (int64_t value = 0; value < 15; ++value) {
+ AssertGetOrInsert(target, value, static_cast<int32_t>(value));
+ }
+ ASSERT_EQ(proxy.bytes_allocated(), bytes_allocated_limit);
+
+ ASSERT_RAISES(OutOfMemory, target.MergeTable(source));
+}
+
TEST(BinaryMemoTable, Basics) {
std::string A = "", B = "a", C = "foo", D = "bar", E, F;
E += '\0';
@@ -480,6 +507,35 @@ TEST(BinaryMemoTable, Stress) {
ASSERT_EQ(table.size(), map.size());
}
+TEST(BinaryMemoTable, MergeTablePropagatesInsertError) {
+ const std::vector<std::string> initial_values = {"a", "bb", "ccc", "dddd"};
+ const std::string extra_value(4096, 'x');
+
+ int64_t bytes_allocated_limit = 0;
+ {
+ ProxyMemoryPool probe(default_memory_pool());
+ BinaryMemoTable<BinaryBuilder> target(&probe, 0);
+ for (size_t i = 0; i < initial_values.size(); ++i) {
+ AssertGetOrInsert(target, initial_values[i], static_cast<int32_t>(i));
+ }
+ bytes_allocated_limit = probe.bytes_allocated();
+ }
+ ASSERT_GT(bytes_allocated_limit, 0);
+
+ BinaryMemoTable<BinaryBuilder> source(default_memory_pool(), 0);
+ AssertGetOrInsert(source, extra_value, 0);
+
+ ProxyMemoryPool proxy(default_memory_pool());
+ CappedMemoryPool pool(&proxy, bytes_allocated_limit);
+ BinaryMemoTable<BinaryBuilder> target(&pool, 0);
+ for (size_t i = 0; i < initial_values.size(); ++i) {
+ AssertGetOrInsert(target, initial_values[i], static_cast<int32_t>(i));
+ }
+ ASSERT_EQ(proxy.bytes_allocated(), bytes_allocated_limit);
+
+ ASSERT_RAISES(OutOfMemory, target.MergeTable(source));
+}
+
TEST(BinaryMemoTable, Empty) {
BinaryMemoTable<BinaryBuilder> table(default_memory_pool());
ASSERT_EQ(table.size(), 0);