This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 375262ad9 [#1579][part-1] fix(spark): Ensure all previous data is 
cleared for stage retry (#1762)
375262ad9 is described below

commit 375262ad9236efebe0010a65a556c2dec3e9ae82
Author: yl09099 <[email protected]>
AuthorDate: Thu Jun 13 09:31:29 2024 +0800

    [#1579][part-1] fix(spark): Ensure all previous data is cleared for stage 
retry (#1762)
    
    ### What changes were proposed in this pull request?
    
    1. clear out previous stage attempt data synchronously when registering the 
re-assignment shuffleIds.
    2. rework the stage retry interface and rpc
    3. introducing the stage version to avoid accepting the older data
    
    ### Why are the changes needed?
    
    Fix: https://github.com/apache/incubator-uniffle/issues/1579
    
    If the previous stage attempt is in the purge queue in shuffle-server side, 
the retry stage writing will cause
    unknown exceptions, so we'd better to clear out all previous stage attempt 
data before re-registering
    
    This PR is to sync remove previous stage data when the first attempt writer 
is initialized.
    
    Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
---
 .../hadoop/mapred/SortWriteBufferManagerTest.java  |   3 +-
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |   3 +-
 .../org/apache/spark/shuffle/RssStageInfo.java     |  33 +-
 .../spark/shuffle/RssStageResubmitManager.java     |  69 +++
 .../handle/StageAttemptShuffleHandleInfo.java      | 137 ++++++
 .../apache/spark/shuffle/writer/AddBlockEvent.java |  11 +
 .../apache/spark/shuffle/writer/DataPusher.java    |   5 +-
 .../shuffle/manager/RssShuffleManagerBase.java     | 467 +++++++++++++++++++++
 .../manager/RssShuffleManagerInterface.java        |  19 +-
 .../shuffle/manager/ShuffleManagerGrpcService.java |  47 ++-
 .../spark/shuffle/writer/DataPusherTest.java       |   9 +
 .../shuffle/manager/DummyRssShuffleManager.java    |  28 +-
 .../manager/ShuffleManagerGrpcServiceTest.java     |   4 +-
 .../apache/spark/shuffle/RssShuffleManager.java    | 334 ++-------------
 .../spark/shuffle/writer/RssShuffleWriter.java     |   1 +
 .../apache/spark/shuffle/RssShuffleManager.java    | 441 ++-----------------
 .../common/sort/buffer/WriteBufferManagerTest.java |   3 +-
 .../uniffle/client/api/ShuffleWriteClient.java     |  32 +-
 .../client/impl/ShuffleWriteClientImpl.java        |  24 +-
 .../netty/protocol/SendShuffleDataRequest.java     |  18 +
 .../org/apache/uniffle/common/rpc/StatusCode.java  |   1 +
 .../uniffle/client/api/ShuffleManagerClient.java   |  17 +-
 .../client/impl/grpc/ShuffleManagerGrpcClient.java |  27 +-
 .../client/impl/grpc/ShuffleServerGrpcClient.java  |  11 +-
 .../impl/grpc/ShuffleServerGrpcNettyClient.java    |   2 +
 .../request/RssGetShuffleAssignmentsRequest.java   |  24 ++
 .../client/request/RssRegisterShuffleRequest.java  |  32 +-
 .../client/request/RssSendShuffleDataRequest.java  |  15 +
 .../RssReassignOnBlockSendFailureResponse.java     |   2 +-
 ...e.java => RssReassignOnStageRetryResponse.java} |  16 +-
 proto/src/main/proto/Rss.proto                     |  22 +-
 .../uniffle/server/ShuffleServerGrpcService.java   |  58 +++
 .../org/apache/uniffle/server/ShuffleTaskInfo.java |  11 +
 .../apache/uniffle/server/ShuffleTaskManager.java  |   2 +-
 .../server/netty/ShuffleServerNettyHandler.java    |  13 +
 35 files changed, 1153 insertions(+), 788 deletions(-)

diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 5c9f401b8..430e2ff58 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -517,7 +517,8 @@ public class SortWriteBufferManagerTest {
         List<PartitionRange> partitionRanges,
         RemoteStorageInfo remoteStorage,
         ShuffleDataDistributionType distributionType,
-        int maxConcurrencyPerPartitionToWrite) {}
+        int maxConcurrencyPerPartitionToWrite,
+        int stageAttemptNumber) {}
 
     @Override
     public boolean sendCommit(
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index b858312b0..7664b47d6 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -504,7 +504,8 @@ public class FetcherTest {
         List<PartitionRange> partitionRanges,
         RemoteStorageInfo storageType,
         ShuffleDataDistributionType distributionType,
-        int maxConcurrencyPerPartitionToWrite) {}
+        int maxConcurrencyPerPartitionToWrite,
+        int stageAttemptNumber) {}
 
     @Override
     public boolean sendCommit(
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
 b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageInfo.java
similarity index 50%
copy from 
internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
copy to 
client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageInfo.java
index 81ca7d548..c8168d6c4 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageInfo.java
@@ -15,27 +15,30 @@
  * limitations under the License.
  */
 
-package org.apache.uniffle.client.response;
+package org.apache.spark.shuffle;
 
-import org.apache.uniffle.common.rpc.StatusCode;
-import org.apache.uniffle.proto.RssProtos;
+public class RssStageInfo {
+  private String stageAttemptIdAndNumber;
+  private boolean isReassigned;
 
-public class RssReassignOnBlockSendFailureResponse extends ClientResponse {
-  private RssProtos.MutableShuffleHandleInfo handle;
+  public RssStageInfo(String stageAttemptIdAndNumber, boolean isReassigned) {
+    this.stageAttemptIdAndNumber = stageAttemptIdAndNumber;
+    this.isReassigned = isReassigned;
+  }
+
+  public String getStageAttemptIdAndNumber() {
+    return stageAttemptIdAndNumber;
+  }
 
-  public RssReassignOnBlockSendFailureResponse(
-      StatusCode statusCode, String message, 
RssProtos.MutableShuffleHandleInfo handle) {
-    super(statusCode, message);
-    this.handle = handle;
+  public void setStageAttemptIdAndNumber(String stageAttemptIdAndNumber) {
+    this.stageAttemptIdAndNumber = stageAttemptIdAndNumber;
   }
 
-  public RssProtos.MutableShuffleHandleInfo getHandle() {
-    return handle;
+  public boolean isReassigned() {
+    return isReassigned;
   }
 
-  public static RssReassignOnBlockSendFailureResponse fromProto(
-      RssProtos.RssReassignOnBlockSendFailureResponse response) {
-    return new RssReassignOnBlockSendFailureResponse(
-        StatusCode.valueOf(response.getStatus().name()), response.getMsg(), 
response.getHandle());
+  public void setReassigned(boolean reassigned) {
+    isReassigned = reassigned;
   }
 }
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java
new file mode 100644
index 000000000..028622f92
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle;
+
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.util.JavaUtils;
+
+public class RssStageResubmitManager {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(RssStageResubmitManager.class);
+  /** Blacklist of the Shuffle Server when the write fails. */
+  private Set<String> serverIdBlackList;
+  /**
+   * Prevent multiple tasks from reporting FetchFailed, resulting in multiple 
ShuffleServer
+   * assignments, stageID, Attemptnumber Whether to reassign the combination 
flag;
+   */
+  private Map<Integer, RssStageInfo> serverAssignedInfos;
+
+  public RssStageResubmitManager() {
+    this.serverIdBlackList = Sets.newConcurrentHashSet();
+    this.serverAssignedInfos = JavaUtils.newConcurrentMap();
+  }
+
+  public Set<String> getServerIdBlackList() {
+    return serverIdBlackList;
+  }
+
+  public void resetServerIdBlackList(Set<String> failuresShuffleServerIds) {
+    this.serverIdBlackList = failuresShuffleServerIds;
+  }
+
+  public void recordFailuresShuffleServer(String shuffleServerId) {
+    serverIdBlackList.add(shuffleServerId);
+  }
+
+  public RssStageInfo recordAndGetServerAssignedInfo(int shuffleId, String 
stageIdAndAttempt) {
+
+    return serverAssignedInfos.computeIfAbsent(
+        shuffleId, id -> new RssStageInfo(stageIdAndAttempt, false));
+  }
+
+  public void recordAndGetServerAssignedInfo(
+      int shuffleId, String stageIdAndAttempt, boolean isRetried) {
+    serverAssignedInfos
+        .computeIfAbsent(shuffleId, id -> new RssStageInfo(stageIdAndAttempt, 
false))
+        .setReassigned(isRetried);
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
new file mode 100644
index 000000000..8fd9642ac
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.handle;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Lists;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.proto.RssProtos;
+
+public class StageAttemptShuffleHandleInfo extends ShuffleHandleInfoBase {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(StageAttemptShuffleHandleInfo.class);
+
+  private ShuffleHandleInfo current;
+  /** When Stage retry occurs, record the Shuffle Server of the previous 
Stage. */
+  private LinkedList<ShuffleHandleInfo> historyHandles;
+
+  public StageAttemptShuffleHandleInfo(
+      int shuffleId, RemoteStorageInfo remoteStorage, ShuffleHandleInfo 
shuffleServerInfo) {
+    super(shuffleId, remoteStorage);
+    this.current = shuffleServerInfo;
+    this.historyHandles = Lists.newLinkedList();
+  }
+
+  public StageAttemptShuffleHandleInfo(
+      int shuffleId,
+      RemoteStorageInfo remoteStorage,
+      ShuffleHandleInfo currentShuffleServerInfo,
+      LinkedList<ShuffleHandleInfo> historyHandles) {
+    super(shuffleId, remoteStorage);
+    this.current = currentShuffleServerInfo;
+    this.historyHandles = historyHandles;
+  }
+
+  @Override
+  public Set<ShuffleServerInfo> getServers() {
+    return current.getServers();
+  }
+
+  @Override
+  public Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter() {
+    return current.getAvailablePartitionServersForWriter();
+  }
+
+  @Override
+  public Map<Integer, List<ShuffleServerInfo>> 
getAllPartitionServersForReader() {
+    return current.getAllPartitionServersForReader();
+  }
+
+  @Override
+  public PartitionDataReplicaRequirementTracking 
createPartitionReplicaTracking() {
+    return current.createPartitionReplicaTracking();
+  }
+
+  /**
+   * When a Stage retry occurs, replace the current shuffleHandleInfo and 
record the historical
+   * shuffleHandleInfo.
+   */
+  public void replaceCurrentShuffleHandleInfo(ShuffleHandleInfo 
shuffleHandleInfo) {
+    this.historyHandles.add(current);
+    this.current = shuffleHandleInfo;
+  }
+
+  public ShuffleHandleInfo getCurrent() {
+    return current;
+  }
+
+  public LinkedList<ShuffleHandleInfo> getHistoryHandles() {
+    return historyHandles;
+  }
+
+  public static RssProtos.StageAttemptShuffleHandleInfo toProto(
+      StageAttemptShuffleHandleInfo handleInfo) {
+    LinkedList<RssProtos.MutableShuffleHandleInfo> 
mutableShuffleHandleInfoLinkedList =
+        Lists.newLinkedList();
+    RssProtos.MutableShuffleHandleInfo currentMutableShuffleHandleInfo =
+        MutableShuffleHandleInfo.toProto((MutableShuffleHandleInfo) 
handleInfo.getCurrent());
+    for (ShuffleHandleInfo historyHandle : handleInfo.getHistoryHandles()) {
+      mutableShuffleHandleInfoLinkedList.add(
+          MutableShuffleHandleInfo.toProto((MutableShuffleHandleInfo) 
historyHandle));
+    }
+    RssProtos.StageAttemptShuffleHandleInfo handleProto =
+        RssProtos.StageAttemptShuffleHandleInfo.newBuilder()
+            
.setCurrentMutableShuffleHandleInfo(currentMutableShuffleHandleInfo)
+            
.addAllHistoryMutableShuffleHandleInfo(mutableShuffleHandleInfoLinkedList)
+            .build();
+    return handleProto;
+  }
+
+  public static StageAttemptShuffleHandleInfo fromProto(
+      RssProtos.StageAttemptShuffleHandleInfo handleProto) {
+    if (handleProto == null) {
+      return null;
+    }
+
+    MutableShuffleHandleInfo mutableShuffleHandleInfo =
+        
MutableShuffleHandleInfo.fromProto(handleProto.getCurrentMutableShuffleHandleInfo());
+    List<RssProtos.MutableShuffleHandleInfo> 
historyMutableShuffleHandleInfoList =
+        handleProto.getHistoryMutableShuffleHandleInfoList();
+    LinkedList<ShuffleHandleInfo> historyHandles = Lists.newLinkedList();
+    for (RssProtos.MutableShuffleHandleInfo shuffleHandleInfo :
+        historyMutableShuffleHandleInfoList) {
+      
historyHandles.add(MutableShuffleHandleInfo.fromProto(shuffleHandleInfo));
+    }
+
+    StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
+        new StageAttemptShuffleHandleInfo(
+            mutableShuffleHandleInfo.shuffleId,
+            mutableShuffleHandleInfo.remoteStorage,
+            mutableShuffleHandleInfo,
+            historyHandles);
+    return stageAttemptShuffleHandleInfo;
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
index 9751ba0b8..f989fdb0b 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
@@ -25,11 +25,18 @@ import org.apache.uniffle.common.ShuffleBlockInfo;
 public class AddBlockEvent {
 
   private String taskId;
+  private int stageAttemptNumber;
   private List<ShuffleBlockInfo> shuffleDataInfoList;
   private List<Runnable> processedCallbackChain;
 
   public AddBlockEvent(String taskId, List<ShuffleBlockInfo> 
shuffleDataInfoList) {
+    this(taskId, 0, shuffleDataInfoList);
+  }
+
+  public AddBlockEvent(
+      String taskId, int stageAttemptNumber, List<ShuffleBlockInfo> 
shuffleDataInfoList) {
     this.taskId = taskId;
+    this.stageAttemptNumber = stageAttemptNumber;
     this.shuffleDataInfoList = shuffleDataInfoList;
     this.processedCallbackChain = new ArrayList<>();
   }
@@ -43,6 +50,10 @@ public class AddBlockEvent {
     return taskId;
   }
 
+  public int getStageAttemptNumber() {
+    return stageAttemptNumber;
+  }
+
   public List<ShuffleBlockInfo> getShuffleDataInfoList() {
     return shuffleDataInfoList;
   }
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index e9ef2ba61..bdf0cf849 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -92,7 +92,10 @@ public class DataPusher implements Closeable {
               try {
                 result =
                     shuffleWriteClient.sendShuffleData(
-                        rssAppId, shuffleBlockInfoList, () -> 
!isValidTask(taskId));
+                        rssAppId,
+                        event.getStageAttemptNumber(),
+                        shuffleBlockInfoList,
+                        () -> !isValidTask(taskId));
                 putBlockId(taskToSuccessBlockIds, taskId, 
result.getSuccessBlockIds());
                 putFailedBlockSendTracker(
                     taskToFailedBlockSendTracker, taskId, 
result.getFailedBlockSendTracker());
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 77cb173e3..912bf998f 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -19,16 +19,24 @@ package org.apache.uniffle.shuffle.manager;
 
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.spark.MapOutputTracker;
@@ -38,22 +46,41 @@ import org.apache.spark.SparkEnv;
 import org.apache.spark.SparkException;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.RssSparkShuffleUtils;
+import org.apache.spark.shuffle.RssStageInfo;
+import org.apache.spark.shuffle.RssStageResubmitManager;
+import org.apache.spark.shuffle.ShuffleHandleInfoManager;
 import org.apache.spark.shuffle.ShuffleManager;
 import org.apache.spark.shuffle.SparkVersionUtils;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.CoordinatorClient;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.request.RssFetchClientConfRequest;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
 import org.apache.uniffle.client.response.RssFetchClientConfResponse;
+import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
+import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.ConfigOption;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.shuffle.BlockIdManager;
 
 import static 
org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX;
@@ -65,7 +92,34 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
   private Method unregisterAllMapOutputMethod;
   private Method registerShuffleMethod;
   private volatile BlockIdManager blockIdManager;
+  protected ShuffleDataDistributionType dataDistributionType;
   private Object blockIdManagerLock = new Object();
+  protected AtomicReference<String> id = new AtomicReference<>();
+  protected String appId = "";
+  protected ShuffleWriteClient shuffleWriteClient;
+  protected boolean dynamicConfEnabled;
+  protected int maxConcurrencyPerPartitionToWrite;
+  protected String clientType;
+
+  protected SparkConf sparkConf;
+  protected ShuffleManagerClient shuffleManagerClient;
+  /** Whether to enable the dynamic shuffleServer function rewrite and reread 
functions */
+  protected boolean rssResubmitStage;
+  /**
+   * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is 
dynamically allocated.
+   * ShuffleServer is not obtained from RssShuffleHandle, but from this 
mapping.
+   */
+  protected ShuffleHandleInfoManager shuffleHandleInfoManager;
+
+  protected RssStageResubmitManager rssStageResubmitManager;
+
+  protected int partitionReassignMaxServerNum;
+
+  protected boolean blockIdSelfManagedEnabled;
+
+  protected boolean taskBlockSendFailureRetryEnabled;
+
+  protected boolean shuffleManagerRpcServiceEnabled;
 
   public BlockIdManager getBlockIdManager() {
     if (blockIdManager == null) {
@@ -520,4 +574,417 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     return new RemoteStorageInfo(
         sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), 
confItems);
   }
+
+  /**
+   * In Stage Retry mode, obtain the Shuffle Server list from the Driver based 
on shuffleId.
+   *
+   * @param shuffleId shuffleId
+   * @return ShuffleHandleInfo
+   */
+  protected synchronized StageAttemptShuffleHandleInfo 
getRemoteShuffleHandleInfoWithStageRetry(
+      int shuffleId) {
+    RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+        new RssPartitionToShuffleServerRequest(shuffleId);
+    RssReassignOnStageRetryResponse rpcPartitionToShufflerServer =
+        getOrCreateShuffleManagerClient()
+            
.getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
+    StageAttemptShuffleHandleInfo shuffleHandleInfo =
+        StageAttemptShuffleHandleInfo.fromProto(
+            rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
+    return shuffleHandleInfo;
+  }
+
+  /**
+   * In Block Retry mode, obtain the Shuffle Server list from the Driver based 
on shuffleId.
+   *
+   * @param shuffleId shuffleId
+   * @return ShuffleHandleInfo
+   */
+  protected synchronized MutableShuffleHandleInfo 
getRemoteShuffleHandleInfoWithBlockRetry(
+      int shuffleId) {
+    RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+        new RssPartitionToShuffleServerRequest(shuffleId);
+    RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer =
+        getOrCreateShuffleManagerClient()
+            
.getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
+    MutableShuffleHandleInfo shuffleHandleInfo =
+        
MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
+    return shuffleHandleInfo;
+  }
+
+  // todo: automatic close client when the client is idle to avoid too much 
connections for spark
+  // driver.
+  protected ShuffleManagerClient getOrCreateShuffleManagerClient() {
+    if (shuffleManagerClient == null) {
+      RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+      String driver = rssConf.getString("driver.host", "");
+      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+      this.shuffleManagerClient =
+          ShuffleManagerClientFactory.getInstance()
+              .createShuffleManagerClient(ClientType.GRPC, driver, port);
+    }
+    return shuffleManagerClient;
+  }
+
+  @Override
+  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+    return shuffleHandleInfoManager.get(shuffleId);
+  }
+
+  /**
+   * @return the maximum number of fetch failures per shuffle partition before 
that shuffle stage
+   *     should be recomputed
+   */
+  @Override
+  public int getMaxFetchFailures() {
+    final String TASK_MAX_FAILURE = "spark.task.maxFailures";
+    return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1);
+  }
+
+  /**
+   * Add the shuffleServer that failed to write to the failure list
+   *
+   * @param shuffleServerId
+   */
+  @Override
+  public void addFailuresShuffleServerInfos(String shuffleServerId) {
+    rssStageResubmitManager.recordFailuresShuffleServer(shuffleServerId);
+  }
+
+  /**
+   * Reassign the ShuffleServer list for ShuffleId
+   *
+   * @param shuffleId
+   * @param numPartitions
+   */
+  @Override
+  public boolean reassignOnStageResubmit(
+      int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
+    String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
+    RssStageInfo rssStageInfo =
+        rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, 
stageIdAndAttempt);
+    synchronized (rssStageInfo) {
+      Boolean needReassign = rssStageInfo.isReassigned();
+      if (!needReassign) {
+        int requiredShuffleServerNumber =
+            RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
+        int estimateTaskConcurrency = 
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
+
+        /**
+         * this will clear up the previous stage attempt all data when 
registering the same
+         * shuffleId at the second time
+         */
+        Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+            requestShuffleAssignment(
+                shuffleId,
+                numPartitions,
+                1,
+                requiredShuffleServerNumber,
+                estimateTaskConcurrency,
+                rssStageResubmitManager.getServerIdBlackList(),
+                stageAttemptNumber);
+        /**
+         * we need to clear the metadata of the completed task, otherwise some 
of the stage's data
+         * will be lost
+         */
+        try {
+          unregisterAllMapOutput(shuffleId);
+        } catch (SparkException e) {
+          LOG.error("Clear MapoutTracker Meta failed!");
+          throw new RssException("Clear MapoutTracker Meta failed!", e);
+        }
+        MutableShuffleHandleInfo shuffleHandleInfo =
+            new MutableShuffleHandleInfo(shuffleId, partitionToServers, 
getRemoteStorageInfo());
+        StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
+            (StageAttemptShuffleHandleInfo) 
shuffleHandleInfoManager.get(shuffleId);
+        
stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo);
+        rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, 
stageIdAndAttempt, true);
+        return true;
+      } else {
+        LOG.info(
+            "Do nothing that the stage: {} has been reassigned for attempt{}",
+            stageId,
+            stageAttemptNumber);
+        return false;
+      }
+    }
+  }
+
+  /** this is only valid on driver side that exposed to being invoked by grpc 
server */
+  @Override
+  public MutableShuffleHandleInfo reassignOnBlockSendFailure(
+      int shuffleId, Map<Integer, List<ReceivingFailureServer>> 
partitionToFailureServers) {
+    long startTime = System.currentTimeMillis();
+    MutableShuffleHandleInfo handleInfo =
+        (MutableShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId);
+    synchronized (handleInfo) {
+      // If the reassignment servers for one partition exceeds the max 
reassign server num,
+      // it should fast fail.
+      handleInfo.checkPartitionReassignServerNum(
+          partitionToFailureServers.keySet(), partitionReassignMaxServerNum);
+
+      Map<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new 
HashMap<>();
+      // receivingFailureServer -> partitionId -> replacementServerIds. For 
logging
+      Map<String, Map<Integer, Set<String>>> reassignResult = new HashMap<>();
+
+      for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
+          partitionToFailureServers.entrySet()) {
+        int partitionId = entry.getKey();
+        for (ReceivingFailureServer receivingFailureServer : entry.getValue()) 
{
+          StatusCode code = receivingFailureServer.getStatusCode();
+          String serverId = receivingFailureServer.getServerId();
+
+          boolean serverHasReplaced = false;
+          Set<ShuffleServerInfo> replacements = 
handleInfo.getReplacements(serverId);
+          if (CollectionUtils.isEmpty(replacements)) {
+            final int requiredServerNum = 1;
+            Set<String> excludedServers = new 
HashSet<>(handleInfo.listExcludedServers());
+            excludedServers.add(serverId);
+            replacements =
+                reassignServerForTask(
+                    shuffleId, Sets.newHashSet(partitionId), excludedServers, 
requiredServerNum);
+          } else {
+            serverHasReplaced = true;
+          }
+
+          Set<ShuffleServerInfo> updatedReassignServers =
+              handleInfo.updateAssignment(partitionId, serverId, replacements);
+
+          reassignResult
+              .computeIfAbsent(serverId, x -> new HashMap<>())
+              .computeIfAbsent(partitionId, x -> new HashSet<>())
+              .addAll(
+                  updatedReassignServers.stream().map(x -> 
x.getId()).collect(Collectors.toSet()));
+
+          if (serverHasReplaced) {
+            for (ShuffleServerInfo serverInfo : updatedReassignServers) {
+              newServerToPartitions
+                  .computeIfAbsent(serverInfo, x -> new ArrayList<>())
+                  .add(new PartitionRange(partitionId, partitionId));
+            }
+          }
+        }
+      }
+      if (!newServerToPartitions.isEmpty()) {
+        LOG.info(
+            "Register the new partition->servers assignment on reassign. {}",
+            newServerToPartitions);
+        registerShuffleServers(id.get(), shuffleId, newServerToPartitions, 
getRemoteStorageInfo());
+      }
+
+      LOG.info(
+          "Finished reassignOnBlockSendFailure request and cost {}(ms). 
Reassign result: {}",
+          System.currentTimeMillis() - startTime,
+          reassignResult);
+
+      return handleInfo;
+    }
+  }
+
+  /**
+   * Creating the shuffleAssignmentInfo from the servers and partitionIds
+   *
+   * @param servers
+   * @param partitionIds
+   * @return
+   */
+  private ShuffleAssignmentsInfo createShuffleAssignmentsInfo(
+      Set<ShuffleServerInfo> servers, Set<Integer> partitionIds) {
+    Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new 
HashMap<>();
+    List<PartitionRange> partitionRanges = new ArrayList<>();
+    for (Integer partitionId : partitionIds) {
+      newPartitionToServers.put(partitionId, new ArrayList<>(servers));
+      partitionRanges.add(new PartitionRange(partitionId, partitionId));
+    }
+    Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new 
HashMap<>();
+    for (ShuffleServerInfo server : servers) {
+      serverToPartitionRanges.put(server, partitionRanges);
+    }
+    return new ShuffleAssignmentsInfo(newPartitionToServers, 
serverToPartitionRanges);
+  }
+
+  /** Request the new shuffle-servers to replace faulty server. */
+  private Set<ShuffleServerInfo> reassignServerForTask(
+      int shuffleId,
+      Set<Integer> partitionIds,
+      Set<String> excludedServers,
+      int requiredServerNum) {
+    AtomicReference<Set<ShuffleServerInfo>> replacementsRef =
+        new AtomicReference<>(new HashSet<>());
+    requestShuffleAssignment(
+        shuffleId,
+        requiredServerNum,
+        1,
+        requiredServerNum,
+        1,
+        excludedServers,
+        shuffleAssignmentsInfo -> {
+          if (shuffleAssignmentsInfo == null) {
+            return null;
+          }
+          Set<ShuffleServerInfo> replacements =
+              shuffleAssignmentsInfo.getPartitionToServers().values().stream()
+                  .flatMap(x -> x.stream())
+                  .collect(Collectors.toSet());
+          replacementsRef.set(replacements);
+          return createShuffleAssignmentsInfo(replacements, partitionIds);
+        });
+    return replacementsRef.get();
+  }
+
+  private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
+      int shuffleId,
+      int partitionNum,
+      int partitionNumPerRange,
+      int assignmentShuffleServerNumber,
+      int estimateTaskConcurrency,
+      Set<String> faultyServerIds,
+      Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo> 
reassignmentHandler) {
+    Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+    ClientUtils.validateClientType(clientType);
+    assignmentTags.add(clientType);
+    long retryInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+    int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+    faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
+    try {
+      return RetryUtils.retry(
+          () -> {
+            ShuffleAssignmentsInfo response =
+                shuffleWriteClient.getShuffleAssignments(
+                    id.get(),
+                    shuffleId,
+                    partitionNum,
+                    partitionNumPerRange,
+                    assignmentTags,
+                    assignmentShuffleServerNumber,
+                    estimateTaskConcurrency,
+                    faultyServerIds);
+            LOG.info("Finished reassign");
+            if (reassignmentHandler != null) {
+              response = reassignmentHandler.apply(response);
+            }
+            registerShuffleServers(
+                id.get(), shuffleId, response.getServerToPartitionRanges(), 
getRemoteStorageInfo());
+            return response.getPartitionToServers();
+          },
+          retryInterval,
+          retryTimes);
+    } catch (Throwable throwable) {
+      throw new RssException("registerShuffle failed!", throwable);
+    }
+  }
+
+  protected Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
+      int shuffleId,
+      int partitionNum,
+      int partitionNumPerRange,
+      int assignmentShuffleServerNumber,
+      int estimateTaskConcurrency,
+      Set<String> faultyServerIds,
+      int stageAttemptNumber) {
+    Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+    ClientUtils.validateClientType(clientType);
+    assignmentTags.add(clientType);
+
+    long retryInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+    int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+    faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
+    try {
+      return RetryUtils.retry(
+          () -> {
+            ShuffleAssignmentsInfo response =
+                shuffleWriteClient.getShuffleAssignments(
+                    appId,
+                    shuffleId,
+                    partitionNum,
+                    partitionNumPerRange,
+                    assignmentTags,
+                    assignmentShuffleServerNumber,
+                    estimateTaskConcurrency,
+                    faultyServerIds);
+            registerShuffleServers(
+                appId,
+                shuffleId,
+                response.getServerToPartitionRanges(),
+                getRemoteStorageInfo(),
+                stageAttemptNumber);
+            return response.getPartitionToServers();
+          },
+          retryInterval,
+          retryTimes);
+    } catch (Throwable throwable) {
+      throw new RssException("registerShuffle failed!", throwable);
+    }
+  }
+
+  protected void registerShuffleServers(
+      String appId,
+      int shuffleId,
+      Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
+      RemoteStorageInfo remoteStorage,
+      int stageAttemptNumber) {
+    if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
+      return;
+    }
+    LOG.info("Start to register shuffleId {}", shuffleId);
+    long start = System.currentTimeMillis();
+    serverToPartitionRanges.entrySet().stream()
+        .forEach(
+            entry -> {
+              shuffleWriteClient.registerShuffle(
+                  entry.getKey(),
+                  appId,
+                  shuffleId,
+                  entry.getValue(),
+                  remoteStorage,
+                  ShuffleDataDistributionType.NORMAL,
+                  maxConcurrencyPerPartitionToWrite,
+                  stageAttemptNumber);
+            });
+    LOG.info(
+        "Finish register shuffleId {} with {} ms", shuffleId, 
(System.currentTimeMillis() - start));
+  }
+
+  @VisibleForTesting
+  protected void registerShuffleServers(
+      String appId,
+      int shuffleId,
+      Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
+      RemoteStorageInfo remoteStorage) {
+    if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
+      return;
+    }
+    LOG.info("Start to register shuffleId[{}]", shuffleId);
+    long start = System.currentTimeMillis();
+    Set<Map.Entry<ShuffleServerInfo, List<PartitionRange>>> entries =
+        serverToPartitionRanges.entrySet();
+    entries.stream()
+        .forEach(
+            entry -> {
+              shuffleWriteClient.registerShuffle(
+                  entry.getKey(),
+                  appId,
+                  shuffleId,
+                  entry.getValue(),
+                  remoteStorage,
+                  dataDistributionType,
+                  maxConcurrencyPerPartitionToWrite);
+            });
+    LOG.info(
+        "Finish register shuffleId[{}] with {} ms",
+        shuffleId,
+        (System.currentTimeMillis() - start));
+  }
+
+  protected RemoteStorageInfo getRemoteStorageInfo() {
+    String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
+    RemoteStorageInfo defaultRemoteStorage =
+        new 
RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), 
""));
+    return ClientUtils.fetchRemoteStorage(
+        appId, defaultRemoteStorage, dynamicConfEnabled, storageType, 
shuffleWriteClient);
+  }
+
+  public boolean isRssResubmitStage() {
+    return rssResubmitStage;
+  }
 }
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index b21360041..52ded5db7 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -37,12 +37,6 @@ public interface RssShuffleManagerInterface {
   /** @return the unique spark id for rss shuffle */
   String getAppId();
 
-  /**
-   * @return the maximum number of fetch failures per shuffle partition before 
that shuffle stage
-   *     should be re-submitted
-   */
-  int getMaxFetchFailures();
-
   /**
    * @param shuffleId the shuffle id to query
    * @return the num of partitions(a.k.a reduce tasks) for shuffle with 
shuffle id.
@@ -63,6 +57,8 @@ public interface RssShuffleManagerInterface {
    */
   void unregisterAllMapOutput(int shuffleId) throws SparkException;
 
+  BlockIdManager getBlockIdManager();
+
   /**
    * Get ShuffleHandleInfo with ShuffleId
    *
@@ -71,6 +67,12 @@ public interface RssShuffleManagerInterface {
    */
   ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId);
 
