This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 9131c1e07 [CELEBORN-1792] MemoryManager resume should use 
pinnedDirectMemory instead of usedDirectMemory
9131c1e07 is described below

commit 9131c1e07a321e4b54685c166b40ee0ecc762e68
Author: Xianming Lei <[email protected]>
AuthorDate: Wed Jan 22 14:30:20 2025 +0800

    [CELEBORN-1792] MemoryManager resume should use pinnedDirectMemory instead 
of usedDirectMemory
    
    ### What changes were proposed in this pull request?
    Congestion and MemoryManager should use pinnedDirectMemory instead of 
usedDirectMemory
    
    ### Why are the changes needed?
    In our production environment, after worker pausing, the usedDirectMemory 
keep high and does not decrease. The worker node is permanently blacklisted and 
cannot be used.
    
    This problem has been bothering us for a long time. When the thred cache is 
turned off, in fact, **after ctx.channel().config().setAutoRead(false), the 
netty framework will still hold some ByteBufs**. This part of ByteBuf result in 
a lot of PoolChunks cannot be released.
    
    In netty, if a chunk is 16M and 8k of this chunk has been allocated, then 
the pinnedMemory is 8k and the activeMemory is 16M. The remaining (16M-8k) 
memory can be allocated, but not yet allocated, netty allocates and releases 
memory in chunk units, so the 8k that has been allocated will result in 16M 
that cannot be returned to the operating system.
    
    Here are some scenes from our production/test environment:
    
    We config 10gb off-heap memory for worker, other configs as below:
    ```
    celeborn.network.memory.allocator.allowCache                         false
    celeborn.worker.monitor.memory.check.interval                         100ms
    celeborn.worker.monitor.memory.report.interval                        10s
    celeborn.worker.directMemoryRatioToPauseReceive                       0.75
    celeborn.worker.directMemoryRatioToPauseReplicate                     0.85
    celeborn.worker.directMemoryRatioToResume                             0.5
    ```
    
    When receiving high traffic, the worker's usedDirectMemory increases. After 
