This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c0c27dea [#2725] fix(spark)(partition-split): Add fallback under 
load-balance mode and fix stale assignment missing callback that caused timeout 
(#2729)
9c0c27dea is described below

commit 9c0c27dead125a50f4140b342c470e0ed13ec8f6
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Feb 12 15:33:26 2026 +0800

    [#2725] fix(spark)(partition-split): Add fallback under load-balance mode 
and fix stale assignment missing callback that caused timeout (#2729)
    
    ### What changes were proposed in this pull request?
    
    1. Fallback to random server when no servers are available in load-balance 
mode
    2. Fix stale assignment missing callback in data pusher that caused the 
writer to hang until timeout, preventing reassign from being triggered
    
    ### Why are the changes needed?
    
    fix the #2725 . Finally tracked down and fixed this tricky bug after a 
thorough investigation.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests
---
 .../shuffle/handle/MutableShuffleHandleInfo.java   |  9 +++
 .../apache/spark/shuffle/writer/DataPusher.java    | 36 ++++++------
 .../shuffle/writer/TaskAttemptAssignment.java      |  9 +--
 .../apache/uniffle/shuffle/ReassignExecutor.java   | 31 +++++++----
 .../handle/MutableShuffleHandleInfoTest.java       | 42 ++++++++++++++
 .../spark/shuffle/writer/DataPusherTest.java       | 65 ++++++++++++++++++++++
 .../spark/shuffle/writer/RssShuffleWriter.java     | 35 +++++++-----
 .../spark/shuffle/writer/RssShuffleWriterTest.java |  2 +-
 8 files changed, 184 insertions(+), 45 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
index 4872dc171..987497cef 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
@@ -259,6 +259,15 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
             // 0, 1, 2
             int idx = (int) (taskAttemptId % (serverSize - 1)) + 1;
             candidate = servers.get(idx);
+          } else {
+            // fallback to random server if no available servers in 
load-balanced mode
+            servers =
+                replicaServerEntry.getValue().stream()
+                    .filter(x -> 
!excludedServerToReplacements.containsKey(x.getId()))
+                    .collect(Collectors.toList());
+            serverSize = servers.size();
+            int idx = (int) (taskAttemptId % (serverSize - 1)) + 1;
+            candidate = servers.get(idx);
           }
         }
 
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index df13e0f39..963000e8c 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -124,27 +124,26 @@ public class DataPusher implements Closeable {
               String taskId = event.getTaskId();
               List<ShuffleBlockInfo> blocks = event.getShuffleDataInfoList();
               List<ShuffleBlockInfo> validBlocks = 
filterOutStaleAssignmentBlocks(taskId, blocks);
-              if (CollectionUtils.isEmpty(validBlocks)) {
-                return 0L;
-              }
 
               SendShuffleDataResult result = null;
               try {
-                result =
-                    shuffleWriteClient.sendShuffleData(
-                        rssAppId,
-                        event.getStageAttemptNumber(),
-                        validBlocks,
-                        () -> !isValidTask(taskId));
-                // completionCallback should be executed before updating 
taskToSuccessBlockIds
-                // structure to avoid side effect
-                Set<Long> succeedBlockIds = getSucceedBlockIds(result);
-                for (ShuffleBlockInfo block : validBlocks) {
-                  
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
+                if (CollectionUtils.isNotEmpty(validBlocks)) {
+                  result =
+                      shuffleWriteClient.sendShuffleData(
+                          rssAppId,
+                          event.getStageAttemptNumber(),
+                          validBlocks,
+                          () -> !isValidTask(taskId));
+                  // completionCallback should be executed before updating 
taskToSuccessBlockIds
+                  // structure to avoid side effect
+                  Set<Long> succeedBlockIds = getSucceedBlockIds(result);
+                  for (ShuffleBlockInfo block : validBlocks) {
+                    
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
+                  }
+                  putBlockId(taskToSuccessBlockIds, taskId, 
result.getSuccessBlockIds());
+                  putFailedBlockSendTracker(
+                      taskToFailedBlockSendTracker, taskId, 
result.getFailedBlockSendTracker());
                 }
-                putBlockId(taskToSuccessBlockIds, taskId, 
result.getSuccessBlockIds());
-                putFailedBlockSendTracker(
-                    taskToFailedBlockSendTracker, taskId, 
result.getFailedBlockSendTracker());
               } finally {
                 WriteBufferManager bufferManager = event.getBufferManager();
                 if (bufferManager != null && result != null) {
@@ -159,6 +158,9 @@ public class DataPusher implements Closeable {
                   runnable.run();
                 }
               }
+              if (CollectionUtils.isEmpty(validBlocks)) {
+                return 0L;
+              }
               Set<Long> succeedBlockIds = getSucceedBlockIds(result);
               return validBlocks.stream()
                   .filter(x -> succeedBlockIds.contains(x.getBlockId()))
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
index 63fac0c12..2faac350b 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
@@ -91,10 +91,11 @@ public class TaskAttemptAssignment {
   public boolean tryNextServerForSplitPartition(
       int partitionId, List<ShuffleServerInfo> exclusiveServers) {
     if (hasBeenLoadBalanced(partitionId)) {
-      Set<ShuffleServerInfo> servers =
-          this.exclusiveServersForPartition.computeIfAbsent(
-              partitionId, k -> new ConcurrentSkipListSet<>());
-      servers.addAll(exclusiveServers);
+      // update the exclusive servers
+      this.exclusiveServersForPartition
+          .computeIfAbsent(partitionId, k -> new ConcurrentSkipListSet<>())
+          .addAll(exclusiveServers);
+      // update the assignment due to the upper exclusive servers change
       update(this.handle);
       return true;
     }
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
index c380e3ee7..57bd25641 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
@@ -33,6 +33,7 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
@@ -306,19 +307,22 @@ public class ReassignExecutor {
       }
     }
 
+    String readableMessage = readableResult(fastSwitchList);
     if (reassignList.isEmpty()) {
-      LOG.info(
-          "[partition-split] All fast switch to another servers successfully 
for taskId[{}]. list: {}",
-          taskId,
-          readableResult(fastSwitchList));
-      return;
-    } else {
-      if (!fastSwitchList.isEmpty()) {
+      if (StringUtils.isNotEmpty(readableMessage)) {
         LOG.info(
-            "[partition-split] Partial fast switch to another servers for 
taskId[{}]. list: {}",
+            "[partition-split] All partitions fast-switched successfully for 
taskId[{}]. list: {}",
             taskId,
-            readableResult(fastSwitchList));
+            readableMessage);
       }
+      return;
+    }
+
+    if (StringUtils.isNotEmpty(readableMessage)) {
+      LOG.info(
+          "[partition-split] Partial partitions fast-switched for taskId[{}]. 
list: {}",
+          taskId,
+          readableMessage);
     }
 
     @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
@@ -385,6 +389,7 @@ public class ReassignExecutor {
     List<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
     Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks =
         blocks.stream()
+            .filter(x -> x.getStatusCode() != null)
             .collect(Collectors.groupingBy(d -> 
d.getShuffleBlockInfo().getPartitionId()));
 
     Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new 
HashMap<>();
@@ -429,8 +434,12 @@ public class ReassignExecutor {
           readableResult(constructUpdateList(failurePartitionToServers)));
     }
 
+    int staleCnt = 0;
     for (TrackingBlockStatus blockStatus : blocks) {
       ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
+      if (blockStatus.getStatusCode() == null) {
+        staleCnt += 1;
+      }
       // todo: getting the replacement should support multi replica.
       List<ShuffleServerInfo> servers = 
taskAttemptAssignment.retrieve(block.getPartitionId());
       // Gets the first replica for this partition for now.
@@ -459,8 +468,10 @@ public class ReassignExecutor {
     }
     resendBlocksFunction.accept(resendCandidates);
     LOG.info(
-        "[partition-reassign] All {} blocks have been resent to queue 
successfully in {} ms.",
+        "[partition-reassign] {} blocks (failed/stale: {}/{}) have been resent 
to queue successfully in {} ms.",
         blocks.size(),
+        blocks.size() - staleCnt,
+        staleCnt,
         System.currentTimeMillis() - start);
   }
 
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
index cf75152e7..63dc247d3 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
@@ -276,4 +276,46 @@ public class MutableShuffleHandleInfoTest {
     // All the servers were selected as writer are available as reader
     assertEquals(6, assignment.get(1).size());
   }
+
+  @Test
+  public void testLoadBalanceFallbackToNonExcludedServers() {
+    // prepare servers
+    ShuffleServerInfo a = createFakeServerInfo("a");
+    ShuffleServerInfo b = createFakeServerInfo("b");
+
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(a, b));
+
+    // create handle with LOAD_BALANCE mode
+    MutableShuffleHandleInfo handleInfo =
+        new MutableShuffleHandleInfo(
+            1,
+            partitionToServers,
+            new RemoteStorageInfo(""),
+            org.apache.uniffle.common.PartitionSplitMode.LOAD_BALANCE);
+
+    int partitionId = 1;
+
+    // mark partition as split by excluding server "a"
+    Set<ShuffleServerInfo> replacements = 
Sets.newHashSet(createFakeServerInfo("c"));
+    handleInfo.updateAssignmentOnPartitionSplit(partitionId, "a", 
replacements);
+
+    // also make sure excludedServerToReplacements contains "b"
+    // so that first filtering (exclude problem nodes) removes all servers
+    handleInfo.updateAssignment(partitionId, "b", 
Sets.newHashSet(createFakeServerInfo("d")));
+
+    // now call writer assignment
+    Map<Integer, List<ShuffleServerInfo>> available =
+        handleInfo.getAvailablePartitionServersForWriter(null);
+
+    // fallback branch should be triggered and still return a valid candidate
+    // ensure we have exactly one candidate for replica 0
+    assertTrue(available.containsKey(partitionId));
+    assertEquals(2, available.get(partitionId).size());
+
+    // candidate must be one of the original servers or appended replacements, 
rather than always
+    // the last one
+    ShuffleServerInfo candidate = available.get(partitionId).get(0);
+    assertEquals("c", candidate.getId());
+  }
 }
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index eb357d9da..720bad4e3 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -25,6 +25,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Supplier;
 
 import com.google.common.collect.Maps;
@@ -130,6 +131,70 @@ public class DataPusherTest {
     assertEquals(3, failedBlockIds.stream().findFirst().get());
   }
 
+  /**
+   * Test that when all blocks in a batch are stale (filtered out by 
fast-switch), the
+   * processedCallbackChain is still executed. Before the fix, if all blocks 
were stale, the early
+   * return skipped the finally block, causing the callback (which notifies 
checkBlockSendResult via
+   * finishEventQueue) to never run. This led to checkBlockSendResult blocking 
indefinitely on
+   * poll(), unable to call reassign() to resend the stale blocks, ultimately 
timing out.
+   */
+  @Test
+  public void testProcessedCallbackChainExecutedWhenAllBlocksAreStale()
+      throws ExecutionException, InterruptedException {
+    FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
+
+    Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = 
JavaUtils.newConcurrentMap();
+    Set<String> failedTaskIds = new HashSet<>();
+
+    RssConf rssConf = new RssConf();
+    rssConf.set(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED, true);
+    
rssConf.set(RssSparkConfig.RSS_PARTITION_REASSIGN_STALE_ASSIGNMENT_FAST_SWITCH_ENABLED,
 true);
+    DataPusher dataPusher =
+        new DataPusher(
+            shuffleWriteClient,
+            taskToSuccessBlockIds,
+            taskToFailedBlockSendTracker,
+            failedTaskIds,
+            1,
+            2,
+            rssConf);
+    dataPusher.setRssAppId("testCallbackWhenAllStale");
+
+    String taskId = "taskId1";
+    List<ShuffleServerInfo> server1 =
+        Collections.singletonList(new ShuffleServerInfo("0", "localhost", 
1234));
+    // Create a stale block: isStaleAssignment() returns true because the
+    // partitionAssignmentRetrieveFunc returns an empty list (different from 
the block's servers).
+    ShuffleBlockInfo staleBlock =
+        new ShuffleBlockInfo(
+            1, 1, 10, 1, 1, new byte[1], server1, 1, 100, 1, integer -> 
Collections.emptyList());
+
+    // Track whether processedCallbackChain is invoked
+    AtomicBoolean callbackExecuted = new AtomicBoolean(false);
+    AddBlockEvent event = new AddBlockEvent(taskId, Arrays.asList(staleBlock));
+    event.addCallback(() -> callbackExecuted.set(true));
+
+    CompletableFuture<Long> future = dataPusher.send(event);
+    long result = future.get();
+
+    // The block is stale, so no data is actually sent (0 bytes freed)
+    assertEquals(0L, result);
+
+    // The stale block should be tracked in the FailedBlockSendTracker
+    Set<Long> failedBlockIds = 
taskToFailedBlockSendTracker.get(taskId).getFailedBlockIds();
+    assertEquals(1, failedBlockIds.size());
+    assertEquals(10, failedBlockIds.stream().findFirst().get());
+
+    // The processedCallbackChain MUST be executed even when all blocks are 
stale.
+    // Before the fix, this assertion would fail because the early return 
(return 0L)
+    // was placed before the try-finally that executes the callback chain.
+    assertTrue(
+        callbackExecuted.get(),
+        "processedCallbackChain must be executed even when all blocks are 
stale, "
+            + "otherwise checkBlockSendResult will block on 
finishEventQueue.poll() indefinitely");
+  }
+
   @Test
   public void testSendData() throws ExecutionException, InterruptedException {
     FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 60f3ef46f..03d92e8a3 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -81,6 +81,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.shuffle.BlockStats;
 import org.apache.uniffle.shuffle.ReassignExecutor;
 import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
@@ -518,8 +519,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       }
       event.addCallback(
           () -> {
-            boolean ret = finishEventQueue.add(new Object());
-            if (!ret) {
+            if (!finishEventQueue.add(new Object())) {
               LOG.error("Add event " + event + " to finishEventQueue fail");
             }
           });
@@ -572,17 +572,26 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       }
       Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId);
       if (currentAckValue != 0 || blockIds.size() != successBlockIds.size()) {
-        int failedBlockCount = blockIds.size() - successBlockIds.size();
-        String errorMsg =
-            "Timeout: Task["
-                + taskId
-                + "] failed because "
-                + failedBlockCount
-                + " blocks can't be sent to shuffle server in "
-                + sendCheckTimeout
-                + " ms.";
-        LOG.error(errorMsg);
-        throw new RssWaitFailedException(errorMsg);
+        int missing = blockIds.size() - successBlockIds.size();
+        int failed =
+            
Optional.ofNullable(shuffleManager.getFailedBlockIds(taskId)).map(Set::size).orElse(0);
+        String message =
+            String.format(
+                "TaskId[%s] failed because %d blocks (failed: %d}) can't be 
sent to shuffle server in %d ms",
+                taskId, missing, failed, sendCheckTimeout);
+
+        // detailed error message
+        Set<Long> missingBlockIds = new HashSet<>(blockIds);
+        missingBlockIds.removeAll(successBlockIds);
+        BlockIdLayout layout = BlockIdLayout.from(rssConf);
+        LOG.error(
+            "{}, includes partitions: {}",
+            message,
+            missingBlockIds.stream()
+                .map(x -> layout.getPartitionId(x))
+                .collect(Collectors.toSet()));
+
+        throw new RssWaitFailedException(message);
       }
     } finally {
       if (interrupted) {
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 2e16f454c..a054b9b2a 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -611,7 +611,7 @@ public class RssShuffleWriterTest {
         assertThrows(
             RuntimeException.class,
             () -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 
2L, 3L)));
-    assertTrue(e2.getMessage().startsWith("Timeout:"));
+    assertTrue(e2.getMessage().contains("failed because"));
     successBlocks.clear();
 
     // case 3: partial blocks are sent failed, Runtime exception will be thrown

Reply via email to