This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 479510cb9 [CELEBORN-888][WORKER] Tweak the logic and add unit tests 
for the MemoryManager#currentServingState method
479510cb9 is described below

commit 479510cb9cdba5dfec37169d04d0344d50511947
Author: zwangsheng <[email protected]>
AuthorDate: Wed Aug 23 17:20:31 2023 +0800

    [CELEBORN-888][WORKER] Tweak the logic and add unit tests for the 
MemoryManager#currentServingState method
    
    ### What changes were proposed in this pull request?
    Tweak the logic of `MemoryManager#currentServingState`
    
    Add Unit Test for this function
    
    ```mermaid
    graph TB
    
    A(Check Used Memory) --> B{Reach Pause Replicate Threshold}
    B --> | N | C{Reach Pause Push Threshold}
    B --> | Y | Z(Trigger Pause Push and Replicate)
    C --> | N | D{Reach Resume Threshold}
    C --> | Y | Y(Trigger Pause Push but Resume Replicate)
    D --> | N | E{In Pause Mode}
    D --> | Y | X(Trigger Resume Push and Replicate)
    E --> | N | U(Do Nothing)
    E --> | Y | Y
    ```
    ### Why are the changes needed?
    Make this method logical, and add unit test to ensure logic won't be 
accidental modification
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add Unit Test
    
    Closes #1811 from zwangsheng/CELEBORN-888.
    
    Authored-by: zwangsheng <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 project/CelebornBuild.scala                        |  3 +-
 worker/pom.xml                                     |  7 ++++
 .../deploy/worker/memory/MemoryManager.java        | 35 ++++++++---------
 .../service/deploy/memory/MemoryManagerSuite.scala | 45 ++++++++++++++++++++--
 4 files changed, 67 insertions(+), 23 deletions(-)

diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala
index 06e2b831f..833b06084 100644
--- a/project/CelebornBuild.scala
+++ b/project/CelebornBuild.scala
@@ -375,7 +375,8 @@ object CelebornMaster {
 
 object CelebornWorker {
   lazy val worker = Project("celeborn-worker", file("worker"))
-    .dependsOn(CelebornCommon.common, CelebornService.service)
+    .dependsOn(CelebornService.service)
+    .dependsOn(CelebornCommon.common % "test->test;compile->compile")
     .dependsOn(CelebornClient.client % "test->test;compile->compile")
     .dependsOn(CelebornMaster.master % "test->test;compile->compile")
     .settings (
diff --git a/worker/pom.xml b/worker/pom.xml
index 3d7ab4675..b3d0daf86 100644
--- a/worker/pom.xml
+++ b/worker/pom.xml
@@ -77,6 +77,13 @@
       <artifactId>log4j-1.2-api</artifactId>
     </dependency>
 
+    <dependency>
+      <groupId>org.apache.celeborn</groupId>
+      <artifactId>celeborn-common_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
     <dependency>
       <groupId>org.apache.celeborn</groupId>
       <artifactId>celeborn-client_${scala.binary.version}</artifactId>
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
index 758ddbb8c..7f0192496 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
@@ -43,7 +43,7 @@ import 
org.apache.celeborn.service.deploy.worker.storage.CreditStreamManager;
 public class MemoryManager {
   private static final Logger logger = 
LoggerFactory.getLogger(MemoryManager.class);
   private static volatile MemoryManager _INSTANCE = null;
-  private long maxDirectorMemory = 0;
+  @VisibleForTesting public long maxDirectorMemory = 0;
   private final long pausePushDataThreshold;
   private final long pauseReplicateThreshold;
   private final long resumeThreshold;
@@ -66,7 +66,7 @@ public class MemoryManager {
   private final LongAdder pausePushDataCounter = new LongAdder();
   private final LongAdder pausePushDataAndReplicateCounter = new LongAdder();
   private ServingState servingState = ServingState.NONE_PAUSED;
-  private boolean underPressure;
+  private volatile boolean isPaused = false;
 
   // For credit stream
   private final AtomicLong readBufferCounter = new AtomicLong(0);
@@ -253,27 +253,25 @@ public class MemoryManager {
 
   public ServingState currentServingState() {
     long memoryUsage = getMemoryUsage();
-    boolean pausePushData = memoryUsage > pausePushDataThreshold;
-    boolean pauseReplicate = memoryUsage > pauseReplicateThreshold;
-    boolean resume = memoryUsage < resumeThreshold;
-    if (pausePushData || pauseReplicate) {
-      underPressure = true;
-    } else if (resume) {
-      underPressure = false;
-    }
-    if (pausePushData && pauseReplicate) {
+    // pause replicate threshold always greater than pause push data threshold
+    // so when trigger pause replicate, pause both push and replicate
+    if (memoryUsage > pauseReplicateThreshold) {
+      isPaused = true;
       return ServingState.PUSH_AND_REPLICATE_PAUSED;
     }
-    if (pausePushData) {
+    // trigger pause only push
+    if (memoryUsage > pausePushDataThreshold) {
+      isPaused = true;
       return ServingState.PUSH_PAUSED;
     }
-    if (resume) {
+    // trigger resume
+    if (memoryUsage < resumeThreshold) {
+      isPaused = false;
       return ServingState.NONE_PAUSED;
     }
-    if (underPressure) {
-      return ServingState.PUSH_PAUSED;
-    }
-    return ServingState.NONE_PAUSED;
+    // if isPaused and not trigger resume, then return pause push
+    // wait for trigger resumeThreshold to resume state
+    return isPaused ? ServingState.PUSH_PAUSED : ServingState.NONE_PAUSED;
   }
 
   public void trimAllListeners() {
@@ -412,7 +410,8 @@ public class MemoryManager {
     void onChange(long newMemoryTarget);
   }
 
-  enum ServingState {
+  @VisibleForTesting
+  public enum ServingState {
     NONE_PAUSED,
     PUSH_AND_REPLICATE_PAUSED,
     PUSH_PAUSED
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
index 0e0d627bf..f329b6f36 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
@@ -17,13 +17,12 @@
 
 package org.apache.celeborn.service.deploy.memory
 
-import org.scalatest.BeforeAndAfterEach
-import org.scalatest.funsuite.AnyFunSuite
-
+import org.apache.celeborn.CelebornFunSuite
 import org.apache.celeborn.common.CelebornConf
 import 
org.apache.celeborn.common.CelebornConf.{WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE,
 WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE}
 import org.apache.celeborn.service.deploy.worker.memory.MemoryManager
-class MemoryManagerSuite extends AnyFunSuite with BeforeAndAfterEach {
+import 
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState
+class MemoryManagerSuite extends CelebornFunSuite {
 
   // reset the memory manager before each test
   override protected def beforeEach(): Unit = {
@@ -42,4 +41,42 @@ class MemoryManagerSuite extends AnyFunSuite with 
BeforeAndAfterEach {
       caught.getMessage == s"Invalid config, 
${WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE.key}(0.85) " +
         s"should be greater than 
${WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE.key}(0.95)")
   }
+
+  test("[CELEBORN-888] Test MemoryManager#currentServingState trigger case") {
+    val conf = new CelebornConf()
+    try {
+      val memoryManager = MemoryManager.initialize(conf)
+      val maxDirectorMemory = memoryManager.maxDirectorMemory
+      val pushThreshold =
+        (conf.workerDirectMemoryRatioToPauseReceive * 
maxDirectorMemory).longValue()
+      val replicateThreshold =
+        (conf.workerDirectMemoryRatioToPauseReplicate * 
maxDirectorMemory).longValue()
+      val resumeThreshold = (conf.workerDirectMemoryRatioToResume * 
maxDirectorMemory).longValue()
+
+      // use sortMemoryCounter to trigger each state
+      val memoryCounter = memoryManager.getSortMemoryCounter
+
+      // default state
+      assert(ServingState.NONE_PAUSED == memoryManager.currentServingState())
+      // reach pause push data threshold
+      memoryCounter.set(pushThreshold + 1)
+      assert(ServingState.PUSH_PAUSED == memoryManager.currentServingState())
+      // reach pause replicate data threshold
+      memoryCounter.set(replicateThreshold + 1);
+      assert(ServingState.PUSH_AND_REPLICATE_PAUSED == 
memoryManager.currentServingState());
+      // touch pause push data threshold again
+      memoryCounter.set(pushThreshold + 1);
+      assert(MemoryManager.ServingState.PUSH_PAUSED == 
memoryManager.currentServingState());
+      // between pause push data threshold and resume data threshold
+      memoryCounter.set(resumeThreshold + 2);
+      assert(MemoryManager.ServingState.PUSH_PAUSED == 
memoryManager.currentServingState());
+      // touch resume data threshold
+      memoryCounter.set(0);
+      assert(MemoryManager.ServingState.NONE_PAUSED == 
memoryManager.currentServingState());
+    } catch {
+      case e: Exception => throw e
+    } finally {
+      MemoryManager.reset()
+    }
+  }
 }

Reply via email to