This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 9c0c27dea [#2725] fix(spark)(partition-split): Add fallback under
load-balance mode and fix stale assignment missing callback that caused timeout
(#2729)
9c0c27dea is described below
commit 9c0c27dead125a50f4140b342c470e0ed13ec8f6
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Feb 12 15:33:26 2026 +0800
[#2725] fix(spark)(partition-split): Add fallback under load-balance mode
and fix stale assignment missing callback that caused timeout (#2729)
### What changes were proposed in this pull request?
1. Fallback to random server when no servers are available in load-balance
mode
2. Fix stale assignment missing callback in data pusher that caused the
writer to hang until timeout, preventing reassign from being triggered
### Why are the changes needed?
fix the #2725 . Finally tracked down and fixed this tricky bug after a
thorough investigation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests
---
.../shuffle/handle/MutableShuffleHandleInfo.java | 9 +++
.../apache/spark/shuffle/writer/DataPusher.java | 36 ++++++------
.../shuffle/writer/TaskAttemptAssignment.java | 9 +--
.../apache/uniffle/shuffle/ReassignExecutor.java | 31 +++++++----
.../handle/MutableShuffleHandleInfoTest.java | 42 ++++++++++++++
.../spark/shuffle/writer/DataPusherTest.java | 65 ++++++++++++++++++++++
.../spark/shuffle/writer/RssShuffleWriter.java | 35 +++++++-----
.../spark/shuffle/writer/RssShuffleWriterTest.java | 2 +-
8 files changed, 184 insertions(+), 45 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
index 4872dc171..987497cef 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
@@ -259,6 +259,15 @@ public class MutableShuffleHandleInfo extends
ShuffleHandleInfoBase {
// 0, 1, 2
int idx = (int) (taskAttemptId % (serverSize - 1)) + 1;
candidate = servers.get(idx);
+ } else {
+ // fallback to random server if no available servers in
load-balanced mode
+ servers =
+ replicaServerEntry.getValue().stream()
+ .filter(x ->
!excludedServerToReplacements.containsKey(x.getId()))
+ .collect(Collectors.toList());
+ serverSize = servers.size();
+ int idx = (int) (taskAttemptId % (serverSize - 1)) + 1;
+ candidate = servers.get(idx);
}
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index df13e0f39..963000e8c 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -124,27 +124,26 @@ public class DataPusher implements Closeable {
String taskId = event.getTaskId();
List<ShuffleBlockInfo> blocks = event.getShuffleDataInfoList();
List<ShuffleBlockInfo> validBlocks =
filterOutStaleAssignmentBlocks(taskId, blocks);
- if (CollectionUtils.isEmpty(validBlocks)) {
- return 0L;
- }
SendShuffleDataResult result = null;
try {
- result =
- shuffleWriteClient.sendShuffleData(
- rssAppId,
- event.getStageAttemptNumber(),
- validBlocks,
- () -> !isValidTask(taskId));
- // completionCallback should be executed before updating
taskToSuccessBlockIds
- // structure to avoid side effect
- Set<Long> succeedBlockIds = getSucceedBlockIds(result);
- for (ShuffleBlockInfo block : validBlocks) {
-
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
+ if (CollectionUtils.isNotEmpty(validBlocks)) {
+ result =
+ shuffleWriteClient.sendShuffleData(
+ rssAppId,
+ event.getStageAttemptNumber(),
+ validBlocks,
+ () -> !isValidTask(taskId));
+ // completionCallback should be executed before updating
taskToSuccessBlockIds
+ // structure to avoid side effect
+ Set<Long> succeedBlockIds = getSucceedBlockIds(result);
+ for (ShuffleBlockInfo block : validBlocks) {
+
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
+ }
+ putBlockId(taskToSuccessBlockIds, taskId,
result.getSuccessBlockIds());
+ putFailedBlockSendTracker(
+ taskToFailedBlockSendTracker, taskId,
result.getFailedBlockSendTracker());
}
- putBlockId(taskToSuccessBlockIds, taskId,
result.getSuccessBlockIds());
- putFailedBlockSendTracker(
- taskToFailedBlockSendTracker, taskId,
result.getFailedBlockSendTracker());
} finally {
WriteBufferManager bufferManager = event.getBufferManager();
if (bufferManager != null && result != null) {
@@ -159,6 +158,9 @@ public class DataPusher implements Closeable {
runnable.run();
}
}
+ if (CollectionUtils.isEmpty(validBlocks)) {
+ return 0L;
+ }
Set<Long> succeedBlockIds = getSucceedBlockIds(result);
return validBlocks.stream()
.filter(x -> succeedBlockIds.contains(x.getBlockId()))
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
index 63fac0c12..2faac350b 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
@@ -91,10 +91,11 @@ public class TaskAttemptAssignment {
public boolean tryNextServerForSplitPartition(
int partitionId, List<ShuffleServerInfo> exclusiveServers) {
if (hasBeenLoadBalanced(partitionId)) {
- Set<ShuffleServerInfo> servers =
- this.exclusiveServersForPartition.computeIfAbsent(
- partitionId, k -> new ConcurrentSkipListSet<>());
- servers.addAll(exclusiveServers);
+ // update the exclusive servers
+ this.exclusiveServersForPartition
+ .computeIfAbsent(partitionId, k -> new ConcurrentSkipListSet<>())
+ .addAll(exclusiveServers);
+ // update the assignment due to the upper exclusive servers change
update(this.handle);
return true;
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
index c380e3ee7..57bd25641 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
@@ -33,6 +33,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
@@ -306,19 +307,22 @@ public class ReassignExecutor {
}
}
+ String readableMessage = readableResult(fastSwitchList);
if (reassignList.isEmpty()) {
- LOG.info(
- "[partition-split] All fast switch to another servers successfully
for taskId[{}]. list: {}",
- taskId,
- readableResult(fastSwitchList));
- return;
- } else {
- if (!fastSwitchList.isEmpty()) {
+ if (StringUtils.isNotEmpty(readableMessage)) {
LOG.info(
- "[partition-split] Partial fast switch to another servers for
taskId[{}]. list: {}",
+ "[partition-split] All partitions fast-switched successfully for
taskId[{}]. list: {}",
taskId,
- readableResult(fastSwitchList));
+ readableMessage);
}
+ return;
+ }
+
+ if (StringUtils.isNotEmpty(readableMessage)) {
+ LOG.info(
+ "[partition-split] Partial partitions fast-switched for taskId[{}].
list: {}",
+ taskId,
+ readableMessage);
}
@SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
@@ -385,6 +389,7 @@ public class ReassignExecutor {
List<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks =
blocks.stream()
+ .filter(x -> x.getStatusCode() != null)
.collect(Collectors.groupingBy(d ->
d.getShuffleBlockInfo().getPartitionId()));
Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new
HashMap<>();
@@ -429,8 +434,12 @@ public class ReassignExecutor {
readableResult(constructUpdateList(failurePartitionToServers)));
}
+ int staleCnt = 0;
for (TrackingBlockStatus blockStatus : blocks) {
ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
+ if (blockStatus.getStatusCode() == null) {
+ staleCnt += 1;
+ }
// todo: getting the replacement should support multi replica.
List<ShuffleServerInfo> servers =
taskAttemptAssignment.retrieve(block.getPartitionId());
// Gets the first replica for this partition for now.
@@ -459,8 +468,10 @@ public class ReassignExecutor {
}
resendBlocksFunction.accept(resendCandidates);
LOG.info(
- "[partition-reassign] All {} blocks have been resent to queue
successfully in {} ms.",
+ "[partition-reassign] {} blocks (failed/stale: {}/{}) have been resent
to queue successfully in {} ms.",
blocks.size(),
+ blocks.size() - staleCnt,
+ staleCnt,
System.currentTimeMillis() - start);
}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
index cf75152e7..63dc247d3 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
@@ -276,4 +276,46 @@ public class MutableShuffleHandleInfoTest {
// All the servers were selected as writer are available as reader
assertEquals(6, assignment.get(1).size());
}
+
+ @Test
+ public void testLoadBalanceFallbackToNonExcludedServers() {
+ // prepare servers
+ ShuffleServerInfo a = createFakeServerInfo("a");
+ ShuffleServerInfo b = createFakeServerInfo("b");
+
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ partitionToServers.put(1, Arrays.asList(a, b));
+
+ // create handle with LOAD_BALANCE mode
+ MutableShuffleHandleInfo handleInfo =
+ new MutableShuffleHandleInfo(
+ 1,
+ partitionToServers,
+ new RemoteStorageInfo(""),
+ org.apache.uniffle.common.PartitionSplitMode.LOAD_BALANCE);
+
+ int partitionId = 1;
+
+ // mark partition as split by excluding server "a"
+ Set<ShuffleServerInfo> replacements =
Sets.newHashSet(createFakeServerInfo("c"));
+ handleInfo.updateAssignmentOnPartitionSplit(partitionId, "a",
replacements);
+
+ // also make sure excludedServerToReplacements contains "b"
+ // so that first filtering (exclude problem nodes) removes all servers
+ handleInfo.updateAssignment(partitionId, "b",
Sets.newHashSet(createFakeServerInfo("d")));
+
+ // now call writer assignment
+ Map<Integer, List<ShuffleServerInfo>> available =
+ handleInfo.getAvailablePartitionServersForWriter(null);
+
+ // fallback branch should be triggered and still return a valid candidate
+ // ensure we have exactly one candidate for replica 0
+ assertTrue(available.containsKey(partitionId));
+ assertEquals(2, available.get(partitionId).size());
+
+ // candidate must be one of the original servers or appended replacements,
rather than always
+ // the last one
+ ShuffleServerInfo candidate = available.get(partitionId).get(0);
+ assertEquals("c", candidate.getId());
+ }
}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index eb357d9da..720bad4e3 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -25,6 +25,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import com.google.common.collect.Maps;
@@ -130,6 +131,70 @@ public class DataPusherTest {
assertEquals(3, failedBlockIds.stream().findFirst().get());
}
+ /**
+ * Test that when all blocks in a batch are stale (filtered out by
fast-switch), the
+ * processedCallbackChain is still executed. Before the fix, if all blocks
were stale, the early
+ * return skipped the finally block, causing the callback (which notifies
checkBlockSendResult via
+ * finishEventQueue) to never run. This led to checkBlockSendResult blocking
indefinitely on
+ * poll(), unable to call reassign() to resend the stale blocks, ultimately
timing out.
+ */
+ @Test
+ public void testProcessedCallbackChainExecutedWhenAllBlocksAreStale()
+ throws ExecutionException, InterruptedException {
+ FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
+
+ Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
JavaUtils.newConcurrentMap();
+ Set<String> failedTaskIds = new HashSet<>();
+
+ RssConf rssConf = new RssConf();
+ rssConf.set(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED, true);
+
rssConf.set(RssSparkConfig.RSS_PARTITION_REASSIGN_STALE_ASSIGNMENT_FAST_SWITCH_ENABLED,
true);
+ DataPusher dataPusher =
+ new DataPusher(
+ shuffleWriteClient,
+ taskToSuccessBlockIds,
+ taskToFailedBlockSendTracker,
+ failedTaskIds,
+ 1,
+ 2,
+ rssConf);
+ dataPusher.setRssAppId("testCallbackWhenAllStale");
+
+ String taskId = "taskId1";
+ List<ShuffleServerInfo> server1 =
+ Collections.singletonList(new ShuffleServerInfo("0", "localhost",
1234));
+ // Create a stale block: isStaleAssignment() returns true because the
+ // partitionAssignmentRetrieveFunc returns an empty list (different from
the block's servers).
+ ShuffleBlockInfo staleBlock =
+ new ShuffleBlockInfo(
+ 1, 1, 10, 1, 1, new byte[1], server1, 1, 100, 1, integer ->
Collections.emptyList());
+
+ // Track whether processedCallbackChain is invoked
+ AtomicBoolean callbackExecuted = new AtomicBoolean(false);
+ AddBlockEvent event = new AddBlockEvent(taskId, Arrays.asList(staleBlock));
+ event.addCallback(() -> callbackExecuted.set(true));
+
+ CompletableFuture<Long> future = dataPusher.send(event);
+ long result = future.get();
+
+ // The block is stale, so no data is actually sent (0 bytes freed)
+ assertEquals(0L, result);
+
+ // The stale block should be tracked in the FailedBlockSendTracker
+ Set<Long> failedBlockIds =
taskToFailedBlockSendTracker.get(taskId).getFailedBlockIds();
+ assertEquals(1, failedBlockIds.size());
+ assertEquals(10, failedBlockIds.stream().findFirst().get());
+
+ // The processedCallbackChain MUST be executed even when all blocks are
stale.
+ // Before the fix, this assertion would fail because the early return
(return 0L)
+ // was placed before the try-finally that executes the callback chain.
+ assertTrue(
+ callbackExecuted.get(),
+ "processedCallbackChain must be executed even when all blocks are
stale, "
+ + "otherwise checkBlockSendResult will block on
finishEventQueue.poll() indefinitely");
+ }
+
@Test
public void testSendData() throws ExecutionException, InterruptedException {
FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 60f3ef46f..03d92e8a3 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -81,6 +81,7 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.shuffle.BlockStats;
import org.apache.uniffle.shuffle.ReassignExecutor;
import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
@@ -518,8 +519,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
event.addCallback(
() -> {
- boolean ret = finishEventQueue.add(new Object());
- if (!ret) {
+ if (!finishEventQueue.add(new Object())) {
LOG.error("Add event " + event + " to finishEventQueue fail");
}
});
@@ -572,17 +572,26 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId);
if (currentAckValue != 0 || blockIds.size() != successBlockIds.size()) {
- int failedBlockCount = blockIds.size() - successBlockIds.size();
- String errorMsg =
- "Timeout: Task["
- + taskId
- + "] failed because "
- + failedBlockCount
- + " blocks can't be sent to shuffle server in "
- + sendCheckTimeout
- + " ms.";
- LOG.error(errorMsg);
- throw new RssWaitFailedException(errorMsg);
+ int missing = blockIds.size() - successBlockIds.size();
+ int failed =
+
Optional.ofNullable(shuffleManager.getFailedBlockIds(taskId)).map(Set::size).orElse(0);
+ String message =
+ String.format(
+ "TaskId[%s] failed because %d blocks (failed: %d}) can't be
sent to shuffle server in %d ms",
+ taskId, missing, failed, sendCheckTimeout);
+
+ // detailed error message
+ Set<Long> missingBlockIds = new HashSet<>(blockIds);
+ missingBlockIds.removeAll(successBlockIds);
+ BlockIdLayout layout = BlockIdLayout.from(rssConf);
+ LOG.error(
+ "{}, includes partitions: {}",
+ message,
+ missingBlockIds.stream()
+ .map(x -> layout.getPartitionId(x))
+ .collect(Collectors.toSet()));
+
+ throw new RssWaitFailedException(message);
}
} finally {
if (interrupted) {
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 2e16f454c..a054b9b2a 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -611,7 +611,7 @@ public class RssShuffleWriterTest {
assertThrows(
RuntimeException.class,
() -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L,
2L, 3L)));
- assertTrue(e2.getMessage().startsWith("Timeout:"));
+ assertTrue(e2.getMessage().contains("failed because"));
successBlocks.clear();
// case 3: partial blocks are sent failed, Runtime exception will be thrown