This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 5a17f4376 [#1765] improvement(client): Improve shuffle unregister
timeouts (#1738)
5a17f4376 is described below
commit 5a17f43762c189f966f39fa74d249405936d7338
Author: Enrico Minack <[email protected]>
AuthorDate: Wed Jun 12 07:52:36 2024 +0200
[#1765] improvement(client): Improve shuffle unregister timeouts (#1738)
### What changes were proposed in this pull request?
Use the `spark.rss.client.unregister.request.timeout.sec` timeout for the
individual GRPC calls and introduce an overall
`spark.rss.client.unregister.timeout.sec`.
### Why are the changes needed?
When unregistering with many shuffle servers and a relative small task
pool, there will be requests that are executed sequentially, but the current
per request timeout is applied to all requests together.
Example: 100 shuffle servers, 10 threads, completion of all tasks currently
times out after 10 seconds, giving each request only 1 second (though per
request timeout is 10s).
Fix: #1765
### Does this PR introduce _any_ user-facing change?
Adds config option `spark.rss.client.unregister.timeout.sec`.
### How was this patch tested?
Manually tested.
---
.../org/apache/spark/shuffle/RssSparkConfig.java | 14 +++-
.../spark/shuffle/writer/DataPusherTest.java | 1 +
.../apache/spark/shuffle/RssShuffleManager.java | 4 +-
.../apache/spark/shuffle/RssShuffleManager.java | 6 +-
.../apache/tez/dag/app/RssDAGAppMasterTest.java | 1 +
.../client/factory/ShuffleClientFactory.java | 10 +++
.../client/impl/ShuffleWriteClientImpl.java | 89 +++++++++++++++++-----
.../client/impl/ShuffleWriteClientImplTest.java | 6 ++
.../apache/uniffle/common/util/ThreadUtils.java | 71 +++++++++++++----
.../uniffle/common/util/ThreadUtilsTest.java | 82 ++++++++++++++++++--
.../uniffle/test/AssignmentWithTagsTest.java | 1 +
.../uniffle/test/CoordinatorAssignmentTest.java | 3 +
.../java/org/apache/uniffle/test/QuorumTest.java | 1 +
.../apache/uniffle/test/RpcClientRetryTest.java | 1 +
.../apache/uniffle/test/ShuffleServerGrpcTest.java | 1 +
.../test/ShuffleServerInternalGrpcTest.java | 2 +-
.../uniffle/test/ShuffleWithRssClientTest.java | 1 +
.../apache/uniffle/test/RssShuffleManagerTest.java | 1 +
.../client/impl/grpc/ShuffleServerGrpcClient.java | 29 ++++---
.../RssUnregisterShuffleByAppIdRequest.java | 8 +-
.../request/RssUnregisterShuffleRequest.java | 8 +-
21 files changed, 285 insertions(+), 55 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 75b4b998b..455ac6d52 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -191,8 +191,20 @@ public class RssSparkConfig {
createIntegerBuilder(new
ConfigBuilder("spark.rss.client.unregister.thread.pool.size"))
.createWithDefault(10);
+ public static final ConfigEntry<Integer> RSS_CLIENT_UNREGISTER_TIMEOUT_SEC =
+ createIntegerBuilder(
+ new ConfigBuilder("spark.rss.client.unregister.timeout.sec")
+ .doc(
+ "Unregister requests are executed concurrently and all
requests together "
+ + "have to complete within this timeout."))
+ .createWithDefault(10);
+
public static final ConfigEntry<Integer>
RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC =
- createIntegerBuilder(new
ConfigBuilder("spark.rss.client.unregister.request.timeout.sec"))
+ createIntegerBuilder(
+ new
ConfigBuilder("spark.rss.client.unregister.request.timeout.sec")
+ .doc(
+ "Unregister requests are executed concurrently and
individual requests "
+ + "have to complete within this timeout."))
.createWithDefault(10);
// When the size of read buffer reaches the half of JVM region (i.e., 32m),
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index 2a608bd4c..ebc590af5 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -61,6 +61,7 @@ public class DataPusherTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(1)
+ .unregisterTimeSec(1)
.unregisterRequestTimeSec(1));
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6f5b255be..0f2830721 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -199,6 +199,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
int unregisterThreadPoolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+ int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
int unregisterRequestTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
// External shuffle service is not supported when using remote shuffle
service
@@ -259,6 +260,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.dataTransferPoolSize(dataTransferPoolSize)
.dataCommitPoolSize(dataCommitPoolSize)
.unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterTimeSec(unregisterTimeoutSec)
.unregisterRequestTimeSec(unregisterRequestTimeoutSec)
.rssConf(rssConf));
registerCoordinator();
@@ -639,7 +641,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
}
} catch (Exception e) {
- LOG.warn("Errors on unregister to remote shuffle-servers", e);
+ LOG.warn("Errors on unregistering from remote shuffle-servers", e);
}
return true;
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 802d4d5ad..dc1de59ef 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -269,6 +269,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
int unregisterThreadPoolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+ int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
int unregisterRequestTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
@@ -290,6 +291,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.dataTransferPoolSize(dataTransferPoolSize)
.dataCommitPoolSize(dataCommitPoolSize)
.unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterTimeSec(unregisterTimeoutSec)
.unregisterRequestTimeSec(unregisterRequestTimeoutSec)
.rssConf(rssConf));
registerCoordinator();
@@ -373,6 +375,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
int unregisterThreadPoolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+ int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
int unregisterRequestTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
shuffleWriteClient =
@@ -391,6 +394,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
.dataTransferPoolSize(dataTransferPoolSize)
.dataCommitPoolSize(dataCommitPoolSize)
.unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterTimeSec(unregisterTimeoutSec)
.unregisterRequestTimeSec(unregisterRequestTimeoutSec)
.rssConf(rssConf));
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
@@ -893,7 +897,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
}
} catch (Exception e) {
- LOG.warn("Errors on unregister to remote shuffle-servers", e);
+ LOG.warn("Errors on unregistering from remote shuffle-servers", e);
}
return true;
}
diff --git
a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
index ba1921b60..ed631e534 100644
--- a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
@@ -511,6 +511,7 @@ public class RssDAGAppMasterTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(1)
+ .unregisterTimeSec(1)
.unregisterRequestTimeSec(1));
this.mode = mode;
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index 0eed01ea8..def9d5fa6 100644
---
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -63,6 +63,7 @@ public class ShuffleClientFactory {
private int dataTransferPoolSize;
private int dataCommitPoolSize;
private int unregisterThreadPoolSize;
+ private int unregisterTimeSec;
private int unregisterRequestTimeSec;
private RssConf rssConf;
@@ -110,6 +111,10 @@ public class ShuffleClientFactory {
return unregisterThreadPoolSize;
}
+ public int getUnregisterTimeSec() {
+ return unregisterTimeSec;
+ }
+
public int getUnregisterRequestTimeSec() {
return unregisterRequestTimeSec;
}
@@ -177,6 +182,11 @@ public class ShuffleClientFactory {
return self();
}
+ public T unregisterTimeSec(int unregisterTimeSec) {
+ this.unregisterTimeSec = unregisterTimeSec;
+ return self();
+ }
+
public T unregisterRequestTimeSec(int unregisterRequestTimeSec) {
this.unregisterRequestTimeSec = unregisterRequestTimeSec;
return self();
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index c109668ac..dc6ef420a 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -37,6 +37,8 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -115,6 +117,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
private int dataCommitPoolSize = -1;
private final ExecutorService dataTransferPool;
private final int unregisterThreadPoolSize;
+ private final int unregisterTimeSec;
private final int unregisterRequestTimeSec;
private Set<ShuffleServerInfo> defectiveServers;
private RssConf rssConf;
@@ -128,6 +131,9 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
if (builder.getUnregisterThreadPoolSize() == 0) {
builder.unregisterThreadPoolSize(10);
}
+ if (builder.getUnregisterTimeSec() == 0) {
+ builder.unregisterTimeSec(10);
+ }
if (builder.getUnregisterRequestTimeSec() == 0) {
builder.unregisterRequestTimeSec(10);
}
@@ -146,6 +152,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
builder.getDataTransferPoolSize(), "client-data-transfer");
this.dataCommitPoolSize = builder.getDataCommitPoolSize();
this.unregisterThreadPoolSize = builder.getUnregisterThreadPoolSize();
+ this.unregisterTimeSec = builder.getUnregisterTimeSec();
this.unregisterRequestTimeSec = builder.getUnregisterRequestTimeSec();
if (replica > 1) {
defectiveServers = Sets.newConcurrentHashSet();
@@ -957,8 +964,9 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
@Override
public void unregisterShuffle(String appId, int shuffleId) {
- int unregisterTimeMs = unregisterRequestTimeSec * 1000;
- RssUnregisterShuffleRequest request = new
RssUnregisterShuffleRequest(appId, shuffleId);
+ int unregisterTimeMs = unregisterTimeSec * 1000;
+ RssUnregisterShuffleRequest request =
+ new RssUnregisterShuffleRequest(appId, shuffleId,
unregisterRequestTimeSec);
Map<Integer, Set<ShuffleServerInfo>> appServerMap =
shuffleServerInfoMap.get(appId);
if (appServerMap == null) {
@@ -968,12 +976,17 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
if (shuffleServerInfos == null) {
return;
}
+ LOG.info(
+ "Unregistering shuffleId[{}] from {} shuffle servers with individual
timeout[{}s] and overall timeout[{}s]",
+ shuffleId,
+ shuffleServerInfos.size(),
+ unregisterRequestTimeSec,
+ unregisterTimeSec);
ExecutorService executorService = null;
try {
- executorService =
- ThreadUtils.getDaemonFixedThreadPool(
- Math.min(unregisterThreadPoolSize, shuffleServerInfos.size()),
"unregister-shuffle");
+ int concurrency = Math.min(unregisterThreadPoolSize,
shuffleServerInfos.size());
+ executorService = ThreadUtils.getDaemonFixedThreadPool(concurrency,
"unregister-shuffle");
ThreadUtils.executeTasks(
executorService,
@@ -984,16 +997,33 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
ShuffleServerClientFactory.getInstance()
.getShuffleServerClient(clientType, shuffleServerInfo,
rssConf);
RssUnregisterShuffleResponse response =
client.unregisterShuffle(request);
- if (response.getStatusCode() != StatusCode.SUCCESS) {
- LOG.warn("Failed to unregister shuffle to " +
shuffleServerInfo);
+ if (response.getStatusCode() == StatusCode.SUCCESS) {
+ LOG.info("Successfully unregistered shuffle from {}",
shuffleServerInfo);
+ } else {
+ LOG.warn("Failed to unregister shuffle from {}",
shuffleServerInfo);
}
} catch (Exception e) {
- LOG.warn("Error happened when unregistering to " +
shuffleServerInfo, e);
+ // this request observed the unregisterRequestTimeSec timeout
+ if (e instanceof StatusRuntimeException
+ && ((StatusRuntimeException) e).getStatus().getCode()
+ == Status.DEADLINE_EXCEEDED.getCode()) {
+ LOG.warn(
+ "Timeout occurred while unregistering from {}. The request
timeout is {}s: {}",
+ shuffleServerInfo,
+ unregisterRequestTimeSec,
+ ((StatusRuntimeException) e).getStatus().getDescription());
+ } else {
+ LOG.warn("Error while unregistering from {}",
shuffleServerInfo, e);
+ }
}
return null;
},
unregisterTimeMs,
- "unregister shuffle server");
+ "unregister shuffle server",
+ String.format(
+ "Please consider increasing the thread pool size (%s) or the
overall timeout (%ss) "
+ + "if you still think the request timeout (%ss) is
sensible.",
+ unregisterThreadPoolSize, unregisterTimeSec,
unregisterRequestTimeSec));
} finally {
if (executorService != null) {
@@ -1005,8 +1035,9 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
@Override
public void unregisterShuffle(String appId) {
- int unregisterTimeMs = unregisterRequestTimeSec * 1000;
- RssUnregisterShuffleByAppIdRequest request = new
RssUnregisterShuffleByAppIdRequest(appId);
+ int unregisterTimeMs = unregisterTimeSec * 1000;
+ RssUnregisterShuffleByAppIdRequest request =
+ new RssUnregisterShuffleByAppIdRequest(appId,
unregisterRequestTimeSec);
if (appId == null) {
return;
@@ -1016,12 +1047,17 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
return;
}
Set<ShuffleServerInfo> shuffleServerInfos = getAllShuffleServers(appId);
+ LOG.info(
+ "Unregistering shuffles of appId[{}] from {} shuffle servers with
individual timeout[{}s] and overall timeout[{}s]",
+ appId,
+ shuffleServerInfos.size(),
+ unregisterRequestTimeSec,
+ unregisterTimeSec);
ExecutorService executorService = null;
try {
- executorService =
- ThreadUtils.getDaemonFixedThreadPool(
- Math.min(unregisterThreadPoolSize, shuffleServerInfos.size()),
"unregister-shuffle");
+ int concurrency = Math.min(unregisterThreadPoolSize,
shuffleServerInfos.size());
+ executorService = ThreadUtils.getDaemonFixedThreadPool(concurrency,
"unregister-shuffle");
ThreadUtils.executeTasks(
executorService,
@@ -1033,16 +1069,33 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
.getShuffleServerClient(clientType, shuffleServerInfo,
rssConf);
RssUnregisterShuffleByAppIdResponse response =
client.unregisterShuffleByAppId(request);
- if (response.getStatusCode() != StatusCode.SUCCESS) {
- LOG.warn("Failed to unregister shuffle to " +
shuffleServerInfo);
+ if (response.getStatusCode() == StatusCode.SUCCESS) {
+ LOG.info("Successfully unregistered shuffle from {}",
shuffleServerInfo);
+ } else {
+ LOG.warn("Failed to unregister shuffle from {}",
shuffleServerInfo);
}
} catch (Exception e) {
- LOG.warn("Error happened when unregistering to " +
shuffleServerInfo, e);
+ // this request observed the unregisterRequestTimeSec timeout
+ if (e instanceof StatusRuntimeException
+ && ((StatusRuntimeException) e).getStatus().getCode()
+ == Status.DEADLINE_EXCEEDED.getCode()) {
+ LOG.warn(
+ "Timeout occurred while unregistering from {}. The request
timeout is {}s: {}",
+ shuffleServerInfo,
+ unregisterRequestTimeSec,
+ ((StatusRuntimeException) e).getStatus().getDescription());
+ } else {
+ LOG.warn("Error while unregistering from {}",
shuffleServerInfo, e);
+ }
}
return null;
},
unregisterTimeMs,
- "unregister shuffle server");
+ "unregister shuffle server",
+ String.format(
+ "Please consider increasing the thread pool size (%s) or the
overall timeout (%ss) "
+ + "if you still think the request timeout (%ss) is
sensible.",
+ unregisterThreadPoolSize, unregisterTimeSec,
unregisterRequestTimeSec));
} finally {
if (executorService != null) {
diff --git
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
index d26e96982..4ae2262e8 100644
---
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
+++
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
@@ -78,6 +78,7 @@ public class ShuffleWriteClientImplTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
ShuffleServerClient mockShuffleServerClient =
mock(ShuffleServerClient.class);
@@ -124,6 +125,7 @@ public class ShuffleWriteClientImplTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
ShuffleServerClient mockShuffleServerClient =
mock(ShuffleServerClient.class);
@@ -159,6 +161,7 @@ public class ShuffleWriteClientImplTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
String appId1 = "testRegisterAndUnRegisterShuffleServer-1";
@@ -197,6 +200,7 @@ public class ShuffleWriteClientImplTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
ShuffleServerClient mockShuffleServerClient =
mock(ShuffleServerClient.class);
@@ -324,6 +328,7 @@ public class ShuffleWriteClientImplTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.rssConf(rssConf);
ShuffleWriteClientImpl client = writeClientBuilder.build();
@@ -357,6 +362,7 @@ public class ShuffleWriteClientImplTest {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.rssConf(rssConf)
.build();
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
index f68eae444..eb2003995 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
@@ -28,6 +28,7 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -104,6 +105,7 @@ public class ThreadUtils {
Function<T, R> task,
long timeoutMs,
String taskMsg,
+ String timeoutMsg,
Function<Future<R>, R> futureHandler) {
List<Callable<R>> callableList =
items.stream()
@@ -112,7 +114,40 @@ public class ThreadUtils {
try {
List<Future<R>> futures =
executorService.invokeAll(callableList, timeoutMs,
TimeUnit.MILLISECONDS);
- return futures.stream().map(futureHandler).collect(Collectors.toList());
+ AtomicInteger cancelled = new AtomicInteger();
+ List<R> result =
+ futures.stream()
+ .map(
+ future -> {
+ // api doc says all futures are done, but better be safe
here
+ if (!future.isDone()) {
+ future.cancel(true);
+ }
+ // detect cancelled tasks (timeout)
+ if (future.isCancelled()) {
+ cancelled.incrementAndGet();
+ }
+ // do not replace this map with peek as peek is for debug
purposes and may be
+ // optimized away
+ return future;
+ })
+ .map(futureHandler)
+ .collect(Collectors.toList());
+ if (cancelled.get() > 0) {
+ if (timeoutMsg != null) {
+ timeoutMsg = " " + timeoutMsg;
+ } else {
+ timeoutMsg = "";
+ }
+ LOGGER.warn(
+ "Executing {} observed timeout of {}ms, {} out of {} tasks
cancelled.{}",
+ taskMsg,
+ timeoutMs,
+ cancelled.get(),
+ items.size(),
+ timeoutMsg);
+ }
+ return result;
} catch (InterruptedException ie) {
LOGGER.warn("Execute " + taskMsg + " is interrupted", ie);
return Collections.emptyList();
@@ -124,18 +159,28 @@ public class ThreadUtils {
Collection<T> items,
Function<T, R> task,
long timeoutMs,
- String taskMsg) {
+ String taskMsg,
+ Function<Future<R>, R> futureHandler) {
+ return executeTasks(executorService, items, task, timeoutMs, taskMsg,
null, futureHandler);
+ }
+
+ public static <T, R> List<R> executeTasks(
+ ExecutorService executorService,
+ Collection<T> items,
+ Function<T, R> task,
+ long timeoutMs,
+ String taskMsg,
+ String timeoutMsg) {
return executeTasks(
- executorService,
- items,
- task,
- timeoutMs,
- taskMsg,
- future -> {
- if (!future.isDone()) {
- future.cancel(true);
- }
- return null;
- });
+ executorService, items, task, timeoutMs, taskMsg, timeoutMsg, future
-> null);
+ }
+
+ public static <T, R> List<R> executeTasks(
+ ExecutorService executorService,
+ Collection<T> items,
+ Function<T, R> task,
+ long timeoutMs,
+ String taskMsg) {
+ return executeTasks(executorService, items, task, timeoutMs, taskMsg,
future -> null);
}
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
index d7c70a32f..1e621c6ec 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
@@ -19,6 +19,7 @@ package org.apache.uniffle.common.util;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@@ -26,9 +27,11 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
+import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ThreadUtilsTest {
@@ -37,19 +40,46 @@ public class ThreadUtilsTest {
public void shutdownThreadPoolTest() throws InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(2);
AtomicBoolean finished = new AtomicBoolean(false);
- executorService.submit(
+ AtomicBoolean interrupted = new AtomicBoolean(false);
+ Future<?> future =
+ executorService.submit(
+ () -> {
+ try {
+ Thread.sleep(100000);
+ } catch (InterruptedException interruptedException) {
+ interrupted.set(true);
+ } finally {
+ finished.set(true);
+ }
+ });
+ ThreadUtils.shutdownThreadPool(executorService, 1);
+ assertTrue(future.isDone());
+ assertFalse(future.isCancelled());
+ assertTrue(finished.get());
+ assertTrue(interrupted.get());
+ assertTrue(executorService.isShutdown());
+ }
+
+ @Test
+ public void invokeAllTimeoutThreadPoolTest() throws InterruptedException {
+ ExecutorService executorService = Executors.newFixedThreadPool(2);
+ Callable<Boolean> slowTask =
() -> {
try {
Thread.sleep(100000);
} catch (InterruptedException interruptedException) {
- // ignore
- } finally {
- finished.set(true);
+ // ignored
}
- });
- ThreadUtils.shutdownThreadPool(executorService, 1);
- assertTrue(finished.get());
- assertTrue(executorService.isShutdown());
+ return true;
+ };
+ Callable<Boolean> fastTask = () -> true;
+ List<Future<Boolean>> future =
+ executorService.invokeAll(Lists.newArrayList(slowTask, fastTask), 1,
TimeUnit.SECONDS);
+ assertTrue(future.get(0).isDone());
+ assertTrue(future.get(0).isCancelled());
+ assertTrue(future.get(1).isDone());
+ assertFalse(future.get(1).isCancelled());
+ assertFalse(executorService.isShutdown());
}
@Test
@@ -72,4 +102,40 @@ public class ThreadUtilsTest {
ThreadUtils.executeTasks(executorService, items, task, timeoutMs,
taskMsg, futureHandler);
assertEquals(Arrays.asList(2, 4, 6, 8, 10), results);
}
+
+ @Test
+ public void testExecuteTasksWithFutureHandlerAndTimeout() {
+ ExecutorService executorService = Executors.newFixedThreadPool(2);
+ List<Integer> items = Arrays.asList(1, 2, 3, 4, 5);
+ AtomicBoolean completed = new AtomicBoolean(false);
+ Function<Integer, Integer> task =
+ item -> {
+ if (item == 3) {
+ try {
+ Thread.sleep(100000);
+ completed.set(true);
+ } catch (InterruptedException interruptedException) {
+ // ignored
+ }
+ }
+ return item * 2;
+ };
+ long timeoutMs = 1000;
+ String taskMsg = "Test Task";
+ String timeoutMsg = "timeout message";
+ Function<Future<Integer>, Integer> futureHandler =
+ future -> {
+ try {
+ return future.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ return null;
+ }
+ };
+
+ List<Integer> results =
+ ThreadUtils.executeTasks(
+ executorService, items, task, timeoutMs, taskMsg, timeoutMsg,
futureHandler);
+ assertFalse(completed.get());
+ assertEquals(Arrays.asList(2, 4, null, 8, 10), results);
+ }
}
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
index fa6d036f4..2512c6c05 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
@@ -165,6 +165,7 @@ public class AssignmentWithTagsTest extends
CoordinatorTestBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
index 769a74faa..fd45c6929 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java
@@ -109,6 +109,7 @@ public class CoordinatorAssignmentTest extends
CoordinatorTestBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
shuffleWriteClient.registerCoordinators(QUORUM);
@@ -150,6 +151,7 @@ public class CoordinatorAssignmentTest extends
CoordinatorTestBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
@@ -199,6 +201,7 @@ public class CoordinatorAssignmentTest extends
CoordinatorTestBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
index d6af2d06b..35037470b 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
@@ -342,6 +342,7 @@ public class QuorumTest extends ShuffleReadWriteBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10));
List<ShuffleServerInfo> allServers =
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
index 100b645e7..9bd0944be 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
@@ -248,6 +248,7 @@ public class RpcClientRetryTest extends
ShuffleReadWriteBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10));
List<ShuffleServerInfo> allServers =
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
index c7a17b9b4..8a6e0cecf 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
@@ -176,6 +176,7 @@ public class ShuffleServerGrpcTest extends
IntegrationTestBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10));
shuffleWriteClient.registerCoordinators("127.0.0.1:" + COORDINATOR_PORT_1);
shuffleWriteClient.registerShuffle(
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
index 17bd40dbe..a149f4af2 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerInternalGrpcTest.java
@@ -103,7 +103,7 @@ public class ShuffleServerInternalGrpcTest extends
IntegrationTestBase {
assertEquals(ServerStatus.ACTIVE, shuffleServer.getServerStatus());
// Clean all apps, shuffle server will be shutdown right now.
- shuffleServerClient.unregisterShuffle(new
RssUnregisterShuffleRequest(appId, shuffleId));
+ shuffleServerClient.unregisterShuffle(new
RssUnregisterShuffleRequest(appId, shuffleId, 1));
response = shuffleServerInternalClient.decommission(new
RssDecommissionRequest());
assertEquals(StatusCode.SUCCESS, response.getStatusCode());
assertEquals(ServerStatus.DECOMMISSIONING,
shuffleServer.getServerStatus());
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
index de96781c8..595f5d2d3 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
@@ -116,6 +116,7 @@ public class ShuffleWithRssClientTest extends
ShuffleReadWriteBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.build();
}
diff --git
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
index 0401b2c1e..9589293e0 100644
---
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
+++
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
@@ -215,6 +215,7 @@ public class RssShuffleManagerTest extends
SparkIntegrationTestBase {
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
+ .unregisterTimeSec(10)
.unregisterRequestTimeSec(10)
.rssConf(rssConf)
.build();
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index f20cd85f5..659b1731f 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -359,17 +359,20 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
return result;
}
- private RssProtos.ShuffleUnregisterByAppIdResponse
doUnregisterShuffleByAppId(String appId) {
+ private RssProtos.ShuffleUnregisterByAppIdResponse
doUnregisterShuffleByAppId(
+ String appId, int timeoutSec) {
RssProtos.ShuffleUnregisterByAppIdRequest request =
RssProtos.ShuffleUnregisterByAppIdRequest.newBuilder().setAppId(appId).build();
- return blockingStub.unregisterShuffleByAppId(request);
+ return blockingStub
+ .withDeadlineAfter(timeoutSec, TimeUnit.SECONDS)
+ .unregisterShuffleByAppId(request);
}
@Override
public RssUnregisterShuffleByAppIdResponse unregisterShuffleByAppId(
RssUnregisterShuffleByAppIdRequest request) {
RssProtos.ShuffleUnregisterByAppIdResponse rpcResponse =
- doUnregisterShuffleByAppId(request.getAppId());
+ doUnregisterShuffleByAppId(request.getAppId(),
request.getTimeoutSec());
RssUnregisterShuffleByAppIdResponse response;
RssProtos.StatusCode statusCode = rpcResponse.getStatus();
@@ -381,8 +384,8 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
default:
String msg =
String.format(
- "Errors on unregister app to %s:%s for appId[%s], error: %s",
- host, port, request.getAppId(), rpcResponse.getRetMsg());
+ "Errors on unregistering app from %s:%s for appId[%s] and
timeout[%ss], error: %s",
+ host, port, request.getAppId(), request.getTimeoutSec(),
rpcResponse.getRetMsg());
LOG.error(msg);
throw new RssException(msg);
}
@@ -390,19 +393,20 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
return response;
}
- private RssProtos.ShuffleUnregisterResponse doUnregisterShuffle(String
appId, int shuffleId) {
+ private RssProtos.ShuffleUnregisterResponse doUnregisterShuffle(
+ String appId, int shuffleId, int timeoutSec) {
RssProtos.ShuffleUnregisterRequest request =
RssProtos.ShuffleUnregisterRequest.newBuilder()
.setAppId(appId)
.setShuffleId(shuffleId)
.build();
- return blockingStub.unregisterShuffle(request);
+ return blockingStub.withDeadlineAfter(timeoutSec,
TimeUnit.SECONDS).unregisterShuffle(request);
}
@Override
public RssUnregisterShuffleResponse
unregisterShuffle(RssUnregisterShuffleRequest request) {
RssProtos.ShuffleUnregisterResponse rpcResponse =
- doUnregisterShuffle(request.getAppId(), request.getShuffleId());
+ doUnregisterShuffle(request.getAppId(), request.getShuffleId(),
request.getTimeoutSec());
RssUnregisterShuffleResponse response;
RssProtos.StatusCode statusCode = rpcResponse.getStatus();
@@ -414,8 +418,13 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
default:
String msg =
String.format(
- "Errors on unregister shuffle to %s:%s for
appId[%s].shuffleId[%], error: %s",
- host, port, request.getAppId(), request.getShuffleId(),
rpcResponse.getRetMsg());
+ "Errors on unregistering shuffle from %s:%s for
appId[%s].shuffleId[%s] and timeout[%ss], error: %s",
+ host,
+ port,
+ request.getAppId(),
+ request.getShuffleId(),
+ request.getTimeoutSec(),
+ rpcResponse.getRetMsg());
LOG.error(msg);
throw new RssException(msg);
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
index 0992355a5..c37c4a6b7 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleByAppIdRequest.java
@@ -19,12 +19,18 @@ package org.apache.uniffle.client.request;
public class RssUnregisterShuffleByAppIdRequest {
private String appId;
+ private int timeoutSec;
- public RssUnregisterShuffleByAppIdRequest(String appId) {
+ public RssUnregisterShuffleByAppIdRequest(String appId, int timeoutSec) {
this.appId = appId;
+ this.timeoutSec = timeoutSec;
}
public String getAppId() {
return appId;
}
+
+ public int getTimeoutSec() {
+ return this.timeoutSec;
+ }
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
index 317e27ab6..48e8fc159 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssUnregisterShuffleRequest.java
@@ -20,10 +20,12 @@ package org.apache.uniffle.client.request;
public class RssUnregisterShuffleRequest {
private String appId;
private int shuffleId;
+ private int timeoutSec;
- public RssUnregisterShuffleRequest(String appId, int shuffleId) {
+ public RssUnregisterShuffleRequest(String appId, int shuffleId, int
timeoutSec) {
this.appId = appId;
this.shuffleId = shuffleId;
+ this.timeoutSec = timeoutSec;
}
public String getAppId() {
@@ -33,4 +35,8 @@ public class RssUnregisterShuffleRequest {
public int getShuffleId() {
return shuffleId;
}
+
+ public int getTimeoutSec() {
+ return timeoutSec;
+ }
}