This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a17f4376 [#1765] improvement(client): Improve shuffle unregister 
timeouts (#1738)
5a17f4376 is described below

commit 5a17f43762c189f966f39fa74d249405936d7338
Author: Enrico Minack <[email protected]>
AuthorDate: Wed Jun 12 07:52:36 2024 +0200

    [#1765] improvement(client): Improve shuffle unregister timeouts (#1738)
    
    ### What changes were proposed in this pull request?
    Use the `spark.rss.client.unregister.request.timeout.sec` timeout for the 
individual GRPC calls and introduce an overall 
`spark.rss.client.unregister.timeout.sec`.
    
    ### Why are the changes needed?
    When unregistering with many shuffle servers and a relative small task 
pool, there will be requests that are executed sequentially, but the current 
per request timeout is applied to all requests together.
    
    Example: 100 shuffle servers, 10 threads, completion of all tasks currently 
times out after 10 seconds, giving each request only 1 second (though per 
request timeout is 10s).
    
    Fix: #1765
    
    ### Does this PR introduce _any_ user-facing change?
    Adds config option `spark.rss.client.unregister.timeout.sec`.
    
    ### How was this patch tested?
    Manually tested.
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   | 14 +++-
 .../spark/shuffle/writer/DataPusherTest.java       |  1 +
 .../apache/spark/shuffle/RssShuffleManager.java    |  4 +-
 .../apache/spark/shuffle/RssShuffleManager.java    |  6 +-
 .../apache/tez/dag/app/RssDAGAppMasterTest.java    |  1 +
 .../client/factory/ShuffleClientFactory.java       | 10 +++
 .../client/impl/ShuffleWriteClientImpl.java        | 89 +++++++++++++++++-----
 .../client/impl/ShuffleWriteClientImplTest.java    |  6 ++
 .../apache/uniffle/common/util/ThreadUtils.java    | 71 +++++++++++++----
 .../uniffle/common/util/ThreadUtilsTest.java       | 82 ++++++++++++++++++--
 .../uniffle/test/AssignmentWithTagsTest.java       |  1 +
 .../uniffle/test/CoordinatorAssignmentTest.java    |  3 +
 .../java/org/apache/uniffle/test/QuorumTest.java   |  1 +
 .../apache/uniffle/test/RpcClientRetryTest.java    |  1 +
 .../apache/uniffle/test/ShuffleServerGrpcTest.java |  1 +
 .../test/ShuffleServerInternalGrpcTest.java        |  2 +-
 .../uniffle/test/ShuffleWithRssClientTest.java     |  1 +
 .../apache/uniffle/test/RssShuffleManagerTest.java |  1 +
 .../client/impl/grpc/ShuffleServerGrpcClient.java  | 29 ++++---
 .../RssUnregisterShuffleByAppIdRequest.java        |  8 +-
 .../request/RssUnregisterShuffleRequest.java       |  8 +-
 21 files changed, 285 insertions(+), 55 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 75b4b998b..455ac6d52 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -191,8 +191,20 @@ public class RssSparkConfig {
       createIntegerBuilder(new 
ConfigBuilder("spark.rss.client.unregister.thread.pool.size"))
           .createWithDefault(10);
 
+  public static final ConfigEntry<Integer> RSS_CLIENT_UNREGISTER_TIMEOUT_SEC =
+      createIntegerBuilder(
+              new ConfigBuilder("spark.rss.client.unregister.timeout.sec")
+                  .doc(
+                      "Unregister requests are executed concurrently and all 
requests together "
+                          + "have to complete within this timeout."))
+          .createWithDefault(10);
+
   public static final ConfigEntry<Integer> 
RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC =
-      createIntegerBuilder(new 
ConfigBuilder("spark.rss.client.unregister.request.timeout.sec"))
+      createIntegerBuilder(
+              new 
ConfigBuilder("spark.rss.client.unregister.request.timeout.sec")
+                  .doc(
+                      "Unregister requests are executed concurrently and 
individual requests "
+                          + "have to complete within this timeout."))
           .createWithDefault(10);
 
   // When the size of read buffer reaches the half of JVM region (i.e., 32m),
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index 2a608bd4c..ebc590af5 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -61,6 +61,7 @@ public class DataPusherTest {
               .dataTransferPoolSize(1)
               .dataCommitPoolSize(1)
               .unregisterThreadPoolSize(1)
+              .unregisterTimeSec(1)
               .unregisterRequestTimeSec(1));
     }
 
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6f5b255be..0f2830721 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -199,6 +199,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.dataCommitPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
     int unregisterThreadPoolSize =
         sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
     int unregisterRequestTimeoutSec =
         
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
     // External shuffle service is not supported when using remote shuffle 
service
@@ -259,6 +260,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
                       .dataTransferPoolSize(dataTransferPoolSize)
                       .dataCommitPoolSize(dataCommitPoolSize)
                       .unregisterThreadPoolSize(unregisterThreadPoolSize)
+                      .unregisterTimeSec(unregisterTimeoutSec)
                       .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
                       .rssConf(rssConf));
       registerCoordinator();
@@ -639,7 +641,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         }
       }
     } catch (Exception e) {
-      LOG.warn("Errors on unregister to remote shuffle-servers", e);
+      LOG.warn("Errors on unregistering from remote shuffle-servers", e);
     }
     return true;
   }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 802d4d5ad..dc1de59ef 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -269,6 +269,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
     int unregisterThreadPoolSize =
         sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
     int unregisterRequestTimeoutSec =
         
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
     long retryIntervalMax = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
