This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch branch-0.8
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git

commit a785d3a7e934f432fc0966dfa24d14de7a9fe319
Author: roryqi <[email protected]>
AuthorDate: Thu Aug 31 10:05:12 2023 +0800

    [#1177] improvement: Reduce the write time of tasks (#1179)
    
    ### What changes were proposed in this pull request?
    Use callback and queue to optimize the write time.
    
    ### Why are the changes needed?
    
    Fix: #1177
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    CI passed.
    
    (cherry picked from commit 164e0d02456763f661fb55cc47459312bf4e80b4)
---
 .../spark/shuffle/writer/RssShuffleWriter.java     | 51 +++++++++++++++++-----
 1 file changed, 41 insertions(+), 10 deletions(-)

diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index fb0c78502..330f56c8d 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -24,10 +24,12 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 
@@ -93,6 +95,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   protected final ShuffleWriteMetrics shuffleWriteMetrics;
 
+  private final BlockingQueue<Object> finishEventQueue = new 
LinkedBlockingQueue<>();
+
   // Only for tests
   @VisibleForTesting
   public RssShuffleWriter(
@@ -293,6 +297,13 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       List<ShuffleBlockInfo> shuffleBlockInfoList) {
     List<CompletableFuture<Long>> futures = new ArrayList<>();
     for (AddBlockEvent event : 
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
+      event.addCallback(
+          () -> {
+            boolean ret = finishEventQueue.add(new Object());
+            if (!ret) {
+              LOG.error("Add event " + event + " to finishEventQueue fail");
+            }
+          });
       futures.add(shuffleManager.sendData(event));
     }
     return futures;
@@ -300,17 +311,33 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @VisibleForTesting
   protected void checkBlockSendResult(Set<Long> blockIds) {
-    long start = System.currentTimeMillis();
-    while (true) {
-      checkIfBlocksFailed();
-      Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId);
-      blockIds.removeAll(successBlockIds);
-      if (blockIds.isEmpty()) {
-        break;
+    boolean interrupted = false;
+
+    try {
+      long remainingMs = sendCheckTimeout;
+      long end = System.currentTimeMillis() + remainingMs;
+
+      while (true) {
+        try {
+          finishEventQueue.clear();
+          checkIfBlocksFailed();
+          Set<Long> successBlockIds = 
shuffleManager.getSuccessBlockIds(taskId);
+          blockIds.removeAll(successBlockIds);
+          if (blockIds.isEmpty()) {
+            break;
+          }
+          if (finishEventQueue.isEmpty()) {
+            remainingMs = Math.max(end - System.currentTimeMillis(), 0);
+            Object event = finishEventQueue.poll(remainingMs, 
TimeUnit.MILLISECONDS);
+            if (event == null) {
+              break;
+            }
+          }
+        } catch (InterruptedException e) {
+          interrupted = true;
+        }
       }
-      LOG.info("Wait " + blockIds.size() + " blocks sent to shuffle server");
-      Uninterruptibles.sleepUninterruptibly(sendCheckInterval, 
TimeUnit.MILLISECONDS);
-      if (System.currentTimeMillis() - start > sendCheckTimeout) {
+      if (!blockIds.isEmpty()) {
         String errorMsg =
             "Timeout: Task["
                 + taskId
@@ -322,6 +349,10 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         LOG.error(errorMsg);
         throw new RssException(errorMsg);
       }
+    } finally {
+      if (interrupted) {
+        Thread.currentThread().interrupt();
+      }
     }
   }
 

Reply via email to