This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new d0f0b572 [#477] feat(spark): support getShuffleResult throws
FetchFailedException. (#1004)
d0f0b572 is described below
commit d0f0b5725425bf453b8fac6b7aad52862730ca42
Author: Xianming Lei <[email protected]>
AuthorDate: Thu Jul 20 16:09:33 2023 +0800
[#477] feat(spark): support getShuffleResult throws FetchFailedException.
(#1004)
### What changes were proposed in this pull request?
feat #477
### Why are the changes needed?
support getShuffleResult throws FetchFailedException.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Exiting UTs, Already added getShuffleResult failed, see
`MockedShuffleServerGrpcService`.
Co-authored-by: leixianming <[email protected]>
---
.../hadoop/mapred/SortWriteBufferManagerTest.java | 3 +-
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 3 +-
.../apache/spark/shuffle/RssSparkShuffleUtils.java | 52 ++++++++++++++++++++++
.../apache/spark/shuffle/RssShuffleManager.java | 22 ++++++++-
.../apache/spark/shuffle/RssShuffleManager.java | 25 ++++++++++-
.../spark/shuffle/reader/RssShuffleReader.java | 4 +-
.../common/sort/buffer/WriteBufferManagerTest.java | 3 +-
.../uniffle/client/api/ShuffleWriteClient.java | 3 +-
.../client/impl/ShuffleWriteClientImpl.java | 4 +-
.../org/apache/uniffle/common/util/Constants.java | 6 +++
.../server/MockedShuffleServerGrpcService.java | 20 +++++++++
11 files changed, 135 insertions(+), 10 deletions(-)
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index a9f84936..cd737877 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -473,7 +473,8 @@ public class SortWriteBufferManagerTest {
String clientType,
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
String appId,
- int shuffleId) {
+ int shuffleId,
+ Set<Integer> failedPartitions) {
return null;
}
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index 8de5c4b3..b171fa9e 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -550,7 +550,8 @@ public class FetcherTest {
String clientType,
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
String appId,
- int shuffleId) {
+ int shuffleId,
+ Set<Integer> failedPartitions) {
return null;
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index bbb17f6b..0edef8f4 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -17,6 +17,7 @@
package org.apache.spark.shuffle;
+import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
@@ -39,14 +40,24 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.CoordinatorClient;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.Constants;
+import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
+
public class RssSparkShuffleUtils {
private static final Logger LOG =
LoggerFactory.getLogger(RssSparkShuffleUtils.class);
@@ -328,4 +339,45 @@ public class RssSparkShuffleUtils {
return SparkVersionUtils.isSpark3()
|| (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION >=
3);
}
+
+ public static RssException reportRssFetchFailedException(
+ RssFetchFailedException rssFetchFailedException,
+ SparkConf sparkConf,
+ String appId,
+ int shuffleId,
+ int stageAttemptId,
+ Set<Integer> failedPartitions) {
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+ && RssSparkShuffleUtils.isStageResubmitSupported()) {
+ String driver = rssConf.getString(DRIVER_HOST, "");
+ int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+ try (ShuffleManagerClient client =
+ ShuffleManagerClientFactory.getInstance()
+ .createShuffleManagerClient(ClientType.GRPC, driver, port)) {
+ // todo: Create a new rpc interface to report failures in batch.
+ for (int partitionId : failedPartitions) {
+ RssReportShuffleFetchFailureRequest req =
+ new RssReportShuffleFetchFailureRequest(
+ appId,
+ shuffleId,
+ stageAttemptId,
+ partitionId,
+ rssFetchFailedException.getMessage());
+ RssReportShuffleFetchFailureResponse response =
client.reportShuffleFetchFailure(req);
+ if (response.getReSubmitWholeStage()) {
+ // since we are going to roll out the whole stage, mapIndex
shouldn't matter, hence -1
+ // is provided.
+ FetchFailedException ffe =
+ RssSparkShuffleUtils.createFetchFailedException(
+ shuffleId, -1, partitionId, rssFetchFailedException);
+ return new RssException(ffe);
+ }
+ }
+ } catch (IOException ioe) {
+ LOG.info("Error closing shuffle manager client with error:", ioe);
+ }
+ }
+ return rssFetchFailedException;
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 66e259cb..90b7a8e3 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -65,6 +65,7 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
@@ -477,12 +478,13 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
rssShuffleHandle.getPartitionToServers();
Roaring64NavigableMap blockIdBitmap =
- shuffleWriteClient.getShuffleResult(
+ getShuffleResult(
clientType,
Sets.newHashSet(partitionToServers.get(startPartition)),
rssShuffleHandle.getAppId(),
shuffleId,
- startPartition);
+ startPartition,
+ context.stageAttemptNumber());
LOG.info(
"Get shuffle blockId cost "
+ (System.currentTimeMillis() - start)
@@ -687,4 +689,20 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
public int getNumMaps(int shuffleId) {
return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
}
+
+ private Roaring64NavigableMap getShuffleResult(
+ String clientType,
+ Set<ShuffleServerInfo> shuffleServerInfoSet,
+ String appId,
+ int shuffleId,
+ int partitionId,
+ int stageAttemptId) {
+ try {
+ return shuffleWriteClient.getShuffleResult(
+ clientType, shuffleServerInfoSet, appId, shuffleId, partitionId);
+ } catch (RssFetchFailedException e) {
+ throw RssSparkShuffleUtils.reportRssFetchFailedException(
+ e, sparkConf, appId, shuffleId, stageAttemptId,
Sets.newHashSet(partitionId));
+ }
+ }
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index fa3c172f..674d6777 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -73,6 +73,7 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
@@ -597,8 +598,12 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
RssUtils.generateServerToPartitions(requirePartitionToServers);
long start = System.currentTimeMillis();
Roaring64NavigableMap blockIdBitmap =
- shuffleWriteClient.getShuffleResultForMultiPart(
- clientType, serverToPartitions, rssShuffleHandle.getAppId(),
shuffleId);
+ getShuffleResultForMultiPart(
+ clientType,
+ serverToPartitions,
+ rssShuffleHandle.getAppId(),
+ shuffleId,
+ context.stageAttemptNumber());
LOG.info(
"Get shuffle blockId cost "
+ (System.currentTimeMillis() - start)
@@ -979,4 +984,20 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
public boolean isValidTask(String taskId) {
return !failedTaskIds.contains(taskId);
}
+
+ private Roaring64NavigableMap getShuffleResultForMultiPart(
+ String clientType,
+ Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
+ String appId,
+ int shuffleId,
+ int stageAttemptId) {
+ Set<Integer> failedPartitions = Sets.newHashSet();
+ try {
+ return shuffleWriteClient.getShuffleResultForMultiPart(
+ clientType, serverToPartitions, appId, shuffleId, failedPartitions);
+ } catch (RssFetchFailedException e) {
+ throw RssSparkShuffleUtils.reportRssFetchFailedException(
+ e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions);
+ }
+ }
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index ad6c6810..4cd4aaa2 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -56,6 +56,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
+import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
+
public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleReader.class);
private final Map<Integer, List<ShuffleServerInfo>>
partitionToShuffleServers;
@@ -184,7 +186,7 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
// resubmit stage and shuffle manager server port are both set
if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
&& rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0)
{
- String driver = rssConf.getString("driver.host", "");
+ String driver = rssConf.getString(DRIVER_HOST, "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
resultIter =
RssFetchFailedIterator.newBuilder()
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 281dbb97..974b92cb 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -524,7 +524,8 @@ public class WriteBufferManagerTest {
String clientType,
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
String appId,
- int shuffleId) {
+ int shuffleId,
+ Set<Integer> failedPartitions) {
return null;
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index ac8fe4e8..39931646 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -89,7 +89,8 @@ public interface ShuffleWriteClient {
String clientType,
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
String appId,
- int shuffleId);
+ int shuffleId,
+ Set<Integer> failedPartitions);
void close();
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index fed876dc..9c5cb535 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -792,7 +792,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
String clientType,
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
String appId,
- int shuffleId) {
+ int shuffleId,
+ Set<Integer> failedPartitions) {
Map<Integer, Integer> partitionReadSuccess = Maps.newHashMap();
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
for (Map.Entry<ShuffleServerInfo, Set<Integer>> entry :
serverToPartitions.entrySet()) {
@@ -819,6 +820,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
}
}
} catch (Exception e) {
+ failedPartitions.addAll(requestPartitions);
LOG.warn(
"Get shuffle result is failed from "
+ shuffleServerInfo
diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java
b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
index cde4bbdb..5970bacc 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
@@ -79,4 +79,10 @@ public final class Constants {
public static final String GRPC_SERVICE_NAME = "grpc.server";
public static final int COMPOSITE_BYTE_BUF_MAX_COMPONENTS = 1024;
+
+ // The `driver.host` is matching spark's `DRIVER_HOST_ADDRESS`
configuration, which is
+ // `spark.driver.host`.
+ // We are accessing this configuration through RssConf, the spark prefix is
stripped, hence, this
+ // field.
+ public static final String DRIVER_HOST = "driver.host";
}
diff --git
a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
index c382028c..28c0944c 100644
---
a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
+++
b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
@@ -100,6 +100,16 @@ public class MockedShuffleServerGrpcService extends
ShuffleServerGrpcService {
LOG.info("Add a mocked timeout on getShuffleResult");
Uninterruptibles.sleepUninterruptibly(mockedTimeout,
TimeUnit.MILLISECONDS);
}
+ if (numOfFailedReadRequest > 0) {
+ int currentFailedReadRequest = failedReadRequest.getAndIncrement();
+ if (currentFailedReadRequest < numOfFailedReadRequest) {
+ LOG.info(
+ "This request is failed as mocked failure, current/firstN: {}/{}",
+ currentFailedReadRequest,
+ numOfFailedReadRequest);
+ throw new RuntimeException("This request is failed as mocked failure");
+ }
+ }
super.getShuffleResult(request, responseObserver);
}
@@ -111,6 +121,16 @@ public class MockedShuffleServerGrpcService extends
ShuffleServerGrpcService {
LOG.info("Add a mocked timeout on getShuffleResult");
Uninterruptibles.sleepUninterruptibly(mockedTimeout,
TimeUnit.MILLISECONDS);
}
+ if (numOfFailedReadRequest > 0) {
+ int currentFailedReadRequest = failedReadRequest.getAndIncrement();
+ if (currentFailedReadRequest < numOfFailedReadRequest) {
+ LOG.info(
+ "This request is failed as mocked failure, current/firstN: {}/{}",
+ currentFailedReadRequest,
+ numOfFailedReadRequest);
+ throw new RuntimeException("This request is failed as mocked failure");
+ }
+ }
if (recordGetShuffleResult) {
List<Integer> requestPartitions = request.getPartitionsList();
Map<Integer, AtomicInteger> shuffleIdToPartitionRequestNum =