This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 01925d27b [851] improvement: Add a similar util method like 
ThreadUtils.parmap in the Spark (#1396)
01925d27b is described below

commit 01925d27b535985aa172cea899ca74a186d64d79
Author: Qing <[email protected]>
AuthorDate: Wed Dec 27 09:59:27 2023 +0800

    [851] improvement: Add a similar util method like ThreadUtils.parmap in the 
Spark (#1396)
    
    ### What changes were proposed in this pull request?
    
    Add a similar util method like ThreadUtils.parmap in the Spark
    ### Why are the changes needed?
    
    Fix: https://github.com/apache/incubator-uniffle/issues/851
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    unit test
---
 .../client/impl/ShuffleWriteClientImpl.java        | 186 +++++++++------------
 .../apache/uniffle/common/util/ThreadUtils.java    |  48 ++++++
 .../uniffle/common/util/ThreadUtilsTest.java       |  27 +++
 .../apache/uniffle/server/RegisterHeartBeat.java   |  45 +++--
 4 files changed, 173 insertions(+), 133 deletions(-)

diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 49be1d325..814072127 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -24,13 +24,10 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.Callable;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.ForkJoinPool;
-import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Supplier;
 import java.util.stream.Stream;
@@ -863,93 +860,72 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   @Override
   public void registerApplicationInfo(String appId, long timeoutMs, String 
user) {
     RssApplicationInfoRequest request = new RssApplicationInfoRequest(appId, 
timeoutMs, user);
-    List<Callable<Void>> callableList = Lists.newArrayList();
-    coordinatorClients.forEach(
+
+    ThreadUtils.executeTasks(
+        heartBeatExecutorService,
+        coordinatorClients,
         coordinatorClient -> {
-          callableList.add(
-              () -> {
-                try {
-                  RssApplicationInfoResponse response =
-                      coordinatorClient.registerApplicationInfo(request);
-                  if (response.getStatusCode() != StatusCode.SUCCESS) {
-                    LOG.error("Failed to send applicationInfo to " + 
coordinatorClient.getDesc());
-                  } else {
-                    LOG.info("Successfully send applicationInfo to " + 
coordinatorClient.getDesc());
-                  }
-                } catch (Exception e) {
-                  LOG.warn(
-                      "Error happened when send applicationInfo to " + 
coordinatorClient.getDesc(),
-                      e);
-                }
-                return null;
-              });
-        });
-    try {
-      List<Future<Void>> futures =
-          heartBeatExecutorService.invokeAll(callableList, timeoutMs, 
TimeUnit.MILLISECONDS);
-      for (Future<Void> future : futures) {
-        if (!future.isDone()) {
-          future.cancel(true);
-        }
-      }
-    } catch (InterruptedException ie) {
-      LOG.warn("register application is interrupted", ie);
-    }
+          try {
+            RssApplicationInfoResponse response =
+                coordinatorClient.registerApplicationInfo(request);
+            if (response.getStatusCode() != StatusCode.SUCCESS) {
+              LOG.error("Failed to send applicationInfo to " + 
coordinatorClient.getDesc());
+            } else {
+              LOG.info("Successfully send applicationInfo to " + 
coordinatorClient.getDesc());
+            }
+          } catch (Exception e) {
+            LOG.warn(
+                "Error happened when send applicationInfo to " + 
coordinatorClient.getDesc(), e);
+          }
+          return null;
+        },
+        timeoutMs,
+        "register application");
   }
 
   @Override
   public void sendAppHeartbeat(String appId, long timeoutMs) {
     RssAppHeartBeatRequest request = new RssAppHeartBeatRequest(appId, 
timeoutMs);
-    List<Callable<Void>> callableList = Lists.newArrayList();
     Set<ShuffleServerInfo> allShuffleServers = getAllShuffleServers(appId);
-    allShuffleServers.forEach(
+
+    ThreadUtils.executeTasks(
+        heartBeatExecutorService,
+        allShuffleServers,
         shuffleServerInfo -> {
-          callableList.add(
-              () -> {
-                try {
-                  ShuffleServerClient client =
-                      ShuffleServerClientFactory.getInstance()
-                          .getShuffleServerClient(clientType, 
shuffleServerInfo, rssConf);
-                  RssAppHeartBeatResponse response = 
client.sendHeartBeat(request);
-                  if (response.getStatusCode() != StatusCode.SUCCESS) {
-                    LOG.warn("Failed to send heartbeat to " + 
shuffleServerInfo);
-                  }
-                } catch (Exception e) {
-                  LOG.warn("Error happened when send heartbeat to " + 
shuffleServerInfo, e);
-                }
-                return null;
-              });
-        });
+          try {
+            ShuffleServerClient client =
+                ShuffleServerClientFactory.getInstance()
+                    .getShuffleServerClient(clientType, shuffleServerInfo, 
rssConf);
+            RssAppHeartBeatResponse response = client.sendHeartBeat(request);
+            if (response.getStatusCode() != StatusCode.SUCCESS) {
+              LOG.warn("Failed to send heartbeat to " + shuffleServerInfo);
+            }
+          } catch (Exception e) {
+            LOG.warn("Error happened when send heartbeat to " + 
shuffleServerInfo, e);
+          }
+          return null;
+        },
+        timeoutMs,
+        "send heartbeat to shuffle server");
 
-    coordinatorClients.forEach(
+    ThreadUtils.executeTasks(
+        heartBeatExecutorService,
+        coordinatorClients,
         coordinatorClient -> {
-          callableList.add(
-              () -> {
-                try {
-                  RssAppHeartBeatResponse response = 
coordinatorClient.sendAppHeartBeat(request);
-                  if (response.getStatusCode() != StatusCode.SUCCESS) {
-                    LOG.warn("Failed to send heartbeat to " + 
coordinatorClient.getDesc());
-                  } else {
-                    LOG.info("Successfully send heartbeat to " + 
coordinatorClient.getDesc());
-                  }
-                } catch (Exception e) {
-                  LOG.warn(
-                      "Error happened when send heartbeat to " + 
coordinatorClient.getDesc(), e);
-                }
-                return null;
-              });
-        });
-    try {
-      List<Future<Void>> futures =
-          heartBeatExecutorService.invokeAll(callableList, timeoutMs, 
TimeUnit.MILLISECONDS);
-      for (Future<Void> future : futures) {
-        if (!future.isDone()) {
-          future.cancel(true);
-        }
-      }
-    } catch (InterruptedException ie) {
-      LOG.warn("heartbeat is interrupted", ie);
-    }
+          try {
+            RssAppHeartBeatResponse response = 
coordinatorClient.sendAppHeartBeat(request);
+            if (response.getStatusCode() != StatusCode.SUCCESS) {
+              LOG.warn("Failed to send heartbeat to " + 
coordinatorClient.getDesc());
+            } else {
+              LOG.info("Successfully send heartbeat to " + 
coordinatorClient.getDesc());
+            }
+          } catch (Exception e) {
+            LOG.warn("Error happened when send heartbeat to " + 
coordinatorClient.getDesc(), e);
+          }
+          return null;
+        },
+        timeoutMs,
+        "send heartbeat to coordinator");
   }
 
   @Override
