This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new d0f0b572 [#477] feat(spark): support getShuffleResult throws 
FetchFailedException. (#1004)
d0f0b572 is described below

commit d0f0b5725425bf453b8fac6b7aad52862730ca42
Author: Xianming Lei <[email protected]>
AuthorDate: Thu Jul 20 16:09:33 2023 +0800

    [#477] feat(spark): support getShuffleResult throws FetchFailedException. 
(#1004)
    
    ### What changes were proposed in this pull request?
    
    feat #477
    
    ### Why are the changes needed?
    
    support getShuffleResult throws FetchFailedException.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    Exiting UTs, Already added getShuffleResult failed, see 
`MockedShuffleServerGrpcService`.
    
    Co-authored-by: leixianming <[email protected]>
---
 .../hadoop/mapred/SortWriteBufferManagerTest.java  |  3 +-
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |  3 +-
 .../apache/spark/shuffle/RssSparkShuffleUtils.java | 52 ++++++++++++++++++++++
 .../apache/spark/shuffle/RssShuffleManager.java    | 22 ++++++++-
 .../apache/spark/shuffle/RssShuffleManager.java    | 25 ++++++++++-
 .../spark/shuffle/reader/RssShuffleReader.java     |  4 +-
 .../common/sort/buffer/WriteBufferManagerTest.java |  3 +-
 .../uniffle/client/api/ShuffleWriteClient.java     |  3 +-
 .../client/impl/ShuffleWriteClientImpl.java        |  4 +-
 .../org/apache/uniffle/common/util/Constants.java  |  6 +++
 .../server/MockedShuffleServerGrpcService.java     | 20 +++++++++
 11 files changed, 135 insertions(+), 10 deletions(-)

diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index a9f84936..cd737877 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -473,7 +473,8 @@ public class SortWriteBufferManagerTest {
         String clientType,
         Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
         String appId,
-        int shuffleId) {
+        int shuffleId,
+        Set<Integer> failedPartitions) {
       return null;
     }
 
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index 8de5c4b3..b171fa9e 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -550,7 +550,8 @@ public class FetcherTest {
         String clientType,
         Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
         String appId,
-        int shuffleId) {
+        int shuffleId,
+        Set<Integer> failedPartitions) {
       return null;
     }
 
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index bbb17f6b..0edef8f4 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.shuffle;
 
+import java.io.IOException;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.util.Arrays;
@@ -39,14 +40,24 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.CoordinatorClient;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.util.Constants;
 
+import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
+
 public class RssSparkShuffleUtils {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(RssSparkShuffleUtils.class);
@@ -328,4 +339,45 @@ public class RssSparkShuffleUtils {
     return SparkVersionUtils.isSpark3()
         || (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION >= 
3);
   }
+
+  public static RssException reportRssFetchFailedException(
+      RssFetchFailedException rssFetchFailedException,
+      SparkConf sparkConf,
+      String appId,
+      int shuffleId,
+      int stageAttemptId,
+      Set<Integer> failedPartitions) {
+    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+    if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
+        && RssSparkShuffleUtils.isStageResubmitSupported()) {
+      String driver = rssConf.getString(DRIVER_HOST, "");
+      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+      try (ShuffleManagerClient client =
+          ShuffleManagerClientFactory.getInstance()
+              .createShuffleManagerClient(ClientType.GRPC, driver, port)) {
+        // todo: Create a new rpc interface to report failures in batch.
+        for (int partitionId : failedPartitions) {
+          RssReportShuffleFetchFailureRequest req =
+              new RssReportShuffleFetchFailureRequest(
+                  appId,
+                  shuffleId,
+                  stageAttemptId,
+                  partitionId,
+                  rssFetchFailedException.getMessage());
+          RssReportShuffleFetchFailureResponse response = 
client.reportShuffleFetchFailure(req);
+          if (response.getReSubmitWholeStage()) {
+            // since we are going to roll out the whole stage, mapIndex 
shouldn't matter, hence -1
+            // is provided.
+            FetchFailedException ffe =
+                RssSparkShuffleUtils.createFetchFailedException(
+                    shuffleId, -1, partitionId, rssFetchFailedException);
+            return new RssException(ffe);
+          }
+        }
+      } catch (IOException ioe) {
+        LOG.info("Error closing shuffle manager client with error:", ioe);
+      }
+    }
+    return rssFetchFailedException;
+  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 66e259cb..90b7a8e3 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -65,6 +65,7 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
@@ -477,12 +478,13 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       Map<Integer, List<ShuffleServerInfo>> partitionToServers =
           rssShuffleHandle.getPartitionToServers();
       Roaring64NavigableMap blockIdBitmap =
-          shuffleWriteClient.getShuffleResult(
+          getShuffleResult(
               clientType,
               Sets.newHashSet(partitionToServers.get(startPartition)),
               rssShuffleHandle.getAppId(),
               shuffleId,
-              startPartition);
+              startPartition,
+              context.stageAttemptNumber());
       LOG.info(
           "Get shuffle blockId cost "
               + (System.currentTimeMillis() - start)
@@ -687,4 +689,20 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   public int getNumMaps(int shuffleId) {
     return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
   }
+
+  private Roaring64NavigableMap getShuffleResult(
+      String clientType,
+      Set<ShuffleServerInfo> shuffleServerInfoSet,
+      String appId,
+      int shuffleId,
+      int partitionId,
+      int stageAttemptId) {
+    try {
+      return shuffleWriteClient.getShuffleResult(
+          clientType, shuffleServerInfoSet, appId, shuffleId, partitionId);
+    } catch (RssFetchFailedException e) {
+      throw RssSparkShuffleUtils.reportRssFetchFailedException(
+          e, sparkConf, appId, shuffleId, stageAttemptId, 
Sets.newHashSet(partitionId));
+    }
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index fa3c172f..674d6777 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -73,6 +73,7 @@ import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
@@ -597,8 +598,12 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         RssUtils.generateServerToPartitions(requirePartitionToServers);
     long start = System.currentTimeMillis();
     Roaring64NavigableMap blockIdBitmap =
-        shuffleWriteClient.getShuffleResultForMultiPart(
-            clientType, serverToPartitions, rssShuffleHandle.getAppId(), 
shuffleId);
+        getShuffleResultForMultiPart(
+            clientType,
+            serverToPartitions,
+            rssShuffleHandle.getAppId(),
+            shuffleId,
+            context.stageAttemptNumber());
     LOG.info(
         "Get shuffle blockId cost "
             + (System.currentTimeMillis() - start)
@@ -979,4 +984,20 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   public boolean isValidTask(String taskId) {
     return !failedTaskIds.contains(taskId);
   }
+
+  private Roaring64NavigableMap getShuffleResultForMultiPart(
+      String clientType,
+      Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
+      String appId,
+      int shuffleId,
+      int stageAttemptId) {
+    Set<Integer> failedPartitions = Sets.newHashSet();
+    try {
+      return shuffleWriteClient.getShuffleResultForMultiPart(
+          clientType, serverToPartitions, appId, shuffleId, failedPartitions);
+    } catch (RssFetchFailedException e) {
+      throw RssSparkShuffleUtils.reportRssFetchFailedException(
+          e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions);
+    }
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index ad6c6810..4cd4aaa2 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -56,6 +56,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 
+import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
+
 public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleReader.class);
   private final Map<Integer, List<ShuffleServerInfo>> 
partitionToShuffleServers;
@@ -184,7 +186,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     // resubmit stage and shuffle manager server port are both set
     if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
         && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) 
{
-      String driver = rssConf.getString("driver.host", "");
+      String driver = rssConf.getString(DRIVER_HOST, "");
       int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
       resultIter =
           RssFetchFailedIterator.newBuilder()
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 281dbb97..974b92cb 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -524,7 +524,8 @@ public class WriteBufferManagerTest {
         String clientType,
         Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
         String appId,
-        int shuffleId) {
+        int shuffleId,
+        Set<Integer> failedPartitions) {
       return null;
     }
 
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index ac8fe4e8..39931646 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -89,7 +89,8 @@ public interface ShuffleWriteClient {
       String clientType,
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
       String appId,
-      int shuffleId);
+      int shuffleId,
+      Set<Integer> failedPartitions);
 
   void close();
 
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index fed876dc..9c5cb535 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -792,7 +792,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       String clientType,
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
       String appId,
-      int shuffleId) {
+      int shuffleId,
+      Set<Integer> failedPartitions) {
     Map<Integer, Integer> partitionReadSuccess = Maps.newHashMap();
     Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
     for (Map.Entry<ShuffleServerInfo, Set<Integer>> entry : 
serverToPartitions.entrySet()) {
@@ -819,6 +820,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
           }
         }
       } catch (Exception e) {
+        failedPartitions.addAll(requestPartitions);
         LOG.warn(
             "Get shuffle result is failed from "
                 + shuffleServerInfo
diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java 
b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
index cde4bbdb..5970bacc 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
@@ -79,4 +79,10 @@ public final class Constants {
   public static final String GRPC_SERVICE_NAME = "grpc.server";
 
   public static final int COMPOSITE_BYTE_BUF_MAX_COMPONENTS = 1024;
+
+  // The `driver.host` is matching spark's `DRIVER_HOST_ADDRESS` 
configuration, which is
+  // `spark.driver.host`.
+  // We are accessing this configuration through RssConf, the spark prefix is 
stripped, hence, this
+  // field.
+  public static final String DRIVER_HOST = "driver.host";
 }
diff --git 
a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
 
b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
index c382028c..28c0944c 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
@@ -100,6 +100,16 @@ public class MockedShuffleServerGrpcService extends 
ShuffleServerGrpcService {
       LOG.info("Add a mocked timeout on getShuffleResult");
       Uninterruptibles.sleepUninterruptibly(mockedTimeout, 
TimeUnit.MILLISECONDS);
     }
+    if (numOfFailedReadRequest > 0) {
+      int currentFailedReadRequest = failedReadRequest.getAndIncrement();
+      if (currentFailedReadRequest < numOfFailedReadRequest) {
+        LOG.info(
+            "This request is failed as mocked failure, current/firstN: {}/{}",
+            currentFailedReadRequest,
+            numOfFailedReadRequest);
+        throw new RuntimeException("This request is failed as mocked failure");
+      }
+    }
     super.getShuffleResult(request, responseObserver);
   }
 
@@ -111,6 +121,16 @@ public class MockedShuffleServerGrpcService extends 
ShuffleServerGrpcService {
       LOG.info("Add a mocked timeout on getShuffleResult");
       Uninterruptibles.sleepUninterruptibly(mockedTimeout, 
TimeUnit.MILLISECONDS);
     }
+    if (numOfFailedReadRequest > 0) {
+      int currentFailedReadRequest = failedReadRequest.getAndIncrement();
+      if (currentFailedReadRequest < numOfFailedReadRequest) {
+        LOG.info(
+            "This request is failed as mocked failure, current/firstN: {}/{}",
+            currentFailedReadRequest,
+            numOfFailedReadRequest);
+        throw new RuntimeException("This request is failed as mocked failure");
+      }
+    }
     if (recordGetShuffleResult) {
       List<Integer> requestPartitions = request.getPartitionsList();
       Map<Integer, AtomicInteger> shuffleIdToPartitionRequestNum =

Reply via email to