This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 10c20837126d7e50c4b2c22678a6853193ae7534
Author: Weijie Guo <[email protected]>
AuthorDate: Fri Jul 8 01:00:33 2022 +0800

    [hotfix] Migrate CreditBasedPartitionRequestClientHandlerTest, 
NettyMessageClientSideSerializationTest, SingleInputGateTest and 
BlockCompressionTest to Junit5/AssertJ
    
    This closes #20216.
---
 .../io/compression/BlockCompressionTest.java       |  57 +--
 ...editBasedPartitionRequestClientHandlerTest.java | 224 +++++------
 .../NettyMessageClientSideSerializationTest.java   |  60 +--
 .../partition/consumer/SingleInputGateTest.java    | 428 +++++++++++----------
 4 files changed, 384 insertions(+), 385 deletions(-)

diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
index 1f57ce2ad19..fd8a05db6f9 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
@@ -18,19 +18,19 @@
 
 package org.apache.flink.runtime.io.compression;
 
-import org.junit.Assert;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.nio.ByteBuffer;
 
 import static 
org.apache.flink.runtime.io.compression.Lz4BlockCompressionFactory.HEADER_LENGTH;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for block compression. */
-public class BlockCompressionTest {
+class BlockCompressionTest {
 
     @Test
-    public void testLz4() {
+    void testLz4() {
         BlockCompressionFactory factory = new Lz4BlockCompressionFactory();
         runArrayTest(factory, 32768);
         runArrayTest(factory, 16);
@@ -54,12 +54,16 @@ public class BlockCompressionTest {
         int compressedOff = 32;
 
         // 1. test compress with insufficient target
-        byte[] insufficientArray = new byte[compressedOff + HEADER_LENGTH + 1];
-        try {
-            compressor.compress(data, originalOff, originalLen, 
insufficientArray, compressedOff);
-            Assert.fail("expect exception here");
-        } catch (InsufficientBufferException ex) {
-        }
+        byte[] insufficientCompressArray = new byte[compressedOff + 
HEADER_LENGTH + 1];
+        assertThatThrownBy(
+                        () ->
+                                compressor.compress(
+                                        data,
+                                        originalOff,
+                                        originalLen,
+                                        insufficientCompressArray,
+                                        compressedOff))
+                .isInstanceOf(InsufficientBufferException.class);
 
         // 2. test normal compress
         byte[] compressedData =
@@ -70,17 +74,16 @@ public class BlockCompressionTest {
         int decompressedOff = 16;
 
         // 3. test decompress with insufficient target
-        insufficientArray = new byte[decompressedOff + originalLen - 1];
-        try {
-            decompressor.decompress(
-                    compressedData,
-                    compressedOff,
-                    compressedLen,
-                    insufficientArray,
-                    decompressedOff);
-            Assert.fail("expect exception here");
-        } catch (InsufficientBufferException ex) {
-        }
+        byte[] insufficientDecompressArray = new byte[decompressedOff + 
originalLen - 1];
+        assertThatThrownBy(
+                        () ->
+                                decompressor.decompress(
+                                        compressedData,
+                                        compressedOff,
+                                        compressedLen,
+                                        insufficientDecompressArray,
+                                        decompressedOff))
+                .isInstanceOf(InsufficientBufferException.class);
 
         // 4. test normal decompress
         byte[] decompressedData = new byte[decompressedOff + originalLen];
@@ -91,10 +94,10 @@ public class BlockCompressionTest {
                         compressedLen,
                         decompressedData,
                         decompressedOff);
-        assertEquals(originalLen, decompressedLen);
+        assertThat(decompressedLen).isEqualTo(originalLen);
 
         for (int i = 0; i < originalLen; i++) {
-            assertEquals(data[originalOff + i], 
decompressedData[decompressedOff + i]);
+            assertThat(decompressedData[decompressedOff + 
i]).isEqualTo(data[originalOff + i]);
         }
     }
 
@@ -129,7 +132,7 @@ public class BlockCompressionTest {
             compressedData = ByteBuffer.allocate(maxCompressedLen);
         }
         int compressedLen = compressor.compress(data, originalOff, 
originalLen, compressedData, 0);
-        assertEquals(compressedLen, compressedData.position());
+        assertThat(compressedData.position()).isEqualTo(compressedLen);
         compressedData.flip();
 
         int compressedOff = 32;
@@ -159,11 +162,11 @@ public class BlockCompressionTest {
         int decompressedLen =
                 decompressor.decompress(
                         copiedCompressedData, compressedOff, compressedLen, 
decompressedData, 0);
-        assertEquals(decompressedLen, decompressedData.position());
+        assertThat(decompressedData.position()).isEqualTo(decompressedLen);
         decompressedData.flip();
 
         for (int i = 0; i < decompressedLen; i++) {
-            assertEquals((byte) i, decompressedData.get());
+            assertThat(decompressedData.get()).isEqualTo((byte) i);
         }
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
index f3074f9682b..8ad892b1d7a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
@@ -56,24 +56,17 @@ import 
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
 import org.apache.flink.shaded.netty4.io.netty.channel.epoll.Epoll;
 import org.apache.flink.shaded.netty4.io.netty.channel.unix.Errors;
 
-import org.junit.Assume;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
 
 import java.io.IOException;
 
 import static 
org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel;
 import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel;
 import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.instanceOf;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertSame;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assumptions.assumeThat;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -83,7 +76,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 /** Test for {@link CreditBasedPartitionRequestClientHandler}. */
-public class CreditBasedPartitionRequestClientHandlerTest {
+class CreditBasedPartitionRequestClientHandlerTest {
 
     /**
      * Tests a fix for FLINK-1627.
@@ -96,9 +89,10 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      *
      * @see <a 
href="https://issues.apache.org/jira/browse/FLINK-1627";>FLINK-1627</a>
      */
-    @Test(timeout = 60000)
+    @Test
+    @Timeout(60)
     @SuppressWarnings("unchecked")
-    public void testReleaseInputChannelDuringDecode() throws Exception {
+    void testReleaseInputChannelDuringDecode() throws Exception {
         // Mocks an input channel in a state as it was released during a 
decode.
         final BufferProvider bufferProvider = mock(BufferProvider.class);
         when(bufferProvider.requestBuffer()).thenReturn(null);
@@ -130,7 +124,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * <p>FLINK-1761 discovered an IndexOutOfBoundsException, when receiving 
buffers of size 0.
      */
     @Test
-    public void testReceiveEmptyBuffer() throws Exception {
+    void testReceiveEmptyBuffer() throws Exception {
         // Minimal mock of a remote input channel
         final BufferProvider bufferProvider = mock(BufferProvider.class);
         
when(bufferProvider.requestBuffer()).thenReturn(TestBufferFactory.createBuffer(0));
@@ -168,7 +162,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * BufferResponse} is received.
      */
     @Test
-    public void testReceiveBuffer() throws Exception {
+    void testReceiveBuffer() throws Exception {
         final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 
32);
         final SingleInputGate inputGate = createSingleInputGate(1, 
networkBufferPool);
         final RemoteInputChannel inputChannel =
@@ -193,8 +187,8 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                             new NetworkBufferAllocator(handler));
             handler.channelRead(mock(ChannelHandlerContext.class), 
bufferResponse);
 
-            assertEquals(1, inputChannel.getNumberOfQueuedBuffers());
-            assertEquals(2, inputChannel.getSenderBacklog());
+            assertThat(inputChannel.getNumberOfQueuedBuffers()).isEqualTo(1);
+            assertThat(inputChannel.getSenderBacklog()).isEqualTo(2);
         } finally {
             releaseResource(inputGate, networkBufferPool);
         }
@@ -204,7 +198,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * Verifies that {@link BufferResponse} of compressed {@link Buffer} can 
be handled correctly.
      */
     @Test
-    public void testReceiveCompressedBuffer() throws Exception {
+    void testReceiveCompressedBuffer() throws Exception {
         int bufferSize = 1024;
         String compressionCodec = "LZ4";
         BufferCompressor compressor = new BufferCompressor(bufferSize, 
compressionCodec);
@@ -236,12 +230,12 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
                             inputChannel.getInputChannelId(),
                             2,
                             new NetworkBufferAllocator(handler));
-            assertTrue(bufferResponse.isCompressed);
+            assertThat(bufferResponse.isCompressed).isTrue();
             handler.channelRead(null, bufferResponse);
 
             Buffer receivedBuffer = inputChannel.getNextReceivedBuffer();
-            assertNotNull(receivedBuffer);
-            assertTrue(receivedBuffer.isCompressed());
+            assertThat(receivedBuffer).isNotNull();
+            assertThat(receivedBuffer.isCompressed()).isTrue();
             receivedBuffer.recycleBuffer();
         } finally {
             releaseResource(inputGate, networkBufferPool);
@@ -250,7 +244,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
     /** Verifies that {@link NettyMessage.BacklogAnnouncement} can be handled 
correctly. */
     @Test
-    public void testReceiveBacklogAnnouncement() throws Exception {
+    void testReceiveBacklogAnnouncement() throws Exception {
         int bufferSize = 1024;
         int numBuffers = 10;
         NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(numBuffers, bufferSize);
@@ -268,26 +262,26 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
                     new CreditBasedPartitionRequestClientHandler();
             handler.addInputChannel(inputChannel);
 
-            assertEquals(2, inputChannel.getNumberOfAvailableBuffers());
-            assertEquals(0, 
inputChannel.unsynchronizedGetFloatingBuffersAvailable());
+            
assertThat(inputChannel.getNumberOfAvailableBuffers()).isEqualTo(2);
+            
assertThat(inputChannel.unsynchronizedGetFloatingBuffersAvailable()).isZero();
 
             int backlog = 5;
             NettyMessage.BacklogAnnouncement announcement =
                     new NettyMessage.BacklogAnnouncement(backlog, 
inputChannel.getInputChannelId());
             handler.channelRead(null, announcement);
-            assertEquals(7, inputChannel.getNumberOfAvailableBuffers());
-            assertEquals(7, inputChannel.getNumberOfRequiredBuffers());
-            assertEquals(backlog, inputChannel.getSenderBacklog());
-            assertEquals(5, 
inputChannel.unsynchronizedGetFloatingBuffersAvailable());
+            
assertThat(inputChannel.getNumberOfAvailableBuffers()).isEqualTo(7);
+            assertThat(inputChannel.getNumberOfRequiredBuffers()).isEqualTo(7);
+            assertThat(inputChannel.getSenderBacklog()).isEqualTo(backlog);
+            
assertThat(inputChannel.unsynchronizedGetFloatingBuffersAvailable()).isEqualTo(5);
 
             backlog = 12;
             announcement =
                     new NettyMessage.BacklogAnnouncement(backlog, 
inputChannel.getInputChannelId());
             handler.channelRead(null, announcement);
-            assertEquals(10, inputChannel.getNumberOfAvailableBuffers());
-            assertEquals(14, inputChannel.getNumberOfRequiredBuffers());
-            assertEquals(backlog, inputChannel.getSenderBacklog());
-            assertEquals(8, 
inputChannel.unsynchronizedGetFloatingBuffersAvailable());
+            
assertThat(inputChannel.getNumberOfAvailableBuffers()).isEqualTo(10);
+            
assertThat(inputChannel.getNumberOfRequiredBuffers()).isEqualTo(14);
+            assertThat(inputChannel.getSenderBacklog()).isEqualTo(backlog);
+            
assertThat(inputChannel.unsynchronizedGetFloatingBuffersAvailable()).isEqualTo(8);
         } finally {
             releaseResource(inputGate, networkBufferPool);
         }
@@ -298,7 +292,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * BufferResponse} is received but no available buffer in input channel.
      */
     @Test
-    public void testThrowExceptionForNoAvailableBuffer() throws Exception {
+    void testThrowExceptionForNoAvailableBuffer() throws Exception {
         final SingleInputGate inputGate = createSingleInputGate(1);
         final RemoteInputChannel inputChannel =
                 
spy(InputChannelBuilder.newBuilder().buildRemoteChannel(inputGate));
@@ -307,10 +301,9 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                 new CreditBasedPartitionRequestClientHandler();
         handler.addInputChannel(inputChannel);
 
-        assertEquals(
-                "There should be no buffers available in the channel.",
-                0,
-                inputChannel.getNumberOfAvailableBuffers());
+        assertThat(inputChannel.getNumberOfAvailableBuffers())
+                .as("There should be no buffers available in the channel.")
+                .isEqualTo(0);
 
         final BufferResponse bufferResponse =
                 createBufferResponse(
@@ -319,7 +312,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                         inputChannel.getInputChannelId(),
                         2,
                         new NetworkBufferAllocator(handler));
-        assertNull(bufferResponse.getBuffer());
+        assertThat(bufferResponse.getBuffer()).isNull();
 
         handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
         verify(inputChannel, 
times(1)).onError(any(IllegalStateException.class));
@@ -330,7 +323,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * PartitionNotFoundException} is received.
      */
     @Test
-    public void testReceivePartitionNotFoundException() throws Exception {
+    void testReceivePartitionNotFoundException() throws Exception {
         // Minimal mock of a remote input channel
         final BufferProvider bufferProvider = mock(BufferProvider.class);
         
when(bufferProvider.requestBuffer()).thenReturn(TestBufferFactory.createBuffer(0));
@@ -360,7 +353,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
     }
 
     @Test
-    public void testCancelBeforeActive() throws Exception {
+    void testCancelBeforeActive() throws Exception {
 
         final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
         when(inputChannel.getInputChannelId()).thenReturn(new 
InputChannelID());
@@ -382,7 +375,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * changed.
      */
     @Test
-    public void testNotifyCreditAvailable() throws Exception {
+    void testNotifyCreditAvailable() throws Exception {
         final CreditBasedPartitionRequestClientHandler handler =
                 new CreditBasedPartitionRequestClientHandler();
         final NetworkBufferAllocator allocator = new 
NetworkBufferAllocator(handler);
@@ -409,20 +402,18 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
             inputChannels[1].requestSubpartition();
 
             // The two input channels should send partition requests
-            assertTrue(channel.isWritable());
+            assertThat(channel.isWritable()).isTrue();
             Object readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(PartitionRequest.class));
-            assertEquals(
-                    inputChannels[0].getInputChannelId(),
-                    ((PartitionRequest) readFromOutbound).receiverId);
-            assertEquals(2, ((PartitionRequest) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(PartitionRequest.class);
+            assertThat(inputChannels[0].getInputChannelId())
+                    .isEqualTo(((PartitionRequest) 
readFromOutbound).receiverId);
+            assertThat(((PartitionRequest) 
readFromOutbound).credit).isEqualTo(2);
 
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(PartitionRequest.class));
-            assertEquals(
-                    inputChannels[1].getInputChannelId(),
-                    ((PartitionRequest) readFromOutbound).receiverId);
-            assertEquals(2, ((PartitionRequest) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(PartitionRequest.class);
+            assertThat(inputChannels[1].getInputChannelId())
+                    .isEqualTo(((PartitionRequest) 
readFromOutbound).receiverId);
+            assertThat(((PartitionRequest) 
readFromOutbound).credit).isEqualTo(2);
 
             // The buffer response will take one available buffer from input 
channel, and it will
             // trigger
@@ -444,26 +435,24 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
             handler.channelRead(mock(ChannelHandlerContext.class), 
bufferResponse1);
             handler.channelRead(mock(ChannelHandlerContext.class), 
bufferResponse2);
 
-            assertEquals(2, inputChannels[0].getUnannouncedCredit());
-            assertEquals(2, inputChannels[1].getUnannouncedCredit());
+            assertThat(inputChannels[0].getUnannouncedCredit()).isEqualTo(2);
+            assertThat(inputChannels[1].getUnannouncedCredit()).isEqualTo(2);
 
             channel.runPendingTasks();
 
             // The two input channels should notify credits availability via 
the writable channel
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(AddCredit.class));
-            assertEquals(
-                    inputChannels[0].getInputChannelId(),
-                    ((AddCredit) readFromOutbound).receiverId);
-            assertEquals(2, ((AddCredit) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(AddCredit.class);
+            assertThat(inputChannels[0].getInputChannelId())
+                    .isEqualTo(((AddCredit) readFromOutbound).receiverId);
+            assertThat(((AddCredit) readFromOutbound).credit).isEqualTo(2);
 
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(AddCredit.class));
-            assertEquals(
-                    inputChannels[1].getInputChannelId(),
-                    ((AddCredit) readFromOutbound).receiverId);
-            assertEquals(2, ((AddCredit) readFromOutbound).credit);
-            assertNull(channel.readOutbound());
+            assertThat(readFromOutbound).isInstanceOf(AddCredit.class);
+            assertThat(inputChannels[1].getInputChannelId())
+                    .isEqualTo(((AddCredit) readFromOutbound).receiverId);
+            assertThat(((AddCredit) readFromOutbound).credit).isEqualTo(2);
+            assertThat((Object) channel.readOutbound()).isNull();
 
             ByteBuf channelBlockingBuffer = blockChannel(channel);
 
@@ -478,29 +467,29 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
                             allocator);
             handler.channelRead(mock(ChannelHandlerContext.class), 
bufferResponse3);
 
-            assertEquals(1, inputChannels[0].getUnannouncedCredit());
-            assertEquals(0, inputChannels[1].getUnannouncedCredit());
+            assertThat(inputChannels[0].getUnannouncedCredit()).isEqualTo(1);
+            assertThat(inputChannels[1].getUnannouncedCredit()).isZero();
 
             channel.runPendingTasks();
 
             // The input channel will not notify credits via un-writable 
channel
-            assertFalse(channel.isWritable());
-            assertNull(channel.readOutbound());
+            assertThat(channel.isWritable()).isFalse();
+            assertThat((Object) channel.readOutbound()).isNull();
 
             // Flush the buffer to make the channel writable again
             channel.flush();
-            assertSame(channelBlockingBuffer, channel.readOutbound());
+            assertThat(channelBlockingBuffer).isSameAs(channel.readOutbound());
 
             // The input channel should notify credits via channel's 
writability changed event
-            assertTrue(channel.isWritable());
+            assertThat(channel.isWritable()).isTrue();
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(AddCredit.class));
-            assertEquals(1, ((AddCredit) readFromOutbound).credit);
-            assertEquals(0, inputChannels[0].getUnannouncedCredit());
-            assertEquals(0, inputChannels[1].getUnannouncedCredit());
+            assertThat(readFromOutbound).isInstanceOf(AddCredit.class);
+            assertThat(((AddCredit) readFromOutbound).credit).isEqualTo(1);
+            assertThat(inputChannels[0].getUnannouncedCredit()).isZero();
+            assertThat(inputChannels[1].getUnannouncedCredit()).isZero();
 
             // no more messages
-            assertNull(channel.readOutbound());
+            assertThat((Object) channel.readOutbound()).isNull();
         } finally {
             releaseResource(inputGate, networkBufferPool);
             channel.close();
@@ -512,7 +501,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * message is not sent actually when this input channel is released.
      */
     @Test
-    public void testNotifyCreditAvailableAfterReleased() throws Exception {
+    void testNotifyCreditAvailableAfterReleased() throws Exception {
         final CreditBasedPartitionRequestClientHandler handler =
                 new CreditBasedPartitionRequestClientHandler();
         final EmbeddedChannel channel = new EmbeddedChannel(handler);
@@ -536,8 +525,8 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
             // This should send the partition request
             Object readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(PartitionRequest.class));
-            assertEquals(2, ((PartitionRequest) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(PartitionRequest.class);
+            assertThat(((PartitionRequest) 
readFromOutbound).credit).isEqualTo(2);
 
             // Trigger request floating buffers via buffer response to notify 
credits available
             final BufferResponse bufferResponse =
@@ -549,7 +538,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                             new NetworkBufferAllocator(handler));
             handler.channelRead(mock(ChannelHandlerContext.class), 
bufferResponse);
 
-            assertEquals(2, inputChannel.getUnannouncedCredit());
+            assertThat(inputChannel.getUnannouncedCredit()).isEqualTo(2);
 
             // Release the input channel
             inputGate.close();
@@ -557,11 +546,10 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
             // it should send a close request after releasing the input 
channel,
             // but will not notify credits for a released input channel.
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(CloseRequest.class));
+            assertThat(readFromOutbound).isInstanceOf(CloseRequest.class);
 
             channel.runPendingTasks();
-
-            assertNull(channel.readOutbound());
+            assertThat((Object) channel.readOutbound()).isNull();
         } finally {
             releaseResource(inputGate, networkBufferPool);
             channel.close();
@@ -569,27 +557,27 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
     }
 
     @Test
-    public void testReadBufferResponseBeforeReleasingChannel() throws 
Exception {
+    void testReadBufferResponseBeforeReleasingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(false, true);
     }
 
     @Test
-    public void testReadBufferResponseBeforeRemovingChannel() throws Exception 
{
+    void testReadBufferResponseBeforeRemovingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(true, true);
     }
 
     @Test
-    public void testReadBufferResponseAfterReleasingChannel() throws Exception 
{
+    void testReadBufferResponseAfterReleasingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(false, false);
     }
 
     @Test
-    public void testReadBufferResponseAfterRemovingChannel() throws Exception {
+    void testReadBufferResponseAfterRemovingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(true, false);
     }
 
     @Test
-    public void testDoNotFailHandlerOnSingleChannelFailure() throws Exception {
+    void testDoNotFailHandlerOnSingleChannelFailure() throws Exception {
         // Setup
         final int bufferSize = 1024;
         final String expectedMessage = "test exception on buffer";
@@ -620,13 +608,11 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
             // The handler should not be tagged as error for above excepted 
exception
             handler.checkError();
 
-            try {
-                // The input channel should be tagged as error and the 
respective exception is
-                // thrown via #getNext
-                inputGate.getNext();
-            } catch (IOException ignored) {
-                assertEquals(expectedMessage, ignored.getMessage());
-            }
+            // The input channel should be tagged as error and the respective 
exception is
+            // thrown via #getNext
+            assertThatThrownBy(inputGate::getNext)
+                    .isInstanceOf(IOException.class)
+                    .hasMessage(expectedMessage);
         } finally {
             // Cleanup
             releaseResource(inputGate, networkBufferPool);
@@ -634,7 +620,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
     }
 
     @Test
-    public void testExceptionWrap() {
+    void testExceptionWrap() {
         testExceptionWrap(LocalTransportException.class, new Exception());
         testExceptionWrap(LocalTransportException.class, new Exception("some 
error"));
         testExceptionWrap(
@@ -642,7 +628,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
         // Only when Epoll is available the following exception could be 
initiated normally
         // since it relies on the native strerror method.
-        Assume.assumeTrue(Epoll.isAvailable());
+        assumeThat(Epoll.isAvailable()).isTrue();
         testExceptionWrap(
                 RemoteTransportException.class,
                 new Errors.NativeIoException("readAddress", 
Errors.ERRNO_ECONNRESET_NEGATIVE));
@@ -665,19 +651,16 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
                         handler);
 
         embeddedChannel.writeInbound(1);
-        try {
-            handler.checkError();
-            fail(
-                    String.format(
-                            "The handler should wrap the exception %s as %s, 
but it does not.",
-                            cause, expectedClass));
-        } catch (IOException e) {
-            assertThat(e, instanceOf(expectedClass));
-        }
+        assertThatThrownBy(() -> handler.checkError())
+                .isInstanceOf(expectedClass)
+                .withFailMessage(
+                        String.format(
+                                "The handler should wrap the exception %s as 
%s, but it does not.",
+                                cause, expectedClass));
     }
 
     @Test
-    public void testAnnounceBufferSize() throws Exception {
+    void testAnnounceBufferSize() throws Exception {
         final CreditBasedPartitionRequestClientHandler handler =
                 new CreditBasedPartitionRequestClientHandler();
         final EmbeddedChannel channel = new EmbeddedChannel(handler);
@@ -709,13 +692,13 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
             channel.runPendingTasks();
 
             NettyMessage.NewBufferSize readOutbound = channel.readOutbound();
-            assertThat(readOutbound, 
instanceOf(NettyMessage.NewBufferSize.class));
-            assertThat(readOutbound.receiverId, 
is(inputChannels[0].getInputChannelId()));
-            assertThat(readOutbound.bufferSize, is(333));
+            
assertThat(readOutbound).isInstanceOf(NettyMessage.NewBufferSize.class);
+            
assertThat(inputChannels[0].getInputChannelId()).isEqualTo(readOutbound.receiverId);
+            assertThat(readOutbound.bufferSize).isEqualTo(333);
 
             readOutbound = channel.readOutbound();
-            assertThat(readOutbound.receiverId, 
is(inputChannels[1].getInputChannelId()));
-            assertThat(readOutbound.bufferSize, is(333));
+            
assertThat(inputChannels[1].getInputChannelId()).isEqualTo(readOutbound.receiverId);
+            assertThat(readOutbound.bufferSize).isEqualTo(333);
 
         } finally {
             releaseResource(inputGate, networkBufferPool);
@@ -766,19 +749,20 @@ public class CreditBasedPartitionRequestClientHandlerTest 
{
 
             handler.channelRead(null, bufferResponse);
 
-            assertEquals(0, inputChannel.getNumberOfQueuedBuffers());
+            assertThat(inputChannel.getNumberOfQueuedBuffers()).isZero();
             if (!readBeforeReleasingOrRemoving) {
-                assertNull(bufferResponse.getBuffer());
+                assertThat(bufferResponse.getBuffer()).isNull();
             } else {
-                assertNotNull(bufferResponse.getBuffer());
-                assertTrue(bufferResponse.getBuffer().isRecycled());
+                assertThat(bufferResponse.getBuffer()).isNotNull();
+                assertThat(bufferResponse.getBuffer().isRecycled()).isTrue();
             }
 
             embeddedChannel.runScheduledPendingTasks();
             NettyMessage.CancelPartitionRequest cancelPartitionRequest =
                     embeddedChannel.readOutbound();
-            assertNotNull(cancelPartitionRequest);
-            assertEquals(inputChannel.getInputChannelId(), 
cancelPartitionRequest.receiverId);
+            assertThat(cancelPartitionRequest).isNotNull();
+            assertThat(inputChannel.getInputChannelId())
+                    .isEqualTo(cancelPartitionRequest.receiverId);
         } finally {
             releaseResource(inputGate, networkBufferPool);
             embeddedChannel.close();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
index ee42d8dd748..925a1444c32 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
@@ -30,13 +30,14 @@ import 
org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
 import 
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.util.TestLoggerExtension;
 
 import 
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
 
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
 
 import java.io.IOException;
 import java.util.Random;
@@ -51,15 +52,14 @@ import static 
org.apache.flink.runtime.io.network.netty.NettyTestUtil.verifyErro
 import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel;
 import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.apache.flink.util.Preconditions.checkArgument;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
  * Tests for the serialization and deserialization of the various {@link 
NettyMessage} sub-classes
  * sent from server side to client side.
  */
-public class NettyMessageClientSideSerializationTest extends TestLogger {
+@ExtendWith(TestLoggerExtension.class)
+class NettyMessageClientSideSerializationTest {
 
     private static final int BUFFER_SIZE = 1024;
 
@@ -78,8 +78,8 @@ public class NettyMessageClientSideSerializationTest extends 
TestLogger {
 
     private InputChannelID inputChannelId;
 
-    @Before
-    public void setup() throws IOException, InterruptedException {
+    @BeforeEach
+    void setup() throws IOException, InterruptedException {
         networkBufferPool = new NetworkBufferPool(8, BUFFER_SIZE);
         inputGate = createSingleInputGate(1, networkBufferPool);
         RemoteInputChannel inputChannel =
@@ -100,8 +100,8 @@ public class NettyMessageClientSideSerializationTest 
extends TestLogger {
         inputChannelId = inputChannel.getInputChannelId();
     }
 
-    @After
-    public void tearDown() throws IOException {
+    @AfterEach
+    void tearDown() throws IOException {
         if (inputGate != null) {
             inputGate.close();
         }
@@ -117,43 +117,43 @@ public class NettyMessageClientSideSerializationTest 
extends TestLogger {
     }
 
     @Test
-    public void testErrorResponseWithoutErrorMessage() {
+    void testErrorResponseWithoutErrorMessage() {
         testErrorResponse(new ErrorResponse(new IllegalStateException(), 
inputChannelId));
     }
 
     @Test
-    public void testErrorResponseWithErrorMessage() {
+    void testErrorResponseWithErrorMessage() {
         testErrorResponse(
                 new ErrorResponse(
                         new IllegalStateException("Illegal illegal illegal"), 
inputChannelId));
     }
 
     @Test
-    public void testErrorResponseWithFatalError() {
+    void testErrorResponseWithFatalError() {
         testErrorResponse(new ErrorResponse(new IllegalStateException("Illegal 
illegal illegal")));
     }
 
     @Test
-    public void testOrdinaryBufferResponse() {
+    void testOrdinaryBufferResponse() {
         testBufferResponse(false, false);
     }
 
     @Test
-    public void testBufferResponseWithReadOnlySlice() {
+    void testBufferResponseWithReadOnlySlice() {
         testBufferResponse(true, false);
     }
 
     @Test
-    public void testCompressedBufferResponse() {
+    void testCompressedBufferResponse() {
         testBufferResponse(false, true);
     }
 
     @Test
-    public void testBacklogAnnouncement() {
+    void testBacklogAnnouncement() {
         BacklogAnnouncement expected = new BacklogAnnouncement(1024, 
inputChannelId);
         BacklogAnnouncement actual = encodeAndDecode(expected, channel);
-        assertEquals(expected.backlog, actual.backlog);
-        assertEquals(expected.receiverId, actual.receiverId);
+        assertThat(actual.backlog).isEqualTo(expected.backlog);
+        assertThat(actual.receiverId).isEqualTo(expected.receiverId);
     }
 
     private void testErrorResponse(ErrorResponse expect) {
@@ -189,22 +189,22 @@ public class NettyMessageClientSideSerializationTest 
extends TestLogger {
                         random.nextInt(Integer.MAX_VALUE));
         BufferResponse actual = encodeAndDecode(expected, channel);
 
-        assertTrue(buffer.isRecycled());
-        assertTrue(testBuffer.isRecycled());
-        assertNotNull(
-                "The request input channel should always have available 
buffers in this test.",
-                actual.getBuffer());
+        assertThat(buffer.isRecycled()).isTrue();
+        assertThat(testBuffer.isRecycled()).isTrue();
+        assertThat(actual.getBuffer())
+                .as("The request input channel should always have available 
buffers in this test.")
+                .isNotNull();
 
         Buffer decodedBuffer = actual.getBuffer();
         if (testCompressedBuffer) {
-            assertTrue(actual.isCompressed);
+            assertThat(actual.isCompressed).isTrue();
             decodedBuffer = decompress(decodedBuffer);
         }
 
         verifyBufferResponseHeader(expected, actual);
-        assertEquals(BUFFER_SIZE, decodedBuffer.readableBytes());
+        assertThat(decodedBuffer.readableBytes()).isEqualTo(BUFFER_SIZE);
         for (int i = 0; i < BUFFER_SIZE; i += 8) {
-            assertEquals(i, decodedBuffer.asByteBuf().readLong());
+            assertThat(decodedBuffer.asByteBuf().readLong()).isEqualTo(i);
         }
 
         // Release the received message.
@@ -213,7 +213,7 @@ public class NettyMessageClientSideSerializationTest 
extends TestLogger {
             decodedBuffer.recycleBuffer();
         }
 
-        assertTrue(actual.getBuffer().isRecycled());
+        assertThat(actual.getBuffer().isRecycled()).isTrue();
     }
 
     private Buffer decompress(Buffer buffer) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index dc0e383660f..44249a57bfa 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -72,7 +72,7 @@ import org.apache.flink.util.CompressedSerializedValue;
 
 import org.apache.flink.shaded.guava30.com.google.common.io.Closer;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -97,35 +97,40 @@ import static 
org.apache.flink.runtime.io.network.util.TestBufferFactory.createB
 import static 
org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
 import static 
org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder.createRemoteWithIdAndLocation;
 import static org.apache.flink.util.Preconditions.checkState;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.instanceOf;
-import static org.hamcrest.Matchers.is;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for {@link SingleInputGate}. */
 public class SingleInputGateTest extends InputGateTestBase {
 
-    @Test(expected = CheckpointException.class)
-    public void testCheckpointsDeclinedUnlessAllChannelsAreKnown() throws 
CheckpointException {
+    @Test
+    void testCheckpointsDeclinedUnlessAllChannelsAreKnown() throws 
CheckpointException {
         SingleInputGate gate =
                 createInputGate(createNettyShuffleEnvironment(), 1, 
ResultPartitionType.PIPELINED);
         gate.setInputChannels(
                 new 
InputChannelBuilder().setChannelIndex(0).buildUnknownChannel(gate));
-        gate.checkpointStarted(
-                new CheckpointBarrier(1L, 1L, alignedNoTimeout(CHECKPOINT, 
getDefault())));
+        assertThatThrownBy(
+                        () ->
+                                gate.checkpointStarted(
+                                        new CheckpointBarrier(
+                                                1L,
+                                                1L,
+                                                alignedNoTimeout(CHECKPOINT, 
getDefault()))))
+                .isInstanceOf(CheckpointException.class);
     }
 
-    @Test(expected = CheckpointException.class)
-    public void testCheckpointsDeclinedUnlessStateConsumed() throws 
CheckpointException {
+    @Test
+    void testCheckpointsDeclinedUnlessStateConsumed() throws 
CheckpointException {
         SingleInputGate gate = 
createInputGate(createNettyShuffleEnvironment());
         checkState(!gate.getStateConsumedFuture().isDone());
-        gate.checkpointStarted(
-                new CheckpointBarrier(1L, 1L, alignedNoTimeout(CHECKPOINT, 
getDefault())));
+        assertThatThrownBy(
+                        () ->
+                                gate.checkpointStarted(
+                                        new CheckpointBarrier(
+                                                1L,
+                                                1L,
+                                                alignedNoTimeout(CHECKPOINT, 
getDefault()))))
+                .isInstanceOf(CheckpointException.class);
     }
 
     /**
@@ -133,7 +138,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * exclusive buffers for {@link RemoteInputChannel}s, but should not 
request partitions.
      */
     @Test
-    public void testSetupLogic() throws Exception {
+    void testSetupLogic() throws Exception {
         final NettyShuffleEnvironment environment = 
createNettyShuffleEnvironment();
         final SingleInputGate inputGate = createInputGate(environment);
         try (Closer closer = Closer.create()) {
@@ -141,52 +146,53 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             closer.register(inputGate::close);
 
             // before setup
-            assertNull(inputGate.getBufferPool());
+            assertThat(inputGate.getBufferPool()).isNull();
             for (InputChannel inputChannel : 
inputGate.getInputChannels().values()) {
-                assertTrue(
-                        inputChannel instanceof RecoveredInputChannel
-                                || inputChannel instanceof 
UnknownInputChannel);
+                assertThat(
+                                inputChannel instanceof RecoveredInputChannel
+                                        || inputChannel instanceof 
UnknownInputChannel)
+                        .isTrue();
                 if (inputChannel instanceof RecoveredInputChannel) {
-                    assertEquals(
-                            0,
-                            ((RecoveredInputChannel) inputChannel)
-                                    
.bufferManager.getNumberOfAvailableBuffers());
+                    assertThat(
+                                    ((RecoveredInputChannel) inputChannel)
+                                            
.bufferManager.getNumberOfAvailableBuffers())
+                            .isEqualTo(0);
                 }
             }
 
             inputGate.setup();
 
             // after setup
-            assertNotNull(inputGate.getBufferPool());
-            assertEquals(1, 
inputGate.getBufferPool().getNumberOfRequiredMemorySegments());
+            assertThat(inputGate.getBufferPool()).isNotNull();
+            
assertThat(inputGate.getBufferPool().getNumberOfRequiredMemorySegments()).isEqualTo(1);
             for (InputChannel inputChannel : 
inputGate.getInputChannels().values()) {
                 if (inputChannel instanceof RemoteRecoveredInputChannel) {
-                    assertEquals(
-                            0,
-                            ((RemoteRecoveredInputChannel) inputChannel)
-                                    
.bufferManager.getNumberOfAvailableBuffers());
+                    assertThat(
+                                    ((RemoteRecoveredInputChannel) 
inputChannel)
+                                            
.bufferManager.getNumberOfAvailableBuffers())
+                            .isEqualTo(0);
                 } else if (inputChannel instanceof LocalRecoveredInputChannel) 
{
-                    assertEquals(
-                            0,
-                            ((LocalRecoveredInputChannel) inputChannel)
-                                    
.bufferManager.getNumberOfAvailableBuffers());
+                    assertThat(
+                                    ((LocalRecoveredInputChannel) inputChannel)
+                                            
.bufferManager.getNumberOfAvailableBuffers())
+                            .isEqualTo(0);
                 }
             }
 
             inputGate.convertRecoveredInputChannels();
-            assertNotNull(inputGate.getBufferPool());
-            assertEquals(1, 
inputGate.getBufferPool().getNumberOfRequiredMemorySegments());
+            assertThat(inputGate.getBufferPool()).isNotNull();
+            
assertThat(inputGate.getBufferPool().getNumberOfRequiredMemorySegments()).isEqualTo(1);
             for (InputChannel inputChannel : 
inputGate.getInputChannels().values()) {
                 if (inputChannel instanceof RemoteInputChannel) {
-                    assertEquals(
-                            2, ((RemoteInputChannel) 
inputChannel).getNumberOfAvailableBuffers());
+                    assertThat(((RemoteInputChannel) 
inputChannel).getNumberOfAvailableBuffers())
+                            .isEqualTo(2);
                 }
             }
         }
     }
 
     @Test
-    public void testPartitionRequestLogic() throws Exception {
+    void testPartitionRequestLogic() throws Exception {
         final NettyShuffleEnvironment environment = new 
NettyShuffleEnvironmentBuilder().build();
         final SingleInputGate gate = createInputGate(environment);
 
@@ -203,15 +209,16 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             gate.pollNext();
 
             final InputChannel remoteChannel = gate.getChannel(0);
-            assertThat(remoteChannel, instanceOf(RemoteInputChannel.class));
-            assertNotNull(((RemoteInputChannel) 
remoteChannel).getPartitionRequestClient());
-            assertEquals(2, ((RemoteInputChannel) 
remoteChannel).getInitialCredit());
+            assertThat(remoteChannel).isInstanceOf(RemoteInputChannel.class);
+            assertThat(((RemoteInputChannel) 
remoteChannel).getPartitionRequestClient())
+                    .isNotNull();
+            assertThat(((RemoteInputChannel) 
remoteChannel).getInitialCredit()).isEqualTo(2);
 
             final InputChannel localChannel = gate.getChannel(1);
-            assertThat(localChannel, instanceOf(LocalInputChannel.class));
-            assertNotNull(((LocalInputChannel) 
localChannel).getSubpartitionView());
+            assertThat(localChannel).isInstanceOf(LocalInputChannel.class);
+            assertThat(((LocalInputChannel) 
localChannel).getSubpartitionView()).isNotNull();
 
-            assertThat(gate.getChannel(2), 
instanceOf(UnknownInputChannel.class));
+            
assertThat(gate.getChannel(2)).isInstanceOf(UnknownInputChannel.class);
         }
     }
 
@@ -220,7 +227,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * value after receiving all end-of-partition events.
      */
     @Test
-    public void testBasicGetNextLogic() throws Exception {
+    void testBasicGetNextLogic() throws Exception {
         // Setup
         final SingleInputGate inputGate = createInputGate();
 
@@ -248,20 +255,19 @@ public class SingleInputGateTest extends 
InputGateTestBase {
         verifyBufferOrEvent(inputGate, true, 0, true);
         verifyBufferOrEvent(inputGate, false, 1, true);
         // we have received EndOfData on a single channel only
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
-                inputGate.hasReceivedEndOfData());
+        assertThat(inputGate.hasReceivedEndOfData())
+                
.isEqualTo(PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA);
         verifyBufferOrEvent(inputGate, false, 0, true);
-        assertFalse(inputGate.isFinished());
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.DRAINED, 
inputGate.hasReceivedEndOfData());
+        assertThat(inputGate.isFinished()).isFalse();
+        assertThat(inputGate.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.DRAINED);
         verifyBufferOrEvent(inputGate, false, 1, true);
         verifyBufferOrEvent(inputGate, false, 0, false);
 
         // Return null when the input gate has received all end-of-partition 
events
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.DRAINED, 
inputGate.hasReceivedEndOfData());
-        assertTrue(inputGate.isFinished());
+        assertThat(inputGate.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.DRAINED);
+        assertThat(inputGate.isFinished()).isTrue();
 
         for (TestInputChannel ic : inputChannels) {
             ic.assertReturnedEventsAreRecycled();
@@ -269,7 +275,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testDrainFlagComputation() throws Exception {
+    void testDrainFlagComputation() throws Exception {
         // Setup
         final SingleInputGate inputGate1 = createInputGate();
         final SingleInputGate inputGate2 = createInputGate();
@@ -299,23 +305,21 @@ public class SingleInputGateTest extends 
InputGateTestBase {
 
         verifyBufferOrEvent(inputGate1, false, 0, true);
         // we have received EndOfData on a single channel only
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
-                inputGate1.hasReceivedEndOfData());
+        assertThat(inputGate1.hasReceivedEndOfData())
+                
.isEqualTo(PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA);
         verifyBufferOrEvent(inputGate1, false, 1, true);
         // one of the channels said we should not drain
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.STOPPED, 
inputGate1.hasReceivedEndOfData());
+        assertThat(inputGate1.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.STOPPED);
 
         verifyBufferOrEvent(inputGate2, false, 0, true);
         // we have received EndOfData on a single channel only
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
-                inputGate2.hasReceivedEndOfData());
+        assertThat(inputGate2.hasReceivedEndOfData())
+                
.isEqualTo(PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA);
         verifyBufferOrEvent(inputGate2, false, 1, true);
         // both channels said we should drain
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.DRAINED, 
inputGate2.hasReceivedEndOfData());
+        assertThat(inputGate2.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.DRAINED);
     }
 
     /**
@@ -323,7 +327,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * SingleInputGate#getNext()}.
      */
     @Test
-    public void testGetCompressedBuffer() throws Exception {
+    void testGetCompressedBuffer() throws Exception {
         int bufferSize = 1024;
         String compressionCodec = "LZ4";
         BufferCompressor compressor = new BufferCompressor(bufferSize, 
compressionCodec);
@@ -340,15 +344,15 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             Buffer uncompressedBuffer = new NetworkBuffer(segment, 
FreeingBufferRecycler.INSTANCE);
             uncompressedBuffer.setSize(bufferSize);
             Buffer compressedBuffer = 
compressor.compressToOriginalBuffer(uncompressedBuffer);
-            assertTrue(compressedBuffer.isCompressed());
+            assertThat(compressedBuffer.isCompressed()).isTrue();
 
             inputChannel.read(compressedBuffer);
             inputGate.setInputChannels(inputChannel);
             inputGate.notifyChannelNonEmpty(inputChannel);
 
             Optional<BufferOrEvent> bufferOrEvent = inputGate.getNext();
-            assertTrue(bufferOrEvent.isPresent());
-            assertTrue(bufferOrEvent.get().isBuffer());
+            assertThat(bufferOrEvent.isPresent()).isTrue();
+            assertThat(bufferOrEvent.get().isBuffer()).isTrue();
             ByteBuffer buffer =
                     bufferOrEvent
                             .get()
@@ -356,29 +360,29 @@ public class SingleInputGateTest extends 
InputGateTestBase {
                             .getNioBufferReadable()
                             .order(ByteOrder.LITTLE_ENDIAN);
             for (int i = 0; i < bufferSize; i += 8) {
-                assertEquals(i, buffer.getLong());
+                assertThat(buffer.getLong()).isEqualTo(i);
             }
         }
     }
 
     @Test
-    public void testNotifyAfterEndOfPartition() throws Exception {
+    void testNotifyAfterEndOfPartition() throws Exception {
         final SingleInputGate inputGate = createInputGate(2);
         TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
         inputGate.setInputChannels(inputChannel, new 
TestInputChannel(inputGate, 1));
 
         inputChannel.readEndOfPartitionEvent();
         inputChannel.notifyChannelNonEmpty();
-        assertEquals(EndOfPartitionEvent.INSTANCE, 
inputGate.pollNext().get().getEvent());
+        
assertThat(inputGate.pollNext().get().getEvent()).isEqualTo(EndOfPartitionEvent.INSTANCE);
 
         // gate is still active because of secondary channel
         // test if released channel is enqueued
         inputChannel.notifyChannelNonEmpty();
-        assertFalse(inputGate.pollNext().isPresent());
+        assertThat(inputGate.pollNext().isPresent()).isFalse();
     }
 
     @Test
-    public void testIsAvailable() throws Exception {
+    void testIsAvailable() throws Exception {
         final SingleInputGate inputGate = createInputGate(1);
         TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
         inputGate.setInputChannels(inputChannel);
@@ -387,7 +391,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testIsAvailableAfterFinished() throws Exception {
+    void testIsAvailableAfterFinished() throws Exception {
         final SingleInputGate inputGate = createInputGate(1);
         TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
         inputGate.setInputChannels(inputChannel);
@@ -401,7 +405,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testIsMoreAvailableReadingFromSingleInputChannel() throws 
Exception {
+    void testIsMoreAvailableReadingFromSingleInputChannel() throws Exception {
         // Setup
         final SingleInputGate inputGate = createInputGate();
 
@@ -423,7 +427,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testBackwardsEventWithUninitializedChannel() throws Exception {
+    void testBackwardsEventWithUninitializedChannel() throws Exception {
         // Setup environment
         TestingTaskEventPublisher taskEventPublisher = new 
TestingTaskEventPublisher();
 
@@ -464,14 +468,14 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             setupInputGate(inputGate, inputChannels);
 
             // Only the local channel can request
-            assertEquals(1, partitionManager.counter);
+            assertThat(partitionManager.counter).isEqualTo(1);
 
             // Send event backwards and initialize unknown channel afterwards
             final TaskEvent event = new TestTaskEvent();
             inputGate.sendTaskEvent(event);
 
             // Only the local channel can send out the event
-            assertEquals(1, taskEventPublisher.counter);
+            assertThat(taskEventPublisher.counter).isEqualTo(1);
 
             // After the update, the pending event should be send to local 
channel
 
@@ -480,8 +484,8 @@ public class SingleInputGateTest extends InputGateTestBase {
                     location,
                     
createRemoteWithIdAndLocation(unknownPartitionId.getPartitionId(), location));
 
-            assertEquals(2, partitionManager.counter);
-            assertEquals(2, taskEventPublisher.counter);
+            assertThat(partitionManager.counter).isEqualTo(2);
+            assertThat(taskEventPublisher.counter).isEqualTo(2);
         }
     }
 
@@ -492,7 +496,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * listener.
      */
     @Test
-    public void testUpdateChannelBeforeRequest() throws Exception {
+    void testUpdateChannelBeforeRequest() throws Exception {
         SingleInputGate inputGate = createInputGate(1);
 
         TestingResultPartitionManager partitionManager =
@@ -511,7 +515,7 @@ public class SingleInputGateTest extends InputGateTestBase {
                 location,
                 
createRemoteWithIdAndLocation(resultPartitionID.getPartitionId(), location));
 
-        assertEquals(0, partitionManager.counter);
+        assertThat(partitionManager.counter).isEqualTo(0);
     }
 
     /**
@@ -519,7 +523,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * data.
      */
     @Test
-    public void testReleaseWhilePollingChannel() throws Exception {
+    void testReleaseWhilePollingChannel() throws Exception {
         final AtomicReference<Exception> asyncException = new 
AtomicReference<>();
 
         // Setup the input gate with a single channel that does nothing
@@ -558,7 +562,7 @@ public class SingleInputGateTest extends InputGateTestBase {
         }
 
         // Verify that async consumer is in blocking request
-        assertTrue("Did not trigger blocking buffer request.", success);
+        assertThat(success).as("Did not trigger blocking buffer 
request.").isTrue();
 
         // Release the input gate
         inputGate.close();
@@ -568,13 +572,13 @@ public class SingleInputGateTest extends 
InputGateTestBase {
         // call will never return.
         asyncConsumer.join();
 
-        assertNotNull(asyncException.get());
-        assertEquals(IllegalStateException.class, 
asyncException.get().getClass());
+        assertThat(asyncException.get()).isNotNull();
+        
assertThat(asyncException.get().getClass()).isEqualTo(IllegalStateException.class);
     }
 
     /** Tests request back off configuration is correctly forwarded to the 
channels. */
     @Test
-    public void testRequestBackoffConfiguration() throws Exception {
+    void testRequestBackoffConfiguration() throws Exception {
         IntermediateResultPartitionID[] partitionIds =
                 new IntermediateResultPartitionID[] {
                     new IntermediateResultPartitionID(),
@@ -605,11 +609,11 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             closer.register(netEnv::close);
             closer.register(gate::close);
 
-            assertEquals(ResultPartitionType.PIPELINED, 
gate.getConsumedPartitionType());
+            
assertThat(gate.getConsumedPartitionType()).isEqualTo(ResultPartitionType.PIPELINED);
 
             Map<SubpartitionInfo, InputChannel> channelMap = 
gate.getInputChannels();
 
-            assertEquals(3, channelMap.size());
+            assertThat(channelMap.size()).isEqualTo(3);
             channelMap
                     .values()
                     .forEach(
@@ -621,39 +625,39 @@ public class SingleInputGateTest extends 
InputGateTestBase {
                                 }
                             });
             InputChannel localChannel = 
channelMap.get(createSubpartitionInfo(partitionIds[0]));
-            assertEquals(LocalInputChannel.class, localChannel.getClass());
+            
assertThat(localChannel.getClass()).isEqualTo(LocalInputChannel.class);
 
             InputChannel remoteChannel = 
channelMap.get(createSubpartitionInfo(partitionIds[1]));
-            assertEquals(RemoteInputChannel.class, remoteChannel.getClass());
+            
assertThat(remoteChannel.getClass()).isEqualTo(RemoteInputChannel.class);
 
             InputChannel unknownChannel = 
channelMap.get(createSubpartitionInfo(partitionIds[2]));
-            assertEquals(UnknownInputChannel.class, unknownChannel.getClass());
+            
assertThat(unknownChannel.getClass()).isEqualTo(UnknownInputChannel.class);
 
             InputChannel[] channels =
                     new InputChannel[] {localChannel, remoteChannel, 
unknownChannel};
             for (InputChannel ch : channels) {
-                assertEquals(0, ch.getCurrentBackoff());
+                assertThat(ch.getCurrentBackoff()).isEqualTo(0);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(initialBackoff, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(initialBackoff);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(initialBackoff * 2, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(initialBackoff * 
2);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(initialBackoff * 
2 * 2);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(maxBackoff, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(maxBackoff);
 
-                assertFalse(ch.increaseBackoff());
+                assertThat(ch.increaseBackoff()).isFalse();
             }
         }
     }
 
     /** Tests that input gate requests and assigns network buffers for remote 
input channel. */
     @Test
-    public void testRequestBuffersWithRemoteInputChannel() throws Exception {
+    void testRequestBuffersWithRemoteInputChannel() throws Exception {
         final NettyShuffleEnvironment network = 
createNettyShuffleEnvironment();
         final SingleInputGate inputGate =
                 createInputGate(network, 1, 
ResultPartitionType.PIPELINED_BOUNDED);
@@ -674,14 +678,13 @@ public class SingleInputGateTest extends 
InputGateTestBase {
 
             NetworkBufferPool bufferPool = network.getNetworkBufferPool();
             // only the exclusive buffers should be assigned/available now
-            assertEquals(buffersPerChannel, 
remote.getNumberOfAvailableBuffers());
+            
assertThat(remote.getNumberOfAvailableBuffers()).isEqualTo(buffersPerChannel);
 
-            assertEquals(
-                    bufferPool.getTotalNumberOfMemorySegments() - 
buffersPerChannel - 1,
-                    bufferPool.getNumberOfAvailableMemorySegments());
+            assertThat(bufferPool.getNumberOfAvailableMemorySegments())
+                    .isEqualTo(bufferPool.getTotalNumberOfMemorySegments() - 
buffersPerChannel - 1);
             // note: exclusive buffers are not handed out into LocalBufferPool 
and are thus not
             // counted
-            assertEquals(extraNetworkBuffersPerGate, 
bufferPool.countBuffers());
+            
assertThat(bufferPool.countBuffers()).isEqualTo(extraNetworkBuffersPerGate);
         }
     }
 
@@ -690,7 +693,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * to remote input channel.
      */
     @Test
-    public void testRequestBuffersWithUnknownInputChannel() throws Exception {
+    void testRequestBuffersWithUnknownInputChannel() throws Exception {
         final NettyShuffleEnvironment network = 
createNettyShuffleEnvironment();
         final SingleInputGate inputGate =
                 createInputGate(network, 1, 
ResultPartitionType.PIPELINED_BOUNDED);
@@ -709,12 +712,11 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             inputGate.setup();
             NetworkBufferPool bufferPool = network.getNetworkBufferPool();
 
-            assertEquals(
-                    bufferPool.getTotalNumberOfMemorySegments() - 1,
-                    bufferPool.getNumberOfAvailableMemorySegments());
+            assertThat(bufferPool.getNumberOfAvailableMemorySegments())
+                    .isEqualTo(bufferPool.getTotalNumberOfMemorySegments() - 
1);
             // note: exclusive buffers are not handed out into LocalBufferPool 
and are thus not
             // counted
-            assertEquals(extraNetworkBuffersPerGate, 
bufferPool.countBuffers());
+            
assertThat(bufferPool.countBuffers()).isEqualTo(extraNetworkBuffersPerGate);
 
             // Trigger updates to remote input channel from unknown input 
channel
             inputGate.updateInputChannel(
@@ -730,14 +732,13 @@ public class SingleInputGateTest extends 
InputGateTestBase {
                                             createSubpartitionInfo(
                                                     
resultPartitionId.getPartitionId()));
             // only the exclusive buffers should be assigned/available now
-            assertEquals(buffersPerChannel, 
remote.getNumberOfAvailableBuffers());
+            
assertThat(remote.getNumberOfAvailableBuffers()).isEqualTo(buffersPerChannel);
 
-            assertEquals(
-                    bufferPool.getTotalNumberOfMemorySegments() - 
buffersPerChannel - 1,
-                    bufferPool.getNumberOfAvailableMemorySegments());
+            assertThat(bufferPool.getNumberOfAvailableMemorySegments())
+                    .isEqualTo(bufferPool.getTotalNumberOfMemorySegments() - 
buffersPerChannel - 1);
             // note: exclusive buffers are not handed out into LocalBufferPool 
and are thus not
             // counted
-            assertEquals(extraNetworkBuffersPerGate, 
bufferPool.countBuffers());
+            
assertThat(bufferPool.countBuffers()).isEqualTo(extraNetworkBuffersPerGate);
         }
     }
 
@@ -746,7 +747,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * channels.
      */
     @Test
-    public void testUpdateUnknownInputChannel() throws Exception {
+    void testUpdateUnknownInputChannel() throws Exception {
         final NettyShuffleEnvironment network = 
createNettyShuffleEnvironment();
 
         final ResultPartition localResultPartition =
@@ -785,15 +786,19 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             inputGate.setup();
 
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            
.get(createSubpartitionInfo(remoteResultPartitionId.getPartitionId())),
-                    is(instanceOf((UnknownInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    
remoteResultPartitionId.getPartitionId())))
+                    .isInstanceOf(UnknownInputChannel.class);
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            
.get(createSubpartitionInfo(localResultPartitionId.getPartitionId())),
-                    is(instanceOf((UnknownInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    
localResultPartitionId.getPartitionId())))
+                    .isInstanceOf(UnknownInputChannel.class);
 
             ResourceID localLocation = ResourceID.generate();
 
@@ -804,15 +809,19 @@ public class SingleInputGateTest extends 
InputGateTestBase {
                             remoteResultPartitionId.getPartitionId(), 
ResourceID.generate()));
 
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            
.get(createSubpartitionInfo(remoteResultPartitionId.getPartitionId())),
-                    is(instanceOf((RemoteInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    
remoteResultPartitionId.getPartitionId())))
+                    .isInstanceOf(RemoteInputChannel.class);
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            
.get(createSubpartitionInfo(localResultPartitionId.getPartitionId())),
-                    is(instanceOf((UnknownInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    
localResultPartitionId.getPartitionId())))
+                    .isInstanceOf(UnknownInputChannel.class);
 
             // Trigger updates to local input channel from unknown input 
channel
             inputGate.updateInputChannel(
@@ -821,21 +830,24 @@ public class SingleInputGateTest extends 
InputGateTestBase {
                             localResultPartitionId.getPartitionId(), 
localLocation));
 
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            
.get(createSubpartitionInfo(remoteResultPartitionId.getPartitionId())),
-                    is(instanceOf((RemoteInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    
remoteResultPartitionId.getPartitionId())))
+                    .isInstanceOf(RemoteInputChannel.class);
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            
.get(createSubpartitionInfo(localResultPartitionId.getPartitionId())),
-                    is(instanceOf((LocalInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    
localResultPartitionId.getPartitionId())))
+                    .isInstanceOf(LocalInputChannel.class);
         }
     }
 
     @Test
-    public void testSingleInputGateWithSubpartitionIndexRange()
-            throws IOException, InterruptedException {
+    void testSingleInputGateWithSubpartitionIndexRange() throws IOException, 
InterruptedException {
 
         IntermediateResultPartitionID[] partitionIds =
                 new IntermediateResultPartitionID[] {
@@ -872,13 +884,13 @@ public class SingleInputGateTest extends 
InputGateTestBase {
         SubpartitionInfo info5 = createSubpartitionInfo(partitionIds[2], 0);
         SubpartitionInfo info6 = createSubpartitionInfo(partitionIds[2], 1);
 
-        assertThat(gate.getInputChannels().size(), is(6));
-        
assertThat(gate.getInputChannels().get(info1).getConsumedSubpartitionIndex(), 
is(0));
-        
assertThat(gate.getInputChannels().get(info2).getConsumedSubpartitionIndex(), 
is(1));
-        
assertThat(gate.getInputChannels().get(info3).getConsumedSubpartitionIndex(), 
is(0));
-        
assertThat(gate.getInputChannels().get(info4).getConsumedSubpartitionIndex(), 
is(1));
-        
assertThat(gate.getInputChannels().get(info5).getConsumedSubpartitionIndex(), 
is(0));
-        
assertThat(gate.getInputChannels().get(info6).getConsumedSubpartitionIndex(), 
is(1));
+        assertThat(gate.getInputChannels().size()).isEqualTo(6);
+        
assertThat(gate.getInputChannels().get(info1).getConsumedSubpartitionIndex()).isEqualTo(0);
+        
assertThat(gate.getInputChannels().get(info2).getConsumedSubpartitionIndex()).isEqualTo(1);
+        
assertThat(gate.getInputChannels().get(info3).getConsumedSubpartitionIndex()).isEqualTo(0);
+        
assertThat(gate.getInputChannels().get(info4).getConsumedSubpartitionIndex()).isEqualTo(1);
+        
assertThat(gate.getInputChannels().get(info5).getConsumedSubpartitionIndex()).isEqualTo(0);
+        
assertThat(gate.getInputChannels().get(info6).getConsumedSubpartitionIndex()).isEqualTo(1);
 
         assertChannelsType(gate, LocalRecoveredInputChannel.class, 
Arrays.asList(info1, info2));
         assertChannelsType(gate, RemoteRecoveredInputChannel.class, 
Arrays.asList(info3, info4));
@@ -886,8 +898,8 @@ public class SingleInputGateTest extends InputGateTestBase {
 
         // test setup
         gate.setup();
-        assertNotNull(gate.getBufferPool());
-        assertEquals(1, 
gate.getBufferPool().getNumberOfRequiredMemorySegments());
+        assertThat(gate.getBufferPool()).isNotNull();
+        
assertThat(gate.getBufferPool().getNumberOfRequiredMemorySegments()).isEqualTo(1);
 
         gate.finishReadRecoveredState();
         while (!gate.getStateConsumedFuture().isDone()) {
@@ -902,10 +914,11 @@ public class SingleInputGateTest extends 
InputGateTestBase {
         assertChannelsType(gate, UnknownInputChannel.class, 
Arrays.asList(info5, info6));
         for (InputChannel inputChannel : gate.getInputChannels().values()) {
             if (inputChannel instanceof RemoteInputChannel) {
-                assertNotNull(((RemoteInputChannel) 
inputChannel).getPartitionRequestClient());
-                assertEquals(2, ((RemoteInputChannel) 
inputChannel).getInitialCredit());
+                assertThat(((RemoteInputChannel) 
inputChannel).getPartitionRequestClient())
+                        .isNotNull();
+                assertThat(((RemoteInputChannel) 
inputChannel).getInitialCredit()).isEqualTo(2);
             } else if (inputChannel instanceof LocalInputChannel) {
-                assertNotNull(((LocalInputChannel) 
inputChannel).getSubpartitionView());
+                assertThat(((LocalInputChannel) 
inputChannel).getSubpartitionView()).isNotNull();
             }
         }
 
@@ -920,12 +933,12 @@ public class SingleInputGateTest extends 
InputGateTestBase {
     private void assertChannelsType(
             SingleInputGate gate, Class<?> clazz, List<SubpartitionInfo> 
infos) {
         for (SubpartitionInfo subpartitionInfo : infos) {
-            assertThat(gate.getInputChannels().get(subpartitionInfo), 
instanceOf(clazz));
+            
assertThat(gate.getInputChannels().get(subpartitionInfo)).isInstanceOf(clazz);
         }
     }
 
     @Test
-    public void testQueuedBuffers() throws Exception {
+    void testQueuedBuffers() throws Exception {
         final NettyShuffleEnvironment network = 
createNettyShuffleEnvironment();
 
         final BufferWritingResultPartition resultPartition =
@@ -966,10 +979,10 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             setupInputGate(inputGate, inputChannels);
 
             remoteInputChannel.onBuffer(createBuffer(1), 0, 0);
-            assertEquals(1, inputGate.getNumberOfQueuedBuffers());
+            assertThat(inputGate.getNumberOfQueuedBuffers()).isEqualTo(1);
 
             resultPartition.emitRecord(ByteBuffer.allocate(1), 0);
-            assertEquals(2, inputGate.getNumberOfQueuedBuffers());
+            assertThat(inputGate.getNumberOfQueuedBuffers()).isEqualTo(2);
         }
     }
 
@@ -979,7 +992,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * the {@link SingleInputGate} would not swallow or transform the original 
exception.
      */
     @Test
-    public void testPartitionNotFoundExceptionWhileGetNextBuffer() throws 
Exception {
+    void testPartitionNotFoundExceptionWhileGetNextBuffer() throws Exception {
         final SingleInputGate inputGate = 
InputChannelTestUtils.createSingleInputGate(1);
         final LocalInputChannel localChannel =
                 createLocalInputChannel(inputGate, new 
ResultPartitionManager());
@@ -987,17 +1000,16 @@ public class SingleInputGateTest extends 
InputGateTestBase {
 
         inputGate.setInputChannels(localChannel);
         localChannel.setError(new PartitionNotFoundException(partitionId));
-        try {
-            inputGate.getNext();
-
-            fail("Should throw a PartitionNotFoundException.");
-        } catch (PartitionNotFoundException notFound) {
-            assertThat(partitionId, is(notFound.getPartitionId()));
-        }
+        assertThatThrownBy(inputGate::getNext)
+                .isInstanceOfSatisfying(
+                        PartitionNotFoundException.class,
+                        (notFoundException) ->
+                                assertThat(notFoundException.getPartitionId())
+                                        .isEqualTo(partitionId));
     }
 
     @Test
-    public void testAnnounceBufferSize() throws Exception {
+    void testAnnounceBufferSize() throws Exception {
         final SingleInputGate inputGate = 
InputChannelTestUtils.createSingleInputGate(2);
         final LocalInputChannel localChannel =
                 createLocalInputChannel(
@@ -1028,7 +1040,7 @@ public class SingleInputGateTest extends 
InputGateTestBase {
     }
 
     @Test
-    public void testInputGateRemovalFromNettyShuffleEnvironment() throws 
Exception {
+    void testInputGateRemovalFromNettyShuffleEnvironment() throws Exception {
         NettyShuffleEnvironment network = createNettyShuffleEnvironment();
 
         try (Closer closer = Closer.create()) {
@@ -1038,18 +1050,18 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             Map<InputGateID, SingleInputGate> createdInputGatesById =
                     createInputGateWithLocalChannels(network, numberOfGates, 
1);
 
-            assertEquals(numberOfGates, createdInputGatesById.size());
+            assertThat(createdInputGatesById.size()).isEqualTo(numberOfGates);
 
             for (InputGateID id : createdInputGatesById.keySet()) {
-                assertThat(network.getInputGate(id).isPresent(), is(true));
+                assertThat(network.getInputGate(id).isPresent()).isTrue();
                 createdInputGatesById.get(id).close();
-                assertThat(network.getInputGate(id).isPresent(), is(false));
+                assertThat(network.getInputGate(id).isPresent()).isFalse();
             }
         }
     }
 
     @Test
-    public void testSingleInputGateInfo() {
+    void testSingleInputGateInfo() {
         final int numSingleInputGates = 2;
         final int numInputChannels = 3;
 
@@ -1064,14 +1076,14 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             for (InputChannel inputChannel : gate.getInputChannels().values()) 
{
                 InputChannelInfo channelInfo = inputChannel.getChannelInfo();
 
-                assertEquals(i, channelInfo.getGateIdx());
-                assertEquals(channelCounter++, 
channelInfo.getInputChannelIdx());
+                assertThat(channelInfo.getGateIdx()).isEqualTo(i);
+                
assertThat(channelInfo.getInputChannelIdx()).isEqualTo(channelCounter++);
             }
         }
     }
 
     @Test
-    public void testGetUnfinishedChannels() throws IOException, 
InterruptedException {
+    void testGetUnfinishedChannels() throws IOException, InterruptedException {
         SingleInputGate inputGate =
                 new SingleInputGateBuilder()
                         .setSingleInputGateIndex(1)
@@ -1085,35 +1097,36 @@ public class SingleInputGateTest extends 
InputGateTestBase {
                 };
         inputGate.setInputChannels(inputChannels);
 
-        assertEquals(
-                Arrays.asList(
-                        inputChannels[0].getChannelInfo(),
-                        inputChannels[1].getChannelInfo(),
-                        inputChannels[2].getChannelInfo()),
-                inputGate.getUnfinishedChannels());
+        assertThat(inputGate.getUnfinishedChannels())
+                .isEqualTo(
+                        Arrays.asList(
+                                inputChannels[0].getChannelInfo(),
+                                inputChannels[1].getChannelInfo(),
+                                inputChannels[2].getChannelInfo()));
 
         inputChannels[1].readEndOfPartitionEvent();
         inputGate.notifyChannelNonEmpty(inputChannels[1]);
         inputGate.getNext();
-        assertEquals(
-                Arrays.asList(inputChannels[0].getChannelInfo(), 
inputChannels[2].getChannelInfo()),
-                inputGate.getUnfinishedChannels());
+        assertThat(inputGate.getUnfinishedChannels())
+                .isEqualTo(
+                        Arrays.asList(
+                                inputChannels[0].getChannelInfo(),
+                                inputChannels[2].getChannelInfo()));
 
         inputChannels[0].readEndOfPartitionEvent();
         inputGate.notifyChannelNonEmpty(inputChannels[0]);
         inputGate.getNext();
-        assertEquals(
-                Collections.singletonList(inputChannels[2].getChannelInfo()),
-                inputGate.getUnfinishedChannels());
+        assertThat(inputGate.getUnfinishedChannels())
+                
.isEqualTo(Collections.singletonList(inputChannels[2].getChannelInfo()));
 
         inputChannels[2].readEndOfPartitionEvent();
         inputGate.notifyChannelNonEmpty(inputChannels[2]);
         inputGate.getNext();
-        assertEquals(Collections.emptyList(), 
inputGate.getUnfinishedChannels());
+        
assertThat(inputGate.getUnfinishedChannels()).isEqualTo(Collections.emptyList());
     }
 
     @Test
-    public void testBufferInUseCount() throws Exception {
+    void testBufferInUseCount() throws Exception {
         // Setup
         final SingleInputGate inputGate = createInputGate();
 
@@ -1125,17 +1138,17 @@ public class SingleInputGateTest extends 
InputGateTestBase {
         inputGate.setInputChannels(inputChannels);
 
         // It should be no buffers when all channels are empty.
-        assertThat(inputGate.getBuffersInUseCount(), is(0));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(0);
 
         // Add buffers into channels.
         inputChannels[0].readBuffer();
-        assertThat(inputGate.getBuffersInUseCount(), is(1));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(1);
 
         inputChannels[0].readBuffer();
-        assertThat(inputGate.getBuffersInUseCount(), is(2));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(2);
 
         inputChannels[1].readBuffer();
-        assertThat(inputGate.getBuffersInUseCount(), is(3));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(3);
     }
 
     // 
---------------------------------------------------------------------------------------------
@@ -1278,14 +1291,13 @@ public class SingleInputGateTest extends 
InputGateTestBase {
             throws IOException, InterruptedException {
 
         final Optional<BufferOrEvent> bufferOrEvent = inputGate.getNext();
-        assertTrue(bufferOrEvent.isPresent());
-        assertEquals(expectedIsBuffer, bufferOrEvent.get().isBuffer());
-        assertEquals(
-                inputGate.getChannel(expectedChannelIndex).getChannelInfo(),
-                bufferOrEvent.get().getChannelInfo());
-        assertEquals(expectedMoreAvailable, 
bufferOrEvent.get().moreAvailable());
+        assertThat(bufferOrEvent.isPresent()).isTrue();
+        assertThat(bufferOrEvent.get().isBuffer()).isEqualTo(expectedIsBuffer);
+        assertThat(bufferOrEvent.get().getChannelInfo())
+                
.isEqualTo(inputGate.getChannel(expectedChannelIndex).getChannelInfo());
+        
assertThat(bufferOrEvent.get().moreAvailable()).isEqualTo(expectedMoreAvailable);
         if (!expectedMoreAvailable) {
-            assertFalse(inputGate.pollNext().isPresent());
+            assertThat(inputGate.pollNext().isPresent()).isFalse();
         }
     }
 

Reply via email to