This is an automated email from the ASF dual-hosted git repository.

sewen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 65c1edf2ab076718cecbd847b926a7f02c4a4dcb
Author: Zhijiang <[email protected]>
AuthorDate: Wed Oct 28 15:19:30 2020 +0800

    [FLINK-15981][refactor] Refactor NettyMessage#write() to allow writing 
multiple message parts.
    
    Rather than having to produce one message (as a composite buffer) this now 
allows writing
    multiple partial messages. That way, we can combine different message 
types, like a memory
    buffer (for headers and events) with file region buffers (for direct file 
transfer).
---
 .../runtime/io/network/netty/NettyMessage.java     | 224 +++++++++------------
 .../NettyMessageClientDecoderDelegateTest.java     |  36 ++--
 .../runtime/io/network/netty/NettyTestUtil.java    |   9 +-
 3 files changed, 118 insertions(+), 151 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
index 3bf537a..057b219 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.netty;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
@@ -36,6 +37,7 @@ import 
org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
 import 
org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundHandlerAdapter;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundInvoker;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelPromise;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
 
@@ -46,6 +48,7 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.net.ProtocolException;
 import java.nio.ByteBuffer;
+import java.util.function.Consumer;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -67,7 +70,7 @@ public abstract class NettyMessage {
 
        static final int MAGIC_NUMBER = 0xBADC0FFE;
 
-       abstract ByteBuf write(ByteBufAllocator allocator) throws Exception;
+       abstract void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException;
 
        // 
------------------------------------------------------------------------
 
@@ -165,22 +168,9 @@ public abstract class NettyMessage {
        static class NettyMessageEncoder extends ChannelOutboundHandlerAdapter {
 
                @Override
-               public void write(ChannelHandlerContext ctx, Object msg, 
ChannelPromise promise) throws Exception {
+               public void write(ChannelHandlerContext ctx, Object msg, 
ChannelPromise promise) throws IOException {
                        if (msg instanceof NettyMessage) {
-
-                               ByteBuf serialized = null;
-
-                               try {
-                                       serialized = ((NettyMessage) 
msg).write(ctx.alloc());
-                               }
-                               catch (Throwable t) {
-                                       throw new IOException("Error while 
serializing message: " + msg, t);
-                               }
-                               finally {
-                                       if (serialized != null) {
-                                               ctx.write(serialized, promise);
-                                       }
-                               }
+                               ((NettyMessage) msg).write(ctx, promise, 
ctx.alloc());
                        }
                        else {
                                ctx.write(msg, promise);
@@ -337,21 +327,29 @@ public abstract class NettyMessage {
                // 
--------------------------------------------------------------------
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws IOException {
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
                        ByteBuf headerBuf = null;
                        try {
                                // in order to forward the buffer to netty, it 
needs an allocator set
                                buffer.setAllocator(allocator);
 
-                               // only allocate header buffer - we will 
combine it with the data buffer below
-                               headerBuf = allocateBuffer(allocator, ID, 
MESSAGE_HEADER_LENGTH, bufferSize, false);
+                               headerBuf = fillHeader(allocator);
+                               out.write(headerBuf);
+                               out.write(buffer, promise);
+                       }
+                       catch (Throwable t) {
+                               handleException(headerBuf, buffer, t);
+                       }
+               }
+
+               @VisibleForTesting
+               ByteBuf write(ByteBufAllocator allocator) throws IOException {
+                       ByteBuf headerBuf = null;
+                       try {
+                               // in order to forward the buffer to netty, it 
needs an allocator set
+                               buffer.setAllocator(allocator);
 
-                               receiverId.writeTo(headerBuf);
-                               headerBuf.writeInt(sequenceNumber);
-                               headerBuf.writeInt(backlog);
-                               headerBuf.writeByte(dataType.ordinal());
-                               headerBuf.writeBoolean(isCompressed);
-                               headerBuf.writeInt(buffer.readableBytes());
+                               headerBuf = fillHeader(allocator);
 
                                CompositeByteBuf composityBuf = 
allocator.compositeDirectBuffer();
                                composityBuf.addComponent(headerBuf);
@@ -361,16 +359,24 @@ public abstract class NettyMessage {
                                return composityBuf;
                        }
                        catch (Throwable t) {
-                               if (headerBuf != null) {
-                                       headerBuf.release();
-                               }
-                               buffer.recycleBuffer();
-
-                               ExceptionUtils.rethrowIOException(t);
+                               handleException(headerBuf, buffer, t);
                                return null; // silence the compiler
                        }
                }
 
+               private ByteBuf fillHeader(ByteBufAllocator allocator) {
+                       // only allocate header buffer - we will combine it 
with the data buffer below
+                       ByteBuf headerBuf = allocateBuffer(allocator, ID, 
MESSAGE_HEADER_LENGTH, bufferSize, false);
+
+                       receiverId.writeTo(headerBuf);
+                       headerBuf.writeInt(sequenceNumber);
+                       headerBuf.writeInt(backlog);
+                       headerBuf.writeByte(dataType.ordinal());
+                       headerBuf.writeBoolean(isCompressed);
+                       headerBuf.writeInt(buffer.readableBytes());
+                       return headerBuf;
+               }
+
                /**
                 * Parses the message header part and composes a new 
BufferResponse with an empty data buffer. The
                 * data buffer will be filled in later.
@@ -437,7 +443,7 @@ public abstract class NettyMessage {
                }
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws IOException {
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
                        final ByteBuf result = allocateBuffer(allocator, ID);
 
                        try (ObjectOutputStream oos = new 
ObjectOutputStream(new ByteBufOutputStream(result))) {
@@ -452,16 +458,10 @@ public abstract class NettyMessage {
 
                                // Update frame length...
                                result.setInt(0, result.readableBytes());
-                               return result;
+                               out.write(result, promise);
                        }
                        catch (Throwable t) {
-                               result.release();
-
-                               if (t instanceof IOException) {
-                                       throw (IOException) t;
-                               } else {
-                                       throw new IOException(t);
-                               }
+                               handleException(result, null, t);
                        }
                }
 
@@ -508,27 +508,16 @@ public abstract class NettyMessage {
                }
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws IOException {
-                       ByteBuf result = null;
-
-                       try {
-                               result = allocateBuffer(allocator, ID, 20 + 16 
+ 4 + 16 + 4);
-
-                               partitionId.getPartitionId().writeTo(result);
-                               partitionId.getProducerId().writeTo(result);
-                               result.writeInt(queueIndex);
-                               receiverId.writeTo(result);
-                               result.writeInt(credit);
-
-                               return result;
-                       }
-                       catch (Throwable t) {
-                               if (result != null) {
-                                       result.release();
-                               }
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
+                       Consumer<ByteBuf> consumer = (bb) -> {
+                               partitionId.getPartitionId().writeTo(bb);
+                               partitionId.getProducerId().writeTo(bb);
+                               bb.writeInt(queueIndex);
+                               receiverId.writeTo(bb);
+                               bb.writeInt(credit);
+                       };
 
-                               throw new IOException(t);
-                       }
+                       writeToChannel(out, promise, allocator, consumer, ID, 
20 + 16 + 4 + 16 + 4);
                }
 
                static PartitionRequest readFrom(ByteBuf buffer) {
@@ -566,32 +555,20 @@ public abstract class NettyMessage {
                }
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws IOException {
-                       ByteBuf result = null;
-
-                       try {
-                               // TODO Directly serialize to Netty's buffer
-                               ByteBuffer serializedEvent = 
EventSerializer.toSerializedEvent(event);
-
-                               result = allocateBuffer(allocator, ID, 4 + 
serializedEvent.remaining() + 20 + 16 + 16);
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
+                       // TODO Directly serialize to Netty's buffer
+                       ByteBuffer serializedEvent = 
EventSerializer.toSerializedEvent(event);
 
-                               result.writeInt(serializedEvent.remaining());
-                               result.writeBytes(serializedEvent);
+                       Consumer<ByteBuf> consumer = (bb) -> {
+                               bb.writeInt(serializedEvent.remaining());
+                               bb.writeBytes(serializedEvent);
 
-                               partitionId.getPartitionId().writeTo(result);
-                               partitionId.getProducerId().writeTo(result);
+                               partitionId.getPartitionId().writeTo(bb);
+                               partitionId.getProducerId().writeTo(bb);
+                               receiverId.writeTo(bb);
+                       };
 
-                               receiverId.writeTo(result);
-
-                               return result;
-                       }
-                       catch (Throwable t) {
-                               if (result != null) {
-                                       result.release();
-                               }
-
-                               throw new IOException(t);
-                       }
+                       writeToChannel(out, promise, allocator, consumer, ID, 4 
+ serializedEvent.remaining() + 20 + 16 + 16);
                }
 
                static TaskEventRequest readFrom(ByteBuf buffer, ClassLoader 
classLoader) throws IOException {
@@ -633,22 +610,8 @@ public abstract class NettyMessage {
                }
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws Exception {
-                       ByteBuf result = null;
-
-                       try {
-                               result = allocateBuffer(allocator, ID, 16);
-                               receiverId.writeTo(result);
-                       }
-                       catch (Throwable t) {
-                               if (result != null) {
-                                       result.release();
-                               }
-
-                               throw new IOException(t);
-                       }
-
-                       return result;
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
+                       writeToChannel(out, promise, allocator, receiverId :: 
writeTo, ID, 16);
                }
 
                static CancelPartitionRequest readFrom(ByteBuf buffer) throws 
Exception {
@@ -664,8 +627,8 @@ public abstract class NettyMessage {
                }
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws Exception {
-                       return allocateBuffer(allocator, ID, 0);
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
+                       writeToChannel(out, promise, allocator, ignored -> {}, 
ID, 0);
                }
 
                static CloseRequest readFrom(@SuppressWarnings("unused") 
ByteBuf buffer) throws Exception {
@@ -691,7 +654,7 @@ public abstract class NettyMessage {
                }
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws IOException {
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
                        ByteBuf result = null;
 
                        try {
@@ -699,14 +662,10 @@ public abstract class NettyMessage {
                                result.writeInt(credit);
                                receiverId.writeTo(result);
 
-                               return result;
+                               out.write(result, promise);
                        }
                        catch (Throwable t) {
-                               if (result != null) {
-                                       result.release();
-                               }
-
-                               throw new IOException(t);
+                               handleException(result, null, t);
                        }
                }
 
@@ -737,22 +696,8 @@ public abstract class NettyMessage {
                }
 
                @Override
-               ByteBuf write(ByteBufAllocator allocator) throws IOException {
-                       ByteBuf result = null;
-
-                       try {
-                               result = allocateBuffer(allocator, ID, 16);
-                               receiverId.writeTo(result);
-
-                               return result;
-                       }
-                       catch (Throwable t) {
-                               if (result != null) {
-                                       result.release();
-                               }
-
-                               throw new IOException(t);
-                       }
+               void write(ChannelOutboundInvoker out, ChannelPromise promise, 
ByteBufAllocator allocator) throws IOException {
+                       writeToChannel(out, promise, allocator, receiverId :: 
writeTo, ID, 16);
                }
 
                static ResumeConsumption readFrom(ByteBuf buffer) {
@@ -764,4 +709,35 @@ public abstract class NettyMessage {
                        return String.format("ResumeConsumption(%s)", 
receiverId);
                }
        }
+
+       // 
------------------------------------------------------------------------
+
+       void writeToChannel(
+                       ChannelOutboundInvoker out,
+                       ChannelPromise promise,
+                       ByteBufAllocator allocator,
+                       Consumer<ByteBuf> consumer,
+                       byte id,
+                       int length) throws IOException {
+
+               ByteBuf byteBuf = null;
+               try {
+                       byteBuf = allocateBuffer(allocator, id, length);
+                       consumer.accept(byteBuf);
+                       out.write(byteBuf, promise);
+               }
+               catch (Throwable t) {
+                       handleException(byteBuf, null, t);
+               }
+       }
+
+       void handleException(@Nullable ByteBuf byteBuf, @Nullable Buffer 
buffer, Throwable t) throws IOException {
+               if (byteBuf != null) {
+                       byteBuf.release();
+               }
+               if (buffer != null) {
+                       buffer.recycleBuffer();
+               }
+               ExceptionUtils.rethrowIOException(t);
+       }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java
index 6613328..7954258 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java
@@ -48,13 +48,10 @@ import java.util.List;
 import static junit.framework.TestCase.assertEquals;
 import static junit.framework.TestCase.assertTrue;
 import static 
org.apache.flink.runtime.io.network.netty.NettyMessage.BufferResponse;
-import static 
org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse;
 import static 
org.apache.flink.runtime.io.network.netty.NettyTestUtil.verifyBufferResponseHeader;
-import static 
org.apache.flink.runtime.io.network.netty.NettyTestUtil.verifyErrorResponse;
 import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel;
 import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.junit.Assert.assertNull;
-import static org.junit.Assert.fail;
 
 /**
  * Tests the client side message decoder.
@@ -163,7 +160,7 @@ public class NettyMessageClientDecoderDelegateTest extends 
TestLogger {
                ByteBuf[] encodedMessages = null;
                List<NettyMessage> decodedMessages = null;
                try {
-                       List<NettyMessage> messages = createMessageList(
+                       List<BufferResponse> messages = createMessageList(
                                hasEmptyBuffer,
                                hasBufferForReleasedChannel,
                                hasBufferForRemovedChannel);
@@ -186,13 +183,13 @@ public class NettyMessageClientDecoderDelegateTest 
extends TestLogger {
                }
        }
 
-       private List<NettyMessage> createMessageList(
+       private List<BufferResponse> createMessageList(
                boolean hasEmptyBuffer,
                boolean hasBufferForRemovedChannel,
                boolean hasBufferForReleasedChannel) {
 
                int seqNumber = 1;
-               List<NettyMessage> messages = new ArrayList<>();
+               List<BufferResponse> messages = new ArrayList<>();
 
                for (int i = 0; i < NUMBER_OF_BUFFER_RESPONSES - 1; i++) {
                        addBufferResponse(messages, inputChannelId, 
Buffer.DataType.DATA_BUFFER, BUFFER_SIZE, seqNumber++);
@@ -210,13 +207,12 @@ public class NettyMessageClientDecoderDelegateTest 
extends TestLogger {
 
                addBufferResponse(messages, inputChannelId, 
Buffer.DataType.EVENT_BUFFER, 32, seqNumber++);
                addBufferResponse(messages, inputChannelId, 
Buffer.DataType.DATA_BUFFER, BUFFER_SIZE, seqNumber);
-               messages.add(new NettyMessage.ErrorResponse(new 
RuntimeException("test"), inputChannelId));
 
                return messages;
        }
 
        private void addBufferResponse(
-               List<NettyMessage> messages,
+               List<BufferResponse> messages,
                InputChannelID inputChannelId,
                Buffer.DataType dataType,
                int bufferSize,
@@ -236,7 +232,7 @@ public class NettyMessageClientDecoderDelegateTest extends 
TestLogger {
                return buffer;
        }
 
-       private ByteBuf[] encodeMessages(List<NettyMessage> messages) throws 
Exception {
+       private ByteBuf[] encodeMessages(List<BufferResponse> messages) throws 
Exception {
                ByteBuf[] encodedMessages = new ByteBuf[messages.size()];
                for (int i = 0; i < messages.size(); ++i) {
                        encodedMessages[i] = messages.get(i).write(ALLOCATOR);
@@ -315,26 +311,18 @@ public class NettyMessageClientDecoderDelegateTest 
extends TestLogger {
                return decodedMessages;
        }
 
-       private void verifyDecodedMessages(List<NettyMessage> expectedMessages, 
List<NettyMessage> decodedMessages) {
+       private void verifyDecodedMessages(List<BufferResponse> 
expectedMessages, List<NettyMessage> decodedMessages) {
                assertEquals(expectedMessages.size(), decodedMessages.size());
                for (int i = 0; i < expectedMessages.size(); ++i) {
                        assertEquals(expectedMessages.get(i).getClass(), 
decodedMessages.get(i).getClass());
 
-                       if (expectedMessages.get(i) instanceof 
NettyMessage.BufferResponse) {
-                               BufferResponse expected = (BufferResponse) 
expectedMessages.get(i);
-                               BufferResponse actual = (BufferResponse) 
decodedMessages.get(i);
-
-                               verifyBufferResponseHeader(expected, actual);
-                               if (expected.bufferSize == 0 || 
!expected.receiverId.equals(inputChannelId)) {
-                                       assertNull(actual.getBuffer());
-                               } else {
-                                       assertEquals(expected.getBuffer(), 
actual.getBuffer());
-                               }
-
-                       } else if (expectedMessages.get(i) instanceof 
NettyMessage.ErrorResponse) {
-                               verifyErrorResponse((ErrorResponse) 
expectedMessages.get(i), (ErrorResponse) decodedMessages.get(i));
+                       BufferResponse expected = expectedMessages.get(i);
+                       BufferResponse actual = (BufferResponse) 
decodedMessages.get(i);
+                       verifyBufferResponseHeader(expected, actual);
+                       if (expected.bufferSize == 0 || 
!expected.receiverId.equals(inputChannelId)) {
+                               assertNull(actual.getBuffer());
                        } else {
-                               fail("Unsupported message type");
+                               assertEquals(expected.getBuffer(), 
actual.getBuffer());
                        }
                }
        }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java
index c84bcf8..afa3127 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java
@@ -175,9 +175,12 @@ public class NettyTestUtil {
 
        static <T extends NettyMessage> T encodeAndDecode(T msg, 
EmbeddedChannel channel) {
                channel.writeOutbound(msg);
-               ByteBuf encoded = channel.readOutbound();
-
-               assertTrue(channel.writeInbound(encoded));
+               ByteBuf encoded;
+               boolean msgNotEmpty = false;
+               while ((encoded = channel.readOutbound()) != null) {
+                       msgNotEmpty = channel.writeInbound(encoded);
+               }
+               assertTrue(msgNotEmpty);
 
                return channel.readInbound();
        }

Reply via email to