This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new c924a4ff [CELEBORN-61][CELEBORN-62][FEATURE] Shuffle client support 
slow start, congestion avoidance and congestion control (#1052)
c924a4ff is described below

commit c924a4ff0d38e5d39184e9753d3b7693640db46c
Author: Angerszhuuuu <[email protected]>
AuthorDate: Thu Dec 8 12:41:34 2022 +0800

    [CELEBORN-61][CELEBORN-62][FEATURE] Shuffle client support slow start, 
congestion avoidance and congestion control (#1052)
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 68 +++++++++++++++++++++-
 1 file changed, 65 insertions(+), 3 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 86f0c5b4..51243ee5 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -83,7 +83,9 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   private final int registerShuffleMaxRetries;
   private final long registerShuffleRetryWait;
-  private final int maxInFlight;
+  private int maxInFlight;
+  private Integer currentMaxReqsInFlight = 1;
+  private int congestionAvoidanceFlag = 0;
   private final int pushBufferMaxSize;
 
   private final RpcEnv rpcEnv;
@@ -588,7 +590,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           partitionId,
           nextBatchId);
       // check limit
-      limitMaxInFlight(mapKey, pushState, maxInFlight);
+      limitMaxInFlight(mapKey, pushState, currentMaxReqsInFlight);
 
       // add inFlight requests
       pushState.inFlightBatches.put(nextBatchId, loc);
@@ -649,6 +651,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                       attemptId,
                       nextBatchId);
                   splitPartition(shuffleId, partitionId, applicationId, loc);
+                  slowStart();
                   callback.onSuccess(response);
                 } else if (reason == StatusCode.HARD_SPLIT.getValue()) {
                   logger.debug(
@@ -669,11 +672,27 @@ public class ShuffleClientImpl extends ShuffleClient {
                               this,
                               pushState,
                               StatusCode.HARD_SPLIT));
+                } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_MASTER_CONGESTED.getValue()) {
+                  logger.debug(
+                      "Push data split for map {} attempt {} batch {} return 
master congested.",
+                      mapId,
+                      attemptId,
+                      nextBatchId);
+                  congestionControl();
+                } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) {
+                  logger.debug(
+                      "Push data split for map {} attempt {} batch {} return 
slave congested.",
+                      mapId,
+                      attemptId,
+                      nextBatchId);
+                  congestionControl();
                 } else {
                   response.rewind();
+                  slowStart();
                   callback.onSuccess(response);
                 }
               } else {
+                slowStart();
                 callback.onSuccess(response);
               }
             }
@@ -853,7 +872,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     ArrayList<Map.Entry<String, DataBatches>> batchesArr =
         new ArrayList<>(pushState.batchesMap.entrySet());
     while (!batchesArr.isEmpty()) {
-      limitMaxInFlight(mapKey, pushState, maxInFlight);
+      limitMaxInFlight(mapKey, pushState, currentMaxReqsInFlight);
       Map.Entry<String, DataBatches> entry = 
batchesArr.get(rand.nextInt(batchesArr.size()));
       ArrayList<DataBatches.DataBatch> batches = 
entry.getValue().requireBatches(pushBufferMaxSize);
       if (entry.getValue().getTotalSize() == 0) {
@@ -969,13 +988,29 @@ public class ShuffleClientImpl extends ShuffleClient {
                             batches,
                             StatusCode.HARD_SPLIT,
                             groupedBatchId));
+              } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_MASTER_CONGESTED.getValue()) {
+                logger.debug(
+                    "Push data split for map {} attempt {} batchs {} return 
master congested.",
+                    mapId,
+                    attemptId,
+                    Arrays.toString(batchIds));
+                congestionControl();
+              } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) {
+                logger.debug(
+                    "Push data split for map {} attempt {} batchs {} return 
slave congested.",
+                    mapId,
+                    attemptId,
+                    Arrays.toString(batchIds));
+                congestionControl();
               } else {
                 // Should not happen in current architecture.
                 response.rewind();
                 logger.error("Push merged data should not receive this 
response");
+                slowStart();
                 callback.onSuccess(response);
               }
             } else {
+              slowStart();
               callback.onSuccess(response);
             }
           }
@@ -1228,6 +1263,33 @@ public class ShuffleClientImpl extends ShuffleClient {
     driverRssMetaService = endpointRef;
   }
 
+  private void slowStart() {
+    synchronized (currentMaxReqsInFlight) {
+      if (currentMaxReqsInFlight > maxInFlight) {
+        // Congestion avoidance
+        congestionAvoidanceFlag++;
+        if (congestionAvoidanceFlag >= currentMaxReqsInFlight) {
+          currentMaxReqsInFlight++;
+          congestionAvoidanceFlag = 0;
+        }
+      } else {
+        // Slow start
+        currentMaxReqsInFlight++;
+      }
+    }
+  }
+
+  private void congestionControl() {
+    synchronized (currentMaxReqsInFlight) {
+      if (currentMaxReqsInFlight <= 1) {
+        currentMaxReqsInFlight = 1;
+      } else {
+        currentMaxReqsInFlight = currentMaxReqsInFlight / 2;
+      }
+      maxInFlight = currentMaxReqsInFlight;
+    }
+  }
+
   private boolean mapperEnded(int shuffleId, int mapId, int attemptId) {
     return mapperEndMap.containsKey(shuffleId)
         && mapperEndMap.get(shuffleId).contains(Utils.makeMapKey(shuffleId, 
mapId, attemptId));

Reply via email to