mapleFU commented on code in PR #40915:
URL: https://github.com/apache/arrow/pull/40915#discussion_r1545982462
##########
cpp/src/arrow/util/hashing.h:
##########
@@ -380,6 +382,269 @@ class HashTable {
TypedBufferBuilder<Entry> entries_builder_;
};
+// SwissHashTable is a hash table adapated from the "SwissTable" family of
hash tables
+// from Abseil (https://abseil.io/blog/20180927-swisstables) (no deletes)
+template <typename Payload>
+class SwissHashTable {
+ public:
+ static constexpr hash_t kSentinel = 0ULL;
+ static constexpr uint64_t kGroupSize = 8;
+ static constexpr uint64_t kMaxAvgGroupLoad = 7;
+ static_assert(kMaxAvgGroupLoad < kGroupSize);
+
+ static constexpr uint64_t kH1Mask = 0xffffffffffffff80ULL;
+ static constexpr uint64_t kH2Mask = 0x000000000000007fULL;
+ static constexpr uint8_t kEmptyControlByte = 0b10000000;
+
+ static constexpr uint64_t kLoBits = 0x0101010101010101;
+ static constexpr uint64_t kHiBits = 0x8080808080808080;
+
+ using H1 = uint64_t;
+ using H2 = uint8_t;
+
+ struct Entry {
+ hash_t h;
+ Payload payload;
+ uint64_t entry_index;
+
+ // An entry is valid if the hash is different from the sentinel value
+ explicit operator bool() const { return h != kSentinel; }
+ };
+
+ // TODO(SGZW): support kGroupSize = 16 by simd
+ struct Group {
+ Entry entries[kGroupSize];
+ };
+
+ // metadata is the h2 metadata array for a group.
+ // find operations first probe the controls bytes
+ // to filter candidates before matching keys
+ struct GroupMeta {
+ uint8_t control_bytes[kGroupSize];
+ };
+
+ SwissHashTable(MemoryPool* pool, uint64_t capacity)
+ : group_builder_(pool), group_meta_builder_(pool) {
+ DCHECK_NE(pool, nullptr);
+ groups_count_ = NumGroups(capacity);
+ limit_ = groups_count_ * kMaxAvgGroupLoad;
+ size_ = 0;
+ DCHECK_OK(UpsizeBuffer(groups_count_));
+ }
+
+ // Lookup with non-linear probing
+ // cmp_func should have signature bool(const Payload*).
+ // Return a (Entry*, found) pair.
+ template <typename CmpFunc>
+ std::pair<Entry*, bool> Lookup(hash_t h, CmpFunc&& cmp_func) {
+ return DoLookup<DoCompare, CmpFunc>(h);
+ }
+
+ template <typename CmpFunc>
+ std::pair<const Entry*, bool> Lookup(hash_t h, CmpFunc&& cmp_func) const {
+ return DoLookup<DoCompare, CmpFunc>(h);
+ }
+
+ Status Insert(Entry* entry, hash_t h, const Payload& payload) {
+ return DoInsert(entry, h, payload);
+ }
+
+ uint64_t size() const { return size_; }
+
+ // Visit all non-empty entries in the table
+ // The visit_func should have signature void(const Entry*)
+ template <typename VisitFunc>
+ void VisitEntries(VisitFunc&& visit_func) const {
+ for (uint64_t i = 0; i < groups_count_; i++) {
+ const auto& group = groups_[i];
+ for (uint16_t j = 0; j < kGroupSize; j++) {
+ const auto& entry = group.entries[j];
+ if (entry) {
+ visit_func(&entry);
+ }
+ }
+ }
+ }
+
+ protected:
+ // NoCompare is for when the value is known not to exist in the table
+ enum CompareKind { DoCompare, NoCompare };
+
+ Status DoInsert(Entry* entry, hash_t h, Payload payload) {
+ // Ensure entry is empty before inserting
+ assert(!*entry);
+ assert(entry->entry_index != 0);
+ entry->h = h;
+ entry->payload = std::move(payload);
+ ++size_;
+
+ auto p = UnPackEntryIndex(entry->index);
+ auto group_index = p.first;
+ auto group_internal_index = p.second;
+
+ // update meta
+ group_metas_[group_index].control_bytes[group_internal_index] = h &
kH2Mask;
+
+ if (ARROW_PREDICT_FALSE(NeedUpsizing())) {
+ // Resize less frequently since it is expensive
+ return Upsize();
+ }
+ return Status::OK();
+ }
+
+ // The workhorse lookup function
+ template <CompareKind CKind, typename CmpFunc>
+ std::pair<const Entry*, bool> DoLookup(hash_t h, CmpFunc&& cmp_func) const {
+ auto hash_pair = SplitHash(h);
+ auto h1 = hash_pair.first;
+ auto h2 = hash_pair.second;
+
+ auto group_index = ProbeStart(h1);
+ while (true) {
+ const auto& group = groups_[group_index];
+ // probe
+ auto match_value = GroupMetaMatchH2(group_metas_[group_index], h2);
+ while (match_value != 0) {
+ auto group_internal_index = NextMatchGroupInternalIndex(&match_value);
+ auto* entry = &group.entries[group_internal_index];
+ if (CompareEntry<CKind, CmpFunc>(h, entry,
std::forward<CmpFunc>(cmp_func))) {
+ // Found
+ entry->index = PackEntryIndex(group_index, group_internal_index);
+ return {entry, true};
+ }
+ }
+
+ // stop probing if we see an empty slot
+ auto match_empty_value = GroupMetaMatchEmpty(group_metas_[group_index]);
+ if (match_empty_value != 0) {
+ auto group_internal_index =
NextMatchGroupInternalIndex(&match_empty_value);
+ auto* entry = &group.entries[group_internal_index];
+ // Not Found
+ entry->index = PackEntryIndex(group_index, group_internal_index);
+ return {entry, false};
+ }
+
+ // next group
+ ++group_index;
+ if (ARROW_PREDICT_FALSE(group_index >= groups_count_)) {
+ group_index = 0;
+ }
+ }
+ }
+
+ template <CompareKind CKind, typename CmpFunc>
+ bool CompareEntry(hash_t h, const Entry* entry, CmpFunc&& cmp_func) const {
+ if (CKind == NoCompare) {
+ return false;
+ } else {
+ return entry->h == h && cmp_func(&entry->payload);
+ }
+ }
+
+ bool NeedUpsizing() const {
+ // Keep the load factor(size_ / (groups_count_ * kGroupSize)) <=
kMaxAvgGroupLoad /
+ // kGroupSize
+ return size_ >= limit_;
+ }
+
+ Status UpsizeBuffer(uint64_t groups_count) {
+ RETURN_NOT_OK(group_builder_.Resize(groups_count));
+ RETURN_NOT_OK(group_meta_builder_.Resize(groups_count));
+ groups_ = group_builder_.mutable_data();
+ group_metas_ = group_meta_builder_.mutable_data();
+ memset(static_cast<void*>(groups_), 0, groups_count * sizeof(Entry));
+ memset(static_cast<void*>(group_metas_), kEmptyControlByte,
+ groups_count * sizeof(Entry));
+ return Status::OK();
+ }
+
+ Status Upsize() {
+ auto old_groups_count = groups_count_;
+ auto old_size = size_;
+
+ // Stash old entries and seal builder, effectively resetting the Buffer
+ const Group* old_groups = groups_;
+ ARROW_ASSIGN_OR_RAISE(auto previous_groups,
+ group_builder_.FinishWithLength(old_groups_count));
+ ARROW_ASSIGN_OR_RAISE(auto previous_metas,
+
group_meta_builder_.FinishWithLength(old_groups_count));
+
+ groups_count_ = old_groups_count << 1;
+ limit_ = groups_count_ * kMaxAvgGroupLoad;
+ size_ = 0;
+ RETURN_NOT_OK(UpsizeBuffer(groups_count_));
+
+ uint64_t reinsert_count = 0;
+ for (uint64_t i = 0; i < old_groups; i++) {
+ auto& old_group = old_groups[i];
+ for (uint16_t j = 0; j < kGroupSize; j++) {
+ auto* old_entry = *old_group.entries[j];
+ if (old_entry) {
+ ++reinsert_count;
+ auto p = DoLookup<CompareKind::NoCompare>(old_entry->h,
+ [](const Payload*) {
return false; });
+ assert(!p.second);
Review Comment:
DCHECK?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]