@@ -962,7 +938,6 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   @Override
   public void unregisterShuffle(String appId, int shuffleId) {
     RssUnregisterShuffleRequest request = new 
RssUnregisterShuffleRequest(appId, shuffleId);
-    List<Callable<Void>> callableList = Lists.newArrayList();
 
     Map<Integer, Set<ShuffleServerInfo>> appServerMap = 
shuffleServerInfoMap.get(appId);
     if (appServerMap == null) {
@@ -973,39 +948,32 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       return;
     }
 
-    shuffleServerInfos.forEach(
-        shuffleServerInfo -> {
-          callableList.add(
-              () -> {
-                try {
-                  ShuffleServerClient client =
-                      ShuffleServerClientFactory.getInstance()
-                          .getShuffleServerClient(clientType, 
shuffleServerInfo, rssConf);
-                  RssUnregisterShuffleResponse response = 
client.unregisterShuffle(request);
-                  if (response.getStatusCode() != StatusCode.SUCCESS) {
-                    LOG.warn("Failed to unregister shuffle to " + 
shuffleServerInfo);
-                  }
-                } catch (Exception e) {
-                  LOG.warn("Error happened when unregistering to " + 
shuffleServerInfo, e);
-                }
-                return null;
-              });
-        });
-
     ExecutorService executorService = null;
     try {
       executorService =
           ThreadUtils.getDaemonFixedThreadPool(
               Math.min(unregisterThreadPoolSize, shuffleServerInfos.size()), 
"unregister-shuffle");
-      List<Future<Void>> futures =
-          executorService.invokeAll(callableList, unregisterRequestTimeSec, 
TimeUnit.SECONDS);
-      for (Future<Void> future : futures) {
-        if (!future.isDone()) {
-          future.cancel(true);
-        }
-      }
-    } catch (InterruptedException ie) {
-      LOG.warn("Unregister shuffle is interrupted", ie);
+
+      ThreadUtils.executeTasks(
+          executorService,
+          shuffleServerInfos,
+          shuffleServerInfo -> {
+            try {
+              ShuffleServerClient client =
+                  ShuffleServerClientFactory.getInstance()
+                      .getShuffleServerClient(clientType, shuffleServerInfo, 
rssConf);
+              RssUnregisterShuffleResponse response = 
client.unregisterShuffle(request);
+              if (response.getStatusCode() != StatusCode.SUCCESS) {
+                LOG.warn("Failed to unregister shuffle to " + 
shuffleServerInfo);
+              }
+            } catch (Exception e) {
+              LOG.warn("Error happened when unregistering to " + 
shuffleServerInfo, e);
+            }
+            return null;
+          },
+          unregisterRequestTimeSec,
+          "unregister shuffle server");
+
     } finally {
       if (executorService != null) {
         executorService.shutdownNow();
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
index e343c128f..f68eae444 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
@@ -17,12 +17,19 @@
 
 package org.apache.uniffle.common.util;
 
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import io.netty.util.concurrent.DefaultThreadFactory;
@@ -90,4 +97,45 @@ public class ThreadUtils {
       }
     }
   }
+
+  public static <T, R> List<R> executeTasks(
+      ExecutorService executorService,
+      Collection<T> items,
+      Function<T, R> task,
+      long timeoutMs,
+      String taskMsg,
+      Function<Future<R>, R> futureHandler) {
+    List<Callable<R>> callableList =
+        items.stream()
+            .map(item -> (Callable<R>) () -> task.apply(item))
+            .collect(Collectors.toList());
+    try {
+      List<Future<R>> futures =
+          executorService.invokeAll(callableList, timeoutMs, 
TimeUnit.MILLISECONDS);
+      return futures.stream().map(futureHandler).collect(Collectors.toList());
+    } catch (InterruptedException ie) {
+      LOGGER.warn("Execute " + taskMsg + " is interrupted", ie);
+      return Collections.emptyList();
+    }
+  }
+
+  public static <T, R> List<R> executeTasks(
+      ExecutorService executorService,
+      Collection<T> items,
+      Function<T, R> task,
+      long timeoutMs,
+      String taskMsg) {
+    return executeTasks(
+        executorService,
+        items,
+        task,
+        timeoutMs,
+        taskMsg,
+        future -> {
+          if (!future.isDone()) {
+            future.cancel(true);
+          }
+          return null;
+        });
+  }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java 
b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
index 02ab5d879..d7c70a32f 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
@@ -17,12 +17,18 @@
 
 package org.apache.uniffle.common.util;
 
+import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Function;
 
 import org.junit.jupiter.api.Test;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class ThreadUtilsTest {
@@ -45,4 +51,25 @@ public class ThreadUtilsTest {
     assertTrue(finished.get());
     assertTrue(executorService.isShutdown());
   }
+
+  @Test
+  public void testExecuteTasksWithFutureHandler() {
+    ExecutorService executorService = Executors.newFixedThreadPool(2);
+    List<Integer> items = Arrays.asList(1, 2, 3, 4, 5);
+    Function<Integer, Integer> task = item -> item * 2;
+    long timeoutMs = 1000;
+    String taskMsg = "Test Task";
+    Function<Future<Integer>, Integer> futureHandler =
+        future -> {
+          try {
+            return future.get(timeoutMs, TimeUnit.MILLISECONDS);
+          } catch (Exception e) {
+            return null;
+          }
+        };
+
+    List<Integer> results =
+        ThreadUtils.executeTasks(executorService, items, task, timeoutMs, 
taskMsg, futureHandler);
+    assertEquals(Arrays.asList(2, 4, 6, 8, 10), results);
+  }
 }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java 
b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
index 95b7b729c..9b7ca1de2 100644
--- a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
+++ b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
@@ -21,10 +21,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
@@ -33,7 +32,6 @@ import org.slf4j.LoggerFactory;
 import org.apache.uniffle.client.api.CoordinatorClient;
 import org.apache.uniffle.client.factory.CoordinatorClientFactory;
 import org.apache.uniffle.client.request.RssSendHeartBeatRequest;
-import org.apache.uniffle.client.response.RssSendHeartBeatResponse;
 import org.apache.uniffle.common.ServerStatus;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.storage.StorageInfo;
@@ -109,7 +107,7 @@ public class RegisterHeartBeat {
       ServerStatus serverStatus,
       Map<String, StorageInfo> localStorageInfo,
       int nettyPort) {
-    boolean sendSuccessfully = false;
+    AtomicBoolean sendSuccessfully = new AtomicBoolean(false);
     // use `rss.server.heartbeat.interval` as the timeout option
     RssSendHeartBeatRequest request =
         new RssSendHeartBeatRequest(
@@ -125,28 +123,27 @@ public class RegisterHeartBeat {
             serverStatus,
             localStorageInfo,
             nettyPort);
-    List<Future<RssSendHeartBeatResponse>> respFutures =
-        coordinatorClients.stream()
-            .map(client -> heartBeatExecutorService.submit(() -> 
client.sendHeartBeat(request)))
-            .collect(Collectors.toList());
 
-    String msg = "";
-    for (Future<RssSendHeartBeatResponse> rf : respFutures) {
-      try {
-        if (rf.get(request.getTimeout() * 2, 
TimeUnit.MILLISECONDS).getStatusCode()
-            == StatusCode.SUCCESS) {
-          sendSuccessfully = true;
-        }
-      } catch (Exception e) {
-        msg = e.getMessage();
-      }
-    }
-
-    if (!sendSuccessfully) {
-      LOG.error(msg);
-    }
+    ThreadUtils.executeTasks(
+        heartBeatExecutorService,
+        coordinatorClients,
+        client -> client.sendHeartBeat(request),
+        request.getTimeout() * 2,
+        "send heartbeat",
+        future -> {
+          try {
+            if (future.get(request.getTimeout() * 2, 
TimeUnit.MILLISECONDS).getStatusCode()
+                == StatusCode.SUCCESS) {
+              sendSuccessfully.set(true);
+            }
+          } catch (Exception e) {
+            LOG.error(e.getMessage());
+            return null;
+          }
+          return null;
+        });
 
-    return sendSuccessfully;
+    return sendSuccessfully.get();
   }
 
   public void shutdown() {

Reply via email to