triggering trim and pause, usedDirectMemory still does not reach the resume 
threshold, and worker was excluded.
    
    
![image](https://github.com/user-attachments/assets/40f5609e-fbf9-4841-84ec-69a69256edf4)
    
    So we checked the heap snapshot of the abnormal worker, we can see that 
there are a large number of DirectByteBuffers in the heap memory. These 
DirectByteBuffers are all 4mb in size, which is exactly the size of chunksize. 
According to the path to gc root, DirectByteBuffer is held by PoolChunk, and 
these 4m only have 160k pinnedBytes.
    
    
![image](https://github.com/user-attachments/assets/3d755ef3-164c-4b5b-bec1-aaf039c0c0a5)
    
    
![image](https://github.com/user-attachments/assets/17907753-2f42-4617-a95e-1ee980553fb9)
    
    There are many ByteBufs that are not released
    
    
![image](https://github.com/user-attachments/assets/b87eb1a9-313f-4f42-baa8-227fd49c19b6)
    
    The stack shows that these ByteBufs are allocated by netty
    
![image](https://github.com/user-attachments/assets/f8783f99-507a-44a8-9de5-7215a5eed1db)
    
    We tried to reproduce this situation in the test environment. When the same 
problem occurred, we added a restful api of the worker to force the worker to 
resume. After the resume, the worker returned to normal, and PushDataHandler 
handled many delayed requests.
    
    
![image](https://github.com/user-attachments/assets/be37039b-97b8-4ae8-a64f-a2003bea613e)
    
    
![image](https://github.com/user-attachments/assets/24b1c8ad-131c-4bd6-adcb-bad655cfbdbf)
    
    So I think that when pinnedMemory is not high enough, we should not trigger 
pause and congestion, because at this time a large part of the memory can still 
be allocated.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing UTs.
    
    Closes #3018 from leixm/CELEBORN-1792.
    
    Authored-by: Xianming Lei <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../celeborn/common/network/util/NettyUtils.java   |  14 ++
 .../org/apache/celeborn/common/CelebornConf.scala  |  31 ++++
 docs/configuration/worker.md                       |   3 +
 .../deploy/worker/memory/MemoryManager.java        | 158 +++++++++++++++------
 .../service/deploy/memory/MemoryManagerSuite.scala |  67 ++++++++-
 5 files changed, 224 insertions(+), 49 deletions(-)

diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java 
b/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java
index 596b80785..55a018914 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java
@@ -17,8 +17,10 @@
 
 package org.apache.celeborn.common.network.util;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ThreadFactory;
@@ -47,6 +49,8 @@ public class NettyUtils {
   private static final ByteBufAllocator[] _sharedByteBufAllocator = new 
ByteBufAllocator[2];
   private static final ConcurrentHashMap<String, Integer> allocatorsIndex =
       JavaUtils.newConcurrentHashMap();
+  private static final List<PooledByteBufAllocator> pooledByteBufAllocators = 
new ArrayList<>();
+
   /** Creates a new ThreadFactory which prefixes each thread with the given 
name. */
   public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
     return new DefaultThreadFactory(threadPoolPrefix, true);
@@ -141,6 +145,9 @@ public class NettyUtils {
       _sharedByteBufAllocator[index] =
           createByteBufAllocator(
               conf.networkMemoryAllocatorPooled(), true, allowCache, 
conf.networkAllocatorArenas());
+      if (conf.networkMemoryAllocatorPooled()) {
+        pooledByteBufAllocators.add((PooledByteBufAllocator) 
_sharedByteBufAllocator[index]);
+      }
       if (source != null) {
         new NettyMemoryMetrics(
             _sharedByteBufAllocator[index],
@@ -178,6 +185,9 @@ public class NettyUtils {
             conf.preferDirectBufs(),
             allowCache,
             arenas);
+    if (conf.getCelebornConf().networkMemoryAllocatorPooled()) {
+      pooledByteBufAllocators.add((PooledByteBufAllocator) allocator);
+    }
     if (source != null) {
       String poolName = "default-netty-pool";
       Map<String, String> labels = new HashMap<>();
@@ -196,4 +206,8 @@ public class NettyUtils {
     }
     return allocator;
   }
+
+  public static List<PooledByteBufAllocator> getAllPooledByteBufAllocators() {
+    return pooledByteBufAllocators;
+  }
 }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 791c6fc2f..f18cc272f 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1278,9 +1278,12 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
   def workerDirectMemoryRatioToPauseReplicate: Double =
     get(WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE)
   def workerDirectMemoryRatioToResume: Double = 
get(WORKER_DIRECT_MEMORY_RATIO_RESUME)
+  def workerPinnedMemoryRatioToResume: Double = 
get(WORKER_PINNED_MEMORY_RATIO_RESUME)
   def workerPartitionSorterDirectMemoryRatioThreshold: Double =
     get(WORKER_PARTITION_SORTER_DIRECT_MEMORY_RATIO_THRESHOLD)
   def workerDirectMemoryPressureCheckIntervalMs: Long = 
get(WORKER_DIRECT_MEMORY_CHECK_INTERVAL)
+  def workerPinnedMemoryCheckEnabled: Boolean = 
get(WORKER_PINNED_MEMORY_CHECK_ENABLED)
+  def workerPinnedMemoryCheckIntervalMs: Long = 
get(WORKER_PINNED_MEMORY_CHECK_INTERVAL)
   def workerDirectMemoryReportIntervalSecond: Long = 
get(WORKER_DIRECT_MEMORY_REPORT_INTERVAL)
   def workerDirectMemoryTrimChannelWaitInterval: Long =
     get(WORKER_DIRECT_MEMORY_TRIM_CHANNEL_WAIT_INTERVAL)
@@ -3711,6 +3714,24 @@ object CelebornConf extends Logging {
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("10ms")
 
+  val WORKER_PINNED_MEMORY_CHECK_ENABLED: ConfigEntry[Boolean] =
+    buildConf("celeborn.worker.monitor.pinnedMemory.check.enabled")
+      .categories("worker")
+      .doc("If true, MemoryManager will check worker should resume by pinned 
memory used.")
+      .version("0.6.0")
+      .booleanConf
+      .createWithDefaultString("true")
+
+  val WORKER_PINNED_MEMORY_CHECK_INTERVAL: ConfigEntry[Long] =
+    buildConf("celeborn.worker.monitor.pinnedMemory.check.interval")
+      .categories("worker")
+      .doc("Interval of worker direct pinned memory checking, " +
+        "only takes effect when celeborn.network.memory.allocator.pooled and " 
+
+        "celeborn.worker.monitor.pinnedMemory.check.enabled are enabled.")
+      .version("0.6.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("10s")
+
   val WORKER_DIRECT_MEMORY_REPORT_INTERVAL: ConfigEntry[Long] =
     buildConf("celeborn.worker.monitor.memory.report.interval")
       .withAlternative("celeborn.worker.memory.reportInterval")
@@ -3860,6 +3881,16 @@ object CelebornConf extends Logging {
       .doubleConf
       .createWithDefault(0.7)
 
+  val WORKER_PINNED_MEMORY_RATIO_RESUME: ConfigEntry[Double] =
+    buildConf("celeborn.worker.pinnedMemoryRatioToResume")
+      .categories("worker")
+      .doc("If pinned memory usage is less than this limit, worker will 
resume, " +
+        "only takes effect when celeborn.network.memory.allocator.pooled and " 
+
+        "celeborn.worker.monitor.pinnedMemory.check.enabled are enabled")
+      .version("0.6.0")
+      .doubleConf
+      .createWithDefault(0.3)
+
   val WORKER_MEMORY_FILE_STORAGE_MAX_FILE_SIZE: ConfigEntry[Long] =
     buildConf("celeborn.worker.memoryFileStorage.maxFileSize")
       .categories("worker")
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 14c7e791c..b4018b35e 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -144,9 +144,12 @@ license: |
 | celeborn.worker.monitor.memory.report.interval | 10s | false | Interval of 
worker direct memory tracker reporting to log. | 0.3.0 | 
celeborn.worker.memory.reportInterval | 
 | celeborn.worker.monitor.memory.trimChannelWaitInterval | 1s | false | Wait 
time after worker trigger channel to trim cache. | 0.3.0 |  | 
 | celeborn.worker.monitor.memory.trimFlushWaitInterval | 1s | false | Wait 
time after worker trigger StorageManger to flush data. | 0.3.0 |  | 
+| celeborn.worker.monitor.pinnedMemory.check.enabled | true | false | If true, 
MemoryManager will check worker should resume by pinned memory used. | 0.6.0 |  
| 
+| celeborn.worker.monitor.pinnedMemory.check.interval | 10s | false | Interval 
of worker direct pinned memory checking, only takes effect when 
celeborn.network.memory.allocator.pooled and 
celeborn.worker.monitor.pinnedMemory.check.enabled are enabled. | 0.6.0 |  | 
 | celeborn.worker.partition.initial.readBuffersMax | 1024 | false | Max number 
of initial read buffers | 0.3.0 |  | 
 | celeborn.worker.partition.initial.readBuffersMin | 1 | false | Min number of 
initial read buffers | 0.3.0 |  | 
 | celeborn.worker.partitionSorter.directMemoryRatioThreshold | 0.1 | false | 
Max ratio of partition sorter's memory for sorting, when reserved memory is 
higher than max partition sorter memory, partition sorter will stop sorting. If 
this value is set to 0, partition files sorter will skip memory check and 
ServingState check. | 0.2.0 |  | 
+| celeborn.worker.pinnedMemoryRatioToResume | 0.3 | false | If pinned memory 
usage is less than this limit, worker will resume, only takes effect when 
celeborn.network.memory.allocator.pooled and 
celeborn.worker.monitor.pinnedMemory.check.enabled are enabled | 0.6.0 |  | 
 | celeborn.worker.push.heartbeat.enabled | false | false | enable the 
heartbeat from worker to client when pushing data | 0.3.0 |  | 
 | celeborn.worker.push.io.threads | &lt;undefined&gt; | false | Netty IO 
thread number of worker to handle client push data. The default threads number 
is the number of flush thread. | 0.2.0 |  | 
 | celeborn.worker.push.port | 0 | false | Server port for Worker to receive 
push data request from ShuffleClient. | 0.2.0 |  | 
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
index 6db63598e..31d2e47e8 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
@@ -30,12 +30,14 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.PooledByteBufAllocator;
 import io.netty.util.internal.PlatformDependent;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.metrics.source.AbstractSource;
+import org.apache.celeborn.common.network.util.NettyUtils;
 import org.apache.celeborn.common.protocol.TransportModuleConstants;
 import org.apache.celeborn.common.util.ThreadUtils;
 import org.apache.celeborn.common.util.Utils;
@@ -50,7 +52,8 @@ public class MemoryManager {
   @VisibleForTesting public long maxDirectMemory;
   private final long pausePushDataThreshold;
   private final long pauseReplicateThreshold;
-  private final double resumeRatio;
+  private final double directMemoryResumeRatio;
+  private final double pinnedMemoryResumeRatio;
   private final long maxSortMemory;
   private final int forceAppendPauseSpentTimeThreshold;
   private final List<MemoryPressureListener> memoryPressureListeners = new 
ArrayList<>();
@@ -93,6 +96,9 @@ public class MemoryManager {
   private long memoryFileStorageThreshold;
   private final LongAdder memoryFileStorageCounter = new LongAdder();
   private final StorageManager storageManager;
+  private boolean pinnedMemoryCheckEnabled;
+  private long pinnedMemoryCheckInterval;
+  private long pinnedMemoryLastCheckTime = 0;
 
   @VisibleForTesting
   public static MemoryManager initialize(CelebornConf conf) {
@@ -120,11 +126,14 @@ public class MemoryManager {
   private MemoryManager(CelebornConf conf, StorageManager storageManager, 
AbstractSource source) {
     double pausePushDataRatio = conf.workerDirectMemoryRatioToPauseReceive();
     double pauseReplicateRatio = 
conf.workerDirectMemoryRatioToPauseReplicate();
-    this.resumeRatio = conf.workerDirectMemoryRatioToResume();
+    this.directMemoryResumeRatio = conf.workerDirectMemoryRatioToResume();
+    this.pinnedMemoryResumeRatio = conf.workerPinnedMemoryRatioToResume();
     double maxSortMemRatio = 
conf.workerPartitionSorterDirectMemoryRatioThreshold();
     double readBufferRatio = conf.workerDirectMemoryRatioForReadBuffer();
     double memoryFileStorageRatio = 
conf.workerDirectMemoryRatioForMemoryFilesStorage();
     long checkInterval = conf.workerDirectMemoryPressureCheckIntervalMs();
+    this.pinnedMemoryCheckEnabled = conf.workerPinnedMemoryCheckEnabled();
+    this.pinnedMemoryCheckInterval = conf.workerPinnedMemoryCheckIntervalMs();
     long reportInterval = conf.workerDirectMemoryReportIntervalSecond();
     double readBufferTargetRatio = conf.readBufferTargetRatio();
     long readBufferTargetUpdateInterval = 
conf.readBufferTargetUpdateInterval();
@@ -148,9 +157,10 @@ public class MemoryManager {
             pauseReplicateRatio,
             CelebornConf.WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE().key(),
             pausePushDataRatio));
-    Preconditions.checkArgument(pausePushDataRatio > resumeRatio);
+    Preconditions.checkArgument(pausePushDataRatio > directMemoryResumeRatio);
     if (memoryFileStorageRatio > 0) {
-      Preconditions.checkArgument(resumeRatio > (readBufferRatio + 
memoryFileStorageRatio));
+      Preconditions.checkArgument(
+          directMemoryResumeRatio > (readBufferRatio + 
memoryFileStorageRatio));
     }
 
     maxSortMemory = ((long) (maxDirectMemory * maxSortMemRatio));
@@ -275,14 +285,16 @@ public class MemoryManager {
             + "pause replication memory: {},  "
             + "read buffer memory limit: {} target: {}, "
             + "memory shuffle storage limit: {}, "
-            + "resume memory ratio: {}",
+            + "resume by direct memory ratio: {}, "
+            + "resume by pinned memory ratio: {}",
         Utils.bytesToString(maxDirectMemory),
         Utils.bytesToString(pausePushDataThreshold),
         Utils.bytesToString(pauseReplicateThreshold),
         Utils.bytesToString(readBufferThreshold),
         Utils.bytesToString(readBufferTarget),
         Utils.bytesToString(memoryFileStorageThreshold),
-        resumeRatio);
+        directMemoryResumeRatio,
+        pinnedMemoryResumeRatio);
   }
 
   public boolean shouldEvict(boolean aggressiveMemoryFileEvictEnabled, double 
evictRatio) {
@@ -305,7 +317,7 @@ public class MemoryManager {
       return ServingState.PUSH_PAUSED;
     }
     // trigger resume
-    if (memoryUsage / (double) (maxDirectMemory) < resumeRatio) {
+    if (memoryUsage / (double) (maxDirectMemory) < directMemoryResumeRatio) {
       isPaused = false;
       return ServingState.NONE_PAUSED;
     }
@@ -315,69 +327,70 @@ public class MemoryManager {
   }
 
   @VisibleForTesting
-  protected void switchServingState() {
+  public void switchServingState() {
     ServingState lastState = servingState;
     servingState = currentServingState();
-    if (lastState == servingState) {
-      if (servingState != ServingState.NONE_PAUSED) {
+    logger.info("Serving state transformed from {} to {}", lastState, 
servingState);
+    switch (servingState) {
+      case PUSH_PAUSED:
+        if (canResumeByPinnedMemory()) {
+          resumeByPinnedMemory(servingState);
+        } else {
+          pausePushDataCounter.increment();
+          if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
+            logger.info("Trigger action: RESUME REPLICATE");
+            resumeReplicate();
+          } else {
+            logger.info("Trigger action: PAUSE PUSH");
+            pausePushDataStartTime = System.currentTimeMillis();
+            memoryPressureListeners.forEach(
+                memoryPressureListener ->
+                    
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
+          }
+        }
         logger.debug("Trigger action: TRIM");
         trimCounter += 1;
-        // force to append pause spent time even we are in pause state
+        trimAllListeners();
         if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
           logger.debug(
               "Trigger action: TRIM for {} times, force to append pause spent 
time.", trimCounter);
           appendPauseSpentTime(servingState);
         }
-        trimAllListeners();
-      }
-      return;
-    }
-    logger.info("Serving state transformed from {} to {}", lastState, 
servingState);
-    switch (servingState) {
-      case PUSH_PAUSED:
-        pausePushDataCounter.increment();
-        if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
-          logger.info("Trigger action: RESUME REPLICATE");
-          memoryPressureListeners.forEach(
-              memoryPressureListener ->
-                  
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
-        } else if (lastState == ServingState.NONE_PAUSED) {
-          logger.info("Trigger action: PAUSE PUSH");
-          pausePushDataStartTime = System.currentTimeMillis();
-          memoryPressureListeners.forEach(
-              memoryPressureListener ->
-                  
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
-        }
-        trimAllListeners();
         break;
       case PUSH_AND_REPLICATE_PAUSED:
-        pausePushDataAndReplicateCounter.increment();
-        if (lastState == ServingState.NONE_PAUSED) {
+        if (canResumeByPinnedMemory()) {
+          resumeByPinnedMemory(servingState);
+        } else {
+          pausePushDataAndReplicateCounter.increment();
           logger.info("Trigger action: PAUSE PUSH");
           pausePushDataAndReplicateStartTime = System.currentTimeMillis();
           memoryPressureListeners.forEach(
               memoryPressureListener ->
                   
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
+          logger.info("Trigger action: PAUSE REPLICATE");
+          memoryPressureListeners.forEach(
+              memoryPressureListener ->
+                  
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
         }
-        logger.info("Trigger action: PAUSE REPLICATE");
-        memoryPressureListeners.forEach(
-            memoryPressureListener ->
-                
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
+        logger.debug("Trigger action: TRIM");
+        trimCounter += 1;
         trimAllListeners();
+        if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
+          logger.debug(
+              "Trigger action: TRIM for {} times, force to append pause spent 
time.", trimCounter);
+          appendPauseSpentTime(servingState);
+        }
         break;
       case NONE_PAUSED:
         // resume from paused mode, append pause spent time
-        appendPauseSpentTime(lastState);
         if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
-          logger.info("Trigger action: RESUME REPLICATE");
-          memoryPressureListeners.forEach(
-              memoryPressureListener ->
-                  
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
+          resumeReplicate();
+          resumePush();
+          appendPauseSpentTime(lastState);
+        } else if (lastState == ServingState.PUSH_PAUSED) {
+          resumePush();
+          appendPauseSpentTime(lastState);
         }
-        logger.info("Trigger action: RESUME PUSH");
-        memoryPressureListeners.forEach(
-            memoryPressureListener ->
-                
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
     }
   }
 
@@ -436,6 +449,16 @@ public class MemoryManager {
     return getNettyUsedDirectMemory() + sortMemoryCounter.get();
   }
 
+  public long getPinnedMemory() {
+    return getNettyPinnedDirectMemory() + sortMemoryCounter.get();
+  }
+
+  public long getNettyPinnedDirectMemory() {
+    return NettyUtils.getAllPooledByteBufAllocators().stream()
+        .mapToLong(PooledByteBufAllocator::pinnedDirectMemory)
+        .sum();
+  }
+
   public AtomicLong getSortMemoryCounter() {
     return sortMemoryCounter;
   }
@@ -557,6 +580,47 @@ public class MemoryManager {
     _INSTANCE = null;
   }
 
+  private void resumeByPinnedMemory(ServingState servingState) {
+    switch (servingState) {
+      case PUSH_AND_REPLICATE_PAUSED:
+        logger.info(
+            "Serving State is PUSH_AND_REPLICATE_PAUSED, but resume by lower 
pinned memory {}",
+            getNettyPinnedDirectMemory());
+        resumeReplicate();
+        resumePush();
+      case PUSH_PAUSED:
+        logger.info(
+            "Serving State is PUSH_PAUSED, but resume by lower pinned memory 
{}",
+            getNettyPinnedDirectMemory());
+        resumePush();
+    }
+  }
+
+  private boolean canResumeByPinnedMemory() {
+    if (pinnedMemoryCheckEnabled
+        && System.currentTimeMillis() - pinnedMemoryLastCheckTime >= 
pinnedMemoryCheckInterval
+        && getPinnedMemory() / (double) (maxDirectMemory) < 
pinnedMemoryResumeRatio) {
+      pinnedMemoryLastCheckTime = System.currentTimeMillis();
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+  private void resumePush() {
+    logger.info("Trigger action: RESUME PUSH");
+    memoryPressureListeners.forEach(
+        memoryPressureListener ->
+            
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
+  }
+
+  private void resumeReplicate() {
+    logger.info("Trigger action: RESUME REPLICATE");
+    memoryPressureListeners.forEach(
+        memoryPressureListener ->
+            
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
+  }
+
   public interface MemoryPressureListener {
     void onPause(String moduleName);
 
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
index 78fc43670..c0fe08e61 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.celeborn.service.deploy.memory
 
 import scala.concurrent.duration.DurationInt
 
+import org.mockito.{Mockito, MockitoSugar}
 import org.scalatest.concurrent.Eventually.eventually
 import org.scalatest.concurrent.Futures.{interval, timeout}
 
@@ -27,8 +28,8 @@ import org.apache.celeborn.common.CelebornConf
 import 
org.apache.celeborn.common.CelebornConf.{WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE,
 WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE}
 import org.apache.celeborn.common.protocol.TransportModuleConstants
 import org.apache.celeborn.service.deploy.worker.memory.MemoryManager
-import 
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.MemoryPressureListener
-import 
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState
+import 
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.{MemoryPressureListener,
 ServingState}
+
 class MemoryManagerSuite extends CelebornFunSuite {
 
   // reset the memory manager before each test
@@ -153,6 +154,68 @@ class MemoryManagerSuite extends CelebornFunSuite {
     assert(memoryManager.getPausePushDataAndReplicateTime.longValue() > 0)
   }
 
+  test("[CELEBORN-1792] Test MemoryManager resume by pinned memory") {
+    val conf = new CelebornConf()
+    conf.set(CelebornConf.WORKER_DIRECT_MEMORY_CHECK_INTERVAL.key, "300s")
+    conf.set(CelebornConf.WORKER_PINNED_MEMORY_CHECK_INTERVAL.key, "0")
+    MemoryManager.reset()
+    val memoryManager = MockitoSugar.spy(MemoryManager.initialize(conf))
+    val maxDirectorMemory = memoryManager.maxDirectMemory
+    val pushThreshold =
+      (conf.workerDirectMemoryRatioToPauseReceive * 
maxDirectorMemory).longValue()
+    val replicateThreshold =
+      (conf.workerDirectMemoryRatioToPauseReplicate * 
maxDirectorMemory).longValue()
+
+    val pushListener = new 
MockMemoryPressureListener(TransportModuleConstants.PUSH_MODULE)
+    val replicateListener =
+      new MockMemoryPressureListener(TransportModuleConstants.REPLICATE_MODULE)
+    memoryManager.registerMemoryListener(pushListener)
+    memoryManager.registerMemoryListener(replicateListener)
+
+    // NONE PAUSED -> PAUSE PUSH
+    Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L)
+    Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
+    memoryManager.switchServingState()
+    assert(!pushListener.isPause)
+    assert(!replicateListener.isPause)
+    assert(memoryManager.servingState == ServingState.PUSH_PAUSED)
+
+    // KEEP PAUSE PUSH
+    
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(pushThreshold 
+ 1)
+    memoryManager.switchServingState()
+    assert(pushListener.isPause)
+    assert(!replicateListener.isPause)
+    assert(memoryManager.servingState == ServingState.PUSH_PAUSED)
+
+    Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L)
+    memoryManager.switchServingState()
+    assert(!pushListener.isPause)
+    assert(!replicateListener.isPause)
+    assert(memoryManager.servingState == ServingState.NONE_PAUSED)
+
+    // NONE PAUSED -> PAUSE PUSH AND REPLICATE
+    Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L)
+    Mockito.when(memoryManager.getMemoryUsage).thenReturn(replicateThreshold + 
1)
+    memoryManager.switchServingState()
+    assert(!pushListener.isPause)
+    assert(!replicateListener.isPause)
+    assert(memoryManager.servingState == 
ServingState.PUSH_AND_REPLICATE_PAUSED)
+
+    // KEEP PAUSE PUSH AND REPLICATE
+    
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(replicateThreshold
 + 1)
+    memoryManager.switchServingState()
+    assert(pushListener.isPause)
+    assert(replicateListener.isPause)
+    assert(memoryManager.servingState == 
ServingState.PUSH_AND_REPLICATE_PAUSED)
+
+    Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L)
+    memoryManager.switchServingState()
+    assert(!pushListener.isPause)
+    assert(!replicateListener.isPause)
+    assert(memoryManager.servingState == ServingState.NONE_PAUSED)
+    MemoryManager.reset()
+  }
+
   class MockMemoryPressureListener(
       val belongModuleName: String,
       var isPause: Boolean = false) extends MemoryPressureListener {

Reply via email to