llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-llvm-support Author: Vitaly Buka (vitalybuka) <details> <summary>Changes</summary> This commit introduces a RadixTree implementation to LLVM. RadixTree, as a Trie, is very efficient by searching for prefixes. A Radix Tree is more efficient implementation of Trie. The tree will be used to optimize Glob matching in SpecialCaseList. --- Patch is 24.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164524.diff 3 Files Affected: - (added) llvm/include/llvm/Support/RadixTree.h (+345) - (modified) llvm/unittests/Support/CMakeLists.txt (+1) - (added) llvm/unittests/Support/RadixTreeTest.cpp (+372) ``````````diff diff --git a/llvm/include/llvm/Support/RadixTree.h b/llvm/include/llvm/Support/RadixTree.h new file mode 100644 index 0000000000000..c9bed5ca6ba1b --- /dev/null +++ b/llvm/include/llvm/Support/RadixTree.h @@ -0,0 +1,345 @@ +//===-- RadixTree.h - Radix Tree implementation -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// +// +// This file implements a Radix Tree. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_RADIXTREE_H +#define LLVM_SUPPORT_RADIXTREE_H + +#include "llvm/ADT/ADL.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" +#include <cassert> +#include <cstddef> +#include <iterator> +#include <limits> +#include <list> +#include <utility> + +namespace llvm { + +/// \brief A Radix Tree implementation. +/// +/// A Radix Tree (also known as a compact prefix tree or radix trie) is a +/// data structure that stores a dynamic set or associative array where keys +/// are strings and values are associated with these keys. Unlike a regular +/// trie, the edges of a radix tree can be labeled with sequences of characters +/// as well as single characters. This makes radix trees more efficient for +/// storing sparse data sets, where many nodes in a regular trie would have +/// only one child. +/// +/// This implementation supports arbitrary key types that can be iterated over +/// (e.g., `std::string`, `std::vector<char>`, `ArrayRef<char>`). The key type +/// must provide `begin()` and `end()` for iteration. +/// +/// The tree stores `std::pair<const KeyType, T>` as its value type. +/// +/// Example usage: +/// \code +/// llvm::RadixTree<StringRef, int> Tree; +/// Tree.emplace("apple", 1); +/// Tree.emplace("grapefruit", 2); +/// Tree.emplace("grape", 3); +/// +/// // Find prefixes +/// for (const auto &pair : Tree.find_prefixes("grapefruit juice")) { +/// // pair will be {"grape", 3} +/// // pair will be {"grapefruit", 2} +/// llvm::outs() << pair.first << ": " << pair.second << "\n"; +/// } +/// +/// // Iterate over all elements +/// for (const auto &pair : Tree) { +/// llvm::outs() << pair.first << ": " << pair.second << "\n"; +/// } +/// \endcode +/// +/// \note +/// The `RadixTree` takes ownership of the `KeyType` and `T` objects +/// inserted into it. When an element is removed or the tree is destroyed, +/// these objects will be destructed. +/// However, if `KeyType` is a reference-like type, e.g. StrignRef or range, +/// User must guaranty that destination has lifetime longer than the tree. +template <typename KeyType, typename T> class RadixTree { +public: + using key_type = KeyType; + using mapped_type = T; + using value_type = std::pair<const KeyType, mapped_type>; + +private: + using KeyConstIteratorType = + decltype(adl_begin(std::declval<const key_type &>())); + using KeyConstIteratorRangeType = iterator_range<KeyConstIteratorType>; + using KeyValueType = + remove_cvref_t<decltype(*adl_begin(std::declval<key_type &>()))>; + using ContainerType = std::list<value_type>; + + /// Represents a internal node in the Radix Tree. + struct Node { + KeyConstIteratorRangeType Key = {KeyConstIteratorType{}, + KeyConstIteratorType{}}; + std::vector<Node> Children; + + /// An iterator to the value associated with this node. + /// + /// If this node does not have a value (i.e., it's an internal node that + /// only serves as a path to other values), this iterator will be equal + /// to default constructed `ContainerType::iterator()`. + typename ContainerType::iterator Value; + + /// The first character of the Key. Used for fast child lookup. + KeyValueType KeyFront; + + Node() = default; + Node(const KeyConstIteratorRangeType &Key) + : Key(Key), KeyFront(*Key.begin()) { + assert(!Key.empty()); + } + + Node(Node &&) = default; + Node &operator=(Node &&) = default; + + Node(const Node &) = delete; + Node &operator=(const Node &) = delete; + + const Node *findChild(const KeyConstIteratorRangeType &Key) const { + if (!Key.empty()) { + for (const auto &Child : Children) { + assert(!Child.Key.empty()); // Only root can be empty. + if (Child.KeyFront == *Key.begin()) + return &Child; + } + } + return nullptr; + } + + Node *findChild(const KeyConstIteratorRangeType &Query) { + const Node *This = this; + return const_cast<Node *>(This->findChild(Query)); + } + + size_t countNodes() const { + size_t R = 1; + for (const auto &C : Children) + R += C.countNodes(); + return R; + } + + /// + /// Splits the current node into two. + /// + /// This function is used when a new key needs to be inserted that shares + /// a common prefix with the current node's key, but then diverges. + /// The current `Key` is truncated to the common prefix, and a new child + /// node is created for the remainder of the original node's `Key`. + /// + /// \param SplitPoint An iterator pointing to the character in the current + /// `Key` where the split should occur. + void split(KeyConstIteratorType SplitPoint) { + Node Child(make_range(SplitPoint, Key.end())); + Key = make_range(Key.begin(), SplitPoint); + + Children.swap(Child.Children); + std::swap(Value, Child.Value); + + Children.emplace_back(std::move(Child)); + } + }; + + Node Root; // Root is always for empty range. + ContainerType Values; + + /// Finds or creates a new tail or leaf node corresponding to the `Key`. + Node &findOrCreate(KeyConstIteratorRangeType Key) { + Node *Curr = &Root; + if (Key.empty()) + return *Curr; + + for (;;) { + auto [I1, I2] = llvm::mismatch(Key, Curr->Key); + Key = make_range(I1, Key.end()); + + if (I2 != Curr->Key.end()) { + // Match is partial. Either query is too short, or there is missmatching + // character. Split either way, and put new node in between of the + // current and its children. + Curr->split(I2); + + // Split was caused by mismatch, we can't 'findChild' will fail. + break; + } + + Node *Child = Curr->findChild(Key); + if (!Child) + break; + + // Move to child with the same first character. + Curr = Child; + } + + if (Key.empty()) { + // The current node completely matches the key, return it. + return *Curr; + } + + // `Key` a suffix of original `Key` unmatched by path from the `Root` to the + // `Curr`, and we have no candidate in the children to match more. Create a + // new one. + return Curr->Children.emplace_back(Key); + } + + /// + /// An iterator for traversing prefixes searche results. + /// + /// This iterator is used by `find_prefixes` to traverse the tree and find + /// elements that are prefixes to the given key. It's a forward iterator. + /// + /// \tparam MappedType The type of the value pointed to by the iterator. + /// This will be `value_type` for non-const iterators + /// and `const value_type` for const iterators. + template <typename MappedType> + class IteratorImpl + : public iterator_facade_base<IteratorImpl<MappedType>, + std::forward_iterator_tag, MappedType> { + const Node *Curr = nullptr; + KeyConstIteratorRangeType Query; + + void findNextValid() { + while (Curr && Curr->Value == typename ContainerType::iterator()) + advance(); + } + + void advance() { + assert(Curr); + if (Query.empty()) { + Curr = nullptr; + return; + } + + Curr = Curr->findChild(Query); + if (!Curr) { + Curr = nullptr; + return; + } + + auto [I1, I2] = llvm::mismatch(Query, Curr->Key); + if (I2 != Curr->Key.end()) { + Curr = nullptr; + return; + } + Query = make_range(I1, Query.end()); + } + + friend class RadixTree; + IteratorImpl(const Node *C, const KeyConstIteratorRangeType &Q) + : Curr(C), Query(Q) { + findNextValid(); + } + + public: + IteratorImpl() : Query{{}, {}} {} + + MappedType &operator*() const { return *Curr->Value; } + + IteratorImpl &operator++() { + advance(); + findNextValid(); + return *this; + } + + bool operator==(const IteratorImpl &Other) const { + return Curr == Other.Curr; + } + }; + +public: + RadixTree() = default; + RadixTree(RadixTree &&) = default; + RadixTree &operator=(RadixTree &&) = default; + + using prefix_iterator = IteratorImpl<value_type>; + using const_prefix_iterator = IteratorImpl<const value_type>; + + using iterator = typename ContainerType::iterator; + using const_iterator = typename ContainerType::const_iterator; + + /// Returns true if the tree is empty. + bool empty() const { return Values.empty(); } + + /// Returns the number of elements in the tree. + size_t size() const { return Values.size(); } + + /// Returns the number of nodes in the tree. + /// + /// This function counts all internal in the tree. It can be useful for + /// understanding the memory footprint or complexity of the tree structure. + size_t countNodes() const { return Root.countNodes(); } + + /// Returns an iterator to the first element. + iterator begin() { return Values.begin(); } + const_iterator begin() const { return Values.begin(); } + + /// Returns an iterator to the end of the tree. + iterator end() { return Values.end(); } + const_iterator end() const { return Values.end(); } + + /// Constructs and inserts a new element into the tree. + /// + /// This function constructs an element in-place within the tree. If an + /// element with the same key already exists, the insertion fails and the + /// function returns an iterator to the existing element along with `false`. + /// Otherwise, the new element is inserted and the function returns an + /// iterator to the new element along with `true`. + /// + /// \param Key The key of the element to construct. + /// \param Args Arguments to forward to the constructor of the mapped_type. + /// \return A pair consisting of an iterator to the inserted element (or to + /// the element that prevented insertion) and a boolean value + /// indicating whether the insertion took place. + template <typename... Ts> + std::pair<iterator, bool> emplace(key_type &&Key, Ts &&...Args) { + const value_type &NewValue = + Values.emplace_front(std::move(Key), T(std::move(Args)...)); + Node &Node = findOrCreate(NewValue.first); + bool HasValue = Node.Value != typename ContainerType::iterator(); + if (!HasValue) { + Node.Value = Values.begin(); + } else { + Values.pop_front(); + } + return std::make_pair(Node.Value, !HasValue); + } + + /// + /// Finds all elements whose keys are prefixes of the given `Key`. + /// + /// This function returns an iterator range over all elements in the tree + /// whose keys are prefixes of the provided `Key`. For example, if the tree + /// contains "abcde", "abc", "abcdefgh, and `Key` is "abcde", this function + /// would return iterators to "abcde" and "abc". + /// + /// \param Key The key to search for prefixes of. + /// \return An `iterator_range` of `const_prefix_iterator`s, allowing + /// iteration over the found prefix elements. + /// \note The returned iterators reference the `Key` provided by the caller. + /// The caller must ensure that `Key` remains valid for the lifetime + /// of the iterators. + iterator_range<const_prefix_iterator> + find_prefixes(const key_type &Key) const { + return iterator_range<const_prefix_iterator>{ + const_prefix_iterator( + &Root, KeyConstIteratorRangeType{adl_begin(Key), adl_end(Key)}), + const_prefix_iterator{}}; + } +}; + +} // namespace llvm + +#endif // LLVM_SUPPORT_RADIXTREE_H diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt index 21f10eb610f11..80646cfc0ef1f 100644 --- a/llvm/unittests/Support/CMakeLists.txt +++ b/llvm/unittests/Support/CMakeLists.txt @@ -76,6 +76,7 @@ add_llvm_unittest(SupportTests ProcessTest.cpp ProgramTest.cpp ProgramStackTest.cpp + RadixTreeTest.cpp RecyclerTest.cpp RegexTest.cpp ReverseIterationTest.cpp diff --git a/llvm/unittests/Support/RadixTreeTest.cpp b/llvm/unittests/Support/RadixTreeTest.cpp new file mode 100644 index 0000000000000..e94a40eaf0264 --- /dev/null +++ b/llvm/unittests/Support/RadixTreeTest.cpp @@ -0,0 +1,372 @@ +//===- llvm/unittest/Support/RadixTreeTypeTest.cpp ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/RadixTree.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include <iterator> +#include <list> +#include <vector> + +using namespace llvm; +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +// Test with StringRef. + +TEST(RadixTreeTest, Empty) { + RadixTree<StringRef, int> T; + EXPECT_TRUE(T.empty()); + EXPECT_EQ(0u, T.size()); + + EXPECT_TRUE(T.find_prefixes("").empty()); + EXPECT_TRUE(T.find_prefixes("A").empty()); + + EXPECT_EQ(1u, T.countNodes()); +} + +TEST(RadixTreeTest, InsertEmpty) { + RadixTree<StringRef, int> T; + auto [It, IsNew] = T.emplace("", 4); + EXPECT_TRUE(!T.empty()); + EXPECT_EQ(1u, T.size()); + EXPECT_TRUE(IsNew); + const auto &[K, V] = *It; + EXPECT_TRUE(K.empty()); + EXPECT_EQ(4, V); + + EXPECT_THAT(T, ElementsAre(Pair("", 4))); + + EXPECT_THAT(T.find_prefixes(""), ElementsAre(Pair("", 4))); + + EXPECT_THAT(T.find_prefixes("a"), ElementsAre(Pair("", 4))); + + EXPECT_EQ(1u, T.countNodes()); +} + +TEST(RadixTreeTest, Complex) { + RadixTree<StringRef, int> T; + T.emplace("abcd", 1); + EXPECT_EQ(2u, T.countNodes()); + T.emplace("abklm", 2); + EXPECT_EQ(4u, T.countNodes()); + T.emplace("123abklm", 3); + EXPECT_EQ(5u, T.countNodes()); + T.emplace("123abklm", 4); + EXPECT_EQ(5u, T.countNodes()); + T.emplace("ab", 5); + EXPECT_EQ(5u, T.countNodes()); + T.emplace("1234567", 6); + EXPECT_EQ(7u, T.countNodes()); + T.emplace("123456", 7); + EXPECT_EQ(8u, T.countNodes()); + T.emplace("123456789", 8); + EXPECT_EQ(9u, T.countNodes()); + + EXPECT_THAT(T, UnorderedElementsAre(Pair("abcd", 1), Pair("abklm", 2), + Pair("123abklm", 3), Pair("ab", 5), + Pair("1234567", 6), Pair("123456", 7), + Pair("123456789", 8))); + + EXPECT_THAT(T.find_prefixes("1234567890"), + UnorderedElementsAre(Pair("1234567", 6), Pair("123456", 7), + Pair("123456789", 8))); + + EXPECT_THAT(T.find_prefixes("123abklm"), + UnorderedElementsAre(Pair("123abklm", 3))); + + EXPECT_THAT(T.find_prefixes("abcdefg"), + UnorderedElementsAre(Pair("abcd", 1), Pair("ab", 5))); + + EXPECT_EQ(9u, T.countNodes()); +} + +// Test different types, less readable. + +template <typename T> struct TestData { + static const T Data1[]; + static const T Data2[]; +}; + +template <> const char TestData<char>::Data1[] = "abcdedcba"; +template <> const char TestData<char>::Data2[] = "abCDEDCba"; + +template <> const int TestData<int>::Data1[] = {1, 2, 3, 4, 5, 4, 3, 2, 1}; +template <> const int TestData<int>::Data2[] = {1, 2, 4, 8, 16, 8, 4, 2, 1}; + +template <typename T> class RadixTreeTypeTest : public ::testing::Test { +public: + using IteratorType = decltype(adl_begin(std::declval<const T &>())); + using CharType = remove_cvref_t<decltype(*adl_begin(std::declval<T &>()))>; + + T make(const CharType *Data, size_t N) { return T(StringRef(Data, N)); } + + T make1(size_t N) { return make(TestData<CharType>::Data1, N); } + T make2(size_t N) { return make(TestData<CharType>::Data2, N); } +}; + +template <> +iterator_range<StringRef::const_iterator> +RadixTreeTypeTest<iterator_range<StringRef::const_iterator>>::make( + const char *Data, size_t N) { + return StringRef(Data).take_front(N); +} + +template <> +iterator_range<StringRef::const_reverse_iterator> +RadixTreeTypeTest<iterator_range<StringRef::const_reverse_iterator>>::make( + const char *Data, size_t N) { + return reverse(StringRef(Data).take_back(N)); +} + +template <> +ArrayRef<int> RadixTreeTypeTest<ArrayRef<int>>::make(const int *Data, + size_t N) { + return ArrayRef<int>(Data, Data + N); +} + +template <> +std::vector<int> RadixTreeTypeTest<std::vector<int>>::make(const int *Data, + size_t N) { + return std::vector<int>(Data, Data + N); +} + +template <> +std::list<int> RadixTreeTypeTest<std::list<int>>::make(const int *Data, + size_t N) { + return std::list<int>(Data, Data + N); +} + +class TypeNameGenerator { +public: + template <typename T> static std::string GetName(int) { + if (std::is_same_v<T, StringRef>) + return "StringRef"; + if (std::is_same_v<T, std::string>) + return "string"; + if (std::is_same_v<T, iterator_range<StringRef::const_iterator>>) + return "iterator_range"; + if (std::is_same_v<T, iterator_range<StringRef::const_reverse_iterator>>) + return "reverse_iterator_range"; + if (std::is_same_v<T, ArrayRef<int>>) + return "ArrayRef"; + if (std::is_same_v<T, std::vector<int>>) + return "vector"; + if (std::is_same_v<T, std::list<int>>) + return "list"; + return "Unknown"; + } +}; + +using TestTypes = + ::testing::Types<StringRef, std::string, + iterator_range<StringRef::const_iterator>, + iterator_range<StringRef::const_reverse_iterator>, + ArrayRef<int>, std::vector<int>, std::list<int>>; + +TYPED_TEST_SUITE(RadixTreeTypeTest, TestTypes, TypeNameGenerator); + +TYPED_TEST(RadixTreeTypeTest, Helpers) { + for (size_t i = 0; i < 9; ++i) { + auto R1 = this->make1(i); + auto R2 = this->make2(i); + EXPECT_EQ(i, llvm::range_size(R1)); + EXPECT_EQ(i, llvm::range_size(R2)); + auto [I1, I2] = llvm::mismatch(R1, R2); + // Exactly 2 first elements of Data1 and Data2 must match. + EXPECT_EQ(std::min<int>(2, i), std::distance(R1.begin(), I1)); + } +} + +TYPED_TEST(RadixTreeTypeTest, Empty) { + RadixTree<TypeParam, int> T; + EXPECT_TRUE(T.empty()); + EXPECT_EQ(0u, T.size()); + + EXPECT_TRUE(T.find_prefixes(this->make1(0)).empty()); + EXPECT_TRUE(T.find_prefixes(this->make2(1)).empty()); + + EXPECT_EQ(1u, T.countNodes()); +} + +TYPED_TEST(RadixTreeTypeTest, InsertEmpty) { + using TreeType = RadixTree<TypeParam, int>; + TreeType T; + auto [It, IsNew] = T.emplace(this->make1(0), 5); + EXPECT_TRUE(!T.empty()); + EXPECT_EQ(1u, T.size()); + EXPECT_TRUE(IsNew); + const auto &[K, V] = *It; + EXPECT_TRUE(K.empty()); + EXPECT_EQ(5, V); + + EXPECT_THAT(T.find_prefixes(this->make1(0)), + ElementsAre(Pair(ElementsAre(), 5))); + + EXPECT_THAT(T.find_prefixes(this->make2(1)), + ElementsAre(Pair(ElementsAre(), 5))); + + EXPECT_THAT(T, ElementsAre(Pair(ElementsAre(), 5))); + + EXPECT_EQ(1u, T.countNodes()); +} + +TYPED_TEST(RadixTreeTypeTest, InsertEmptyTwice) { + using TreeType = RadixTree<TypeParam, int>; + TreeType T; + T.emplace(this->make1(0), 5); + auto [It, IsNew] = T.emplace(this->make1(0), 6); + EXPECT_TRUE(!T.empty()); + EXPECT_EQ(1u, T.size()); + EXPECT_TRUE(!IsNew); + const auto &[K, V] = *It; + EXPECT_TRUE(K.empty()); + EXPECT_EQ(5,... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/164524 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
