This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 80caa0ecf [#1608][part-2] fix(spark): avoid releasing block in advance 
when enable block resend (#1610)
80caa0ecf is described below

commit 80caa0ecf7a605d591c8b3e7319cd046901617f3
Author: Junfan Zhang <[email protected]>
AuthorDate: Mon Apr 8 15:09:10 2024 +0800

    [#1608][part-2] fix(spark): avoid releasing block in advance when enable 
block resend (#1610)
    
    ### What changes were proposed in this pull request?
    
    1. avoid releasing block previously when enable block resend
    2. introduce the block max retry times
    
    ### Why are the changes needed?
    
    For: #1608
    
    In the current codebase for partition reassignment, it has some bugs as 
follows
    1. data has been released when resending.
    2. if the blocks fail to resend, it may fast fail without retry again
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    `RssShuffleWriterTest#blockFailureResendTest` is to test the resending 
block mechanism.
---
 .../apache/spark/shuffle/writer/AddBlockEvent.java |   8 -
 .../spark/shuffle/writer/BlockFailureCallback.java |  24 ++
 .../spark/shuffle/writer/BlockSuccessCallback.java |  24 ++
 .../apache/spark/shuffle/writer/DataPusher.java    |  11 +-
 .../spark/shuffle/writer/WriteBufferManager.java   |  32 +--
 .../shuffle/writer/WriteBufferManagerTest.java     |   6 +
 .../apache/spark/shuffle/RssShuffleManager.java    |   5 +
 .../spark/shuffle/writer/RssShuffleWriter.java     | 259 +++++++++++++--------
 .../spark/shuffle/writer/RssShuffleWriterTest.java | 197 +++++++++++++++-
 .../client/impl/FailedBlockSendTracker.java        |  11 +-
 .../uniffle/common/BlockCompletionCallback.java    |  22 ++
 .../apache/uniffle/common/ShuffleBlockInfo.java    |  26 +++
 .../uniffle/common/function/TupleConsumer.java     |  23 ++
 13 files changed, 520 insertions(+), 128 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
index 5a93c2b11..9751ba0b8 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
@@ -34,14 +34,6 @@ public class AddBlockEvent {
     this.processedCallbackChain = new ArrayList<>();
   }
 
-  public AddBlockEvent(
-      String taskId, List<ShuffleBlockInfo> shuffleBlockInfoList, Runnable 
callback) {
-    this.taskId = taskId;
-    this.shuffleDataInfoList = shuffleBlockInfoList;
-    this.processedCallbackChain = new ArrayList<>();
-    addCallback(callback);
-  }
-
   /** @param callback, should not throw any exception and execute fast. */
   public void addCallback(Runnable callback) {
     processedCallbackChain.add(callback);
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java
new file mode 100644
index 000000000..116d1945d
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+public interface BlockFailureCallback {
+  void onBlockFailure(ShuffleBlockInfo block);
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java
new file mode 100644
index 000000000..2b5dc0d09
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+public interface BlockSuccessCallback {
+  void onBlockSuccess(ShuffleBlockInfo block);
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index 30f649f68..1517b7173 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -88,14 +88,23 @@ public class DataPusher implements Closeable {
         () -> {
           String taskId = event.getTaskId();
           List<ShuffleBlockInfo> shuffleBlockInfoList = 
event.getShuffleDataInfoList();
+          SendShuffleDataResult result = null;
           try {
-            SendShuffleDataResult result =
+            result =
                 shuffleWriteClient.sendShuffleData(
                     rssAppId, shuffleBlockInfoList, () -> 
!isValidTask(taskId));
             putBlockId(taskToSuccessBlockIds, taskId, 
result.getSuccessBlockIds());
             putFailedBlockSendTracker(
                 taskToFailedBlockSendTracker, taskId, 
result.getFailedBlockSendTracker());
           } finally {
+            Set<Long> succeedBlockIds =
+                result.getSuccessBlockIds() == null
+                    ? Collections.emptySet()
+                    : result.getSuccessBlockIds();
+            for (ShuffleBlockInfo block : shuffleBlockInfoList) {
+              
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
+            }
+
             List<Runnable> callbackChain =
                 
Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
             for (Runnable runnable : callbackChain) {
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index d8261047f..efe376a34 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -408,14 +408,18 @@ public class WriteBufferManager extends MemoryConsumer {
     }
   }
 
+  public void releaseBlockResource(ShuffleBlockInfo block) {
+    this.freeAllocatedMemory(block.getFreeMemory());
+    block.getData().release();
+  }
+
   public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> 
shuffleBlockInfoList) {
     long totalSize = 0;
-    long memoryUsed = 0;
     List<AddBlockEvent> events = new ArrayList<>();
     List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
     for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
+      sbi.withCompletionCallback((block, isSuccessful) -> 
this.releaseBlockResource(block));
       totalSize += sbi.getSize();
-      memoryUsed += sbi.getFreeMemory();
       shuffleBlockInfosPerEvent.add(sbi);
       // split shuffle data according to the size
       if (totalSize > sendSizeLimit) {
@@ -427,20 +431,9 @@ public class WriteBufferManager extends MemoryConsumer {
                   + totalSize
                   + " bytes");
         }
-        // Use final temporary variables for closures
-        final long memoryUsedTemp = memoryUsed;
-        final List<ShuffleBlockInfo> shuffleBlocksTemp = 
shuffleBlockInfosPerEvent;
-        events.add(
-            new AddBlockEvent(
-                taskId,
-                shuffleBlockInfosPerEvent,
-                () -> {
-                  freeAllocatedMemory(memoryUsedTemp);
-                  shuffleBlocksTemp.stream().forEach(x -> 
x.getData().release());
-                }));
+        events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
         shuffleBlockInfosPerEvent = Lists.newArrayList();
         totalSize = 0;
-        memoryUsed = 0;
       }
     }
     if (!shuffleBlockInfosPerEvent.isEmpty()) {
@@ -453,16 +446,7 @@ public class WriteBufferManager extends MemoryConsumer {
                 + " bytes");
       }
       // Use final temporary variables for closures
-      final long memoryUsedTemp = memoryUsed;
-      final List<ShuffleBlockInfo> shuffleBlocksTemp = 
shuffleBlockInfosPerEvent;
-      events.add(
-          new AddBlockEvent(
-              taskId,
-              shuffleBlockInfosPerEvent,
-              () -> {
-                freeAllocatedMemory(memoryUsedTemp);
-                shuffleBlocksTemp.stream().forEach(x -> x.getData().release());
-              }));
+      events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
     }
     return events;
   }
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 38ebbbd37..22143bc0e 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -371,6 +371,9 @@ public class WriteBufferManagerTest {
           long sum = 0L;
           List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
           for (AddBlockEvent event : events) {
+            for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+              block.executeCompletionCallback(true);
+            }
             event.getProcessedCallbackChain().stream().forEach(x -> x.run());
             sum += event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum();
           }
@@ -413,6 +416,9 @@ public class WriteBufferManagerTest {
                             // ignore.
                           }
                         }
+                        for (ShuffleBlockInfo block : 
event.getShuffleDataInfoList()) {
+                          block.executeCompletionCallback(true);
+                        }
                         event.getProcessedCallbackChain().stream().forEach(x 
-> x.run());
                         sum +=
                             event.getShuffleDataInfoList().stream()
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 0b4faef82..1b4df1747 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -1264,4 +1264,9 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   public boolean isRssResubmitStage() {
     return rssResubmitStage;
   }
+
+  @VisibleForTesting
+  public void setDataPusher(DataPusher dataPusher) {
+    this.dataPusher = dataPusher;
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 635b3593a..8a22b73ba 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -21,10 +21,10 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
@@ -46,6 +46,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
@@ -83,6 +84,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
 public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
@@ -94,7 +96,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final String appId;
   private final int shuffleId;
   private WriteBufferManager bufferManager;
-  private final String taskId;
+  private String taskId;
   private final int numMaps;
   private final ShuffleDependency<K, V, C> shuffleDependency;
   private final Partitioner partitioner;
@@ -113,7 +115,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
   private TaskContext taskContext;
   private SparkConf sparkConf;
-  private boolean blockSendFailureRetryEnabled;
+  private boolean blockFailSentRetryEnabled;
+  private int blockFailSentRetryMaxTimes = 1;
 
   /** used by columnar rss shuffle writer implementation */
   protected final long taskAttemptId;
@@ -122,7 +125,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private final BlockingQueue<Object> finishEventQueue = new 
LinkedBlockingQueue<>();
 
-  private final Map<String, ShuffleServerInfo> faultyServers = new HashMap<>();
+  // shuffleServerId -> failoverShuffleServer
+  private final Map<String, ShuffleServerInfo> replacementShuffleServers =
+      JavaUtils.newConcurrentMap();
 
   // Only for tests
   @VisibleForTesting
@@ -192,7 +197,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.taskFailureCallback = taskFailureCallback;
     this.taskContext = context;
     this.sparkConf = sparkConf;
-    this.blockSendFailureRetryEnabled =
+    this.blockFailSentRetryEnabled =
         sparkConf.getBoolean(
             RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
                 + 
RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(),
@@ -269,8 +274,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     long recordCount = 0;
     while (records.hasNext()) {
       recordCount++;
-      // Task should fast fail when sending data failed
-      checkIfBlocksFailed();
+
+      checkDataIfAnyFailure();
 
       Product2<K, V> record = records.next();
       K key = record._1();
@@ -363,6 +368,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       List<ShuffleBlockInfo> shuffleBlockInfoList) {
     List<CompletableFuture<Long>> futures = new ArrayList<>();
     for (AddBlockEvent event : 
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
+      if (blockFailSentRetryEnabled) {
+        // do nothing if failed.
+        for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+          block.withCompletionCallback(
+              (completionBlock, isSuccessful) -> {
+                if (isSuccessful) {
+                  bufferManager.releaseBlockResource(completionBlock);
+                }
+              });
+        }
+      }
       event.addCallback(
           () -> {
             boolean ret = finishEventQueue.add(new Object());
@@ -386,7 +402,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       while (true) {
         try {
           finishEventQueue.clear();
-          checkIfBlocksFailed();
+          checkDataIfAnyFailure();
           Set<Long> successBlockIds = 
shuffleManager.getSuccessBlockIds(taskId);
           blockIds.removeAll(successBlockIds);
           if (blockIds.isEmpty()) {
@@ -422,105 +438,128 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
-  private void checkIfBlocksFailed() {
-    Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
-    if (blockSendFailureRetryEnabled && !failedBlockIds.isEmpty()) {
-      Set<TrackingBlockStatus> shouldResendBlockSet = 
shouldResendBlockStatusSet(failedBlockIds);
-      try {
-        reSendFailedBlockIds(shouldResendBlockSet);
-      } catch (Exception e) {
-        LOG.error("resend failed blocks failed.", e);
+  private void checkDataIfAnyFailure() {
+    if (blockFailSentRetryEnabled) {
+      collectFailedBlocksToResend();
+    } else {
+      if (hasAnyBlockFailure()) {
+        throw new RssSendFailedException("Fail to send the block");
       }
-      failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
     }
+  }
+
+  private boolean hasAnyBlockFailure() {
+    Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
     if (!failedBlockIds.isEmpty()) {
-      String errorMsg =
-          "Send failed: Task["
-              + taskId
-              + "]"
-              + " failed because "
-              + failedBlockIds.size()
-              + " blocks can't be sent to shuffle server: "
-              + 
shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers();
-      LOG.error(errorMsg);
-      throw new RssSendFailedException(errorMsg);
+      LOG.error(
+          "Errors on sending blocks for task[{}]. {} blocks can't be sent to 
remote servers: {}",
+          taskId,
+          failedBlockIds.size(),
+          
shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers());
+      return true;
     }
+    return false;
   }
 
-  private Set<TrackingBlockStatus> shouldResendBlockStatusSet(Set<Long> 
failedBlockIds) {
-    FailedBlockSendTracker failedBlockTracker = 
shuffleManager.getBlockIdsFailedSendTracker(taskId);
-    Set<TrackingBlockStatus> resendBlockStatusSet = Sets.newHashSet();
-    for (Long failedBlockId : failedBlockIds) {
-      failedBlockTracker.getFailedBlockStatus(failedBlockId).stream()
-          // todo: more status need reassign
-          .filter(
-              trackingBlockStatus -> trackingBlockStatus.getStatusCode() == 
StatusCode.NO_BUFFER)
-          .forEach(trackingBlockStatus -> 
resendBlockStatusSet.add(trackingBlockStatus));
+  private void collectFailedBlocksToResend() {
+    if (!blockFailSentRetryEnabled) {
+      return;
+    }
+
+    FailedBlockSendTracker failedTracker = 
shuffleManager.getBlockIdsFailedSendTracker(taskId);
+    Set<Long> failedBlockIds = failedTracker.getFailedBlockIds();
+    if (CollectionUtils.isEmpty(failedBlockIds)) {
+      return;
+    }
+
+    boolean isFastFail = false;
+    Set<TrackingBlockStatus> resendCandidates = new HashSet<>();
+    // to check whether the blocks resent exceed the max resend count.
+    for (Long blockId : failedBlockIds) {
+      List<TrackingBlockStatus> failedBlockStatus = 
failedTracker.getFailedBlockStatus(blockId);
+      int retryIndex = 
failedBlockStatus.get(0).getShuffleBlockInfo().getRetryCnt();
+      // todo: support retry times by config
+      if (retryIndex >= blockFailSentRetryMaxTimes) {
+        LOG.error(
+            "Partial blocks for taskId: [{}] retry exceeding the max retry 
times: [{}]. Fast fail! faulty server list: {}",
+            taskId,
+            blockFailSentRetryMaxTimes,
+            failedBlockStatus.stream()
+                .map(x -> x.getShuffleServerInfo())
+                .collect(Collectors.toSet()));
+        isFastFail = true;
+        break;
+      }
+
+      // todo: if setting multi replica and another replica is succeed to 
send, no need to resend
+      resendCandidates.addAll(failedBlockStatus);
     }
-    return resendBlockStatusSet;
+
+    if (isFastFail) {
+      // release data and allocated memory
+      for (Long blockId : failedBlockIds) {
+        List<TrackingBlockStatus> failedBlockStatus = 
failedTracker.getFailedBlockStatus(blockId);
+        Optional<TrackingBlockStatus> blockStatus = 
failedBlockStatus.stream().findFirst();
+        if (blockStatus.isPresent()) {
+          
blockStatus.get().getShuffleBlockInfo().executeCompletionCallback(true);
+        }
+      }
+
+      throw new RssSendFailedException(
+          "Errors on resending the blocks data to the remote shuffle-server.");
+    }
+
+    resendFailedBlocks(resendCandidates);
   }
 
-  private void reSendFailedBlockIds(Set<TrackingBlockStatus> 
failedBlockStatusSet) {
-    List<ShuffleBlockInfo> reAssignSeverBlockInfoList = Lists.newArrayList();
-    List<ShuffleBlockInfo> failedBlockInfoList = Lists.newArrayList();
+  private void resendFailedBlocks(Set<TrackingBlockStatus> 
failedBlockStatusSet) {
+    List<ShuffleBlockInfo> reassignBlocks = Lists.newArrayList();
     Map<ShuffleServerInfo, List<TrackingBlockStatus>> faultyServerToPartitions 
=
         failedBlockStatusSet.stream().collect(Collectors.groupingBy(d -> 
d.getShuffleServerInfo()));
-    faultyServerToPartitions.entrySet().stream()
-        .forEach(
-            t -> {
-              Set<Integer> partitionIds =
-                  t.getValue().stream()
-                      .map(x -> x.getShuffleBlockInfo().getPartitionId())
-                      .collect(Collectors.toSet());
-              ShuffleServerInfo dynamicShuffleServer = 
faultyServers.get(t.getKey().getId());
-              if (dynamicShuffleServer == null) {
-                dynamicShuffleServer =
-                    reAssignFaultyShuffleServer(partitionIds, 
t.getKey().getId());
-                faultyServers.put(t.getKey().getId(), dynamicShuffleServer);
-              }
-
-              ShuffleServerInfo finalDynamicShuffleServer = 
dynamicShuffleServer;
-              failedBlockStatusSet.forEach(
-                  trackingBlockStatus -> {
-                    ShuffleBlockInfo failedBlockInfo = 
trackingBlockStatus.getShuffleBlockInfo();
-                    failedBlockInfoList.add(failedBlockInfo);
-                    reAssignSeverBlockInfoList.add(
-                        new ShuffleBlockInfo(
-                            failedBlockInfo.getShuffleId(),
-                            failedBlockInfo.getPartitionId(),
-                            failedBlockInfo.getBlockId(),
-                            failedBlockInfo.getLength(),
-                            failedBlockInfo.getCrc(),
-                            failedBlockInfo.getData(),
-                            Lists.newArrayList(finalDynamicShuffleServer),
-                            failedBlockInfo.getUncompressLength(),
-                            failedBlockInfo.getFreeMemory(),
-                            taskAttemptId));
-                  });
-            });
-    clearFailedBlockIdsStates(failedBlockInfoList, faultyServers);
-    processShuffleBlockInfos(reAssignSeverBlockInfoList);
-    checkIfBlocksFailed();
+
+    for (Map.Entry<ShuffleServerInfo, List<TrackingBlockStatus>> entry :
+        faultyServerToPartitions.entrySet()) {
+      Set<Integer> partitionIds =
+          entry.getValue().stream()
+              .map(x -> x.getShuffleBlockInfo().getPartitionId())
+              .collect(Collectors.toSet());
+      ShuffleServerInfo replacement = 
replacementShuffleServers.get(entry.getKey().getId());
+      if (replacement == null) {
+        // todo: merge multiple requests into one.
+        replacement = reassignFaultyShuffleServer(partitionIds, 
entry.getKey().getId());
+        replacementShuffleServers.put(entry.getKey().getId(), replacement);
+      }
+
+      for (TrackingBlockStatus blockStatus : failedBlockStatusSet) {
+        // clear the previous retry state of block
+        ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
+        clearFailedBlockState(block);
+
+        final ShuffleBlockInfo newBlock = block;
+        newBlock.incrRetryCnt();
+        newBlock.reassignShuffleServers(Arrays.asList(replacement));
+
+        reassignBlocks.add(newBlock);
+      }
+    }
+
+    processShuffleBlockInfos(reassignBlocks);
   }
 
-  private void clearFailedBlockIdsStates(
-      List<ShuffleBlockInfo> failedBlockInfoList, Map<String, 
ShuffleServerInfo> faultyServers) {
-    failedBlockInfoList.forEach(
-        shuffleBlockInfo -> {
-          
shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(shuffleBlockInfo.getBlockId());
-          shuffleBlockInfo.getShuffleServerInfos().stream()
-              .filter(s -> faultyServers.containsKey(s.getId()))
-              .forEach(
-                  s ->
-                      serverToPartitionToBlockIds
-                          .get(s)
-                          .get(shuffleBlockInfo.getPartitionId())
-                          .remove(shuffleBlockInfo.getBlockId()));
-          partitionLengths[shuffleBlockInfo.getPartitionId()] -= 
shuffleBlockInfo.getLength();
-        });
+  private void clearFailedBlockState(ShuffleBlockInfo block) {
+    
shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
+    block.getShuffleServerInfos().stream()
+        .filter(s -> replacementShuffleServers.containsKey(s.getId()))
+        .forEach(
+            s ->
+                serverToPartitionToBlockIds
+                    .get(s)
+                    .get(block.getPartitionId())
+                    .remove(block.getBlockId()));
+    partitionLengths[block.getPartitionId()] -= block.getLength();
   }
 
-  private ShuffleServerInfo reAssignFaultyShuffleServer(
+  private ShuffleServerInfo reassignFaultyShuffleServer(
       Set<Integer> partitionIds, String faultyServerId) {
     RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
     String driver = rssConf.getString("driver.host", "");
@@ -611,6 +650,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         return Option.empty();
       }
     } finally {
+      if (blockFailSentRetryEnabled) {
+        if (success) {
+          if 
(CollectionUtils.isNotEmpty(shuffleManager.getFailedBlockIds(taskId))) {
+            LOG.error(
+                "Errors on stopping writer due to the remaining failed 
blockIds. This should not happen.");
+            return Option.empty();
+          }
+        } else {
+          
shuffleManager.getBlockIdsFailedSendTracker(taskId).clearAndReleaseBlockResources();
+        }
+      }
       // free all memory & metadata, or memory leak happen in executor
       if (bufferManager != null) {
         bufferManager.freeAllMemory();
@@ -694,4 +744,29 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
     throw new RssException(e);
   }
+
+  @VisibleForTesting
+  protected void enableBlockFailSentRetry() {
+    this.blockFailSentRetryEnabled = true;
+  }
+
+  @VisibleForTesting
+  protected void setBlockFailSentRetryMaxTimes(int blockFailSentRetryMaxTimes) 
{
+    this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes;
+  }
+
+  @VisibleForTesting
+  protected void addReassignmentShuffleServer(String shuffleId, 
ShuffleServerInfo replacement) {
+    replacementShuffleServers.put(shuffleId, replacement);
+  }
+
+  @VisibleForTesting
+  protected void setTaskId(String taskId) {
+    this.taskId = taskId;
+  }
+
+  @VisibleForTesting
+  protected Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
getServerToPartitionToBlockIds() {
+    return serverToPartitionToBlockIds;
+  }
 }
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index b68d4b74e..5ca85eced 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -26,6 +26,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -64,6 +65,7 @@ import org.apache.uniffle.storage.util.StorageType;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
 import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
@@ -73,6 +75,198 @@ import static org.mockito.Mockito.when;
 
 public class RssShuffleWriterTest {
 
+  private MutableList<Product2<String, String>> createMockRecords() {
+    MutableList<Product2<String, String>> data = new MutableList<>();
+    data.appendElem(new Tuple2<>("testKey2", "testValue2"));
+    data.appendElem(new Tuple2<>("testKey3", "testValue3"));
+    data.appendElem(new Tuple2<>("testKey4", "testValue4"));
+    data.appendElem(new Tuple2<>("testKey6", "testValue6"));
+    data.appendElem(new Tuple2<>("testKey1", "testValue1"));
+    data.appendElem(new Tuple2<>("testKey5", "testValue5"));
+    return data;
+  }
+
+  @Test
+  public void blockFailureResendTest() throws Exception {
+    SparkConf conf = new SparkConf();
+    conf.setAppName("testApp")
+        .setMaster("local[2]")
+        .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+        .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64")
+        .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name());
+
+    List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
+    Map<String, Set<Long>> successBlockIds = JavaUtils.newConcurrentMap();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = 
JavaUtils.newConcurrentMap();
+    taskToFailedBlockSendTracker.put("taskId", new FailedBlockSendTracker());
+
+    AtomicInteger sentFailureCnt = new AtomicInteger();
+    FakedDataPusher dataPusher =
+        new FakedDataPusher(
+            event -> {
+              assertEquals("taskId", event.getTaskId());
+              FailedBlockSendTracker tracker = 
taskToFailedBlockSendTracker.get(event.getTaskId());
+              for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+                boolean isSuccessful = true;
+                ShuffleServerInfo shuffleServer = 
block.getShuffleServerInfos().get(0);
+                if (shuffleServer.getId().equals("id1") && block.getRetryCnt() 
== 0) {
+                  tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
+                  sentFailureCnt.addAndGet(1);
+                  isSuccessful = false;
+                } else {
+                  successBlockIds.putIfAbsent(event.getTaskId(), 
Sets.newConcurrentHashSet());
+                  
successBlockIds.get(event.getTaskId()).add(block.getBlockId());
+                  shuffleBlockInfos.add(block);
+                }
+                block.executeCompletionCallback(isSuccessful);
+              }
+              return new CompletableFuture<>();
+            });
+
+    final RssShuffleManager manager =
+        TestUtils.createShuffleManager(
+            conf, false, dataPusher, successBlockIds, 
taskToFailedBlockSendTracker);
+    Serializer kryoSerializer = new KryoSerializer(conf);
+    Partitioner mockPartitioner = mock(Partitioner.class);
+    final ShuffleWriteClient mockShuffleWriteClient = 
mock(ShuffleWriteClient.class);
+    ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
+    RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
+    when(mockHandle.getDependency()).thenReturn(mockDependency);
+    when(mockDependency.serializer()).thenReturn(kryoSerializer);
+    when(mockDependency.partitioner()).thenReturn(mockPartitioner);
+    when(mockPartitioner.numPartitions()).thenReturn(3);
+
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = 
Maps.newHashMap();
+    List<ShuffleServerInfo> ssi12 =
+        Arrays.asList(
+            new ShuffleServerInfo("id1", "0.0.0.1", 100),
+            new ShuffleServerInfo("id2", "0.0.0.2", 100));
+    partitionToServers.put(0, ssi12);
+    List<ShuffleServerInfo> ssi34 =
+        Arrays.asList(
+            new ShuffleServerInfo("id3", "0.0.0.3", 100),
+            new ShuffleServerInfo("id4", "0.0.0.4", 100));
+    partitionToServers.put(1, ssi34);
+    List<ShuffleServerInfo> ssi56 =
+        Arrays.asList(
+            new ShuffleServerInfo("id5", "0.0.0.5", 100),
+            new ShuffleServerInfo("id6", "0.0.0.6", 100));
+    partitionToServers.put(2, ssi56);
+    when(mockPartitioner.getPartition("testKey1")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey2")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey4")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey5")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey3")).thenReturn(2);
+    when(mockPartitioner.getPartition("testKey7")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey8")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey9")).thenReturn(2);
+    when(mockPartitioner.getPartition("testKey6")).thenReturn(2);
+
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
+    WriteBufferManager bufferManager =
+        new WriteBufferManager(
+            0,
+            0,
+            bufferOptions,
+            kryoSerializer,
+            partitionToServers,
+            mockTaskMemoryManager,
+            shuffleWriteMetrics,
+            RssSparkConfig.toRssConf(conf));
+    bufferManager.setTaskId("taskId");
+
+    WriteBufferManager bufferManagerSpy = spy(bufferManager);
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    RssShuffleWriter<String, String, String> rssShuffleWriter =
+        new RssShuffleWriter<>(
+            "appId",
+            0,
+            "taskId",
+            1L,
+            bufferManagerSpy,
+            shuffleWriteMetrics,
+            manager,
+            conf,
+            mockShuffleWriteClient,
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
+    rssShuffleWriter.enableBlockFailSentRetry();
+    doReturn(100000L).when(bufferManagerSpy).acquireMemory(anyLong());
+    ShuffleServerInfo replacement = new ShuffleServerInfo("id10", "0.0.0.10", 
100);
+    rssShuffleWriter.addReassignmentShuffleServer("id1", replacement);
+
+    RssShuffleWriter<String, String, String> rssShuffleWriterSpy = 
spy(rssShuffleWriter);
+    doNothing().when(rssShuffleWriterSpy).sendCommit();
+
+    // case 1. failed blocks will be resent
+    MutableList<Product2<String, String>> data = createMockRecords();
+    rssShuffleWriterSpy.write(data.iterator());
+
+    Awaitility.await()
+        .timeout(Duration.ofSeconds(5))
+        .until(() -> successBlockIds.get("taskId").size() == data.size());
+    assertEquals(2, sentFailureCnt.get());
+    assertEquals(0, 
taskToFailedBlockSendTracker.get("taskId").getFailedBlockIds().size());
+    assertEquals(6, shuffleWriteMetrics.recordsWritten());
+    assertEquals(
+        shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
+        shuffleWriteMetrics.bytesWritten());
+    assertEquals(6, shuffleBlockInfos.size());
+
+    assertEquals(0, bufferManagerSpy.getUsedBytes());
+    assertEquals(0, bufferManagerSpy.getInSendListBytes());
+
+    // check the blockId -> servers mapping.
+    // server -> partitionId -> blockIds
+    Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds =
+        rssShuffleWriterSpy.getServerToPartitionToBlockIds();
+    assertEquals(2, 
serverToPartitionToBlockIds.get(replacement).get(0).size());
+
+    // case2. If exceeding the max retry times, it will fast fail.
+    rssShuffleWriterSpy.setBlockFailSentRetryMaxTimes(1);
+    rssShuffleWriterSpy.setTaskId("taskId2");
+    FakedDataPusher alwaysFailedDataPusher =
+        new FakedDataPusher(
+            event -> {
+              assertEquals("taskId2", event.getTaskId());
+              FailedBlockSendTracker tracker = 
taskToFailedBlockSendTracker.get(event.getTaskId());
+              for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+                boolean isSuccessful = true;
+                ShuffleServerInfo shuffleServer = 
block.getShuffleServerInfos().get(0);
+                if (shuffleServer.getId().equals("id1")) {
+                  tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
+                  isSuccessful = false;
+                } else {
+                  successBlockIds.putIfAbsent(event.getTaskId(), 
Sets.newConcurrentHashSet());
+                  
successBlockIds.get(event.getTaskId()).add(block.getBlockId());
+                }
+                block.executeCompletionCallback(isSuccessful);
+              }
+              return new CompletableFuture<>();
+            });
+    manager.setDataPusher(alwaysFailedDataPusher);
+
+    MutableList<Product2<String, String>> mockedData = createMockRecords();
+    try {
+      rssShuffleWriterSpy.write(mockedData.iterator());
+      fail();
+    } catch (Exception e) {
+      // ignore
+    }
+    assertEquals(0, bufferManagerSpy.getUsedBytes());
+    assertEquals(0, bufferManagerSpy.getInSendListBytes());
+  }
+
   @Test
   public void checkBlockSendResultTest() {
     SparkConf conf = new SparkConf();
@@ -161,8 +355,7 @@ public class RssShuffleWriterTest {
         assertThrows(
             RuntimeException.class,
             () -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 
2L, 3L)));
-    System.out.println(e2.getMessage());
-    assertTrue(e3.getMessage().startsWith("Send failed:"));
+    assertTrue(e3.getMessage().startsWith("Fail to send the block"));
     successBlocks.clear();
     taskToFailedBlockSendTracker.clear();
   }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
index 0c239c7e1..93e20dd02 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
@@ -32,6 +32,12 @@ import org.apache.uniffle.common.rpc.StatusCode;
 
 public class FailedBlockSendTracker {
 
+  /**
+   * blockId -> list(trackingStatus)
+   *
+   * <p>This indicates the blockId latest sending status, and it will not 
store the resending
+   * history. The list data structure is to describe the multiple servers for 
the multiple replica
+   */
   private Map<Long, List<TrackingBlockStatus>> trackingBlockStatusMap;
 
   public FailedBlockSendTracker() {
@@ -55,7 +61,10 @@ public class FailedBlockSendTracker {
     trackingBlockStatusMap.remove(blockId);
   }
 
-  public void clear() {
+  public void clearAndReleaseBlockResources() {
+    trackingBlockStatusMap.values().stream()
+        .flatMap(x -> x.stream())
+        .forEach(x -> x.getShuffleBlockInfo().executeCompletionCallback(true));
     trackingBlockStatusMap.clear();
   }
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java 
b/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java
new file mode 100644
index 000000000..01ba694c3
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+public interface BlockCompletionCallback {
+  void onBlockCompletion(ShuffleBlockInfo block, boolean isSuccessful);
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index 8de75d90d..36dec5e25 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -36,6 +36,9 @@ public class ShuffleBlockInfo {
   private List<ShuffleServerInfo> shuffleServerInfos;
   private int uncompressLength;
   private long freeMemory;
+  private int retryCnt = 0;
+
+  private transient BlockCompletionCallback completionCallback;
 
   public ShuffleBlockInfo(
       int shuffleId,
@@ -153,7 +156,30 @@ public class ShuffleBlockInfo {
     return sb.toString();
   }
 
+  public void incrRetryCnt() {
+    this.retryCnt += 1;
+  }
+
+  public int getRetryCnt() {
+    return retryCnt;
+  }
+
+  public void reassignShuffleServers(List<ShuffleServerInfo> replacements) {
+    this.shuffleServerInfos = replacements;
+  }
+
   public synchronized void copyDataTo(ByteBuf to) {
     ByteBufUtils.copyByteBuf(data, to);
   }
+
+  public void withCompletionCallback(BlockCompletionCallback callback) {
+    this.completionCallback = callback;
+  }
+
+  public void executeCompletionCallback(boolean isSuccessful) {
+    if (completionCallback == null) {
+      return;
+    }
+    completionCallback.onBlockCompletion(this, isSuccessful);
+  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java 
b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java
new file mode 100644
index 000000000..2a4638702
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.function;
+
+@FunctionalInterface
+public interface TupleConsumer<T, F> {
+  void accept(T t, F f);
+}


Reply via email to