This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 077a48adb [GLUTEN-6736][VL] Phase 2: Minimize lock scope in
ListenableArbitrator (#6783)
077a48adb is described below
commit 077a48adbe14d71d5cc3e5c1447ea6f6ecff7f33
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Aug 12 17:02:13 2024 +0800
[GLUTEN-6736][VL] Phase 2: Minimize lock scope in ListenableArbitrator
(#6783)
Closes #6736
---
cpp/velox/memory/VeloxMemoryManager.cc | 36 ++---
cpp/velox/tests/MemoryManagerTest.cc | 270 +++++++++++++++++++++++++++++++++
2 files changed, 288 insertions(+), 18 deletions(-)
diff --git a/cpp/velox/memory/VeloxMemoryManager.cc
b/cpp/velox/memory/VeloxMemoryManager.cc
index 3f30d8627..5b2ba258a 100644
--- a/cpp/velox/memory/VeloxMemoryManager.cc
+++ b/cpp/velox/memory/VeloxMemoryManager.cc
@@ -91,33 +91,33 @@ class ListenableArbitrator : public
velox::memory::MemoryArbitrator {
bool growCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes)
override {
velox::memory::ScopedMemoryArbitrationContext ctx(pool);
- VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be
used within a single root pool")
- auto candidate = candidates_.begin()->first;
+ velox::memory::MemoryPool* candidate;
+ {
+ std::unique_lock guard{mutex_};
+ VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only
be used within a single root pool")
+ candidate = candidates_.begin()->first;
+ }
VELOX_CHECK(pool->root() == candidate, "Illegal state in
ListenableArbitrator");
- std::lock_guard<std::recursive_mutex> l(mutex_);
- growCapacityLocked(pool->root(), targetBytes);
+ growCapacity0(pool->root(), targetBytes);
return true;
}
uint64_t shrinkCapacity(uint64_t targetBytes, bool allowSpill, bool
allowAbort) override {
velox::memory::ScopedMemoryArbitrationContext ctx((const
velox::memory::MemoryPool*)nullptr);
facebook::velox::exec::MemoryReclaimer::Stats status;
- VELOX_CHECK_EQ(candidates_.size(), 1, "Gluten only has one root pool");
- std::lock_guard<std::recursive_mutex> l(mutex_); // FIXME: Do we have
recursive locking for this mutex?
- auto pool = candidates_.begin()->first;
- const uint64_t oldCapacity = pool->capacity();
+ velox::memory::MemoryPool* pool;
+ {
+ std::unique_lock guard{mutex_};
+ VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only
be used within a single root pool")
+ pool = candidates_.begin()->first;
+ }
pool->reclaim(targetBytes, 0, status); // ignore the output
- shrinkPool(pool, 0);
- const uint64_t newCapacity = pool->capacity();
- uint64_t total = oldCapacity - newCapacity;
- listener_->allocationChanged(-total);
- return total;
+ return shrinkCapacity0(pool, 0);
}
uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t
targetBytes) override {
- std::lock_guard<std::recursive_mutex> l(mutex_);
- return shrinkCapacityLocked(pool, targetBytes);
+ return shrinkCapacity0(pool, targetBytes);
}
Stats stats() const override {
@@ -130,7 +130,7 @@ class ListenableArbitrator : public
velox::memory::MemoryArbitrator {
}
private:
- void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
+ void growCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
// Since
//
https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820,
// We should pass bytes as parameter "reservationBytes" when calling
::grow.
@@ -152,7 +152,7 @@ class ListenableArbitrator : public
velox::memory::MemoryArbitrator {
pool->toString())
}
- uint64_t shrinkCapacityLocked(velox::memory::MemoryPool* pool, uint64_t
bytes) {
+ uint64_t shrinkCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
uint64_t freeBytes = shrinkPool(pool, bytes);
listener_->allocationChanged(-freeBytes);
return freeBytes;
@@ -162,7 +162,7 @@ class ListenableArbitrator : public
velox::memory::MemoryArbitrator {
const uint64_t memoryPoolInitialCapacity_; // FIXME: Unused.
const uint64_t memoryPoolTransferCapacity_;
- mutable std::recursive_mutex mutex_;
+ mutable std::mutex mutex_;
inline static std::string kind_ = "GLUTEN";
std::unordered_map<velox::memory::MemoryPool*,
std::weak_ptr<velox::memory::MemoryPool>> candidates_;
};
diff --git a/cpp/velox/tests/MemoryManagerTest.cc
b/cpp/velox/tests/MemoryManagerTest.cc
index d86bd46e2..bb102dc2d 100644
--- a/cpp/velox/tests/MemoryManagerTest.cc
+++ b/cpp/velox/tests/MemoryManagerTest.cc
@@ -128,4 +128,274 @@ TEST_F(MemoryManagerTest,
memoryAllocatorWithBlockReservation) {
ASSERT_EQ(allocator_->getBytes(), 0);
}
+namespace {
+class AllocationListenerWrapper : public AllocationListener {
+ public:
+ explicit AllocationListenerWrapper() {}
+
+ void set(AllocationListener* const delegate) {
+ if (delegate_ != nullptr) {
+ throw std::runtime_error("Invalid state");
+ }
+ delegate_ = delegate;
+ }
+
+ void allocationChanged(int64_t diff) override {
+ delegate_->allocationChanged(diff);
+ }
+ int64_t currentBytes() override {
+ return delegate_->currentBytes();
+ }
+ int64_t peakBytes() override {
+ return delegate_->peakBytes();
+ }
+
+ private:
+ AllocationListener* delegate_{nullptr};
+};
+
+class SpillableAllocationListener : public AllocationListener {
+ public:
+ virtual uint64_t shrink(uint64_t bytes) = 0;
+ virtual uint64_t spill(uint64_t bytes) = 0;
+};
+
+class MockSparkTaskMemoryManager {
+ public:
+ explicit MockSparkTaskMemoryManager(const uint64_t maxBytes);
+
+ AllocationListener* newListener(std::function<uint64_t(uint64_t)> shrink,
std::function<uint64_t(uint64_t)> spill);
+
+ uint64_t acquire(uint64_t bytes);
+ void release(uint64_t bytes);
+ uint64_t currentBytes() {
+ return currentBytes_;
+ }
+
+ private:
+ mutable std::recursive_mutex mutex_;
+ std::vector<std::unique_ptr<SpillableAllocationListener>> listeners_{};
+
+ const uint64_t maxBytes_;
+ uint64_t currentBytes_{0L};
+};
+
+class MockSparkAllocationListener : public SpillableAllocationListener {
+ public:
+ explicit MockSparkAllocationListener(
+ MockSparkTaskMemoryManager* const manager,
+ std::function<uint64_t(uint64_t)> shrink,
+ std::function<uint64_t(uint64_t)> spill)
+ : manager_(manager), shrink_(shrink), spill_(spill) {}
+
+ void allocationChanged(int64_t diff) override {
+ if (diff == 0) {
+ return;
+ }
+ if (diff > 0) {
+ auto granted = manager_->acquire(diff);
+ if (granted < diff) {
+ throw std::runtime_error("OOM");
+ }
+ currentBytes_ += granted;
+ return;
+ }
+ manager_->release(-diff);
+ currentBytes_ -= (-diff);
+ }
+
+ uint64_t shrink(uint64_t bytes) override {
+ return shrink_(bytes);
+ }
+
+ uint64_t spill(uint64_t bytes) override {
+ return spill_(bytes);
+ }
+
+ int64_t currentBytes() override {
+ return currentBytes_;
+ }
+
+ private:
+ MockSparkTaskMemoryManager* const manager_;
+ std::function<uint64_t(uint64_t)> shrink_;
+ std::function<uint64_t(uint64_t)> spill_;
+ std::atomic<uint64_t> currentBytes_{0L};
+};
+
+MockSparkTaskMemoryManager::MockSparkTaskMemoryManager(const uint64_t
maxBytes) : maxBytes_(maxBytes) {}
+
+AllocationListener* MockSparkTaskMemoryManager::newListener(
+ std::function<uint64_t(uint64_t)> shrink,
+ std::function<uint64_t(uint64_t)> spill) {
+ listeners_.push_back(std::make_unique<MockSparkAllocationListener>(this,
shrink, spill));
+ return listeners_.back().get();
+}
+
+uint64_t MockSparkTaskMemoryManager::acquire(uint64_t bytes) {
+ std::unique_lock l(mutex_);
+ auto freeBytes = maxBytes_ - currentBytes_;
+ if (bytes <= freeBytes) {
+ currentBytes_ += bytes;
+ return bytes;
+ }
+ // Shrink listeners.
+ int64_t bytesNeeded = bytes - freeBytes;
+ for (const auto& listener : listeners_) {
+ bytesNeeded -= listener->shrink(bytesNeeded);
+ if (bytesNeeded < 0) {
+ break;
+ }
+ }
+ if (bytesNeeded > 0) {
+ for (const auto& listener : listeners_) {
+ bytesNeeded -= listener->spill(bytesNeeded);
+ if (bytesNeeded < 0) {
+ break;
+ }
+ }
+ }
+
+ if (bytesNeeded > 0) {
+ uint64_t granted = bytes - bytesNeeded;
+ currentBytes_ += granted;
+ return granted;
+ }
+
+ currentBytes_ += bytes;
+ return bytes;
+}
+
+void MockSparkTaskMemoryManager::release(uint64_t bytes) {
+ std::unique_lock l(mutex_);
+ currentBytes_ -= bytes;
+}
+
+class MockMemoryReclaimer : public facebook::velox::memory::MemoryReclaimer {
+ public:
+ explicit MockMemoryReclaimer(std::vector<void*>& buffs, int32_t size) :
buffs_(buffs), size_(size) {}
+
+ bool reclaimableBytes(const memory::MemoryPool& pool, uint64_t&
reclaimableBytes) const override {
+ uint64_t total = 0;
+ for (const auto& buf : buffs_) {
+ if (buf == nullptr) {
+ continue;
+ }
+ total += size_;
+ }
+ if (total == 0) {
+ return false;
+ }
+ reclaimableBytes = total;
+ return true;
+ }
+
+ uint64_t reclaim(memory::MemoryPool* pool, uint64_t targetBytes, uint64_t
maxWaitMs, Stats& stats) override {
+ uint64_t total = 0;
+ for (auto& buf : buffs_) {
+ if (buf == nullptr) {
+ // When:
+ // 1. Called by allocation from the same pool so buff is not allocated
yet.
+ // 2. Already called once.
+ continue;
+ }
+ pool->free(buf, size_);
+ buf = nullptr;
+ total += size_;
+ }
+ return total;
+ }
+
+ private:
+ std::vector<void*>& buffs_;
+ int32_t size_;
+};
+
+void assertCapacitiesMatch(MockSparkTaskMemoryManager& tmm,
std::vector<std::unique_ptr<VeloxMemoryManager>>& vmms) {
+ uint64_t sum = 0;
+ for (const auto& vmm : vmms) {
+ if (vmm == nullptr) {
+ continue;
+ }
+ sum += vmm->getAggregateMemoryPool()->capacity();
+ }
+ if (tmm.currentBytes() != sum) {
+ ASSERT_EQ(tmm.currentBytes(), sum);
+ }
+}
+} // namespace
+
+class MultiMemoryManagerTest : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() {
+ std::unordered_map<std::string, std::string> conf = {
+ {kMemoryReservationBlockSize,
std::to_string(kMemoryReservationBlockSizeDefault)},
+ {kVeloxMemInitCapacity, std::to_string(kVeloxMemInitCapacityDefault)}};
+ gluten::VeloxBackend::create(conf);
+ }
+
+ std::unique_ptr<VeloxMemoryManager>
newVeloxMemoryManager(std::unique_ptr<AllocationListener> listener) {
+ return std::make_unique<VeloxMemoryManager>(std::move(listener));
+ }
+};
+
+TEST_F(MultiMemoryManagerTest, spill) {
+ const uint64_t maxBytes = 200 << 20;
+ const uint32_t numThreads = 100;
+ const uint32_t numAllocations = 200;
+ const int32_t allocateSize = 10 << 20;
+
+ MockSparkTaskMemoryManager tmm{maxBytes};
+ std::vector<std::unique_ptr<VeloxMemoryManager>> vmms{};
+ std::vector<std::thread> threads{};
+ std::vector<std::vector<void*>> buffs{};
+ for (size_t i = 0; i < numThreads; ++i) {
+ buffs.push_back({});
+ vmms.emplace_back(nullptr);
+ }
+
+ // Emulate a shared lock to avoid ABBA deadlock.
+ std::recursive_mutex mutex;
+
+ for (size_t i = 0; i < numThreads; ++i) {
+ threads.emplace_back([this, i, allocateSize, &tmm, &vmms, &mutex,
&buffs]() -> void {
+ auto wrapper = std::make_unique<AllocationListenerWrapper>(); // Set
later.
+ auto* listener = wrapper.get();
+
+ facebook::velox::memory::MemoryPool* pool; // Set later.
+ {
+ std::unique_lock<std::recursive_mutex> l(mutex);
+ vmms[i] = newVeloxMemoryManager(std::move(wrapper));
+ pool = vmms[i]->getLeafMemoryPool().get();
+ pool->setReclaimer(std::make_unique<MockMemoryReclaimer>(buffs[i],
allocateSize));
+ listener->set(tmm.newListener(
+ [](uint64_t bytes) -> uint64_t { return 0; },
+ [i, &vmms, &mutex](uint64_t bytes) -> uint64_t {
+ std::unique_lock<std::recursive_mutex> l(mutex);
+ return
vmms[i]->getMemoryManager()->arbitrator()->shrinkCapacity(bytes);
+ }));
+ }
+ {
+ std::unique_lock<std::recursive_mutex> l(mutex);
+ for (size_t j = 0; j < numAllocations; ++j) {
+ assertCapacitiesMatch(tmm, vmms);
+ buffs[i].push_back(pool->allocate(allocateSize));
+ assertCapacitiesMatch(tmm, vmms);
+ }
+ }
+ });
+ }
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ for (auto& vmm : vmms) {
+ assertCapacitiesMatch(tmm, vmms);
+ vmm->getMemoryManager()->arbitrator()->shrinkCapacity(allocateSize *
numAllocations);
+ assertCapacitiesMatch(tmm, vmms);
+ }
+
+ ASSERT_EQ(tmm.currentBytes(), 0);
+}
} // namespace gluten
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]