This is an automated email from the ASF dual-hosted git repository. sewen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 65c1edf2ab076718cecbd847b926a7f02c4a4dcb Author: Zhijiang <[email protected]> AuthorDate: Wed Oct 28 15:19:30 2020 +0800 [FLINK-15981][refactor] Refactor NettyMessage#write() to allow writing multiple message parts. Rather than having to produce one message (as a composite buffer) this now allows writing multiple partial messages. That way, we can combine different message types, like a memory buffer (for headers and events) with file region buffers (for direct file transfer). --- .../runtime/io/network/netty/NettyMessage.java | 224 +++++++++------------ .../NettyMessageClientDecoderDelegateTest.java | 36 ++-- .../runtime/io/network/netty/NettyTestUtil.java | 9 +- 3 files changed, 118 insertions(+), 151 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java index 3bf537a..057b219 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.netty; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; @@ -36,6 +37,7 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundInvoker; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelPromise; import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; @@ -46,6 +48,7 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.net.ProtocolException; import java.nio.ByteBuffer; +import java.util.function.Consumer; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -67,7 +70,7 @@ public abstract class NettyMessage { static final int MAGIC_NUMBER = 0xBADC0FFE; - abstract ByteBuf write(ByteBufAllocator allocator) throws Exception; + abstract void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException; // ------------------------------------------------------------------------ @@ -165,22 +168,9 @@ public abstract class NettyMessage { static class NettyMessageEncoder extends ChannelOutboundHandlerAdapter { @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws IOException { if (msg instanceof NettyMessage) { - - ByteBuf serialized = null; - - try { - serialized = ((NettyMessage) msg).write(ctx.alloc()); - } - catch (Throwable t) { - throw new IOException("Error while serializing message: " + msg, t); - } - finally { - if (serialized != null) { - ctx.write(serialized, promise); - } - } + ((NettyMessage) msg).write(ctx, promise, ctx.alloc()); } else { ctx.write(msg, promise); @@ -337,21 +327,29 @@ public abstract class NettyMessage { // -------------------------------------------------------------------- @Override - ByteBuf write(ByteBufAllocator allocator) throws IOException { + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { ByteBuf headerBuf = null; try { // in order to forward the buffer to netty, it needs an allocator set buffer.setAllocator(allocator); - // only allocate header buffer - we will combine it with the data buffer below - headerBuf = allocateBuffer(allocator, ID, MESSAGE_HEADER_LENGTH, bufferSize, false); + headerBuf = fillHeader(allocator); + out.write(headerBuf); + out.write(buffer, promise); + } + catch (Throwable t) { + handleException(headerBuf, buffer, t); + } + } + + @VisibleForTesting + ByteBuf write(ByteBufAllocator allocator) throws IOException { + ByteBuf headerBuf = null; + try { + // in order to forward the buffer to netty, it needs an allocator set + buffer.setAllocator(allocator); - receiverId.writeTo(headerBuf); - headerBuf.writeInt(sequenceNumber); - headerBuf.writeInt(backlog); - headerBuf.writeByte(dataType.ordinal()); - headerBuf.writeBoolean(isCompressed); - headerBuf.writeInt(buffer.readableBytes()); + headerBuf = fillHeader(allocator); CompositeByteBuf composityBuf = allocator.compositeDirectBuffer(); composityBuf.addComponent(headerBuf); @@ -361,16 +359,24 @@ public abstract class NettyMessage { return composityBuf; } catch (Throwable t) { - if (headerBuf != null) { - headerBuf.release(); - } - buffer.recycleBuffer(); - - ExceptionUtils.rethrowIOException(t); + handleException(headerBuf, buffer, t); return null; // silence the compiler } } + private ByteBuf fillHeader(ByteBufAllocator allocator) { + // only allocate header buffer - we will combine it with the data buffer below + ByteBuf headerBuf = allocateBuffer(allocator, ID, MESSAGE_HEADER_LENGTH, bufferSize, false); + + receiverId.writeTo(headerBuf); + headerBuf.writeInt(sequenceNumber); + headerBuf.writeInt(backlog); + headerBuf.writeByte(dataType.ordinal()); + headerBuf.writeBoolean(isCompressed); + headerBuf.writeInt(buffer.readableBytes()); + return headerBuf; + } + /** * Parses the message header part and composes a new BufferResponse with an empty data buffer. The * data buffer will be filled in later. @@ -437,7 +443,7 @@ public abstract class NettyMessage { } @Override - ByteBuf write(ByteBufAllocator allocator) throws IOException { + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { final ByteBuf result = allocateBuffer(allocator, ID); try (ObjectOutputStream oos = new ObjectOutputStream(new ByteBufOutputStream(result))) { @@ -452,16 +458,10 @@ public abstract class NettyMessage { // Update frame length... result.setInt(0, result.readableBytes()); - return result; + out.write(result, promise); } catch (Throwable t) { - result.release(); - - if (t instanceof IOException) { - throw (IOException) t; - } else { - throw new IOException(t); - } + handleException(result, null, t); } } @@ -508,27 +508,16 @@ public abstract class NettyMessage { } @Override - ByteBuf write(ByteBufAllocator allocator) throws IOException { - ByteBuf result = null; - - try { - result = allocateBuffer(allocator, ID, 20 + 16 + 4 + 16 + 4); - - partitionId.getPartitionId().writeTo(result); - partitionId.getProducerId().writeTo(result); - result.writeInt(queueIndex); - receiverId.writeTo(result); - result.writeInt(credit); - - return result; - } - catch (Throwable t) { - if (result != null) { - result.release(); - } + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { + Consumer<ByteBuf> consumer = (bb) -> { + partitionId.getPartitionId().writeTo(bb); + partitionId.getProducerId().writeTo(bb); + bb.writeInt(queueIndex); + receiverId.writeTo(bb); + bb.writeInt(credit); + }; - throw new IOException(t); - } + writeToChannel(out, promise, allocator, consumer, ID, 20 + 16 + 4 + 16 + 4); } static PartitionRequest readFrom(ByteBuf buffer) { @@ -566,32 +555,20 @@ public abstract class NettyMessage { } @Override - ByteBuf write(ByteBufAllocator allocator) throws IOException { - ByteBuf result = null; - - try { - // TODO Directly serialize to Netty's buffer - ByteBuffer serializedEvent = EventSerializer.toSerializedEvent(event); - - result = allocateBuffer(allocator, ID, 4 + serializedEvent.remaining() + 20 + 16 + 16); + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { + // TODO Directly serialize to Netty's buffer + ByteBuffer serializedEvent = EventSerializer.toSerializedEvent(event); - result.writeInt(serializedEvent.remaining()); - result.writeBytes(serializedEvent); + Consumer<ByteBuf> consumer = (bb) -> { + bb.writeInt(serializedEvent.remaining()); + bb.writeBytes(serializedEvent); - partitionId.getPartitionId().writeTo(result); - partitionId.getProducerId().writeTo(result); + partitionId.getPartitionId().writeTo(bb); + partitionId.getProducerId().writeTo(bb); + receiverId.writeTo(bb); + }; - receiverId.writeTo(result); - - return result; - } - catch (Throwable t) { - if (result != null) { - result.release(); - } - - throw new IOException(t); - } + writeToChannel(out, promise, allocator, consumer, ID, 4 + serializedEvent.remaining() + 20 + 16 + 16); } static TaskEventRequest readFrom(ByteBuf buffer, ClassLoader classLoader) throws IOException { @@ -633,22 +610,8 @@ public abstract class NettyMessage { } @Override - ByteBuf write(ByteBufAllocator allocator) throws Exception { - ByteBuf result = null; - - try { - result = allocateBuffer(allocator, ID, 16); - receiverId.writeTo(result); - } - catch (Throwable t) { - if (result != null) { - result.release(); - } - - throw new IOException(t); - } - - return result; + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { + writeToChannel(out, promise, allocator, receiverId :: writeTo, ID, 16); } static CancelPartitionRequest readFrom(ByteBuf buffer) throws Exception { @@ -664,8 +627,8 @@ public abstract class NettyMessage { } @Override - ByteBuf write(ByteBufAllocator allocator) throws Exception { - return allocateBuffer(allocator, ID, 0); + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { + writeToChannel(out, promise, allocator, ignored -> {}, ID, 0); } static CloseRequest readFrom(@SuppressWarnings("unused") ByteBuf buffer) throws Exception { @@ -691,7 +654,7 @@ public abstract class NettyMessage { } @Override - ByteBuf write(ByteBufAllocator allocator) throws IOException { + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { ByteBuf result = null; try { @@ -699,14 +662,10 @@ public abstract class NettyMessage { result.writeInt(credit); receiverId.writeTo(result); - return result; + out.write(result, promise); } catch (Throwable t) { - if (result != null) { - result.release(); - } - - throw new IOException(t); + handleException(result, null, t); } } @@ -737,22 +696,8 @@ public abstract class NettyMessage { } @Override - ByteBuf write(ByteBufAllocator allocator) throws IOException { - ByteBuf result = null; - - try { - result = allocateBuffer(allocator, ID, 16); - receiverId.writeTo(result); - - return result; - } - catch (Throwable t) { - if (result != null) { - result.release(); - } - - throw new IOException(t); - } + void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator allocator) throws IOException { + writeToChannel(out, promise, allocator, receiverId :: writeTo, ID, 16); } static ResumeConsumption readFrom(ByteBuf buffer) { @@ -764,4 +709,35 @@ public abstract class NettyMessage { return String.format("ResumeConsumption(%s)", receiverId); } } + + // ------------------------------------------------------------------------ + + void writeToChannel( + ChannelOutboundInvoker out, + ChannelPromise promise, + ByteBufAllocator allocator, + Consumer<ByteBuf> consumer, + byte id, + int length) throws IOException { + + ByteBuf byteBuf = null; + try { + byteBuf = allocateBuffer(allocator, id, length); + consumer.accept(byteBuf); + out.write(byteBuf, promise); + } + catch (Throwable t) { + handleException(byteBuf, null, t); + } + } + + void handleException(@Nullable ByteBuf byteBuf, @Nullable Buffer buffer, Throwable t) throws IOException { + if (byteBuf != null) { + byteBuf.release(); + } + if (buffer != null) { + buffer.recycleBuffer(); + } + ExceptionUtils.rethrowIOException(t); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java index 6613328..7954258 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientDecoderDelegateTest.java @@ -48,13 +48,10 @@ import java.util.List; import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertTrue; import static org.apache.flink.runtime.io.network.netty.NettyMessage.BufferResponse; -import static org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse; import static org.apache.flink.runtime.io.network.netty.NettyTestUtil.verifyBufferResponseHeader; -import static org.apache.flink.runtime.io.network.netty.NettyTestUtil.verifyErrorResponse; import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel; import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate; import static org.junit.Assert.assertNull; -import static org.junit.Assert.fail; /** * Tests the client side message decoder. @@ -163,7 +160,7 @@ public class NettyMessageClientDecoderDelegateTest extends TestLogger { ByteBuf[] encodedMessages = null; List<NettyMessage> decodedMessages = null; try { - List<NettyMessage> messages = createMessageList( + List<BufferResponse> messages = createMessageList( hasEmptyBuffer, hasBufferForReleasedChannel, hasBufferForRemovedChannel); @@ -186,13 +183,13 @@ public class NettyMessageClientDecoderDelegateTest extends TestLogger { } } - private List<NettyMessage> createMessageList( + private List<BufferResponse> createMessageList( boolean hasEmptyBuffer, boolean hasBufferForRemovedChannel, boolean hasBufferForReleasedChannel) { int seqNumber = 1; - List<NettyMessage> messages = new ArrayList<>(); + List<BufferResponse> messages = new ArrayList<>(); for (int i = 0; i < NUMBER_OF_BUFFER_RESPONSES - 1; i++) { addBufferResponse(messages, inputChannelId, Buffer.DataType.DATA_BUFFER, BUFFER_SIZE, seqNumber++); @@ -210,13 +207,12 @@ public class NettyMessageClientDecoderDelegateTest extends TestLogger { addBufferResponse(messages, inputChannelId, Buffer.DataType.EVENT_BUFFER, 32, seqNumber++); addBufferResponse(messages, inputChannelId, Buffer.DataType.DATA_BUFFER, BUFFER_SIZE, seqNumber); - messages.add(new NettyMessage.ErrorResponse(new RuntimeException("test"), inputChannelId)); return messages; } private void addBufferResponse( - List<NettyMessage> messages, + List<BufferResponse> messages, InputChannelID inputChannelId, Buffer.DataType dataType, int bufferSize, @@ -236,7 +232,7 @@ public class NettyMessageClientDecoderDelegateTest extends TestLogger { return buffer; } - private ByteBuf[] encodeMessages(List<NettyMessage> messages) throws Exception { + private ByteBuf[] encodeMessages(List<BufferResponse> messages) throws Exception { ByteBuf[] encodedMessages = new ByteBuf[messages.size()]; for (int i = 0; i < messages.size(); ++i) { encodedMessages[i] = messages.get(i).write(ALLOCATOR); @@ -315,26 +311,18 @@ public class NettyMessageClientDecoderDelegateTest extends TestLogger { return decodedMessages; } - private void verifyDecodedMessages(List<NettyMessage> expectedMessages, List<NettyMessage> decodedMessages) { + private void verifyDecodedMessages(List<BufferResponse> expectedMessages, List<NettyMessage> decodedMessages) { assertEquals(expectedMessages.size(), decodedMessages.size()); for (int i = 0; i < expectedMessages.size(); ++i) { assertEquals(expectedMessages.get(i).getClass(), decodedMessages.get(i).getClass()); - if (expectedMessages.get(i) instanceof NettyMessage.BufferResponse) { - BufferResponse expected = (BufferResponse) expectedMessages.get(i); - BufferResponse actual = (BufferResponse) decodedMessages.get(i); - - verifyBufferResponseHeader(expected, actual); - if (expected.bufferSize == 0 || !expected.receiverId.equals(inputChannelId)) { - assertNull(actual.getBuffer()); - } else { - assertEquals(expected.getBuffer(), actual.getBuffer()); - } - - } else if (expectedMessages.get(i) instanceof NettyMessage.ErrorResponse) { - verifyErrorResponse((ErrorResponse) expectedMessages.get(i), (ErrorResponse) decodedMessages.get(i)); + BufferResponse expected = expectedMessages.get(i); + BufferResponse actual = (BufferResponse) decodedMessages.get(i); + verifyBufferResponseHeader(expected, actual); + if (expected.bufferSize == 0 || !expected.receiverId.equals(inputChannelId)) { + assertNull(actual.getBuffer()); } else { - fail("Unsupported message type"); + assertEquals(expected.getBuffer(), actual.getBuffer()); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java index c84bcf8..afa3127 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyTestUtil.java @@ -175,9 +175,12 @@ public class NettyTestUtil { static <T extends NettyMessage> T encodeAndDecode(T msg, EmbeddedChannel channel) { channel.writeOutbound(msg); - ByteBuf encoded = channel.readOutbound(); - - assertTrue(channel.writeInbound(encoded)); + ByteBuf encoded; + boolean msgNotEmpty = false; + while ((encoded = channel.readOutbound()) != null) { + msgNotEmpty = channel.writeInbound(encoded); + } + assertTrue(msgNotEmpty); return channel.readInbound(); }