+  /**
+   * @return the maximum number of fetch failures per shuffle partition before 
that shuffle stage
+   *     should be re-submitted
+   */
+  int getMaxFetchFailures();
+
   /**
    * Add the shuffleServer that failed to write to the failure list
    *
@@ -78,11 +80,8 @@ public interface RssShuffleManagerInterface {
    */
   void addFailuresShuffleServerInfos(String shuffleServerId);
 
-  boolean reassignAllShuffleServersForWholeStage(
-      int stageId, int stageAttemptNumber, int shuffleId, int numMaps);
+  boolean reassignOnStageResubmit(int stageId, int stageAttemptNumber, int 
shuffleId, int numMaps);
 
   MutableShuffleHandleInfo reassignOnBlockSendFailure(
       int shuffleId, Map<Integer, List<ReceivingFailureServer>> 
partitionToFailureServers);
-
-  BlockIdManager getBlockIdManager();
 }
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index a4ff727b5..425f03e65 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -30,6 +30,7 @@ import java.util.stream.Collectors;
 import com.google.protobuf.UnsafeByteOperations;
 import io.grpc.stub.StreamObserver;
 import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -188,10 +189,34 @@ public class ShuffleManagerGrpcService extends 
ShuffleManagerImplBase {
   }
 
   @Override
-  public void getPartitionToShufflerServer(
+  public void getPartitionToShufflerServerWithStageRetry(
       RssProtos.PartitionToShuffleServerRequest request,
-      StreamObserver<RssProtos.PartitionToShuffleServerResponse> 
responseObserver) {
-    RssProtos.PartitionToShuffleServerResponse reply;
+      StreamObserver<RssProtos.ReassignOnStageRetryResponse> responseObserver) 
{
+    RssProtos.ReassignOnStageRetryResponse reply;
+    RssProtos.StatusCode code;
+    int shuffleId = request.getShuffleId();
+    StageAttemptShuffleHandleInfo shuffleHandle =
+        (StageAttemptShuffleHandleInfo) 
shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
+    if (shuffleHandle != null) {
+      code = RssProtos.StatusCode.SUCCESS;
+      reply =
+          RssProtos.ReassignOnStageRetryResponse.newBuilder()
+              .setStatus(code)
+              
.setShuffleHandleInfo(StageAttemptShuffleHandleInfo.toProto(shuffleHandle))
+              .build();
+    } else {
+      code = RssProtos.StatusCode.INVALID_REQUEST;
+      reply = 
RssProtos.ReassignOnStageRetryResponse.newBuilder().setStatus(code).build();
+    }
+    responseObserver.onNext(reply);
+    responseObserver.onCompleted();
+  }
+
+  @Override
+  public void getPartitionToShufflerServerWithBlockRetry(
+      RssProtos.PartitionToShuffleServerRequest request,
+      StreamObserver<RssProtos.ReassignOnBlockSendFailureResponse> 
responseObserver) {
+    RssProtos.ReassignOnBlockSendFailureResponse reply;
     RssProtos.StatusCode code;
     int shuffleId = request.getShuffleId();
     MutableShuffleHandleInfo shuffleHandle =
@@ -199,13 +224,13 @@ public class ShuffleManagerGrpcService extends 
ShuffleManagerImplBase {
     if (shuffleHandle != null) {
       code = RssProtos.StatusCode.SUCCESS;
       reply =
-          RssProtos.PartitionToShuffleServerResponse.newBuilder()
+          RssProtos.ReassignOnBlockSendFailureResponse.newBuilder()
               .setStatus(code)
-              
.setShuffleHandleInfo(MutableShuffleHandleInfo.toProto(shuffleHandle))
+              .setHandle(MutableShuffleHandleInfo.toProto(shuffleHandle))
               .build();
     } else {
       code = RssProtos.StatusCode.INVALID_REQUEST;
-      reply = 
RssProtos.PartitionToShuffleServerResponse.newBuilder().setStatus(code).build();
+      reply = 
RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(code).build();
     }
     responseObserver.onNext(reply);
     responseObserver.onCompleted();
@@ -220,7 +245,7 @@ public class ShuffleManagerGrpcService extends 
ShuffleManagerImplBase {
     int shuffleId = request.getShuffleId();
     int numPartitions = request.getNumPartitions();
     boolean needReassign =
-        shuffleManager.reassignAllShuffleServersForWholeStage(
+        shuffleManager.reassignOnStageResubmit(
             stageId, stageAttemptNumber, shuffleId, numPartitions);
     RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
     RssProtos.ReassignServersResponse reply =
@@ -236,10 +261,10 @@ public class ShuffleManagerGrpcService extends 
ShuffleManagerImplBase {
   public void reassignOnBlockSendFailure(
       org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureRequest 
request,
       io.grpc.stub.StreamObserver<
-              
org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureResponse>
+              
org.apache.uniffle.proto.RssProtos.ReassignOnBlockSendFailureResponse>
           responseObserver) {
     RssProtos.StatusCode code = RssProtos.StatusCode.INTERNAL_ERROR;
-    RssProtos.RssReassignOnBlockSendFailureResponse reply;
+    RssProtos.ReassignOnBlockSendFailureResponse reply;
     try {
       LOG.info(
           "Accepted reassign request on block sent failure for shuffleId: {}, 
stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {}",
@@ -257,14 +282,14 @@ public class ShuffleManagerGrpcService extends 
ShuffleManagerImplBase {
                           Map.Entry::getKey, x -> 
ReceivingFailureServer.fromProto(x.getValue()))));
       code = RssProtos.StatusCode.SUCCESS;
       reply =
-          RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
+          RssProtos.ReassignOnBlockSendFailureResponse.newBuilder()
               .setStatus(code)
               .setHandle(MutableShuffleHandleInfo.toProto(handle))
               .build();
     } catch (Exception e) {
       LOG.error("Errors on reassigning when block send failure.", e);
       reply =
-          RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
+          RssProtos.ReassignOnBlockSendFailureResponse.newBuilder()
               .setStatus(code)
               .setMsg(e.getMessage())
               .build();
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index ebc590af5..080ba1e33 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -70,6 +70,15 @@ public class DataPusherTest {
         String appId,
         List<ShuffleBlockInfo> shuffleBlockInfoList,
         Supplier<Boolean> needCancelRequest) {
+      return sendShuffleData(appId, 0, shuffleBlockInfoList, 
needCancelRequest);
+    }
+
+    @Override
+    public SendShuffleDataResult sendShuffleData(
+        String appId,
+        int stageAttemptNumber,
+        List<ShuffleBlockInfo> shuffleBlockInfoList,
+        Supplier<Boolean> needCancelRequest) {
       return fakedShuffleDataResult;
     }
 
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index 66bb26de8..df40e9d16 100644
--- 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -23,13 +23,11 @@ import java.util.Map;
 import java.util.Set;
 
 import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
-import org.apache.spark.shuffle.handle.ShuffleHandleInfoBase;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 
 import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.shuffle.BlockIdManager;
 
-import static org.mockito.Mockito.mock;
-
 public class DummyRssShuffleManager implements RssShuffleManagerInterface {
   public Set<Integer> unregisteredShuffleIds = new LinkedHashSet<>();
 
@@ -38,11 +36,6 @@ public class DummyRssShuffleManager implements 
RssShuffleManagerInterface {
     return "testAppId";
   }
 
-  @Override
-  public int getMaxFetchFailures() {
-    return 2;
-  }
-
   @Override
   public int getPartitionNum(int shuffleId) {
     return 16;
@@ -59,15 +52,25 @@ public class DummyRssShuffleManager implements 
RssShuffleManagerInterface {
   }
 
   @Override
-  public ShuffleHandleInfoBase getShuffleHandleInfoByShuffleId(int shuffleId) {
+  public BlockIdManager getBlockIdManager() {
+    return null;
+  }
+
+  @Override
+  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
     return null;
   }
 
+  @Override
+  public int getMaxFetchFailures() {
+    return 0;
+  }
+
   @Override
   public void addFailuresShuffleServerInfos(String shuffleServerId) {}
 
   @Override
-  public boolean reassignAllShuffleServersForWholeStage(
+  public boolean reassignOnStageResubmit(
       int stageId, int stageAttemptNumber, int shuffleId, int numMaps) {
     return false;
   }
@@ -75,11 +78,6 @@ public class DummyRssShuffleManager implements 
RssShuffleManagerInterface {
   @Override
   public MutableShuffleHandleInfo reassignOnBlockSendFailure(
       int shuffleId, Map<Integer, List<ReceivingFailureServer>> 
partitionToFailureServers) {
-    return mock(MutableShuffleHandleInfo.class);
-  }
-
-  @Override
-  public BlockIdManager getBlockIdManager() {
     return null;
   }
 }
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
index 6dc2abbf6..ac3fbda7e 100644
--- 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
@@ -35,7 +35,7 @@ import static org.mockito.Mockito.mock;
 
 public class ShuffleManagerGrpcServiceTest {
   // create mock of RssShuffleManagerInterface.
-  private static RssShuffleManagerInterface mockShuffleManager;
+  private static RssShuffleManagerBase mockShuffleManager;
   private static final String appId = "app-123";
   private static final int maxFetchFailures = 2;
   private static final int shuffleId = 0;
@@ -65,7 +65,7 @@ public class ShuffleManagerGrpcServiceTest {
 
   @BeforeAll
   public static void setup() {
-    mockShuffleManager = mock(RssShuffleManagerInterface.class);
+    mockShuffleManager = mock(RssShuffleManagerBase.class);
     Mockito.when(mockShuffleManager.getAppId()).thenReturn(appId);
     Mockito.when(mockShuffleManager.getNumMaps(shuffleId)).thenReturn(numMaps);
     
Mockito.when(mockShuffleManager.getPartitionNum(shuffleId)).thenReturn(numReduces);
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 0f2830721..eb48b2f38 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -37,13 +37,13 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkEnv;
-import org.apache.spark.SparkException;
 import org.apache.spark.TaskContext;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
 import org.apache.spark.shuffle.writer.DataPusher;
@@ -54,20 +54,10 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.uniffle.client.api.ShuffleManagerClient;
-import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
-import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.client.util.RssClientConfig;
-import org.apache.uniffle.common.ClientType;
-import org.apache.uniffle.common.PartitionRange;
-import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.ShuffleAssignmentsInfo;
-import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
@@ -76,7 +66,6 @@ import 
org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
 import org.apache.uniffle.shuffle.RssShuffleClientFactory;
@@ -94,10 +83,6 @@ public class RssShuffleManager extends RssShuffleManagerBase 
{
   private final long heartbeatInterval;
   private final long heartbeatTimeout;
   private ScheduledExecutorService heartBeatScheduledExecutorService;
-  private SparkConf sparkConf;
-  private String appId = "";
-  private String clientType;
-  private ShuffleWriteClient shuffleWriteClient;
   private Map<String, Set<Long>> taskToSuccessBlockIds = 
JavaUtils.newConcurrentMap();
   private Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
       JavaUtils.newConcurrentMap();
@@ -109,41 +94,16 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   private final int dataCommitPoolSize;
   private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
   private boolean heartbeatStarted = false;
-  private boolean dynamicConfEnabled;
   private final int maxFailures;
   private final boolean speculation;
   private final BlockIdLayout blockIdLayout;
   private final String user;
   private final String uuid;
   private DataPusher dataPusher;
-  private final int maxConcurrencyPerPartitionToWrite;
-
   private final Map<Integer, Integer> shuffleIdToPartitionNum = 
JavaUtils.newConcurrentMap();
   private final Map<Integer, Integer> shuffleIdToNumMapTasks = 
JavaUtils.newConcurrentMap();
   private GrpcServer shuffleManagerServer;
   private ShuffleManagerGrpcService service;
-  private ShuffleManagerClient shuffleManagerClient;
-  /**
-   * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is 
dynamically allocated.
-   * ShuffleServer is not obtained from RssShuffleHandle, but from this 
mapping.
-   */
-  private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo =
-      JavaUtils.newConcurrentMap();
-  /** Whether to enable the dynamic shuffleServer function rewrite and reread 
functions */
-  private boolean rssResubmitStage;
-
-  private boolean taskBlockSendFailureRetry;
-
-  private boolean shuffleManagerRpcServiceEnabled;
-  /** A list of shuffleServer for Write failures */
-  private Set<String> failuresShuffleServerIds = Sets.newHashSet();
-  /**
-   * Prevent multiple tasks from reporting FetchFailed, resulting in multiple 
ShuffleServer
-   * assignments, stageID, Attemptnumber Whether to reassign the combination 
flag;
-   */
-  private Map<String, Boolean> serverAssignedInfos = 
JavaUtils.newConcurrentMap();
-
-  private boolean blockIdSelfManagedEnabled;
 
   public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
     if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
@@ -178,15 +138,11 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.maxConcurrencyPerPartitionToWrite =
         
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
     LOG.info(
-        "Check quorum config ["
-            + dataReplica
-            + ":"
-            + dataReplicaWrite
-            + ":"
-            + dataReplicaRead
-            + ":"
-            + dataReplicaSkipEnabled
-            + "]");
+        "Check quorum config [{}:{}:{}:{}]",
+        dataReplica,
+        dataReplicaWrite,
+        dataReplicaRead,
+        dataReplicaSkipEnabled);
     RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, 
dataReplicaRead);
 
     this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
@@ -212,10 +168,11 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.rssResubmitStage =
         rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
             && RssSparkShuffleUtils.isStageResubmitSupported();
-    this.taskBlockSendFailureRetry = 
rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
+    this.taskBlockSendFailureRetryEnabled =
+        rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
     this.blockIdSelfManagedEnabled = 
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
     this.shuffleManagerRpcServiceEnabled =
-        taskBlockSendFailureRetry || rssResubmitStage || 
blockIdSelfManagedEnabled;
+        taskBlockSendFailureRetryEnabled || rssResubmitStage || 
blockIdSelfManagedEnabled;
     if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
       if (isDriver) {
         heartBeatScheduledExecutorService =
@@ -278,6 +235,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
               poolSize,
               keepAliveTime);
     }
+    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+    this.rssStageResubmitManager = new RssStageResubmitManager();
   }
 
   // This method is called in Spark driver side,
@@ -353,8 +312,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         ClientUtils.fetchRemoteStorage(
             appId, defaultRemoteStorage, dynamicConfEnabled, storageType, 
shuffleWriteClient);
 
-    int partitionNumPerRange = 
sparkConf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE);
-
     // get all register info according to coordinator's response
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
     ClientUtils.validateClientType(clientType);
@@ -362,44 +319,32 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
     int requiredShuffleServerNumber =
         RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
+    int estimateTaskConcurrency = 
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
 
-    // retryInterval must bigger than `rss.server.heartbeat.interval`, or 
maybe it will return the
-    // same result
-    long retryInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
-    int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
-
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers;
-    try {
-      partitionToServers =
-          RetryUtils.retry(
-              () -> {
-                ShuffleAssignmentsInfo response =
-                    shuffleWriteClient.getShuffleAssignments(
-                        appId,
-                        shuffleId,
-                        dependency.partitioner().numPartitions(),
-                        partitionNumPerRange,
-                        assignmentTags,
-                        requiredShuffleServerNumber,
-                        -1);
-                registerShuffleServers(
-                    appId, shuffleId, response.getServerToPartitionRanges(), 
remoteStorage);
-                return response.getPartitionToServers();
-              },
-              retryInterval,
-              retryTimes);
-    } catch (Throwable throwable) {
-      throw new RssException("registerShuffle failed!", throwable);
-    }
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+        requestShuffleAssignment(
+            shuffleId,
+            dependency.partitioner().numPartitions(),
+            1,
+            requiredShuffleServerNumber,
+            estimateTaskConcurrency,
+            rssStageResubmitManager.getServerIdBlackList(),
+            0);
 
     startHeartbeat();
 
     shuffleIdToPartitionNum.putIfAbsent(shuffleId, 
dependency.partitioner().numPartitions());
     shuffleIdToNumMapTasks.putIfAbsent(shuffleId, 
dependency.rdd().partitions().length);
-    if (shuffleManagerRpcServiceEnabled) {
-      MutableShuffleHandleInfo handleInfo =
+    if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+      ShuffleHandleInfo handleInfo =
           new MutableShuffleHandleInfo(shuffleId, partitionToServers, 
remoteStorage);
-      shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+      StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
+          new StageAttemptShuffleHandleInfo(shuffleId, remoteStorage, 
handleInfo);
+      shuffleHandleInfoManager.register(shuffleId, 
stageAttemptShuffleHandleInfo);
+    } else if (shuffleManagerRpcServiceEnabled && 
taskBlockSendFailureRetryEnabled) {
+      ShuffleHandleInfo shuffleHandleInfo =
+          new MutableShuffleHandleInfo(shuffleId, partitionToServers, 
remoteStorage);
+      shuffleHandleInfoManager.register(shuffleId, shuffleHandleInfo);
     }
     Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
         RssSparkShuffleUtils.broadcastShuffleHdlInfo(
@@ -435,37 +380,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
   }
 
-  @VisibleForTesting
-  protected void registerShuffleServers(
-      String appId,
-      int shuffleId,
-      Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
-      RemoteStorageInfo remoteStorage) {
-    if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
-      return;
-    }
-    LOG.info("Start to register shuffleId[" + shuffleId + "]");
-    long start = System.currentTimeMillis();
-    serverToPartitionRanges.entrySet().stream()
-        .forEach(
-            entry -> {
-              shuffleWriteClient.registerShuffle(
-                  entry.getKey(),
-                  appId,
-                  shuffleId,
-                  entry.getValue(),
-                  remoteStorage,
-                  ShuffleDataDistributionType.NORMAL,
-                  maxConcurrencyPerPartitionToWrite);
-            });
-    LOG.info(
-        "Finish register shuffleId["
-            + shuffleId
-            + "] with "
-            + (System.currentTimeMillis() - start)
-            + " ms");
-  }
-
   @VisibleForTesting
   protected void registerCoordinator() {
     String coordinators = 
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
@@ -493,9 +407,12 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       int shuffleId = rssHandle.getShuffleId();
       String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
       ShuffleHandleInfo shuffleHandleInfo;
-      if (shuffleManagerRpcServiceEnabled) {
-        // Get the ShuffleServer list from the Driver based on the shuffleId
-        shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+      if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+        // In Stage Retry mode, Get the ShuffleServer list from the Driver 
based on the shuffleId
+        shuffleHandleInfo = 
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+      } else if (shuffleManagerRpcServiceEnabled && 
taskBlockSendFailureRetryEnabled) {
+        // In Block Retry mode, Get the ShuffleServer list from the Driver 
based on the shuffleId
+        shuffleHandleInfo = 
getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
       } else {
         shuffleHandleInfo =
             new SimpleShuffleHandleInfo(
@@ -563,9 +480,12 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
               + "]");
       start = System.currentTimeMillis();
       ShuffleHandleInfo shuffleHandleInfo;
-      if (shuffleManagerRpcServiceEnabled) {
-        // Get the ShuffleServer list from the Driver based on the shuffleId
-        shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+      if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+        // In Stage Retry mode, Get the ShuffleServer list from the Driver 
based on the shuffleId.
+        shuffleHandleInfo = 
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+      } else if (shuffleManagerRpcServiceEnabled && 
taskBlockSendFailureRetryEnabled) {
+        // In Block Retry mode, Get the ShuffleServer list from the Driver 
based on the shuffleId
+        shuffleHandleInfo = 
getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
       } else {
         shuffleHandleInfo =
             new SimpleShuffleHandleInfo(
@@ -764,16 +684,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return appId;
   }
 
-  /**
-   * @return the maximum number of fetch failures per shuffle partition before 
that shuffle stage
-   *     should be recomputed
-   */
-  @Override
-  public int getMaxFetchFailures() {
-    final String TASK_MAX_FAILURE = "spark.task.maxFailures";
-    return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1);
-  }
-
   /**
    * @param shuffleId the shuffleId to query
    * @return the num of partitions(a.k.a reduce tasks) for shuffle with 
shuffle id.
@@ -812,170 +722,14 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return taskToFailedBlockSendTracker.get(taskId);
   }
 
-  @Override
-  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
-    return shuffleIdToShuffleHandleInfo.get(shuffleId);
-  }
-
-  private ShuffleManagerClient createShuffleManagerClient(String host, int 
port) {
-    // Host can be inferred from `spark.driver.bindAddress`, which would be 
set when SparkContext is
-    // constructed.
-    return ShuffleManagerClientFactory.getInstance()
-        .createShuffleManagerClient(ClientType.GRPC, host, port);
-  }
-
-  private ShuffleManagerClient getOrCreateShuffleManagerClient() {
-    if (shuffleManagerClient == null) {
-      RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-      String driver = rssConf.getString("driver.host", "");
-      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-      this.shuffleManagerClient =
-          ShuffleManagerClientFactory.getInstance()
-              .createShuffleManagerClient(ClientType.GRPC, driver, port);
-    }
-    return shuffleManagerClient;
-  }
-
-  /**
-   * Get the ShuffleServer list from the Driver based on the shuffleId
-   *
-   * @param shuffleId shuffleId
-   * @return ShuffleHandleInfo
-   */
-  private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int 
shuffleId) {
-    RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
-        new RssPartitionToShuffleServerRequest(shuffleId);
-    RssPartitionToShuffleServerResponse handleInfoResponse =
-        getOrCreateShuffleManagerClient()
-            .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
-    MutableShuffleHandleInfo shuffleHandleInfo =
-        
MutableShuffleHandleInfo.fromProto(handleInfoResponse.getShuffleHandleInfoProto());
-    return shuffleHandleInfo;
-  }
-
-  /**
-   * Add the shuffleServer that failed to write to the failure list
-   *
-   * @param shuffleServerId
-   */
-  @Override
-  public void addFailuresShuffleServerInfos(String shuffleServerId) {
-    failuresShuffleServerIds.add(shuffleServerId);
-  }
-
-  /**
-   * Reassign the ShuffleServer list for ShuffleId
-   *
-   * @param shuffleId
-   * @param numPartitions
-   */
-  @Override
-  public synchronized boolean reassignAllShuffleServersForWholeStage(
-      int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
-    String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
-    Boolean needReassign = 
serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
-    if (!needReassign) {
-      int requiredShuffleServerNumber =
-          RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
-      int estimateTaskConcurrency = 
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
-      /** Before reassigning ShuffleServer, clear the ShuffleServer list in 
ShuffleWriteClient. */
-      shuffleWriteClient.unregisterShuffle(appId, shuffleId);
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-          requestShuffleAssignment(
-              shuffleId,
-              numPartitions,
-              1,
-              requiredShuffleServerNumber,
-              estimateTaskConcurrency,
-              failuresShuffleServerIds);
-      /**
-       * we need to clear the metadata of the completed task, otherwise some 
of the stage's data
-       * will be lost
-       */
-      try {
-        unregisterAllMapOutput(shuffleId);
-      } catch (SparkException e) {
-        LOG.error("Clear MapoutTracker Meta failed!");
-        throw new RssException("Clear MapoutTracker Meta failed!", e);
-      }
-      MutableShuffleHandleInfo handleInfo =
-          new MutableShuffleHandleInfo(shuffleId, partitionToServers, 
getRemoteStorageInfo());
-      shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
-      serverAssignedInfos.put(stageIdAndAttempt, true);
-      return true;
-    } else {
-      LOG.info(
-          "The Stage:{} has been reassigned in an Attempt{},Return without 
performing any operation",
-          stageId,
-          stageAttemptNumber);
-      return false;
-    }
-  }
-
-  @Override
-  public MutableShuffleHandleInfo reassignOnBlockSendFailure(
-      int shuffleId, Map<Integer, List<ReceivingFailureServer>> 
partitionToFailureServers) {
-    throw new RssException("Illegal access for reassignOnBlockSendFailure that 
is not supported.");
-  }
-
   private ShuffleServerInfo assignShuffleServer(int shuffleId, String 
faultyShuffleServerId) {
     Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
-    faultyServerIds.addAll(failuresShuffleServerIds);
+    faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
     Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-        requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
+        requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, 0);
     if (partitionToServers.get(0) != null && partitionToServers.get(0).size() 
== 1) {
       return partitionToServers.get(0).get(0);
     }
     return null;
   }
-
-  private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
-      int shuffleId,
-      int partitionNum,
-      int partitionNumPerRange,
-      int assignmentShuffleServerNumber,
-      int estimateTaskConcurrency,
-      Set<String> faultyServerIds) {
-    Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
-    ClientUtils.validateClientType(clientType);
-    assignmentTags.add(clientType);
-
-    long retryInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
-    int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
-    faultyServerIds.addAll(failuresShuffleServerIds);
-    try {
-      return RetryUtils.retry(
-          () -> {
-            ShuffleAssignmentsInfo response =
-                shuffleWriteClient.getShuffleAssignments(
-                    appId,
-                    shuffleId,
-                    partitionNum,
-                    partitionNumPerRange,
-                    assignmentTags,
-                    assignmentShuffleServerNumber,
-                    estimateTaskConcurrency,
-                    faultyServerIds);
-            registerShuffleServers(
-                appId, shuffleId, response.getServerToPartitionRanges(), 
getRemoteStorageInfo());
-            return response.getPartitionToServers();
-          },
-          retryInterval,
-          retryTimes);
-    } catch (Throwable throwable) {
-      throw new RssException("registerShuffle failed!", throwable);
-    }
-  }
-
-  private RemoteStorageInfo getRemoteStorageInfo() {
-    String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
-    RemoteStorageInfo defaultRemoteStorage =
-        new 
RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), 
""));
-    return ClientUtils.fetchRemoteStorage(
-        appId, defaultRemoteStorage, dynamicConfEnabled, storageType, 
shuffleWriteClient);
-  }
-
-  public boolean isRssResubmitStage() {
-    return rssResubmitStage;
-  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 81d54ec10..b40b55028 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -551,6 +551,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         RssReportShuffleWriteFailureResponse response =
             shuffleManagerClient.reportShuffleWriteFailure(req);
         if (response.getReSubmitWholeStage()) {
+          // The shuffle server is reassigned.
           RssReassignServersRequest rssReassignServersRequest =
               new RssReassignServersRequest(
                   taskContext.stageId(),
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index dc1de59ef..b5e89d1da 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -18,10 +18,7 @@
 package org.apache.spark.shuffle;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -30,7 +27,6 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import scala.Tuple2;
@@ -41,13 +37,11 @@ import scala.collection.Seq;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Sets;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
-import org.apache.commons.collections4.CollectionUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.MapOutputTracker;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkEnv;
-import org.apache.spark.SparkException;
 import org.apache.spark.TaskContext;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.executor.ShuffleReadMetrics;
@@ -55,6 +49,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
 import org.apache.spark.shuffle.writer.DataPusher;
@@ -67,19 +62,10 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
-import org.apache.uniffle.client.api.ShuffleManagerClient;
-import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
-import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.client.util.RssClientConfig;
-import org.apache.uniffle.common.ClientType;
-import org.apache.uniffle.common.PartitionRange;
-import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
@@ -87,10 +73,8 @@ import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
-import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
 import org.apache.uniffle.shuffle.RssShuffleClientFactory;
@@ -105,10 +89,8 @@ import static 
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER
 
 public class RssShuffleManager extends RssShuffleManagerBase {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManager.class);
-  private final String clientType;
   private final long heartbeatInterval;
   private final long heartbeatTimeout;
-  private AtomicReference<String> id = new AtomicReference<>();
   private final int dataReplica;
   private final int dataReplicaWrite;
   private final int dataReplicaRead;
@@ -119,10 +101,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   private final Map<String, FailedBlockSendTracker> 
taskToFailedBlockSendTracker;
   private ScheduledExecutorService heartBeatScheduledExecutorService;
   private boolean heartbeatStarted = false;
-  private boolean dynamicConfEnabled;
-  private final ShuffleDataDistributionType dataDistributionType;
   private final BlockIdLayout blockIdLayout;
-  private final int maxConcurrencyPerPartitionToWrite;
   private final int maxFailures;
   private final boolean speculation;
   private String user;
@@ -135,31 +114,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   private ShuffleManagerGrpcService service;
   private GrpcServer shuffleManagerServer;
 
-  /** used by columnar rss shuffle writer implementation */
-  protected SparkConf sparkConf;
-
-  protected ShuffleWriteClient shuffleWriteClient;
-
-  private ShuffleManagerClient shuffleManagerClient;
-  /** Whether to enable the dynamic shuffleServer function rewrite and reread 
functions */
-  private boolean rssResubmitStage;
-
-  private boolean taskBlockSendFailureRetryEnabled;
-
-  private boolean shuffleManagerRpcServiceEnabled;
-  /** A list of shuffleServer for Write failures */
-  private Set<String> failuresShuffleServerIds;
-  /**
-   * Prevent multiple tasks from reporting FetchFailed, resulting in multiple 
ShuffleServer
-   * assignments, stageID, Attemptnumber Whether to reassign the combination 
flag;
-   */
-  private Map<String, Boolean> serverAssignedInfos;
-
-  private final int partitionReassignMaxServerNum;
-
-  private final ShuffleHandleInfoManager shuffleHandleInfoManager = new 
ShuffleHandleInfoManager();
-  private boolean blockIdSelfManagedEnabled;
-
   public RssShuffleManager(SparkConf conf, boolean isDriver) {
     this.sparkConf = conf;
     boolean supportsRelocation =
@@ -307,10 +261,10 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
             failedTaskIds,
             poolSize,
             keepAliveTime);
-    this.failuresShuffleServerIds = Sets.newHashSet();
-    this.serverAssignedInfos = JavaUtils.newConcurrentMap();
     this.partitionReassignMaxServerNum =
         rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+    this.rssStageResubmitManager = new RssStageResubmitManager();
   }
 
   public CompletableFuture<Long> sendData(AddBlockEvent event) {
@@ -403,6 +357,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.dataPusher = dataPusher;
     this.partitionReassignMaxServerNum =
         rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+    this.rssStageResubmitManager = new RssStageResubmitManager();
   }
 
   // This method is called in Spark driver side,
@@ -444,6 +400,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
     if (id.get() == null) {
       id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid);
+      appId = id.get();
       dataPusher.setRssAppId(id.get());
     }
     LOG.info("Generate application id used in rss: " + id.get());
@@ -478,43 +435,30 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
     int requiredShuffleServerNumber =
         RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
-
-    // retryInterval must bigger than `rss.server.heartbeat.interval`, or 
maybe it will return the
-    // same result
-    long retryInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
-    int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
     int estimateTaskConcurrency = 
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers;
-    try {
-      partitionToServers =
-          RetryUtils.retry(
-              () -> {
-                ShuffleAssignmentsInfo response =
-                    shuffleWriteClient.getShuffleAssignments(
-                        id.get(),
-                        shuffleId,
-                        dependency.partitioner().numPartitions(),
-                        1,
-                        assignmentTags,
-                        requiredShuffleServerNumber,
-                        estimateTaskConcurrency);
-                registerShuffleServers(
-                    id.get(), shuffleId, 
response.getServerToPartitionRanges(), remoteStorage);
-                return response.getPartitionToServers();
-              },
-              retryInterval,
-              retryTimes);
-    } catch (Throwable throwable) {
-      throw new RssException("registerShuffle failed!", throwable);
-    }
-    startHeartbeat();
 
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+        requestShuffleAssignment(
+            shuffleId,
+            dependency.partitioner().numPartitions(),
+            1,
+            requiredShuffleServerNumber,
+            estimateTaskConcurrency,
+            rssStageResubmitManager.getServerIdBlackList(),
+            0);
+    startHeartbeat();
     shuffleIdToPartitionNum.putIfAbsent(shuffleId, 
dependency.partitioner().numPartitions());
     shuffleIdToNumMapTasks.putIfAbsent(shuffleId, 
dependency.rdd().partitions().length);
-    if (shuffleManagerRpcServiceEnabled) {
-      MutableShuffleHandleInfo handleInfo =
+    if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+      ShuffleHandleInfo shuffleHandleInfo =
           new MutableShuffleHandleInfo(shuffleId, partitionToServers, 
remoteStorage);
+      StageAttemptShuffleHandleInfo handleInfo =
+          new StageAttemptShuffleHandleInfo(shuffleId, remoteStorage, 
shuffleHandleInfo);
       shuffleHandleInfoManager.register(shuffleId, handleInfo);
+    } else if (shuffleManagerRpcServiceEnabled && 
taskBlockSendFailureRetryEnabled) {
+      ShuffleHandleInfo shuffleHandleInfo =
+          new MutableShuffleHandleInfo(shuffleId, partitionToServers, 
remoteStorage);
+      shuffleHandleInfoManager.register(shuffleId, shuffleHandleInfo);
     }
     Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
         RssSparkShuffleUtils.broadcastShuffleHdlInfo(
@@ -549,9 +493,12 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       writeMetrics = context.taskMetrics().shuffleWriteMetrics();
     }
     ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled) {
-      // Get the ShuffleServer list from the Driver based on the shuffleId
-      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+    if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
+      shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+    } else if (shuffleManagerRpcServiceEnabled && 
taskBlockSendFailureRetryEnabled) {
+      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
+      shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
     } else {
       shuffleHandleInfo =
           new SimpleShuffleHandleInfo(
@@ -691,9 +638,12 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     final int partitionNum = 
rssShuffleHandle.getDependency().partitioner().numPartitions();
     int shuffleId = rssShuffleHandle.getShuffleId();
     ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled) {
-      // Get the ShuffleServer list from the Driver based on the shuffleId
-      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+    if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
+      shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+    } else if (shuffleManagerRpcServiceEnabled && 
taskBlockSendFailureRetryEnabled) {
+      // In Block Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
+      shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
     } else {
       shuffleHandleInfo =
           new SimpleShuffleHandleInfo(
@@ -940,39 +890,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     taskToFailedBlockSendTracker.remove(taskId);
   }
 
-  @VisibleForTesting
-  protected void registerShuffleServers(
-      String appId,
-      int shuffleId,
-      Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
-      RemoteStorageInfo remoteStorage) {
-    if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
-      return;
-    }
-    LOG.info("Start to register shuffleId[" + shuffleId + "]");
-    long start = System.currentTimeMillis();
-    Set<Map.Entry<ShuffleServerInfo, List<PartitionRange>>> entries =
-        serverToPartitionRanges.entrySet();
-    entries.stream()
-        .forEach(
-            entry -> {
-              shuffleWriteClient.registerShuffle(
-                  entry.getKey(),
-                  appId,
-                  shuffleId,
-                  entry.getValue(),
-                  remoteStorage,
-                  dataDistributionType,
-                  maxConcurrencyPerPartitionToWrite);
-            });
-    LOG.info(
-        "Finish register shuffleId["
-            + shuffleId
-            + "] with "
-            + (System.currentTimeMillis() - start)
-            + " ms");
-  }
-
   @VisibleForTesting
   protected void registerCoordinator() {
     String coordinators = 
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
@@ -1027,16 +944,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return id.get();
   }
 
-  /**
-   * @return the maximum number of fetch failures per shuffle partition before 
that shuffle stage
-   *     should be recomputed
-   */
-  @Override
-  public int getMaxFetchFailures() {
-    final String TASK_MAX_FAILURE = "spark.task.maxFailures";
-    return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1);
-  }
-
   @Override
   public int getPartitionNum(int shuffleId) {
     return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
@@ -1138,282 +1045,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return taskToFailedBlockSendTracker.get(taskId);
   }
 
-  @Override
-  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
-    return shuffleHandleInfoManager.get(shuffleId);
-  }
-
-  // todo: automatic close client when the client is idle to avoid too much 
connections for spark
-  // driver.
-  private ShuffleManagerClient getOrCreateShuffleManagerClient() {
-    if (shuffleManagerClient == null) {
-      RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-      String driver = rssConf.getString("driver.host", "");
-      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-      this.shuffleManagerClient =
-          ShuffleManagerClientFactory.getInstance()
-              .createShuffleManagerClient(ClientType.GRPC, driver, port);
-    }
-    return shuffleManagerClient;
-  }
-
-  /**
-   * Get the ShuffleServer list from the Driver based on the shuffleId
-   *
-   * @param shuffleId shuffleId
-   * @return ShuffleHandleInfo
-   */
-  private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int 
shuffleId) {
-    RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
-        new RssPartitionToShuffleServerRequest(shuffleId);
-    RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
-        getOrCreateShuffleManagerClient()
-            .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
-    MutableShuffleHandleInfo shuffleHandleInfo =
-        MutableShuffleHandleInfo.fromProto(
-            rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
-    return shuffleHandleInfo;
-  }
-
-  /**
-   * Add the shuffleServer that failed to write to the failure list
-   *
-   * @param shuffleServerId
-   */
-  @Override
-  public void addFailuresShuffleServerInfos(String shuffleServerId) {
-    failuresShuffleServerIds.add(shuffleServerId);
-  }
-
-  /**
-   * Reassign the ShuffleServer list for ShuffleId
-   *
-   * @param shuffleId
-   * @param numPartitions
-   */
-  @Override
-  public synchronized boolean reassignAllShuffleServersForWholeStage(
-      int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
-    String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
-    Boolean needReassign = 
serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
-    if (!needReassign) {
-      int requiredShuffleServerNumber =
-          RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
-      int estimateTaskConcurrency = 
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
-      /** Before reassigning ShuffleServer, clear the ShuffleServer list in 
ShuffleWriteClient. */
-      shuffleWriteClient.unregisterShuffle(id.get(), shuffleId);
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-          requestShuffleAssignment(
-              shuffleId,
-              numPartitions,
-              1,
-              requiredShuffleServerNumber,
-              estimateTaskConcurrency,
-              failuresShuffleServerIds,
-              null);
-      /**
-       * we need to clear the metadata of the completed task, otherwise some 
of the stage's data
-       * will be lost
-       */
-      try {
-        unregisterAllMapOutput(shuffleId);
-      } catch (SparkException e) {
-        LOG.error("Clear MapoutTracker Meta failed!");
-        throw new RssException("Clear MapoutTracker Meta failed!", e);
-      }
-      MutableShuffleHandleInfo handleInfo =
-          new MutableShuffleHandleInfo(shuffleId, partitionToServers, 
getRemoteStorageInfo());
-      shuffleHandleInfoManager.register(shuffleId, handleInfo);
-      serverAssignedInfos.put(stageIdAndAttempt, true);
-      return true;
-    } else {
-      LOG.info(
-          "The Stage:{} has been reassigned in an Attempt{},Return without 
performing any operation",
-          stageId,
-          stageAttemptNumber);
-      return false;
-    }
-  }
-
-  /** this is only valid on driver side that exposed to being invoked by grpc 
server */
-  @Override
-  public MutableShuffleHandleInfo reassignOnBlockSendFailure(
-      int shuffleId, Map<Integer, List<ReceivingFailureServer>> 
partitionToFailureServers) {
-    long startTime = System.currentTimeMillis();
-    MutableShuffleHandleInfo handleInfo =
-        (MutableShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId);
-    synchronized (handleInfo) {
-      // If the reassignment servers for one partition exceeds the max 
reassign server num,
-      // it should fast fail.
-      handleInfo.checkPartitionReassignServerNum(
-          partitionToFailureServers.keySet(), partitionReassignMaxServerNum);
-
-      Map<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new 
HashMap<>();
-      // receivingFailureServer -> partitionId -> replacementServerIds. For 
logging
-      Map<String, Map<Integer, Set<String>>> reassignResult = new HashMap<>();
-
-      for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
-          partitionToFailureServers.entrySet()) {
-        int partitionId = entry.getKey();
-        for (ReceivingFailureServer receivingFailureServer : entry.getValue()) 
{
-          StatusCode code = receivingFailureServer.getStatusCode();
-          String serverId = receivingFailureServer.getServerId();
-
-          boolean serverHasReplaced = false;
-          Set<ShuffleServerInfo> replacements = 
handleInfo.getReplacements(serverId);
-          if (CollectionUtils.isEmpty(replacements)) {
-            final int requiredServerNum = 1;
-            Set<String> excludedServers = new 
HashSet<>(handleInfo.listExcludedServers());
-            excludedServers.add(serverId);
-            replacements =
-                reassignServerForTask(
-                    shuffleId, Sets.newHashSet(partitionId), excludedServers, 
requiredServerNum);
-          } else {
-            serverHasReplaced = true;
-          }
-
-          Set<ShuffleServerInfo> updatedReassignServers =
-              handleInfo.updateAssignment(partitionId, serverId, replacements);
-
-          reassignResult
-              .computeIfAbsent(serverId, x -> new HashMap<>())
-              .computeIfAbsent(partitionId, x -> new HashSet<>())
-              .addAll(
-                  updatedReassignServers.stream().map(x -> 
x.getId()).collect(Collectors.toSet()));
-
-          if (serverHasReplaced) {
-            for (ShuffleServerInfo serverInfo : updatedReassignServers) {
-              newServerToPartitions
-                  .computeIfAbsent(serverInfo, x -> new ArrayList<>())
-                  .add(new PartitionRange(partitionId, partitionId));
-            }
-          }
-        }
-      }
-      if (!newServerToPartitions.isEmpty()) {
-        LOG.info(
-            "Register the new partition->servers assignment on reassign. {}",
-            newServerToPartitions);
-        registerShuffleServers(id.get(), shuffleId, newServerToPartitions, 
getRemoteStorageInfo());
-      }
-
-      LOG.info(
-          "Finished reassignOnBlockSendFailure request and cost {}(ms). 
Reassign result: {}",
-          System.currentTimeMillis() - startTime,
-          reassignResult);
-
-      return handleInfo;
-    }
-  }
-
-  /**
-   * Creating the shuffleAssignmentInfo from the servers and partitionIds
-   *
-   * @param servers
-   * @param partitionIds
-   * @return
-   */
-  private ShuffleAssignmentsInfo createShuffleAssignmentsInfo(
-      Set<ShuffleServerInfo> servers, Set<Integer> partitionIds) {
-    Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new 
HashMap<>();
-    List<PartitionRange> partitionRanges = new ArrayList<>();
-    for (Integer partitionId : partitionIds) {
-      newPartitionToServers.put(partitionId, new ArrayList<>(servers));
-      partitionRanges.add(new PartitionRange(partitionId, partitionId));
-    }
-    Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new 
HashMap<>();
-    for (ShuffleServerInfo server : servers) {
-      serverToPartitionRanges.put(server, partitionRanges);
-    }
-    return new ShuffleAssignmentsInfo(newPartitionToServers, 
serverToPartitionRanges);
-  }
-
-  /** Request the new shuffle-servers to replace faulty server. */
-  private Set<ShuffleServerInfo> reassignServerForTask(
-      int shuffleId,
-      Set<Integer> partitionIds,
-      Set<String> excludedServers,
-      int requiredServerNum) {
-    AtomicReference<Set<ShuffleServerInfo>> replacementsRef =
-        new AtomicReference<>(new HashSet<>());
-    requestShuffleAssignment(
-        shuffleId,
-        requiredServerNum,
-        1,
-        requiredServerNum,
-        1,
-        excludedServers,
-        shuffleAssignmentsInfo -> {
-          if (shuffleAssignmentsInfo == null) {
-            return null;
-          }
-          Set<ShuffleServerInfo> replacements =
-              shuffleAssignmentsInfo.getPartitionToServers().values().stream()
-                  .flatMap(x -> x.stream())
-                  .collect(Collectors.toSet());
-          replacementsRef.set(replacements);
-          return createShuffleAssignmentsInfo(replacements, partitionIds);
-        });
-    return replacementsRef.get();
-  }
-
-  private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
-      int shuffleId,
-      int partitionNum,
-      int partitionNumPerRange,
-      int assignmentShuffleServerNumber,
-      int estimateTaskConcurrency,
-      Set<String> faultyServerIds,
-      Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo> 
reassignmentHandler) {
-    Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
-    ClientUtils.validateClientType(clientType);
-    assignmentTags.add(clientType);
-    long retryInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
-    int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
-    faultyServerIds.addAll(failuresShuffleServerIds);
-    try {
-      return RetryUtils.retry(
-          () -> {
-            ShuffleAssignmentsInfo response =
-                shuffleWriteClient.getShuffleAssignments(
-                    id.get(),
-                    shuffleId,
-                    partitionNum,
-                    partitionNumPerRange,
-                    assignmentTags,
-                    assignmentShuffleServerNumber,
-                    estimateTaskConcurrency,
-                    faultyServerIds);
-            LOG.info("Finished the shuffle assignment request to 
coordinator.");
-            if (reassignmentHandler != null) {
-              response = reassignmentHandler.apply(response);
-            }
-            LOG.info(
-                "Register the partition->servers assignment. {}",
-                response.getServerToPartitionRanges());
-            registerShuffleServers(
-                id.get(), shuffleId, response.getServerToPartitionRanges(), 
getRemoteStorageInfo());
-            return response.getPartitionToServers();
-          },
-          retryInterval,
-          retryTimes);
-    } catch (Throwable throwable) {
-      throw new RssException("Errors on requesting shuffle assignment!", 
throwable);
-    }
-  }
-
-  private RemoteStorageInfo getRemoteStorageInfo() {
-    String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
-    RemoteStorageInfo defaultRemoteStorage =
-        new 
RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), 
""));
-    return ClientUtils.fetchRemoteStorage(
-        id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, 
shuffleWriteClient);
-  }
-
-  public boolean isRssResubmitStage() {
-    return rssResubmitStage;
-  }
-
   @VisibleForTesting
   public void setDataPusher(DataPusher dataPusher) {
     this.dataPusher = dataPusher;
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 0805dfe35..d0eab5ae1 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -588,7 +588,8 @@ public class WriteBufferManagerTest {
         List<PartitionRange> partitionRanges,
         RemoteStorageInfo remoteStorage,
         ShuffleDataDistributionType dataDistributionType,
-        int maxConcurrencyPerPartitionToWrite) {}
+        int maxConcurrencyPerPartitionToWrite,
+        int stageAttemptNumber) {}
 
     @Override
     public boolean sendCommit(
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 7d8f53392..efd39e35a 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -35,6 +35,16 @@ import org.apache.uniffle.common.ShuffleServerInfo;
 
 public interface ShuffleWriteClient {
 
+  default SendShuffleDataResult sendShuffleData(
+      String appId,
+      int stageAttemptNumber,
+      List<ShuffleBlockInfo> shuffleBlockInfoList,
+      Supplier<Boolean> needCancelRequest) {
+    throw new UnsupportedOperationException(
+        this.getClass().getName()
+            + " doesn't implement getShuffleAssignments with faultyServerIds");
+  }
+
   SendShuffleDataResult sendShuffleData(
       String appId,
       List<ShuffleBlockInfo> shuffleBlockInfoList,
@@ -44,6 +54,25 @@ public interface ShuffleWriteClient {
 
   void registerApplicationInfo(String appId, long timeoutMs, String user);
 
+  default void registerShuffle(
+      ShuffleServerInfo shuffleServerInfo,
+      String appId,
+      int shuffleId,
+      List<PartitionRange> partitionRanges,
+      RemoteStorageInfo remoteStorage,
+      ShuffleDataDistributionType dataDistributionType,
+      int maxConcurrencyPerPartitionToWrite) {
+    registerShuffle(
+        shuffleServerInfo,
+        appId,
+        shuffleId,
+        partitionRanges,
+        remoteStorage,
+        dataDistributionType,
+        maxConcurrencyPerPartitionToWrite,
+        0);
+  }
+
   void registerShuffle(
       ShuffleServerInfo shuffleServerInfo,
       String appId,
@@ -51,7 +80,8 @@ public interface ShuffleWriteClient {
       List<PartitionRange> partitionRanges,
       RemoteStorageInfo remoteStorage,
       ShuffleDataDistributionType dataDistributionType,
-      int maxConcurrencyPerPartitionToWrite);
+      int maxConcurrencyPerPartitionToWrite,
+      int stageAttemptNumber);
 
   boolean sendCommit(
       Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int 
shuffleId, int numMaps);
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index dc6ef420a..a79a8557a 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -163,6 +163,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
   private boolean sendShuffleDataAsync(
       String appId,
+      int stageAttemptNumber,
       Map<ShuffleServerInfo, Map<Integer, Map<Integer, 
List<ShuffleBlockInfo>>>> serverToBlocks,
       Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
       Map<Long, AtomicInteger> blockIdsSendSuccessTracker,
@@ -192,7 +193,11 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
                       // todo: compact unnecessary blocks that reach 
replicaWrite
                       RssSendShuffleDataRequest request =
                           new RssSendShuffleDataRequest(
-                              appId, retryMax, retryIntervalMax, 
shuffleIdToBlocks);
+                              appId,
+                              stageAttemptNumber,
+                              retryMax,
+                              retryIntervalMax,
+                              shuffleIdToBlocks);
                       long s = System.currentTimeMillis();
                       RssSendShuffleDataResponse response =
                           getShuffleServerClient(ssi).sendShuffleData(request);
@@ -314,10 +319,19 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
         });
   }
 
+  @Override
+  public SendShuffleDataResult sendShuffleData(
+      String appId,
+      List<ShuffleBlockInfo> shuffleBlockInfoList,
+      Supplier<Boolean> needCancelRequest) {
+    return sendShuffleData(appId, 0, shuffleBlockInfoList, needCancelRequest);
+  }
+
   /** The batch of sending belongs to the same task */
   @Override
   public SendShuffleDataResult sendShuffleData(
       String appId,
+      int stageAttemptNumber,
       List<ShuffleBlockInfo> shuffleBlockInfoList,
       Supplier<Boolean> needCancelRequest) {
 
@@ -400,6 +414,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     boolean isAllSuccess =
         sendShuffleDataAsync(
             appId,
+            stageAttemptNumber,
             primaryServerToBlocks,
             primaryServerToBlockIds,
             blockIdsSendSuccessTracker,
@@ -416,6 +431,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       LOG.info("The sending of primary round is failed partially, so start the 
secondary round");
       sendShuffleDataAsync(
           appId,
+          stageAttemptNumber,
           secondaryServerToBlocks,
           secondaryServerToBlockIds,
           blockIdsSendSuccessTracker,
@@ -545,7 +561,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       List<PartitionRange> partitionRanges,
       RemoteStorageInfo remoteStorage,
       ShuffleDataDistributionType dataDistributionType,
-      int maxConcurrencyPerPartitionToWrite) {
+      int maxConcurrencyPerPartitionToWrite,
+      int stageAttemptNumber) {
     String user = null;
     try {
       user = UserGroupInformation.getCurrentUser().getShortUserName();
@@ -561,7 +578,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
             remoteStorage,
             user,
             dataDistributionType,
-            maxConcurrencyPerPartitionToWrite);
+            maxConcurrencyPerPartitionToWrite,
+            stageAttemptNumber);
     RssRegisterShuffleResponse response =
         getShuffleServerClient(shuffleServerInfo).registerShuffle(request);
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
index a77b0d3c7..9fefb98f6 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
@@ -30,6 +30,8 @@ import org.apache.uniffle.common.util.ByteBufUtils;
 public class SendShuffleDataRequest extends RequestMessage {
   private String appId;
   private int shuffleId;
+
+  private int stageAttemptNumber;
   private long requireId;
   private Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks;
   private long timestamp;
@@ -41,12 +43,24 @@ public class SendShuffleDataRequest extends RequestMessage {
       long requireId,
       Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks,
       long timestamp) {
+    this(requestId, appId, shuffleId, 0, requireId, partitionToBlocks, 
timestamp);
+  }
+
+  public SendShuffleDataRequest(
+      long requestId,
+      String appId,
+      int shuffleId,
+      int stageAttemptNumber,
+      long requireId,
+      Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks,
+      long timestamp) {
     super(requestId);
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.requireId = requireId;
     this.partitionToBlocks = partitionToBlocks;
     this.timestamp = timestamp;
+    this.stageAttemptNumber = stageAttemptNumber;
   }
 
   @Override
@@ -146,6 +160,10 @@ public class SendShuffleDataRequest extends RequestMessage 
{
     this.timestamp = timestamp;
   }
 
+  public int getStageAttemptNumber() {
+    return stageAttemptNumber;
+  }
+
   @Override
   public String getOperationType() {
     return "sendShuffleData";
diff --git a/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java 
b/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java
index 79e35ecab..ff8ac231c 100644
--- a/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java
+++ b/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java
@@ -35,6 +35,7 @@ public enum StatusCode {
   ACCESS_DENIED(8),
   INVALID_REQUEST(9),
   NO_BUFFER_FOR_HUGE_PARTITION(10),
+  STAGE_RETRY_IGNORE(11),
   UNKNOWN(-1);
 
   static final Map<Integer, StatusCode> VALUE_MAP =
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index 6b6ee1ece..71dc66584 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -28,8 +28,8 @@ import 
org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
 import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
 import org.apache.uniffle.client.response.RssGetShuffleResultResponse;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
 import org.apache.uniffle.client.response.RssReassignServersResponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
@@ -40,12 +40,23 @@ public interface ShuffleManagerClient extends Closeable {
       RssReportShuffleFetchFailureRequest request);
 
   /**
-   * Gets the mapping between partitions and ShuffleServer from the 
ShuffleManager server
+   * In Stage Retry mode,Gets the mapping between partitions and ShuffleServer 
from the
+   * ShuffleManager server.
    *
    * @param req request
    * @return RssPartitionToShuffleServerResponse
    */
-  RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+  RssReassignOnStageRetryResponse getPartitionToShufflerServerWithStageRetry(
+      RssPartitionToShuffleServerRequest req);
+
+  /**
+   * In Block Retry mode,Gets the mapping between partitions and ShuffleServer 
from the
+   * ShuffleManager server.
+   *
+   * @param req request
+   * @return RssPartitionToShuffleServerResponse
+   */
+  RssReassignOnBlockSendFailureResponse 
getPartitionToShufflerServerWithBlockRetry(
       RssPartitionToShuffleServerRequest req);
 
   RssReportShuffleWriteFailureResponse reportShuffleWriteFailure(
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 997778bc9..030afd03d 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -32,8 +32,8 @@ import 
org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
 import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
 import org.apache.uniffle.client.response.RssGetShuffleResultResponse;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
 import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
 import org.apache.uniffle.client.response.RssReassignServersResponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
@@ -90,14 +90,25 @@ public class ShuffleManagerGrpcClient extends GrpcClient 
implements ShuffleManag
   }
 
   @Override
-  public RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+  public RssReassignOnStageRetryResponse 
getPartitionToShufflerServerWithStageRetry(
       RssPartitionToShuffleServerRequest req) {
     RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto();
-    RssProtos.PartitionToShuffleServerResponse partitionToShufflerServer =
-        getBlockingStub().getPartitionToShufflerServer(protoRequest);
-    RssPartitionToShuffleServerResponse rssPartitionToShuffleServerResponse =
-        
RssPartitionToShuffleServerResponse.fromProto(partitionToShufflerServer);
-    return rssPartitionToShuffleServerResponse;
+    RssProtos.ReassignOnStageRetryResponse partitionToShufflerServer =
+        
getBlockingStub().getPartitionToShufflerServerWithStageRetry(protoRequest);
+    RssReassignOnStageRetryResponse rssReassignOnStageRetryResponse =
+        RssReassignOnStageRetryResponse.fromProto(partitionToShufflerServer);
+    return rssReassignOnStageRetryResponse;
+  }
+
+  @Override
+  public RssReassignOnBlockSendFailureResponse 
getPartitionToShufflerServerWithBlockRetry(
+      RssPartitionToShuffleServerRequest req) {
+    RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto();
+    RssProtos.ReassignOnBlockSendFailureResponse partitionToShufflerServer =
+        
getBlockingStub().getPartitionToShufflerServerWithBlockRetry(protoRequest);
+    RssReassignOnBlockSendFailureResponse 
rssReassignOnBlockSendFailureResponse =
+        
RssReassignOnBlockSendFailureResponse.fromProto(partitionToShufflerServer);
+    return rssReassignOnBlockSendFailureResponse;
   }
 
   @Override
@@ -128,7 +139,7 @@ public class ShuffleManagerGrpcClient extends GrpcClient 
implements ShuffleManag
       RssReassignOnBlockSendFailureRequest request) {
     RssProtos.RssReassignOnBlockSendFailureRequest protoReq =
         RssReassignOnBlockSendFailureRequest.toProto(request);
-    RssProtos.RssReassignOnBlockSendFailureResponse response =
+    RssProtos.ReassignOnBlockSendFailureResponse response =
         getBlockingStub().reassignOnBlockSendFailure(protoReq);
     return RssReassignOnBlockSendFailureResponse.fromProto(response);
   }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 9a315c162..14dbf2f60 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -200,7 +200,8 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
       RemoteStorageInfo remoteStorageInfo,
       String user,
       ShuffleDataDistributionType dataDistributionType,
-      int maxConcurrencyPerPartitionToWrite) {
+      int maxConcurrencyPerPartitionToWrite,
+      int stageAttemptNumber) {
     ShuffleRegisterRequest.Builder reqBuilder = 
ShuffleRegisterRequest.newBuilder();
     reqBuilder
         .setAppId(appId)
@@ -208,7 +209,8 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
         .setUser(user)
         
.setShuffleDataDistribution(RssProtos.DataDistribution.valueOf(dataDistributionType.name()))
         
.setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite)
-        .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges));
+        .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges))
+        .setStageAttemptNumber(stageAttemptNumber);
     RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder();
     rsBuilder.setPath(remoteStorageInfo.getPath());
     Map<String, String> remoteStorageConf = remoteStorageInfo.getConfItems();
@@ -468,7 +470,8 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
             request.getRemoteStorageInfo(),
             request.getUser(),
             request.getDataDistributionType(),
-            request.getMaxConcurrencyPerPartitionToWrite());
+            request.getMaxConcurrencyPerPartitionToWrite(),
+            request.getStageAttemptNumber());
 
     RssRegisterShuffleResponse response;
     RssProtos.StatusCode statusCode = rpcResponse.getStatus();
@@ -499,6 +502,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
     String appId = request.getAppId();
     Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks =
         request.getShuffleIdToBlocks();
+    int stageAttemptNumber = request.getStageAttemptNumber();
 
     boolean isSuccessful = true;
     AtomicReference<StatusCode> failedStatusCode = new 
AtomicReference<>(StatusCode.INTERNAL_ERROR);
@@ -563,6 +567,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
                       .setRequireBufferId(requireId)
                       .addAllShuffleData(shuffleData)
                       .setTimestamp(start)
+                      .setStageAttemptNumber(stageAttemptNumber)
                       .build();
               SendShuffleDataResponse response = 
getBlockingStub().sendShuffleData(rpcRequest);
               if (LOG.isDebugEnabled()) {
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index 6451a97cd..1c46077ad 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -139,6 +139,7 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
   public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest 
request) {
     Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks =
         request.getShuffleIdToBlocks();
+    int stageAttemptNumber = request.getStageAttemptNumber();
     boolean isSuccessful = true;
     AtomicReference<StatusCode> failedStatusCode = new 
AtomicReference<>(StatusCode.INTERNAL_ERROR);
 
@@ -159,6 +160,7 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
               requestId(),
               request.getAppId(),
               shuffleId,
+              stageAttemptNumber,
               0L,
               stb.getValue(),
               System.currentTimeMillis());
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
index 98fd01241..4cbdc4448 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
@@ -64,6 +64,30 @@ public class RssGetShuffleAssignmentsRequest {
       int assignmentShuffleServerNumber,
       int estimateTaskConcurrency,
       Set<String> faultyServerIds) {
+    this(
+        appId,
+        shuffleId,
+        partitionNum,
+        partitionNumPerRange,
+        dataReplica,
+        requiredTags,
+        assignmentShuffleServerNumber,
+        estimateTaskConcurrency,
+        faultyServerIds,
+        0);
+  }
+
+  public RssGetShuffleAssignmentsRequest(
+      String appId,
+      int shuffleId,
+      int partitionNum,
+      int partitionNumPerRange,
+      int dataReplica,
+      Set<String> requiredTags,
+      int assignmentShuffleServerNumber,
+      int estimateTaskConcurrency,
+      Set<String> faultyServerIds,
+      int stageAttemptNumber) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionNum = partitionNum;
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
index 2cd49bb6d..7e42be653 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
@@ -35,6 +35,7 @@ public class RssRegisterShuffleRequest {
   private String user;
   private ShuffleDataDistributionType dataDistributionType;
   private int maxConcurrencyPerPartitionToWrite;
+  private int stageAttemptNumber;
 
   public RssRegisterShuffleRequest(
       String appId,
@@ -44,6 +45,26 @@ public class RssRegisterShuffleRequest {
       String user,
       ShuffleDataDistributionType dataDistributionType,
       int maxConcurrencyPerPartitionToWrite) {
+    this(
+        appId,
+        shuffleId,
+        partitionRanges,
+        remoteStorageInfo,
+        user,
+        dataDistributionType,
+        maxConcurrencyPerPartitionToWrite,
+        0);
+  }
+
+  public RssRegisterShuffleRequest(
+      String appId,
+      int shuffleId,
+      List<PartitionRange> partitionRanges,
+      RemoteStorageInfo remoteStorageInfo,
+      String user,
+      ShuffleDataDistributionType dataDistributionType,
+      int maxConcurrencyPerPartitionToWrite,
+      int stageAttemptNumber) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionRanges = partitionRanges;
@@ -51,6 +72,7 @@ public class RssRegisterShuffleRequest {
     this.user = user;
     this.dataDistributionType = dataDistributionType;
     this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite;
+    this.stageAttemptNumber = stageAttemptNumber;
   }
 
   public RssRegisterShuffleRequest(
@@ -67,7 +89,8 @@ public class RssRegisterShuffleRequest {
         remoteStorageInfo,
         user,
         dataDistributionType,
-        RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue());
+        RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
+        0);
   }
 
   public RssRegisterShuffleRequest(
@@ -79,7 +102,8 @@ public class RssRegisterShuffleRequest {
         new RemoteStorageInfo(remoteStoragePath),
         StringUtils.EMPTY,
         ShuffleDataDistributionType.NORMAL,
-        RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue());
+        RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
+        0);
   }
 
   public String getAppId() {
@@ -109,4 +133,8 @@ public class RssRegisterShuffleRequest {
   public int getMaxConcurrencyPerPartitionToWrite() {
     return maxConcurrencyPerPartitionToWrite;
   }
+
+  public int getStageAttemptNumber() {
+    return stageAttemptNumber;
+  }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
index 8fbf18f29..1b5fdcff8 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
@@ -25,6 +25,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo;
 public class RssSendShuffleDataRequest {
 
   private String appId;
+  private int stageAttemptNumber;
   private int retryMax;
   private long retryIntervalMax;
   private Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks;
@@ -34,10 +35,20 @@ public class RssSendShuffleDataRequest {
       int retryMax,
       long retryIntervalMax,
       Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
+    this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks);
+  }
+
+  public RssSendShuffleDataRequest(
+      String appId,
+      int stageAttemptNumber,
+      int retryMax,
+      long retryIntervalMax,
+      Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
     this.appId = appId;
     this.retryMax = retryMax;
     this.retryIntervalMax = retryIntervalMax;
     this.shuffleIdToBlocks = shuffleIdToBlocks;
+    this.stageAttemptNumber = stageAttemptNumber;
   }
 
   public String getAppId() {
@@ -52,6 +63,10 @@ public class RssSendShuffleDataRequest {
     return retryIntervalMax;
   }
 
+  public int getStageAttemptNumber() {
+    return stageAttemptNumber;
+  }
+
   public Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> 
getShuffleIdToBlocks() {
     return shuffleIdToBlocks;
   }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
index 81ca7d548..7d20f8742 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
@@ -34,7 +34,7 @@ public class RssReassignOnBlockSendFailureResponse extends 
ClientResponse {
   }
 
   public static RssReassignOnBlockSendFailureResponse fromProto(
-      RssProtos.RssReassignOnBlockSendFailureResponse response) {
+      RssProtos.ReassignOnBlockSendFailureResponse response) {
     return new RssReassignOnBlockSendFailureResponse(
         StatusCode.valueOf(response.getStatus().name()), response.getMsg(), 
response.getHandle());
   }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnStageRetryResponse.java
similarity index 71%
rename from 
internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
rename to 
internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnStageRetryResponse.java
index 9daa002ed..3762ea41b 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnStageRetryResponse.java
@@ -20,24 +20,24 @@ package org.apache.uniffle.client.response;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.proto.RssProtos;
 
-public class RssPartitionToShuffleServerResponse extends ClientResponse {
-  private RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto;
+public class RssReassignOnStageRetryResponse extends ClientResponse {
+  private RssProtos.StageAttemptShuffleHandleInfo shuffleHandleInfoProto;
 
-  public RssPartitionToShuffleServerResponse(
+  public RssReassignOnStageRetryResponse(
       StatusCode statusCode,
       String message,
-      RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto) {
+      RssProtos.StageAttemptShuffleHandleInfo shuffleHandleInfoProto) {
     super(statusCode, message);
     this.shuffleHandleInfoProto = shuffleHandleInfoProto;
   }
 
-  public RssProtos.MutableShuffleHandleInfo getShuffleHandleInfoProto() {
+  public RssProtos.StageAttemptShuffleHandleInfo getShuffleHandleInfoProto() {
     return shuffleHandleInfoProto;
   }
 
-  public static RssPartitionToShuffleServerResponse fromProto(
-      RssProtos.PartitionToShuffleServerResponse response) {
-    return new RssPartitionToShuffleServerResponse(
+  public static RssReassignOnStageRetryResponse fromProto(
+      RssProtos.ReassignOnStageRetryResponse response) {
+    return new RssReassignOnStageRetryResponse(
         StatusCode.valueOf(response.getStatus().name()),
         response.getMsg(),
         response.getShuffleHandleInfo());
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index a63e6d23a..61afad299 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -184,6 +184,7 @@ message ShuffleRegisterRequest {
   string user = 5;
   DataDistribution shuffleDataDistribution = 6;
   int32 maxConcurrencyPerPartitionToWrite = 7;
+  int32 stageAttemptNumber = 8;
 }
 
 enum DataDistribution {
@@ -221,6 +222,7 @@ message SendShuffleDataRequest {
   int64 requireBufferId = 3;
   repeated ShuffleData shuffleData = 4;
   int64 timestamp = 5;
+  int32 stageAttemptNumber = 6;
 }
 
 message SendShuffleDataResponse {
@@ -305,6 +307,7 @@ enum StatusCode {
   ACCESS_DENIED = 8;
   INVALID_REQUEST = 9;
   NO_BUFFER_FOR_HUGE_PARTITION = 10;
+  STAGE_RETRY_IGNORE = 11;
   // add more status
 }
 
@@ -528,14 +531,16 @@ message CancelDecommissionResponse {
 // per application.
 service ShuffleManager {
   rpc reportShuffleFetchFailure (ReportShuffleFetchFailureRequest) returns 
(ReportShuffleFetchFailureResponse);
-  // Gets the mapping between partitions and ShuffleServer from the 
ShuffleManager server
-  rpc getPartitionToShufflerServer(PartitionToShuffleServerRequest) returns 
(PartitionToShuffleServerResponse);
+  // Gets the mapping between partitions and ShuffleServer from the 
ShuffleManager server on Stage Retry.
+  rpc 
getPartitionToShufflerServerWithStageRetry(PartitionToShuffleServerRequest) 
returns (ReassignOnStageRetryResponse);
+  // Gets the mapping between partitions and ShuffleServer from the 
ShuffleManager server on Block Retry.
+  rpc 
getPartitionToShufflerServerWithBlockRetry(PartitionToShuffleServerRequest) 
returns (ReassignOnBlockSendFailureResponse);
   // Report write failures to ShuffleManager
   rpc reportShuffleWriteFailure (ReportShuffleWriteFailureRequest) returns 
(ReportShuffleWriteFailureResponse);
   // Reassign the RPC interface of the ShuffleServer list
   rpc reassignShuffleServers(ReassignServersRequest) returns 
(ReassignServersResponse);
   // Reassign on block send failure that occurs in writer
-  rpc reassignOnBlockSendFailure(RssReassignOnBlockSendFailureRequest) returns 
(RssReassignOnBlockSendFailureResponse);
+  rpc reassignOnBlockSendFailure(RssReassignOnBlockSendFailureRequest) returns 
(ReassignOnBlockSendFailureResponse);
   rpc reportShuffleResult (ReportShuffleResultRequest) returns 
(ReportShuffleResultResponse);
   rpc getShuffleResult (GetShuffleResultRequest) returns 
(GetShuffleResultResponse);
   rpc getShuffleResultForMultiPart (GetShuffleResultForMultiPartRequest) 
returns (GetShuffleResultForMultiPartResponse);
@@ -563,10 +568,15 @@ message PartitionToShuffleServerRequest {
   int32 shuffleId = 2;
 }
 
-message PartitionToShuffleServerResponse {
+message ReassignOnStageRetryResponse {
   StatusCode status = 1;
   string msg = 2;
-  MutableShuffleHandleInfo shuffleHandleInfo = 3;
+  StageAttemptShuffleHandleInfo shuffleHandleInfo = 3;
+}
+
+message StageAttemptShuffleHandleInfo {
+  repeated MutableShuffleHandleInfo historyMutableShuffleHandleInfo= 1;
+  MutableShuffleHandleInfo currentMutableShuffleHandleInfo = 2;
 }
 
 message MutableShuffleHandleInfo {
@@ -633,7 +643,7 @@ message ReceivingFailureServer {
   StatusCode statusCode = 2;
 }
 
-message RssReassignOnBlockSendFailureResponse {
+message ReassignOnBlockSendFailureResponse {
   StatusCode status = 1;
   string msg = 2;
   MutableShuffleHandleInfo handle = 3;
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index de378d66e..b6e37029f 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -157,6 +157,48 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     int shuffleId = req.getShuffleId();
     String remoteStoragePath = req.getRemoteStorage().getPath();
     String user = req.getUser();
+    int stageAttemptNumber = req.getStageAttemptNumber();
+    // If the Stage is registered for the first time, you do not need to 
consider the Stage retry
+    // and delete the Block data that has been sent.
+    if (stageAttemptNumber > 0) {
+      ShuffleTaskInfo taskInfo = 
shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId);
+      // Prevents AttemptNumber of multiple stages from modifying the latest 
AttemptNumber.
+      synchronized (taskInfo) {
+        int attemptNumber = taskInfo.getLatestStageAttemptNumber(shuffleId);
+        if (stageAttemptNumber > attemptNumber) {
+          taskInfo.refreshLatestStageAttemptNumber(shuffleId, 
stageAttemptNumber);
+          try {
+            long start = System.currentTimeMillis();
+            shuffleServer.getShuffleTaskManager().removeShuffleDataSync(appId, 
shuffleId);
+            LOG.info(
+                "Deleted the previous stage attempt data due to stage 
recomputing for app: {}, "
+                    + "shuffleId: {}. It costs {} ms",
+                appId,
+                shuffleId,
+                System.currentTimeMillis() - start);
+          } catch (Exception e) {
+            LOG.error(
+                "Errors on clearing previous stage attempt data for app: {}, 
shuffleId: {}",
+                appId,
+                shuffleId,
+                e);
+            StatusCode code = StatusCode.INTERNAL_ERROR;
+            reply = 
ShuffleRegisterResponse.newBuilder().setStatus(code.toProto()).build();
+            responseObserver.onNext(reply);
+            responseObserver.onCompleted();
+            return;
+          }
+        } else if (stageAttemptNumber < attemptNumber) {
+          // When a Stage retry occurs, the first or last registration of a 
Stage may need to be
+          // ignored and the ignored status quickly returned.
+          StatusCode code = StatusCode.STAGE_RETRY_IGNORE;
+          reply = 
ShuffleRegisterResponse.newBuilder().setStatus(code.toProto()).build();
+          responseObserver.onNext(reply);
+          responseObserver.onCompleted();
+          return;
+        }
+      }
+    }
 
     ShuffleDataDistributionType shuffleDataDistributionType =
         ShuffleDataDistributionType.valueOf(
@@ -210,6 +252,22 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     int shuffleId = req.getShuffleId();
     long requireBufferId = req.getRequireBufferId();
     long timestamp = req.getTimestamp();
+    int stageAttemptNumber = req.getStageAttemptNumber();
+    ShuffleTaskInfo taskInfo = 
shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId);
+    Integer latestStageAttemptNumber = 
taskInfo.getLatestStageAttemptNumber(shuffleId);
+    // The Stage retry occurred, and the task before StageNumber was simply 
ignored and not
+    // processed if the task was being sent.
+    if (stageAttemptNumber < latestStageAttemptNumber) {
+      String responseMessage = "A retry has occurred at the Stage, sending 
data is invalid.";
+      reply =
+          SendShuffleDataResponse.newBuilder()
+              .setStatus(StatusCode.STAGE_RETRY_IGNORE.toProto())
+              .setRetMsg(responseMessage)
+              .build();
+      responseObserver.onNext(reply);
+      responseObserver.onCompleted();
+      return;
+    }
     if (timestamp > 0) {
       /*
        * Here we record the transport time, but we don't consider the impact 
of data size on transport time.
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java
index b6806f634..bbbfded01 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java
@@ -66,6 +66,8 @@ public class ShuffleTaskInfo {
 
   private final Map<Integer, Map<Integer, AtomicLong>> partitionBlockCounters;
 
+  private final Map<Integer, Integer> latestStageAttemptNumbers;
+
   public ShuffleTaskInfo(String appId) {
     this.appId = appId;
     this.currentTimes = System.currentTimeMillis();
@@ -78,6 +80,7 @@ public class ShuffleTaskInfo {
     this.existHugePartition = new AtomicBoolean(false);
     this.specification = new AtomicReference<>();
     this.partitionBlockCounters = JavaUtils.newConcurrentMap();
+    this.latestStageAttemptNumbers = JavaUtils.newConcurrentMap();
   }
 
   public Long getCurrentTimes() {
@@ -220,6 +223,14 @@ public class ShuffleTaskInfo {
     return counter.get();
   }
 
+  public Integer getLatestStageAttemptNumber(int shuffleId) {
+    return latestStageAttemptNumbers.computeIfAbsent(shuffleId, key -> 0);
+  }
+
+  public void refreshLatestStageAttemptNumber(int shuffleId, int 
stageAttemptNumber) {
+    latestStageAttemptNumbers.put(shuffleId, stageAttemptNumber);
+  }
+
   @Override
   public String toString() {
     return "ShuffleTaskInfo{"
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index a98eac26c..8fe597d03 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -910,7 +910,7 @@ public class ShuffleTaskManager {
   }
 
   @VisibleForTesting
-  void removeShuffleDataSync(String appId, int shuffleId) {
+  public void removeShuffleDataSync(String appId, int shuffleId) {
     removeResourcesByShuffleIds(appId, Arrays.asList(shuffleId));
   }
 
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index 668b53cba..448c12a23 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -57,6 +57,7 @@ import org.apache.uniffle.server.ShuffleDataReadEvent;
 import org.apache.uniffle.server.ShuffleServer;
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.server.ShuffleServerMetrics;
+import org.apache.uniffle.server.ShuffleTaskInfo;
 import org.apache.uniffle.server.ShuffleTaskManager;
 import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo;
 import org.apache.uniffle.server.buffer.ShuffleBufferManager;
@@ -102,6 +103,18 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
     int shuffleId = req.getShuffleId();
     long requireBufferId = req.getRequireId();
     long timestamp = req.getTimestamp();
+    int stageAttemptNumber = req.getStageAttemptNumber();
+    ShuffleTaskInfo taskInfo = 
shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId);
+    Integer latestStageAttemptNumber = 
taskInfo.getLatestStageAttemptNumber(shuffleId);
+    // The Stage retry occurred, and the task before StageNumber was simply 
ignored and not
+    // processed if the task was being sent.
+    if (stageAttemptNumber < latestStageAttemptNumber) {
+      String responseMessage = "A retry has occurred at the Stage, sending 
data is invalid.";
+      rpcResponse =
+          new RpcResponse(req.getRequestId(), StatusCode.STAGE_RETRY_IGNORE, 
responseMessage);
+      client.getChannel().writeAndFlush(rpcResponse);
+      return;
+    }
     if (timestamp > 0) {
       /*
        * Here we record the transport time, but we don't consider the impact 
of data size on transport time.


Reply via email to