This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 375262ad9 [#1579][part-1] fix(spark): Ensure all previous data is
cleared for stage retry (#1762)
375262ad9 is described below
commit 375262ad9236efebe0010a65a556c2dec3e9ae82
Author: yl09099 <[email protected]>
AuthorDate: Thu Jun 13 09:31:29 2024 +0800
[#1579][part-1] fix(spark): Ensure all previous data is cleared for stage
retry (#1762)
### What changes were proposed in this pull request?
1. clear out previous stage attempt data synchronously when registering the
re-assignment shuffleIds.
2. rework the stage retry interface and rpc
3. introducing the stage version to avoid accepting the older data
### Why are the changes needed?
Fix: https://github.com/apache/incubator-uniffle/issues/1579
If the previous stage attempt is in the purge queue in shuffle-server side,
the retry stage writing will cause
unknown exceptions, so we'd better to clear out all previous stage attempt
data before re-registering
This PR is to sync remove previous stage data when the first attempt writer
is initialized.
Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing tests.
---
.../hadoop/mapred/SortWriteBufferManagerTest.java | 3 +-
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 3 +-
.../org/apache/spark/shuffle/RssStageInfo.java | 33 +-
.../spark/shuffle/RssStageResubmitManager.java | 69 +++
.../handle/StageAttemptShuffleHandleInfo.java | 137 ++++++
.../apache/spark/shuffle/writer/AddBlockEvent.java | 11 +
.../apache/spark/shuffle/writer/DataPusher.java | 5 +-
.../shuffle/manager/RssShuffleManagerBase.java | 467 +++++++++++++++++++++
.../manager/RssShuffleManagerInterface.java | 19 +-
.../shuffle/manager/ShuffleManagerGrpcService.java | 47 ++-
.../spark/shuffle/writer/DataPusherTest.java | 9 +
.../shuffle/manager/DummyRssShuffleManager.java | 28 +-
.../manager/ShuffleManagerGrpcServiceTest.java | 4 +-
.../apache/spark/shuffle/RssShuffleManager.java | 334 ++-------------
.../spark/shuffle/writer/RssShuffleWriter.java | 1 +
.../apache/spark/shuffle/RssShuffleManager.java | 441 ++-----------------
.../common/sort/buffer/WriteBufferManagerTest.java | 3 +-
.../uniffle/client/api/ShuffleWriteClient.java | 32 +-
.../client/impl/ShuffleWriteClientImpl.java | 24 +-
.../netty/protocol/SendShuffleDataRequest.java | 18 +
.../org/apache/uniffle/common/rpc/StatusCode.java | 1 +
.../uniffle/client/api/ShuffleManagerClient.java | 17 +-
.../client/impl/grpc/ShuffleManagerGrpcClient.java | 27 +-
.../client/impl/grpc/ShuffleServerGrpcClient.java | 11 +-
.../impl/grpc/ShuffleServerGrpcNettyClient.java | 2 +
.../request/RssGetShuffleAssignmentsRequest.java | 24 ++
.../client/request/RssRegisterShuffleRequest.java | 32 +-
.../client/request/RssSendShuffleDataRequest.java | 15 +
.../RssReassignOnBlockSendFailureResponse.java | 2 +-
...e.java => RssReassignOnStageRetryResponse.java} | 16 +-
proto/src/main/proto/Rss.proto | 22 +-
.../uniffle/server/ShuffleServerGrpcService.java | 58 +++
.../org/apache/uniffle/server/ShuffleTaskInfo.java | 11 +
.../apache/uniffle/server/ShuffleTaskManager.java | 2 +-
.../server/netty/ShuffleServerNettyHandler.java | 13 +
35 files changed, 1153 insertions(+), 788 deletions(-)
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 5c9f401b8..430e2ff58 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -517,7 +517,8 @@ public class SortWriteBufferManagerTest {
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType distributionType,
- int maxConcurrencyPerPartitionToWrite) {}
+ int maxConcurrencyPerPartitionToWrite,
+ int stageAttemptNumber) {}
@Override
public boolean sendCommit(
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index b858312b0..7664b47d6 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -504,7 +504,8 @@ public class FetcherTest {
List<PartitionRange> partitionRanges,
RemoteStorageInfo storageType,
ShuffleDataDistributionType distributionType,
- int maxConcurrencyPerPartitionToWrite) {}
+ int maxConcurrencyPerPartitionToWrite,
+ int stageAttemptNumber) {}
@Override
public boolean sendCommit(
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageInfo.java
similarity index 50%
copy from
internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
copy to
client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageInfo.java
index 81ca7d548..c8168d6c4 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageInfo.java
@@ -15,27 +15,30 @@
* limitations under the License.
*/
-package org.apache.uniffle.client.response;
+package org.apache.spark.shuffle;
-import org.apache.uniffle.common.rpc.StatusCode;
-import org.apache.uniffle.proto.RssProtos;
+public class RssStageInfo {
+ private String stageAttemptIdAndNumber;
+ private boolean isReassigned;
-public class RssReassignOnBlockSendFailureResponse extends ClientResponse {
- private RssProtos.MutableShuffleHandleInfo handle;
+ public RssStageInfo(String stageAttemptIdAndNumber, boolean isReassigned) {
+ this.stageAttemptIdAndNumber = stageAttemptIdAndNumber;
+ this.isReassigned = isReassigned;
+ }
+
+ public String getStageAttemptIdAndNumber() {
+ return stageAttemptIdAndNumber;
+ }
- public RssReassignOnBlockSendFailureResponse(
- StatusCode statusCode, String message,
RssProtos.MutableShuffleHandleInfo handle) {
- super(statusCode, message);
- this.handle = handle;
+ public void setStageAttemptIdAndNumber(String stageAttemptIdAndNumber) {
+ this.stageAttemptIdAndNumber = stageAttemptIdAndNumber;
}
- public RssProtos.MutableShuffleHandleInfo getHandle() {
- return handle;
+ public boolean isReassigned() {
+ return isReassigned;
}
- public static RssReassignOnBlockSendFailureResponse fromProto(
- RssProtos.RssReassignOnBlockSendFailureResponse response) {
- return new RssReassignOnBlockSendFailureResponse(
- StatusCode.valueOf(response.getStatus().name()), response.getMsg(),
response.getHandle());
+ public void setReassigned(boolean reassigned) {
+ isReassigned = reassigned;
}
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java
new file mode 100644
index 000000000..028622f92
--- /dev/null
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle;
+
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.util.JavaUtils;
+
+public class RssStageResubmitManager {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(RssStageResubmitManager.class);
+ /** Blacklist of the Shuffle Server when the write fails. */
+ private Set<String> serverIdBlackList;
+ /**
+ * Prevent multiple tasks from reporting FetchFailed, resulting in multiple
ShuffleServer
+ * assignments, stageID, Attemptnumber Whether to reassign the combination
flag;
+ */
+ private Map<Integer, RssStageInfo> serverAssignedInfos;
+
+ public RssStageResubmitManager() {
+ this.serverIdBlackList = Sets.newConcurrentHashSet();
+ this.serverAssignedInfos = JavaUtils.newConcurrentMap();
+ }
+
+ public Set<String> getServerIdBlackList() {
+ return serverIdBlackList;
+ }
+
+ public void resetServerIdBlackList(Set<String> failuresShuffleServerIds) {
+ this.serverIdBlackList = failuresShuffleServerIds;
+ }
+
+ public void recordFailuresShuffleServer(String shuffleServerId) {
+ serverIdBlackList.add(shuffleServerId);
+ }
+
+ public RssStageInfo recordAndGetServerAssignedInfo(int shuffleId, String
stageIdAndAttempt) {
+
+ return serverAssignedInfos.computeIfAbsent(
+ shuffleId, id -> new RssStageInfo(stageIdAndAttempt, false));
+ }
+
+ public void recordAndGetServerAssignedInfo(
+ int shuffleId, String stageIdAndAttempt, boolean isRetried) {
+ serverAssignedInfos
+ .computeIfAbsent(shuffleId, id -> new RssStageInfo(stageIdAndAttempt,
false))
+ .setReassigned(isRetried);
+ }
+}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
new file mode 100644
index 000000000..8fd9642ac
--- /dev/null
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.handle;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Lists;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.proto.RssProtos;
+
+public class StageAttemptShuffleHandleInfo extends ShuffleHandleInfoBase {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(StageAttemptShuffleHandleInfo.class);
+
+ private ShuffleHandleInfo current;
+ /** When Stage retry occurs, record the Shuffle Server of the previous
Stage. */
+ private LinkedList<ShuffleHandleInfo> historyHandles;
+
+ public StageAttemptShuffleHandleInfo(
+ int shuffleId, RemoteStorageInfo remoteStorage, ShuffleHandleInfo
shuffleServerInfo) {
+ super(shuffleId, remoteStorage);
+ this.current = shuffleServerInfo;
+ this.historyHandles = Lists.newLinkedList();
+ }
+
+ public StageAttemptShuffleHandleInfo(
+ int shuffleId,
+ RemoteStorageInfo remoteStorage,
+ ShuffleHandleInfo currentShuffleServerInfo,
+ LinkedList<ShuffleHandleInfo> historyHandles) {
+ super(shuffleId, remoteStorage);
+ this.current = currentShuffleServerInfo;
+ this.historyHandles = historyHandles;
+ }
+
+ @Override
+ public Set<ShuffleServerInfo> getServers() {
+ return current.getServers();
+ }
+
+ @Override
+ public Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter() {
+ return current.getAvailablePartitionServersForWriter();
+ }
+
+ @Override
+ public Map<Integer, List<ShuffleServerInfo>>
getAllPartitionServersForReader() {
+ return current.getAllPartitionServersForReader();
+ }
+
+ @Override
+ public PartitionDataReplicaRequirementTracking
createPartitionReplicaTracking() {
+ return current.createPartitionReplicaTracking();
+ }
+
+ /**
+ * When a Stage retry occurs, replace the current shuffleHandleInfo and
record the historical
+ * shuffleHandleInfo.
+ */
+ public void replaceCurrentShuffleHandleInfo(ShuffleHandleInfo
shuffleHandleInfo) {
+ this.historyHandles.add(current);
+ this.current = shuffleHandleInfo;
+ }
+
+ public ShuffleHandleInfo getCurrent() {
+ return current;
+ }
+
+ public LinkedList<ShuffleHandleInfo> getHistoryHandles() {
+ return historyHandles;
+ }
+
+ public static RssProtos.StageAttemptShuffleHandleInfo toProto(
+ StageAttemptShuffleHandleInfo handleInfo) {
+ LinkedList<RssProtos.MutableShuffleHandleInfo>
mutableShuffleHandleInfoLinkedList =
+ Lists.newLinkedList();
+ RssProtos.MutableShuffleHandleInfo currentMutableShuffleHandleInfo =
+ MutableShuffleHandleInfo.toProto((MutableShuffleHandleInfo)
handleInfo.getCurrent());
+ for (ShuffleHandleInfo historyHandle : handleInfo.getHistoryHandles()) {
+ mutableShuffleHandleInfoLinkedList.add(
+ MutableShuffleHandleInfo.toProto((MutableShuffleHandleInfo)
historyHandle));
+ }
+ RssProtos.StageAttemptShuffleHandleInfo handleProto =
+ RssProtos.StageAttemptShuffleHandleInfo.newBuilder()
+
.setCurrentMutableShuffleHandleInfo(currentMutableShuffleHandleInfo)
+
.addAllHistoryMutableShuffleHandleInfo(mutableShuffleHandleInfoLinkedList)
+ .build();
+ return handleProto;
+ }
+
+ public static StageAttemptShuffleHandleInfo fromProto(
+ RssProtos.StageAttemptShuffleHandleInfo handleProto) {
+ if (handleProto == null) {
+ return null;
+ }
+
+ MutableShuffleHandleInfo mutableShuffleHandleInfo =
+
MutableShuffleHandleInfo.fromProto(handleProto.getCurrentMutableShuffleHandleInfo());
+ List<RssProtos.MutableShuffleHandleInfo>
historyMutableShuffleHandleInfoList =
+ handleProto.getHistoryMutableShuffleHandleInfoList();
+ LinkedList<ShuffleHandleInfo> historyHandles = Lists.newLinkedList();
+ for (RssProtos.MutableShuffleHandleInfo shuffleHandleInfo :
+ historyMutableShuffleHandleInfoList) {
+
historyHandles.add(MutableShuffleHandleInfo.fromProto(shuffleHandleInfo));
+ }
+
+ StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
+ new StageAttemptShuffleHandleInfo(
+ mutableShuffleHandleInfo.shuffleId,
+ mutableShuffleHandleInfo.remoteStorage,
+ mutableShuffleHandleInfo,
+ historyHandles);
+ return stageAttemptShuffleHandleInfo;
+ }
+}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
index 9751ba0b8..f989fdb0b 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
@@ -25,11 +25,18 @@ import org.apache.uniffle.common.ShuffleBlockInfo;
public class AddBlockEvent {
private String taskId;
+ private int stageAttemptNumber;
private List<ShuffleBlockInfo> shuffleDataInfoList;
private List<Runnable> processedCallbackChain;
public AddBlockEvent(String taskId, List<ShuffleBlockInfo>
shuffleDataInfoList) {
+ this(taskId, 0, shuffleDataInfoList);
+ }
+
+ public AddBlockEvent(
+ String taskId, int stageAttemptNumber, List<ShuffleBlockInfo>
shuffleDataInfoList) {
this.taskId = taskId;
+ this.stageAttemptNumber = stageAttemptNumber;
this.shuffleDataInfoList = shuffleDataInfoList;
this.processedCallbackChain = new ArrayList<>();
}
@@ -43,6 +50,10 @@ public class AddBlockEvent {
return taskId;
}
+ public int getStageAttemptNumber() {
+ return stageAttemptNumber;
+ }
+
public List<ShuffleBlockInfo> getShuffleDataInfoList() {
return shuffleDataInfoList;
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index e9ef2ba61..bdf0cf849 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -92,7 +92,10 @@ public class DataPusher implements Closeable {
try {
result =
shuffleWriteClient.sendShuffleData(
- rssAppId, shuffleBlockInfoList, () ->
!isValidTask(taskId));
+ rssAppId,
+ event.getStageAttemptNumber(),
+ shuffleBlockInfoList,
+ () -> !isValidTask(taskId));
putBlockId(taskToSuccessBlockIds, taskId,
result.getSuccessBlockIds());
putFailedBlockSendTracker(
taskToFailedBlockSendTracker, taskId,
result.getFailedBlockSendTracker());
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 77cb173e3..912bf998f 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -19,16 +19,24 @@ package org.apache.uniffle.shuffle.manager;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.apache.commons.collections.CollectionUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.MapOutputTracker;
@@ -38,22 +46,41 @@ import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
+import org.apache.spark.shuffle.RssStageInfo;
+import org.apache.spark.shuffle.RssStageResubmitManager;
+import org.apache.spark.shuffle.ShuffleHandleInfoManager;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.CoordinatorClient;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
+import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
+import
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
+import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.shuffle.BlockIdManager;
import static
org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX;
@@ -65,7 +92,34 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;
private volatile BlockIdManager blockIdManager;
+ protected ShuffleDataDistributionType dataDistributionType;
private Object blockIdManagerLock = new Object();
+ protected AtomicReference<String> id = new AtomicReference<>();
+ protected String appId = "";
+ protected ShuffleWriteClient shuffleWriteClient;
+ protected boolean dynamicConfEnabled;
+ protected int maxConcurrencyPerPartitionToWrite;
+ protected String clientType;
+
+ protected SparkConf sparkConf;
+ protected ShuffleManagerClient shuffleManagerClient;
+ /** Whether to enable the dynamic shuffleServer function rewrite and reread
functions */
+ protected boolean rssResubmitStage;
+ /**
+ * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is
dynamically allocated.
+ * ShuffleServer is not obtained from RssShuffleHandle, but from this
mapping.
+ */
+ protected ShuffleHandleInfoManager shuffleHandleInfoManager;
+
+ protected RssStageResubmitManager rssStageResubmitManager;
+
+ protected int partitionReassignMaxServerNum;
+
+ protected boolean blockIdSelfManagedEnabled;
+
+ protected boolean taskBlockSendFailureRetryEnabled;
+
+ protected boolean shuffleManagerRpcServiceEnabled;
public BlockIdManager getBlockIdManager() {
if (blockIdManager == null) {
@@ -520,4 +574,417 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
return new RemoteStorageInfo(
sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""),
confItems);
}
+
+ /**
+ * In Stage Retry mode, obtain the Shuffle Server list from the Driver based
on shuffleId.
+ *
+ * @param shuffleId shuffleId
+ * @return ShuffleHandleInfo
+ */
+ protected synchronized StageAttemptShuffleHandleInfo
getRemoteShuffleHandleInfoWithStageRetry(
+ int shuffleId) {
+ RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+ new RssPartitionToShuffleServerRequest(shuffleId);
+ RssReassignOnStageRetryResponse rpcPartitionToShufflerServer =
+ getOrCreateShuffleManagerClient()
+
.getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
+ StageAttemptShuffleHandleInfo shuffleHandleInfo =
+ StageAttemptShuffleHandleInfo.fromProto(
+ rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
+ return shuffleHandleInfo;
+ }
+
+ /**
+ * In Block Retry mode, obtain the Shuffle Server list from the Driver based
on shuffleId.
+ *
+ * @param shuffleId shuffleId
+ * @return ShuffleHandleInfo
+ */
+ protected synchronized MutableShuffleHandleInfo
getRemoteShuffleHandleInfoWithBlockRetry(
+ int shuffleId) {
+ RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
+ new RssPartitionToShuffleServerRequest(shuffleId);
+ RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer =
+ getOrCreateShuffleManagerClient()
+
.getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
+ MutableShuffleHandleInfo shuffleHandleInfo =
+
MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
+ return shuffleHandleInfo;
+ }
+
+ // todo: automatic close client when the client is idle to avoid too much
connections for spark
+ // driver.
+ protected ShuffleManagerClient getOrCreateShuffleManagerClient() {
+ if (shuffleManagerClient == null) {
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ String driver = rssConf.getString("driver.host", "");
+ int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+ this.shuffleManagerClient =
+ ShuffleManagerClientFactory.getInstance()
+ .createShuffleManagerClient(ClientType.GRPC, driver, port);
+ }
+ return shuffleManagerClient;
+ }
+
+ @Override
+ public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+ return shuffleHandleInfoManager.get(shuffleId);
+ }
+
+ /**
+ * @return the maximum number of fetch failures per shuffle partition before
that shuffle stage
+ * should be recomputed
+ */
+ @Override
+ public int getMaxFetchFailures() {
+ final String TASK_MAX_FAILURE = "spark.task.maxFailures";
+ return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1);
+ }
+
+ /**
+ * Add the shuffleServer that failed to write to the failure list
+ *
+ * @param shuffleServerId
+ */
+ @Override
+ public void addFailuresShuffleServerInfos(String shuffleServerId) {
+ rssStageResubmitManager.recordFailuresShuffleServer(shuffleServerId);
+ }
+
+ /**
+ * Reassign the ShuffleServer list for ShuffleId
+ *
+ * @param shuffleId
+ * @param numPartitions
+ */
+ @Override
+ public boolean reassignOnStageResubmit(
+ int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
+ String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
+ RssStageInfo rssStageInfo =
+ rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId,
stageIdAndAttempt);
+ synchronized (rssStageInfo) {
+ Boolean needReassign = rssStageInfo.isReassigned();
+ if (!needReassign) {
+ int requiredShuffleServerNumber =
+ RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
+ int estimateTaskConcurrency =
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
+
+ /**
+ * this will clear up the previous stage attempt all data when
registering the same
+ * shuffleId at the second time
+ */
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ requestShuffleAssignment(
+ shuffleId,
+ numPartitions,
+ 1,
+ requiredShuffleServerNumber,
+ estimateTaskConcurrency,
+ rssStageResubmitManager.getServerIdBlackList(),
+ stageAttemptNumber);
+ /**
+ * we need to clear the metadata of the completed task, otherwise some
of the stage's data
+ * will be lost
+ */
+ try {
+ unregisterAllMapOutput(shuffleId);
+ } catch (SparkException e) {
+ LOG.error("Clear MapoutTracker Meta failed!");
+ throw new RssException("Clear MapoutTracker Meta failed!", e);
+ }
+ MutableShuffleHandleInfo shuffleHandleInfo =
+ new MutableShuffleHandleInfo(shuffleId, partitionToServers,
getRemoteStorageInfo());
+ StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
+ (StageAttemptShuffleHandleInfo)
shuffleHandleInfoManager.get(shuffleId);
+
stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo);
+ rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId,
stageIdAndAttempt, true);
+ return true;
+ } else {
+ LOG.info(
+ "Do nothing that the stage: {} has been reassigned for attempt{}",
+ stageId,
+ stageAttemptNumber);
+ return false;
+ }
+ }
+ }
+
+ /** this is only valid on driver side that exposed to being invoked by grpc
server */
+ @Override
+ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
+ int shuffleId, Map<Integer, List<ReceivingFailureServer>>
partitionToFailureServers) {
+ long startTime = System.currentTimeMillis();
+ MutableShuffleHandleInfo handleInfo =
+ (MutableShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId);
+ synchronized (handleInfo) {
+ // If the reassignment servers for one partition exceeds the max
reassign server num,
+ // it should fast fail.
+ handleInfo.checkPartitionReassignServerNum(
+ partitionToFailureServers.keySet(), partitionReassignMaxServerNum);
+
+ Map<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new
HashMap<>();
+ // receivingFailureServer -> partitionId -> replacementServerIds. For
logging
+ Map<String, Map<Integer, Set<String>>> reassignResult = new HashMap<>();
+
+ for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
+ partitionToFailureServers.entrySet()) {
+ int partitionId = entry.getKey();
+ for (ReceivingFailureServer receivingFailureServer : entry.getValue())
{
+ StatusCode code = receivingFailureServer.getStatusCode();
+ String serverId = receivingFailureServer.getServerId();
+
+ boolean serverHasReplaced = false;
+ Set<ShuffleServerInfo> replacements =
handleInfo.getReplacements(serverId);
+ if (CollectionUtils.isEmpty(replacements)) {
+ final int requiredServerNum = 1;
+ Set<String> excludedServers = new
HashSet<>(handleInfo.listExcludedServers());
+ excludedServers.add(serverId);
+ replacements =
+ reassignServerForTask(
+ shuffleId, Sets.newHashSet(partitionId), excludedServers,
requiredServerNum);
+ } else {
+ serverHasReplaced = true;
+ }
+
+ Set<ShuffleServerInfo> updatedReassignServers =
+ handleInfo.updateAssignment(partitionId, serverId, replacements);
+
+ reassignResult
+ .computeIfAbsent(serverId, x -> new HashMap<>())
+ .computeIfAbsent(partitionId, x -> new HashSet<>())
+ .addAll(
+ updatedReassignServers.stream().map(x ->
x.getId()).collect(Collectors.toSet()));
+
+ if (serverHasReplaced) {
+ for (ShuffleServerInfo serverInfo : updatedReassignServers) {
+ newServerToPartitions
+ .computeIfAbsent(serverInfo, x -> new ArrayList<>())
+ .add(new PartitionRange(partitionId, partitionId));
+ }
+ }
+ }
+ }
+ if (!newServerToPartitions.isEmpty()) {
+ LOG.info(
+ "Register the new partition->servers assignment on reassign. {}",
+ newServerToPartitions);
+ registerShuffleServers(id.get(), shuffleId, newServerToPartitions,
getRemoteStorageInfo());
+ }
+
+ LOG.info(
+ "Finished reassignOnBlockSendFailure request and cost {}(ms).
Reassign result: {}",
+ System.currentTimeMillis() - startTime,
+ reassignResult);
+
+ return handleInfo;
+ }
+ }
+
+ /**
+ * Creating the shuffleAssignmentInfo from the servers and partitionIds
+ *
+ * @param servers
+ * @param partitionIds
+ * @return
+ */
+ private ShuffleAssignmentsInfo createShuffleAssignmentsInfo(
+ Set<ShuffleServerInfo> servers, Set<Integer> partitionIds) {
+ Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new
HashMap<>();
+ List<PartitionRange> partitionRanges = new ArrayList<>();
+ for (Integer partitionId : partitionIds) {
+ newPartitionToServers.put(partitionId, new ArrayList<>(servers));
+ partitionRanges.add(new PartitionRange(partitionId, partitionId));
+ }
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new
HashMap<>();
+ for (ShuffleServerInfo server : servers) {
+ serverToPartitionRanges.put(server, partitionRanges);
+ }
+ return new ShuffleAssignmentsInfo(newPartitionToServers,
serverToPartitionRanges);
+ }
+
+ /** Request the new shuffle-servers to replace faulty server. */
+ private Set<ShuffleServerInfo> reassignServerForTask(
+ int shuffleId,
+ Set<Integer> partitionIds,
+ Set<String> excludedServers,
+ int requiredServerNum) {
+ AtomicReference<Set<ShuffleServerInfo>> replacementsRef =
+ new AtomicReference<>(new HashSet<>());
+ requestShuffleAssignment(
+ shuffleId,
+ requiredServerNum,
+ 1,
+ requiredServerNum,
+ 1,
+ excludedServers,
+ shuffleAssignmentsInfo -> {
+ if (shuffleAssignmentsInfo == null) {
+ return null;
+ }
+ Set<ShuffleServerInfo> replacements =
+ shuffleAssignmentsInfo.getPartitionToServers().values().stream()
+ .flatMap(x -> x.stream())
+ .collect(Collectors.toSet());
+ replacementsRef.set(replacements);
+ return createShuffleAssignmentsInfo(replacements, partitionIds);
+ });
+ return replacementsRef.get();
+ }
+
+ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
+ int shuffleId,
+ int partitionNum,
+ int partitionNumPerRange,
+ int assignmentShuffleServerNumber,
+ int estimateTaskConcurrency,
+ Set<String> faultyServerIds,
+ Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo>
reassignmentHandler) {
+ Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ ClientUtils.validateClientType(clientType);
+ assignmentTags.add(clientType);
+ long retryInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+ int retryTimes =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+ faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
+ try {
+ return RetryUtils.retry(
+ () -> {
+ ShuffleAssignmentsInfo response =
+ shuffleWriteClient.getShuffleAssignments(
+ id.get(),
+ shuffleId,
+ partitionNum,
+ partitionNumPerRange,
+ assignmentTags,
+ assignmentShuffleServerNumber,
+ estimateTaskConcurrency,
+ faultyServerIds);
+ LOG.info("Finished reassign");
+ if (reassignmentHandler != null) {
+ response = reassignmentHandler.apply(response);
+ }
+ registerShuffleServers(
+ id.get(), shuffleId, response.getServerToPartitionRanges(),
getRemoteStorageInfo());
+ return response.getPartitionToServers();
+ },
+ retryInterval,
+ retryTimes);
+ } catch (Throwable throwable) {
+ throw new RssException("registerShuffle failed!", throwable);
+ }
+ }
+
+ protected Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
+ int shuffleId,
+ int partitionNum,
+ int partitionNumPerRange,
+ int assignmentShuffleServerNumber,
+ int estimateTaskConcurrency,
+ Set<String> faultyServerIds,
+ int stageAttemptNumber) {
+ Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ ClientUtils.validateClientType(clientType);
+ assignmentTags.add(clientType);
+
+ long retryInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+ int retryTimes =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+ faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
+ try {
+ return RetryUtils.retry(
+ () -> {
+ ShuffleAssignmentsInfo response =
+ shuffleWriteClient.getShuffleAssignments(
+ appId,
+ shuffleId,
+ partitionNum,
+ partitionNumPerRange,
+ assignmentTags,
+ assignmentShuffleServerNumber,
+ estimateTaskConcurrency,
+ faultyServerIds);
+ registerShuffleServers(
+ appId,
+ shuffleId,
+ response.getServerToPartitionRanges(),
+ getRemoteStorageInfo(),
+ stageAttemptNumber);
+ return response.getPartitionToServers();
+ },
+ retryInterval,
+ retryTimes);
+ } catch (Throwable throwable) {
+ throw new RssException("registerShuffle failed!", throwable);
+ }
+ }
+
+ protected void registerShuffleServers(
+ String appId,
+ int shuffleId,
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
+ RemoteStorageInfo remoteStorage,
+ int stageAttemptNumber) {
+ if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
+ return;
+ }
+ LOG.info("Start to register shuffleId {}", shuffleId);
+ long start = System.currentTimeMillis();
+ serverToPartitionRanges.entrySet().stream()
+ .forEach(
+ entry -> {
+ shuffleWriteClient.registerShuffle(
+ entry.getKey(),
+ appId,
+ shuffleId,
+ entry.getValue(),
+ remoteStorage,
+ ShuffleDataDistributionType.NORMAL,
+ maxConcurrencyPerPartitionToWrite,
+ stageAttemptNumber);
+ });
+ LOG.info(
+ "Finish register shuffleId {} with {} ms", shuffleId,
(System.currentTimeMillis() - start));
+ }
+
+ @VisibleForTesting
+ protected void registerShuffleServers(
+ String appId,
+ int shuffleId,
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
+ RemoteStorageInfo remoteStorage) {
+ if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
+ return;
+ }
+ LOG.info("Start to register shuffleId[{}]", shuffleId);
+ long start = System.currentTimeMillis();
+ Set<Map.Entry<ShuffleServerInfo, List<PartitionRange>>> entries =
+ serverToPartitionRanges.entrySet();
+ entries.stream()
+ .forEach(
+ entry -> {
+ shuffleWriteClient.registerShuffle(
+ entry.getKey(),
+ appId,
+ shuffleId,
+ entry.getValue(),
+ remoteStorage,
+ dataDistributionType,
+ maxConcurrencyPerPartitionToWrite);
+ });
+ LOG.info(
+ "Finish register shuffleId[{}] with {} ms",
+ shuffleId,
+ (System.currentTimeMillis() - start));
+ }
+
+ protected RemoteStorageInfo getRemoteStorageInfo() {
+ String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
+ RemoteStorageInfo defaultRemoteStorage =
+ new
RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(),
""));
+ return ClientUtils.fetchRemoteStorage(
+ appId, defaultRemoteStorage, dynamicConfEnabled, storageType,
shuffleWriteClient);
+ }
+
+ public boolean isRssResubmitStage() {
+ return rssResubmitStage;
+ }
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index b21360041..52ded5db7 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -37,12 +37,6 @@ public interface RssShuffleManagerInterface {
/** @return the unique spark id for rss shuffle */
String getAppId();
- /**
- * @return the maximum number of fetch failures per shuffle partition before
that shuffle stage
- * should be re-submitted
- */
- int getMaxFetchFailures();
-
/**
* @param shuffleId the shuffle id to query
* @return the num of partitions(a.k.a reduce tasks) for shuffle with
shuffle id.
@@ -63,6 +57,8 @@ public interface RssShuffleManagerInterface {
*/
void unregisterAllMapOutput(int shuffleId) throws SparkException;
+ BlockIdManager getBlockIdManager();
+
/**
* Get ShuffleHandleInfo with ShuffleId
*
@@ -71,6 +67,12 @@ public interface RssShuffleManagerInterface {
*/
ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId);
+ /**
+ * @return the maximum number of fetch failures per shuffle partition before
that shuffle stage
+ * should be re-submitted
+ */
+ int getMaxFetchFailures();
+
/**
* Add the shuffleServer that failed to write to the failure list
*
@@ -78,11 +80,8 @@ public interface RssShuffleManagerInterface {
*/
void addFailuresShuffleServerInfos(String shuffleServerId);
- boolean reassignAllShuffleServersForWholeStage(
- int stageId, int stageAttemptNumber, int shuffleId, int numMaps);
+ boolean reassignOnStageResubmit(int stageId, int stageAttemptNumber, int
shuffleId, int numMaps);
MutableShuffleHandleInfo reassignOnBlockSendFailure(
int shuffleId, Map<Integer, List<ReceivingFailureServer>>
partitionToFailureServers);
-
- BlockIdManager getBlockIdManager();
}
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index a4ff727b5..425f03e65 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -30,6 +30,7 @@ import java.util.stream.Collectors;
import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.StreamObserver;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -188,10 +189,34 @@ public class ShuffleManagerGrpcService extends
ShuffleManagerImplBase {
}
@Override
- public void getPartitionToShufflerServer(
+ public void getPartitionToShufflerServerWithStageRetry(
RssProtos.PartitionToShuffleServerRequest request,
- StreamObserver<RssProtos.PartitionToShuffleServerResponse>
responseObserver) {
- RssProtos.PartitionToShuffleServerResponse reply;
+ StreamObserver<RssProtos.ReassignOnStageRetryResponse> responseObserver)
{
+ RssProtos.ReassignOnStageRetryResponse reply;
+ RssProtos.StatusCode code;
+ int shuffleId = request.getShuffleId();
+ StageAttemptShuffleHandleInfo shuffleHandle =
+ (StageAttemptShuffleHandleInfo)
shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
+ if (shuffleHandle != null) {
+ code = RssProtos.StatusCode.SUCCESS;
+ reply =
+ RssProtos.ReassignOnStageRetryResponse.newBuilder()
+ .setStatus(code)
+
.setShuffleHandleInfo(StageAttemptShuffleHandleInfo.toProto(shuffleHandle))
+ .build();
+ } else {
+ code = RssProtos.StatusCode.INVALID_REQUEST;
+ reply =
RssProtos.ReassignOnStageRetryResponse.newBuilder().setStatus(code).build();
+ }
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ }
+
+ @Override
+ public void getPartitionToShufflerServerWithBlockRetry(
+ RssProtos.PartitionToShuffleServerRequest request,
+ StreamObserver<RssProtos.ReassignOnBlockSendFailureResponse>
responseObserver) {
+ RssProtos.ReassignOnBlockSendFailureResponse reply;
RssProtos.StatusCode code;
int shuffleId = request.getShuffleId();
MutableShuffleHandleInfo shuffleHandle =
@@ -199,13 +224,13 @@ public class ShuffleManagerGrpcService extends
ShuffleManagerImplBase {
if (shuffleHandle != null) {
code = RssProtos.StatusCode.SUCCESS;
reply =
- RssProtos.PartitionToShuffleServerResponse.newBuilder()
+ RssProtos.ReassignOnBlockSendFailureResponse.newBuilder()
.setStatus(code)
-
.setShuffleHandleInfo(MutableShuffleHandleInfo.toProto(shuffleHandle))
+ .setHandle(MutableShuffleHandleInfo.toProto(shuffleHandle))
.build();
} else {
code = RssProtos.StatusCode.INVALID_REQUEST;
- reply =
RssProtos.PartitionToShuffleServerResponse.newBuilder().setStatus(code).build();
+ reply =
RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(code).build();
}
responseObserver.onNext(reply);
responseObserver.onCompleted();
@@ -220,7 +245,7 @@ public class ShuffleManagerGrpcService extends
ShuffleManagerImplBase {
int shuffleId = request.getShuffleId();
int numPartitions = request.getNumPartitions();
boolean needReassign =
- shuffleManager.reassignAllShuffleServersForWholeStage(
+ shuffleManager.reassignOnStageResubmit(
stageId, stageAttemptNumber, shuffleId, numPartitions);
RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
RssProtos.ReassignServersResponse reply =
@@ -236,10 +261,10 @@ public class ShuffleManagerGrpcService extends
ShuffleManagerImplBase {
public void reassignOnBlockSendFailure(
org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureRequest
request,
io.grpc.stub.StreamObserver<
-
org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureResponse>
+
org.apache.uniffle.proto.RssProtos.ReassignOnBlockSendFailureResponse>
responseObserver) {
RssProtos.StatusCode code = RssProtos.StatusCode.INTERNAL_ERROR;
- RssProtos.RssReassignOnBlockSendFailureResponse reply;
+ RssProtos.ReassignOnBlockSendFailureResponse reply;
try {
LOG.info(
"Accepted reassign request on block sent failure for shuffleId: {},
stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {}",
@@ -257,14 +282,14 @@ public class ShuffleManagerGrpcService extends
ShuffleManagerImplBase {
Map.Entry::getKey, x ->
ReceivingFailureServer.fromProto(x.getValue()))));
code = RssProtos.StatusCode.SUCCESS;
reply =
- RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
+ RssProtos.ReassignOnBlockSendFailureResponse.newBuilder()
.setStatus(code)
.setHandle(MutableShuffleHandleInfo.toProto(handle))
.build();
} catch (Exception e) {
LOG.error("Errors on reassigning when block send failure.", e);
reply =
- RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
+ RssProtos.ReassignOnBlockSendFailureResponse.newBuilder()
.setStatus(code)
.setMsg(e.getMessage())
.build();
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index ebc590af5..080ba1e33 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -70,6 +70,15 @@ public class DataPusherTest {
String appId,
List<ShuffleBlockInfo> shuffleBlockInfoList,
Supplier<Boolean> needCancelRequest) {
+ return sendShuffleData(appId, 0, shuffleBlockInfoList,
needCancelRequest);
+ }
+
+ @Override
+ public SendShuffleDataResult sendShuffleData(
+ String appId,
+ int stageAttemptNumber,
+ List<ShuffleBlockInfo> shuffleBlockInfoList,
+ Supplier<Boolean> needCancelRequest) {
return fakedShuffleDataResult;
}
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index 66bb26de8..df40e9d16 100644
---
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -23,13 +23,11 @@ import java.util.Map;
import java.util.Set;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
-import org.apache.spark.shuffle.handle.ShuffleHandleInfoBase;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.shuffle.BlockIdManager;
-import static org.mockito.Mockito.mock;
-
public class DummyRssShuffleManager implements RssShuffleManagerInterface {
public Set<Integer> unregisteredShuffleIds = new LinkedHashSet<>();
@@ -38,11 +36,6 @@ public class DummyRssShuffleManager implements
RssShuffleManagerInterface {
return "testAppId";
}
- @Override
- public int getMaxFetchFailures() {
- return 2;
- }
-
@Override
public int getPartitionNum(int shuffleId) {
return 16;
@@ -59,15 +52,25 @@ public class DummyRssShuffleManager implements
RssShuffleManagerInterface {
}
@Override
- public ShuffleHandleInfoBase getShuffleHandleInfoByShuffleId(int shuffleId) {
+ public BlockIdManager getBlockIdManager() {
+ return null;
+ }
+
+ @Override
+ public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
return null;
}
+ @Override
+ public int getMaxFetchFailures() {
+ return 0;
+ }
+
@Override
public void addFailuresShuffleServerInfos(String shuffleServerId) {}
@Override
- public boolean reassignAllShuffleServersForWholeStage(
+ public boolean reassignOnStageResubmit(
int stageId, int stageAttemptNumber, int shuffleId, int numMaps) {
return false;
}
@@ -75,11 +78,6 @@ public class DummyRssShuffleManager implements
RssShuffleManagerInterface {
@Override
public MutableShuffleHandleInfo reassignOnBlockSendFailure(
int shuffleId, Map<Integer, List<ReceivingFailureServer>>
partitionToFailureServers) {
- return mock(MutableShuffleHandleInfo.class);
- }
-
- @Override
- public BlockIdManager getBlockIdManager() {
return null;
}
}
diff --git
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
index 6dc2abbf6..ac3fbda7e 100644
---
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
+++
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java
@@ -35,7 +35,7 @@ import static org.mockito.Mockito.mock;
public class ShuffleManagerGrpcServiceTest {
// create mock of RssShuffleManagerInterface.
- private static RssShuffleManagerInterface mockShuffleManager;
+ private static RssShuffleManagerBase mockShuffleManager;
private static final String appId = "app-123";
private static final int maxFetchFailures = 2;
private static final int shuffleId = 0;
@@ -65,7 +65,7 @@ public class ShuffleManagerGrpcServiceTest {
@BeforeAll
public static void setup() {
- mockShuffleManager = mock(RssShuffleManagerInterface.class);
+ mockShuffleManager = mock(RssShuffleManagerBase.class);
Mockito.when(mockShuffleManager.getAppId()).thenReturn(appId);
Mockito.when(mockShuffleManager.getNumMaps(shuffleId)).thenReturn(numMaps);
Mockito.when(mockShuffleManager.getPartitionNum(shuffleId)).thenReturn(numReduces);
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 0f2830721..eb48b2f38 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -37,13 +37,13 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
-import org.apache.spark.SparkException;
import org.apache.spark.TaskContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.DataPusher;
@@ -54,20 +54,10 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.uniffle.client.api.ShuffleManagerClient;
-import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
-import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.client.util.RssClientConfig;
-import org.apache.uniffle.common.ClientType;
-import org.apache.uniffle.common.PartitionRange;
-import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.ShuffleAssignmentsInfo;
-import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
@@ -76,7 +66,6 @@ import
org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shuffle.RssShuffleClientFactory;
@@ -94,10 +83,6 @@ public class RssShuffleManager extends RssShuffleManagerBase
{
private final long heartbeatInterval;
private final long heartbeatTimeout;
private ScheduledExecutorService heartBeatScheduledExecutorService;
- private SparkConf sparkConf;
- private String appId = "";
- private String clientType;
- private ShuffleWriteClient shuffleWriteClient;
private Map<String, Set<Long>> taskToSuccessBlockIds =
JavaUtils.newConcurrentMap();
private Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
JavaUtils.newConcurrentMap();
@@ -109,41 +94,16 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private final int dataCommitPoolSize;
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
private boolean heartbeatStarted = false;
- private boolean dynamicConfEnabled;
private final int maxFailures;
private final boolean speculation;
private final BlockIdLayout blockIdLayout;
private final String user;
private final String uuid;
private DataPusher dataPusher;
- private final int maxConcurrencyPerPartitionToWrite;
-
private final Map<Integer, Integer> shuffleIdToPartitionNum =
JavaUtils.newConcurrentMap();
private final Map<Integer, Integer> shuffleIdToNumMapTasks =
JavaUtils.newConcurrentMap();
private GrpcServer shuffleManagerServer;
private ShuffleManagerGrpcService service;
- private ShuffleManagerClient shuffleManagerClient;
- /**
- * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is
dynamically allocated.
- * ShuffleServer is not obtained from RssShuffleHandle, but from this
mapping.
- */
- private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo =
- JavaUtils.newConcurrentMap();
- /** Whether to enable the dynamic shuffleServer function rewrite and reread
functions */
- private boolean rssResubmitStage;
-
- private boolean taskBlockSendFailureRetry;
-
- private boolean shuffleManagerRpcServiceEnabled;
- /** A list of shuffleServer for Write failures */
- private Set<String> failuresShuffleServerIds = Sets.newHashSet();
- /**
- * Prevent multiple tasks from reporting FetchFailed, resulting in multiple
ShuffleServer
- * assignments, stageID, Attemptnumber Whether to reassign the combination
flag;
- */
- private Map<String, Boolean> serverAssignedInfos =
JavaUtils.newConcurrentMap();
-
- private boolean blockIdSelfManagedEnabled;
public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
@@ -178,15 +138,11 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.maxConcurrencyPerPartitionToWrite =
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
LOG.info(
- "Check quorum config ["
- + dataReplica
- + ":"
- + dataReplicaWrite
- + ":"
- + dataReplicaRead
- + ":"
- + dataReplicaSkipEnabled
- + "]");
+ "Check quorum config [{}:{}:{}:{}]",
+ dataReplica,
+ dataReplicaWrite,
+ dataReplicaRead,
+ dataReplicaSkipEnabled);
RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite,
dataReplicaRead);
this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
@@ -212,10 +168,11 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.rssResubmitStage =
rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
&& RssSparkShuffleUtils.isStageResubmitSupported();
- this.taskBlockSendFailureRetry =
rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
+ this.taskBlockSendFailureRetryEnabled =
+ rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
this.blockIdSelfManagedEnabled =
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
this.shuffleManagerRpcServiceEnabled =
- taskBlockSendFailureRetry || rssResubmitStage ||
blockIdSelfManagedEnabled;
+ taskBlockSendFailureRetryEnabled || rssResubmitStage ||
blockIdSelfManagedEnabled;
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
if (isDriver) {
heartBeatScheduledExecutorService =
@@ -278,6 +235,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
poolSize,
keepAliveTime);
}
+ this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+ this.rssStageResubmitManager = new RssStageResubmitManager();
}
// This method is called in Spark driver side,
@@ -353,8 +312,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
ClientUtils.fetchRemoteStorage(
appId, defaultRemoteStorage, dynamicConfEnabled, storageType,
shuffleWriteClient);
- int partitionNumPerRange =
sparkConf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE);
-
// get all register info according to coordinator's response
Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
@@ -362,44 +319,32 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
+ int estimateTaskConcurrency =
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
- // retryInterval must bigger than `rss.server.heartbeat.interval`, or
maybe it will return the
- // same result
- long retryInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
- int retryTimes =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
-
- Map<Integer, List<ShuffleServerInfo>> partitionToServers;
- try {
- partitionToServers =
- RetryUtils.retry(
- () -> {
- ShuffleAssignmentsInfo response =
- shuffleWriteClient.getShuffleAssignments(
- appId,
- shuffleId,
- dependency.partitioner().numPartitions(),
- partitionNumPerRange,
- assignmentTags,
- requiredShuffleServerNumber,
- -1);
- registerShuffleServers(
- appId, shuffleId, response.getServerToPartitionRanges(),
remoteStorage);
- return response.getPartitionToServers();
- },
- retryInterval,
- retryTimes);
- } catch (Throwable throwable) {
- throw new RssException("registerShuffle failed!", throwable);
- }
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ requestShuffleAssignment(
+ shuffleId,
+ dependency.partitioner().numPartitions(),
+ 1,
+ requiredShuffleServerNumber,
+ estimateTaskConcurrency,
+ rssStageResubmitManager.getServerIdBlackList(),
+ 0);
startHeartbeat();
shuffleIdToPartitionNum.putIfAbsent(shuffleId,
dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId,
dependency.rdd().partitions().length);
- if (shuffleManagerRpcServiceEnabled) {
- MutableShuffleHandleInfo handleInfo =
+ if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ ShuffleHandleInfo handleInfo =
new MutableShuffleHandleInfo(shuffleId, partitionToServers,
remoteStorage);
- shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+ StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo =
+ new StageAttemptShuffleHandleInfo(shuffleId, remoteStorage,
handleInfo);
+ shuffleHandleInfoManager.register(shuffleId,
stageAttemptShuffleHandleInfo);
+ } else if (shuffleManagerRpcServiceEnabled &&
taskBlockSendFailureRetryEnabled) {
+ ShuffleHandleInfo shuffleHandleInfo =
+ new MutableShuffleHandleInfo(shuffleId, partitionToServers,
remoteStorage);
+ shuffleHandleInfoManager.register(shuffleId, shuffleHandleInfo);
}
Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
@@ -435,37 +380,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
}
- @VisibleForTesting
- protected void registerShuffleServers(
- String appId,
- int shuffleId,
- Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
- RemoteStorageInfo remoteStorage) {
- if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
- return;
- }
- LOG.info("Start to register shuffleId[" + shuffleId + "]");
- long start = System.currentTimeMillis();
- serverToPartitionRanges.entrySet().stream()
- .forEach(
- entry -> {
- shuffleWriteClient.registerShuffle(
- entry.getKey(),
- appId,
- shuffleId,
- entry.getValue(),
- remoteStorage,
- ShuffleDataDistributionType.NORMAL,
- maxConcurrencyPerPartitionToWrite);
- });
- LOG.info(
- "Finish register shuffleId["
- + shuffleId
- + "] with "
- + (System.currentTimeMillis() - start)
- + " ms");
- }
-
@VisibleForTesting
protected void registerCoordinator() {
String coordinators =
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
@@ -493,9 +407,12 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ // In Stage Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId
+ shuffleHandleInfo =
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+ } else if (shuffleManagerRpcServiceEnabled &&
taskBlockSendFailureRetryEnabled) {
+ // In Block Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId
+ shuffleHandleInfo =
getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
} else {
shuffleHandleInfo =
new SimpleShuffleHandleInfo(
@@ -563,9 +480,12 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
+ "]");
start = System.currentTimeMillis();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ // In Stage Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId.
+ shuffleHandleInfo =
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+ } else if (shuffleManagerRpcServiceEnabled &&
taskBlockSendFailureRetryEnabled) {
+ // In Block Retry mode, Get the ShuffleServer list from the Driver
based on the shuffleId
+ shuffleHandleInfo =
getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
} else {
shuffleHandleInfo =
new SimpleShuffleHandleInfo(
@@ -764,16 +684,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return appId;
}
- /**
- * @return the maximum number of fetch failures per shuffle partition before
that shuffle stage
- * should be recomputed
- */
- @Override
- public int getMaxFetchFailures() {
- final String TASK_MAX_FAILURE = "spark.task.maxFailures";
- return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1);
- }
-
/**
* @param shuffleId the shuffleId to query
* @return the num of partitions(a.k.a reduce tasks) for shuffle with
shuffle id.
@@ -812,170 +722,14 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return taskToFailedBlockSendTracker.get(taskId);
}
- @Override
- public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
- return shuffleIdToShuffleHandleInfo.get(shuffleId);
- }
-
- private ShuffleManagerClient createShuffleManagerClient(String host, int
port) {
- // Host can be inferred from `spark.driver.bindAddress`, which would be
set when SparkContext is
- // constructed.
- return ShuffleManagerClientFactory.getInstance()
- .createShuffleManagerClient(ClientType.GRPC, host, port);
- }
-
- private ShuffleManagerClient getOrCreateShuffleManagerClient() {
- if (shuffleManagerClient == null) {
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- this.shuffleManagerClient =
- ShuffleManagerClientFactory.getInstance()
- .createShuffleManagerClient(ClientType.GRPC, driver, port);
- }
- return shuffleManagerClient;
- }
-
- /**
- * Get the ShuffleServer list from the Driver based on the shuffleId
- *
- * @param shuffleId shuffleId
- * @return ShuffleHandleInfo
- */
- private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int
shuffleId) {
- RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
- new RssPartitionToShuffleServerRequest(shuffleId);
- RssPartitionToShuffleServerResponse handleInfoResponse =
- getOrCreateShuffleManagerClient()
- .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
- MutableShuffleHandleInfo shuffleHandleInfo =
-
MutableShuffleHandleInfo.fromProto(handleInfoResponse.getShuffleHandleInfoProto());
- return shuffleHandleInfo;
- }
-
- /**
- * Add the shuffleServer that failed to write to the failure list
- *
- * @param shuffleServerId
- */
- @Override
- public void addFailuresShuffleServerInfos(String shuffleServerId) {
- failuresShuffleServerIds.add(shuffleServerId);
- }
-
- /**
- * Reassign the ShuffleServer list for ShuffleId
- *
- * @param shuffleId
- * @param numPartitions
- */
- @Override
- public synchronized boolean reassignAllShuffleServersForWholeStage(
- int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
- String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
- Boolean needReassign =
serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
- if (!needReassign) {
- int requiredShuffleServerNumber =
- RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
- int estimateTaskConcurrency =
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
- /** Before reassigning ShuffleServer, clear the ShuffleServer list in
ShuffleWriteClient. */
- shuffleWriteClient.unregisterShuffle(appId, shuffleId);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers =
- requestShuffleAssignment(
- shuffleId,
- numPartitions,
- 1,
- requiredShuffleServerNumber,
- estimateTaskConcurrency,
- failuresShuffleServerIds);
- /**
- * we need to clear the metadata of the completed task, otherwise some
of the stage's data
- * will be lost
- */
- try {
- unregisterAllMapOutput(shuffleId);
- } catch (SparkException e) {
- LOG.error("Clear MapoutTracker Meta failed!");
- throw new RssException("Clear MapoutTracker Meta failed!", e);
- }
- MutableShuffleHandleInfo handleInfo =
- new MutableShuffleHandleInfo(shuffleId, partitionToServers,
getRemoteStorageInfo());
- shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
- serverAssignedInfos.put(stageIdAndAttempt, true);
- return true;
- } else {
- LOG.info(
- "The Stage:{} has been reassigned in an Attempt{},Return without
performing any operation",
- stageId,
- stageAttemptNumber);
- return false;
- }
- }
-
- @Override
- public MutableShuffleHandleInfo reassignOnBlockSendFailure(
- int shuffleId, Map<Integer, List<ReceivingFailureServer>>
partitionToFailureServers) {
- throw new RssException("Illegal access for reassignOnBlockSendFailure that
is not supported.");
- }
-
private ShuffleServerInfo assignShuffleServer(int shuffleId, String
faultyShuffleServerId) {
Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
- faultyServerIds.addAll(failuresShuffleServerIds);
+ faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
- requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
+ requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, 0);
if (partitionToServers.get(0) != null && partitionToServers.get(0).size()
== 1) {
return partitionToServers.get(0).get(0);
}
return null;
}
-
- private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
- int shuffleId,
- int partitionNum,
- int partitionNumPerRange,
- int assignmentShuffleServerNumber,
- int estimateTaskConcurrency,
- Set<String> faultyServerIds) {
- Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
- ClientUtils.validateClientType(clientType);
- assignmentTags.add(clientType);
-
- long retryInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
- int retryTimes =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
- faultyServerIds.addAll(failuresShuffleServerIds);
- try {
- return RetryUtils.retry(
- () -> {
- ShuffleAssignmentsInfo response =
- shuffleWriteClient.getShuffleAssignments(
- appId,
- shuffleId,
- partitionNum,
- partitionNumPerRange,
- assignmentTags,
- assignmentShuffleServerNumber,
- estimateTaskConcurrency,
- faultyServerIds);
- registerShuffleServers(
- appId, shuffleId, response.getServerToPartitionRanges(),
getRemoteStorageInfo());
- return response.getPartitionToServers();
- },
- retryInterval,
- retryTimes);
- } catch (Throwable throwable) {
- throw new RssException("registerShuffle failed!", throwable);
- }
- }
-
- private RemoteStorageInfo getRemoteStorageInfo() {
- String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
- RemoteStorageInfo defaultRemoteStorage =
- new
RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(),
""));
- return ClientUtils.fetchRemoteStorage(
- appId, defaultRemoteStorage, dynamicConfEnabled, storageType,
shuffleWriteClient);
- }
-
- public boolean isRssResubmitStage() {
- return rssResubmitStage;
- }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 81d54ec10..b40b55028 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -551,6 +551,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
RssReportShuffleWriteFailureResponse response =
shuffleManagerClient.reportShuffleWriteFailure(req);
if (response.getReSubmitWholeStage()) {
+ // The shuffle server is reassigned.
RssReassignServersRequest rssReassignServersRequest =
new RssReassignServersRequest(
taskContext.stageId(),
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index dc1de59ef..b5e89d1da 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -18,10 +18,7 @@
package org.apache.spark.shuffle;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -30,7 +27,6 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Function;
import java.util.stream.Collectors;
import scala.Tuple2;
@@ -41,13 +37,11 @@ import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
-import org.apache.commons.collections4.CollectionUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
-import org.apache.spark.SparkException;
import org.apache.spark.TaskContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.executor.ShuffleReadMetrics;
@@ -55,6 +49,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.DataPusher;
@@ -67,19 +62,10 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
-import org.apache.uniffle.client.api.ShuffleManagerClient;
-import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
-import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.client.util.RssClientConfig;
-import org.apache.uniffle.common.ClientType;
-import org.apache.uniffle.common.PartitionRange;
-import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
@@ -87,10 +73,8 @@ import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.GrpcServer;
-import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shuffle.RssShuffleClientFactory;
@@ -105,10 +89,8 @@ import static
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER
public class RssShuffleManager extends RssShuffleManagerBase {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleManager.class);
- private final String clientType;
private final long heartbeatInterval;
private final long heartbeatTimeout;
- private AtomicReference<String> id = new AtomicReference<>();
private final int dataReplica;
private final int dataReplicaWrite;
private final int dataReplicaRead;
@@ -119,10 +101,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private final Map<String, FailedBlockSendTracker>
taskToFailedBlockSendTracker;
private ScheduledExecutorService heartBeatScheduledExecutorService;
private boolean heartbeatStarted = false;
- private boolean dynamicConfEnabled;
- private final ShuffleDataDistributionType dataDistributionType;
private final BlockIdLayout blockIdLayout;
- private final int maxConcurrencyPerPartitionToWrite;
private final int maxFailures;
private final boolean speculation;
private String user;
@@ -135,31 +114,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
private ShuffleManagerGrpcService service;
private GrpcServer shuffleManagerServer;
- /** used by columnar rss shuffle writer implementation */
- protected SparkConf sparkConf;
-
- protected ShuffleWriteClient shuffleWriteClient;
-
- private ShuffleManagerClient shuffleManagerClient;
- /** Whether to enable the dynamic shuffleServer function rewrite and reread
functions */
- private boolean rssResubmitStage;
-
- private boolean taskBlockSendFailureRetryEnabled;
-
- private boolean shuffleManagerRpcServiceEnabled;
- /** A list of shuffleServer for Write failures */
- private Set<String> failuresShuffleServerIds;
- /**
- * Prevent multiple tasks from reporting FetchFailed, resulting in multiple
ShuffleServer
- * assignments, stageID, Attemptnumber Whether to reassign the combination
flag;
- */
- private Map<String, Boolean> serverAssignedInfos;
-
- private final int partitionReassignMaxServerNum;
-
- private final ShuffleHandleInfoManager shuffleHandleInfoManager = new
ShuffleHandleInfoManager();
- private boolean blockIdSelfManagedEnabled;
-
public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;
boolean supportsRelocation =
@@ -307,10 +261,10 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
failedTaskIds,
poolSize,
keepAliveTime);
- this.failuresShuffleServerIds = Sets.newHashSet();
- this.serverAssignedInfos = JavaUtils.newConcurrentMap();
this.partitionReassignMaxServerNum =
rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+ this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+ this.rssStageResubmitManager = new RssStageResubmitManager();
}
public CompletableFuture<Long> sendData(AddBlockEvent event) {
@@ -403,6 +357,8 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.dataPusher = dataPusher;
this.partitionReassignMaxServerNum =
rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+ this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+ this.rssStageResubmitManager = new RssStageResubmitManager();
}
// This method is called in Spark driver side,
@@ -444,6 +400,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
if (id.get() == null) {
id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid);
+ appId = id.get();
dataPusher.setRssAppId(id.get());
}
LOG.info("Generate application id used in rss: " + id.get());
@@ -478,43 +435,30 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
-
- // retryInterval must bigger than `rss.server.heartbeat.interval`, or
maybe it will return the
- // same result
- long retryInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
- int retryTimes =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
int estimateTaskConcurrency =
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers;
- try {
- partitionToServers =
- RetryUtils.retry(
- () -> {
- ShuffleAssignmentsInfo response =
- shuffleWriteClient.getShuffleAssignments(
- id.get(),
- shuffleId,
- dependency.partitioner().numPartitions(),
- 1,
- assignmentTags,
- requiredShuffleServerNumber,
- estimateTaskConcurrency);
- registerShuffleServers(
- id.get(), shuffleId,
response.getServerToPartitionRanges(), remoteStorage);
- return response.getPartitionToServers();
- },
- retryInterval,
- retryTimes);
- } catch (Throwable throwable) {
- throw new RssException("registerShuffle failed!", throwable);
- }
- startHeartbeat();
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ requestShuffleAssignment(
+ shuffleId,
+ dependency.partitioner().numPartitions(),
+ 1,
+ requiredShuffleServerNumber,
+ estimateTaskConcurrency,
+ rssStageResubmitManager.getServerIdBlackList(),
+ 0);
+ startHeartbeat();
shuffleIdToPartitionNum.putIfAbsent(shuffleId,
dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId,
dependency.rdd().partitions().length);
- if (shuffleManagerRpcServiceEnabled) {
- MutableShuffleHandleInfo handleInfo =
+ if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ ShuffleHandleInfo shuffleHandleInfo =
new MutableShuffleHandleInfo(shuffleId, partitionToServers,
remoteStorage);
+ StageAttemptShuffleHandleInfo handleInfo =
+ new StageAttemptShuffleHandleInfo(shuffleId, remoteStorage,
shuffleHandleInfo);
shuffleHandleInfoManager.register(shuffleId, handleInfo);
+ } else if (shuffleManagerRpcServiceEnabled &&
taskBlockSendFailureRetryEnabled) {
+ ShuffleHandleInfo shuffleHandleInfo =
+ new MutableShuffleHandleInfo(shuffleId, partitionToServers,
remoteStorage);
+ shuffleHandleInfoManager.register(shuffleId, shuffleHandleInfo);
}
Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
@@ -549,9 +493,12 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
+ shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+ } else if (shuffleManagerRpcServiceEnabled &&
taskBlockSendFailureRetryEnabled) {
+ // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
+ shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
} else {
shuffleHandleInfo =
new SimpleShuffleHandleInfo(
@@ -691,9 +638,12 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
final int partitionNum =
rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
+ if (shuffleManagerRpcServiceEnabled && rssResubmitStage) {
+ // In Stage Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
+ shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+ } else if (shuffleManagerRpcServiceEnabled &&
taskBlockSendFailureRetryEnabled) {
+ // In Block Retry mode, Get the ShuffleServer list from the Driver based
on the shuffleId.
+ shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
} else {
shuffleHandleInfo =
new SimpleShuffleHandleInfo(
@@ -940,39 +890,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
taskToFailedBlockSendTracker.remove(taskId);
}
- @VisibleForTesting
- protected void registerShuffleServers(
- String appId,
- int shuffleId,
- Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
- RemoteStorageInfo remoteStorage) {
- if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
- return;
- }
- LOG.info("Start to register shuffleId[" + shuffleId + "]");
- long start = System.currentTimeMillis();
- Set<Map.Entry<ShuffleServerInfo, List<PartitionRange>>> entries =
- serverToPartitionRanges.entrySet();
- entries.stream()
- .forEach(
- entry -> {
- shuffleWriteClient.registerShuffle(
- entry.getKey(),
- appId,
- shuffleId,
- entry.getValue(),
- remoteStorage,
- dataDistributionType,
- maxConcurrencyPerPartitionToWrite);
- });
- LOG.info(
- "Finish register shuffleId["
- + shuffleId
- + "] with "
- + (System.currentTimeMillis() - start)
- + " ms");
- }
-
@VisibleForTesting
protected void registerCoordinator() {
String coordinators =
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
@@ -1027,16 +944,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return id.get();
}
- /**
- * @return the maximum number of fetch failures per shuffle partition before
that shuffle stage
- * should be recomputed
- */
- @Override
- public int getMaxFetchFailures() {
- final String TASK_MAX_FAILURE = "spark.task.maxFailures";
- return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1);
- }
-
@Override
public int getPartitionNum(int shuffleId) {
return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
@@ -1138,282 +1045,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return taskToFailedBlockSendTracker.get(taskId);
}
- @Override
- public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
- return shuffleHandleInfoManager.get(shuffleId);
- }
-
- // todo: automatic close client when the client is idle to avoid too much
connections for spark
- // driver.
- private ShuffleManagerClient getOrCreateShuffleManagerClient() {
- if (shuffleManagerClient == null) {
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- this.shuffleManagerClient =
- ShuffleManagerClientFactory.getInstance()
- .createShuffleManagerClient(ClientType.GRPC, driver, port);
- }
- return shuffleManagerClient;
- }
-
- /**
- * Get the ShuffleServer list from the Driver based on the shuffleId
- *
- * @param shuffleId shuffleId
- * @return ShuffleHandleInfo
- */
- private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int
shuffleId) {
- RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
- new RssPartitionToShuffleServerRequest(shuffleId);
- RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
- getOrCreateShuffleManagerClient()
- .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
- MutableShuffleHandleInfo shuffleHandleInfo =
- MutableShuffleHandleInfo.fromProto(
- rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
- return shuffleHandleInfo;
- }
-
- /**
- * Add the shuffleServer that failed to write to the failure list
- *
- * @param shuffleServerId
- */
- @Override
- public void addFailuresShuffleServerInfos(String shuffleServerId) {
- failuresShuffleServerIds.add(shuffleServerId);
- }
-
- /**
- * Reassign the ShuffleServer list for ShuffleId
- *
- * @param shuffleId
- * @param numPartitions
- */
- @Override
- public synchronized boolean reassignAllShuffleServersForWholeStage(
- int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
- String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
- Boolean needReassign =
serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
- if (!needReassign) {
- int requiredShuffleServerNumber =
- RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
- int estimateTaskConcurrency =
RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
- /** Before reassigning ShuffleServer, clear the ShuffleServer list in
ShuffleWriteClient. */
- shuffleWriteClient.unregisterShuffle(id.get(), shuffleId);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers =
- requestShuffleAssignment(
- shuffleId,
- numPartitions,
- 1,
- requiredShuffleServerNumber,
- estimateTaskConcurrency,
- failuresShuffleServerIds,
- null);
- /**
- * we need to clear the metadata of the completed task, otherwise some
of the stage's data
- * will be lost
- */
- try {
- unregisterAllMapOutput(shuffleId);
- } catch (SparkException e) {
- LOG.error("Clear MapoutTracker Meta failed!");
- throw new RssException("Clear MapoutTracker Meta failed!", e);
- }
- MutableShuffleHandleInfo handleInfo =
- new MutableShuffleHandleInfo(shuffleId, partitionToServers,
getRemoteStorageInfo());
- shuffleHandleInfoManager.register(shuffleId, handleInfo);
- serverAssignedInfos.put(stageIdAndAttempt, true);
- return true;
- } else {
- LOG.info(
- "The Stage:{} has been reassigned in an Attempt{},Return without
performing any operation",
- stageId,
- stageAttemptNumber);
- return false;
- }
- }
-
- /** this is only valid on driver side that exposed to being invoked by grpc
server */
- @Override
- public MutableShuffleHandleInfo reassignOnBlockSendFailure(
- int shuffleId, Map<Integer, List<ReceivingFailureServer>>
partitionToFailureServers) {
- long startTime = System.currentTimeMillis();
- MutableShuffleHandleInfo handleInfo =
- (MutableShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId);
- synchronized (handleInfo) {
- // If the reassignment servers for one partition exceeds the max
reassign server num,
- // it should fast fail.
- handleInfo.checkPartitionReassignServerNum(
- partitionToFailureServers.keySet(), partitionReassignMaxServerNum);
-
- Map<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new
HashMap<>();
- // receivingFailureServer -> partitionId -> replacementServerIds. For
logging
- Map<String, Map<Integer, Set<String>>> reassignResult = new HashMap<>();
-
- for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
- partitionToFailureServers.entrySet()) {
- int partitionId = entry.getKey();
- for (ReceivingFailureServer receivingFailureServer : entry.getValue())
{
- StatusCode code = receivingFailureServer.getStatusCode();
- String serverId = receivingFailureServer.getServerId();
-
- boolean serverHasReplaced = false;
- Set<ShuffleServerInfo> replacements =
handleInfo.getReplacements(serverId);
- if (CollectionUtils.isEmpty(replacements)) {
- final int requiredServerNum = 1;
- Set<String> excludedServers = new
HashSet<>(handleInfo.listExcludedServers());
- excludedServers.add(serverId);
- replacements =
- reassignServerForTask(
- shuffleId, Sets.newHashSet(partitionId), excludedServers,
requiredServerNum);
- } else {
- serverHasReplaced = true;
- }
-
- Set<ShuffleServerInfo> updatedReassignServers =
- handleInfo.updateAssignment(partitionId, serverId, replacements);
-
- reassignResult
- .computeIfAbsent(serverId, x -> new HashMap<>())
- .computeIfAbsent(partitionId, x -> new HashSet<>())
- .addAll(
- updatedReassignServers.stream().map(x ->
x.getId()).collect(Collectors.toSet()));
-
- if (serverHasReplaced) {
- for (ShuffleServerInfo serverInfo : updatedReassignServers) {
- newServerToPartitions
- .computeIfAbsent(serverInfo, x -> new ArrayList<>())
- .add(new PartitionRange(partitionId, partitionId));
- }
- }
- }
- }
- if (!newServerToPartitions.isEmpty()) {
- LOG.info(
- "Register the new partition->servers assignment on reassign. {}",
- newServerToPartitions);
- registerShuffleServers(id.get(), shuffleId, newServerToPartitions,
getRemoteStorageInfo());
- }
-
- LOG.info(
- "Finished reassignOnBlockSendFailure request and cost {}(ms).
Reassign result: {}",
- System.currentTimeMillis() - startTime,
- reassignResult);
-
- return handleInfo;
- }
- }
-
- /**
- * Creating the shuffleAssignmentInfo from the servers and partitionIds
- *
- * @param servers
- * @param partitionIds
- * @return
- */
- private ShuffleAssignmentsInfo createShuffleAssignmentsInfo(
- Set<ShuffleServerInfo> servers, Set<Integer> partitionIds) {
- Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new
HashMap<>();
- List<PartitionRange> partitionRanges = new ArrayList<>();
- for (Integer partitionId : partitionIds) {
- newPartitionToServers.put(partitionId, new ArrayList<>(servers));
- partitionRanges.add(new PartitionRange(partitionId, partitionId));
- }
- Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new
HashMap<>();
- for (ShuffleServerInfo server : servers) {
- serverToPartitionRanges.put(server, partitionRanges);
- }
- return new ShuffleAssignmentsInfo(newPartitionToServers,
serverToPartitionRanges);
- }
-
- /** Request the new shuffle-servers to replace faulty server. */
- private Set<ShuffleServerInfo> reassignServerForTask(
- int shuffleId,
- Set<Integer> partitionIds,
- Set<String> excludedServers,
- int requiredServerNum) {
- AtomicReference<Set<ShuffleServerInfo>> replacementsRef =
- new AtomicReference<>(new HashSet<>());
- requestShuffleAssignment(
- shuffleId,
- requiredServerNum,
- 1,
- requiredServerNum,
- 1,
- excludedServers,
- shuffleAssignmentsInfo -> {
- if (shuffleAssignmentsInfo == null) {
- return null;
- }
- Set<ShuffleServerInfo> replacements =
- shuffleAssignmentsInfo.getPartitionToServers().values().stream()
- .flatMap(x -> x.stream())
- .collect(Collectors.toSet());
- replacementsRef.set(replacements);
- return createShuffleAssignmentsInfo(replacements, partitionIds);
- });
- return replacementsRef.get();
- }
-
- private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
- int shuffleId,
- int partitionNum,
- int partitionNumPerRange,
- int assignmentShuffleServerNumber,
- int estimateTaskConcurrency,
- Set<String> faultyServerIds,
- Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo>
reassignmentHandler) {
- Set<String> assignmentTags =
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
- ClientUtils.validateClientType(clientType);
- assignmentTags.add(clientType);
- long retryInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
- int retryTimes =
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
- faultyServerIds.addAll(failuresShuffleServerIds);
- try {
- return RetryUtils.retry(
- () -> {
- ShuffleAssignmentsInfo response =
- shuffleWriteClient.getShuffleAssignments(
- id.get(),
- shuffleId,
- partitionNum,
- partitionNumPerRange,
- assignmentTags,
- assignmentShuffleServerNumber,
- estimateTaskConcurrency,
- faultyServerIds);
- LOG.info("Finished the shuffle assignment request to
coordinator.");
- if (reassignmentHandler != null) {
- response = reassignmentHandler.apply(response);
- }
- LOG.info(
- "Register the partition->servers assignment. {}",
- response.getServerToPartitionRanges());
- registerShuffleServers(
- id.get(), shuffleId, response.getServerToPartitionRanges(),
getRemoteStorageInfo());
- return response.getPartitionToServers();
- },
- retryInterval,
- retryTimes);
- } catch (Throwable throwable) {
- throw new RssException("Errors on requesting shuffle assignment!",
throwable);
- }
- }
-
- private RemoteStorageInfo getRemoteStorageInfo() {
- String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
- RemoteStorageInfo defaultRemoteStorage =
- new
RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(),
""));
- return ClientUtils.fetchRemoteStorage(
- id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType,
shuffleWriteClient);
- }
-
- public boolean isRssResubmitStage() {
- return rssResubmitStage;
- }
-
@VisibleForTesting
public void setDataPusher(DataPusher dataPusher) {
this.dataPusher = dataPusher;
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 0805dfe35..d0eab5ae1 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -588,7 +588,8 @@ public class WriteBufferManagerTest {
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType dataDistributionType,
- int maxConcurrencyPerPartitionToWrite) {}
+ int maxConcurrencyPerPartitionToWrite,
+ int stageAttemptNumber) {}
@Override
public boolean sendCommit(
diff --git
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 7d8f53392..efd39e35a 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -35,6 +35,16 @@ import org.apache.uniffle.common.ShuffleServerInfo;
public interface ShuffleWriteClient {
+ default SendShuffleDataResult sendShuffleData(
+ String appId,
+ int stageAttemptNumber,
+ List<ShuffleBlockInfo> shuffleBlockInfoList,
+ Supplier<Boolean> needCancelRequest) {
+ throw new UnsupportedOperationException(
+ this.getClass().getName()
+ + " doesn't implement getShuffleAssignments with faultyServerIds");
+ }
+
SendShuffleDataResult sendShuffleData(
String appId,
List<ShuffleBlockInfo> shuffleBlockInfoList,
@@ -44,6 +54,25 @@ public interface ShuffleWriteClient {
void registerApplicationInfo(String appId, long timeoutMs, String user);
+ default void registerShuffle(
+ ShuffleServerInfo shuffleServerInfo,
+ String appId,
+ int shuffleId,
+ List<PartitionRange> partitionRanges,
+ RemoteStorageInfo remoteStorage,
+ ShuffleDataDistributionType dataDistributionType,
+ int maxConcurrencyPerPartitionToWrite) {
+ registerShuffle(
+ shuffleServerInfo,
+ appId,
+ shuffleId,
+ partitionRanges,
+ remoteStorage,
+ dataDistributionType,
+ maxConcurrencyPerPartitionToWrite,
+ 0);
+ }
+
void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
String appId,
@@ -51,7 +80,8 @@ public interface ShuffleWriteClient {
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType dataDistributionType,
- int maxConcurrencyPerPartitionToWrite);
+ int maxConcurrencyPerPartitionToWrite,
+ int stageAttemptNumber);
boolean sendCommit(
Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int
shuffleId, int numMaps);
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index dc6ef420a..a79a8557a 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -163,6 +163,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
private boolean sendShuffleDataAsync(
String appId,
+ int stageAttemptNumber,
Map<ShuffleServerInfo, Map<Integer, Map<Integer,
List<ShuffleBlockInfo>>>> serverToBlocks,
Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
Map<Long, AtomicInteger> blockIdsSendSuccessTracker,
@@ -192,7 +193,11 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
// todo: compact unnecessary blocks that reach
replicaWrite
RssSendShuffleDataRequest request =
new RssSendShuffleDataRequest(
- appId, retryMax, retryIntervalMax,
shuffleIdToBlocks);
+ appId,
+ stageAttemptNumber,
+ retryMax,
+ retryIntervalMax,
+ shuffleIdToBlocks);
long s = System.currentTimeMillis();
RssSendShuffleDataResponse response =
getShuffleServerClient(ssi).sendShuffleData(request);
@@ -314,10 +319,19 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
});
}
+ @Override
+ public SendShuffleDataResult sendShuffleData(
+ String appId,
+ List<ShuffleBlockInfo> shuffleBlockInfoList,
+ Supplier<Boolean> needCancelRequest) {
+ return sendShuffleData(appId, 0, shuffleBlockInfoList, needCancelRequest);
+ }
+
/** The batch of sending belongs to the same task */
@Override
public SendShuffleDataResult sendShuffleData(
String appId,
+ int stageAttemptNumber,
List<ShuffleBlockInfo> shuffleBlockInfoList,
Supplier<Boolean> needCancelRequest) {
@@ -400,6 +414,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
boolean isAllSuccess =
sendShuffleDataAsync(
appId,
+ stageAttemptNumber,
primaryServerToBlocks,
primaryServerToBlockIds,
blockIdsSendSuccessTracker,
@@ -416,6 +431,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
LOG.info("The sending of primary round is failed partially, so start the
secondary round");
sendShuffleDataAsync(
appId,
+ stageAttemptNumber,
secondaryServerToBlocks,
secondaryServerToBlockIds,
blockIdsSendSuccessTracker,
@@ -545,7 +561,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType dataDistributionType,
- int maxConcurrencyPerPartitionToWrite) {
+ int maxConcurrencyPerPartitionToWrite,
+ int stageAttemptNumber) {
String user = null;
try {
user = UserGroupInformation.getCurrentUser().getShortUserName();
@@ -561,7 +578,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
remoteStorage,
user,
dataDistributionType,
- maxConcurrencyPerPartitionToWrite);
+ maxConcurrencyPerPartitionToWrite,
+ stageAttemptNumber);
RssRegisterShuffleResponse response =
getShuffleServerClient(shuffleServerInfo).registerShuffle(request);
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
index a77b0d3c7..9fefb98f6 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
@@ -30,6 +30,8 @@ import org.apache.uniffle.common.util.ByteBufUtils;
public class SendShuffleDataRequest extends RequestMessage {
private String appId;
private int shuffleId;
+
+ private int stageAttemptNumber;
private long requireId;
private Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks;
private long timestamp;
@@ -41,12 +43,24 @@ public class SendShuffleDataRequest extends RequestMessage {
long requireId,
Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks,
long timestamp) {
+ this(requestId, appId, shuffleId, 0, requireId, partitionToBlocks,
timestamp);
+ }
+
+ public SendShuffleDataRequest(
+ long requestId,
+ String appId,
+ int shuffleId,
+ int stageAttemptNumber,
+ long requireId,
+ Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks,
+ long timestamp) {
super(requestId);
this.appId = appId;
this.shuffleId = shuffleId;
this.requireId = requireId;
this.partitionToBlocks = partitionToBlocks;
this.timestamp = timestamp;
+ this.stageAttemptNumber = stageAttemptNumber;
}
@Override
@@ -146,6 +160,10 @@ public class SendShuffleDataRequest extends RequestMessage
{
this.timestamp = timestamp;
}
+ public int getStageAttemptNumber() {
+ return stageAttemptNumber;
+ }
+
@Override
public String getOperationType() {
return "sendShuffleData";
diff --git a/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java
b/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java
index 79e35ecab..ff8ac231c 100644
--- a/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java
+++ b/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java
@@ -35,6 +35,7 @@ public enum StatusCode {
ACCESS_DENIED(8),
INVALID_REQUEST(9),
NO_BUFFER_FOR_HUGE_PARTITION(10),
+ STAGE_RETRY_IGNORE(11),
UNKNOWN(-1);
static final Map<Integer, StatusCode> VALUE_MAP =
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index 6b6ee1ece..71dc66584 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -28,8 +28,8 @@ import
org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssGetShuffleResultResponse;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
import org.apache.uniffle.client.response.RssReassignServersResponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
@@ -40,12 +40,23 @@ public interface ShuffleManagerClient extends Closeable {
RssReportShuffleFetchFailureRequest request);
/**
- * Gets the mapping between partitions and ShuffleServer from the
ShuffleManager server
+ * In Stage Retry mode,Gets the mapping between partitions and ShuffleServer
from the
+ * ShuffleManager server.
*
* @param req request
* @return RssPartitionToShuffleServerResponse
*/
- RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+ RssReassignOnStageRetryResponse getPartitionToShufflerServerWithStageRetry(
+ RssPartitionToShuffleServerRequest req);
+
+ /**
+ * In Block Retry mode,Gets the mapping between partitions and ShuffleServer
from the
+ * ShuffleManager server.
+ *
+ * @param req request
+ * @return RssPartitionToShuffleServerResponse
+ */
+ RssReassignOnBlockSendFailureResponse
getPartitionToShufflerServerWithBlockRetry(
RssPartitionToShuffleServerRequest req);
RssReportShuffleWriteFailureResponse reportShuffleWriteFailure(
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 997778bc9..030afd03d 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -32,8 +32,8 @@ import
org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssGetShuffleResultResponse;
-import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
+import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
import org.apache.uniffle.client.response.RssReassignServersResponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
@@ -90,14 +90,25 @@ public class ShuffleManagerGrpcClient extends GrpcClient
implements ShuffleManag
}
@Override
- public RssPartitionToShuffleServerResponse getPartitionToShufflerServer(
+ public RssReassignOnStageRetryResponse
getPartitionToShufflerServerWithStageRetry(
RssPartitionToShuffleServerRequest req) {
RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto();
- RssProtos.PartitionToShuffleServerResponse partitionToShufflerServer =
- getBlockingStub().getPartitionToShufflerServer(protoRequest);
- RssPartitionToShuffleServerResponse rssPartitionToShuffleServerResponse =
-
RssPartitionToShuffleServerResponse.fromProto(partitionToShufflerServer);
- return rssPartitionToShuffleServerResponse;
+ RssProtos.ReassignOnStageRetryResponse partitionToShufflerServer =
+
getBlockingStub().getPartitionToShufflerServerWithStageRetry(protoRequest);
+ RssReassignOnStageRetryResponse rssReassignOnStageRetryResponse =
+ RssReassignOnStageRetryResponse.fromProto(partitionToShufflerServer);
+ return rssReassignOnStageRetryResponse;
+ }
+
+ @Override
+ public RssReassignOnBlockSendFailureResponse
getPartitionToShufflerServerWithBlockRetry(
+ RssPartitionToShuffleServerRequest req) {
+ RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto();
+ RssProtos.ReassignOnBlockSendFailureResponse partitionToShufflerServer =
+
getBlockingStub().getPartitionToShufflerServerWithBlockRetry(protoRequest);
+ RssReassignOnBlockSendFailureResponse
rssReassignOnBlockSendFailureResponse =
+
RssReassignOnBlockSendFailureResponse.fromProto(partitionToShufflerServer);
+ return rssReassignOnBlockSendFailureResponse;
}
@Override
@@ -128,7 +139,7 @@ public class ShuffleManagerGrpcClient extends GrpcClient
implements ShuffleManag
RssReassignOnBlockSendFailureRequest request) {
RssProtos.RssReassignOnBlockSendFailureRequest protoReq =
RssReassignOnBlockSendFailureRequest.toProto(request);
- RssProtos.RssReassignOnBlockSendFailureResponse response =
+ RssProtos.ReassignOnBlockSendFailureResponse response =
getBlockingStub().reassignOnBlockSendFailure(protoReq);
return RssReassignOnBlockSendFailureResponse.fromProto(response);
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 9a315c162..14dbf2f60 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -200,7 +200,8 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
RemoteStorageInfo remoteStorageInfo,
String user,
ShuffleDataDistributionType dataDistributionType,
- int maxConcurrencyPerPartitionToWrite) {
+ int maxConcurrencyPerPartitionToWrite,
+ int stageAttemptNumber) {
ShuffleRegisterRequest.Builder reqBuilder =
ShuffleRegisterRequest.newBuilder();
reqBuilder
.setAppId(appId)
@@ -208,7 +209,8 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
.setUser(user)
.setShuffleDataDistribution(RssProtos.DataDistribution.valueOf(dataDistributionType.name()))
.setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite)
- .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges));
+ .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges))
+ .setStageAttemptNumber(stageAttemptNumber);
RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder();
rsBuilder.setPath(remoteStorageInfo.getPath());
Map<String, String> remoteStorageConf = remoteStorageInfo.getConfItems();
@@ -468,7 +470,8 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
request.getRemoteStorageInfo(),
request.getUser(),
request.getDataDistributionType(),
- request.getMaxConcurrencyPerPartitionToWrite());
+ request.getMaxConcurrencyPerPartitionToWrite(),
+ request.getStageAttemptNumber());
RssRegisterShuffleResponse response;
RssProtos.StatusCode statusCode = rpcResponse.getStatus();
@@ -499,6 +502,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
String appId = request.getAppId();
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks =
request.getShuffleIdToBlocks();
+ int stageAttemptNumber = request.getStageAttemptNumber();
boolean isSuccessful = true;
AtomicReference<StatusCode> failedStatusCode = new
AtomicReference<>(StatusCode.INTERNAL_ERROR);
@@ -563,6 +567,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
.setRequireBufferId(requireId)
.addAllShuffleData(shuffleData)
.setTimestamp(start)
+ .setStageAttemptNumber(stageAttemptNumber)
.build();
SendShuffleDataResponse response =
getBlockingStub().sendShuffleData(rpcRequest);
if (LOG.isDebugEnabled()) {
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index 6451a97cd..1c46077ad 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -139,6 +139,7 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest
request) {
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks =
request.getShuffleIdToBlocks();
+ int stageAttemptNumber = request.getStageAttemptNumber();
boolean isSuccessful = true;
AtomicReference<StatusCode> failedStatusCode = new
AtomicReference<>(StatusCode.INTERNAL_ERROR);
@@ -159,6 +160,7 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
requestId(),
request.getAppId(),
shuffleId,
+ stageAttemptNumber,
0L,
stb.getValue(),
System.currentTimeMillis());
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
index 98fd01241..4cbdc4448 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
@@ -64,6 +64,30 @@ public class RssGetShuffleAssignmentsRequest {
int assignmentShuffleServerNumber,
int estimateTaskConcurrency,
Set<String> faultyServerIds) {
+ this(
+ appId,
+ shuffleId,
+ partitionNum,
+ partitionNumPerRange,
+ dataReplica,
+ requiredTags,
+ assignmentShuffleServerNumber,
+ estimateTaskConcurrency,
+ faultyServerIds,
+ 0);
+ }
+
+ public RssGetShuffleAssignmentsRequest(
+ String appId,
+ int shuffleId,
+ int partitionNum,
+ int partitionNumPerRange,
+ int dataReplica,
+ Set<String> requiredTags,
+ int assignmentShuffleServerNumber,
+ int estimateTaskConcurrency,
+ Set<String> faultyServerIds,
+ int stageAttemptNumber) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionNum = partitionNum;
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
index 2cd49bb6d..7e42be653 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
@@ -35,6 +35,7 @@ public class RssRegisterShuffleRequest {
private String user;
private ShuffleDataDistributionType dataDistributionType;
private int maxConcurrencyPerPartitionToWrite;
+ private int stageAttemptNumber;
public RssRegisterShuffleRequest(
String appId,
@@ -44,6 +45,26 @@ public class RssRegisterShuffleRequest {
String user,
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite) {
+ this(
+ appId,
+ shuffleId,
+ partitionRanges,
+ remoteStorageInfo,
+ user,
+ dataDistributionType,
+ maxConcurrencyPerPartitionToWrite,
+ 0);
+ }
+
+ public RssRegisterShuffleRequest(
+ String appId,
+ int shuffleId,
+ List<PartitionRange> partitionRanges,
+ RemoteStorageInfo remoteStorageInfo,
+ String user,
+ ShuffleDataDistributionType dataDistributionType,
+ int maxConcurrencyPerPartitionToWrite,
+ int stageAttemptNumber) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionRanges = partitionRanges;
@@ -51,6 +72,7 @@ public class RssRegisterShuffleRequest {
this.user = user;
this.dataDistributionType = dataDistributionType;
this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite;
+ this.stageAttemptNumber = stageAttemptNumber;
}
public RssRegisterShuffleRequest(
@@ -67,7 +89,8 @@ public class RssRegisterShuffleRequest {
remoteStorageInfo,
user,
dataDistributionType,
- RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue());
+ RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
+ 0);
}
public RssRegisterShuffleRequest(
@@ -79,7 +102,8 @@ public class RssRegisterShuffleRequest {
new RemoteStorageInfo(remoteStoragePath),
StringUtils.EMPTY,
ShuffleDataDistributionType.NORMAL,
- RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue());
+ RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
+ 0);
}
public String getAppId() {
@@ -109,4 +133,8 @@ public class RssRegisterShuffleRequest {
public int getMaxConcurrencyPerPartitionToWrite() {
return maxConcurrencyPerPartitionToWrite;
}
+
+ public int getStageAttemptNumber() {
+ return stageAttemptNumber;
+ }
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
index 8fbf18f29..1b5fdcff8 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
@@ -25,6 +25,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo;
public class RssSendShuffleDataRequest {
private String appId;
+ private int stageAttemptNumber;
private int retryMax;
private long retryIntervalMax;
private Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks;
@@ -34,10 +35,20 @@ public class RssSendShuffleDataRequest {
int retryMax,
long retryIntervalMax,
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
+ this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks);
+ }
+
+ public RssSendShuffleDataRequest(
+ String appId,
+ int stageAttemptNumber,
+ int retryMax,
+ long retryIntervalMax,
+ Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
this.appId = appId;
this.retryMax = retryMax;
this.retryIntervalMax = retryIntervalMax;
this.shuffleIdToBlocks = shuffleIdToBlocks;
+ this.stageAttemptNumber = stageAttemptNumber;
}
public String getAppId() {
@@ -52,6 +63,10 @@ public class RssSendShuffleDataRequest {
return retryIntervalMax;
}
+ public int getStageAttemptNumber() {
+ return stageAttemptNumber;
+ }
+
public Map<Integer, Map<Integer, List<ShuffleBlockInfo>>>
getShuffleIdToBlocks() {
return shuffleIdToBlocks;
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
index 81ca7d548..7d20f8742 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
@@ -34,7 +34,7 @@ public class RssReassignOnBlockSendFailureResponse extends
ClientResponse {
}
public static RssReassignOnBlockSendFailureResponse fromProto(
- RssProtos.RssReassignOnBlockSendFailureResponse response) {
+ RssProtos.ReassignOnBlockSendFailureResponse response) {
return new RssReassignOnBlockSendFailureResponse(
StatusCode.valueOf(response.getStatus().name()), response.getMsg(),
response.getHandle());
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnStageRetryResponse.java
similarity index 71%
rename from
internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
rename to
internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnStageRetryResponse.java
index 9daa002ed..3762ea41b 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnStageRetryResponse.java
@@ -20,24 +20,24 @@ package org.apache.uniffle.client.response;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.proto.RssProtos;
-public class RssPartitionToShuffleServerResponse extends ClientResponse {
- private RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto;
+public class RssReassignOnStageRetryResponse extends ClientResponse {
+ private RssProtos.StageAttemptShuffleHandleInfo shuffleHandleInfoProto;
- public RssPartitionToShuffleServerResponse(
+ public RssReassignOnStageRetryResponse(
StatusCode statusCode,
String message,
- RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto) {
+ RssProtos.StageAttemptShuffleHandleInfo shuffleHandleInfoProto) {
super(statusCode, message);
this.shuffleHandleInfoProto = shuffleHandleInfoProto;
}
- public RssProtos.MutableShuffleHandleInfo getShuffleHandleInfoProto() {
+ public RssProtos.StageAttemptShuffleHandleInfo getShuffleHandleInfoProto() {
return shuffleHandleInfoProto;
}
- public static RssPartitionToShuffleServerResponse fromProto(
- RssProtos.PartitionToShuffleServerResponse response) {
- return new RssPartitionToShuffleServerResponse(
+ public static RssReassignOnStageRetryResponse fromProto(
+ RssProtos.ReassignOnStageRetryResponse response) {
+ return new RssReassignOnStageRetryResponse(
StatusCode.valueOf(response.getStatus().name()),
response.getMsg(),
response.getShuffleHandleInfo());
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index a63e6d23a..61afad299 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -184,6 +184,7 @@ message ShuffleRegisterRequest {
string user = 5;
DataDistribution shuffleDataDistribution = 6;
int32 maxConcurrencyPerPartitionToWrite = 7;
+ int32 stageAttemptNumber = 8;
}
enum DataDistribution {
@@ -221,6 +222,7 @@ message SendShuffleDataRequest {
int64 requireBufferId = 3;
repeated ShuffleData shuffleData = 4;
int64 timestamp = 5;
+ int32 stageAttemptNumber = 6;
}
message SendShuffleDataResponse {
@@ -305,6 +307,7 @@ enum StatusCode {
ACCESS_DENIED = 8;
INVALID_REQUEST = 9;
NO_BUFFER_FOR_HUGE_PARTITION = 10;
+ STAGE_RETRY_IGNORE = 11;
// add more status
}
@@ -528,14 +531,16 @@ message CancelDecommissionResponse {
// per application.
service ShuffleManager {
rpc reportShuffleFetchFailure (ReportShuffleFetchFailureRequest) returns
(ReportShuffleFetchFailureResponse);
- // Gets the mapping between partitions and ShuffleServer from the
ShuffleManager server
- rpc getPartitionToShufflerServer(PartitionToShuffleServerRequest) returns
(PartitionToShuffleServerResponse);
+ // Gets the mapping between partitions and ShuffleServer from the
ShuffleManager server on Stage Retry.
+ rpc
getPartitionToShufflerServerWithStageRetry(PartitionToShuffleServerRequest)
returns (ReassignOnStageRetryResponse);
+ // Gets the mapping between partitions and ShuffleServer from the
ShuffleManager server on Block Retry.
+ rpc
getPartitionToShufflerServerWithBlockRetry(PartitionToShuffleServerRequest)
returns (ReassignOnBlockSendFailureResponse);
// Report write failures to ShuffleManager
rpc reportShuffleWriteFailure (ReportShuffleWriteFailureRequest) returns
(ReportShuffleWriteFailureResponse);
// Reassign the RPC interface of the ShuffleServer list
rpc reassignShuffleServers(ReassignServersRequest) returns
(ReassignServersResponse);
// Reassign on block send failure that occurs in writer
- rpc reassignOnBlockSendFailure(RssReassignOnBlockSendFailureRequest) returns
(RssReassignOnBlockSendFailureResponse);
+ rpc reassignOnBlockSendFailure(RssReassignOnBlockSendFailureRequest) returns
(ReassignOnBlockSendFailureResponse);
rpc reportShuffleResult (ReportShuffleResultRequest) returns
(ReportShuffleResultResponse);
rpc getShuffleResult (GetShuffleResultRequest) returns
(GetShuffleResultResponse);
rpc getShuffleResultForMultiPart (GetShuffleResultForMultiPartRequest)
returns (GetShuffleResultForMultiPartResponse);
@@ -563,10 +568,15 @@ message PartitionToShuffleServerRequest {
int32 shuffleId = 2;
}
-message PartitionToShuffleServerResponse {
+message ReassignOnStageRetryResponse {
StatusCode status = 1;
string msg = 2;
- MutableShuffleHandleInfo shuffleHandleInfo = 3;
+ StageAttemptShuffleHandleInfo shuffleHandleInfo = 3;
+}
+
+message StageAttemptShuffleHandleInfo {
+ repeated MutableShuffleHandleInfo historyMutableShuffleHandleInfo= 1;
+ MutableShuffleHandleInfo currentMutableShuffleHandleInfo = 2;
}
message MutableShuffleHandleInfo {
@@ -633,7 +643,7 @@ message ReceivingFailureServer {
StatusCode statusCode = 2;
}
-message RssReassignOnBlockSendFailureResponse {
+message ReassignOnBlockSendFailureResponse {
StatusCode status = 1;
string msg = 2;
MutableShuffleHandleInfo handle = 3;
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index de378d66e..b6e37029f 100644
---
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -157,6 +157,48 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
int shuffleId = req.getShuffleId();
String remoteStoragePath = req.getRemoteStorage().getPath();
String user = req.getUser();
+ int stageAttemptNumber = req.getStageAttemptNumber();
+ // If the Stage is registered for the first time, you do not need to
consider the Stage retry
+ // and delete the Block data that has been sent.
+ if (stageAttemptNumber > 0) {
+ ShuffleTaskInfo taskInfo =
shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId);
+ // Prevents AttemptNumber of multiple stages from modifying the latest
AttemptNumber.
+ synchronized (taskInfo) {
+ int attemptNumber = taskInfo.getLatestStageAttemptNumber(shuffleId);
+ if (stageAttemptNumber > attemptNumber) {
+ taskInfo.refreshLatestStageAttemptNumber(shuffleId,
stageAttemptNumber);
+ try {
+ long start = System.currentTimeMillis();
+ shuffleServer.getShuffleTaskManager().removeShuffleDataSync(appId,
shuffleId);
+ LOG.info(
+ "Deleted the previous stage attempt data due to stage
recomputing for app: {}, "
+ + "shuffleId: {}. It costs {} ms",
+ appId,
+ shuffleId,
+ System.currentTimeMillis() - start);
+ } catch (Exception e) {
+ LOG.error(
+ "Errors on clearing previous stage attempt data for app: {},
shuffleId: {}",
+ appId,
+ shuffleId,
+ e);
+ StatusCode code = StatusCode.INTERNAL_ERROR;
+ reply =
ShuffleRegisterResponse.newBuilder().setStatus(code.toProto()).build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
+ } else if (stageAttemptNumber < attemptNumber) {
+ // When a Stage retry occurs, the first or last registration of a
Stage may need to be
+ // ignored and the ignored status quickly returned.
+ StatusCode code = StatusCode.STAGE_RETRY_IGNORE;
+ reply =
ShuffleRegisterResponse.newBuilder().setStatus(code.toProto()).build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
+ }
+ }
ShuffleDataDistributionType shuffleDataDistributionType =
ShuffleDataDistributionType.valueOf(
@@ -210,6 +252,22 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
int shuffleId = req.getShuffleId();
long requireBufferId = req.getRequireBufferId();
long timestamp = req.getTimestamp();
+ int stageAttemptNumber = req.getStageAttemptNumber();
+ ShuffleTaskInfo taskInfo =
shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId);
+ Integer latestStageAttemptNumber =
taskInfo.getLatestStageAttemptNumber(shuffleId);
+ // The Stage retry occurred, and the task before StageNumber was simply
ignored and not
+ // processed if the task was being sent.
+ if (stageAttemptNumber < latestStageAttemptNumber) {
+ String responseMessage = "A retry has occurred at the Stage, sending
data is invalid.";
+ reply =
+ SendShuffleDataResponse.newBuilder()
+ .setStatus(StatusCode.STAGE_RETRY_IGNORE.toProto())
+ .setRetMsg(responseMessage)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
if (timestamp > 0) {
/*
* Here we record the transport time, but we don't consider the impact
of data size on transport time.
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java
index b6806f634..bbbfded01 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java
@@ -66,6 +66,8 @@ public class ShuffleTaskInfo {
private final Map<Integer, Map<Integer, AtomicLong>> partitionBlockCounters;
+ private final Map<Integer, Integer> latestStageAttemptNumbers;
+
public ShuffleTaskInfo(String appId) {
this.appId = appId;
this.currentTimes = System.currentTimeMillis();
@@ -78,6 +80,7 @@ public class ShuffleTaskInfo {
this.existHugePartition = new AtomicBoolean(false);
this.specification = new AtomicReference<>();
this.partitionBlockCounters = JavaUtils.newConcurrentMap();
+ this.latestStageAttemptNumbers = JavaUtils.newConcurrentMap();
}
public Long getCurrentTimes() {
@@ -220,6 +223,14 @@ public class ShuffleTaskInfo {
return counter.get();
}
+ public Integer getLatestStageAttemptNumber(int shuffleId) {
+ return latestStageAttemptNumbers.computeIfAbsent(shuffleId, key -> 0);
+ }
+
+ public void refreshLatestStageAttemptNumber(int shuffleId, int
stageAttemptNumber) {
+ latestStageAttemptNumbers.put(shuffleId, stageAttemptNumber);
+ }
+
@Override
public String toString() {
return "ShuffleTaskInfo{"
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index a98eac26c..8fe597d03 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -910,7 +910,7 @@ public class ShuffleTaskManager {
}
@VisibleForTesting
- void removeShuffleDataSync(String appId, int shuffleId) {
+ public void removeShuffleDataSync(String appId, int shuffleId) {
removeResourcesByShuffleIds(appId, Arrays.asList(shuffleId));
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index 668b53cba..448c12a23 100644
---
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -57,6 +57,7 @@ import org.apache.uniffle.server.ShuffleDataReadEvent;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.ShuffleServerMetrics;
+import org.apache.uniffle.server.ShuffleTaskInfo;
import org.apache.uniffle.server.ShuffleTaskManager;
import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo;
import org.apache.uniffle.server.buffer.ShuffleBufferManager;
@@ -102,6 +103,18 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
int shuffleId = req.getShuffleId();
long requireBufferId = req.getRequireId();
long timestamp = req.getTimestamp();
+ int stageAttemptNumber = req.getStageAttemptNumber();
+ ShuffleTaskInfo taskInfo =
shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId);
+ Integer latestStageAttemptNumber =
taskInfo.getLatestStageAttemptNumber(shuffleId);
+ // The Stage retry occurred, and the task before StageNumber was simply
ignored and not
+ // processed if the task was being sent.
+ if (stageAttemptNumber < latestStageAttemptNumber) {
+ String responseMessage = "A retry has occurred at the Stage, sending
data is invalid.";
+ rpcResponse =
+ new RpcResponse(req.getRequestId(), StatusCode.STAGE_RETRY_IGNORE,
responseMessage);
+ client.getChannel().writeAndFlush(rpcResponse);
+ return;
+ }
if (timestamp > 0) {
/*
* Here we record the transport time, but we don't consider the impact
of data size on transport time.