This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e4804b9 [#1068] feat(tez): Fail fast in client when failed to send 
data to server. (#1069)
1e4804b9 is described below

commit 1e4804b944bd91887748fc0ad2771825752f3a28
Author: Fantasy-Jay <[email protected]>
AuthorDate: Thu Aug 3 14:28:04 2023 +0800

    [#1068] feat(tez): Fail fast in client when failed to send data to server. 
(#1069)
    
    ### What changes were proposed in this pull request?
    
    Currently, it only checks for blocks that failed to send after all buffer 
data has been sent.
    This check also needs to be moved forward into the addRecord method, 
allowing it to fail fast.
    
    ### Why are the changes needed?
    Fix: # ([1068](https://github.com/apache/incubator-uniffle/issues/1068))
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Add more test case in WriteBufferManager.
---
 .../common/sort/buffer/WriteBufferManager.java     |  24 ++--
 .../common/sort/buffer/WriteBufferManagerTest.java | 126 ++++++++++++++++++---
 2 files changed, 129 insertions(+), 21 deletions(-)

diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
index 06e4194b..7da3208c 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -172,6 +172,9 @@ public class WriteBufferManager<K, V> {
       memoryLock.unlock();
     }
 
+    // Fail fast if there are some failed blocks.
+    checkFailedBlocks();
+
     if (!buffers.containsKey(partitionId)) {
       WriteBuffer<K, V> sortWriterBuffer =
           new WriteBuffer(
@@ -282,14 +285,7 @@ public class WriteBufferManager<K, V> {
     }
     long start = System.currentTimeMillis();
     while (true) {
-      if (failedBlockIds.size() > 0) {
-        String errorMsg =
-            "Send failed: failed because "
-                + failedBlockIds.size()
-                + " blocks can't be sent to shuffle server.";
-        LOG.error(errorMsg);
-        throw new RssException(errorMsg);
-      }
+      checkFailedBlocks();
       allBlockIds.removeAll(successBlockIds);
       if (allBlockIds.isEmpty()) {
         break;
@@ -335,6 +331,18 @@ public class WriteBufferManager<K, V> {
         sortTime);
   }
 
+  // Check if there are some failed blocks, if true then throw Exception.
+  private void checkFailedBlocks() {
+    if (failedBlockIds.size() > 0) {
+      String errorMsg =
+          "Send failed: failed because "
+              + failedBlockIds.size()
+              + " blocks can't be sent to shuffle server.";
+      LOG.error(errorMsg);
+      throw new RssException(errorMsg);
+    }
+  }
+
   ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) {
     byte[] data = wb.getData();
     copyTime += wb.getCopyTime();
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 8f2af65b..1449649c 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -44,7 +44,6 @@ import org.apache.tez.common.TezRuntimeFrameworkConfigs;
 import org.apache.tez.common.counters.TaskCounter;
 import org.apache.tez.common.counters.TezCounter;
 import org.apache.tez.dag.records.TezTaskAttemptID;
-import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.OutputContext;
 import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
 import org.apache.tez.runtime.library.output.OutputTestHelpers;
@@ -67,6 +66,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class WriteBufferManagerTest {
@@ -99,7 +99,9 @@ public class WriteBufferManagerTest {
     long sendCheckInterval = 500L;
     long sendCheckTimeout = 5;
     int bitmapSplitNum = 1;
-    int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
+    int shuffleId =
+        RssTezUtils.computeShuffleId(
+            tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 
2);
 
     Configuration conf = new Configuration();
     FileSystem localFs = FileSystem.getLocal(conf);
@@ -139,7 +141,7 @@ public class WriteBufferManagerTest {
             rssConf,
             partitionToServers,
             numMaps,
-            isMemoryShuffleEnabled(storageType),
+            StorageType.withMemory(StorageType.valueOf(storageType)),
             sendCheckInterval,
             sendCheckTimeout,
             bitmapSplitNum,
@@ -197,7 +199,9 @@ public class WriteBufferManagerTest {
     long sendCheckInterval = 500L;
     long sendCheckTimeout = 60 * 1000 * 10L;
     int bitmapSplitNum = 1;
-    int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
+    int shuffleId =
+        RssTezUtils.computeShuffleId(
+            tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 
2);
 
     Configuration conf = new Configuration();
     FileSystem localFs = FileSystem.getLocal(conf);
@@ -237,7 +241,7 @@ public class WriteBufferManagerTest {
             rssConf,
             partitionToServers,
             numMaps,
-            isMemoryShuffleEnabled(storageType),
+            StorageType.withMemory(StorageType.valueOf(storageType)),
             sendCheckInterval,
             sendCheckTimeout,
             bitmapSplitNum,
@@ -305,7 +309,9 @@ public class WriteBufferManagerTest {
     long sendCheckInterval = 500L;
     long sendCheckTimeout = 60 * 1000 * 10L;
     int bitmapSplitNum = 1;
-    int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2);
+    int shuffleId =
+        RssTezUtils.computeShuffleId(
+            tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 
2);
 
     Configuration conf = new Configuration();
     FileSystem localFs = FileSystem.getLocal(conf);
@@ -371,15 +377,109 @@ public class WriteBufferManagerTest {
         writeClient.mockedShuffleServer.getFlushBlockSize());
   }
 
-  private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, int upVertexId, 
int downVertexId) {
-    TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID();
+  @Test
+  public void testFastFailWhenSendBlocksFailed(@TempDir File tmpDir)
+      throws IOException, InterruptedException {
+    TezTaskAttemptID tezTaskAttemptID =
+        
TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0");
+    final long maxMemSize = 10240;
+    final String appId = "application_1681717153064_3770270";
+    final long taskAttemptId = 0;
+    final Set<Long> successBlockIds = Sets.newConcurrentHashSet();
+    final Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
+    MockShuffleWriteClient writeClient = new MockShuffleWriteClient();
+    // set mode = 1 to fake sending shuffle data failed.
+    writeClient.setMode(1);
+    RawComparator comparator = WritableComparator.get(BytesWritable.class);
+    long maxSegmentSize = 3 * 1024;
+    SerializationFactory serializationFactory = new SerializationFactory(new 
JobConf());
+    Serializer<BytesWritable> keySerializer =
+        serializationFactory.getSerializer(BytesWritable.class);
+    Serializer<BytesWritable> valSerializer =
+        serializationFactory.getSerializer(BytesWritable.class);
+    // note: max buffer size is tiny.
+    long maxBufferSize = 14 * 1024;
+    double memoryThreshold = 0.8f;
+    int sendThreadNum = 1;
+    double sendThreshold = 0.2f;
+    int batch = 50;
+    int numMaps = 1;
+    RssConf rssConf = new RssConf();
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    long sendCheckInterval = 500L;
+    long sendCheckTimeout = 60 * 1000 * 10L;
+    int bitmapSplitNum = 1;
     int shuffleId =
-        RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), 
upVertexId, downVertexId);
-    return shuffleId;
-  }
+        RssTezUtils.computeShuffleId(
+            tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 
2);
+
+    Configuration conf = new Configuration();
+    FileSystem localFs = FileSystem.getLocal(conf);
+    Path workingDir =
+        new Path(
+                System.getProperty(
+                    "test.build.data", System.getProperty("java.io.tmpdir", 
tmpDir.toString())),
+                RssOrderedPartitionedKVOutputTest.class.getName())
+            .makeQualified(localFs.getUri(), localFs.getWorkingDirectory());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, 
Text.class.getName());
+    conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, 
Text.class.getName());
+    conf.set(
+        TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS, 
HashPartitioner.class.getName());
+    conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, 
workingDir.toString());
+    OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, 
workingDir);
+    TezCounter mapOutputByteCounter =
+        outputContext.getCounters().findCounter(TaskCounter.OUTPUT_BYTES);
 
-  private boolean isMemoryShuffleEnabled(String storageType) {
-    return StorageType.withMemory(StorageType.valueOf(storageType));
+    WriteBufferManager<BytesWritable, BytesWritable> bufferManager =
+        new WriteBufferManager(
+            tezTaskAttemptID,
+            maxMemSize,
+            appId,
+            taskAttemptId,
+            successBlockIds,
+            failedBlockIds,
+            writeClient,
+            comparator,
+            maxSegmentSize,
+            keySerializer,
+            valSerializer,
+            maxBufferSize,
+            memoryThreshold,
+            sendThreadNum,
+            sendThreshold,
+            batch,
+            rssConf,
+            partitionToServers,
+            numMaps,
+            false,
+            sendCheckInterval,
+            sendCheckTimeout,
+            bitmapSplitNum,
+            shuffleId,
+            true,
+            mapOutputByteCounter);
+
+    Random random = new Random();
+    RssException rssException =
+        assertThrows(
+            RssException.class,
+            () -> {
+              for (int i = 0; i < 10000; i++) {
+                byte[] key = new byte[20];
+                byte[] value = new byte[1024];
+                random.nextBytes(key);
+                random.nextBytes(value);
+                int partitionId = random.nextInt(50);
+                bufferManager.addRecord(
+                    partitionId, new BytesWritable(key), new 
BytesWritable(value));
+              }
+            });
+    assertTrue(rssException.getMessage().contains("Send failed"));
+
+    rssException = assertThrows(RssException.class, 
bufferManager::waitSendFinished);
+    assertTrue(rssException.getMessage().contains("Send failed"));
+
+    assertTrue(mapOutputByteCounter.getValue() < 10520000);
   }
 
   class MockShuffleServer {

Reply via email to