This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch release-1.17
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.17 by this push:
     new b90bf7acde9 [FLINK-31386][network] Fix the potential deadlock issue of 
blocking shuffle
b90bf7acde9 is described below

commit b90bf7acde9632a399137c20be36f12706bd75f2
Author: kevin.cyj <[email protected]>
AuthorDate: Thu Mar 9 11:52:21 2023 +0800

    [FLINK-31386][network] Fix the potential deadlock issue of blocking shuffle
    
    Currently, the SortMergeResultPartition may allocate more network buffers 
than the guaranteed size of the LocalBufferPool. As a result, some result 
partitions may need to wait other result partitions to release the 
over-allocated network buffers to continue. However, the result partitions 
which have allocated more than guaranteed buffers relies on the processing of 
input data to trigger data spilling and buffer recycling. The input data 
further relies on batch reading buffers used by  [...]
    
    This closes #22148.
---
 .../io/disk/BatchShuffleReadBufferPool.java        |   2 +-
 .../partition/SortMergeResultPartition.java        |  12 ++-
 .../partition/SortMergeResultPartitionTest.java    | 108 ++++++++++++++++++++-
 3 files changed, 117 insertions(+), 5 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java
index fc9aa1b0ef2..863c0429400 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java
@@ -54,7 +54,7 @@ public class BatchShuffleReadBufferPool {
      * Memory size in bytes can be allocated from this buffer pool for a 
single request (4M is for
      * better sequential read).
      */
-    private static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024;
+    public static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024;
 
     /**
      * Wait for at most 2 seconds before return if there is no enough 
available buffers currently.
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index def80f40574..ba84f8b6e0b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -186,6 +186,11 @@ public class SortMergeResultPartition extends 
ResultPartition {
             }
         }
 
+        // reserve the "guaranteed" buffers for this buffer pool to avoid the 
case that those
+        // buffers are taken by other result partitions and can not be 
released, which may cause
+        // deadlock
+        requestGuaranteedBuffers();
+
         // initialize the buffer pool eagerly to avoid reporting errors such 
as OOM too late
         readBufferPool.initialize();
         LOG.info("Sort-merge partition {} initialized.", getPartitionId());
@@ -325,7 +330,7 @@ public class SortMergeResultPartition extends 
ResultPartition {
         }
     }
 
-    private void requestNetworkBuffers() throws IOException {
+    private void requestGuaranteedBuffers() throws IOException {
         int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments();
         if (numRequiredBuffer < 2) {
             throw new IOException(
@@ -339,8 +344,13 @@ public class SortMergeResultPartition extends 
ResultPartition {
                 
freeSegments.add(checkNotNull(bufferPool.requestMemorySegmentBlocking()));
             }
         } catch (InterruptedException exception) {
+            freeSegments.forEach(bufferPool::recycle);
             throw new IOException("Failed to allocate buffers for result 
partition.", exception);
         }
+    }
+
+    private void requestNetworkBuffers() throws IOException {
+        requestGuaranteedBuffers();
 
         // avoid taking too many buffers in one result partition
         while (freeSegments.size() < 
bufferPool.getMaxNumberOfMemorySegments()) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
index 6f01823c3fd..c202669e9f3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
@@ -50,6 +50,7 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Queue;
 import java.util.Random;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.function.Consumer;
@@ -325,7 +326,7 @@ public class SortMergeResultPartitionTest {
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, 
bufferPool);
-        assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0);
+        
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
 
         partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 
1)), 0);
         partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 
1)), 1);
@@ -348,7 +349,7 @@ public class SortMergeResultPartitionTest {
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, 
bufferPool);
-        assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0);
+        
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
 
         partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 
1)), 0);
         partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 
1)), 1);
@@ -381,7 +382,7 @@ public class SortMergeResultPartitionTest {
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, 
bufferPool);
-        assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0);
+        
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
 
         partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 
1)), 5);
         assertThat(bufferPool.bestEffortGetNumOfUsedBuffers())
@@ -423,6 +424,107 @@ public class SortMergeResultPartitionTest {
         testResultPartitionBytesCounter(true);
     }
 
+    @TestTemplate
+    void testNetworkBufferReservation() throws IOException {
+        int numBuffers = 10;
+
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 2 * 
numBuffers);
+        SortMergeResultPartition partition = createSortMergedPartition(1, 
bufferPool);
+        
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
+
+        partition.finish();
+        partition.close();
+    }
+
+    @TestTemplate
+    void testNoDeadlockOnSpecificConsumptionOrder() throws Exception {
+        // see https://issues.apache.org/jira/browse/FLINK-31386 for more 
information
+        int numNetworkBuffers = 2 * 
BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST / bufferSize;
+        NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(numNetworkBuffers, bufferSize);
+        BatchShuffleReadBufferPool readBufferPool =
+                new BatchShuffleReadBufferPool(
+                        BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST, 
bufferSize);
+
+        BufferPool bufferPool =
+                networkBufferPool.createBufferPool(numNetworkBuffers, 
numNetworkBuffers);
+        SortMergeResultPartition partition =
+                createSortMergedPartition(1, bufferPool, readBufferPool);
+        for (int i = 0; i < numNetworkBuffers; ++i) {
+            partition.emitRecord(ByteBuffer.allocate(bufferSize), 0);
+        }
+        partition.finish();
+        partition.close();
+
+        CountDownLatch condition1 = new CountDownLatch(1);
+        CountDownLatch condition2 = new CountDownLatch(1);
+
+        Runnable task1 =
+                () -> {
+                    try {
+                        ResultSubpartitionView view = 
partition.createSubpartitionView(0, listener);
+                        BufferPool bufferPool1 =
+                                networkBufferPool.createBufferPool(
+                                        numNetworkBuffers / 2, 
numNetworkBuffers);
+                        SortMergeResultPartition partition1 =
+                                createSortMergedPartition(1, bufferPool1);
+                        readAndEmitData(view, partition1);
+
+                        condition1.countDown();
+                        condition2.await();
+                        readAndEmitAllData(view, partition1);
+                    } catch (Exception ignored) {
+                    }
+                };
+        Thread consumer1 = new Thread(task1);
+        consumer1.start();
+
+        Runnable task2 =
+                () -> {
+                    try {
+                        condition1.await();
+                        BufferPool bufferPool2 =
+                                networkBufferPool.createBufferPool(
+                                        numNetworkBuffers / 2, 
numNetworkBuffers);
+                        condition2.countDown();
+
+                        SortMergeResultPartition partition2 =
+                                createSortMergedPartition(1, bufferPool2);
+                        ResultSubpartitionView view = 
partition.createSubpartitionView(0, listener);
+                        readAndEmitAllData(view, partition2);
+                    } catch (Exception ignored) {
+                    }
+                };
+        Thread consumer2 = new Thread(task2);
+        consumer2.start();
+
+        consumer1.join();
+        consumer2.join();
+    }
+
+    private boolean readAndEmitData(ResultSubpartitionView view, 
SortMergeResultPartition partition)
+            throws Exception {
+        MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
+        ResultSubpartition.BufferAndBacklog buffer;
+        do {
+            buffer = view.getNextBuffer();
+            if (buffer != null) {
+                Buffer data = ((CompositeBuffer) 
buffer.buffer()).getFullBufferData(segment);
+                partition.emitRecord(data.getNioBufferReadable(), 0);
+                if (!data.isRecycled()) {
+                    data.recycleBuffer();
+                }
+                return buffer.buffer().isBuffer();
+            }
+        } while (true);
+    }
+
+    private void readAndEmitAllData(ResultSubpartitionView view, 
SortMergeResultPartition partition)
+            throws Exception {
+        while (readAndEmitData(view, partition)) {}
+        partition.finish();
+        partition.close();
+    }
+
     private void testResultPartitionBytesCounter(boolean isBroadcast) throws 
IOException {
         int numBuffers = useHashDataBuffer ? 100 : 15;
         int numSubpartitions = 2;

Reply via email to