This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 196ad607c [CELEBORN-1792][FOLLOWUP] Keep resume for a while after 
resumeByPinnedMemory
196ad607c is described below

commit 196ad607cd62af83b1ace887b8eb91d548fc36ac
Author: TheodoreLx <[email protected] >
AuthorDate: Wed Mar 5 09:37:59 2025 +0800

    [CELEBORN-1792][FOLLOWUP] Keep resume for a while after resumeByPinnedMemory
    
    ### What changes were proposed in this pull request?
    
    In the switchServingState after resumeByPinnedMemory, keep the resume 
channel to prevent the channel from frequently resuming and pausing before 
memoryUsage decreases to pausePushDataThreshold.
    
    ### Why are the changes needed?
    
    Frequent channel resume and pause will result in slow data reception and 
failure to quickly reduce memoryUsage to below pausePushDataThreshold.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    ut
    
    Closes #3099 from TheodoreLx/keep-resume.
    
    Lead-authored-by: TheodoreLx <[email protected] >
    Co-authored-by: 慧枫 <[email protected]>
    Co-authored-by: Zhengqi Zhang <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../org/apache/celeborn/common/CelebornConf.scala  |  9 +++
 docs/configuration/worker.md                       |  1 +
 .../deploy/worker/memory/MemoryManager.java        | 76 ++++++++++++++--------
 .../service/deploy/memory/MemoryManagerSuite.scala | 58 +++++++++++++++++
 4 files changed, 117 insertions(+), 27 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 9d6d2e445..00c25b379 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1298,6 +1298,7 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
   def workerDirectMemoryPressureCheckIntervalMs: Long = 
get(WORKER_DIRECT_MEMORY_CHECK_INTERVAL)
   def workerPinnedMemoryCheckEnabled: Boolean = 
get(WORKER_PINNED_MEMORY_CHECK_ENABLED)
   def workerPinnedMemoryCheckIntervalMs: Long = 
get(WORKER_PINNED_MEMORY_CHECK_INTERVAL)
+  def workerPinnedMemoryResumeKeepTime: Long = 
get(WORKER_PINNED_MEMORY_RESUME_KEEP_TIME)
   def workerDirectMemoryReportIntervalSecond: Long = 
get(WORKER_DIRECT_MEMORY_REPORT_INTERVAL)
   def workerDirectMemoryTrimChannelWaitInterval: Long =
     get(WORKER_DIRECT_MEMORY_TRIM_CHANNEL_WAIT_INTERVAL)
