This is an automated email from the ASF dual-hosted git repository.

guoweijie pushed a commit to branch release-1.18
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 9c960536c4a7fac8bffbfcab2d0c5c3500dbd28b
Author: Wencong Liu <[email protected]>
AuthorDate: Mon Sep 11 09:56:50 2023 +0800

    [hotfix][network] Optimize the backlog calculation logic in Hybrid Shuffle
---
 .../hybrid/tiered/netty/NettyPayloadManager.java   | 52 +++++++++++++++++++---
 .../netty/TieredStorageResultSubpartitionView.java |  8 ++--
 .../tiered/netty/NettyConnectionWriterTest.java    | 41 +++++++++++++++++
 .../TieredStorageResultSubpartitionViewTest.java   | 27 ++++++-----
 4 files changed, 104 insertions(+), 24 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyPayloadManager.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyPayloadManager.java
index 95e7bc77bc0..00350a676ce 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyPayloadManager.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyPayloadManager.java
@@ -19,13 +19,17 @@
 package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty;
 
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.util.Preconditions;
 
 import javax.annotation.concurrent.GuardedBy;
 
+import java.util.Deque;
 import java.util.LinkedList;
 import java.util.Optional;
 import java.util.Queue;
 
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
 /** {@link NettyPayloadManager} is used to contain all netty payloads from a 
storage tier. */
 public class NettyPayloadManager {
 
@@ -33,16 +37,29 @@ public class NettyPayloadManager {
 
     private final Queue<NettyPayload> queue = new LinkedList<>();
 
-    /** Number of buffers whose {@link Buffer.DataType} is buffer in the 
queue. */
+    /**
+     * The queue contains a collection of numbers. Each number represents the 
number of buffers that
+     * belongs to consecutive segment ids.
+     */
+    @GuardedBy("lock")
+    private final Deque<Integer> backlogs = new LinkedList<>();
+
     @GuardedBy("lock")
-    private int backlog = 0;
+    private int lastSegmentId = -1;
 
     public void add(NettyPayload nettyPayload) {
         synchronized (lock) {
             queue.add(nettyPayload);
+            int segmentId = nettyPayload.getSegmentId();
+            if (segmentId != -1 && segmentId != lastSegmentId) {
+                if (segmentId == 0 || segmentId != (lastSegmentId + 1)) {
+                    addNewBacklog();
+                }
+                lastSegmentId = segmentId;
+            }
             Optional<Buffer> buffer = nettyPayload.getBuffer();
             if (buffer.isPresent() && buffer.get().isBuffer()) {
-                backlog++;
+                addBacklog();
             }
         }
     }
@@ -59,7 +76,7 @@ public class NettyPayloadManager {
             if (nettyPayload != null
                     && nettyPayload.getBuffer().isPresent()
                     && nettyPayload.getBuffer().get().isBuffer()) {
-                backlog--;
+                decreaseBacklog();
             }
             return nettyPayload;
         }
@@ -67,7 +84,8 @@ public class NettyPayloadManager {
 
     public int getBacklog() {
         synchronized (lock) {
-            return backlog;
+            Integer backlog = backlogs.peekFirst();
+            return backlog == null ? 0 : backlog;
         }
     }
 
@@ -76,4 +94,28 @@ public class NettyPayloadManager {
             return queue.size();
         }
     }
+
+    @GuardedBy("lock")
+    private void addNewBacklog() {
+        backlogs.addLast(0);
+    }
+
+    @GuardedBy("lock")
+    private void addBacklog() {
+        Integer backlog = backlogs.pollLast();
+        if (backlog == null) {
+            backlogs.addLast(1);
+        } else {
+            backlogs.addLast(backlog + 1);
+        }
+    }
+
+    @GuardedBy("lock")
+    private void decreaseBacklog() {
+        int backlog = checkNotNull(backlogs.pollFirst());
+        Preconditions.checkState(backlog > 0);
+        if (backlog > 1) {
+            backlogs.addFirst(backlog - 1);
+        }
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionView.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionView.java
index 674ebe9a359..8ec8f252bb7 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionView.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionView.java
@@ -200,11 +200,9 @@ public class TieredStorageResultSubpartitionView 
implements ResultSubpartitionVi
     }
 
     private int getBacklog() {
-        int backlog = 0;
-        for (NettyPayloadManager nettyPayloadManager : nettyPayloadManagers) {
-            backlog += nettyPayloadManager.getBacklog();
-        }
-        return backlog;
+        return managerIndexContainsCurrentSegment == -1
+                ? 0
+                : 
nettyPayloadManagers.get(managerIndexContainsCurrentSegment).getBacklog();
     }
 
     private boolean isEventOrError(NettyPayloadManager nettyPayloadManager) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyConnectionWriterTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyConnectionWriterTest.java
index 5d655218472..7748a3889a1 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyConnectionWriterTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyConnectionWriterTest.java
@@ -79,6 +79,40 @@ public class NettyConnectionWriterTest {
         assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isZero();
     }
 
+    @Test
+    void testGetNumQueuedBufferPayloads() {
+        NettyPayloadManager nettyPayloadManager = new NettyPayloadManager();
+        NettyConnectionWriter nettyConnectionWriter =
+                new NettyConnectionWriterImpl(nettyPayloadManager, () -> {});
+        nettyConnectionWriter.writeNettyPayload(NettyPayload.newSegment(0));
+        writeBufferToWriter(3, nettyConnectionWriter);
+        nettyConnectionWriter.writeNettyPayload(NettyPayload.newSegment(2));
+        writeBufferToWriter(1, nettyConnectionWriter);
+        nettyConnectionWriter.writeNettyPayload(NettyPayload.newSegment(3));
+        writeBufferToWriter(1, nettyConnectionWriter);
+        nettyConnectionWriter.writeNettyPayload(NettyPayload.newSegment(5));
+        writeBufferToWriter(5, nettyConnectionWriter);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(3);
+        clearNettyPayloadManager(1, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(3);
+        clearNettyPayloadManager(2, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(1);
+        clearNettyPayloadManager(1, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(2);
+        clearNettyPayloadManager(1, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(2);
+        clearNettyPayloadManager(1, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(1);
+        clearNettyPayloadManager(1, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(1);
+        clearNettyPayloadManager(1, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(5);
+        clearNettyPayloadManager(1, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(5);
+        clearNettyPayloadManager(2, nettyPayloadManager);
+        
assertThat(nettyConnectionWriter.numQueuedBufferPayloads()).isEqualTo(3);
+    }
+
     private static void writeBufferToWriter(
             int bufferNumber, NettyConnectionWriter nettyConnectionWriter) {
         for (int index = 0; index < bufferNumber; ++index) {
@@ -87,4 +121,11 @@ public class NettyConnectionWriterTest {
                             BufferBuilderTestUtils.buildSomeBuffer(0), index, 
SUBPARTITION_ID));
         }
     }
+
+    private static void clearNettyPayloadManager(
+            int payloadNumber, NettyPayloadManager nettyPayloadManager) {
+        for (int index = 0; index < payloadNumber; ++index) {
+            nettyPayloadManager.poll();
+        }
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionViewTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionViewTest.java
index afd84848cef..29ab615bcff 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionViewTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageResultSubpartitionViewTest.java
@@ -55,7 +55,7 @@ public class TieredStorageResultSubpartitionViewTest {
     @BeforeEach
     void before() {
         availabilityListener = new CompletableFuture<>();
-        nettyPayloadManagers = createNettyPayloadQueues();
+        nettyPayloadManagers = createNettyPayloadManagers();
         connectionBrokenConsumers =
                 Arrays.asList(new CompletableFuture<>(), new 
CompletableFuture<>());
         tieredStorageResultSubpartitionView =
@@ -68,12 +68,10 @@ public class TieredStorageResultSubpartitionViewTest {
 
     @Test
     void testGetNextBuffer() throws IOException {
-        
checkBufferAndBacklog(tieredStorageResultSubpartitionView.getNextBuffer(), 1);
-        
checkBufferAndBacklog(tieredStorageResultSubpartitionView.getNextBuffer(), 1);
+        
checkBufferAndBacklog(tieredStorageResultSubpartitionView.getNextBuffer(), 0);
         tieredStorageResultSubpartitionView.notifyRequiredSegmentId(1);
         assertThat(availabilityListener).isDone();
         
checkBufferAndBacklog(tieredStorageResultSubpartitionView.getNextBuffer(), 0);
-        
checkBufferAndBacklog(tieredStorageResultSubpartitionView.getNextBuffer(), 0);
         
assertThat(tieredStorageResultSubpartitionView.getNextBuffer()).isNull();
     }
 
@@ -96,11 +94,11 @@ public class TieredStorageResultSubpartitionViewTest {
     void testGetAvailabilityAndBacklog() {
         ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog1 
=
                 
tieredStorageResultSubpartitionView.getAvailabilityAndBacklog(0);
-        assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(2);
+        assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(1);
         assertThat(availabilityAndBacklog1.isAvailable()).isEqualTo(false);
         ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog2 
=
                 
tieredStorageResultSubpartitionView.getAvailabilityAndBacklog(2);
-        assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(2);
+        assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(1);
         assertThat(availabilityAndBacklog2.isAvailable()).isEqualTo(true);
     }
 
@@ -122,9 +120,9 @@ public class TieredStorageResultSubpartitionViewTest {
 
     @Test
     void testGetNumberOfQueuedBuffers() {
-        
assertThat(tieredStorageResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(2);
+        
assertThat(tieredStorageResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(1);
         
assertThat(tieredStorageResultSubpartitionView.unsynchronizedGetNumberOfQueuedBuffers())
-                .isEqualTo(2);
+                .isEqualTo(1);
     }
 
     private static void checkBufferAndBacklog(BufferAndBacklog 
bufferAndBacklog, int backlog) {
@@ -138,13 +136,14 @@ public class TieredStorageResultSubpartitionViewTest {
         return () -> notifier.complete(null);
     }
 
-    private static List<NettyPayloadManager> createNettyPayloadQueues() {
+    private static List<NettyPayloadManager> createNettyPayloadManagers() {
         List<NettyPayloadManager> nettyPayloadManagers = new ArrayList<>();
         for (int index = 0; index < TIER_NUMBER; ++index) {
-            NettyPayloadManager queue = new NettyPayloadManager();
-            queue.add(NettyPayload.newSegment(index));
-            
queue.add(NettyPayload.newBuffer(BufferBuilderTestUtils.buildSomeBuffer(0), 0, 
index));
-            queue.add(
+            NettyPayloadManager nettyPayloadManager = new 
NettyPayloadManager();
+            nettyPayloadManager.add(NettyPayload.newSegment(index));
+            nettyPayloadManager.add(
+                    
NettyPayload.newBuffer(BufferBuilderTestUtils.buildSomeBuffer(0), 0, index));
+            nettyPayloadManager.add(
                     NettyPayload.newBuffer(
                             new NetworkBuffer(
                                     
MemorySegmentFactory.allocateUnpooledSegment(0),
@@ -152,7 +151,7 @@ public class TieredStorageResultSubpartitionViewTest {
                                     END_OF_SEGMENT),
                             1,
                             index));
-            nettyPayloadManagers.add(queue);
+            nettyPayloadManagers.add(nettyPayloadManager);
         }
         return nettyPayloadManagers;
     }

Reply via email to