This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new a74600c8e [CELEBORN-908][WORKER] Tweak pause and resume logic && add 
unit test MemoryManager memory check thread
a74600c8e is described below

commit a74600c8e400acd06740365eab7c070d559e7b09
Author: zwangsheng <[email protected]>
AuthorDate: Fri Aug 25 17:55:00 2023 +0800

    [CELEBORN-908][WORKER] Tweak pause and resume logic && add unit test 
MemoryManager memory check thread
    
    ### What changes were proposed in this pull request?
    - Tweak pause and resume logic
    - Add unit test
    
    ```mermaid
    graph TB
    A(NON_PAUSE)
    B(PAUSE_PUSH)
    C(PAUSE_PUSH_AND_REPLICATE)
    A --> | pause push listener | B
    B --> | resume push listener | A
    A --> | pause push and replicate listeners | C
    C --> | resume push and replicate listeners | A
    B --> | pause replicate listener | C
    C --> | resume replicate listener | B
    ```
    
    ### Why are the changes needed?
    Add unit test for those pause and resume logic.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add unit test.
    
    Closes #1835 from zwangsheng/CELEBORN-908.
    
    Authored-by: zwangsheng <[email protected]>
    Signed-off-by: mingji <[email protected]>
    (cherry picked from commit 467b9bd81c7a96e26a80ddb9f7d75f95d3c7c773)
    Signed-off-by: mingji <[email protected]>
---
 .../deploy/worker/memory/MemoryManager.java        | 89 +++++++++++++---------
 .../service/deploy/memory/MemoryManagerSuite.scala | 87 +++++++++++++++++++++
 2 files changed, 141 insertions(+), 35 deletions(-)

diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
index 7f0192496..8ef2ee19d 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
@@ -142,41 +142,7 @@ public class MemoryManager {
     checkService.scheduleWithFixedDelay(
         () -> {
           try {
-            ServingState lastState = servingState;
-            servingState = currentServingState();
-            if (lastState != servingState) {
-              logger.info("Serving state transformed from {} to {}", 
lastState, servingState);
-              if (servingState == ServingState.PUSH_PAUSED) {
-                pausePushDataCounter.increment();
-                logger.info("Trigger action: PAUSE PUSH, RESUME REPLICATE");
-                memoryPressureListeners.forEach(
-                    memoryPressureListener ->
-                        
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
-                memoryPressureListeners.forEach(
-                    memoryPressureListener ->
-                        
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
-                trimAllListeners();
-              } else if (servingState == 
ServingState.PUSH_AND_REPLICATE_PAUSED) {
-                pausePushDataAndReplicateCounter.increment();
-                logger.info("Trigger action: PAUSE PUSH and REPLICATE");
-                memoryPressureListeners.forEach(
-                    memoryPressureListener ->
-                        
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
-                memoryPressureListeners.forEach(
-                    memoryPressureListener ->
-                        
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
-                trimAllListeners();
-              } else {
-                logger.info("Trigger action: RESUME PUSH and REPLICATE");
-                memoryPressureListeners.forEach(
-                    memoryPressureListener -> 
memoryPressureListener.onResume("all"));
-              }
-            } else {
-              if (servingState != ServingState.NONE_PAUSED) {
-                logger.debug("Trigger action: TRIM");
-                trimAllListeners();
-              }
-            }
+            switchServingState();
           } catch (Exception e) {
             logger.error("Memory tracker check error", e);
           }
@@ -274,6 +240,59 @@ public class MemoryManager {
     return isPaused ? ServingState.PUSH_PAUSED : ServingState.NONE_PAUSED;
   }
 
+  @VisibleForTesting
+  protected void switchServingState() {
+    ServingState lastState = servingState;
+    servingState = currentServingState();
+    if (lastState == servingState) {
+      if (servingState != ServingState.NONE_PAUSED) {
+        logger.debug("Trigger action: TRIM");
+        trimAllListeners();
+      }
+      return;
+    }
+    logger.info("Serving state transformed from {} to {}", lastState, 
servingState);
+    switch (servingState) {
+      case PUSH_PAUSED:
+        pausePushDataCounter.increment();
+        logger.info("Trigger action: PAUSE PUSH, RESUME REPLICATE");
+        if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
+          memoryPressureListeners.forEach(
+              memoryPressureListener ->
+                  
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
+        } else if (lastState == ServingState.NONE_PAUSED) {
+          memoryPressureListeners.forEach(
+              memoryPressureListener ->
+                  
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
+        }
+        trimAllListeners();
+        break;
+      case PUSH_AND_REPLICATE_PAUSED:
+        pausePushDataAndReplicateCounter.increment();
+        logger.info("Trigger action: PAUSE PUSH and REPLICATE");
+        if (lastState == ServingState.NONE_PAUSED) {
+          memoryPressureListeners.forEach(
+              memoryPressureListener ->
+                  
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
+        }
+        memoryPressureListeners.forEach(
+            memoryPressureListener ->
+                
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
+        trimAllListeners();
+        break;
+      case NONE_PAUSED:
+        logger.info("Trigger action: RESUME PUSH and REPLICATE");
+        if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
+          memoryPressureListeners.forEach(
+              memoryPressureListener ->
+                  
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
+        }
+        memoryPressureListeners.forEach(
+            memoryPressureListener ->
+                
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
+    }
+  }
+
   public void trimAllListeners() {
     if (trimInProcess.compareAndSet(false, true)) {
       actionService.submit(
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
index f329b6f36..864ddffec 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
@@ -17,10 +17,17 @@
 
 package org.apache.celeborn.service.deploy.memory
 
+import scala.concurrent.duration.DurationInt
+
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.{interval, timeout}
+
 import org.apache.celeborn.CelebornFunSuite
 import org.apache.celeborn.common.CelebornConf
 import 
org.apache.celeborn.common.CelebornConf.{WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE,
 WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE}
+import org.apache.celeborn.common.protocol.TransportModuleConstants
 import org.apache.celeborn.service.deploy.worker.memory.MemoryManager
+import 
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.MemoryPressureListener
 import 
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState
 class MemoryManagerSuite extends CelebornFunSuite {
 
@@ -79,4 +86,84 @@ class MemoryManagerSuite extends CelebornFunSuite {
       MemoryManager.reset()
     }
   }
+
+  test("[CELEBORN-882] Test MemoryManager check memory thread logic") {
+    val conf = new CelebornConf()
+    val memoryManager = MemoryManager.initialize(conf)
+    val maxDirectorMemory = memoryManager.maxDirectorMemory
+    val pushThreshold =
+      (conf.workerDirectMemoryRatioToPauseReceive * 
maxDirectorMemory).longValue()
+    val replicateThreshold =
+      (conf.workerDirectMemoryRatioToPauseReplicate * 
maxDirectorMemory).longValue()
+    val memoryCounter = memoryManager.getSortMemoryCounter
+
+    val pushListener = new 
MockMemoryPressureListener(TransportModuleConstants.PUSH_MODULE)
+    val replicateListener =
+      new MockMemoryPressureListener(TransportModuleConstants.REPLICATE_MODULE)
+    memoryManager.registerMemoryListener(pushListener)
+    memoryManager.registerMemoryListener(replicateListener)
+
+    // NONE PAUSED -> PAUSE PUSH
+    memoryCounter.set(pushThreshold + 1)
+    // default check interval is 10ms and we need wait 30ms to make sure the 
listener is triggered
+    eventually(timeout(30.second), interval(10.milliseconds)) {
+      assert(pushListener.isPause)
+      assert(!replicateListener.isPause)
+    }
+
+    // PAUSE PUSH -> PAUSE PUSH AND REPLICATE
+    memoryCounter.set(replicateThreshold + 1);
+    eventually(timeout(30.second), interval(10.milliseconds)) {
+      assert(pushListener.isPause)
+      assert(replicateListener.isPause)
+    }
+
+    // PAUSE PUSH AND REPLICATE -> PAUSE PUSH
+    memoryCounter.set(pushThreshold + 1);
+    eventually(timeout(30.second), interval(10.milliseconds)) {
+      assert(pushListener.isPause)
+      assert(!replicateListener.isPause)
+    }
+
+    // PAUSE PUSH -> NONE PAUSED
+    memoryCounter.set(0);
+    eventually(timeout(30.second), interval(10.milliseconds)) {
+      assert(!pushListener.isPause)
+      assert(!replicateListener.isPause)
+    }
+
+    // NONE PAUSED -> PAUSE PUSH AND REPLICATE
+    memoryCounter.set(replicateThreshold + 1);
+    eventually(timeout(30.second), interval(10.milliseconds)) {
+      assert(pushListener.isPause)
+      assert(replicateListener.isPause)
+    }
+
+    // PAUSE PUSH AND REPLICATE -> NONE PAUSED
+    memoryCounter.set(0);
+    eventually(timeout(30.second), interval(10.milliseconds)) {
+      assert(!pushListener.isPause)
+      assert(!replicateListener.isPause)
+    }
+  }
+
+  class MockMemoryPressureListener(
+      val belongModuleName: String,
+      var isPause: Boolean = false) extends MemoryPressureListener {
+    override def onPause(moduleName: String): Unit = {
+      if (belongModuleName == moduleName) {
+        isPause = true
+      }
+    }
+
+    override def onResume(moduleName: String): Unit = {
+      if (belongModuleName == moduleName) {
+        isPause = false
+      }
+    }
+
+    override def onTrim(): Unit = {
+      // do nothing
+    }
+  }
 }

Reply via email to