This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 33d2f2d31 [VL] Following #6526, minor fixes and improvements (#6554)
33d2f2d31 is described below

commit 33d2f2d31ec223303cdaf38c88d8478ded36dbf2
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Jul 24 13:44:38 2024 +0800

    [VL] Following #6526, minor fixes and improvements (#6554)
---
 cpp/core/jni/JniCommon.h             | 22 +++++++++++------
 cpp/core/memory/AllocationListener.h | 47 +++++++++++++++++++-----------------
 cpp/core/memory/MemoryAllocator.cc   | 13 ++++++++--
 3 files changed, 51 insertions(+), 31 deletions(-)

diff --git a/cpp/core/jni/JniCommon.h b/cpp/core/jni/JniCommon.h
index 8f8002b2c..1d784f3a5 100644
--- a/cpp/core/jni/JniCommon.h
+++ b/cpp/core/jni/JniCommon.h
@@ -403,17 +403,25 @@ class SparkAllocationListener final : public 
gluten::AllocationListener {
       env->CallLongMethod(jListenerGlobalRef_, jReserveMethod_, size);
       checkException(env);
     }
-    // atomic operation is enough here, no need to use mutex
-    bytesReserved_.fetch_add(size);
-    maxBytesReserved_.store(std::max(bytesReserved_.load(), 
maxBytesReserved_.load()));
+    usedBytes_ += size;
+    while (true) {
+      int64_t savedPeakBytes = peakBytes_;
+      if (usedBytes_ <= savedPeakBytes) {
+        break;
+      }
+      // usedBytes_ > savedPeakBytes, update peak
+      if (peakBytes_.compare_exchange_weak(savedPeakBytes, usedBytes_)) {
+        break;
+      }
+    }
   }
 
   int64_t currentBytes() override {
-    return bytesReserved_;
+    return usedBytes_;
   }
 
   int64_t peakBytes() override {
-    return maxBytesReserved_;
+    return peakBytes_;
   }
 
  private:
@@ -421,8 +429,8 @@ class SparkAllocationListener final : public 
gluten::AllocationListener {
   jobject jListenerGlobalRef_;
   const jmethodID jReserveMethod_;
   const jmethodID jUnreserveMethod_;
-  std::atomic_int64_t bytesReserved_{0L};
-  std::atomic_int64_t maxBytesReserved_{0L};
+  std::atomic_int64_t usedBytes_{0L};
+  std::atomic_int64_t peakBytes_{0L};
 };
 
 class BacktraceAllocationListener final : public gluten::AllocationListener {
diff --git a/cpp/core/memory/AllocationListener.h 
b/cpp/core/memory/AllocationListener.h
index 695552cef..41797641f 100644
--- a/cpp/core/memory/AllocationListener.h
+++ b/cpp/core/memory/AllocationListener.h
@@ -50,32 +50,18 @@ class AllocationListener {
 // The class must be thread safe
 class BlockAllocationListener final : public AllocationListener {
  public:
-  BlockAllocationListener(AllocationListener* delegated, uint64_t blockSize)
+  BlockAllocationListener(AllocationListener* delegated, int64_t blockSize)
       : delegated_(delegated), blockSize_(blockSize) {}
 
   void allocationChanged(int64_t diff) override {
     if (diff == 0) {
       return;
     }
-    std::unique_lock<std::mutex> guard{mutex_};
-    if (diff > 0) {
-      if (reservationBytes_ - usedBytes_ < diff) {
-        auto roundSize = (diff + (blockSize_ - 1)) / blockSize_ * blockSize_;
-        reservationBytes_ += roundSize;
-        peakBytes_ = std::max(peakBytes_, reservationBytes_);
-        guard.unlock();
-        // unnecessary to lock the delegated listener, assume it's thread safe
-        delegated_->allocationChanged(roundSize);
-      }
-      usedBytes_ += diff;
-    } else {
-      usedBytes_ += diff;
-      auto unreservedSize = (reservationBytes_ - usedBytes_) / blockSize_ * 
blockSize_;
-      reservationBytes_ -= unreservedSize;
-      guard.unlock();
-      // unnecessary to lock the delegated listener
-      delegated_->allocationChanged(-unreservedSize);
+    int64_t granted = reserve(diff);
+    if (granted == 0) {
+      return;
     }
+    delegated_->allocationChanged(granted);
   }
 
   int64_t currentBytes() override {
@@ -87,11 +73,28 @@ class BlockAllocationListener final : public 
AllocationListener {
   }
 
  private:
+  inline int64_t reserve(int64_t diff) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    usedBytes_ += diff;
+    int64_t newBlockCount;
+    if (usedBytes_ == 0) {
+      newBlockCount = 0;
+    } else {
+      // ceil to get the required block number
+      newBlockCount = (usedBytes_ - 1) / blockSize_ + 1;
+    }
+    int64_t bytesGranted = (newBlockCount - blocksReserved_) * blockSize_;
+    blocksReserved_ = newBlockCount;
+    peakBytes_ = std::max(peakBytes_, usedBytes_);
+    return bytesGranted;
+  }
+
   AllocationListener* const delegated_;
   const uint64_t blockSize_;
-  uint64_t usedBytes_{0L};
-  uint64_t peakBytes_{0L};
-  uint64_t reservationBytes_{0L};
+  int64_t blocksReserved_{0L};
+  int64_t usedBytes_{0L};
+  int64_t peakBytes_{0L};
+  int64_t reservationBytes_{0L};
 
   mutable std::mutex mutex_;
 };
diff --git a/cpp/core/memory/MemoryAllocator.cc 
b/cpp/core/memory/MemoryAllocator.cc
index 01818636a..c637c6a9c 100644
--- a/cpp/core/memory/MemoryAllocator.cc
+++ b/cpp/core/memory/MemoryAllocator.cc
@@ -92,8 +92,17 @@ int64_t ListenableMemoryAllocator::peakBytes() const {
 
 void ListenableMemoryAllocator::updateUsage(int64_t size) {
   listener_->allocationChanged(size);
-  usedBytes_.fetch_add(size);
-  peakBytes_.store(std::max(peakBytes_.load(), usedBytes_.load()));
+  usedBytes_ += size;
+  while (true) {
+    int64_t savedPeakBytes = peakBytes_;
+    if (usedBytes_ <= savedPeakBytes) {
+      break;
+    }
+    // usedBytes_ > savedPeakBytes, update peak
+    if (peakBytes_.compare_exchange_weak(savedPeakBytes, usedBytes_)) {
+      break;
+    }
+  }
 }
 
 bool StdMemoryAllocator::allocate(int64_t size, void** out) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to