This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 4669d1e31 [CELEBORN-788] Update latest PartitionLocation before retry 
PushData
4669d1e31 is described below

commit 4669d1e31c79826b3a510297690322265272b42a
Author: caojiaqing <[email protected]>
AuthorDate: Thu Jul 20 21:36:37 2023 +0800

    [CELEBORN-788] Update latest PartitionLocation before retry PushData
    
    ### What changes were proposed in this pull request?
    
    Inside `ShuffleClient.submitRetryPushData`,  update the latest 
PartitionLocation before retry push data again.
    
    ### Why are the changes needed?
    Before this PR, inside `ShuffleClient.submitRetryPushData`, push data will 
use the previous PartitionLocation,
    which is incorrect, and may cause inefficiency in some cases.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Passes GA.
    
    Closes #1706 from JQ-Cao/788.
    
    Authored-by: caojiaqing <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 55 +++++++++++++---------
 1 file changed, 33 insertions(+), 22 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index dcec21a94..7361c7906 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -216,7 +216,7 @@ public class ShuffleClientImpl extends ShuffleClient {
       int shuffleId,
       byte[] body,
       int batchId,
-      RpcResponseCallback wrappedCallback,
+      PushDataRpcResponseCallback pushDataRpcResponseCallback,
       PushState pushState,
       ReviveRequest request,
       int remainReviveTimes,
@@ -250,7 +250,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           loc);
       pushState.removeBatch(batchId, loc.hostAndPushPort());
     } else if (request.reviveStatus != StatusCode.SUCCESS.getValue()) {
-      wrappedCallback.onFailure(
+      pushDataRpcResponseCallback.onFailure(
           new CelebornIOException(
               cause
                   + " then revive but "
@@ -273,7 +273,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           batchId,
           newLoc);
       try {
-        if (!isPushTargetWorkerExcluded(newLoc, wrappedCallback)) {
+        if (!isPushTargetWorkerExcluded(newLoc, pushDataRpcResponseCallback)) {
           if (!testRetryRevive || remainReviveTimes < 1) {
             TransportClient client =
                 dataClientFactory.createClient(newLoc.getHost(), 
newLoc.getPushPort(), partitionId);
@@ -281,7 +281,8 @@ public class ShuffleClientImpl extends ShuffleClient {
             String shuffleKey = Utils.makeShuffleKey(appUniqueId, shuffleId);
             PushData newPushData =
                 new PushData(PRIMARY_MODE, shuffleKey, newLoc.getUniqueId(), 
newBuffer);
-            client.pushData(newPushData, pushDataTimeout, wrappedCallback);
+            pushDataRpcResponseCallback.updateLatestPartition(newLoc);
+            client.pushData(newPushData, pushDataTimeout, 
pushDataRpcResponseCallback);
           } else {
             throw new RuntimeException(
                 "Mock push data submit retry failed. remainReviveTimes = "
@@ -299,7 +300,7 @@ public class ShuffleClientImpl extends ShuffleClient {
             batchId,
             newLoc,
             e);
-        wrappedCallback.onFailure(
+        pushDataRpcResponseCallback.onFailure(
             new 
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
       }
     }
@@ -756,6 +757,10 @@ public class ShuffleClientImpl extends ShuffleClient {
     }
   }
 
+  private interface PushDataRpcResponseCallback extends RpcResponseCallback {
+    default void updateLatestPartition(PartitionLocation latest) {}
+  }
+
   public int pushOrMergeData(
       int shuffleId,
       int mapId,
@@ -894,8 +899,14 @@ public class ShuffleClientImpl extends ShuffleClient {
           };
 
       RpcResponseCallback wrappedCallback =
-          new RpcResponseCallback() {
+          new PushDataRpcResponseCallback() {
             int remainReviveTimes = maxReviveTimes;
+            PartitionLocation latest = loc;
+
+            @Override
+            public void updateLatestPartition(PartitionLocation latest) {
+              this.latest = latest;
+            }
 
             @Override
             public void onSuccess(ByteBuffer response) {
@@ -904,19 +915,19 @@ public class ShuffleClientImpl extends ShuffleClient {
                 if (reason == StatusCode.SOFT_SPLIT.getValue()) {
                   logger.debug(
                       "Push data to {} soft split required for shuffle {} map 
{} attempt {} partition {} batch {}.",
-                      loc.hostAndPushPort(),
+                      latest.hostAndPushPort(),
                       shuffleId,
                       mapId,
                       attemptId,
                       partitionId,
                       nextBatchId);
-                  splitPartition(shuffleId, partitionId, loc);
-                  pushState.onSuccess(loc.hostAndPushPort());
+                  splitPartition(shuffleId, partitionId, latest);
+                  pushState.onSuccess(latest.hostAndPushPort());
                   callback.onSuccess(response);
                 } else if (reason == StatusCode.HARD_SPLIT.getValue()) {
                   logger.debug(
                       "Push data to {} hard split required for shuffle {} map 
{} attempt {} partition {} batch {}.",
-                      loc.hostAndPushPort(),
+                      latest.hostAndPushPort(),
                       shuffleId,
                       mapId,
                       attemptId,
@@ -928,8 +939,8 @@ public class ShuffleClientImpl extends ShuffleClient {
                           mapId,
                           attemptId,
                           partitionId,
-                          loc.getEpoch(),
-                          loc,
+                          latest.getEpoch(),
+                          latest,
                           StatusCode.HARD_SPLIT);
                   reviveManager.addRequest(reviveRequest);
                   long dueTime =
@@ -951,33 +962,33 @@ public class ShuffleClientImpl extends ShuffleClient {
                 } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
                   logger.debug(
                       "Push data to {} primary congestion required for shuffle 
{} map {} attempt {} partition {} batch {}.",
-                      loc.hostAndPushPort(),
+                      latest.hostAndPushPort(),
                       shuffleId,
                       mapId,
                       attemptId,
                       partitionId,
                       nextBatchId);
-                  pushState.onCongestControl(loc.hostAndPushPort());
+                  pushState.onCongestControl(latest.hostAndPushPort());
                   callback.onSuccess(response);
                 } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
                   logger.debug(
                       "Push data to {} replica congestion required for shuffle 
{} map {} attempt {} partition {} batch {}.",
-                      loc.hostAndPushPort(),
+                      latest.hostAndPushPort(),
                       shuffleId,
                       mapId,
                       attemptId,
                       partitionId,
                       nextBatchId);
-                  pushState.onCongestControl(loc.hostAndPushPort());
+                  pushState.onCongestControl(latest.hostAndPushPort());
                   callback.onSuccess(response);
                 } else {
                   // StageEnd.
                   response.rewind();
-                  pushState.onSuccess(loc.hostAndPushPort());
+                  pushState.onSuccess(latest.hostAndPushPort());
                   callback.onSuccess(response);
                 }
               } else {
-                pushState.onSuccess(loc.hostAndPushPort());
+                pushState.onSuccess(latest.hostAndPushPort());
                 callback.onSuccess(response);
               }
             }
@@ -1001,7 +1012,7 @@ public class ShuffleClientImpl extends ShuffleClient {
 
               logger.error(
                   "Push data to {} failed for shuffle {} map {} attempt {} 
partition {} batch {}, remain revive times {}.",
-                  loc.hostAndPushPort(),
+                  latest.hostAndPushPort(),
                   shuffleId,
                   mapId,
                   attemptId,
@@ -1014,7 +1025,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                 remainReviveTimes = remainReviveTimes - 1;
                 ReviveRequest reviveRequest =
                     new ReviveRequest(
-                        shuffleId, mapId, attemptId, partitionId, 
loc.getEpoch(), loc, cause);
+                        shuffleId, mapId, attemptId, partitionId, 
latest.getEpoch(), latest, cause);
                 reviveManager.addRequest(reviveRequest);
                 long dueTime =
                     System.currentTimeMillis()
@@ -1033,10 +1044,10 @@ public class ShuffleClientImpl extends ShuffleClient {
                             remainReviveTimes,
                             dueTime));
               } else {
-                pushState.removeBatch(nextBatchId, loc.hostAndPushPort());
+                pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
                 logger.info(
                     "Push data to {} failed but mapper already ended for 
shuffle {} map {} attempt {} partition {} batch {}, remain revive times {}.",
-                    loc.hostAndPushPort(),
+                    latest.hostAndPushPort(),
                     shuffleId,
                     mapId,
                     attemptId,

Reply via email to