This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 2731cf2df [#2724] refactor(spark): Introduce `ReassignExecutor` to
simplify shuffle writer logic (#2727)
2731cf2df is described below
commit 2731cf2df0de5c8322acbfbc2d81a35d1b7517cb
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Feb 12 10:10:33 2026 +0800
[#2724] refactor(spark): Introduce `ReassignExecutor` to simplify shuffle
writer logic (#2727)
### What changes were proposed in this pull request?
Introduce `ReassignExecutor` to clarify shuffle writer logic
### Why are the changes needed?
for #2724
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests
---
.../shuffle/writer/TaskAttemptAssignment.java | 2 +-
.../spark/shuffle/writer/WriteBufferManager.java | 5 +
.../apache/uniffle/shuffle/ReassignExecutor.java | 471 +++++++++++++++++++++
.../shuffle/writer/TaskAttemptAssignmentTest.java | 2 +-
.../uniffle/shuffle/ReassignExecutorTest.java | 124 ++++++
.../spark/shuffle/writer/RssShuffleWriter.java | 330 ++-------------
.../spark/shuffle/writer/RssShuffleWriterTest.java | 40 +-
7 files changed, 659 insertions(+), 315 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
index 7e46528c4..63fac0c12 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
@@ -88,7 +88,7 @@ public class TaskAttemptAssignment {
* @param exclusiveServers
* @return
*/
- public boolean updatePartitionSplitAssignment(
+ public boolean tryNextServerForSplitPartition(
int partitionId, List<ShuffleServerInfo> exclusiveServers) {
if (hasBeenLoadBalanced(partitionId)) {
Set<ShuffleServerInfo> servers =
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index b50cc5d56..280c6aefd 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -678,6 +678,11 @@ public class WriteBufferManager extends MemoryConsumer {
return recordCounter.get();
}
+ @VisibleForTesting
+ protected void resetRecordCount() {
+ recordCounter.set(0);
+ }
+
public long getBlockCount() {
return blockCounter.get();
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
new file mode 100644
index 000000000..c380e3ee7
--- /dev/null
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java
@@ -0,0 +1,471 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.writer.TaskAttemptAssignment;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.impl.TrackingBlockStatus;
+import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
+import
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.common.ReceivingFailureServer;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssSendFailedException;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+/**
+ * this class is responsible for the reassignment, including the partition
split and the block
+ * resend after reassignment.
+ */
+public class ReassignExecutor {
+ private static final Logger LOG =
LoggerFactory.getLogger(ReassignExecutor.class);
+ private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND =
+ Sets.newHashSet(StatusCode.NO_REGISTER);
+
+ private Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker;
+ private final TaskAttemptAssignment taskAttemptAssignment;
+
+ private final Consumer<ShuffleBlockInfo> removeBlockStatsFunction;
+ private final Consumer<List<ShuffleBlockInfo>> resendBlocksFunction;
+ private final Supplier<ShuffleManagerClient> managerClientSupplier;
+
+ private String taskId;
+ private final TaskContext taskContext;
+ private final int shuffleId;
+ private int blockFailSentRetryMaxTimes;
+
+ public ReassignExecutor(
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker,
+ String taskId,
+ TaskAttemptAssignment taskAttemptAssignment,
+ Consumer<ShuffleBlockInfo> removeBlockStatsFunction,
+ Consumer<List<ShuffleBlockInfo>> resendBlocksFunction,
+ Supplier<ShuffleManagerClient> managerClientSupplier,
+ TaskContext taskContext,
+ int shuffleId,
+ int blockFailSentRetryMaxTimes) {
+ this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
+ this.taskId = taskId;
+ this.taskAttemptAssignment = taskAttemptAssignment;
+ this.removeBlockStatsFunction = removeBlockStatsFunction;
+ this.resendBlocksFunction = resendBlocksFunction;
+ this.managerClientSupplier = managerClientSupplier;
+ this.taskContext = taskContext;
+ this.shuffleId = shuffleId;
+ this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes;
+ LOG.debug("Initialized {} for taskId[{}]",
this.getClass().getSimpleName(), taskId);
+ }
+
+ public void reassign() {
+ FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(taskId);
+ if (tracker == null) {
+ return;
+ }
+ // 1. reassign for split partitions.
+ reassignOnPartitionNeedSplit(tracker);
+ // 2. reassign for failed blocks
+ reassignAndResendForFailedBlocks(tracker);
+ }
+
+ @VisibleForTesting
+ public void resetBlockRetryMaxTimes(int times) {
+ this.blockFailSentRetryMaxTimes = times;
+ }
+
+ private void releaseResources(FailedBlockSendTracker tracker, Set<Long>
blockIds) {
+ for (Long blockId : blockIds) {
+ List<TrackingBlockStatus> failedBlockStatus =
tracker.getFailedBlockStatus(blockId);
+ if (CollectionUtils.isNotEmpty(failedBlockStatus)) {
+ TrackingBlockStatus blockStatus = failedBlockStatus.get(0);
+ blockStatus.getShuffleBlockInfo().executeCompletionCallback(true);
+ }
+ }
+ }
+
+ private void reassignAndResendForFailedBlocks(FailedBlockSendTracker
failedBlockSendTracker) {
+ Set<Long> failedBlockIds = failedBlockSendTracker.getFailedBlockIds();
+ if (CollectionUtils.isEmpty(failedBlockIds)) {
+ return;
+ }
+
+ Set<TrackingBlockStatus> resendBlocks = new HashSet<>();
+ for (Long blockId : failedBlockIds) {
+ List<TrackingBlockStatus> failedBlockStatus =
+ failedBlockSendTracker.getFailedBlockStatus(blockId);
+ synchronized (failedBlockStatus) {
+ int retryCnt =
+ failedBlockStatus.stream()
+ .filter(
+ x -> {
+ // If statusCode is null, the block was resent due to a
stale assignment.
+ // In this case, the retry count checking should be
ignored.
+ return x.getStatusCode() != null;
+ })
+ .map(x -> x.getShuffleBlockInfo().getRetryCnt())
+ .max(Comparator.comparing(Integer::valueOf))
+ .orElse(-1);
+ if (retryCnt >= blockFailSentRetryMaxTimes) {
+ releaseResources(failedBlockSendTracker, failedBlockIds);
+ String message =
+ String.format(
+ "Block send retry exceeded max retries. blockId=%d,
retryCount=%d, maxRetry=%d, faultyServers=%s",
+ blockId,
+ retryCnt,
+ blockFailSentRetryMaxTimes,
+ failedBlockStatus.stream()
+ .map(TrackingBlockStatus::getShuffleServerInfo)
+ .map(ShuffleServerInfo::getId)
+ .collect(Collectors.toSet()));
+ throw new RssSendFailedException(message);
+ }
+
+ for (TrackingBlockStatus status : failedBlockStatus) {
+ StatusCode code = status.getStatusCode();
+ if (STATUS_CODE_WITHOUT_BLOCK_RESEND.contains(code)) {
+ releaseResources(failedBlockSendTracker, failedBlockIds);
+ String message =
+ String.format(
+ "Block send failed with status code [%s] which does not
trigger block resend. blockId=%d, retryCount=%d, maxRetry=%d, faultyServer=%s",
+ code,
+ blockId,
+ retryCnt,
+ blockFailSentRetryMaxTimes,
+ status.getShuffleServerInfo());
+ throw new RssSendFailedException(message);
+ }
+ }
+
+ // todo: if setting multi replica and another replica is succeed to
send, no need to resend
+ resendBlocks.addAll(failedBlockStatus);
+ }
+ }
+ reassignAndResendBlocks(resendBlocks);
+ }
+
+ private Map<Integer, Pair<List<String>, List<String>>> constructUpdateList(
+ Map<Integer, List<ReceivingFailureServer>> requestList) {
+ Map<Integer, Pair<List<String>, List<String>>> reassignUpdateList = new
HashMap<>();
+ for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
requestList.entrySet()) {
+ Integer partitionId = entry.getKey();
+ List<String> oldServers =
+ entry.getValue().stream()
+ .map(ReceivingFailureServer::getServerId)
+ .collect(Collectors.toList());
+ List<String> newServers =
+ taskAttemptAssignment.retrieve(partitionId).stream()
+ .map(ShuffleServerInfo::getId)
+ .collect(Collectors.toList());
+ reassignUpdateList.put(partitionId, Pair.of(oldServers, newServers));
+ }
+ return reassignUpdateList;
+ }
+
+ @VisibleForTesting
+ protected static String readableResult(
+ Map<Integer, Pair<List<String>, List<String>>> fastSwitchList) {
+ if (fastSwitchList == null || fastSwitchList.isEmpty()) {
+ return "";
+ }
+
+ StringBuilder sb = new StringBuilder();
+ boolean hasDiff = false;
+
+ for (Map.Entry<Integer, Pair<List<String>, List<String>>> entry :
+ fastSwitchList.entrySet().stream()
+ .sorted(Map.Entry.comparingByKey())
+ .collect(Collectors.toList())) {
+
+ Integer partitionId = entry.getKey();
+ Pair<List<String>, List<String>> servers = entry.getValue();
+ List<String> oldServers = servers.getLeft();
+ List<String> newServers = servers.getRight();
+
+ // compare as set to avoid ordering impact
+ if (oldServers != null
+ && newServers != null
+ && new HashSet<>(oldServers).equals(new HashSet<>(newServers))) {
+ continue;
+ }
+
+ hasDiff = true;
+
+ sb.append("partitionId=")
+ .append(partitionId)
+ .append(": ")
+ .append(oldServers)
+ .append(" -> ")
+ .append(newServers)
+ .append("; ");
+ }
+
+ if (!hasDiff) {
+ return "";
+ }
+
+ return sb.toString().trim();
+ }
+
+ private void reassignOnPartitionNeedSplit(FailedBlockSendTracker
failedTracker) {
+ Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new
HashMap<>();
+
+ failedTracker
+ .removeAllTrackedPartitions()
+ .forEach(
+ partitionStatus -> {
+ List<ReceivingFailureServer> servers =
+ failurePartitionToServers.computeIfAbsent(
+ partitionStatus.getPartitionId(), x -> new
ArrayList<>());
+ String serverId = partitionStatus.getShuffleServerInfo().getId();
+ // todo: use better data structure to filter
+ if (!servers.stream()
+ .map(x -> x.getServerId())
+ .collect(Collectors.toSet())
+ .contains(serverId)) {
+ servers.add(new ReceivingFailureServer(serverId,
StatusCode.SUCCESS));
+ }
+ });
+
+ if (failurePartitionToServers.isEmpty()) {
+ return;
+ }
+
+ //
+ // For the [load balance] mode
+ // Once partition has been split, the following split trigger will be
ignored.
+ //
+ // For the [pipeline] mode
+ // The split request will be always response
+ //
+
+ // the list of reassign list
+ Map<Integer, List<ReceivingFailureServer>> reassignList = new HashMap<>();
+
+ // the list of fast switch list. key: partitionId, value: left=old,
right=new
+ Map<Integer, Pair<List<String>, List<String>>> fastSwitchList = new
HashMap<>();
+
+ for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
+ failurePartitionToServers.entrySet()) {
+ int partitionId = entry.getKey();
+ List<ReceivingFailureServer> failureServers = entry.getValue();
+ if (taskAttemptAssignment.tryNextServerForSplitPartition(
+ partitionId,
+ failureServers.stream()
+ .map(x -> ShuffleServerInfo.from(x.getServerId()))
+ .collect(Collectors.toList()))) {
+ fastSwitchList.put(
+ partitionId,
+ Pair.of(
+ failureServers.stream()
+ .map(ReceivingFailureServer::getServerId)
+ .collect(Collectors.toList()),
+ taskAttemptAssignment.retrieve(partitionId).stream()
+ .map(ShuffleServerInfo::getId)
+ .collect(Collectors.toList())));
+ } else {
+ reassignList.put(partitionId, failureServers);
+ }
+ }
+
+ if (reassignList.isEmpty()) {
+ LOG.info(
+ "[partition-split] All fast switch to another servers successfully
for taskId[{}]. list: {}",
+ taskId,
+ readableResult(fastSwitchList));
+ return;
+ } else {
+ if (!fastSwitchList.isEmpty()) {
+ LOG.info(
+ "[partition-split] Partial fast switch to another servers for
taskId[{}]. list: {}",
+ taskId,
+ readableResult(fastSwitchList));
+ }
+ }
+
+ @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
+ long start = System.currentTimeMillis();
+ doReassignOnBlockSendFailure(reassignList, true);
+ LOG.info(
+ "[partition-split] Reassign successfully for taskId[{}] in {} ms.
list: {}",
+ taskId,
+ System.currentTimeMillis() - start,
+ readableResult(constructUpdateList(reassignList)));
+ }
+
+ @VisibleForTesting
+ protected void doReassignOnBlockSendFailure(
+ Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers,
+ boolean partitionSplit) {
+ LOG.info(
+ "Initiate reassignOnBlockSendFailure of taskId[{}]. isPartitionSplit:
{}. failurePartitionServers: {}.",
+ taskId,
+ partitionSplit,
+ failurePartitionToServers);
+ // for tests to set the default value
+ String executorId = "NULL";
+ try {
+ executorId = SparkEnv.get().executorId();
+ } catch (Exception e) {
+ // ignore
+ }
+ long taskAttemptId = taskContext.taskAttemptId();
+ int stageId = taskContext.stageId();
+ int stageAttemptNum = taskContext.stageAttemptNumber();
+ try {
+ RssReassignOnBlockSendFailureRequest request =
+ new RssReassignOnBlockSendFailureRequest(
+ shuffleId,
+ failurePartitionToServers,
+ executorId,
+ taskAttemptId,
+ stageId,
+ stageAttemptNum,
+ partitionSplit);
+ RssReassignOnBlockSendFailureResponse response =
+ managerClientSupplier.get().reassignOnBlockSendFailure(request);
+ if (response.getStatusCode() != StatusCode.SUCCESS) {
+ String msg =
+ String.format(
+ "Reassign request failed. statusCode: %s, msg: %s",
+ response.getStatusCode(), response.getMessage());
+ throw new RssException(msg);
+ }
+ MutableShuffleHandleInfo handle =
MutableShuffleHandleInfo.fromProto(response.getHandle());
+ taskAttemptAssignment.update(handle);
+ } catch (Exception e) {
+ throw new RssException(
+ "Errors on reassign on block send failure. failure
partition->servers : "
+ + failurePartitionToServers,
+ e);
+ }
+ }
+
+ private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
+ @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
+ long start = System.currentTimeMillis();
+ List<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
+ Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks =
+ blocks.stream()
+ .collect(Collectors.groupingBy(d ->
d.getShuffleBlockInfo().getPartitionId()));
+
+ Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new
HashMap<>();
+ for (Map.Entry<Integer, List<TrackingBlockStatus>> entry :
partitionedFailedBlocks.entrySet()) {
+ int partitionId = entry.getKey();
+ List<TrackingBlockStatus> partitionBlocks = entry.getValue();
+ Map<ShuffleServerInfo, TrackingBlockStatus> serverBlocks =
+ partitionBlocks.stream()
+ .collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()))
+ .entrySet()
+ .stream()
+ .collect(
+ Collectors.toMap(
+ Map.Entry::getKey, x ->
x.getValue().stream().findFirst().get()));
+ for (Map.Entry<ShuffleServerInfo, TrackingBlockStatus> blockStatusEntry :
+ serverBlocks.entrySet()) {
+ String serverId = blockStatusEntry.getKey().getId();
+ // avoid duplicate reassign for the same failure server.
+ // todo: getting the replacement should support multi replica.
+ List<ShuffleServerInfo> servers =
taskAttemptAssignment.retrieve(partitionId);
+ // Gets the first replica for this partition for now.
+ // It can not work if we want to use multiple replicas.
+ ShuffleServerInfo replacement = servers.get(0);
+ String latestServerId = replacement.getId();
+ if (!serverId.equals(latestServerId)) {
+ continue;
+ }
+ StatusCode code = blockStatusEntry.getValue().getStatusCode();
+ failurePartitionToServers
+ .computeIfAbsent(partitionId, x -> new ArrayList<>())
+ .add(new ReceivingFailureServer(serverId, code));
+ }
+ }
+
+ if (!failurePartitionToServers.isEmpty()) {
+ @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
+ long requestStart = System.currentTimeMillis();
+ doReassignOnBlockSendFailure(failurePartitionToServers, false);
+ LOG.info(
+ "[partition-reassign] Do reassign request successfully in {} ms.
list: {}",
+ System.currentTimeMillis() - requestStart,
+ readableResult(constructUpdateList(failurePartitionToServers)));
+ }
+
+ for (TrackingBlockStatus blockStatus : blocks) {
+ ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
+ // todo: getting the replacement should support multi replica.
+ List<ShuffleServerInfo> servers =
taskAttemptAssignment.retrieve(block.getPartitionId());
+ // Gets the first replica for this partition for now.
+ // It can not work if we want to use multiple replicas.
+ ShuffleServerInfo replacement = servers.get(0);
+ if
(blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) {
+ LOG.warn(
+ "PartitionId:{} has the following assigned servers: {}. But
currently the replacement server:{} is the same with previous one!",
+ block.getPartitionId(),
+ taskAttemptAssignment.list(block.getPartitionId()),
+ replacement);
+ throw new RssException(
+ "No available replacement server for: " +
blockStatus.getShuffleServerInfo().getId());
+ }
+ // clear the previous retry state of block
+ removeBlockStatsFunction.accept(block);
+ final ShuffleBlockInfo newBlock = block;
+ // if the status code is null, it means the block is resent due to stale
assignment, not
+ // because of the block send failure. In this case, the retry count
should not be increased;
+ // otherwise it may cause unexpected fast failure.
+ if (blockStatus.getStatusCode() != null) {
+ newBlock.incrRetryCnt();
+ }
+ newBlock.reassignShuffleServers(Arrays.asList(replacement));
+ resendCandidates.add(newBlock);
+ }
+ resendBlocksFunction.accept(resendCandidates);
+ LOG.info(
+ "[partition-reassign] All {} blocks have been resent to queue
successfully in {} ms.",
+ blocks.size(),
+ System.currentTimeMillis() - start);
+ }
+
+ @VisibleForTesting
+ public void resetTaskId(String taskId) {
+ this.taskId = taskId;
+ }
+}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
index b68e2964c..2ec928b10 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/TaskAttemptAssignmentTest.java
@@ -79,7 +79,7 @@ public class TaskAttemptAssignmentTest {
public void testUpdatePartitionSplitAssignment() {
TaskAttemptAssignment assignment = new TaskAttemptAssignment(1, new
MockShuffleHandleInfo());
assertTrue(
- assignment.updatePartitionSplitAssignment(
+ assignment.tryNextServerForSplitPartition(
1, Arrays.asList(new ShuffleServerInfo("localhost", 122))));
}
}
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ReassignExecutorTest.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ReassignExecutorTest.java
new file mode 100644
index 000000000..fd25e0a2c
--- /dev/null
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/ReassignExecutorTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.writer.TaskAttemptAssignment;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.impl.TrackingBlockStatus;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssSendFailedException;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class ReassignExecutorTest {
+
+ @Mock private FailedBlockSendTracker failedBlockSendTracker =
mock(FailedBlockSendTracker.class);
+
+ @Mock private TaskAttemptAssignment taskAttemptAssignment =
mock(TaskAttemptAssignment.class);
+
+ @Mock private ShuffleManagerClient shuffleManagerClient =
mock(ShuffleManagerClient.class);
+
+ @Mock private TaskContext taskContext = mock(TaskContext.class);
+
+ @Mock private Consumer<ShuffleBlockInfo> removeBlockStatsFunction =
mock(Consumer.class);
+
+ @Mock private Consumer<List<ShuffleBlockInfo>> resendBlocksFunction =
mock(Consumer.class);
+
+ private ReassignExecutor executor = mock(ReassignExecutor.class);
+
+ @BeforeEach
+ void setUp() {
+ when(taskContext.taskAttemptId()).thenReturn(1L);
+ when(taskContext.stageId()).thenReturn(1);
+ when(taskContext.stageAttemptNumber()).thenReturn(0);
+
+ Map<String, FailedBlockSendTracker> taskToTracker = new HashMap<>();
+ taskToTracker.put("task1", failedBlockSendTracker);
+ executor =
+ new ReassignExecutor(
+ taskToTracker,
+ "task1",
+ taskAttemptAssignment,
+ removeBlockStatsFunction,
+ resendBlocksFunction,
+ () -> shuffleManagerClient,
+ taskContext,
+ 1,
+ 3);
+ }
+
+ @Test
+ void testRetryExceededShouldFailAndReleaseResources() {
+ long blockId = 100L;
+
+ ShuffleBlockInfo blockInfo =
org.mockito.Mockito.mock(ShuffleBlockInfo.class);
+ when(blockInfo.getRetryCnt()).thenReturn(3);
+
+ TrackingBlockStatus status =
org.mockito.Mockito.mock(TrackingBlockStatus.class);
+ when(status.getShuffleBlockInfo()).thenReturn(blockInfo);
+ when(status.getStatusCode()).thenReturn(StatusCode.INTERNAL_ERROR);
+ when(status.getShuffleServerInfo()).thenReturn(new
ShuffleServerInfo("localhost", 1234));
+
+ when(failedBlockSendTracker.getFailedBlockIds())
+ .thenReturn(new HashSet<>(Arrays.asList(blockId)));
+
when(failedBlockSendTracker.getFailedBlockStatus(blockId)).thenReturn(Arrays.asList(status));
+
+ assertThrows(RssSendFailedException.class, executor::reassign);
+
+ verify(blockInfo).executeCompletionCallback(true);
+ }
+
+ @Test
+ public void testMixedSameAndDifferent() throws Exception {
+ Map<Integer, Pair<List<String>, List<String>>> map = new HashMap<>();
+
+ // same -> should ignore
+ map.put(1, Pair.of(Arrays.asList("s1", "s2"), Arrays.asList("s2", "s1")));
+
+ // diff -> should print
+ map.put(3, Pair.of(Arrays.asList("x"), Arrays.asList("y")));
+
+ String result = ReassignExecutor.readableResult(map);
+
+ System.out.println(result);
+
+ assertFalse(result.contains("partitionId=1"));
+ assertTrue(result.contains("partitionId=3"));
+ }
+}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 75bd1fb76..60f3ef46f 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -18,9 +18,7 @@
package org.apache.spark.shuffle.writer;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collections;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -52,7 +50,6 @@ import org.apache.commons.collections4.CollectionUtils;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
-import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
@@ -62,7 +59,6 @@ import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleWriter;
-import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.storage.BlockManagerId;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -73,13 +69,10 @@ import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.impl.TrackingBlockStatus;
-import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteMetricRequest;
-import
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteMetricResponse;
-import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
@@ -89,6 +82,7 @@ import
org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.shuffle.BlockStats;
+import org.apache.uniffle.shuffle.ReassignExecutor;
import org.apache.uniffle.shuffle.ShuffleWriteTaskStats;
import org.apache.uniffle.storage.util.StorageType;
@@ -127,7 +121,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private SparkConf sparkConf;
private RssConf rssConf;
private boolean blockFailSentRetryEnabled;
- private int blockFailSentRetryMaxTimes = 1;
/** used by columnar rss shuffle writer implementation */
protected final long taskAttemptId;
@@ -157,6 +150,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private boolean isIntegrityValidationClientManagementEnabled = false;
+ private ReassignExecutor reassignExecutor;
+
// Only for tests
@VisibleForTesting
public RssShuffleWriter(
@@ -189,6 +184,17 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
context);
this.bufferManager = bufferManager;
this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId,
shuffleHandleInfo);
+ this.reassignExecutor =
+ new ReassignExecutor(
+ shuffleManager.getTaskToFailedBlockSendTracker(),
+ taskId,
+ taskAttemptAssignment,
+ block -> clearFailedBlockState(block),
+ blocks -> processShuffleBlockInfos(blocks),
+ managerClientSupplier,
+ taskContext,
+ shuffleId,
+ rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES));
}
private RssShuffleWriter(
@@ -239,7 +245,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
+ RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.key(),
RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.defaultValue());
- this.blockFailSentRetryMaxTimes =
rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
this.enableWriteFailureRetry =
rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
this.isIntegrityValidationClientManagementEnabled =
@@ -321,6 +326,17 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this::getPartitionAssignedServers,
context.stageAttemptNumber());
this.bufferManager = bufferManager;
+ this.reassignExecutor =
+ new ReassignExecutor(
+ shuffleManager.getTaskToFailedBlockSendTracker(),
+ taskId,
+ taskAttemptAssignment,
+ block -> clearFailedBlockState(block),
+ blocks -> processShuffleBlockInfos(blocks),
+ managerClientSupplier,
+ taskContext,
+ shuffleId,
+ rssConf.get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES));
}
@VisibleForTesting
@@ -421,11 +437,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private void checkSentRecordCount(long recordCount) {
if (recordCount != bufferManager.getRecordCount()) {
- String errorMsg =
- "Potential record loss may have occurred while preparing to send
blocks for task["
- + taskId
- + "]";
- throw new RssSendFailedException(errorMsg);
+ String message =
+ String.format(
+ "Inconsistent records number for taskId[%s]. expected: %d,
actual: %d.",
+ taskId, recordCount, bufferManager.getRecordCount());
+ throw new RssSendFailedException(message);
}
}
@@ -578,7 +594,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// This method should remain protected so that Gluten can invoke it
protected void checkDataIfAnyFailure() {
if (blockFailSentRetryEnabled) {
- collectFailedBlocksToResend();
+ reassignExecutor.reassign();
} else {
String errorMsg = getFirstBlockFailure();
if (errorMsg != null) {
@@ -608,281 +624,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
return null;
}
- private void collectFailedBlocksToResend() {
- if (!blockFailSentRetryEnabled) {
- return;
- }
-
- FailedBlockSendTracker failedTracker =
shuffleManager.getBlockIdsFailedSendTracker(taskId);
- if (failedTracker == null) {
- return;
- }
-
- reassignOnPartitionNeedSplit(failedTracker);
-
- Set<Long> failedBlockIds = failedTracker.getFailedBlockIds();
- if (CollectionUtils.isEmpty(failedBlockIds)) {
- return;
- }
-
- boolean isFastFail = false;
- Set<TrackingBlockStatus> resendCandidates = new HashSet<>();
- // to check whether the blocks resent exceed the max resend count.
- for (Long blockId : failedBlockIds) {
- List<TrackingBlockStatus> failedBlockStatus =
failedTracker.getFailedBlockStatus(blockId);
- synchronized (failedBlockStatus) {
- int retryCnt =
- failedBlockStatus.stream()
- .filter(
- x -> {
- // If statusCode is null, the block was resent due to a
stale assignment.
- // In this case, the retry count checking should be
ignored.
- return x.getStatusCode() != null;
- })
- .map(x -> x.getShuffleBlockInfo().getRetryCnt())
- .max(Comparator.comparing(Integer::valueOf))
- .orElse(-1);
- if (retryCnt >= blockFailSentRetryMaxTimes) {
- LOG.error(
- "Partial blocks for taskId: [{}] retry exceeding the max retry
times: [{}]. Fast fail! faulty server list: {}",
- taskId,
- blockFailSentRetryMaxTimes,
- failedBlockStatus.stream()
- .map(x -> x.getShuffleServerInfo())
- .collect(Collectors.toSet()));
- isFastFail = true;
- break;
- }
-
- for (TrackingBlockStatus status : failedBlockStatus) {
- StatusCode code = status.getStatusCode();
- if (STATUS_CODE_WITHOUT_BLOCK_RESEND.contains(code)) {
- LOG.error(
- "Partial blocks for taskId: [{}] failed on the illegal status
code: [{}] without resend on server: {}",
- taskId,
- code,
- status.getShuffleServerInfo());
- isFastFail = true;
- break;
- }
- }
-
- // todo: if setting multi replica and another replica is succeed to
send, no need to resend
- resendCandidates.addAll(failedBlockStatus);
- }
- }
-
- if (isFastFail) {
- // release data and allocated memory
- for (Long blockId : failedBlockIds) {
- List<TrackingBlockStatus> failedBlockStatus =
failedTracker.getFailedBlockStatus(blockId);
- if (CollectionUtils.isNotEmpty(failedBlockStatus)) {
- TrackingBlockStatus blockStatus = failedBlockStatus.get(0);
- blockStatus.getShuffleBlockInfo().executeCompletionCallback(true);
- }
- }
-
- throw new RssSendFailedException(
- "Errors on resending the blocks data to the remote shuffle-server.");
- }
-
- reassignAndResendBlocks(resendCandidates);
- }
-
- private void reassignOnPartitionNeedSplit(FailedBlockSendTracker
failedTracker) {
- Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new
HashMap<>();
-
- failedTracker
- .removeAllTrackedPartitions()
- .forEach(
- partitionStatus -> {
- List<ReceivingFailureServer> servers =
- failurePartitionToServers.computeIfAbsent(
- partitionStatus.getPartitionId(), x -> new
ArrayList<>());
- String serverId = partitionStatus.getShuffleServerInfo().getId();
- // todo: use better data structure to filter
- if (!servers.stream()
- .map(x -> x.getServerId())
- .collect(Collectors.toSet())
- .contains(serverId)) {
- servers.add(new ReceivingFailureServer(serverId,
StatusCode.SUCCESS));
- }
- });
-
- if (failurePartitionToServers.isEmpty()) {
- return;
- }
-
- //
- // For the [load balance] mode
- // Once partition has been split, the following split trigger will be
ignored.
- //
- // For the [pipeline] mode
- // The split request will be always response
- //
- Map<Integer, List<ReceivingFailureServer>> partitionToServersReassignList
= new HashMap<>();
- for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
- failurePartitionToServers.entrySet()) {
- int partitionId = entry.getKey();
- List<ReceivingFailureServer> failureServers = entry.getValue();
- if (!taskAttemptAssignment.updatePartitionSplitAssignment(
- partitionId,
- failureServers.stream()
- .map(x -> ShuffleServerInfo.from(x.getServerId()))
- .collect(Collectors.toList()))) {
- partitionToServersReassignList.put(partitionId, failureServers);
- }
- }
-
- if (partitionToServersReassignList.isEmpty()) {
- LOG.info(
- "[Partition split] Skip the following partition split request (maybe
has been load balanced). partitionIds: {}",
- failurePartitionToServers.keySet());
- return;
- }
-
- doReassignOnBlockSendFailure(partitionToServersReassignList, true);
-
- LOG.info("========================= Partition Split Result
=========================");
- for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
- partitionToServersReassignList.entrySet()) {
- LOG.info(
- "partitionId:{}. {} -> {}",
- entry.getKey(),
- entry.getValue().stream().map(x ->
x.getServerId()).collect(Collectors.toList()),
- taskAttemptAssignment.retrieve(entry.getKey()));
- }
-
LOG.info("==========================================================================");
- }
-
- private void doReassignOnBlockSendFailure(
- Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers,
- boolean partitionSplit) {
- LOG.info(
- "Initiate reassignOnBlockSendFailure of taskId[{}]. partition split:
{}. failure partition servers: {}. ",
- taskAttemptId,
- partitionSplit,
- failurePartitionToServers);
- String executorId = SparkEnv.get().executorId();
- long taskAttemptId = taskContext.taskAttemptId();
- int stageId = taskContext.stageId();
- int stageAttemptNum = taskContext.stageAttemptNumber();
- try {
- RssReassignOnBlockSendFailureRequest request =
- new RssReassignOnBlockSendFailureRequest(
- shuffleId,
- failurePartitionToServers,
- executorId,
- taskAttemptId,
- stageId,
- stageAttemptNum,
- partitionSplit);
- RssReassignOnBlockSendFailureResponse response =
- managerClientSupplier.get().reassignOnBlockSendFailure(request);
- if (response.getStatusCode() != StatusCode.SUCCESS) {
- String msg =
- String.format(
- "Reassign failed. statusCode: %s, msg: %s",
- response.getStatusCode(), response.getMessage());
- throw new RssException(msg);
- }
- MutableShuffleHandleInfo handle =
MutableShuffleHandleInfo.fromProto(response.getHandle());
- taskAttemptAssignment.update(handle);
-
- // print the lastest assignment for those reassignment partition ids
- Map<Integer, List<String>> reassignments = new HashMap<>();
- for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
- failurePartitionToServers.entrySet()) {
- int partitionId = entry.getKey();
- List<ShuffleServerInfo> servers =
taskAttemptAssignment.retrieve(partitionId);
- reassignments.put(
- partitionId, servers.stream().map(x ->
x.getId()).collect(Collectors.toList()));
- }
- LOG.info("Succeed to reassign that the latest assignment is {}",
reassignments);
- } catch (Exception e) {
- throw new RssException(
- "Errors on reassign on block send failure. failure
partition->servers : "
- + failurePartitionToServers,
- e);
- }
- }
-
- private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
- List<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
- Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks =
- blocks.stream()
- .collect(Collectors.groupingBy(d ->
d.getShuffleBlockInfo().getPartitionId()));
-
- Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new
HashMap<>();
- for (Map.Entry<Integer, List<TrackingBlockStatus>> entry :
partitionedFailedBlocks.entrySet()) {
- int partitionId = entry.getKey();
- List<TrackingBlockStatus> partitionBlocks = entry.getValue();
- Map<ShuffleServerInfo, TrackingBlockStatus> serverBlocks =
- partitionBlocks.stream()
- .collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()))
- .entrySet()
- .stream()
- .collect(
- Collectors.toMap(
- Map.Entry::getKey, x ->
x.getValue().stream().findFirst().get()));
- for (Map.Entry<ShuffleServerInfo, TrackingBlockStatus> blockStatusEntry :
- serverBlocks.entrySet()) {
- String serverId = blockStatusEntry.getKey().getId();
- // avoid duplicate reassign for the same failure server.
- // todo: getting the replacement should support multi replica.
- List<ShuffleServerInfo> servers =
getPartitionAssignedServers(partitionId);
- // Gets the first replica for this partition for now.
- // It can not work if we want to use multiple replicas.
- ShuffleServerInfo replacement = servers.get(0);
- String latestServerId = replacement.getId();
- if (!serverId.equals(latestServerId)) {
- continue;
- }
- StatusCode code = blockStatusEntry.getValue().getStatusCode();
- failurePartitionToServers
- .computeIfAbsent(partitionId, x -> new ArrayList<>())
- .add(new ReceivingFailureServer(serverId, code));
- }
- }
-
- if (!failurePartitionToServers.isEmpty()) {
- doReassignOnBlockSendFailure(failurePartitionToServers, false);
- }
-
- for (TrackingBlockStatus blockStatus : blocks) {
- ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
- // todo: getting the replacement should support multi replica.
- List<ShuffleServerInfo> servers =
getPartitionAssignedServers(block.getPartitionId());
- // Gets the first replica for this partition for now.
- // It can not work if we want to use multiple replicas.
- ShuffleServerInfo replacement = servers.get(0);
- if
(blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) {
- LOG.warn(
- "PartitionId:{} has the following assigned servers: {}. But
currently the replacement server:{} is the same with previous one!",
- block.getPartitionId(),
- taskAttemptAssignment.list(block.getPartitionId()),
- replacement);
- throw new RssException(
- "No available replacement server for: " +
blockStatus.getShuffleServerInfo().getId());
- }
- // clear the previous retry state of block
- clearFailedBlockState(block);
- final ShuffleBlockInfo newBlock = block;
- // if the status code is null, it means the block is resent due to stale
assignment, not
- // because of the block send failure. In this case, the retry count
should not be increased;
- // otherwise it may cause unexpected fast failure.
- if (blockStatus.getStatusCode() != null) {
- newBlock.incrRetryCnt();
- }
- newBlock.reassignShuffleServers(Arrays.asList(replacement));
- resendCandidates.add(newBlock);
- }
-
- processShuffleBlockInfos(resendCandidates);
- LOG.info(
- "Failed blocks have been resent to data pusher queue since
reassignment has been finished successfully");
- }
-
private void clearFailedBlockState(ShuffleBlockInfo block) {
shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
shuffleTaskStats.decPartitionBlock(block.getPartitionId());
@@ -1109,8 +850,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
@VisibleForTesting
- protected void setBlockFailSentRetryMaxTimes(int blockFailSentRetryMaxTimes)
{
- this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes;
+ protected void resetBlockFailSentRetryMaxTimes(int times) {
+ reassignExecutor.resetBlockRetryMaxTimes(times);
}
@VisibleForTesting
@@ -1160,6 +901,11 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
return shuffleManager;
}
+ @VisibleForTesting
+ public ReassignExecutor getReassignExecutor() {
+ return reassignExecutor;
+ }
+
public TaskAttemptAssignment getTaskAttemptAssignment() {
return taskAttemptAssignment;
}
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index dd52aa915..2e16f454c 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -20,9 +20,11 @@ package org.apache.spark.shuffle.writer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
@@ -212,7 +214,7 @@ public class RssShuffleWriterTest {
String taskId = "taskId";
MutableShuffleHandleInfo shuffleHandle = createMutableShuffleHandle();
RssShuffleWriter writer = createMockWriter(shuffleHandle, taskId);
- writer.setBlockFailSentRetryMaxTimes(10);
+ writer.resetBlockFailSentRetryMaxTimes(10);
// Make the id1 + id10 + id11 broken, and then finally, it will use the
id12 successfully
AtomicInteger failureCnt = new AtomicInteger();
@@ -498,28 +500,26 @@ public class RssShuffleWriterTest {
assertEquals(2,
serverToPartitionToBlockIds.get(replacement).get(0).size());
// case2. If exceeding the max retry times, it will fast fail.
- rssShuffleWriter.setBlockFailSentRetryMaxTimes(1);
- rssShuffleWriter.setTaskId("taskId2");
- rssShuffleWriter.getBufferManager().setTaskId("taskId2");
- taskToFailedBlockSendTracker.put("taskId2", new FailedBlockSendTracker());
- AtomicInteger rejectCnt = new AtomicInteger(0);
+ String taskId = "t2";
+ rssShuffleWriter.getReassignExecutor().resetTaskId(taskId);
+ bufferManagerSpy.resetRecordCount();
+ rssShuffleWriter.resetBlockFailSentRetryMaxTimes(1);
+ rssShuffleWriter.setTaskId(taskId);
+ rssShuffleWriter.getBufferManager().setTaskId(taskId);
+ FailedBlockSendTracker tracker = new FailedBlockSendTracker();
+ taskToFailedBlockSendTracker.put(taskId, tracker);
FakedDataPusher alwaysFailedDataPusher =
new FakedDataPusher(
event -> {
- assertEquals("taskId2", event.getTaskId());
- FailedBlockSendTracker tracker =
taskToFailedBlockSendTracker.get(event.getTaskId());
+ assertEquals(taskId, event.getTaskId());
for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
- boolean isSuccessful = true;
- ShuffleServerInfo shuffleServer =
block.getShuffleServerInfos().get(0);
- if (shuffleServer.getId().equals("id1") && rejectCnt.get() <=
3) {
- tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
- isSuccessful = false;
- rejectCnt.incrementAndGet();
- } else {
- successBlockIds.putIfAbsent(event.getTaskId(),
Sets.newConcurrentHashSet());
-
successBlockIds.get(event.getTaskId()).add(block.getBlockId());
- }
- block.executeCompletionCallback(isSuccessful);
+ tracker.add(block, block.getShuffleServerInfos().get(0),
StatusCode.NO_BUFFER);
+ block.executeCompletionCallback(false);
+ }
+ List<Runnable> callbackChain =
+
Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
+ for (Runnable runnable : callbackChain) {
+ runnable.run();
}
return new CompletableFuture<>();
});
@@ -533,8 +533,6 @@ public class RssShuffleWriterTest {
} catch (Exception e) {
// ignore
}
- assertEquals(0, bufferManagerSpy.getUsedBytes());
- assertEquals(0, bufferManagerSpy.getInSendListBytes());
}
@Test