@@ -3780,6 +3781,14 @@ object CelebornConf extends Logging {
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("10s")
 
+  val WORKER_PINNED_MEMORY_RESUME_KEEP_TIME: ConfigEntry[Long] =
+    buildConf("celeborn.worker.monitor.pinnedMemory.resumeKeepTime")
+      .categories("worker")
+      .doc("Time of worker to stay in resume state after resumeByPinnedMemory")
+      .version("0.6.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("1s")
+
   val WORKER_DIRECT_MEMORY_REPORT_INTERVAL: ConfigEntry[Long] =
     buildConf("celeborn.worker.monitor.memory.report.interval")
       .withAlternative("celeborn.worker.memory.reportInterval")
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 68a357bf5..a69c435da 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -146,6 +146,7 @@ license: |
 | celeborn.worker.monitor.memory.trimFlushWaitInterval | 1s | false | Wait 
time after worker trigger StorageManger to flush data. | 0.3.0 |  | 
 | celeborn.worker.monitor.pinnedMemory.check.enabled | true | false | If true, 
MemoryManager will check worker should resume by pinned memory used. | 0.6.0 |  
| 
 | celeborn.worker.monitor.pinnedMemory.check.interval | 10s | false | Interval 
of worker direct pinned memory checking, only takes effect when 
celeborn.network.memory.allocator.pooled and 
celeborn.worker.monitor.pinnedMemory.check.enabled are enabled. | 0.6.0 |  | 
+| celeborn.worker.monitor.pinnedMemory.resumeKeepTime | 1s | false | Time of 
worker to stay in resume state after resumeByPinnedMemory | 0.6.0 |  | 
 | celeborn.worker.partition.initial.readBuffersMax | 1024 | false | Max number 
of initial read buffers | 0.3.0 |  | 
 | celeborn.worker.partition.initial.readBuffersMin | 1 | false | Min number of 
initial read buffers | 0.3.0 |  | 
 | celeborn.worker.partitionSorter.directMemoryRatioThreshold | 0.1 | false | 
Max ratio of partition sorter's memory for sorting, when reserved memory is 
higher than max partition sorter memory, partition sorter will stop sorting. If 
this value is set to 0, partition files sorter will skip memory check and 
ServingState check. | 0.2.0 |  | 
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
index 13ffc367b..e729bdc3b 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
@@ -99,6 +99,8 @@ public class MemoryManager {
   private boolean pinnedMemoryCheckEnabled;
   private long pinnedMemoryCheckInterval;
   private long pinnedMemoryLastCheckTime = 0;
+  private boolean resumingByPinnedMemory = false;
+  private long workerPinnedMemoryResumeKeepTime;
 
   @VisibleForTesting
   public static MemoryManager initialize(CelebornConf conf) {
@@ -134,6 +136,7 @@ public class MemoryManager {
     long checkInterval = conf.workerDirectMemoryPressureCheckIntervalMs();
     this.pinnedMemoryCheckEnabled = conf.workerPinnedMemoryCheckEnabled();
     this.pinnedMemoryCheckInterval = conf.workerPinnedMemoryCheckIntervalMs();
+    this.workerPinnedMemoryResumeKeepTime = 
conf.workerPinnedMemoryResumeKeepTime();
     long reportInterval = conf.workerDirectMemoryReportIntervalSecond();
     double readBufferTargetRatio = conf.readBufferTargetRatio();
     long readBufferTargetUpdateInterval = 
conf.readBufferTargetUpdateInterval();
@@ -336,37 +339,37 @@ public class MemoryManager {
     }
     switch (servingState) {
       case PUSH_PAUSED:
-        if (canResumeByPinnedMemory()) {
-          resumeByPinnedMemory(servingState);
-        } else {
+        if (!tryResumeByPinnedMemory(servingState, lastState)) {
           pausePushDataCounter.increment();
           if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
-            logger.info("Trigger action: RESUME REPLICATE");
             resumeReplicate();
           } else {
             logger.info("Trigger action: PAUSE PUSH");
             pausePushDataStartTime = System.currentTimeMillis();
+            resumingByPinnedMemory = false;
             memoryPressureListeners.forEach(
                 memoryPressureListener ->
                     
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
+            // trimCounter cannot be increased when channels resume by 
PinnedMemory, otherwise
+            // PauseSpentTime will be increased unexpectedly
+            trimCounter += 1;
+            if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
+              logger.debug(
+                  "Trigger action: TRIM for {} times, force to append pause 
spent time.",
+                  trimCounter);
+              appendPauseSpentTime(servingState);
+            }
           }
         }
         logger.debug("Trigger action: TRIM");
-        trimCounter += 1;
         trimAllListeners();
-        if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
-          logger.debug(
-              "Trigger action: TRIM for {} times, force to append pause spent 
time.", trimCounter);
-          appendPauseSpentTime(servingState);
-        }
         break;
       case PUSH_AND_REPLICATE_PAUSED:
-        if (canResumeByPinnedMemory()) {
-          resumeByPinnedMemory(servingState);
-        } else {
+        if (!tryResumeByPinnedMemory(servingState, lastState)) {
           pausePushDataAndReplicateCounter.increment();
           logger.info("Trigger action: PAUSE PUSH");
           pausePushDataAndReplicateStartTime = System.currentTimeMillis();
+          resumingByPinnedMemory = false;
           memoryPressureListeners.forEach(
               memoryPressureListener ->
                   
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
@@ -374,18 +377,20 @@ public class MemoryManager {
           memoryPressureListeners.forEach(
               memoryPressureListener ->
                   
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
+          trimCounter += 1;
+          if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
+            logger.debug(
+                "Trigger action: TRIM for {} times, force to append pause 
spent time.",
+                trimCounter);
+            appendPauseSpentTime(servingState);
+          }
         }
         logger.debug("Trigger action: TRIM");
-        trimCounter += 1;
         trimAllListeners();
-        if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
-          logger.debug(
-              "Trigger action: TRIM for {} times, force to append pause spent 
time.", trimCounter);
-          appendPauseSpentTime(servingState);
-        }
         break;
       case NONE_PAUSED:
         // resume from paused mode, append pause spent time
+        resumingByPinnedMemory = false;
         if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
           resumeReplicate();
           resumePush();
@@ -599,15 +604,32 @@ public class MemoryManager {
     }
   }
 
-  private boolean canResumeByPinnedMemory() {
-    if (pinnedMemoryCheckEnabled
-        && System.currentTimeMillis() - pinnedMemoryLastCheckTime >= 
pinnedMemoryCheckInterval
-        && getPinnedMemory() / (double) (maxDirectMemory) < 
pinnedMemoryResumeRatio) {
-      pinnedMemoryLastCheckTime = System.currentTimeMillis();
-      return true;
-    } else {
-      return false;
+  private boolean tryResumeByPinnedMemory(ServingState currentState, 
ServingState lastState) {
+    if (pinnedMemoryCheckEnabled) {
+      long currentTime = System.currentTimeMillis();
+      if (currentTime - pinnedMemoryLastCheckTime >= 
pinnedMemoryCheckInterval) {
+        if (getPinnedMemory() / (double) (maxDirectMemory) < 
pinnedMemoryResumeRatio) {
+          pinnedMemoryLastCheckTime = currentTime;
+          resumingByPinnedMemory = true;
+          resumeByPinnedMemory(currentState);
+          return true;
+        }
+      } else {
+        if (resumingByPinnedMemory
+            && lastState != ServingState.NONE_PAUSED
+            && System.currentTimeMillis() - pinnedMemoryLastCheckTime
+                < workerPinnedMemoryResumeKeepTime
+            && getPinnedMemory() / (double) (maxDirectMemory) < 
pinnedMemoryResumeRatio) {
+          // do nothing, keep resume for a while
+          logger.info(
+              "currentState: {}, keep resume for {}ms after last 
resumeByPinnedMemory",
+              currentState,
+              currentTime - pinnedMemoryLastCheckTime);
+          return true;
+        }
+      }
     }
+    return false;
   }
 
   private void resumePush() {
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
index c0fe08e61..d74f1889e 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.celeborn.service.deploy.memory
 
+import java.util.concurrent.TimeUnit
+
 import scala.concurrent.duration.DurationInt
 
 import org.mockito.{Mockito, MockitoSugar}
@@ -216,6 +218,47 @@ class MemoryManagerSuite extends CelebornFunSuite {
     MemoryManager.reset()
   }
 
+  test("[CELEBORN-1792] Test MemoryManager keep resume a while by pinned 
memory") {
+    val conf = new CelebornConf()
+    conf.set(CelebornConf.WORKER_DIRECT_MEMORY_CHECK_INTERVAL.key, "300s")
+    conf.set(CelebornConf.WORKER_PINNED_MEMORY_CHECK_INTERVAL.key, "1s")
+    MemoryManager.reset()
+    val memoryManager = MockitoSugar.spy(MemoryManager.initialize(conf))
+    val maxDirectorMemory = memoryManager.maxDirectMemory
+    val pushThreshold =
+      (conf.workerDirectMemoryRatioToPauseReceive * 
maxDirectorMemory).longValue()
+    val pinnedMemoryResumeThreshold =
+      (conf.workerPinnedMemoryRatioToResume * maxDirectorMemory).longValue()
+    val channelsLimiter = new MockChannelsLimiter()
+    memoryManager.registerMemoryListener(channelsLimiter)
+
+    // NONE PAUSED -> PAUSE PUSH
+    Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L)
+    Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
+    memoryManager.switchServingState()
+    assert(channelsLimiter.isResume)
+    assert(memoryManager.servingState == ServingState.PUSH_PAUSED)
+
+    // keep pause push, but channels keep resume when pinnedMemory still less 
than threshold
+    Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
+    memoryManager.switchServingState()
+    assert(channelsLimiter.isResume)
+    assert(memoryManager.servingState == ServingState.PUSH_PAUSED)
+
+    // exit keepResumeByPinnedMemory because pinnedMemory is greater than 
threshold
+    Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(
+      pinnedMemoryResumeThreshold + 1)
+    memoryManager.switchServingState()
+    assert(!channelsLimiter.isResume)
+    assert(memoryManager.servingState == ServingState.PUSH_PAUSED)
+
+    Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L)
+    memoryManager.switchServingState()
+    assert(channelsLimiter.isResume)
+    assert(memoryManager.servingState == ServingState.NONE_PAUSED)
+
+  }
+
   class MockMemoryPressureListener(
       val belongModuleName: String,
       var isPause: Boolean = false) extends MemoryPressureListener {
@@ -235,4 +278,19 @@ class MemoryManagerSuite extends CelebornFunSuite {
       // do nothing
     }
   }
+
+  class MockChannelsLimiter(var isResume: Boolean = false) extends 
MemoryPressureListener {
+    override def onPause(moduleName: String): Unit = {
+      isResume = false
+    }
+
+    override def onResume(moduleName: String): Unit = {
+      isResume = true
+    }
+
+    override def onTrim(): Unit = {
+      // do nothing
+    }
+  }
+
 }

Reply via email to