Author: Kiran Chandramohan
Date: 2024-05-24T11:18:55+01:00
New Revision: 2790b4d9e63c13d1e692cc301bbd373b10f28070

URL: 
https://github.com/llvm/llvm-project/commit/2790b4d9e63c13d1e692cc301bbd373b10f28070
DIFF: 
https://github.com/llvm/llvm-project/commit/2790b4d9e63c13d1e692cc301bbd373b10f28070.diff

LOG: Revert "[mlir] Fix race condition introduced in ThreadLocalCache (#93280)"

This reverts commit 6977bfb57c3efb9488aef463cd7ea521fd25a067.

Added: 
    

Modified: 
    mlir/include/mlir/Support/ThreadLocalCache.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/ThreadLocalCache.h 
b/mlir/include/mlir/Support/ThreadLocalCache.h
index fe6c6fa3cf6bd..d19257bf6e25e 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/Mutex.h"
 
 namespace mlir {
@@ -24,80 +25,28 @@ namespace mlir {
 /// cache has very large lock contention.
 template <typename ValueT>
 class ThreadLocalCache {
-  struct PerInstanceState;
-
-  /// The "observer" is owned by a thread-local cache instance. It is
-  /// constructed the first time a `ThreadLocalCache` instance is accessed by a
-  /// thread, unless `perInstanceState` happens to get re-allocated to the same
-  /// address as a previous one. This class is destructed the thread in which
-  /// the `thread_local` cache lives is destroyed.
-  ///
-  /// This class is called the "observer" because while values cached in
-  /// thread-local caches are owned by `PerInstanceState`, a reference is 
stored
-  /// via this class in the TLC. With a double pointer, it knows when the
-  /// referenced value has been destroyed.
-  struct Observer {
-    /// This is the double pointer, explicitly allocated because we need to 
keep
-    /// the address stable if the TLC map re-allocates. It is owned by the
-    /// observer and shared with the value owner.
-    std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
-    /// Because `Owner` living inside `PerInstanceState` contains a reference 
to
-    /// the double pointer, and livkewise this class contains a reference to 
the
-    /// value, we need to synchronize destruction of the TLC and the
-    /// `PerInstanceState` to avoid racing. This weak pointer is acquired 
during
-    /// TLC destruction if the `PerInstanceState` hasn't entered its destructor
-    /// yet, and prevents it from happening.
-    std::weak_ptr<PerInstanceState> keepalive;
-  };
-
-  /// This struct owns the cache entries. It contains a reference back to the
-  /// reference inside the cache so that it can be written to null to indicate
-  /// that the cache entry is invalidated. It needs to do this because
-  /// `perInstanceState` could get re-allocated to the same pointer and we 
don't
-  /// remove entries from the TLC when it is deallocated. Thus, we have to 
reset
-  /// the TLC entries to a starting state in case the `ThreadLocalCache` lives
-  /// shorter than the threads.
-  struct Owner {
-    /// Save a pointer to the reference and write it to the newly created 
entry.
-    Owner(Observer &observer)
-        : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
-      *observer.ptr = value.get();
-    }
-    ~Owner() {
-      if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
-        *ptr = nullptr;
-    }
-
-    Owner(Owner &&) = default;
-    Owner &operator=(Owner &&) = default;
-
-    std::unique_ptr<ValueT> value;
-    std::weak_ptr<ValueT *> ptrRef;
-  };
-
   // Keep a separate shared_ptr protected state that can be acquired atomically
   // instead of using shared_ptr's for each value. This avoids a problem
   // where the instance shared_ptr is locked() successfully, and then the
   // ThreadLocalCache gets destroyed before remove() can be called 
successfully.
   struct PerInstanceState {
-    /// Remove the given value entry. This is called when a thread local cache
-    /// is destructing but still contains references to values owned by the
-    /// `PerInstanceState`. Removal is required because it prevents writeback 
to
-    /// a pointer that was deallocated.
+    /// Remove the given value entry. This is generally called when a thread
+    /// local cache is destructing.
     void remove(ValueT *value) {
       // Erase the found value directly, because it is guaranteed to be in the
       // list.
       llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
-      auto it = llvm::find_if(instances, [&](Owner &instance) {
-        return instance.value.get() == value;
-      });
+      auto it =
+          llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
+            return instance.get() == value;
+          });
       assert(it != instances.end() && "expected value to exist in cache");
       instances.erase(it);
     }
 
     /// Owning pointers to all of the values that have been constructed for 
this
     /// object in the static cache.
-    SmallVector<Owner, 1> instances;
+    SmallVector<std::unique_ptr<ValueT>, 1> instances;
 
     /// A mutex used when a new thread instance has been added to the cache for
     /// this object.
@@ -108,14 +57,14 @@ class ThreadLocalCache {
   /// instance of the non-static cache and a weak reference to an instance of
   /// ValueT. We use a weak reference here so that the object can be destroyed
   /// without needing to lock access to the cache itself.
-  struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
+  struct CacheType
+      : public llvm::SmallDenseMap<PerInstanceState *,
+                                   std::pair<std::weak_ptr<ValueT>, ValueT *>> 
{
     ~CacheType() {
-      // Remove the values of this cache that haven't already expired. This is
-      // required because if we don't remove them, they will contain a 
reference
-      // back to the data here that is being destroyed.
-      for (auto &[instance, observer] : *this)
-        if (std::shared_ptr<PerInstanceState> state = 
observer.keepalive.lock())
-          state->remove(*observer.ptr);
+      // Remove the values of this cache that haven't already expired.
+      for (auto &it : *this)
+        if (std::shared_ptr<ValueT> value = it.second.first.lock())
+          it.first->remove(value.get());
     }
 
     /// Clear out any unused entries within the map. This method is not
@@ -123,7 +72,7 @@ class ThreadLocalCache {
     void clearExpiredEntries() {
       for (auto it = this->begin(), e = this->end(); it != e;) {
         auto curIt = it++;
-        if (!*curIt->second.ptr)
+        if (curIt->second.first.expired())
           this->erase(curIt);
       }
     }
@@ -140,23 +89,27 @@ class ThreadLocalCache {
   ValueT &get() {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
-    Observer &threadInstance = staticCache[perInstanceState.get()];
-    if (ValueT *value = *threadInstance.ptr)
+    std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
+        staticCache[perInstanceState.get()];
+    if (ValueT *value = threadInstance.second)
       return *value;
 
     // Otherwise, create a new instance for this thread.
     {
       llvm::sys::SmartScopedLock<true> threadInstanceLock(
           perInstanceState->instanceMutex);
-      perInstanceState->instances.emplace_back(threadInstance);
+      threadInstance.second =
+          perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
+              .get();
     }
-    threadInstance.keepalive = perInstanceState;
+    threadInstance.first =
+        std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
 
     // Before returning the new instance, take the chance to clear out any used
     // entries in the static map. The cache is only cleared within the same
     // thread to remove the need to lock the cache itself.
     staticCache.clearExpiredEntries();
-    return **threadInstance.ptr;
+    return *threadInstance.second;
   }
   ValueT &operator*() { return get(); }
   ValueT *operator->() { return &get(); }


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to