@@ -290,6 +291,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
                     .dataTransferPoolSize(dataTransferPoolSize)
                     .dataCommitPoolSize(dataCommitPoolSize)
                     .unregisterThreadPoolSize(unregisterThreadPoolSize)
+                    .unregisterTimeSec(unregisterTimeoutSec)
                     .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
                     .rssConf(rssConf));
     registerCoordinator();
@@ -373,6 +375,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.dataCommitPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
     int unregisterThreadPoolSize =
         sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
     int unregisterRequestTimeoutSec =
         
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
     shuffleWriteClient =
@@ -391,6 +394,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
                     .dataTransferPoolSize(dataTransferPoolSize)
                     .dataCommitPoolSize(dataCommitPoolSize)
                     .unregisterThreadPoolSize(unregisterThreadPoolSize)
+                    .unregisterTimeSec(unregisterTimeoutSec)
                     .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
                     .rssConf(rssConf));
     this.taskToSuccessBlockIds = taskToSuccessBlockIds;
@@ -893,7 +897,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         }
       }
     } catch (Exception e) {
-      LOG.warn("Errors on unregister to remote shuffle-servers", e);
+      LOG.warn("Errors on unregistering from remote shuffle-servers", e);
     }
     return true;
   }
diff --git 
a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java 
b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
index ba1921b60..ed631e534 100644
--- a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
@@ -511,6 +511,7 @@ public class RssDAGAppMasterTest {
               .dataTransferPoolSize(1)
               .dataCommitPoolSize(1)
               .unregisterThreadPoolSize(1)
+              .unregisterTimeSec(1)
               .unregisterRequestTimeSec(1));
       this.mode = mode;
     }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index 0eed01ea8..def9d5fa6 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -63,6 +63,7 @@ public class ShuffleClientFactory {
     private int dataTransferPoolSize;
     private int dataCommitPoolSize;
     private int unregisterThreadPoolSize;
+    private int unregisterTimeSec;
     private int unregisterRequestTimeSec;
     private RssConf rssConf;
 
@@ -110,6 +111,10 @@ public class ShuffleClientFactory {
       return unregisterThreadPoolSize;
     }
 
+    public int getUnregisterTimeSec() {
+      return unregisterTimeSec;
+    }
+
     public int getUnregisterRequestTimeSec() {
       return unregisterRequestTimeSec;
     }
@@ -177,6 +182,11 @@ public class ShuffleClientFactory {
       return self();
     }
 
+    public T unregisterTimeSec(int unregisterTimeSec) {
+      this.unregisterTimeSec = unregisterTimeSec;
+      return self();
+    }
+
     public T unregisterRequestTimeSec(int unregisterRequestTimeSec) {
       this.unregisterRequestTimeSec = unregisterRequestTimeSec;
       return self();
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index c109668ac..dc6ef420a 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -37,6 +37,8 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -115,6 +117,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   private int dataCommitPoolSize = -1;
   private final ExecutorService dataTransferPool;
   private final int unregisterThreadPoolSize;
+  private final int unregisterTimeSec;
   private final int unregisterRequestTimeSec;
   private Set<ShuffleServerInfo> defectiveServers;
   private RssConf rssConf;
@@ -128,6 +131,9 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     if (builder.getUnregisterThreadPoolSize() == 0) {
       builder.unregisterThreadPoolSize(10);
     }
+    if (builder.getUnregisterTimeSec() == 0) {
+      builder.unregisterTimeSec(10);
+    }
     if (builder.getUnregisterRequestTimeSec() == 0) {
       builder.unregisterRequestTimeSec(10);
     }
@@ -146,6 +152,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
             builder.getDataTransferPoolSize(), "client-data-transfer");
     this.dataCommitPoolSize = builder.getDataCommitPoolSize();
     this.unregisterThreadPoolSize = builder.getUnregisterThreadPoolSize();
+    this.unregisterTimeSec = builder.getUnregisterTimeSec();
     this.unregisterRequestTimeSec = builder.getUnregisterRequestTimeSec();
     if (replica > 1) {
       defectiveServers = Sets.newConcurrentHashSet();
@@ -957,8 +964,9 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
   @Override
   public void unregisterShuffle(String appId, int shuffleId) {
-    int unregisterTimeMs = unregisterRequestTimeSec * 1000;
-    RssUnregisterShuffleRequest request = new 
RssUnregisterShuffleRequest(appId, shuffleId);
+    int unregisterTimeMs = unregisterTimeSec * 1000;
+    RssUnregisterShuffleRequest request =
+        new RssUnregisterShuffleRequest(appId, shuffleId, 
unregisterRequestTimeSec);
 
     Map<Integer, Set<ShuffleServerInfo>> appServerMap = 
shuffleServerInfoMap.get(appId);
     if (appServerMap == null) {
@@ -968,12 +976,17 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     if (shuffleServerInfos == null) {
       return;
     }
+    LOG.info(
+        "Unregistering shuffleId[{}] from {} shuffle servers with individual 
timeout[{}s] and overall timeout[{}s]",
+        shuffleId,
+        shuffleServerInfos.size(),
+        unregisterRequestTimeSec,
+        unregisterTimeSec);
 
     ExecutorService executorService = null;
     try {
-      executorService =
-          ThreadUtils.getDaemonFixedThreadPool(
-              Math.min(unregisterThreadPoolSize, shuffleServerInfos.size()), 
"unregister-shuffle");
+      int concurrency = Math.min(unregisterThreadPoolSize, 
shuffleServerInfos.size());
+      executorService = ThreadUtils.getDaemonFixedThreadPool(concurrency, 
"unregister-shuffle");
 
       ThreadUtils.executeTasks(
           executorService,
@@ -984,16 +997,33 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
                   ShuffleServerClientFactory.getInstance()
                       .getShuffleServerClient(clientType, shuffleServerInfo, 
rssConf);
               RssUnregisterShuffleResponse response = 
client.unregisterShuffle(request);
-              if (response.getStatusCode() != StatusCode.SUCCESS) {
-                LOG.warn("Failed to unregister shuffle to " + 
shuffleServerInfo);
+              if (response.getStatusCode() == StatusCode.SUCCESS) {
+                LOG.info("Successfully unregistered shuffle from {}", 
shuffleServerInfo);
+              } else {
+                LOG.warn("Failed to unregister shuffle from {}", 
shuffleServerInfo);
               }
             } catch (Exception e) {
-              LOG.warn("Error happened when unregistering to " + 
shuffleServerInfo, e);
+              // this request observed the unregisterRequestTimeSec timeout
+              if (e instanceof StatusRuntimeException
+                  && ((StatusRuntimeException) e).getStatus().getCode()
+                      == Status.DEADLINE_EXCEEDED.getCode()) {
+                LOG.warn(
+                    "Timeout occurred while unregistering from {}. The request 
timeout is {}s: {}",
+                    shuffleServerInfo,
+                    unregisterRequestTimeSec,
+                    ((StatusRuntimeException) e).getStatus().getDescription());
+              } else {
+                LOG.warn("Error while unregistering from {}", 
shuffleServerInfo, e);
+              }
             }
             return null;
           },
           unregisterTimeMs,
-          "unregister shuffle server");
+          "unregister shuffle server",
+          String.format(
+              "Please consider increasing the thread pool size (%s) or the 
overall timeout (%ss) "
+                  + "if you still think the request timeout (%ss) is 
sensible.",
+              unregisterThreadPoolSize, unregisterTimeSec, 
unregisterRequestTimeSec));
 
     } finally {
       if (executorService != null) {
@@ -1005,8 +1035,9 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
   @Override
   public void unregisterShuffle(String appId) {
-    int unregisterTimeMs = unregisterRequestTimeSec * 1000;
-    RssUnregisterShuffleByAppIdRequest request = new 
RssUnregisterShuffleByAppIdRequest(appId);
+    int unregisterTimeMs = unregisterTimeSec * 1000;
+    RssUnregisterShuffleByAppIdRequest request =
+        new RssUnregisterShuffleByAppIdRequest(appId, 
unregisterRequestTimeSec);
 
     if (appId == null) {
       return;
@@ -1016,12 +1047,17 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       return;
     }
     Set<ShuffleServerInfo> shuffleServerInfos = getAllShuffleServers(appId);
+    LOG.info(
+        "Unregistering shuffles of appId[{}] from {} shuffle servers with 
individual timeout[{}s] and overall timeout[{}s]",
+        appId,
+        shuffleServerInfos.size(),
+        unregisterRequestTimeSec,
+        unregisterTimeSec);
 
     ExecutorService executorService = null;
     try {
-      executorService =
-          ThreadUtils.getDaemonFixedThreadPool(
-              Math.min(unregisterThreadPoolSize, shuffleServerInfos.size()), 
"unregister-shuffle");
+      int concurrency = Math.min(unregisterThreadPoolSize, 
shuffleServerInfos.size());
+      executorService = ThreadUtils.getDaemonFixedThreadPool(concurrency, 
"unregister-shuffle");
 
       ThreadUtils.executeTasks(
           executorService,
@@ -1033,16 +1069,33 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
                       .getShuffleServerClient(clientType, shuffleServerInfo, 
rssConf);
               RssUnregisterShuffleByAppIdResponse response =
                   client.unregisterShuffleByAppId(request);
-              if (response.getStatusCode() != StatusCode.SUCCESS) {
-                LOG.warn("Failed to unregister shuffle to " + 
shuffleServerInfo);
+              if (response.getStatusCode() == StatusCode.SUCCESS) {
+                LOG.info("Successfully unregistered shuffle from {}", 
shuffleServerInfo);
+              } else {
+                LOG.warn("Failed to unregister shuffle from {}", 
shuffleServerInfo);
               }
             } catch (Exception e) {
-              LOG.warn("Error happened when unregistering to " + 
shuffleServerInfo, e);
+              // this request observed the unregisterRequestTimeSec timeout
+              if (e instanceof StatusRuntimeException
+                  && ((StatusRuntimeException) e).getStatus().getCode()
+                      == Status.DEADLINE_EXCEEDED.getCode()) {
+                LOG.warn(
+                    "Timeout occurred while unregistering from {}. The request 
timeout is {}s: {}",
+                    shuffleServerInfo,
+                    unregisterRequestTimeSec,
+                    ((StatusRuntimeException) e).getStatus().getDescription());
+              } else {
+                LOG.warn("Error while unregistering from {}", 
shuffleServerInfo, e);
+              }
             }
             return null;
           },
           unregisterTimeMs,
-          "unregister shuffle server");
+          "unregister shuffle server",
+          String.format(
+              "Please consider increasing the thread pool size (%s) or the 
overall timeout (%ss) "
+                  + "if you still think the request timeout (%ss) is 
sensible.",
+              unregisterThreadPoolSize, unregisterTimeSec, 
unregisterRequestTimeSec));
 
     } finally {
       if (executorService != null) {
diff --git 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
index d26e96982..4ae2262e8 100644
--- 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
+++ 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
@@ -78,6 +78,7 @@ public class ShuffleWriteClientImplTest {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     ShuffleServerClient mockShuffleServerClient = 
mock(ShuffleServerClient.class);
@@ -124,6 +125,7 @@ public class ShuffleWriteClientImplTest {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     ShuffleServerClient mockShuffleServerClient = 
mock(ShuffleServerClient.class);
@@ -159,6 +161,7 @@ public class ShuffleWriteClientImplTest {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     String appId1 = "testRegisterAndUnRegisterShuffleServer-1";
@@ -197,6 +200,7 @@ public class ShuffleWriteClientImplTest {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     ShuffleServerClient mockShuffleServerClient = 
mock(ShuffleServerClient.class);
@@ -324,6 +328,7 @@ public class ShuffleWriteClientImplTest {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .rssConf(rssConf);
     ShuffleWriteClientImpl client = writeClientBuilder.build();
@@ -357,6 +362,7 @@ public class ShuffleWriteClientImplTest {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .rssConf(rssConf)
             .build();
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
index f68eae444..eb2003995 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
@@ -28,6 +28,7 @@ import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -104,6 +105,7 @@ public class ThreadUtils {
       Function<T, R> task,
       long timeoutMs,
       String taskMsg,
+      String timeoutMsg,
       Function<Future<R>, R> futureHandler) {
     List<Callable<R>> callableList =
         items.stream()
@@ -112,7 +114,40 @@ public class ThreadUtils {
     try {
       List<Future<R>> futures =
           executorService.invokeAll(callableList, timeoutMs, 
TimeUnit.MILLISECONDS);
-      return futures.stream().map(futureHandler).collect(Collectors.toList());
+      AtomicInteger cancelled = new AtomicInteger();
+      List<R> result =
+          futures.stream()
+              .map(
+                  future -> {
+                    // api doc says all futures are done, but better be safe 
here
+                    if (!future.isDone()) {
+                      future.cancel(true);
+                    }
+                    // detect cancelled tasks (timeout)
+                    if (future.isCancelled()) {
+                      cancelled.incrementAndGet();
+                    }
+                    // do not replace this map with peek as peek is for debug 
purposes and may be
+                    // optimized away
+                    return future;
+                  })
+              .map(futureHandler)
+              .collect(Collectors.toList());
+      if (cancelled.get() > 0) {
+        if (timeoutMsg != null) {
+          timeoutMsg = " " + timeoutMsg;
+        } else {
+          timeoutMsg = "";
+        }
+        LOGGER.warn(
+            "Executing {} observed timeout of {}ms, {} out of {} tasks 
cancelled.{}",
+            taskMsg,
+            timeoutMs,
+            cancelled.get(),
+            items.size(),
+            timeoutMsg);
+      }
+      return result;
     } catch (InterruptedException ie) {
       LOGGER.warn("Execute " + taskMsg + " is interrupted", ie);
       return Collections.emptyList();
@@ -124,18 +159,28 @@ public class ThreadUtils {
       Collection<T> items,
       Function<T, R> task,
       long timeoutMs,
-      String taskMsg) {
+      String taskMsg,
+      Function<Future<R>, R> futureHandler) {
+    return executeTasks(executorService, items, task, timeoutMs, taskMsg, 
null, futureHandler);
+  }
+
+  public static <T, R> List<R> executeTasks(
+      ExecutorService executorService,
+      Collection<T> items,
+      Function<T, R> task,
+      long timeoutMs,
+      String taskMsg,
+      String timeoutMsg) {
     return executeTasks(
-        executorService,
-        items,
-        task,
-        timeoutMs,
-        taskMsg,
-        future -> {
-          if (!future.isDone()) {
-            future.cancel(true);
-          }
-          return null;
-        });
+        executorService, items, task, timeoutMs, taskMsg, timeoutMsg, future 
-> null);
+  }
+
+  public static <T, R> List<R> executeTasks(
+      ExecutorService executorService,
+      Collection<T> items,
+      Function<T, R> task,
+      long timeoutMs,
+      String taskMsg) {
+    return executeTasks(executorService, items, task, timeoutMs, taskMsg, 
future -> null);
   }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java 
b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
index d7c70a32f..1e621c6ec 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
@@ -19,6 +19,7 @@ package org.apache.uniffle.common.util;
 
 import java.util.Arrays;
 import java.util.List;
+import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -26,9 +27,11 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Function;
 
+import com.google.common.collect.Lists;
 import org.junit.jupiter.api.Test;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class ThreadUtilsTest {
@@ -37,19 +40,46 @@ public class ThreadUtilsTest {
   public void shutdownThreadPoolTest() throws InterruptedException {
     ExecutorService executorService = Executors.newFixedThreadPool(2);
     AtomicBoolean finished = new AtomicBoolean(false);
-    executorService.submit(
+    AtomicBoolean interrupted = new AtomicBoolean(false);
+    Future<?> future =
+        executorService.submit(
+            () -> {
+              try {
+                Thread.sleep(100000);
+              } catch (InterruptedException interruptedException) {
+                interrupted.set(true);
+              } finally {
+                finished.set(true);
+              }
+            });
+    ThreadUtils.shutdownThreadPool(executorService, 1);
+    assertTrue(future.isDone());
+    assertFalse(future.isCancelled());
+    assertTrue(finished.get());
+    assertTrue(interrupted.get());
+    assertTrue(executorService.isShutdown());
+  }
+
+  @Test
+  public void invokeAllTimeoutThreadPoolTest() throws InterruptedException {
+    ExecutorService executorService = Executors.newFixedThreadPool(2);
+    Callable<Boolean> slowTask =
         () -> {
           try {
             Thread.sleep(100000);
           } catch (InterruptedException interruptedException) {
-            // ignore
-          } finally {
-            finished.set(true);
+            // ignored
           }
-        });
-    ThreadUtils.shutdownThreadPool(executorService, 1);
-    assertTrue(finished.get());
-    assertTrue(executorService.isShutdown());
+          return true;
+        };
+    Callable<Boolean> fastTask = () -> true;
+    List<Future<Boolean>> future =
+        executorService.invokeAll(Lists.newArrayList(slowTask, fastTask), 1, 
TimeUnit.SECONDS);
+    assertTrue(future.get(0).isDone());
+    assertTrue(future.get(0).isCancelled());
+    assertTrue(future.get(1).isDone());
+    assertFalse(future.get(1).isCancelled());
+    assertFalse(executorService.isShutdown());
   }
 
   @Test
@@ -72,4 +102,40 @@ public class ThreadUtilsTest {
         ThreadUtils.executeTasks(executorService, items, task, timeoutMs, 
taskMsg, futureHandler);
     assertEquals(Arrays.asList(2, 4, 6, 8, 10), results);
   }
+
+  @Test
+  public void testExecuteTasksWithFutureHandlerAndTimeout() {
+    ExecutorService executorService = Executors.newFixedThreadPool(2);
+    List<Integer> items = Arrays.asList(1, 2, 3, 4, 5);
+    AtomicBoolean completed = new AtomicBoolean(false);
+    Function<Integer, Integer> task =
+        item -> {
+          if (item == 3) {
+            try {
+              Thread.sleep(100000);
+              completed.set(true);
+            } catch (InterruptedException interruptedException) {
+              // ignored
+            }
+          }
+          return item * 2;
+        };
+    long timeoutMs = 1000;
+    String taskMsg = "Test Task";
+    String timeoutMsg = "timeout message";
+    Function<Future<Integer>, Integer> futureHandler =
+        future -> {
+          try {
+            return future.get(timeoutMs, TimeUnit.MILLISECONDS);
+          } catch (Exception e) {
+            return null;
+          }
+        };
+
+    List<Integer> results =
+        ThreadUtils.executeTasks(
+            executorService, items, task, timeoutMs, taskMsg, timeoutMsg, 
futureHandler);
+    assertFalse(completed.get());
+    assertEquals(Arrays.asList(2, 4, null, 8, 10), results);
+  }
 }
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
index fa6d036f4..2512c6c05 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
@@ -165,6 +165,7 @@ public class AssignmentWithTagsTest extends 
CoordinatorTestBase {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
index 769a74faa..fd45c6929 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
@@ -109,6 +109,7 @@ public class CoordinatorAssignmentTest extends 
CoordinatorTestBase {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     shuffleWriteClient.registerCoordinators(QUORUM);
@@ -150,6 +151,7 @@ public class CoordinatorAssignmentTest extends 
CoordinatorTestBase {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
@@ -199,6 +201,7 @@ public class CoordinatorAssignmentTest extends 
CoordinatorTestBase {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
     shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java 
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
index d6af2d06b..35037470b 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
@@ -342,6 +342,7 @@ public class QuorumTest extends ShuffleReadWriteBase {
                 .dataTransferPoolSize(1)
                 .dataCommitPoolSize(1)
                 .unregisterThreadPoolSize(10)
+                .unregisterTimeSec(10)
                 .unregisterRequestTimeSec(10));
 
     List<ShuffleServerInfo> allServers =
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
index 100b645e7..9bd0944be 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
@@ -248,6 +248,7 @@ public class RpcClientRetryTest extends 
ShuffleReadWriteBase {
                 .dataTransferPoolSize(1)
                 .dataCommitPoolSize(1)
                 .unregisterThreadPoolSize(10)
+                .unregisterTimeSec(10)
                 .unregisterRequestTimeSec(10));
 
     List<ShuffleServerInfo> allServers =
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
index c7a17b9b4..8a6e0cecf 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
@@ -176,6 +176,7 @@ public class ShuffleServerGrpcTest extends 
IntegrationTestBase {
                     .dataTransferPoolSize(1)
                     .dataCommitPoolSize(1)
                     .unregisterThreadPoolSize(10)
+                    .unregisterTimeSec(10)
                     .unregisterRequestTimeSec(10));
     shuffleWriteClient.registerCoordinators("127.0.0.1:" + COORDINATOR_PORT_1);
     shuffleWriteClient.registerShuffle(
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
index 17bd40dbe..a149f4af2 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
@@ -103,7 +103,7 @@ public class ShuffleServerInternalGrpcTest extends 
IntegrationTestBase {
     assertEquals(ServerStatus.ACTIVE, shuffleServer.getServerStatus());
 
     // Clean all apps, shuffle server will be shutdown right now.
-    shuffleServerClient.unregisterShuffle(new 
RssUnregisterShuffleRequest(appId, shuffleId));
+    shuffleServerClient.unregisterShuffle(new 
RssUnregisterShuffleRequest(appId, shuffleId, 1));
     response = shuffleServerInternalClient.decommission(new 
RssDecommissionRequest());
     assertEquals(StatusCode.SUCCESS, response.getStatusCode());
     assertEquals(ServerStatus.DECOMMISSIONING, 
shuffleServer.getServerStatus());
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
index de96781c8..595f5d2d3 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
@@ -116,6 +116,7 @@ public class ShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
             .dataTransferPoolSize(1)
             .dataCommitPoolSize(1)
             .unregisterThreadPoolSize(10)
+            .unregisterTimeSec(10)
             .unregisterRequestTimeSec(10)
             .build();
   }
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
index 0401b2c1e..9589293e0 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
@@ -215,6 +215,7 @@ public class RssShuffleManagerTest extends 
SparkIntegrationTestBase {
               .dataTransferPoolSize(1)
               .dataCommitPoolSize(1)
               .unregisterThreadPoolSize(10)
+              .unregisterTimeSec(10)
               .unregisterRequestTimeSec(10)
               .rssConf(rssConf)
               .build();
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index f20cd85f5..659b1731f 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -359,17 +359,20 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
     return result;
   }
 
-  private RssProtos.ShuffleUnregisterByAppIdResponse 
doUnregisterShuffleByAppId(String appId) {
+  private RssProtos.ShuffleUnregisterByAppIdResponse 
doUnregisterShuffleByAppId(
+      String appId, int timeoutSec) {
     RssProtos.ShuffleUnregisterByAppIdRequest request =
         
RssProtos.ShuffleUnregisterByAppIdRequest.newBuilder().setAppId(appId).build();
-    return blockingStub.unregisterShuffleByAppId(request);
+    return blockingStub
+        .withDeadlineAfter(timeoutSec, TimeUnit.SECONDS)
+        .unregisterShuffleByAppId(request);
   }
 
   @Override
   public RssUnregisterShuffleByAppIdResponse unregisterShuffleByAppId(
       RssUnregisterShuffleByAppIdRequest request) {
     RssProtos.ShuffleUnregisterByAppIdResponse rpcResponse =
-        doUnregisterShuffleByAppId(request.getAppId());
+        doUnregisterShuffleByAppId(request.getAppId(), 
request.getTimeoutSec());
 
     RssUnregisterShuffleByAppIdResponse response;
     RssProtos.StatusCode statusCode = rpcResponse.getStatus();
@@ -381,8 +384,8 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
       default:
         String msg =
             String.format(
-                "Errors on unregister app to %s:%s for appId[%s], error: %s",
-                host, port, request.getAppId(), rpcResponse.getRetMsg());
+                "Errors on unregistering app from %s:%s for appId[%s] and 
timeout[%ss], error: %s",
+                host, port, request.getAppId(), request.getTimeoutSec(), 
rpcResponse.getRetMsg());
         LOG.error(msg);
         throw new RssException(msg);
     }
@@ -390,19 +393,20 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
     return response;
   }
 
-  private RssProtos.ShuffleUnregisterResponse doUnregisterShuffle(String 
appId, int shuffleId) {
+  private RssProtos.ShuffleUnregisterResponse doUnregisterShuffle(
+      String appId, int shuffleId, int timeoutSec) {
     RssProtos.ShuffleUnregisterRequest request =
         RssProtos.ShuffleUnregisterRequest.newBuilder()
             .setAppId(appId)
             .setShuffleId(shuffleId)
             .build();
-    return blockingStub.unregisterShuffle(request);
+    return blockingStub.withDeadlineAfter(timeoutSec, 
TimeUnit.SECONDS).unregisterShuffle(request);
   }
 
   @Override
   public RssUnregisterShuffleResponse 
unregisterShuffle(RssUnregisterShuffleRequest request) {
     RssProtos.ShuffleUnregisterResponse rpcResponse =
-        doUnregisterShuffle(request.getAppId(), request.getShuffleId());
+        doUnregisterShuffle(request.getAppId(), request.getShuffleId(), 
request.getTimeoutSec());
 
     RssUnregisterShuffleResponse response;
     RssProtos.StatusCode statusCode = rpcResponse.getStatus();
@@ -414,8 +418,13 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
       default:
         String msg =
             String.format(
-                "Errors on unregister shuffle to %s:%s for 
appId[%s].shuffleId[%], error: %s",
-                host, port, request.getAppId(), request.getShuffleId(), 
rpcResponse.getRetMsg());
+                "Errors on unregistering shuffle from %s:%s for 
appId[%s].shuffleId[%s] and timeout[%ss], error: %s",
+                host,
+                port,
+                request.getAppId(),
+                request.getShuffleId(),
+                request.getTimeoutSec(),
+                rpcResponse.getRetMsg());
         LOG.error(msg);
         throw new RssException(msg);
     }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
index 0992355a5..c37c4a6b7 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
@@ -19,12 +19,18 @@ package org.apache.uniffle.client.request;
 
 public class RssUnregisterShuffleByAppIdRequest {
   private String appId;
+  private int timeoutSec;
 
-  public RssUnregisterShuffleByAppIdRequest(String appId) {
+  public RssUnregisterShuffleByAppIdRequest(String appId, int timeoutSec) {
     this.appId = appId;
+    this.timeoutSec = timeoutSec;
   }
 
   public String getAppId() {
     return appId;
   }
+
+  public int getTimeoutSec() {
+    return this.timeoutSec;
+  }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
index 317e27ab6..48e8fc159 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
@@ -20,10 +20,12 @@ package org.apache.uniffle.client.request;
 public class RssUnregisterShuffleRequest {
   private String appId;
   private int shuffleId;
+  private int timeoutSec;
 
-  public RssUnregisterShuffleRequest(String appId, int shuffleId) {
+  public RssUnregisterShuffleRequest(String appId, int shuffleId, int 
timeoutSec) {
     this.appId = appId;
     this.shuffleId = shuffleId;
+    this.timeoutSec = timeoutSec;
   }
 
   public String getAppId() {
@@ -33,4 +35,8 @@ public class RssUnregisterShuffleRequest {
   public int getShuffleId() {
     return shuffleId;
   }
+
+  public int getTimeoutSec() {
+    return timeoutSec;
+  }
 }


Reply via email to