wForget commented on code in PR #7648:
URL: https://github.com/apache/incubator-gluten/pull/7648#discussion_r1816536045
##########
cpp/velox/tests/MemoryManagerTest.cc:
##########
@@ -129,4 +129,280 @@ TEST_F(MemoryManagerTest,
memoryAllocatorWithBlockReservation) {
ASSERT_EQ(allocator_->getBytes(), 0);
}
+namespace {
+class AllocationListenerWrapper : public AllocationListener {
+ public:
+ explicit AllocationListenerWrapper() {}
+
+ void set(AllocationListener* const delegate) {
+ if (delegate_ != nullptr) {
+ throw std::runtime_error("Invalid state");
+ }
+ delegate_ = delegate;
+ }
+
+ void allocationChanged(int64_t diff) override {
+ if (delegate_ == nullptr) {
+ if (diff > 0) {
+ throw std::runtime_error("changed without delegate");
+ }
+ return;
+ }
+ delegate_->allocationChanged(diff);
+ }
+ int64_t currentBytes() override {
+ return delegate_->currentBytes();
+ }
+ int64_t peakBytes() override {
+ return delegate_->peakBytes();
+ }
+
+ private:
+ AllocationListener* delegate_{nullptr};
+};
+
+class SpillableAllocationListener : public AllocationListener {
+ public:
+ virtual uint64_t shrink(uint64_t bytes) = 0;
+ virtual uint64_t spill(uint64_t bytes) = 0;
+};
+
+class MockSparkTaskMemoryManager {
+ public:
+ explicit MockSparkTaskMemoryManager(const uint64_t maxBytes);
+
+ AllocationListener* newListener(std::function<uint64_t(uint64_t)> shrink,
std::function<uint64_t(uint64_t)> spill);
+
+ uint64_t acquire(uint64_t bytes);
+ void release(uint64_t bytes);
+ uint64_t currentBytes() {
+ return currentBytes_;
+ }
+
+ private:
+ mutable std::recursive_mutex mutex_;
+ std::vector<std::unique_ptr<SpillableAllocationListener>> listeners_{};
+
+ const uint64_t maxBytes_;
+ uint64_t currentBytes_{0L};
+};
+
+class MockSparkAllocationListener : public SpillableAllocationListener {
+ public:
+ explicit MockSparkAllocationListener(
+ MockSparkTaskMemoryManager* const manager,
+ std::function<uint64_t(uint64_t)> shrink,
+ std::function<uint64_t(uint64_t)> spill)
+ : manager_(manager), shrink_(shrink), spill_(spill) {}
+
+ void allocationChanged(int64_t diff) override {
+ if (diff == 0) {
+ return;
+ }
+ if (diff > 0) {
+ auto granted = manager_->acquire(diff);
+ if (granted < diff) {
+ throw std::runtime_error("OOM");
+ }
+ currentBytes_ += granted;
+ return;
+ }
+ manager_->release(-diff);
+ currentBytes_ -= (-diff);
+ }
+
+ uint64_t shrink(uint64_t bytes) override {
+ return shrink_(bytes);
+ }
+
+ uint64_t spill(uint64_t bytes) override {
+ return spill_(bytes);
+ }
+
+ int64_t currentBytes() override {
+ return currentBytes_;
+ }
+
+ private:
+ MockSparkTaskMemoryManager* const manager_;
+ std::function<uint64_t(uint64_t)> shrink_;
+ std::function<uint64_t(uint64_t)> spill_;
+ std::atomic<uint64_t> currentBytes_{0L};
+};
+
+MockSparkTaskMemoryManager::MockSparkTaskMemoryManager(const uint64_t
maxBytes) : maxBytes_(maxBytes) {}
+
+AllocationListener* MockSparkTaskMemoryManager::newListener(
+ std::function<uint64_t(uint64_t)> shrink,
+ std::function<uint64_t(uint64_t)> spill) {
+ listeners_.push_back(std::make_unique<MockSparkAllocationListener>(this,
shrink, spill));
+ return listeners_.back().get();
+}
+
+uint64_t MockSparkTaskMemoryManager::acquire(uint64_t bytes) {
+ std::unique_lock l(mutex_);
+ auto freeBytes = maxBytes_ - currentBytes_;
+ if (bytes <= freeBytes) {
+ currentBytes_ += bytes;
+ return bytes;
+ }
+ // Shrink listeners.
+ int64_t bytesNeeded = bytes - freeBytes;
+ for (const auto& listener : listeners_) {
+ bytesNeeded -= listener->shrink(bytesNeeded);
+ if (bytesNeeded < 0) {
+ break;
+ }
+ }
+ if (bytesNeeded > 0) {
+ for (const auto& listener : listeners_) {
+ bytesNeeded -= listener->spill(bytesNeeded);
+ if (bytesNeeded < 0) {
+ break;
+ }
+ }
+ }
+
+ if (bytesNeeded > 0) {
+ uint64_t granted = bytes - bytesNeeded;
+ currentBytes_ += granted;
+ return granted;
+ }
+
+ currentBytes_ += bytes;
+ return bytes;
+}
+
+void MockSparkTaskMemoryManager::release(uint64_t bytes) {
+ std::unique_lock l(mutex_);
+ currentBytes_ -= bytes;
+}
+
+class MockMemoryReclaimer : public facebook::velox::memory::MemoryReclaimer {
+ public:
+ explicit MockMemoryReclaimer(std::vector<void*>& buffs, int32_t size) :
buffs_(buffs), size_(size) {}
+
+ bool reclaimableBytes(const memory::MemoryPool& pool, uint64_t&
reclaimableBytes) const override {
+ uint64_t total = 0;
+ for (const auto& buf : buffs_) {
+ if (buf == nullptr) {
+ continue;
+ }
+ total += size_;
+ }
+ if (total == 0) {
+ return false;
+ }
+ reclaimableBytes = total;
+ return true;
+ }
+
+ uint64_t reclaim(memory::MemoryPool* pool, uint64_t targetBytes, uint64_t
maxWaitMs, Stats& stats) override {
+ uint64_t total = 0;
+ for (auto& buf : buffs_) {
+ if (buf == nullptr) {
+ // When:
+ // 1. Called by allocation from the same pool so buff is not allocated
yet.
+ // 2. Already called once.
+ continue;
+ }
+ pool->free(buf, size_);
+ buf = nullptr;
+ total += size_;
+ }
+ return total;
+ }
+
+ private:
+ std::vector<void*>& buffs_;
+ int32_t size_;
+};
+
+void assertCapacitiesMatch(MockSparkTaskMemoryManager& tmm,
std::vector<std::unique_ptr<VeloxMemoryManager>>& vmms) {
+ uint64_t sum = 0;
+ for (const auto& vmm : vmms) {
+ if (vmm == nullptr) {
+ continue;
+ }
+ sum += vmm->getAggregateMemoryPool()->capacity();
+ }
+ if (tmm.currentBytes() != sum) {
+ ASSERT_EQ(tmm.currentBytes(), sum);
+ }
+}
+} // namespace
+
+class MultiMemoryManagerTest : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() {
+ std::unordered_map<std::string, std::string> conf = {
+ {kMemoryReservationBlockSize,
std::to_string(kMemoryReservationBlockSizeDefault)},
+ {kVeloxMemInitCapacity, std::to_string(0)}};
Review Comment:
set veloxMemInitCapacity value to `0`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]