Author: Duncan P. N. Exon Smith Date: 2021-01-13T20:00:44-08:00 New Revision: 3043e5a5c33c4c871f4a1dfd621a8839f9a1f0b3
URL: https://github.com/llvm/llvm-project/commit/3043e5a5c33c4c871f4a1dfd621a8839f9a1f0b3 DIFF: https://github.com/llvm/llvm-project/commit/3043e5a5c33c4c871f4a1dfd621a8839f9a1f0b3.diff LOG: ADT: Fix reference invalidation in N-element SmallVector::append and insert For small enough, trivially copyable `T`, take the parameter by-value in `SmallVector::append` and `SmallVector::insert`. Otherwise, when growing, update the arugment appropriately. Differential Revision: https://reviews.llvm.org/D93780 Added: Modified: llvm/include/llvm/ADT/SmallVector.h llvm/unittests/ADT/SmallVectorTest.cpp Removed: ################################################################################ diff --git a/llvm/include/llvm/ADT/SmallVector.h b/llvm/include/llvm/ADT/SmallVector.h index c91075677b3f..fea8a763d48f 100644 --- a/llvm/include/llvm/ADT/SmallVector.h +++ b/llvm/include/llvm/ADT/SmallVector.h @@ -223,8 +223,9 @@ class SmallVectorTemplateCommon /// Reserve enough space to add one element, and return the updated element /// pointer in case it was a reference to the storage. template <class U> - static const T *reserveForAndGetAddressImpl(U *This, const T &Elt) { - if (LLVM_LIKELY(This->size() < This->capacity())) + static const T *reserveForAndGetAddressImpl(U *This, const T &Elt, size_t N) { + size_t NewSize = This->size() + N; + if (LLVM_LIKELY(NewSize <= This->capacity())) return &Elt; bool ReferencesStorage = false; @@ -233,7 +234,7 @@ class SmallVectorTemplateCommon ReferencesStorage = true; Index = &Elt - This->begin(); } - This->grow(); + This->grow(NewSize); return ReferencesStorage ? This->begin() + Index : &Elt; } @@ -357,14 +358,14 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> { /// Reserve enough space to add one element, and return the updated element /// pointer in case it was a reference to the storage. - const T *reserveForAndGetAddress(const T &Elt) { - return this->reserveForAndGetAddressImpl(this, Elt); + const T *reserveForAndGetAddress(const T &Elt, size_t N = 1) { + return this->reserveForAndGetAddressImpl(this, Elt, N); } /// Reserve enough space to add one element, and return the updated element /// pointer in case it was a reference to the storage. - T *reserveForAndGetAddress(T &Elt) { - return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt)); + T *reserveForAndGetAddress(T &Elt, size_t N = 1) { + return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt, N)); } static T &&forward_value_param(T &&V) { return std::move(V); } @@ -483,14 +484,14 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> { /// Reserve enough space to add one element, and return the updated element /// pointer in case it was a reference to the storage. - const T *reserveForAndGetAddress(const T &Elt) { - return this->reserveForAndGetAddressImpl(this, Elt); + const T *reserveForAndGetAddress(const T &Elt, size_t N = 1) { + return this->reserveForAndGetAddressImpl(this, Elt, N); } /// Reserve enough space to add one element, and return the updated element /// pointer in case it was a reference to the storage. - T *reserveForAndGetAddress(T &Elt) { - return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt)); + T *reserveForAndGetAddress(T &Elt, size_t N = 1) { + return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt, N)); } /// Copy \p V or return a reference, depending on \a ValueParamT. @@ -616,12 +617,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> { } /// Append \p NumInputs copies of \p Elt to the end. - void append(size_type NumInputs, const T &Elt) { - this->assertSafeToAdd(&Elt, NumInputs); - if (NumInputs > this->capacity() - this->size()) - this->grow(this->size()+NumInputs); - - std::uninitialized_fill_n(this->end(), NumInputs, Elt); + void append(size_type NumInputs, ValueParamT Elt) { + const T *EltPtr = this->reserveForAndGetAddress(Elt, NumInputs); + std::uninitialized_fill_n(this->end(), NumInputs, *EltPtr); this->set_size(this->size() + NumInputs); } @@ -732,7 +730,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> { return insert_one_impl(I, this->forward_value_param(Elt)); } - iterator insert(iterator I, size_type NumToInsert, const T &Elt) { + iterator insert(iterator I, size_type NumToInsert, ValueParamT Elt) { // Convert iterator to elt# to avoid invalidating iterator when we reserve() size_t InsertElt = I - this->begin(); @@ -743,11 +741,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> { assert(this->isReferenceToStorage(I) && "Insertion iterator is out of bounds."); - // Check that adding NumToInsert elements won't invalidate Elt. - this->assertSafeToAdd(&Elt, NumToInsert); - - // Ensure there is enough space. - reserve(this->size() + NumToInsert); + // Ensure there is enough space, and get the (maybe updated) address of + // Elt. + const T *EltPtr = this->reserveForAndGetAddress(Elt, NumToInsert); // Uninvalidate the iterator. I = this->begin()+InsertElt; @@ -764,7 +760,12 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> { // Copy the existing elements that get replaced. std::move_backward(I, OldEnd-NumToInsert, OldEnd); - std::fill_n(I, NumToInsert, Elt); + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + + std::fill_n(I, NumToInsert, *EltPtr); return I; } @@ -777,11 +778,16 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> { size_t NumOverwritten = OldEnd-I; this->uninitialized_move(I, OldEnd, this->end()-NumOverwritten); + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + // Replace the overwritten part. - std::fill_n(I, NumOverwritten, Elt); + std::fill_n(I, NumOverwritten, *EltPtr); // Insert the non-overwritten middle part. - std::uninitialized_fill_n(OldEnd, NumToInsert-NumOverwritten, Elt); + std::uninitialized_fill_n(OldEnd, NumToInsert - NumOverwritten, *EltPtr); return I; } diff --git a/llvm/unittests/ADT/SmallVectorTest.cpp b/llvm/unittests/ADT/SmallVectorTest.cpp index c880a6b6c543..c236a68636d0 100644 --- a/llvm/unittests/ADT/SmallVectorTest.cpp +++ b/llvm/unittests/ADT/SmallVectorTest.cpp @@ -1146,9 +1146,17 @@ TYPED_TEST(SmallVectorReferenceInvalidationTest, Resize) { TYPED_TEST(SmallVectorReferenceInvalidationTest, Append) { auto &V = this->V; (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.append(1, V.back()), this->AssertionMessage); -#endif + V.append(1, V.back()); + int N = this->NumBuiltinElts(V); + EXPECT_EQ(N, V[N - 1]); + + // Append enough more elements that V will grow again. This tests growing + // when already in large mode. + // + // If reference invalidation breaks in the future, sanitizers should be able + // to catch a use-after-free here. + V.append(V.capacity() - V.size() + 1, V.front()); + EXPECT_EQ(1, V.back()); } TYPED_TEST(SmallVectorReferenceInvalidationTest, AppendRange) { @@ -1244,9 +1252,20 @@ TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertMoved) { TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertN) { auto &V = this->V; (void)V; -#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST - EXPECT_DEATH(V.insert(V.begin(), 2, V.back()), this->AssertionMessage); -#endif + + // Cover NumToInsert <= this->end() - I. + V.insert(V.begin() + 1, 1, V.back()); + int N = this->NumBuiltinElts(V); + EXPECT_EQ(N, V[1]); + + // Cover NumToInsert > this->end() - I, inserting enough elements that V will + // also grow again; V.capacity() will be more elements than necessary but + // it's a simple way to cover both conditions. + // + // If reference invalidation breaks in the future, sanitizers should be able + // to catch a use-after-free here. + V.insert(V.begin(), V.capacity(), V.front()); + EXPECT_EQ(1, V.front()); } TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertRange) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits