This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 2731cf2df [#2724] refactor(spark): Introduce `ReassignExecutor` to 
simplify shuffle writer logic (#2727)
2731cf2df is described below

commit 2731cf2df0de5c8322acbfbc2d81a35d1b7517cb
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Feb 12 10:10:33 2026 +0800

    [#2724] refactor(spark): Introduce `ReassignExecutor` to simplify shuffle 
writer logic (#2727)
    
    ### What changes were proposed in this pull request?
    
    Introduce `ReassignExecutor` to clarify shuffle writer logic
    
    ### Why are the changes needed?
    
    for #2724
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests
---
 .../shuffle/writer/TaskAttemptAssignment.java      |   2 +-
 .../spark/shuffle/writer/WriteBufferManager.java   |   5 +
 .../apache/uniffle/shuffle/ReassignExecutor.java   | 471 +++++++++++++++++++++
 .../shuffle/writer/TaskAttemptAssignmentTest.java  |   2 +-
 .../uniffle/shuffle/ReassignExecutorTest.java      | 124 ++++++
 .../spark/shuffle/writer/RssShuffleWriter.java     | 330 ++-------------
 .../spark/shuffle/writer/RssShuffleWriterTest.java |  40 +-
 7 files changed, 659 insertions(+), 315 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
index 7e46528c4..63fac0c12 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
@@ -88,7 +88,7 @@ public class TaskAttemptAssignment {
    * @param exclusiveServers
    * @return
    */
-  public boolean updatePartitionSplitAssignment(
+  public boolean tryNextServerForSplitPartition(
       int partitionId, List<ShuffleServerInfo> exclusiveServers) {
     if (hasBeenLoadBalanced(partitionId)) {
       Set<ShuffleServerInfo> servers =
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index b50cc5d56..280c6aefd 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -678,6 +678,11 @@ public class WriteBufferManager extends MemoryConsumer {
     return recordCounter.get();
   }
 
+  @VisibleForTesting
+  protected void resetRecordCount() {
+    recordCounter.set(0);
+  }
+
   public long getBlockCount() {
     return blockCounter.get();
   }
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
new file mode 100644
index 000000000..c380e3ee7
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
@@ -0,0 +1,471 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.writer.TaskAttemptAssignment;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.impl.TrackingBlockStatus;
+import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
+import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.common.ReceivingFailureServer;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssSendFailedException;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+/**
+ * this class is responsible for the reassignment, including the partition 
split and the block
+ * resend after reassignment.
+ */
+public class ReassignExecutor {
+  private static final Logger LOG = 
LoggerFactory.getLogger(ReassignExecutor.class);
+  private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND =
+      Sets.newHashSet(StatusCode.NO_REGISTER);
+
+  private Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker;
+  private final TaskAttemptAssignment taskAttemptAssignment;
+
+  private final Consumer<ShuffleBlockInfo> removeBlockStatsFunction;
+  private final Consumer<List<ShuffleBlockInfo>> resendBlocksFunction;
+  private final Supplier<ShuffleManagerClient> managerClientSupplier;
+
+  private String taskId;
+  private final TaskContext taskContext;
+  private final int shuffleId;
+  private int blockFailSentRetryMaxTimes;
+
+  public ReassignExecutor(
+      Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker,
+      String taskId,
+      TaskAttemptAssignment taskAttemptAssignment,
+      Consumer<ShuffleBlockInfo> removeBlockStatsFunction,
+      Consumer<List<ShuffleBlockInfo>> resendBlocksFunction,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
+      TaskContext taskContext,
+      int shuffleId,
+      int blockFailSentRetryMaxTimes) {
+    this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
+    this.taskId = taskId;
+    this.taskAttemptAssignment = taskAttemptAssignment;
+    this.removeBlockStatsFunction = removeBlockStatsFunction;
+    this.resendBlocksFunction = resendBlocksFunction;
+    this.managerClientSupplier = managerClientSupplier;
+    this.taskContext = taskContext;
+    this.shuffleId = shuffleId;
+    this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes;
+    LOG.debug("Initialized {} for taskId[{}]", 
this.getClass().getSimpleName(), taskId);
+  }
+
+  public void reassign() {
+    FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(taskId);
+    if (tracker == null) {
+      return;
+    }
+    // 1. reassign for split partitions.
+    reassignOnPartitionNeedSplit(tracker);
+    // 2. reassign for failed blocks
+    reassignAndResendForFailedBlocks(tracker);
+  }
+
+  @VisibleForTesting
+  public void resetBlockRetryMaxTimes(int times) {
+    this.blockFailSentRetryMaxTimes = times;
+  }
+
+  private void releaseResources(FailedBlockSendTracker tracker, Set<Long> 
blockIds) {
+    for (Long blockId : blockIds) {
+      List<TrackingBlockStatus> failedBlockStatus = 
tracker.getFailedBlockStatus(blockId);
+      if (CollectionUtils.isNotEmpty(failedBlockStatus)) {
+        TrackingBlockStatus blockStatus = failedBlockStatus.get(0);
+        blockStatus.getShuffleBlockInfo().executeCompletionCallback(true);
+      }
+    }
+  }
+
+  private void reassignAndResendForFailedBlocks(FailedBlockSendTracker 
failedBlockSendTracker) {
+    Set<Long> failedBlockIds = failedBlockSendTracker.getFailedBlockIds();
+    if (CollectionUtils.isEmpty(failedBlockIds)) {
+      return;
+    }
+
+    Set<TrackingBlockStatus> resendBlocks = new HashSet<>();
+    for (Long blockId : failedBlockIds) {
+      List<TrackingBlockStatus> failedBlockStatus =
+          failedBlockSendTracker.getFailedBlockStatus(blockId);
+      synchronized (failedBlockStatus) {
+        int retryCnt =
+            failedBlockStatus.stream()
+                .filter(
+                    x -> {
+                      // If statusCode is null, the block was resent due to a 
stale assignment.
+                      // In this case, the retry count checking should be 
ignored.
+                      return x.getStatusCode() != null;
+                    })
+                .map(x -> x.getShuffleBlockInfo().getRetryCnt())
+                .max(Comparator.comparing(Integer::valueOf))
+                .orElse(-1);
+        if (retryCnt >= blockFailSentRetryMaxTimes) {
+          releaseResources(failedBlockSendTracker, failedBlockIds);
+          String message =
+              String.format(
+                  "Block send retry exceeded max retries. blockId=%d, 
retryCount=%d, maxRetry=%d, faultyServers=%s",
+                  blockId,
+                  retryCnt,
+                  blockFailSentRetryMaxTimes,
+                  failedBlockStatus.stream()
+                      .map(TrackingBlockStatus::getShuffleServerInfo)
+                      .map(ShuffleServerInfo::getId)
+                      .collect(Collectors.toSet()));
+          throw new RssSendFailedException(message);
+        }
+
+        for (TrackingBlockStatus status : failedBlockStatus) {
+          StatusCode code = status.getStatusCode();
+          if (STATUS_CODE_WITHOUT_BLOCK_RESEND.contains(code)) {
+            releaseResources(failedBlockSendTracker, failedBlockIds);
+            String message =
+                String.format(
+                    "Block send failed with status code [%s] which does not 
trigger block resend. blockId=%d, retryCount=%d, maxRetry=%d, faultyServer=%s",
+                    code,
+                    blockId,
+                    retryCnt,
+                    blockFailSentRetryMaxTimes,
+                    status.getShuffleServerInfo());
+            throw new RssSendFailedException(message);
+          }
+        }
+
+        // todo: if setting multi replica and another replica is succeed to 
send, no need to resend
+        resendBlocks.addAll(failedBlockStatus);
+      }
+    }
+    reassignAndResendBlocks(resendBlocks);
+  }
+
+  private Map<Integer, Pair<List<String>, List<String>>> constructUpdateList(
+      Map<Integer, List<ReceivingFailureServer>> requestList) {
+    Map<Integer, Pair<List<String>, List<String>>> reassignUpdateList = new 
HashMap<>();
+    for (Map.Entry<Integer, List<ReceivingFailureServer>> entry : 
requestList.entrySet()) {
+      Integer partitionId = entry.getKey();
+      List<String> oldServers =
+          entry.getValue().stream()
+              .map(ReceivingFailureServer::getServerId)
+              .collect(Collectors.toList());
+      List<String> newServers =
+          taskAttemptAssignment.retrieve(partitionId).stream()
+              .map(ShuffleServerInfo::getId)
+              .collect(Collectors.toList());
+      reassignUpdateList.put(partitionId, Pair.of(oldServers, newServers));
+    }
+    return reassignUpdateList;
+  }
+
+  @VisibleForTesting
+  protected static String readableResult(
+      Map<Integer, Pair<List<String>, List<String>>> fastSwitchList) {
+    if (fastSwitchList == null || fastSwitchList.isEmpty()) {
+      return "";
+    }
+
+    StringBuilder sb = new StringBuilder();
+    boolean hasDiff = false;
+
+    for (Map.Entry<Integer, Pair<List<String>, List<String>>> entry :
+        fastSwitchList.entrySet().stream()
+            .sorted(Map.Entry.comparingByKey())
+            .collect(Collectors.toList())) {
+
+      Integer partitionId = entry.getKey();
+      Pair<List<String>, List<String>> servers = entry.getValue();
+      List<String> oldServers = servers.getLeft();
+      List<String> newServers = servers.getRight();
+
+      // compare as set to avoid ordering impact
+      if (oldServers != null
+          && newServers != null
+          && new HashSet<>(oldServers).equals(new HashSet<>(newServers))) {
+        continue;
+      }
+
+      hasDiff = true;
+
+      sb.append("partitionId=")
+          .append(partitionId)
+          .append(": ")
+          .append(oldServers)
+          .append(" -> ")
+          .append(newServers)
+          .append("; ");
+    }
+
+    if (!hasDiff) {
+      return "";
+    }
+
+    return sb.toString().trim();
+  }
+
+  private void reassignOnPartitionNeedSplit(FailedBlockSendTracker 
failedTracker) {
+    Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new 
HashMap<>();
+
+    failedTracker
+        .removeAllTrackedPartitions()
+        .forEach(
+            partitionStatus -> {
+              List<ReceivingFailureServer> servers =
+                  failurePartitionToServers.computeIfAbsent(
+                      partitionStatus.getPartitionId(), x -> new 
ArrayList<>());
+              String serverId = partitionStatus.getShuffleServerInfo().getId();
+              // todo: use better data structure to filter
+              if (!servers.stream()
+                  .map(x -> x.getServerId())
+                  .collect(Collectors.toSet())
+                  .contains(serverId)) {
+                servers.add(new ReceivingFailureServer(serverId, 
StatusCode.SUCCESS));
+              }
+            });
+
+    if (failurePartitionToServers.isEmpty()) {
+      return;
+    }
+
+    //
+    // For the [load balance] mode
+    // Once partition has been split, the following split trigger will be 
ignored.
+    //
+    // For the [pipeline] mode
+    // The split request will be always response
+    //
+
+    // the list of reassign list
+    Map<Integer, List<ReceivingFailureServer>> reassignList = new HashMap<>();
+
+    // the list of fast switch list. key: partitionId, value: left=old, 
right=new
+    Map<Integer, Pair<List<String>, List<String>>> fastSwitchList = new 
HashMap<>();
+
+    for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
+        failurePartitionToServers.entrySet()) {
+      int partitionId = entry.getKey();
+      List<ReceivingFailureServer> failureServers = entry.getValue();
+      if (taskAttemptAssignment.tryNextServerForSplitPartition(
+          partitionId,
+          failureServers.stream()
+              .map(x -> ShuffleServerInfo.from(x.getServerId()))
+              .collect(Collectors.toList()))) {
+        fastSwitchList.put(
+            partitionId,
+            Pair.of(
+                failureServers.stream()
+                    .map(ReceivingFailureServer::getServerId)
+                    .collect(Collectors.toList()),
+                taskAttemptAssignment.retrieve(partitionId).stream()
+                    .map(ShuffleServerInfo::getId)
+                    .collect(Collectors.toList())));
+      } else {
+        reassignList.put(partitionId, failureServers);
+      }
+    }
+
+    if (reassignList.isEmpty()) {
+      LOG.info(
+          "[partition-split] All fast switch to another servers successfully 
for taskId[{}]. list: {}",
+          taskId,
+          readableResult(fastSwitchList));
+      return;
+    } else {
+      if (!fastSwitchList.isEmpty()) {
+        LOG.info(
+            "[partition-split] Partial fast switch to another servers for 
taskId[{}]. list: {}",
+            taskId,
+            readableResult(fastSwitchList));
+      }
+    }
+
+    @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
+    long start = System.currentTimeMillis();
+    doReassignOnBlockSendFailure(reassignList, true);
+    LOG.info(
+        "[partition-split] Reassign successfully for taskId[{}] in {} ms. 
list: {}",
+        taskId,
+        System.currentTimeMillis() - start,
+        readableResult(constructUpdateList(reassignList)));
+  }
+
+  @VisibleForTesting
+  protected void doReassignOnBlockSendFailure(
+      Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers,
+      boolean partitionSplit) {
+    LOG.info(
+        "Initiate reassignOnBlockSendFailure of taskId[{}]. isPartitionSplit: 
{}. failurePartitionServers: {}.",
+        taskId,
+        partitionSplit,
+        failurePartitionToServers);
+    // for tests to set the default value
+    String executorId = "NULL";
+    try {
+      executorId = SparkEnv.get().executorId();
+    } catch (Exception e) {
+      // ignore
+    }
+    long taskAttemptId = taskContext.taskAttemptId();
+    int stageId = taskContext.stageId();
+    int stageAttemptNum = taskContext.stageAttemptNumber();
+    try {
+      RssReassignOnBlockSendFailureRequest request =
+          new RssReassignOnBlockSendFailureRequest(
+              shuffleId,
+              failurePartitionToServers,
+              executorId,
+              taskAttemptId,
+              stageId,
+              stageAttemptNum,
+              partitionSplit);
+      RssReassignOnBlockSendFailureResponse response =
+          managerClientSupplier.get().reassignOnBlockSendFailure(request);
+      if (response.getStatusCode() != StatusCode.SUCCESS) {
+        String msg =
+            String.format(
+                "Reassign request failed. statusCode: %s, msg: %s",
+                response.getStatusCode(), response.getMessage());
+        throw new RssException(msg);
+      }
+      MutableShuffleHandleInfo handle = 
MutableShuffleHandleInfo.fromProto(response.getHandle());
+      taskAttemptAssignment.update(handle);
+    } catch (Exception e) {
+      throw new RssException(
+          "Errors on reassign on block send failure. failure 
partition->servers : "
+              + failurePartitionToServers,
+          e);
+    }
+  }
+
+  private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
+    @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
+    long start = System.currentTimeMillis();
+    List<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
+    Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks =
+        blocks.stream()
+            .collect(Collectors.groupingBy(d -> 
d.getShuffleBlockInfo().getPartitionId()));
+
+    Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new 
HashMap<>();
+    for (Map.Entry<Integer, List<TrackingBlockStatus>> entry : 
partitionedFailedBlocks.entrySet()) {
+      int partitionId = entry.getKey();
+      List<TrackingBlockStatus> partitionBlocks = entry.getValue();
+      Map<ShuffleServerInfo, TrackingBlockStatus> serverBlocks =
+          partitionBlocks.stream()
+              .collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()))
+              .entrySet()
+              .stream()
+              .collect(
+                  Collectors.toMap(
+                      Map.Entry::getKey, x -> 
x.getValue().stream().findFirst().get()));
+      for (Map.Entry<ShuffleServerInfo, TrackingBlockStatus> blockStatusEntry :
+          serverBlocks.entrySet()) {
+        String serverId = blockStatusEntry.getKey().getId();
+        // avoid duplicate reassign for the same failure server.
+        // todo: getting the replacement should support multi replica.
+        List<ShuffleServerInfo> servers = 
taskAttemptAssignment.retrieve(partitionId);
+        // Gets the first replica for this partition for now.
+        // It can not work if we want to use multiple replicas.
+        ShuffleServerInfo replacement = servers.get(0);
+        String latestServerId = replacement.getId();
+        if (!serverId.equals(latestServerId)) {
+          continue;
+        }
+        StatusCode code = blockStatusEntry.getValue().getStatusCode();
+        failurePartitionToServers
+            .computeIfAbsent(partitionId, x -> new ArrayList<>())
+            .add(new ReceivingFailureServer(serverId, code));
+      }
+    }
+
+    if (!failurePartitionToServers.isEmpty()) {
+      @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
+      long requestStart = System.currentTimeMillis();
+      doReassignOnBlockSendFailure(failurePartitionToServers, false);
+      LOG.info(
+          "[partition-reassign] Do reassign request successfully in {} ms. 
list: {}",
+          System.currentTimeMillis() - requestStart,
+          readableResult(constructUpdateList(failurePartitionToServers)));
+    }
+
+    for (TrackingBlockStatus blockStatus : blocks) {
+      ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
+      // todo: getting the replacement should support multi replica.
+      List<ShuffleServerInfo> servers = 
taskAttemptAssignment.retrieve(block.getPartitionId());
+      // Gets the first replica for this partition for now.
+      // It can not work if we want to use multiple replicas.
+      ShuffleServerInfo replacement = servers.get(0);
+      if 
(blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) {
+        LOG.warn(
+            "PartitionId:{} has the following assigned servers: {}. But 
currently the replacement server:{} is the same with previous one!",
+            block.getPartitionId(),
+            taskAttemptAssignment.list(block.getPartitionId()),
+            replacement);
+        throw new RssException(
+            "No available replacement server for: " + 
blockStatus.getShuffleServerInfo().getId());
+      }
+      // clear the previous retry state of block
+      removeBlockStatsFunction.accept(block);
+      final ShuffleBlockInfo newBlock = block;
+      // if the status code is null, it means the block is resent due to stale 
assignment, not
+      // because of the block send failure. In this case, the retry count 
should not be increased;
+      // otherwise it may cause unexpected fast failure.
+      if (blockStatus.getStatusCode() != null) {
+        newBlock.incrRetryCnt();
+      }
+      newBlock.reassignShuffleServers(Arrays.asList(replacement));
+      resendCandidates.add(newBlock);
+    }
+    resendBlocksFunction.accept(resendCandidates);
+    LOG.info(
+        "[partition-reassign] All {} blocks have been resent to queue 
successfully in {} ms.",
+        blocks.size(),
+        System.currentTimeMillis() - start);
+  }
+
+  @VisibleForTesting
+  public void resetTaskId(String taskId) {
+    this.taskId = taskId;
+  }
+}
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
index b68e2964c..2ec928b10 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
@@ -79,7 +79,7 @@ public class TaskAttemptAssignmentTest {
   public void testUpdatePartitionSplitAssignment() {
     TaskAttemptAssignment assignment = new TaskAttemptAssignment(1, new 
MockShuffleHandleInfo());
     assertTrue(
-        assignment.updatePartitionSplitAssignment(
+        assignment.tryNextServerForSplitPartition(
             1, Arrays.asList(new ShuffleServerInfo("localhost", 122))));
   }
 }
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ReassignExecutorTest.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ReassignExecutorTest.java
new file mode 100644
index 000000000..fd25e0a2c
--- /dev/null
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ReassignExecutorTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.writer.TaskAttemptAssignment;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.impl.TrackingBlockStatus;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssSendFailedException;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class ReassignExecutorTest {
+
+  @Mock private FailedBlockSendTracker failedBlockSendTracker = 
mock(FailedBlockSendTracker.class);
+
+  @Mock private TaskAttemptAssignment taskAttemptAssignment = 
mock(TaskAttemptAssignment.class);
+
+  @Mock private ShuffleManagerClient shuffleManagerClient = 
mock(ShuffleManagerClient.class);
+
+  @Mock private TaskContext taskContext = mock(TaskContext.class);
+
+  @Mock private Consumer<ShuffleBlockInfo> removeBlockStatsFunction = 
mock(Consumer.class);
+
+  @Mock private Consumer<List<ShuffleBlockInfo>> resendBlocksFunction = 
mock(Consumer.class);
+
+  private ReassignExecutor executor = mock(ReassignExecutor.class);
+
+  @BeforeEach
+  void setUp() {
+    when(taskContext.taskAttemptId()).thenReturn(1L);
+    when(taskContext.stageId()).thenReturn(1);
+    when(taskContext.stageAttemptNumber()).thenReturn(0);
+
+    Map<String, FailedBlockSendTracker> taskToTracker = new HashMap<>();
+    taskToTracker.put("task1", failedBlockSendTracker);
+    executor =
+        new ReassignExecutor(
+            taskToTracker,
+            "task1",
+            taskAttemptAssignment,
+            removeBlockStatsFunction,
+            resendBlocksFunction,
+            () -> shuffleManagerClient,
+            taskContext,
+            1,
+            3);
+  }
+
+  @Test
+  void testRetryExceededShouldFailAndReleaseResources() {
+    long blockId = 100L;
+
+    ShuffleBlockInfo blockInfo = 
org.mockito.Mockito.mock(ShuffleBlockInfo.class);
+    when(blockInfo.getRetryCnt()).thenReturn(3);
+
+    TrackingBlockStatus status = 
org.mockito.Mockito.mock(TrackingBlockStatus.class);
+    when(status.getShuffleBlockInfo()).thenReturn(blockInfo);
+    when(status.getStatusCode()).thenReturn(StatusCode.INTERNAL_ERROR);
+    when(status.getShuffleServerInfo()).thenReturn(new 
ShuffleServerInfo("localhost", 1234));
+
+    when(failedBlockSendTracker.getFailedBlockIds())
+        .thenReturn(new HashSet<>(Arrays.asList(blockId)));
+    
when(failedBlockSendTracker.getFailedBlockStatus(blockId)).thenReturn(Arrays.asList(status));
+
+    assertThrows(RssSendFailedException.class, executor::reassign);
+
+    verify(blockInfo).executeCompletionCallback(true);
+  }
+
+  @Test
+  public void testMixedSameAndDifferent() throws Exception {
+    Map<Integer, Pair<List<String>, List<String>>> map = new HashMap<>();
+
+    // same -> should ignore
+    map.put(1, Pair.of(Arrays.asList("s1", "s2"), Arrays.asList("s2", "s1")));
+
+    // diff -> should print
+    map.put(3, Pair.of(Arrays.asList("x"), Arrays.asList("y")));
+
+    String result = ReassignExecutor.readableResult(map);
+
+    System.out.println(result);
+
+    assertFalse(result.contains("partitionId=1"));
+    assertTrue(result.contains("partitionId=3"));
+  }
+}
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 75bd1fb76..60f3ef46f 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -18,9 +18,7 @@
 package org.apache.spark.shuffle.writer;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
-import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -52,7 +50,6 @@ import org.apache.commons.collections4.CollectionUtils;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
-import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.scheduler.MapStatus;
@@ -62,7 +59,6 @@ import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.RssSparkShuffleUtils;
 import org.apache.spark.shuffle.ShuffleWriter;
-import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.storage.BlockManagerId;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -73,13 +69,10 @@ import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.impl.TrackingBlockStatus;
-import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteMetricRequest;
-import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteMetricResponse;
-import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
@@ -89,6 +82,7 @@ import 
org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.shuffle.BlockStats;
+import org.apache.uniffle.shuffle.ReassignExecutor;
 import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -127,7 +121,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private SparkConf sparkConf;
   private RssConf rssConf;
   private boolean blockFailSentRetryEnabled;
-  private int blockFailSentRetryMaxTimes = 1;
 
   /** used by columnar rss shuffle writer implementation */
   protected final long taskAttemptId;
@@ -157,6 +150,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private boolean isIntegrityValidationClientManagementEnabled = false;
 
+  private ReassignExecutor reassignExecutor;
+
   // Only for tests
   @VisibleForTesting
   public RssShuffleWriter(
@@ -189,6 +184,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         context);
     this.bufferManager = bufferManager;
     this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId, 
shuffleHandleInfo);
+    this.reassignExecutor =
+        new ReassignExecutor(
+            shuffleManager.getTaskToFailedBlockSendTracker(),
+            taskId,
+            taskAttemptAssignment,
+            block -> clearFailedBlockState(block),
+            blocks -> processShuffleBlockInfos(blocks),
+            managerClientSupplier,
+            taskContext,
+            shuffleId,
+            rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES));
   }
 
   private RssShuffleWriter(
@@ -239,7 +245,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
                 + RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.key(),
             RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.defaultValue());
-    this.blockFailSentRetryMaxTimes = 
rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
     this.enableWriteFailureRetry = 
rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
     this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
     this.isIntegrityValidationClientManagementEnabled =
@@ -321,6 +326,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             this::getPartitionAssignedServers,
             context.stageAttemptNumber());
     this.bufferManager = bufferManager;
+    this.reassignExecutor =
+        new ReassignExecutor(
+            shuffleManager.getTaskToFailedBlockSendTracker(),
+            taskId,
+            taskAttemptAssignment,
+            block -> clearFailedBlockState(block),
+            blocks -> processShuffleBlockInfos(blocks),
+            managerClientSupplier,
+            taskContext,
+            shuffleId,
+            rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES));
   }
 
   @VisibleForTesting
@@ -421,11 +437,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void checkSentRecordCount(long recordCount) {
     if (recordCount != bufferManager.getRecordCount()) {
-      String errorMsg =
-          "Potential record loss may have occurred while preparing to send 
blocks for task["
-              + taskId
-              + "]";
-      throw new RssSendFailedException(errorMsg);
+      String message =
+          String.format(
+              "Inconsistent records number for taskId[%s]. expected: %d, 
actual: %d.",
+              taskId, recordCount, bufferManager.getRecordCount());
+      throw new RssSendFailedException(message);
     }
   }
 
@@ -578,7 +594,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   // This method should remain protected so that Gluten can invoke it
   protected void checkDataIfAnyFailure() {
     if (blockFailSentRetryEnabled) {
-      collectFailedBlocksToResend();
+      reassignExecutor.reassign();
     } else {
       String errorMsg = getFirstBlockFailure();
       if (errorMsg != null) {
@@ -608,281 +624,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return null;
   }
 
-  private void collectFailedBlocksToResend() {
-    if (!blockFailSentRetryEnabled) {
-      return;
-    }
-
-    FailedBlockSendTracker failedTracker = 
shuffleManager.getBlockIdsFailedSendTracker(taskId);
-    if (failedTracker == null) {
-      return;
-    }
-
-    reassignOnPartitionNeedSplit(failedTracker);
-
-    Set<Long> failedBlockIds = failedTracker.getFailedBlockIds();
-    if (CollectionUtils.isEmpty(failedBlockIds)) {
-      return;
-    }
-
-    boolean isFastFail = false;
-    Set<TrackingBlockStatus> resendCandidates = new HashSet<>();
-    // to check whether the blocks resent exceed the max resend count.
-    for (Long blockId : failedBlockIds) {
-      List<TrackingBlockStatus> failedBlockStatus = 
failedTracker.getFailedBlockStatus(blockId);
-      synchronized (failedBlockStatus) {
-        int retryCnt =
-            failedBlockStatus.stream()
-                .filter(
-                    x -> {
-                      // If statusCode is null, the block was resent due to a 
stale assignment.
-                      // In this case, the retry count checking should be 
ignored.
-                      return x.getStatusCode() != null;
-                    })
-                .map(x -> x.getShuffleBlockInfo().getRetryCnt())
-                .max(Comparator.comparing(Integer::valueOf))
-                .orElse(-1);
-        if (retryCnt >= blockFailSentRetryMaxTimes) {
-          LOG.error(
-              "Partial blocks for taskId: [{}] retry exceeding the max retry 
times: [{}]. Fast fail! faulty server list: {}",
-              taskId,
-              blockFailSentRetryMaxTimes,
-              failedBlockStatus.stream()
-                  .map(x -> x.getShuffleServerInfo())
-                  .collect(Collectors.toSet()));
-          isFastFail = true;
-          break;
-        }
-
-        for (TrackingBlockStatus status : failedBlockStatus) {
-          StatusCode code = status.getStatusCode();
-          if (STATUS_CODE_WITHOUT_BLOCK_RESEND.contains(code)) {
-            LOG.error(
-                "Partial blocks for taskId: [{}] failed on the illegal status 
code: [{}] without resend on server: {}",
-                taskId,
-                code,
-                status.getShuffleServerInfo());
-            isFastFail = true;
-            break;
-          }
-        }
-
-        // todo: if setting multi replica and another replica is succeed to 
send, no need to resend
-        resendCandidates.addAll(failedBlockStatus);
-      }
-    }
-
-    if (isFastFail) {
-      // release data and allocated memory
-      for (Long blockId : failedBlockIds) {
-        List<TrackingBlockStatus> failedBlockStatus = 
failedTracker.getFailedBlockStatus(blockId);
-        if (CollectionUtils.isNotEmpty(failedBlockStatus)) {
-          TrackingBlockStatus blockStatus = failedBlockStatus.get(0);
-          blockStatus.getShuffleBlockInfo().executeCompletionCallback(true);
-        }
-      }
-
-      throw new RssSendFailedException(
-          "Errors on resending the blocks data to the remote shuffle-server.");
-    }
-
-    reassignAndResendBlocks(resendCandidates);
-  }
-
-  private void reassignOnPartitionNeedSplit(FailedBlockSendTracker 
failedTracker) {
-    Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new 
HashMap<>();
-
-    failedTracker
-        .removeAllTrackedPartitions()
-        .forEach(
-            partitionStatus -> {
-              List<ReceivingFailureServer> servers =
-                  failurePartitionToServers.computeIfAbsent(
-                      partitionStatus.getPartitionId(), x -> new 
ArrayList<>());
-              String serverId = partitionStatus.getShuffleServerInfo().getId();
-              // todo: use better data structure to filter
-              if (!servers.stream()
-                  .map(x -> x.getServerId())
-                  .collect(Collectors.toSet())
-                  .contains(serverId)) {
-                servers.add(new ReceivingFailureServer(serverId, 
StatusCode.SUCCESS));
-              }
-            });
-
-    if (failurePartitionToServers.isEmpty()) {
-      return;
-    }
-
-    //
-    // For the [load balance] mode
-    // Once partition has been split, the following split trigger will be 
ignored.
-    //
-    // For the [pipeline] mode
-    // The split request will be always response
-    //
-    Map<Integer, List<ReceivingFailureServer>> partitionToServersReassignList 
= new HashMap<>();
-    for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
-        failurePartitionToServers.entrySet()) {
-      int partitionId = entry.getKey();
-      List<ReceivingFailureServer> failureServers = entry.getValue();
-      if (!taskAttemptAssignment.updatePartitionSplitAssignment(
-          partitionId,
-          failureServers.stream()
-              .map(x -> ShuffleServerInfo.from(x.getServerId()))
-              .collect(Collectors.toList()))) {
-        partitionToServersReassignList.put(partitionId, failureServers);
-      }
-    }
-
-    if (partitionToServersReassignList.isEmpty()) {
-      LOG.info(
-          "[Partition split] Skip the following partition split request (maybe 
has been load balanced). partitionIds: {}",
-          failurePartitionToServers.keySet());
-      return;
-    }
-
-    doReassignOnBlockSendFailure(partitionToServersReassignList, true);
-
-    LOG.info("========================= Partition Split Result 
=========================");
-    for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
-        partitionToServersReassignList.entrySet()) {
-      LOG.info(
-          "partitionId:{}. {} -> {}",
-          entry.getKey(),
-          entry.getValue().stream().map(x -> 
x.getServerId()).collect(Collectors.toList()),
-          taskAttemptAssignment.retrieve(entry.getKey()));
-    }
-    
LOG.info("==========================================================================");
-  }
-
-  private void doReassignOnBlockSendFailure(
-      Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers,
-      boolean partitionSplit) {
-    LOG.info(
-        "Initiate reassignOnBlockSendFailure of taskId[{}]. partition split: 
{}. failure partition servers: {}. ",
-        taskAttemptId,
-        partitionSplit,
-        failurePartitionToServers);
-    String executorId = SparkEnv.get().executorId();
-    long taskAttemptId = taskContext.taskAttemptId();
-    int stageId = taskContext.stageId();
-    int stageAttemptNum = taskContext.stageAttemptNumber();
-    try {
-      RssReassignOnBlockSendFailureRequest request =
-          new RssReassignOnBlockSendFailureRequest(
-              shuffleId,
-              failurePartitionToServers,
-              executorId,
-              taskAttemptId,
-              stageId,
-              stageAttemptNum,
-              partitionSplit);
-      RssReassignOnBlockSendFailureResponse response =
-          managerClientSupplier.get().reassignOnBlockSendFailure(request);
-      if (response.getStatusCode() != StatusCode.SUCCESS) {
-        String msg =
-            String.format(
-                "Reassign failed. statusCode: %s, msg: %s",
-                response.getStatusCode(), response.getMessage());
-        throw new RssException(msg);
-      }
-      MutableShuffleHandleInfo handle = 
MutableShuffleHandleInfo.fromProto(response.getHandle());
-      taskAttemptAssignment.update(handle);
-
-      // print the lastest assignment for those reassignment partition ids
-      Map<Integer, List<String>> reassignments = new HashMap<>();
-      for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
-          failurePartitionToServers.entrySet()) {
-        int partitionId = entry.getKey();
-        List<ShuffleServerInfo> servers = 
taskAttemptAssignment.retrieve(partitionId);
-        reassignments.put(
-            partitionId, servers.stream().map(x -> 
x.getId()).collect(Collectors.toList()));
-      }
-      LOG.info("Succeed to reassign that the latest assignment is {}", 
reassignments);
-    } catch (Exception e) {
-      throw new RssException(
-          "Errors on reassign on block send failure. failure 
partition->servers : "
-              + failurePartitionToServers,
-          e);
-    }
-  }
-
-  private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
-    List<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
-    Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks =
-        blocks.stream()
-            .collect(Collectors.groupingBy(d -> 
d.getShuffleBlockInfo().getPartitionId()));
-
-    Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new 
HashMap<>();
-    for (Map.Entry<Integer, List<TrackingBlockStatus>> entry : 
partitionedFailedBlocks.entrySet()) {
-      int partitionId = entry.getKey();
-      List<TrackingBlockStatus> partitionBlocks = entry.getValue();
-      Map<ShuffleServerInfo, TrackingBlockStatus> serverBlocks =
-          partitionBlocks.stream()
-              .collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()))
-              .entrySet()
-              .stream()
-              .collect(
-                  Collectors.toMap(
-                      Map.Entry::getKey, x -> 
x.getValue().stream().findFirst().get()));
-      for (Map.Entry<ShuffleServerInfo, TrackingBlockStatus> blockStatusEntry :
-          serverBlocks.entrySet()) {
-        String serverId = blockStatusEntry.getKey().getId();
-        // avoid duplicate reassign for the same failure server.
-        // todo: getting the replacement should support multi replica.
-        List<ShuffleServerInfo> servers = 
getPartitionAssignedServers(partitionId);
-        // Gets the first replica for this partition for now.
-        // It can not work if we want to use multiple replicas.
-        ShuffleServerInfo replacement = servers.get(0);
-        String latestServerId = replacement.getId();
-        if (!serverId.equals(latestServerId)) {
-          continue;
-        }
-        StatusCode code = blockStatusEntry.getValue().getStatusCode();
-        failurePartitionToServers
-            .computeIfAbsent(partitionId, x -> new ArrayList<>())
-            .add(new ReceivingFailureServer(serverId, code));
-      }
-    }
-
-    if (!failurePartitionToServers.isEmpty()) {
-      doReassignOnBlockSendFailure(failurePartitionToServers, false);
-    }
-
-    for (TrackingBlockStatus blockStatus : blocks) {
-      ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
-      // todo: getting the replacement should support multi replica.
-      List<ShuffleServerInfo> servers = 
getPartitionAssignedServers(block.getPartitionId());
-      // Gets the first replica for this partition for now.
-      // It can not work if we want to use multiple replicas.
-      ShuffleServerInfo replacement = servers.get(0);
-      if 
(blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) {
-        LOG.warn(
-            "PartitionId:{} has the following assigned servers: {}. But 
currently the replacement server:{} is the same with previous one!",
-            block.getPartitionId(),
-            taskAttemptAssignment.list(block.getPartitionId()),
-            replacement);
-        throw new RssException(
-            "No available replacement server for: " + 
blockStatus.getShuffleServerInfo().getId());
-      }
-      // clear the previous retry state of block
-      clearFailedBlockState(block);
-      final ShuffleBlockInfo newBlock = block;
-      // if the status code is null, it means the block is resent due to stale 
assignment, not
-      // because of the block send failure. In this case, the retry count 
should not be increased;
-      // otherwise it may cause unexpected fast failure.
-      if (blockStatus.getStatusCode() != null) {
-        newBlock.incrRetryCnt();
-      }
-      newBlock.reassignShuffleServers(Arrays.asList(replacement));
-      resendCandidates.add(newBlock);
-    }
-
-    processShuffleBlockInfos(resendCandidates);
-    LOG.info(
-        "Failed blocks have been resent to data pusher queue since 
reassignment has been finished successfully");
-  }
-
   private void clearFailedBlockState(ShuffleBlockInfo block) {
     
shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
     shuffleTaskStats.decPartitionBlock(block.getPartitionId());
@@ -1109,8 +850,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   }
 
   @VisibleForTesting
-  protected void setBlockFailSentRetryMaxTimes(int blockFailSentRetryMaxTimes) 
{
-    this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes;
+  protected void resetBlockFailSentRetryMaxTimes(int times) {
+    reassignExecutor.resetBlockRetryMaxTimes(times);
   }
 
   @VisibleForTesting
@@ -1160,6 +901,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return shuffleManager;
   }
 
+  @VisibleForTesting
+  public ReassignExecutor getReassignExecutor() {
+    return reassignExecutor;
+  }
+
   public TaskAttemptAssignment getTaskAttemptAssignment() {
     return taskAttemptAssignment;
   }
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index dd52aa915..2e16f454c 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -20,9 +20,11 @@ package org.apache.spark.shuffle.writer;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
@@ -212,7 +214,7 @@ public class RssShuffleWriterTest {
     String taskId = "taskId";
     MutableShuffleHandleInfo shuffleHandle = createMutableShuffleHandle();
     RssShuffleWriter writer = createMockWriter(shuffleHandle, taskId);
-    writer.setBlockFailSentRetryMaxTimes(10);
+    writer.resetBlockFailSentRetryMaxTimes(10);
 
     // Make the id1 + id10 + id11 broken, and then finally, it will use the 
id12 successfully
     AtomicInteger failureCnt = new AtomicInteger();
@@ -498,28 +500,26 @@ public class RssShuffleWriterTest {
     assertEquals(2, 
serverToPartitionToBlockIds.get(replacement).get(0).size());
 
     // case2. If exceeding the max retry times, it will fast fail.
-    rssShuffleWriter.setBlockFailSentRetryMaxTimes(1);
-    rssShuffleWriter.setTaskId("taskId2");
-    rssShuffleWriter.getBufferManager().setTaskId("taskId2");
-    taskToFailedBlockSendTracker.put("taskId2", new FailedBlockSendTracker());
-    AtomicInteger rejectCnt = new AtomicInteger(0);
+    String taskId = "t2";
+    rssShuffleWriter.getReassignExecutor().resetTaskId(taskId);
+    bufferManagerSpy.resetRecordCount();
+    rssShuffleWriter.resetBlockFailSentRetryMaxTimes(1);
+    rssShuffleWriter.setTaskId(taskId);
+    rssShuffleWriter.getBufferManager().setTaskId(taskId);
+    FailedBlockSendTracker tracker = new FailedBlockSendTracker();
+    taskToFailedBlockSendTracker.put(taskId, tracker);
     FakedDataPusher alwaysFailedDataPusher =
         new FakedDataPusher(
             event -> {
-              assertEquals("taskId2", event.getTaskId());
-              FailedBlockSendTracker tracker = 
taskToFailedBlockSendTracker.get(event.getTaskId());
+              assertEquals(taskId, event.getTaskId());
               for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
-                boolean isSuccessful = true;
-                ShuffleServerInfo shuffleServer = 
block.getShuffleServerInfos().get(0);
-                if (shuffleServer.getId().equals("id1") && rejectCnt.get() <= 
3) {
-                  tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
-                  isSuccessful = false;
-                  rejectCnt.incrementAndGet();
-                } else {
-                  successBlockIds.putIfAbsent(event.getTaskId(), 
Sets.newConcurrentHashSet());
-                  
successBlockIds.get(event.getTaskId()).add(block.getBlockId());
-                }
-                block.executeCompletionCallback(isSuccessful);
+                tracker.add(block, block.getShuffleServerInfos().get(0), 
StatusCode.NO_BUFFER);
+                block.executeCompletionCallback(false);
+              }
+              List<Runnable> callbackChain =
+                  
Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
+              for (Runnable runnable : callbackChain) {
+                runnable.run();
               }
               return new CompletableFuture<>();
             });
@@ -533,8 +533,6 @@ public class RssShuffleWriterTest {
     } catch (Exception e) {
       // ignore
     }
-    assertEquals(0, bufferManagerSpy.getUsedBytes());
-    assertEquals(0, bufferManagerSpy.getInSendListBytes());
   }
 
   @Test

Reply via email to