This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/branch-0.3 by this push:
new a74600c8e [CELEBORN-908][WORKER] Tweak pause and resume logic && add
unit test MemoryManager memory check thread
a74600c8e is described below
commit a74600c8e400acd06740365eab7c070d559e7b09
Author: zwangsheng <[email protected]>
AuthorDate: Fri Aug 25 17:55:00 2023 +0800
[CELEBORN-908][WORKER] Tweak pause and resume logic && add unit test
MemoryManager memory check thread
### What changes were proposed in this pull request?
- Tweak pause and resume logic
- Add unit test
```mermaid
graph TB
A(NON_PAUSE)
B(PAUSE_PUSH)
C(PAUSE_PUSH_AND_REPLICATE)
A --> | pause push listener | B
B --> | resume push listener | A
A --> | pause push and replicate listeners | C
C --> | resume push and replicate listeners | A
B --> | pause replicate listener | C
C --> | resume replicate listener | B
```
### Why are the changes needed?
Add unit test for those pause and resume logic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add unit test.
Closes #1835 from zwangsheng/CELEBORN-908.
Authored-by: zwangsheng <[email protected]>
Signed-off-by: mingji <[email protected]>
(cherry picked from commit 467b9bd81c7a96e26a80ddb9f7d75f95d3c7c773)
Signed-off-by: mingji <[email protected]>
---
.../deploy/worker/memory/MemoryManager.java | 89 +++++++++++++---------
.../service/deploy/memory/MemoryManagerSuite.scala | 87 +++++++++++++++++++++
2 files changed, 141 insertions(+), 35 deletions(-)
diff --git
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
index 7f0192496..8ef2ee19d 100644
---
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
+++
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java
@@ -142,41 +142,7 @@ public class MemoryManager {
checkService.scheduleWithFixedDelay(
() -> {
try {
- ServingState lastState = servingState;
- servingState = currentServingState();
- if (lastState != servingState) {
- logger.info("Serving state transformed from {} to {}",
lastState, servingState);
- if (servingState == ServingState.PUSH_PAUSED) {
- pausePushDataCounter.increment();
- logger.info("Trigger action: PAUSE PUSH, RESUME REPLICATE");
- memoryPressureListeners.forEach(
- memoryPressureListener ->
-
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
- memoryPressureListeners.forEach(
- memoryPressureListener ->
-
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
- trimAllListeners();
- } else if (servingState ==
ServingState.PUSH_AND_REPLICATE_PAUSED) {
- pausePushDataAndReplicateCounter.increment();
- logger.info("Trigger action: PAUSE PUSH and REPLICATE");
- memoryPressureListeners.forEach(
- memoryPressureListener ->
-
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
- memoryPressureListeners.forEach(
- memoryPressureListener ->
-
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
- trimAllListeners();
- } else {
- logger.info("Trigger action: RESUME PUSH and REPLICATE");
- memoryPressureListeners.forEach(
- memoryPressureListener ->
memoryPressureListener.onResume("all"));
- }
- } else {
- if (servingState != ServingState.NONE_PAUSED) {
- logger.debug("Trigger action: TRIM");
- trimAllListeners();
- }
- }
+ switchServingState();
} catch (Exception e) {
logger.error("Memory tracker check error", e);
}
@@ -274,6 +240,59 @@ public class MemoryManager {
return isPaused ? ServingState.PUSH_PAUSED : ServingState.NONE_PAUSED;
}
+ @VisibleForTesting
+ protected void switchServingState() {
+ ServingState lastState = servingState;
+ servingState = currentServingState();
+ if (lastState == servingState) {
+ if (servingState != ServingState.NONE_PAUSED) {
+ logger.debug("Trigger action: TRIM");
+ trimAllListeners();
+ }
+ return;
+ }
+ logger.info("Serving state transformed from {} to {}", lastState,
servingState);
+ switch (servingState) {
+ case PUSH_PAUSED:
+ pausePushDataCounter.increment();
+ logger.info("Trigger action: PAUSE PUSH, RESUME REPLICATE");
+ if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
+ memoryPressureListeners.forEach(
+ memoryPressureListener ->
+
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
+ } else if (lastState == ServingState.NONE_PAUSED) {
+ memoryPressureListeners.forEach(
+ memoryPressureListener ->
+
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
+ }
+ trimAllListeners();
+ break;
+ case PUSH_AND_REPLICATE_PAUSED:
+ pausePushDataAndReplicateCounter.increment();
+ logger.info("Trigger action: PAUSE PUSH and REPLICATE");
+ if (lastState == ServingState.NONE_PAUSED) {
+ memoryPressureListeners.forEach(
+ memoryPressureListener ->
+
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
+ }
+ memoryPressureListeners.forEach(
+ memoryPressureListener ->
+
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
+ trimAllListeners();
+ break;
+ case NONE_PAUSED:
+ logger.info("Trigger action: RESUME PUSH and REPLICATE");
+ if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
+ memoryPressureListeners.forEach(
+ memoryPressureListener ->
+
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
+ }
+ memoryPressureListeners.forEach(
+ memoryPressureListener ->
+
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
+ }
+ }
+
public void trimAllListeners() {
if (trimInProcess.compareAndSet(false, true)) {
actionService.submit(
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
index f329b6f36..864ddffec 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala
@@ -17,10 +17,17 @@
package org.apache.celeborn.service.deploy.memory
+import scala.concurrent.duration.DurationInt
+
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.{interval, timeout}
+
import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.common.CelebornConf
import
org.apache.celeborn.common.CelebornConf.{WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE,
WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE}
+import org.apache.celeborn.common.protocol.TransportModuleConstants
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager
+import
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.MemoryPressureListener
import
org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState
class MemoryManagerSuite extends CelebornFunSuite {
@@ -79,4 +86,84 @@ class MemoryManagerSuite extends CelebornFunSuite {
MemoryManager.reset()
}
}
+
+ test("[CELEBORN-882] Test MemoryManager check memory thread logic") {
+ val conf = new CelebornConf()
+ val memoryManager = MemoryManager.initialize(conf)
+ val maxDirectorMemory = memoryManager.maxDirectorMemory
+ val pushThreshold =
+ (conf.workerDirectMemoryRatioToPauseReceive *
maxDirectorMemory).longValue()
+ val replicateThreshold =
+ (conf.workerDirectMemoryRatioToPauseReplicate *
maxDirectorMemory).longValue()
+ val memoryCounter = memoryManager.getSortMemoryCounter
+
+ val pushListener = new
MockMemoryPressureListener(TransportModuleConstants.PUSH_MODULE)
+ val replicateListener =
+ new MockMemoryPressureListener(TransportModuleConstants.REPLICATE_MODULE)
+ memoryManager.registerMemoryListener(pushListener)
+ memoryManager.registerMemoryListener(replicateListener)
+
+ // NONE PAUSED -> PAUSE PUSH
+ memoryCounter.set(pushThreshold + 1)
+ // default check interval is 10ms and we need wait 30ms to make sure the
listener is triggered
+ eventually(timeout(30.second), interval(10.milliseconds)) {
+ assert(pushListener.isPause)
+ assert(!replicateListener.isPause)
+ }
+
+ // PAUSE PUSH -> PAUSE PUSH AND REPLICATE
+ memoryCounter.set(replicateThreshold + 1);
+ eventually(timeout(30.second), interval(10.milliseconds)) {
+ assert(pushListener.isPause)
+ assert(replicateListener.isPause)
+ }
+
+ // PAUSE PUSH AND REPLICATE -> PAUSE PUSH
+ memoryCounter.set(pushThreshold + 1);
+ eventually(timeout(30.second), interval(10.milliseconds)) {
+ assert(pushListener.isPause)
+ assert(!replicateListener.isPause)
+ }
+
+ // PAUSE PUSH -> NONE PAUSED
+ memoryCounter.set(0);
+ eventually(timeout(30.second), interval(10.milliseconds)) {
+ assert(!pushListener.isPause)
+ assert(!replicateListener.isPause)
+ }
+
+ // NONE PAUSED -> PAUSE PUSH AND REPLICATE
+ memoryCounter.set(replicateThreshold + 1);
+ eventually(timeout(30.second), interval(10.milliseconds)) {
+ assert(pushListener.isPause)
+ assert(replicateListener.isPause)
+ }
+
+ // PAUSE PUSH AND REPLICATE -> NONE PAUSED
+ memoryCounter.set(0);
+ eventually(timeout(30.second), interval(10.milliseconds)) {
+ assert(!pushListener.isPause)
+ assert(!replicateListener.isPause)
+ }
+ }
+
+ class MockMemoryPressureListener(
+ val belongModuleName: String,
+ var isPause: Boolean = false) extends MemoryPressureListener {
+ override def onPause(moduleName: String): Unit = {
+ if (belongModuleName == moduleName) {
+ isPause = true
+ }
+ }
+
+ override def onResume(moduleName: String): Unit = {
+ if (belongModuleName == moduleName) {
+ isPause = false
+ }
+ }
+
+ override def onTrim(): Unit = {
+ // do nothing
+ }
+ }
}