This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 01925d27b [851] improvement: Add a similar util method like
ThreadUtils.parmap in the Spark (#1396)
01925d27b is described below
commit 01925d27b535985aa172cea899ca74a186d64d79
Author: Qing <[email protected]>
AuthorDate: Wed Dec 27 09:59:27 2023 +0800
[851] improvement: Add a similar util method like ThreadUtils.parmap in the
Spark (#1396)
### What changes were proposed in this pull request?
Add a similar util method like ThreadUtils.parmap in the Spark
### Why are the changes needed?
Fix: https://github.com/apache/incubator-uniffle/issues/851
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
unit test
---
.../client/impl/ShuffleWriteClientImpl.java | 186 +++++++++------------
.../apache/uniffle/common/util/ThreadUtils.java | 48 ++++++
.../uniffle/common/util/ThreadUtilsTest.java | 27 +++
.../apache/uniffle/server/RegisterHeartBeat.java | 45 +++--
4 files changed, 173 insertions(+), 133 deletions(-)
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 49be1d325..814072127 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -24,13 +24,10 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
-import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Stream;
@@ -863,93 +860,72 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
@Override
public void registerApplicationInfo(String appId, long timeoutMs, String
user) {
RssApplicationInfoRequest request = new RssApplicationInfoRequest(appId,
timeoutMs, user);
- List<Callable<Void>> callableList = Lists.newArrayList();
- coordinatorClients.forEach(
+
+ ThreadUtils.executeTasks(
+ heartBeatExecutorService,
+ coordinatorClients,
coordinatorClient -> {
- callableList.add(
- () -> {
- try {
- RssApplicationInfoResponse response =
- coordinatorClient.registerApplicationInfo(request);
- if (response.getStatusCode() != StatusCode.SUCCESS) {
- LOG.error("Failed to send applicationInfo to " +
coordinatorClient.getDesc());
- } else {
- LOG.info("Successfully send applicationInfo to " +
coordinatorClient.getDesc());
- }
- } catch (Exception e) {
- LOG.warn(
- "Error happened when send applicationInfo to " +
coordinatorClient.getDesc(),
- e);
- }
- return null;
- });
- });
- try {
- List<Future<Void>> futures =
- heartBeatExecutorService.invokeAll(callableList, timeoutMs,
TimeUnit.MILLISECONDS);
- for (Future<Void> future : futures) {
- if (!future.isDone()) {
- future.cancel(true);
- }
- }
- } catch (InterruptedException ie) {
- LOG.warn("register application is interrupted", ie);
- }
+ try {
+ RssApplicationInfoResponse response =
+ coordinatorClient.registerApplicationInfo(request);
+ if (response.getStatusCode() != StatusCode.SUCCESS) {
+ LOG.error("Failed to send applicationInfo to " +
coordinatorClient.getDesc());
+ } else {
+ LOG.info("Successfully send applicationInfo to " +
coordinatorClient.getDesc());
+ }
+ } catch (Exception e) {
+ LOG.warn(
+ "Error happened when send applicationInfo to " +
coordinatorClient.getDesc(), e);
+ }
+ return null;
+ },
+ timeoutMs,
+ "register application");
}
@Override
public void sendAppHeartbeat(String appId, long timeoutMs) {
RssAppHeartBeatRequest request = new RssAppHeartBeatRequest(appId,
timeoutMs);
- List<Callable<Void>> callableList = Lists.newArrayList();
Set<ShuffleServerInfo> allShuffleServers = getAllShuffleServers(appId);
- allShuffleServers.forEach(
+
+ ThreadUtils.executeTasks(
+ heartBeatExecutorService,
+ allShuffleServers,
shuffleServerInfo -> {
- callableList.add(
- () -> {
- try {
- ShuffleServerClient client =
- ShuffleServerClientFactory.getInstance()
- .getShuffleServerClient(clientType,
shuffleServerInfo, rssConf);
- RssAppHeartBeatResponse response =
client.sendHeartBeat(request);
- if (response.getStatusCode() != StatusCode.SUCCESS) {
- LOG.warn("Failed to send heartbeat to " +
shuffleServerInfo);
- }
- } catch (Exception e) {
- LOG.warn("Error happened when send heartbeat to " +
shuffleServerInfo, e);
- }
- return null;
- });
- });
+ try {
+ ShuffleServerClient client =
+ ShuffleServerClientFactory.getInstance()
+ .getShuffleServerClient(clientType, shuffleServerInfo,
rssConf);
+ RssAppHeartBeatResponse response = client.sendHeartBeat(request);
+ if (response.getStatusCode() != StatusCode.SUCCESS) {
+ LOG.warn("Failed to send heartbeat to " + shuffleServerInfo);
+ }
+ } catch (Exception e) {
+ LOG.warn("Error happened when send heartbeat to " +
shuffleServerInfo, e);
+ }
+ return null;
+ },
+ timeoutMs,
+ "send heartbeat to shuffle server");
- coordinatorClients.forEach(
+ ThreadUtils.executeTasks(
+ heartBeatExecutorService,
+ coordinatorClients,
coordinatorClient -> {
- callableList.add(
- () -> {
- try {
- RssAppHeartBeatResponse response =
coordinatorClient.sendAppHeartBeat(request);
- if (response.getStatusCode() != StatusCode.SUCCESS) {
- LOG.warn("Failed to send heartbeat to " +
coordinatorClient.getDesc());
- } else {
- LOG.info("Successfully send heartbeat to " +
coordinatorClient.getDesc());
- }
- } catch (Exception e) {
- LOG.warn(
- "Error happened when send heartbeat to " +
coordinatorClient.getDesc(), e);
- }
- return null;
- });
- });
- try {
- List<Future<Void>> futures =
- heartBeatExecutorService.invokeAll(callableList, timeoutMs,
TimeUnit.MILLISECONDS);
- for (Future<Void> future : futures) {
- if (!future.isDone()) {
- future.cancel(true);
- }
- }
- } catch (InterruptedException ie) {
- LOG.warn("heartbeat is interrupted", ie);
- }
+ try {
+ RssAppHeartBeatResponse response =
coordinatorClient.sendAppHeartBeat(request);
+ if (response.getStatusCode() != StatusCode.SUCCESS) {
+ LOG.warn("Failed to send heartbeat to " +
coordinatorClient.getDesc());
+ } else {
+ LOG.info("Successfully send heartbeat to " +
coordinatorClient.getDesc());
+ }
+ } catch (Exception e) {
+ LOG.warn("Error happened when send heartbeat to " +
coordinatorClient.getDesc(), e);
+ }
+ return null;
+ },
+ timeoutMs,
+ "send heartbeat to coordinator");
}
@Override
@@ -962,7 +938,6 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
@Override
public void unregisterShuffle(String appId, int shuffleId) {
RssUnregisterShuffleRequest request = new
RssUnregisterShuffleRequest(appId, shuffleId);
- List<Callable<Void>> callableList = Lists.newArrayList();
Map<Integer, Set<ShuffleServerInfo>> appServerMap =
shuffleServerInfoMap.get(appId);
if (appServerMap == null) {
@@ -973,39 +948,32 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
return;
}
- shuffleServerInfos.forEach(
- shuffleServerInfo -> {
- callableList.add(
- () -> {
- try {
- ShuffleServerClient client =
- ShuffleServerClientFactory.getInstance()
- .getShuffleServerClient(clientType,
shuffleServerInfo, rssConf);
- RssUnregisterShuffleResponse response =
client.unregisterShuffle(request);
- if (response.getStatusCode() != StatusCode.SUCCESS) {
- LOG.warn("Failed to unregister shuffle to " +
shuffleServerInfo);
- }
- } catch (Exception e) {
- LOG.warn("Error happened when unregistering to " +
shuffleServerInfo, e);
- }
- return null;
- });
- });
-
ExecutorService executorService = null;
try {
executorService =
ThreadUtils.getDaemonFixedThreadPool(
Math.min(unregisterThreadPoolSize, shuffleServerInfos.size()),
"unregister-shuffle");
- List<Future<Void>> futures =
- executorService.invokeAll(callableList, unregisterRequestTimeSec,
TimeUnit.SECONDS);
- for (Future<Void> future : futures) {
- if (!future.isDone()) {
- future.cancel(true);
- }
- }
- } catch (InterruptedException ie) {
- LOG.warn("Unregister shuffle is interrupted", ie);
+
+ ThreadUtils.executeTasks(
+ executorService,
+ shuffleServerInfos,
+ shuffleServerInfo -> {
+ try {
+ ShuffleServerClient client =
+ ShuffleServerClientFactory.getInstance()
+ .getShuffleServerClient(clientType, shuffleServerInfo,
rssConf);
+ RssUnregisterShuffleResponse response =
client.unregisterShuffle(request);
+ if (response.getStatusCode() != StatusCode.SUCCESS) {
+ LOG.warn("Failed to unregister shuffle to " +
shuffleServerInfo);
+ }
+ } catch (Exception e) {
+ LOG.warn("Error happened when unregistering to " +
shuffleServerInfo, e);
+ }
+ return null;
+ },
+ unregisterRequestTimeSec,
+ "unregister shuffle server");
+
} finally {
if (executorService != null) {
executorService.shutdownNow();
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
index e343c128f..f68eae444 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
@@ -17,12 +17,19 @@
package org.apache.uniffle.common.util;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+import java.util.stream.Collectors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
@@ -90,4 +97,45 @@ public class ThreadUtils {
}
}
}
+
+ public static <T, R> List<R> executeTasks(
+ ExecutorService executorService,
+ Collection<T> items,
+ Function<T, R> task,
+ long timeoutMs,
+ String taskMsg,
+ Function<Future<R>, R> futureHandler) {
+ List<Callable<R>> callableList =
+ items.stream()
+ .map(item -> (Callable<R>) () -> task.apply(item))
+ .collect(Collectors.toList());
+ try {
+ List<Future<R>> futures =
+ executorService.invokeAll(callableList, timeoutMs,
TimeUnit.MILLISECONDS);
+ return futures.stream().map(futureHandler).collect(Collectors.toList());
+ } catch (InterruptedException ie) {
+ LOGGER.warn("Execute " + taskMsg + " is interrupted", ie);
+ return Collections.emptyList();
+ }
+ }
+
+ public static <T, R> List<R> executeTasks(
+ ExecutorService executorService,
+ Collection<T> items,
+ Function<T, R> task,
+ long timeoutMs,
+ String taskMsg) {
+ return executeTasks(
+ executorService,
+ items,
+ task,
+ timeoutMs,
+ taskMsg,
+ future -> {
+ if (!future.isDone()) {
+ future.cancel(true);
+ }
+ return null;
+ });
+ }
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
index 02ab5d879..d7c70a32f 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
@@ -17,12 +17,18 @@
package org.apache.uniffle.common.util;
+import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Function;
import org.junit.jupiter.api.Test;
+import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ThreadUtilsTest {
@@ -45,4 +51,25 @@ public class ThreadUtilsTest {
assertTrue(finished.get());
assertTrue(executorService.isShutdown());
}
+
+ @Test
+ public void testExecuteTasksWithFutureHandler() {
+ ExecutorService executorService = Executors.newFixedThreadPool(2);
+ List<Integer> items = Arrays.asList(1, 2, 3, 4, 5);
+ Function<Integer, Integer> task = item -> item * 2;
+ long timeoutMs = 1000;
+ String taskMsg = "Test Task";
+ Function<Future<Integer>, Integer> futureHandler =
+ future -> {
+ try {
+ return future.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ return null;
+ }
+ };
+
+ List<Integer> results =
+ ThreadUtils.executeTasks(executorService, items, task, timeoutMs,
taskMsg, futureHandler);
+ assertEquals(Arrays.asList(2, 4, 6, 8, 10), results);
+ }
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
index 95b7b729c..9b7ca1de2 100644
--- a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
+++ b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
@@ -21,10 +21,9 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
+import java.util.concurrent.atomic.AtomicBoolean;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
@@ -33,7 +32,6 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.request.RssSendHeartBeatRequest;
-import org.apache.uniffle.client.response.RssSendHeartBeatResponse;
import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.storage.StorageInfo;
@@ -109,7 +107,7 @@ public class RegisterHeartBeat {
ServerStatus serverStatus,
Map<String, StorageInfo> localStorageInfo,
int nettyPort) {
- boolean sendSuccessfully = false;
+ AtomicBoolean sendSuccessfully = new AtomicBoolean(false);
// use `rss.server.heartbeat.interval` as the timeout option
RssSendHeartBeatRequest request =
new RssSendHeartBeatRequest(
@@ -125,28 +123,27 @@ public class RegisterHeartBeat {
serverStatus,
localStorageInfo,
nettyPort);
- List<Future<RssSendHeartBeatResponse>> respFutures =
- coordinatorClients.stream()
- .map(client -> heartBeatExecutorService.submit(() ->
client.sendHeartBeat(request)))
- .collect(Collectors.toList());
- String msg = "";
- for (Future<RssSendHeartBeatResponse> rf : respFutures) {
- try {
- if (rf.get(request.getTimeout() * 2,
TimeUnit.MILLISECONDS).getStatusCode()
- == StatusCode.SUCCESS) {
- sendSuccessfully = true;
- }
- } catch (Exception e) {
- msg = e.getMessage();
- }
- }
-
- if (!sendSuccessfully) {
- LOG.error(msg);
- }
+ ThreadUtils.executeTasks(
+ heartBeatExecutorService,
+ coordinatorClients,
+ client -> client.sendHeartBeat(request),
+ request.getTimeout() * 2,
+ "send heartbeat",
+ future -> {
+ try {
+ if (future.get(request.getTimeout() * 2,
TimeUnit.MILLISECONDS).getStatusCode()
+ == StatusCode.SUCCESS) {
+ sendSuccessfully.set(true);
+ }
+ } catch (Exception e) {
+ LOG.error(e.getMessage());
+ return null;
+ }
+ return null;
+ });
- return sendSuccessfully;
+ return sendSuccessfully.get();
}
public void shutdown() {