This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 077a48adb [GLUTEN-6736][VL] Phase 2: Minimize lock scope in 
ListenableArbitrator (#6783)
077a48adb is described below

commit 077a48adbe14d71d5cc3e5c1447ea6f6ecff7f33
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Aug 12 17:02:13 2024 +0800

    [GLUTEN-6736][VL] Phase 2: Minimize lock scope in ListenableArbitrator 
(#6783)
    
    Closes #6736
---
 cpp/velox/memory/VeloxMemoryManager.cc |  36 ++---
 cpp/velox/tests/MemoryManagerTest.cc   | 270 +++++++++++++++++++++++++++++++++
 2 files changed, 288 insertions(+), 18 deletions(-)

diff --git a/cpp/velox/memory/VeloxMemoryManager.cc 
b/cpp/velox/memory/VeloxMemoryManager.cc
index 3f30d8627..5b2ba258a 100644
--- a/cpp/velox/memory/VeloxMemoryManager.cc
+++ b/cpp/velox/memory/VeloxMemoryManager.cc
@@ -91,33 +91,33 @@ class ListenableArbitrator : public 
velox::memory::MemoryArbitrator {
 
   bool growCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) 
override {
     velox::memory::ScopedMemoryArbitrationContext ctx(pool);
-    VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be 
used within a single root pool")
-    auto candidate = candidates_.begin()->first;
+    velox::memory::MemoryPool* candidate;
+    {
+      std::unique_lock guard{mutex_};
+      VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only 
be used within a single root pool")
+      candidate = candidates_.begin()->first;
+    }
     VELOX_CHECK(pool->root() == candidate, "Illegal state in 
ListenableArbitrator");
 
-    std::lock_guard<std::recursive_mutex> l(mutex_);
-    growCapacityLocked(pool->root(), targetBytes);
+    growCapacity0(pool->root(), targetBytes);
     return true;
   }
 
   uint64_t shrinkCapacity(uint64_t targetBytes, bool allowSpill, bool 
allowAbort) override {
     velox::memory::ScopedMemoryArbitrationContext ctx((const 
velox::memory::MemoryPool*)nullptr);
     facebook::velox::exec::MemoryReclaimer::Stats status;
-    VELOX_CHECK_EQ(candidates_.size(), 1, "Gluten only has one root pool");
-    std::lock_guard<std::recursive_mutex> l(mutex_); // FIXME: Do we have 
recursive locking for this mutex?
-    auto pool = candidates_.begin()->first;
-    const uint64_t oldCapacity = pool->capacity();
+    velox::memory::MemoryPool* pool;
+    {
+      std::unique_lock guard{mutex_};
+      VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only 
be used within a single root pool")
+      pool = candidates_.begin()->first;
+    }
     pool->reclaim(targetBytes, 0, status); // ignore the output
-    shrinkPool(pool, 0);
-    const uint64_t newCapacity = pool->capacity();
-    uint64_t total = oldCapacity - newCapacity;
-    listener_->allocationChanged(-total);
-    return total;
+    return shrinkCapacity0(pool, 0);
   }
 
   uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t 
targetBytes) override {
-    std::lock_guard<std::recursive_mutex> l(mutex_);
-    return shrinkCapacityLocked(pool, targetBytes);
+    return shrinkCapacity0(pool, targetBytes);
   }
 
   Stats stats() const override {
@@ -130,7 +130,7 @@ class ListenableArbitrator : public 
velox::memory::MemoryArbitrator {
   }
 
  private:
-  void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
+  void growCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
     // Since
     // 
https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820,
     // We should pass bytes as parameter "reservationBytes" when calling 
::grow.
@@ -152,7 +152,7 @@ class ListenableArbitrator : public 
velox::memory::MemoryArbitrator {
         pool->toString())
   }
 
-  uint64_t shrinkCapacityLocked(velox::memory::MemoryPool* pool, uint64_t 
bytes) {
+  uint64_t shrinkCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
     uint64_t freeBytes = shrinkPool(pool, bytes);
     listener_->allocationChanged(-freeBytes);
     return freeBytes;
@@ -162,7 +162,7 @@ class ListenableArbitrator : public 
velox::memory::MemoryArbitrator {
   const uint64_t memoryPoolInitialCapacity_; // FIXME: Unused.
   const uint64_t memoryPoolTransferCapacity_;
 
-  mutable std::recursive_mutex mutex_;
+  mutable std::mutex mutex_;
   inline static std::string kind_ = "GLUTEN";
   std::unordered_map<velox::memory::MemoryPool*, 
std::weak_ptr<velox::memory::MemoryPool>> candidates_;
 };
diff --git a/cpp/velox/tests/MemoryManagerTest.cc 
b/cpp/velox/tests/MemoryManagerTest.cc
index d86bd46e2..bb102dc2d 100644
--- a/cpp/velox/tests/MemoryManagerTest.cc
+++ b/cpp/velox/tests/MemoryManagerTest.cc
@@ -128,4 +128,274 @@ TEST_F(MemoryManagerTest, 
memoryAllocatorWithBlockReservation) {
   ASSERT_EQ(allocator_->getBytes(), 0);
 }
 
+namespace {
+class AllocationListenerWrapper : public AllocationListener {
+ public:
+  explicit AllocationListenerWrapper() {}
+
+  void set(AllocationListener* const delegate) {
+    if (delegate_ != nullptr) {
+      throw std::runtime_error("Invalid state");
+    }
+    delegate_ = delegate;
+  }
+
+  void allocationChanged(int64_t diff) override {
+    delegate_->allocationChanged(diff);
+  }
+  int64_t currentBytes() override {
+    return delegate_->currentBytes();
+  }
+  int64_t peakBytes() override {
+    return delegate_->peakBytes();
+  }
+
+ private:
+  AllocationListener* delegate_{nullptr};
+};
+
+class SpillableAllocationListener : public AllocationListener {
+ public:
+  virtual uint64_t shrink(uint64_t bytes) = 0;
+  virtual uint64_t spill(uint64_t bytes) = 0;
+};
+
+class MockSparkTaskMemoryManager {
+ public:
+  explicit MockSparkTaskMemoryManager(const uint64_t maxBytes);
+
+  AllocationListener* newListener(std::function<uint64_t(uint64_t)> shrink, 
std::function<uint64_t(uint64_t)> spill);
+
+  uint64_t acquire(uint64_t bytes);
+  void release(uint64_t bytes);
+  uint64_t currentBytes() {
+    return currentBytes_;
+  }
+
+ private:
+  mutable std::recursive_mutex mutex_;
+  std::vector<std::unique_ptr<SpillableAllocationListener>> listeners_{};
+
+  const uint64_t maxBytes_;
+  uint64_t currentBytes_{0L};
+};
+
+class MockSparkAllocationListener : public SpillableAllocationListener {
+ public:
+  explicit MockSparkAllocationListener(
+      MockSparkTaskMemoryManager* const manager,
+      std::function<uint64_t(uint64_t)> shrink,
+      std::function<uint64_t(uint64_t)> spill)
+      : manager_(manager), shrink_(shrink), spill_(spill) {}
+
+  void allocationChanged(int64_t diff) override {
+    if (diff == 0) {
+      return;
+    }
+    if (diff > 0) {
+      auto granted = manager_->acquire(diff);
+      if (granted < diff) {
+        throw std::runtime_error("OOM");
+      }
+      currentBytes_ += granted;
+      return;
+    }
+    manager_->release(-diff);
+    currentBytes_ -= (-diff);
+  }
+
+  uint64_t shrink(uint64_t bytes) override {
+    return shrink_(bytes);
+  }
+
+  uint64_t spill(uint64_t bytes) override {
+    return spill_(bytes);
+  }
+
+  int64_t currentBytes() override {
+    return currentBytes_;
+  }
+
+ private:
+  MockSparkTaskMemoryManager* const manager_;
+  std::function<uint64_t(uint64_t)> shrink_;
+  std::function<uint64_t(uint64_t)> spill_;
+  std::atomic<uint64_t> currentBytes_{0L};
+};
+
+MockSparkTaskMemoryManager::MockSparkTaskMemoryManager(const uint64_t 
maxBytes) : maxBytes_(maxBytes) {}
+
+AllocationListener* MockSparkTaskMemoryManager::newListener(
+    std::function<uint64_t(uint64_t)> shrink,
+    std::function<uint64_t(uint64_t)> spill) {
+  listeners_.push_back(std::make_unique<MockSparkAllocationListener>(this, 
shrink, spill));
+  return listeners_.back().get();
+}
+
+uint64_t MockSparkTaskMemoryManager::acquire(uint64_t bytes) {
+  std::unique_lock l(mutex_);
+  auto freeBytes = maxBytes_ - currentBytes_;
+  if (bytes <= freeBytes) {
+    currentBytes_ += bytes;
+    return bytes;
+  }
+  // Shrink listeners.
+  int64_t bytesNeeded = bytes - freeBytes;
+  for (const auto& listener : listeners_) {
+    bytesNeeded -= listener->shrink(bytesNeeded);
+    if (bytesNeeded < 0) {
+      break;
+    }
+  }
+  if (bytesNeeded > 0) {
+    for (const auto& listener : listeners_) {
+      bytesNeeded -= listener->spill(bytesNeeded);
+      if (bytesNeeded < 0) {
+        break;
+      }
+    }
+  }
+
+  if (bytesNeeded > 0) {
+    uint64_t granted = bytes - bytesNeeded;
+    currentBytes_ += granted;
+    return granted;
+  }
+
+  currentBytes_ += bytes;
+  return bytes;
+}
+
+void MockSparkTaskMemoryManager::release(uint64_t bytes) {
+  std::unique_lock l(mutex_);
+  currentBytes_ -= bytes;
+}
+
+class MockMemoryReclaimer : public facebook::velox::memory::MemoryReclaimer {
+ public:
+  explicit MockMemoryReclaimer(std::vector<void*>& buffs, int32_t size) : 
buffs_(buffs), size_(size) {}
+
+  bool reclaimableBytes(const memory::MemoryPool& pool, uint64_t& 
reclaimableBytes) const override {
+    uint64_t total = 0;
+    for (const auto& buf : buffs_) {
+      if (buf == nullptr) {
+        continue;
+      }
+      total += size_;
+    }
+    if (total == 0) {
+      return false;
+    }
+    reclaimableBytes = total;
+    return true;
+  }
+
+  uint64_t reclaim(memory::MemoryPool* pool, uint64_t targetBytes, uint64_t 
maxWaitMs, Stats& stats) override {
+    uint64_t total = 0;
+    for (auto& buf : buffs_) {
+      if (buf == nullptr) {
+        // When:
+        // 1. Called by allocation from the same pool so buff is not allocated 
yet.
+        // 2. Already called once.
+        continue;
+      }
+      pool->free(buf, size_);
+      buf = nullptr;
+      total += size_;
+    }
+    return total;
+  }
+
+ private:
+  std::vector<void*>& buffs_;
+  int32_t size_;
+};
+
+void assertCapacitiesMatch(MockSparkTaskMemoryManager& tmm, 
std::vector<std::unique_ptr<VeloxMemoryManager>>& vmms) {
+  uint64_t sum = 0;
+  for (const auto& vmm : vmms) {
+    if (vmm == nullptr) {
+      continue;
+    }
+    sum += vmm->getAggregateMemoryPool()->capacity();
+  }
+  if (tmm.currentBytes() != sum) {
+    ASSERT_EQ(tmm.currentBytes(), sum);
+  }
+}
+} // namespace
+
+class MultiMemoryManagerTest : public ::testing::Test {
+ protected:
+  static void SetUpTestCase() {
+    std::unordered_map<std::string, std::string> conf = {
+        {kMemoryReservationBlockSize, 
std::to_string(kMemoryReservationBlockSizeDefault)},
+        {kVeloxMemInitCapacity, std::to_string(kVeloxMemInitCapacityDefault)}};
+    gluten::VeloxBackend::create(conf);
+  }
+
+  std::unique_ptr<VeloxMemoryManager> 
newVeloxMemoryManager(std::unique_ptr<AllocationListener> listener) {
+    return std::make_unique<VeloxMemoryManager>(std::move(listener));
+  }
+};
+
+TEST_F(MultiMemoryManagerTest, spill) {
+  const uint64_t maxBytes = 200 << 20;
+  const uint32_t numThreads = 100;
+  const uint32_t numAllocations = 200;
+  const int32_t allocateSize = 10 << 20;
+
+  MockSparkTaskMemoryManager tmm{maxBytes};
+  std::vector<std::unique_ptr<VeloxMemoryManager>> vmms{};
+  std::vector<std::thread> threads{};
+  std::vector<std::vector<void*>> buffs{};
+  for (size_t i = 0; i < numThreads; ++i) {
+    buffs.push_back({});
+    vmms.emplace_back(nullptr);
+  }
+
+  // Emulate a shared lock to avoid ABBA deadlock.
+  std::recursive_mutex mutex;
+
+  for (size_t i = 0; i < numThreads; ++i) {
+    threads.emplace_back([this, i, allocateSize, &tmm, &vmms, &mutex, 
&buffs]() -> void {
+      auto wrapper = std::make_unique<AllocationListenerWrapper>(); // Set 
later.
+      auto* listener = wrapper.get();
+
+      facebook::velox::memory::MemoryPool* pool; // Set later.
+      {
+        std::unique_lock<std::recursive_mutex> l(mutex);
+        vmms[i] = newVeloxMemoryManager(std::move(wrapper));
+        pool = vmms[i]->getLeafMemoryPool().get();
+        pool->setReclaimer(std::make_unique<MockMemoryReclaimer>(buffs[i], 
allocateSize));
+        listener->set(tmm.newListener(
+            [](uint64_t bytes) -> uint64_t { return 0; },
+            [i, &vmms, &mutex](uint64_t bytes) -> uint64_t {
+              std::unique_lock<std::recursive_mutex> l(mutex);
+              return 
vmms[i]->getMemoryManager()->arbitrator()->shrinkCapacity(bytes);
+            }));
+      }
+      {
+        std::unique_lock<std::recursive_mutex> l(mutex);
+        for (size_t j = 0; j < numAllocations; ++j) {
+          assertCapacitiesMatch(tmm, vmms);
+          buffs[i].push_back(pool->allocate(allocateSize));
+          assertCapacitiesMatch(tmm, vmms);
+        }
+      }
+    });
+  }
+
+  for (auto& thread : threads) {
+    thread.join();
+  }
+
+  for (auto& vmm : vmms) {
+    assertCapacitiesMatch(tmm, vmms);
+    vmm->getMemoryManager()->arbitrator()->shrinkCapacity(allocateSize * 
numAllocations);
+    assertCapacitiesMatch(tmm, vmms);
+  }
+
+  ASSERT_EQ(tmm.currentBytes(), 0);
+}
 } // namespace gluten


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to