Github user NicoK commented on a diff in the pull request:

    https://github.com/apache/flink/pull/4509#discussion_r152860104
  
    --- Diff: 
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
 ---
    @@ -301,81 +306,388 @@ public void testProducerFailedException() throws 
Exception {
        }
     
        /**
    -    * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying 
the exclusive segment is
    -    * recycled to available buffers directly and it triggers notify of 
announced credit.
    +    * Tests to verify that the input channel requests floating buffers 
from buffer pool
    +    * in order to maintain backlog + initialCredit buffers available once 
receiving the
    +    * sender's backlog, and registers as listener if no floating buffers 
available.
         */
        @Test
    -   public void testRecycleExclusiveBufferBeforeReleased() throws Exception 
{
    -           final SingleInputGate inputGate = mock(SingleInputGate.class);
    -           final RemoteInputChannel inputChannel = 
spy(createRemoteInputChannel(inputGate));
    +   public void testRequestFloatingBufferOnSenderBacklog() throws Exception 
{
    +           // Setup
    +           final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(12, 32, MemoryType.HEAP);
    +           final SingleInputGate inputGate = createSingleInputGate();
    +           final RemoteInputChannel inputChannel = 
createRemoteInputChannel(inputGate);
    +           try {
    +                   final int numFloatingBuffers = 10;
    +                   final BufferPool bufferPool = 
spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
    +                   inputGate.setBufferPool(bufferPool);
    +
    +                   // Assign exclusive segments to the channel
    +                   final int numExclusiveBuffers = 2;
    +                   
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
    +                   inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveBuffers);
    +
    +                   assertEquals("There should be " + numExclusiveBuffers + 
" buffers available in the channel",
    +                           numExclusiveBuffers, 
inputChannel.getNumberOfAvailableBuffers());
     
    -           // Recycle exclusive segment
    -           
inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
    +                   // Receive the producer's backlog
    +                   inputChannel.onSenderBacklog(8);
     
    -           assertEquals("There should be one buffer available after 
recycle.",
    -                   1, inputChannel.getNumberOfAvailableBuffers());
    -           verify(inputChannel, times(1)).notifyCreditAvailable();
    +                   // Request the number of floating buffers by the 
formula of backlog + initialCredit - availableBuffers
    +                   verify(bufferPool, times(8)).requestBuffer();
    +                   verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
    +                   assertEquals("There should be 10 buffers available in 
the channel",
    +                           10, inputChannel.getNumberOfAvailableBuffers());
     
    -           
inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
    +                   inputChannel.onSenderBacklog(11);
     
    -           assertEquals("There should be two buffers available after 
recycle.",
    -                   2, inputChannel.getNumberOfAvailableBuffers());
    -           // It should be called only once when increased from zero.
    -           verify(inputChannel, times(1)).notifyCreditAvailable();
    +                   // Need extra three floating buffers, but only two 
buffers available in buffer pool, register as listener as a result
    +                   verify(bufferPool, times(11)).requestBuffer();
    +                   verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
    +                   assertEquals("There should be 12 buffers available in 
the channel",
    +                           12, inputChannel.getNumberOfAvailableBuffers());
    +
    +                   inputChannel.onSenderBacklog(12);
    +
    +                   // Already in the status of waiting for buffers and 
will not request any more
    +                   verify(bufferPool, times(11)).requestBuffer();
    +                   verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
    +
    +           } finally {
    +                   // Release all the buffer resources
    +                   inputChannel.releaseAllResources();
    +
    +                   networkBufferPool.destroyAllBufferPools();
    +                   networkBufferPool.destroy();
    +           }
        }
     
        /**
    -    * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying 
the exclusive segment is
    -    * recycled to global pool via input gate when channel is released.
    +    * Tests to verify that the buffer pool will distribute available 
floating buffers among
    +    * all the channel listeners in a fair way.
         */
        @Test
    -   public void testRecycleExclusiveBufferAfterReleased() throws Exception {
    +   public void testFairDistributionFloatingBuffers() throws Exception {
                // Setup
    -           final SingleInputGate inputGate = mock(SingleInputGate.class);
    -           final RemoteInputChannel inputChannel = 
spy(createRemoteInputChannel(inputGate));
    +           final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(12, 32, MemoryType.HEAP);
    +           final SingleInputGate inputGate = createSingleInputGate();
    +           final RemoteInputChannel channel1 = 
spy(createRemoteInputChannel(inputGate));
    +           final RemoteInputChannel channel2 = 
spy(createRemoteInputChannel(inputGate));
    +           final RemoteInputChannel channel3 = 
spy(createRemoteInputChannel(inputGate));
    +           try {
    +                   final int numFloatingBuffers = 3;
    +                   final BufferPool bufferPool = 
spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
    +                   inputGate.setBufferPool(bufferPool);
    +
    +                   // Assign exclusive segments to the channels
    +                   
inputGate.setInputChannel(channel1.partitionId.getPartitionId(), channel1);
    +                   
inputGate.setInputChannel(channel2.partitionId.getPartitionId(), channel2);
    +                   
inputGate.setInputChannel(channel3.partitionId.getPartitionId(), channel3);
    +                   final int numExclusiveBuffers = 2;
    +                   inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveBuffers);
    +
    +                   // Exhaust all the floating buffers
    +                   final List<Buffer> floatingBuffers = new 
ArrayList<>(numFloatingBuffers);
    +                   for (int i = 0; i < numFloatingBuffers; i++) {
    +                           Buffer buffer = bufferPool.requestBuffer();
    +                           assertNotNull(buffer);
    +                           floatingBuffers.add(buffer);
    +                   }
    +
    +                   // Receive the producer's backlog to trigger request 
floating buffers from pool
    +                   // and register as listeners as a result
    +                   channel1.onSenderBacklog(8);
    +                   channel2.onSenderBacklog(8);
    +                   channel3.onSenderBacklog(8);
    +
    +                   verify(bufferPool, 
times(1)).addBufferListener(channel1);
    +                   verify(bufferPool, 
times(1)).addBufferListener(channel2);
    +                   verify(bufferPool, 
times(1)).addBufferListener(channel3);
    +                   assertEquals("There should be " + numExclusiveBuffers + 
" buffers available in the channel",
    +                           numExclusiveBuffers, 
channel1.getNumberOfAvailableBuffers());
    +                   assertEquals("There should be " + numExclusiveBuffers + 
" buffers available in the channel",
    +                           numExclusiveBuffers, 
channel2.getNumberOfAvailableBuffers());
    +                   assertEquals("There should be " + numExclusiveBuffers + 
" buffers available in the channel",
    +                           numExclusiveBuffers, 
channel3.getNumberOfAvailableBuffers());
    +
    +                   // Recycle three floating buffers to trigger notify 
buffer available
    +                   for (Buffer buffer : floatingBuffers) {
    +                           buffer.recycle();
    +                   }
    +
    +                   verify(channel1, 
times(1)).notifyBufferAvailable(any(Buffer.class));
    +                   verify(channel2, 
times(1)).notifyBufferAvailable(any(Buffer.class));
    +                   verify(channel3, 
times(1)).notifyBufferAvailable(any(Buffer.class));
    +                   assertEquals("There should be 3 buffers available in 
the channel", 3, channel1.getNumberOfAvailableBuffers());
    +                   assertEquals("There should be 3 buffers available in 
the channel", 3, channel2.getNumberOfAvailableBuffers());
    +                   assertEquals("There should be 3 buffers available in 
the channel", 3, channel3.getNumberOfAvailableBuffers());
    +
    +           } finally {
    +                   // Release all the buffer resources
    +                   channel1.releaseAllResources();
    +                   channel2.releaseAllResources();
    +                   channel3.releaseAllResources();
    +
    +                   networkBufferPool.destroyAllBufferPools();
    +                   networkBufferPool.destroy();
    +           }
    +   }
    +
    +   /**
    +    * Tests to verify that there is no race condition with two things 
running in parallel:
    +    * requesting floating buffers on sender backlog and some other thread 
releasing
    +    * the input channel.
    +    */
    +   @Test
    +   public void testConcurrentOnSenderBacklogAndRelease() throws Exception {
    +           // Setup
    +           final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(256, 32, MemoryType.HEAP);
    +           final ExecutorService executor = 
Executors.newFixedThreadPool(2);
    +           final SingleInputGate inputGate = createSingleInputGate();
    +           final RemoteInputChannel inputChannel  = 
createRemoteInputChannel(inputGate);
    +           
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
    +           try {
    +                   final BufferPool bufferPool = 
networkBufferPool.createBufferPool(128, 128);
    +                   inputGate.setBufferPool(bufferPool);
    +                   inputGate.assignExclusiveSegments(networkBufferPool, 2);
    +
    +                   final Callable<Void> requestBufferTask = new 
Callable<Void>() {
    +                           @Override
    +                           public Void call() throws Exception {
    +                                   while (true) {
    +                                           for (int j = 1; j <= 128; j++) {
    +                                                   
inputChannel.onSenderBacklog(j);
    +                                           }
    +
    +                                           if (inputChannel.isReleased()) {
    +                                                   return null;
    +                                           }
    +                                   }
    +                           }
    +                   };
     
    -           inputChannel.releaseAllResources();
    +                   final Callable<Void> releaseTask = new Callable<Void>() 
{
    +                           @Override
    +                           public Void call() throws Exception {
    +                                   inputChannel.releaseAllResources();
    +
    +                                   return null;
    +                           }
    +                   };
    +
    +                   // Submit tasks and wait to finish
    +                   final List<Future<Void>> results = 
Lists.newArrayListWithCapacity(2);
    +                   results.add(executor.submit(requestBufferTask));
    +                   results.add(executor.submit(releaseTask));
    +                   for (Future<Void> result : results) {
    +                           result.get();
    +                   }
     
    -           // Recycle exclusive segment after channel released
    -           
inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
    +                   assertEquals("There should be no buffers available in 
the channel.",
    +                           0, inputChannel.getNumberOfAvailableBuffers());
     
    -           assertEquals("Resource leak during recycling buffer after 
channel is released.",
    -                   0, inputChannel.getNumberOfAvailableBuffers());
    -           verify(inputChannel, times(0)).notifyCreditAvailable();
    -           verify(inputGate, 
times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class));
    +           } finally {
    +                   // Release all the buffer resources once exception
    +                   if (!inputChannel.isReleased()) {
    +                           inputChannel.releaseAllResources();
    +                   }
    +
    +                   networkBufferPool.destroyAllBufferPools();
    +                   networkBufferPool.destroy();
    +
    +                   executor.shutdown();
    +           }
        }
     
        /**
    -    * Tests {@link RemoteInputChannel#releaseAllResources()}, verifying 
the exclusive segments are
    -    * recycled to global pool via input gate and no resource leak.
    +    * Tests to verify that there is no race condition with two things 
running in parallel:
    +    * requesting floating buffers on sender backlog and some other thread 
recycling
    +    * floating or exclusive buffers.
         */
        @Test
    -   public void testReleaseExclusiveBuffers() throws Exception {
    +   public void testConcurrentOnSenderBacklogAndRecycle() throws Exception {
                // Setup
    -           final SingleInputGate inputGate = mock(SingleInputGate.class);
    -           final RemoteInputChannel inputChannel = 
createRemoteInputChannel(inputGate);
    +           final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(256, 32, MemoryType.HEAP);
    +           final ExecutorService executor = 
Executors.newFixedThreadPool(2);
    +           final SingleInputGate inputGate = createSingleInputGate();
    +           final RemoteInputChannel inputChannel  = 
createRemoteInputChannel(inputGate);
    +           
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
    +           try {
    +                   final int numFloatingBuffers = 128;
    +                   final int numExclusiveSegments = 2;
    +                   final BufferPool bufferPool = 
networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
    +                   inputGate.setBufferPool(bufferPool);
    +                   inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveSegments);
    +
    +                   // Exhaust all the floating buffers
    +                   final List<Buffer> floatingBuffers = new 
ArrayList<>(numFloatingBuffers);
    +                   for (int i = 0; i < numFloatingBuffers; i++) {
    +                           Buffer buffer = bufferPool.requestBuffer();
    +                           assertNotNull(buffer);
    +                           floatingBuffers.add(buffer);
    +                   }
     
    -           // Assign exclusive segments to channel
    -           final List<MemorySegment> exclusiveSegments = new ArrayList<>();
    -           final int numExclusiveBuffers = 2;
    -           for (int i = 0; i < numExclusiveBuffers; i++) {
    -                   
exclusiveSegments.add(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
    +                   // Exhaust all the exclusive buffers
    +                   final List<Buffer> exclusiveBuffers = new 
ArrayList<>(numExclusiveSegments);
    +                   for (int i = 0; i < numExclusiveSegments; i++) {
    +                           Buffer buffer = inputChannel.requestBuffer();
    +                           assertNotNull(buffer);
    +                           exclusiveBuffers.add(buffer);
    +                   }
    +
    +                   final int backlog = 128;
    +                   final Callable<Void> requestBufferTask = new 
Callable<Void>() {
    +                           @Override
    +                           public Void call() throws Exception {
    +                                   for (int j = 1; j <= backlog; j++) {
    +                                           inputChannel.onSenderBacklog(j);
    +                                   }
    +
    +                                   return null;
    +                           }
    +                   };
    +
    +                   final Callable<Void> recycleBufferTask = new 
Callable<Void>() {
    +                           @Override
    +                           public Void call() throws Exception {
    +                                   // Recycle all the exclusive buffers
    +                                   for (Buffer buffer : exclusiveBuffers) {
    +                                           buffer.recycle();
    +                                   }
    +
    +                                   // Recycle all the floating buffers
    +                                   for (Buffer buffer : floatingBuffers) {
    +                                           buffer.recycle();
    +                                   }
    +
    +                                   return null;
    +                           }
    +                   };
    +
    +                   // Submit tasks and wait to finish
    +                   final List<Future<Void>> results = 
Lists.newArrayListWithCapacity(2);
    +                   results.add(executor.submit(requestBufferTask));
    +                   results.add(executor.submit(recycleBufferTask));
    +                   for (Future<Void> result : results) {
    +                           result.get();
    +                   }
    +
    +                   final int numRequiredBuffers = backlog + 
numExclusiveSegments;
    +                   assertEquals("There should be " + numRequiredBuffers +" 
buffers available in channel.",
    +                           numRequiredBuffers, 
inputChannel.getNumberOfAvailableBuffers());
    +                   assertEquals("There should be no buffers available in 
buffer pool.",
    +                           0, 
bufferPool.getNumberOfAvailableMemorySegments());
    +
    +           } finally {
    +                   // Release all the buffer resources
    +                   inputChannel.releaseAllResources();
    +
    +                   networkBufferPool.destroyAllBufferPools();
    +                   networkBufferPool.destroy();
    +
    +                   executor.shutdown();
                }
    -           inputChannel.assignExclusiveSegments(exclusiveSegments);
    +   }
     
    -           assertEquals("The number of available buffers is not equal to 
the assigned amount.",
    -                   numExclusiveBuffers, 
inputChannel.getNumberOfAvailableBuffers());
    +   /**
    +    * Tests to verify that there is no race condition with two things 
running in parallel:
    +    * recycling the exclusive or floating buffers and some other thread 
releasing the
    +    * input channel.
    +    */
    +   @Test
    +   public void testConcurrentRecycleAndRelease() throws Exception {
    +           // Setup
    +           final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(256, 32, MemoryType.HEAP);
    +           final ExecutorService executor = 
Executors.newFixedThreadPool(2);
    +           final SingleInputGate inputGate = createSingleInputGate();
    +           final RemoteInputChannel inputChannel  = 
createRemoteInputChannel(inputGate);
    +           
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
    +           try {
    +                   final int numFloatingBuffers = 128;
    +                   final int numExclusiveSegments = 2;
    +                   final BufferPool bufferPool = 
networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
    +                   inputGate.setBufferPool(bufferPool);
    +                   inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveSegments);
    +
    +                   // Exhaust all the floating buffers
    +                   final List<Buffer> floatingBuffers = new 
ArrayList<>(numFloatingBuffers);
    +                   for (int i = 0; i < numFloatingBuffers; i++) {
    +                           Buffer buffer = bufferPool.requestBuffer();
    +                           assertNotNull(buffer);
    +                           floatingBuffers.add(buffer);
    +                   }
    +
    +                   // Exhaust all the exclusive buffers
    +                   final List<Buffer> exclusiveBuffers = new 
ArrayList<>(numExclusiveSegments);
    +                   for (int i = 0; i < numExclusiveSegments; i++) {
    +                           Buffer buffer = inputChannel.requestBuffer();
    +                           assertNotNull(buffer);
    +                           exclusiveBuffers.add(buffer);
    +                   }
    +
    +                   final Callable<Void> recycleBufferTask = new 
Callable<Void>() {
    +                           @Override
    +                           public Void call() throws Exception {
    +                                   // Recycle all the exclusive buffers
    --- End diff --
    
    I was actually hoping we could extract more into a common test method but 
it's probably best as you implemented it to keep the actual tests easier to 
understand


---

Reply via email to