This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ea3aaa11 [#1373][FOLLOWUP] fix(spark): register with incorrect 
partitionRanges after reassign (#1612)
3ea3aaa11 is described below

commit 3ea3aaa1143cf8abea0998af4089934e9fabedd2
Author: dingshun3016 <[email protected]>
AuthorDate: Mon Apr 8 17:27:51 2024 +0800

    [#1373][FOLLOWUP] fix(spark): register with incorrect partitionRanges after 
reassign (#1612)
    
    ### What changes were proposed in this pull request?
    
    fix partition id inconsistency when reassign new shuffle server
    
    For example:
    when writing data on node a1, the registered partition id is 1003.
    a1 node fails,and reassign node b1 and register shuffle server b1,but 
partitionNumPerRange is 1.
    when writing data to node b1, NO_REGISTER exception will be thrown
    
    ### Why are the changes needed?
    
    Fix: (#1373)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    ---------
    
    Co-authored-by: shun01.ding <[email protected]>
---
 .../apache/spark/shuffle/RssShuffleManager.java    | 59 ++++++++++++++++++----
 1 file changed, 49 insertions(+), 10 deletions(-)

diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 1b4df1747..6d9487ca4 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -18,7 +18,10 @@
 package org.apache.spark.shuffle;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -27,6 +30,7 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import scala.Tuple2;
@@ -1157,7 +1161,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
               1,
               requiredShuffleServerNumber,
               estimateTaskConcurrency,
-              failuresShuffleServerIds);
+              failuresShuffleServerIds,
+              null);
       /**
        * we need to clear the metadata of the completed task, otherwise some 
of the stage's data
        * will be lost
@@ -1196,24 +1201,54 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       }
 
       // get the newer server to replace faulty server.
-      ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, 
faultyShuffleServerId);
+      ShuffleServerInfo newAssignedServer =
+          reassignShuffleServerForTask(shuffleId, partitionIds, 
faultyShuffleServerId);
       if (newAssignedServer != null) {
         handleInfo.createNewReassignmentForMultiPartitions(
             partitionIds, faultyShuffleServerId, newAssignedServer);
       }
+      LOG.info(
+          "Reassign shuffle-server from {} -> {} for shuffleId: {}, 
partitionIds: {}",
+          faultyShuffleServerId,
+          newAssignedServer,
+          shuffleId,
+          partitionIds);
       return newAssignedServer;
     }
   }
 
-  private ShuffleServerInfo assignShuffleServer(int shuffleId, String 
faultyShuffleServerId) {
+  private ShuffleServerInfo reassignShuffleServerForTask(
+      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
     Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
     faultyServerIds.addAll(failuresShuffleServerIds);
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-        requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
-    if (partitionToServers.get(0) != null && partitionToServers.get(0).size() 
== 1) {
-      return partitionToServers.get(0).get(0);
-    }
-    return null;
+    AtomicReference<ShuffleServerInfo> replacementRef = new 
AtomicReference<>();
+    requestShuffleAssignment(
+        shuffleId,
+        1,
+        1,
+        1,
+        1,
+        faultyServerIds,
+        shuffleAssignmentsInfo -> {
+          if (shuffleAssignmentsInfo == null) {
+            return null;
+          }
+          Optional<List<ShuffleServerInfo>> replacementOpt =
+              
shuffleAssignmentsInfo.getPartitionToServers().values().stream().findFirst();
+          ShuffleServerInfo replacement = replacementOpt.get().get(0);
+          replacementRef.set(replacement);
+
+          Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new 
HashMap<>();
+          List<PartitionRange> partitionRanges = new ArrayList<>();
+          for (Integer partitionId : partitionIds) {
+            newPartitionToServers.put(partitionId, Arrays.asList(replacement));
+            partitionRanges.add(new PartitionRange(partitionId, partitionId));
+          }
+          Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges 
= new HashMap<>();
+          serverToPartitionRanges.put(replacement, partitionRanges);
+          return new ShuffleAssignmentsInfo(newPartitionToServers, 
serverToPartitionRanges);
+        });
+    return replacementRef.get();
   }
 
   private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
@@ -1222,7 +1257,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       int partitionNumPerRange,
       int assignmentShuffleServerNumber,
       int estimateTaskConcurrency,
-      Set<String> faultyServerIds) {
+      Set<String> faultyServerIds,
+      Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo> 
reassignmentHandler) {
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
     ClientUtils.validateClientType(clientType);
     assignmentTags.add(clientType);
@@ -1242,6 +1278,9 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
                     assignmentShuffleServerNumber,
                     estimateTaskConcurrency,
                     faultyServerIds);
+            if (reassignmentHandler != null) {
+              response = reassignmentHandler.apply(response);
+            }
             registerShuffleServers(
                 id.get(), shuffleId, response.getServerToPartitionRanges(), 
getRemoteStorageInfo());
             return response.getPartitionToServers();

Reply via email to