gaoyunhaii closed pull request #7367: [Flink-10742][network] Let Netty use
Flink's buffers directly in credit-based mode
URL: https://github.com/apache/flink/pull/7367
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/docs/_includes/generated/netty_configuration.html
b/docs/_includes/generated/netty_configuration.html
index 47c48c0aa38..1f9aac4cd59 100644
--- a/docs/_includes/generated/netty_configuration.html
+++ b/docs/_includes/generated/netty_configuration.html
@@ -22,6 +22,11 @@
<td style="word-wrap: break-word;">-1</td>
<td>The number of Netty arenas.</td>
</tr>
+ <tr>
+ <td><h5>taskmanager.network.netty.max-order</h5></td>
+ <td style="word-wrap: break-word;">9</td>
+ <td>The power of 2 of the number of pages in each chunk.</td>
+ </tr>
<tr>
<td><h5>taskmanager.network.netty.sendReceiveBufferSize</h5></td>
<td style="word-wrap: break-word;">0</td>
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/BufferResponseAndNoDataBufferMessageParser.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/BufferResponseAndNoDataBufferMessageParser.java
new file mode 100644
index 00000000000..c2be630a719
--- /dev/null
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/BufferResponseAndNoDataBufferMessageParser.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+
+/**
+ * The parser for both {@link NettyMessage.BufferResponse} and messages with
no data buffer part.
+ */
+public class BufferResponseAndNoDataBufferMessageParser extends
NoDataBufferMessageParser {
+ /** The allocator for the flink buffer. */
+ private final NetworkBufferAllocator networkBufferAllocator;
+
+ BufferResponseAndNoDataBufferMessageParser(NetworkBufferAllocator
networkBufferAllocator) {
+ this.networkBufferAllocator = networkBufferAllocator;
+ }
+
+ @Override
+ public int getMessageHeaderLength(int lengthWithoutFrameHeader, int
msgId) {
+ if (msgId == NettyMessage.BufferResponse.ID) {
+ return
NettyMessage.BufferResponse.MESSAGE_HEADER_LENGTH;
+ } else {
+ return
super.getMessageHeaderLength(lengthWithoutFrameHeader, msgId);
+ }
+ }
+
+ @Override
+ public MessageHeaderParseResult parseMessageHeader(int msgId, ByteBuf
messageHeader) throws Exception {
+ if (msgId == NettyMessage.BufferResponse.ID) {
+ NettyMessage.BufferResponse<Buffer> bufferResponse =
+
NettyMessage.BufferResponse.readFrom(messageHeader, networkBufferAllocator);
+
+ if (bufferResponse.getBuffer() == null) {
+ return
MessageHeaderParseResult.discardDataBuffer(bufferResponse,
bufferResponse.getBufferSize());
+ } else {
+ return
MessageHeaderParseResult.receiveDataBuffer(bufferResponse,
bufferResponse.getBufferSize(), bufferResponse.getBuffer());
+ }
+ } else {
+ return super.parseMessageHeader(msgId, messageHeader);
+ }
+ }
+}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandler.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandler.java
index cc0b2220fd2..abf1d8b928e 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandler.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandler.java
@@ -18,12 +18,8 @@
package org.apache.flink.runtime.io.network.netty;
-import org.apache.flink.core.memory.MemorySegment;
-import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.network.NetworkClientHandler;
import org.apache.flink.runtime.io.network.buffer.Buffer;
-import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
-import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import
org.apache.flink.runtime.io.network.netty.exception.LocalTransportException;
import
org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException;
import org.apache.flink.runtime.io.network.netty.exception.TransportException;
@@ -32,7 +28,6 @@
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
-import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener;
@@ -108,6 +103,10 @@ public void cancelRequestFor(InputChannelID
inputChannelId) {
}
}
+ RemoteInputChannel getInputChannel(InputChannelID inputChannelId) {
+ return inputChannels.get(inputChannelId);
+ }
+
@Override
public void notifyCreditAvailable(final RemoteInputChannel
inputChannel) {
ctx.executor().execute(() ->
ctx.pipeline().fireUserEventTriggered(inputChannel));
@@ -241,12 +240,13 @@ private void checkError() throws IOException {
}
}
+ @SuppressWarnings("unchecked")
private void decodeMsg(Object msg) throws Throwable {
final Class<?> msgClazz = msg.getClass();
// ---- Buffer
--------------------------------------------------------
if (msgClazz == NettyMessage.BufferResponse.class) {
- NettyMessage.BufferResponse bufferOrEvent =
(NettyMessage.BufferResponse) msg;
+ NettyMessage.BufferResponse<Buffer> bufferOrEvent =
(NettyMessage.BufferResponse<Buffer>) msg;
RemoteInputChannel inputChannel =
inputChannels.get(bufferOrEvent.receiverId);
if (inputChannel == null) {
@@ -289,43 +289,21 @@ private void decodeMsg(Object msg) throws Throwable {
}
}
- private void decodeBufferOrEvent(RemoteInputChannel inputChannel,
NettyMessage.BufferResponse bufferOrEvent) throws Throwable {
- try {
- ByteBuf nettyBuffer = bufferOrEvent.getNettyBuffer();
- final int receivedSize = nettyBuffer.readableBytes();
- if (bufferOrEvent.isBuffer()) {
- // ---- Buffer
------------------------------------------------
-
- // Early return for empty buffers. Otherwise
Netty's readBytes() throws an
- // IndexOutOfBoundsException.
- if (receivedSize == 0) {
-
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
- return;
- }
-
- Buffer buffer = inputChannel.requestBuffer();
- if (buffer != null) {
-
nettyBuffer.readBytes(buffer.asByteBuf(), receivedSize);
-
- inputChannel.onBuffer(buffer,
bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
- } else if (inputChannel.isReleased()) {
-
cancelRequestFor(bufferOrEvent.receiverId);
- } else {
- throw new IllegalStateException("No
buffer available in credit-based input channel.");
- }
- } else {
- // ---- Event
-------------------------------------------------
- // TODO We can just keep the serialized data in
the Netty buffer and release it later at the reader
- byte[] byteArray = new byte[receivedSize];
- nettyBuffer.readBytes(byteArray);
+ private void decodeBufferOrEvent(RemoteInputChannel inputChannel,
NettyMessage.BufferResponse<Buffer> bufferOrEvent) throws Throwable {
+ // Early return for empty buffers.
+ if (bufferOrEvent.isBuffer() && bufferOrEvent.getBufferSize()
== 0) {
+
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+ return;
+ }
- MemorySegment memSeg =
MemorySegmentFactory.wrap(byteArray);
- Buffer buffer = new NetworkBuffer(memSeg,
FreeingBufferRecycler.INSTANCE, false, receivedSize);
+ Buffer dataBuffer = bufferOrEvent.getBuffer();
- inputChannel.onBuffer(buffer,
bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
- }
- } finally {
- bufferOrEvent.releaseBuffer();
+ if (dataBuffer != null) {
+ inputChannel.onBuffer(dataBuffer,
bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+ } else if (inputChannel.isReleased()) {
+ cancelRequestFor(bufferOrEvent.receiverId);
+ } else {
+ throw new IllegalStateException("The read buffer is
null in credit-based input channel.");
}
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java
index 6d2a6c88287..24e62bc36fc 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java
@@ -52,28 +52,24 @@
private static final boolean PREFER_DIRECT = true;
/**
- * Arenas allocate chunks of pageSize << maxOrder bytes. With these
defaults, this results in
- * chunks of 16 MB.
- *
- * @see #MAX_ORDER
- */
- private static final int PAGE_SIZE = 8192;
-
- /**
- * Arenas allocate chunks of pageSize << maxOrder bytes. With these
defaults, this results in
- * chunks of 16 MB.
+ * Creates Netty's buffer pool with the specified number of direct
arenas.
*
- * @see #PAGE_SIZE
+ * @param numberOfArenas Number of arenas (recommended: 2 * number of
task
+ * slots)
*/
- private static final int MAX_ORDER = 11;
+ public NettyBufferPool(int numberOfArenas) {
+ this(numberOfArenas, NettyConfig.MAX_ORDER.defaultValue(),
NettyConfig.DEFAULT_PAGE_SIZE);
+ }
/**
* Creates Netty's buffer pool with the specified number of direct
arenas.
*
* @param numberOfArenas Number of arenas (recommended: 2 * number of
task
* slots)
+ * @param maxOrder Max order for the netty buffer pool, default to 9.
+ * @param pageSize Page size for the netty buffer bool, default to 8k.
*/
- public NettyBufferPool(int numberOfArenas) {
+ public NettyBufferPool(int numberOfArenas, int maxOrder, int pageSize) {
super(
PREFER_DIRECT,
// No heap arenas, please.
@@ -85,8 +81,8 @@ public NettyBufferPool(int numberOfArenas) {
// control the memory allocations with low/high
watermarks when writing
// to the TCP channels. Chunks are allocated lazily.
numberOfArenas,
- PAGE_SIZE,
- MAX_ORDER);
+ pageSize,
+ maxOrder);
checkArgument(numberOfArenas >= 1, "Number of arenas");
this.numberOfArenas = numberOfArenas;
@@ -94,7 +90,7 @@ public NettyBufferPool(int numberOfArenas) {
// Arenas allocate chunks of pageSize << maxOrder bytes. With
these
// defaults, this results in chunks of 16 MB.
- this.chunkSize = PAGE_SIZE << MAX_ORDER;
+ this.chunkSize = pageSize << maxOrder;
Object[] allocDirectArenas = null;
try {
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConfig.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConfig.java
index 4694c69385c..5706ff2e9d5 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConfig.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConfig.java
@@ -24,6 +24,7 @@
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.runtime.net.SSLUtils;
+import org.apache.flink.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -46,6 +47,15 @@
.withDeprecatedKeys("taskmanager.net.num-arenas")
.withDescription("The number of Netty arenas.");
+ /**
+ * Arenas allocate chunks of pageSize << maxOrder bytes. With these
defaults, this results in
+ * chunks of 4 MB.
+ */
+ public static final ConfigOption<Integer> MAX_ORDER = ConfigOptions
+ .key("taskmanager.network.netty.max-order")
+ .defaultValue(9)
+ .withDescription("The power of 2 of the number of pages
in each chunk.");
+
public static final ConfigOption<Integer> NUM_THREADS_SERVER =
ConfigOptions
.key("taskmanager.network.netty.server.numThreads")
.defaultValue(-1)
@@ -85,6 +95,14 @@
//
------------------------------------------------------------------------
+ /**
+ * Arenas allocate chunks of pageSize << maxOrder bytes. With these
defaults, this results in
+ * chunks of 4 MB.
+ *
+ * @see #MAX_ORDER
+ */
+ static final int DEFAULT_PAGE_SIZE = 8192;
+
enum TransportType {
NIO, EPOLL, AUTO
}
@@ -212,6 +230,20 @@ public boolean isCreditBasedEnabled() {
return
config.getBoolean(TaskManagerOptions.NETWORK_CREDIT_MODEL);
}
+ public int getPageSize() {
+ return DEFAULT_PAGE_SIZE;
+ }
+
+ public int getMaxOrder() {
+ int maxOrder = config.getInteger(MAX_ORDER);
+
+ // Assert the chunk size is not too small to fulfill the
requirements of a single thread.
+ // We require the chunk size to be larger than 1MB based on the
experiment results.
+ int minimumMaxOrder = 20 - MathUtils.log2strict(getPageSize());
+
+ return Math.max(minimumMaxOrder, maxOrder);
+ }
+
public Configuration getConfig() {
return config;
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java
index 3fe15e5c898..8756574de05 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java
@@ -38,7 +38,7 @@
public NettyConnectionManager(NettyConfig nettyConfig) {
this.server = new NettyServer(nettyConfig);
this.client = new NettyClient(nettyConfig);
- this.bufferPool = new
NettyBufferPool(nettyConfig.getNumberOfArenas());
+ this.bufferPool = new
NettyBufferPool(nettyConfig.getNumberOfArenas(), nettyConfig.getMaxOrder(),
nettyConfig.getPageSize());
this.partitionRequestClientFactory = new
PartitionRequestClientFactory(client);
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
index ca0de6b10dc..3ce650ff5a0 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
@@ -44,7 +44,6 @@
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
-import java.net.ProtocolException;
import java.nio.ByteBuffer;
import static org.apache.flink.util.Preconditions.checkArgument;
@@ -203,6 +202,7 @@ public void write(ChannelHandlerContext ctx, Object msg,
ChannelPromise promise)
*/
static class NettyMessageDecoder extends LengthFieldBasedFrameDecoder {
private final boolean restoreOldNettyBehaviour;
+ private final NoDataBufferMessageParser
noDataBufferMessageParser;
/**
* Creates a new message decoded with the required frame
properties.
@@ -214,6 +214,7 @@ public void write(ChannelHandlerContext ctx, Object msg,
ChannelPromise promise)
NettyMessageDecoder(boolean restoreOldNettyBehaviour) {
super(Integer.MAX_VALUE, 0, 4, -4, 4);
this.restoreOldNettyBehaviour =
restoreOldNettyBehaviour;
+ this.noDataBufferMessageParser = new
NoDataBufferMessageParser();
}
@Override
@@ -234,31 +235,11 @@ protected Object decode(ChannelHandlerContext ctx,
ByteBuf in) throws Exception
byte msgId = msg.readByte();
final NettyMessage decodedMsg;
- switch (msgId) {
- case BufferResponse.ID:
- decodedMsg =
BufferResponse.readFrom(msg);
- break;
- case PartitionRequest.ID:
- decodedMsg =
PartitionRequest.readFrom(msg);
- break;
- case TaskEventRequest.ID:
- decodedMsg =
TaskEventRequest.readFrom(msg, getClass().getClassLoader());
- break;
- case ErrorResponse.ID:
- decodedMsg =
ErrorResponse.readFrom(msg);
- break;
- case CancelPartitionRequest.ID:
- decodedMsg =
CancelPartitionRequest.readFrom(msg);
- break;
- case CloseRequest.ID:
- decodedMsg =
CloseRequest.readFrom(msg);
- break;
- case AddCredit.ID:
- decodedMsg =
AddCredit.readFrom(msg);
- break;
- default:
- throw new ProtocolException(
- "Received unknown
message from producer: " + msg);
+
+ if (msgId == BufferResponse.ID) {
+ decodedMsg =
BufferResponse.readFrom(msg);
+ } else {
+ decodedMsg =
noDataBufferMessageParser.parseMessageHeader(msgId, msg).getParsedMessage();
}
return decodedMsg;
@@ -293,15 +274,114 @@ protected ByteBuf extractFrame(ChannelHandlerContext
ctx, ByteBuf buffer, int in
}
}
+ /**
+ * Holds a buffer object who has the view of netty {@link ByteBuf}.
+ */
+ interface BufferHolder<T> {
+ /**
+ * Gets the original buffer object.
+ *
+ * @return the original buffer object.
+ */
+ T getBuffer();
+
+ /**
+ * Gets the view that casts the buffer object as a netty
ByteBuf.
+ *
+ * @return the view of the buffer object as a netty ByteBuf.
+ */
+ ByteBuf asByteBuf();
+
+ /**
+ * Notification of the buffer object is going to be written.
+ *
+ * @param allocator the ByteBuf allocator of current netty
pipeline.
+ */
+ void onWrite(ByteBufAllocator allocator);
+
+ /**
+ * Releases the underlying buffer object.
+ */
+ void release();
+ }
+
+ /**
+ * The buffer holder to hold a netty {@link ByteBuf}.
+ */
+ static class NettyBufferHolder implements BufferHolder<ByteBuf> {
+ private ByteBuf byteBuf;
+
+ NettyBufferHolder(@Nullable ByteBuf byteBuf) {
+ this.byteBuf = byteBuf;
+ }
+
+ @Override
+ @Nullable
+ public ByteBuf getBuffer() {
+ return byteBuf;
+ }
+
+ @Override
+ public ByteBuf asByteBuf() {
+ return byteBuf;
+ }
+
+ @Override
+ public void onWrite(ByteBufAllocator allocator) {
+ // No operations.
+ }
+
+ @Override
+ public void release() {
+ byteBuf.release();
+ }
+ }
+
+ /**
+ * The buffer holder to hold a flink {@link Buffer}.
+ */
+ static class FlinkBufferHolder implements BufferHolder<Buffer> {
+ private Buffer buffer;
+
+ FlinkBufferHolder(@Nullable Buffer buffer) {
+ this.buffer = buffer;
+ }
+
+ @Override
+ @Nullable
+ public Buffer getBuffer() {
+ return buffer;
+ }
+
+ @Override
+ public ByteBuf asByteBuf() {
+ return buffer == null ? null : buffer.asByteBuf();
+ }
+
+ @Override
+ public void onWrite(ByteBufAllocator allocator) {
+ // in order to forward the buffer to netty, it needs an
allocator set
+ buffer.setAllocator(allocator);
+ }
+
+ @Override
+ public void release() {
+ buffer.recycleBuffer();
+ }
+ }
+
//
------------------------------------------------------------------------
// Server responses
//
------------------------------------------------------------------------
- static class BufferResponse extends NettyMessage {
+ static class BufferResponse<T> extends NettyMessage {
- private static final byte ID = 0;
+ static final byte ID = 0;
- final ByteBuf buffer;
+ // receiver ID (16), sequence number (4), backlog (4), isBuffer
(1), buffer size (4)
+ static final int MESSAGE_HEADER_LENGTH = 16 + 4 + 4 + 1 + 4;
+
+ final BufferHolder<T> bufferHolder;
final InputChannelID receiverId;
@@ -311,41 +391,58 @@ protected ByteBuf extractFrame(ChannelHandlerContext ctx,
ByteBuf buffer, int in
final boolean isBuffer;
- private BufferResponse(
- ByteBuf buffer,
+ final int bufferSize;
+
+ BufferResponse(
+ BufferHolder<T> bufferHolder,
boolean isBuffer,
int sequenceNumber,
InputChannelID receiverId,
int backlog) {
- this.buffer = checkNotNull(buffer);
+ this.bufferHolder = checkNotNull(bufferHolder);
+ checkNotNull(bufferHolder.getBuffer());
+
this.isBuffer = isBuffer;
this.sequenceNumber = sequenceNumber;
this.receiverId = checkNotNull(receiverId);
this.backlog = backlog;
+ this.bufferSize =
bufferHolder.asByteBuf().readableBytes();
}
- BufferResponse(
- Buffer buffer,
+ private BufferResponse(
+ BufferHolder<T> bufferHolder,
+ boolean isBuffer,
int sequenceNumber,
InputChannelID receiverId,
- int backlog) {
- this.buffer = checkNotNull(buffer).asByteBuf();
- this.isBuffer = buffer.isBuffer();
+ int backlog,
+ int bufferSize) {
+ this.bufferHolder = checkNotNull(bufferHolder);
+
+ this.isBuffer = isBuffer;
this.sequenceNumber = sequenceNumber;
this.receiverId = checkNotNull(receiverId);
this.backlog = backlog;
+ this.bufferSize = bufferSize;
}
boolean isBuffer() {
return isBuffer;
}
- ByteBuf getNettyBuffer() {
- return buffer;
+ T getBuffer() {
+ return bufferHolder.getBuffer();
+ }
+
+ ByteBuf asByteBuf() {
+ return bufferHolder.asByteBuf();
+ }
+
+ public int getBufferSize() {
+ return bufferSize;
}
void releaseBuffer() {
- buffer.release();
+ bufferHolder.release();
}
//
--------------------------------------------------------------------
@@ -354,44 +451,45 @@ void releaseBuffer() {
@Override
ByteBuf write(ByteBufAllocator allocator) throws IOException {
- // receiver ID (16), sequence number (4), backlog (4),
isBuffer (1), buffer size (4)
- final int messageHeaderLength = 16 + 4 + 4 + 1 + 4;
-
ByteBuf headerBuf = null;
try {
- if (buffer instanceof Buffer) {
- // in order to forward the buffer to
netty, it needs an allocator set
- ((Buffer)
buffer).setAllocator(allocator);
- }
+ bufferHolder.onWrite(allocator);
// only allocate header buffer - we will
combine it with the data buffer below
- headerBuf = allocateBuffer(allocator, ID,
messageHeaderLength, buffer.readableBytes(), false);
+ headerBuf = allocateBuffer(allocator, ID,
MESSAGE_HEADER_LENGTH, bufferSize, false);
receiverId.writeTo(headerBuf);
headerBuf.writeInt(sequenceNumber);
headerBuf.writeInt(backlog);
headerBuf.writeBoolean(isBuffer);
- headerBuf.writeInt(buffer.readableBytes());
+ headerBuf.writeInt(bufferSize);
CompositeByteBuf composityBuf =
allocator.compositeDirectBuffer();
composityBuf.addComponent(headerBuf);
- composityBuf.addComponent(buffer);
+
composityBuf.addComponent(bufferHolder.asByteBuf());
// update writer index since we have data
written to the components:
-
composityBuf.writerIndex(headerBuf.writerIndex() + buffer.writerIndex());
+
composityBuf.writerIndex(headerBuf.writerIndex() +
bufferHolder.asByteBuf().writerIndex());
return composityBuf;
}
catch (Throwable t) {
if (headerBuf != null) {
headerBuf.release();
}
- buffer.release();
+ bufferHolder.release();
ExceptionUtils.rethrowIOException(t);
return null; // silence the compiler
}
}
- static BufferResponse readFrom(ByteBuf buffer) {
+ /**
+ * Parses the whole BufferResponse message and composes a new
BufferResponse with both header parsed and
+ * data buffer filled in. This method is used in
non-credit-based network stack.
+ *
+ * @param buffer the whole serialized BufferResponse message.
+ * @return a BufferResponse object with the header parsed and
the data buffer filled in.
+ */
+ static BufferResponse<ByteBuf> readFrom(ByteBuf buffer) {
InputChannelID receiverId =
InputChannelID.fromByteBuf(buffer);
int sequenceNumber = buffer.readInt();
int backlog = buffer.readInt();
@@ -399,13 +497,52 @@ static BufferResponse readFrom(ByteBuf buffer) {
int size = buffer.readInt();
ByteBuf retainedSlice = buffer.readSlice(size).retain();
- return new BufferResponse(retainedSlice, isBuffer,
sequenceNumber, receiverId, backlog);
+ return new BufferResponse<>(
+ new NettyBufferHolder(retainedSlice),
+ isBuffer,
+ sequenceNumber,
+ receiverId,
+ backlog,
+ size);
+ }
+
+ /**
+ * Parses the message header part and composes a new
BufferResponse with an empty data buffer. The
+ * data buffer will be filled in later. This method is used in
credit-based network stack.
+ *
+ * @param messageHeader the serialized message header.
+ * @param bufferAllocator the allocator for network buffer.
+ * @return a BufferResponse object with the header parsed and
the data buffer to fill in later.
+ */
+ static BufferResponse<Buffer> readFrom(ByteBuf messageHeader,
NetworkBufferAllocator bufferAllocator) {
+ InputChannelID receiverId =
InputChannelID.fromByteBuf(messageHeader);
+ int sequenceNumber = messageHeader.readInt();
+ int backlog = messageHeader.readInt();
+ boolean isBuffer = messageHeader.readBoolean();
+ int size = messageHeader.readInt();
+
+ Buffer dataBuffer = null;
+ if (size != 0) {
+ if (isBuffer) {
+ dataBuffer =
bufferAllocator.allocatePooledNetworkBuffer(receiverId, size);
+ } else {
+ dataBuffer =
bufferAllocator.allocateUnPooledNetworkBuffer(size);
+ }
+ }
+
+ return new BufferResponse<>(
+ new FlinkBufferHolder(dataBuffer),
+ isBuffer,
+ sequenceNumber,
+ receiverId,
+ backlog,
+ size);
}
}
static class ErrorResponse extends NettyMessage {
- private static final byte ID = 1;
+ static final byte ID = 1;
final Throwable cause;
@@ -480,7 +617,7 @@ static ErrorResponse readFrom(ByteBuf buffer) throws
Exception {
static class PartitionRequest extends NettyMessage {
- private static final byte ID = 2;
+ static final byte ID = 2;
final ResultPartitionID partitionId;
@@ -541,7 +678,7 @@ public String toString() {
static class TaskEventRequest extends NettyMessage {
- private static final byte ID = 3;
+ static final byte ID = 3;
final TaskEvent event;
@@ -614,7 +751,7 @@ static TaskEventRequest readFrom(ByteBuf buffer,
ClassLoader classLoader) throws
*/
static class CancelPartitionRequest extends NettyMessage {
- private static final byte ID = 4;
+ static final byte ID = 4;
final InputChannelID receiverId;
@@ -648,7 +785,7 @@ static CancelPartitionRequest readFrom(ByteBuf buffer)
throws Exception {
static class CloseRequest extends NettyMessage {
- private static final byte ID = 5;
+ static final byte ID = 5;
CloseRequest() {
}
@@ -668,7 +805,7 @@ static CloseRequest readFrom(@SuppressWarnings("unused")
ByteBuf buffer) throws
*/
static class AddCredit extends NettyMessage {
- private static final byte ID = 6;
+ static final byte ID = 6;
final ResultPartitionID partitionId;
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessageParser.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessageParser.java
new file mode 100644
index 00000000000..1b3e4e51525
--- /dev/null
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessageParser.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+
+/**
+ * Responsible for offering the message header length and parsing the message
header part of the message.
+ */
+public interface NettyMessageParser {
+ /**
+ * Indicates how to deal with the data buffer part of the current
message.
+ */
+ enum DataBufferAction {
+ /** The current message does not have the data buffer part. */
+ NO_DATA_BUFFER,
+
+ /** The current message has data buffer part and this part
needs to be received. */
+ RECEIVE,
+
+ /**
+ * The current message has data buffer part but it needs to be
discarded. For example,
+ * when the target input channel has been closed.
+ */
+ DISCARD
+ }
+
+ /**
+ * Indicates the result of parsing the message header.
+ */
+ class MessageHeaderParseResult {
+ /** The message object parsed from the buffer. */
+ private NettyMessage parsedMessage;
+
+ /** The target action of how to process the data buffer part. */
+ private DataBufferAction dataBufferAction;
+
+ /** The size of the data buffer part. */
+ private int dataBufferSize;
+
+ /** The target data buffer for receiving the data. */
+ private Buffer targetDataBuffer;
+
+ private MessageHeaderParseResult(
+ NettyMessage parsedMessage,
+ DataBufferAction dataBufferAction,
+ int dataBufferSize,
+ Buffer targetDataBuffer) {
+ this.parsedMessage = parsedMessage;
+ this.dataBufferAction = dataBufferAction;
+ this.dataBufferSize = dataBufferSize;
+ this.targetDataBuffer = targetDataBuffer;
+ }
+
+ static MessageHeaderParseResult noDataBuffer(NettyMessage
message) {
+ return new MessageHeaderParseResult(message,
DataBufferAction.NO_DATA_BUFFER, 0, null);
+ }
+
+ static MessageHeaderParseResult receiveDataBuffer(NettyMessage
message, int dataBufferSize, Buffer targetDataBuffer) {
+ return new MessageHeaderParseResult(message,
DataBufferAction.RECEIVE, dataBufferSize, targetDataBuffer);
+ }
+
+ static MessageHeaderParseResult discardDataBuffer(NettyMessage
message, int dataBufferSize) {
+ return new MessageHeaderParseResult(message,
DataBufferAction.DISCARD, dataBufferSize, null);
+ }
+
+ NettyMessage getParsedMessage() {
+ return parsedMessage;
+ }
+
+ DataBufferAction getDataBufferAction() {
+ return dataBufferAction;
+ }
+
+ int getDataBufferSize() {
+ return dataBufferSize;
+ }
+
+ Buffer getTargetDataBuffer() {
+ return targetDataBuffer;
+ }
+ }
+
+ /**
+ * Get the length of the message header part.
+ *
+ * @param lengthWithoutFrameHeader the message length not counting the
frame header part.
+ * @param msgId the id of the current message.
+ * @return the length of the message header part.
+ */
+ int getMessageHeaderLength(int lengthWithoutFrameHeader, int msgId);
+
+ /**
+ * Parse the message header.
+ *
+ * @param msgId the id of the current message.
+ * @param messageHeader the buffer containing the serialized message
header.
+ * @return the parsed result.
+ */
+ MessageHeaderParseResult parseMessageHeader(int msgId, ByteBuf
messageHeader) throws Exception;
+}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyProtocol.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyProtocol.java
index ebad11bdb39..d8df20c19c6 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyProtocol.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyProtocol.java
@@ -18,7 +18,6 @@
package org.apache.flink.runtime.io.network.netty;
-import org.apache.flink.runtime.io.network.NetworkClientHandler;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.partition.ResultPartitionProvider;
@@ -83,7 +82,7 @@
return new ChannelHandler[] {
messageEncoder,
- new
NettyMessage.NettyMessageDecoder(!creditBasedEnabled),
+ new ZeroCopyNettyMessageDecoder(new
NoDataBufferMessageParser()),
serverHandler,
queueOfPartitionQueues
};
@@ -122,13 +121,21 @@
* @return channel handlers
*/
public ChannelHandler[] getClientChannelHandlers() {
- NetworkClientHandler networkClientHandler =
- creditBasedEnabled ? new
CreditBasedPartitionRequestClientHandler() :
- new PartitionRequestClientHandler();
- return new ChannelHandler[] {
- messageEncoder,
- new
NettyMessage.NettyMessageDecoder(!creditBasedEnabled),
- networkClientHandler};
- }
+ if (creditBasedEnabled) {
+ CreditBasedPartitionRequestClientHandler
networkClientHandler = new CreditBasedPartitionRequestClientHandler();
+ NetworkBufferAllocator networkBufferAllocator = new
NetworkBufferAllocator(networkClientHandler);
+ ZeroCopyNettyMessageDecoder zeroCopyNettyMessageDecoder
=
+ new ZeroCopyNettyMessageDecoder(new
BufferResponseAndNoDataBufferMessageParser(networkBufferAllocator));
+ return new ChannelHandler[] {
+ messageEncoder,
+ zeroCopyNettyMessageDecoder,
+ networkClientHandler};
+ } else {
+ return new ChannelHandler[] {
+ messageEncoder,
+ new
NettyMessage.NettyMessageDecoder(true),
+ new PartitionRequestClientHandler()};
+ }
+ }
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NetworkBufferAllocator.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NetworkBufferAllocator.java
new file mode 100644
index 00000000000..4a127e1653d
--- /dev/null
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NetworkBufferAllocator.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
+import
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * An allocator used for requesting buffers in the receiver side of netty
handlers.
+ */
+public class NetworkBufferAllocator {
+ private final CreditBasedPartitionRequestClientHandler
partitionRequestClientHandler;
+
+ NetworkBufferAllocator(CreditBasedPartitionRequestClientHandler
partitionRequestClientHandler) {
+ this.partitionRequestClientHandler =
checkNotNull(partitionRequestClientHandler);
+ }
+
+ /**
+ * Allocates a pooled network buffer for the specific input channel.
+ *
+ * @param receiverId The input channel id to request pooled buffer with.
+ * @param size The requested buffer size.
+ * @return The pooled network buffer.
+ */
+ public Buffer allocatePooledNetworkBuffer(InputChannelID receiverId,
int size) {
+ Buffer buffer = null;
+
+ RemoteInputChannel inputChannel =
partitionRequestClientHandler.getInputChannel(receiverId);
+ if (inputChannel != null) {
+ buffer = inputChannel.requestBuffer();
+ }
+
+ return buffer;
+ }
+
+ /**
+ * Allocates an un-pooled network buffer with the specific size.
+ *
+ * @param size The requested buffer size.
+ * @return The un-pooled network buffer.
+ */
+ public Buffer allocateUnPooledNetworkBuffer(int size) {
+ byte[] byteArray = new byte[size];
+ MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray);
+
+ return new NetworkBuffer(memSeg,
FreeingBufferRecycler.INSTANCE, false);
+ }
+}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NoDataBufferMessageParser.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NoDataBufferMessageParser.java
new file mode 100644
index 00000000000..3e848178bcd
--- /dev/null
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NoDataBufferMessageParser.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+
+import java.net.ProtocolException;
+
+/**
+ * The parser for messages with no data buffer part.
+ */
+public class NoDataBufferMessageParser implements NettyMessageParser {
+
+ @Override
+ public int getMessageHeaderLength(int lengthWithoutFrameHeader, int
msgId) {
+ return lengthWithoutFrameHeader;
+ }
+
+ @Override
+ public MessageHeaderParseResult parseMessageHeader(int msgId, ByteBuf
messageHeader) throws Exception {
+ NettyMessage decodedMsg;
+
+ switch (msgId) {
+ case NettyMessage.PartitionRequest.ID:
+ decodedMsg =
NettyMessage.PartitionRequest.readFrom(messageHeader);
+ break;
+ case NettyMessage.TaskEventRequest.ID:
+ decodedMsg =
NettyMessage.TaskEventRequest.readFrom(messageHeader,
getClass().getClassLoader());
+ break;
+ case NettyMessage.ErrorResponse.ID:
+ decodedMsg =
NettyMessage.ErrorResponse.readFrom(messageHeader);
+ break;
+ case NettyMessage.CancelPartitionRequest.ID:
+ decodedMsg =
NettyMessage.CancelPartitionRequest.readFrom(messageHeader);
+ break;
+ case NettyMessage.CloseRequest.ID:
+ decodedMsg =
NettyMessage.CloseRequest.readFrom(messageHeader);
+ break;
+ case NettyMessage.AddCredit.ID:
+ decodedMsg =
NettyMessage.AddCredit.readFrom(messageHeader);
+ break;
+ default:
+ throw new ProtocolException("Received unknown
message from producer: " + messageHeader);
+ }
+
+ return MessageHeaderParseResult.noDataBuffer(decodedMsg);
+ }
+}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
index 34f65c0f22d..55942fe09a8 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
@@ -233,12 +233,13 @@ public void channelReadComplete(ChannelHandlerContext
ctx) throws Exception {
super.channelReadComplete(ctx);
}
+ @SuppressWarnings("unchecked")
private boolean decodeMsg(Object msg, boolean isStagedBuffer) throws
Throwable {
final Class<?> msgClazz = msg.getClass();
// ---- Buffer
--------------------------------------------------------
if (msgClazz == NettyMessage.BufferResponse.class) {
- NettyMessage.BufferResponse bufferOrEvent =
(NettyMessage.BufferResponse) msg;
+ NettyMessage.BufferResponse<ByteBuf> bufferOrEvent =
(NettyMessage.BufferResponse<ByteBuf>) msg;
RemoteInputChannel inputChannel =
inputChannels.get(bufferOrEvent.receiverId);
if (inputChannel == null) {
@@ -284,11 +285,11 @@ else if (msgClazz == NettyMessage.ErrorResponse.class) {
return true;
}
- private boolean decodeBufferOrEvent(RemoteInputChannel inputChannel,
NettyMessage.BufferResponse bufferOrEvent, boolean isStagedBuffer) throws
Throwable {
+ private boolean decodeBufferOrEvent(RemoteInputChannel inputChannel,
NettyMessage.BufferResponse<ByteBuf> bufferOrEvent, boolean isStagedBuffer)
throws Throwable {
boolean releaseNettyBuffer = true;
try {
- ByteBuf nettyBuffer = bufferOrEvent.getNettyBuffer();
+ ByteBuf nettyBuffer = bufferOrEvent.getBuffer();
final int receivedSize = nettyBuffer.readableBytes();
if (bufferOrEvent.isBuffer()) {
// ---- Buffer
------------------------------------------------
@@ -378,9 +379,9 @@ public void run() {
private final AtomicReference<Buffer> availableBuffer = new
AtomicReference<Buffer>();
- private NettyMessage.BufferResponse stagedBufferResponse;
+ private NettyMessage.BufferResponse<ByteBuf>
stagedBufferResponse;
- private boolean waitForBuffer(BufferProvider bufferProvider,
NettyMessage.BufferResponse bufferResponse) {
+ private boolean waitForBuffer(BufferProvider bufferProvider,
NettyMessage.BufferResponse<ByteBuf> bufferResponse) {
stagedBufferResponse = bufferResponse;
@@ -447,7 +448,7 @@ public void run() {
throw new
IllegalStateException("Running buffer availability task w/o a buffer.");
}
- ByteBuf nettyBuffer =
stagedBufferResponse.getNettyBuffer();
+ ByteBuf nettyBuffer =
stagedBufferResponse.getBuffer();
nettyBuffer.readBytes(buffer.asByteBuf(),
nettyBuffer.readableBytes());
stagedBufferResponse.releaseBuffer();
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
index c3d3d1bcc10..d1b9978559d 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
@@ -241,8 +241,9 @@ private void writeAndFlushNextMessageIfPossible(final
Channel channel) throws IO
registerAvailableReader(reader);
}
- BufferResponse msg = new BufferResponse(
- next.buffer(),
+ BufferResponse msg = new
BufferResponse<>(
+ new
NettyMessage.FlinkBufferHolder(next.buffer()),
+ next.buffer().isBuffer(),
reader.getSequenceNumber(),
reader.getReceiverId(),
next.buffersInBacklog());
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageDecoder.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageDecoder.java
new file mode 100644
index 00000000000..dfe3e80fffc
--- /dev/null
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageDecoder.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
+import
org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
+
+import static
org.apache.flink.runtime.io.network.netty.NettyMessage.FRAME_HEADER_LENGTH;
+import static
org.apache.flink.runtime.io.network.netty.NettyMessage.MAGIC_NUMBER;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Decodes messages from the fragmentary netty buffers. This decoder assumes
the
+ * messages have the following format:
+ * +-----------------------------------+--------------------------------+
+ * | FRAME HEADER || MESSAGE HEADER | DATA BUFFER (Optional) |
+ * +-----------------------------------+--------------------------------+
+ * and it decodes each part in order.
+ *
+ * This decoder tries best to eliminate copying. For the frame header and
message header,
+ * it only cumulates data when they span multiple input buffers. For the
buffer part, it
+ * copies directly to the input channels to avoid future copying.
+ *
+ * The format of the frame header is
+ * +------------------+------------------+--------++----------------+
+ * | FRAME LENGTH (4) | MAGIC NUMBER (4) | ID (1) || CUSTOM MESSAGE |
+ * +------------------+------------------+--------++----------------+
+ */
+public class ZeroCopyNettyMessageDecoder extends ChannelInboundHandlerAdapter {
+
+ private static final int INITIAL_MESSAGE_HEADER_BUFFER_LENGTH = 128;
+
+ /** The parser to parse the message header. */
+ private final NettyMessageParser messageParser;
+
+ /** The buffer used to cumulate the frame header part. */
+ private ByteBuf frameHeaderBuffer;
+
+ /** The buffer used to receive the message header part. */
+ private ByteBuf messageHeaderBuffer;
+
+ /** Which part of the current message is being decoded. */
+ private DecodeStep decodeStep;
+
+ /** How many bytes have been decoded in current step. */
+ private int decodedBytesOfCurrentStep;
+
+ /** The intermediate state when decoding the current message. */
+ private final MessageDecodeIntermediateState intermediateState;
+
+ ZeroCopyNettyMessageDecoder(NettyMessageParser messageParser) {
+ this.messageParser = messageParser;
+ this.intermediateState = new MessageDecodeIntermediateState();
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ super.channelActive(ctx);
+
+ frameHeaderBuffer =
ctx.alloc().directBuffer(NettyMessage.FRAME_HEADER_LENGTH);
+ messageHeaderBuffer =
ctx.alloc().directBuffer(INITIAL_MESSAGE_HEADER_BUFFER_LENGTH);
+
+ decodeStep = DecodeStep.DECODING_FRAME;
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception
{
+ super.channelInactive(ctx);
+
+ if (intermediateState.messageHeaderParseResult != null) {
+ Buffer buffer =
intermediateState.messageHeaderParseResult.getTargetDataBuffer();
+
+ if (buffer != null) {
+ buffer.recycleBuffer();
+ }
+ }
+
+ clearState();
+
+ frameHeaderBuffer.release();
+ messageHeaderBuffer.release();
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws
Exception {
+ if (!(msg instanceof ByteBuf)) {
+ ctx.fireChannelRead(msg);
+ return;
+ }
+
+ ByteBuf data = (ByteBuf) msg;
+
+ try {
+ while (data.readableBytes() > 0) {
+ if (decodeStep == DecodeStep.DECODING_FRAME) {
+ ByteBuf toDecode =
cumulateBufferIfNeeded(frameHeaderBuffer, data, FRAME_HEADER_LENGTH);
+
+ if (toDecode != null) {
+ decodeFrameHeader(toDecode);
+
+ decodedBytesOfCurrentStep = 0;
+ decodeStep =
DecodeStep.DECODING_MESSAGE_HEADER;
+ }
+ }
+
+ if (decodeStep ==
DecodeStep.DECODING_MESSAGE_HEADER) {
+ ByteBuf toDecoder =
cumulateBufferIfNeeded(messageHeaderBuffer, data,
intermediateState.messageHeaderLength);
+
+ if (toDecoder != null) {
+
intermediateState.messageHeaderParseResult =
messageParser.parseMessageHeader(intermediateState.msgId, toDecoder);
+
+ if
(intermediateState.messageHeaderParseResult.getDataBufferAction() ==
NettyMessageParser.DataBufferAction.NO_DATA_BUFFER) {
+
ctx.fireChannelRead(intermediateState.messageHeaderParseResult.getParsedMessage());
+ clearState();
+ } else {
+
decodedBytesOfCurrentStep = 0;
+ decodeStep =
DecodeStep.DECODING_BUFFER;
+ }
+ }
+ }
+
+ if (decodeStep == DecodeStep.DECODING_BUFFER) {
+ readOrDiscardBufferResponse(data);
+
+ if (decodedBytesOfCurrentStep ==
intermediateState.messageHeaderParseResult.getDataBufferSize()) {
+
ctx.fireChannelRead(intermediateState.messageHeaderParseResult.getParsedMessage());
+ clearState();
+ }
+ }
+ }
+
+ checkState(!data.isReadable(), "Not all data of the
received buffer consumed.");
+ } finally {
+ data.release();
+ }
+ }
+
+ private void decodeFrameHeader(ByteBuf frameHeaderBuffer) {
+ int messageLength = frameHeaderBuffer.readInt();
+ checkState(messageLength >= 0, "The length field of current
message must be non-negative");
+
+ int magicNumber = frameHeaderBuffer.readInt();
+ checkState(magicNumber == MAGIC_NUMBER, "Network stream
corrupted: received incorrect magic number.");
+
+ intermediateState.msgId = frameHeaderBuffer.readByte();
+ intermediateState.messageHeaderLength =
messageParser.getMessageHeaderLength(
+ messageLength - FRAME_HEADER_LENGTH,
+ intermediateState.msgId);
+ }
+
+ private void readOrDiscardBufferResponse(ByteBuf data) {
+ int dataBufferSize =
intermediateState.messageHeaderParseResult.getDataBufferSize();
+
+ // If current buffer is empty, then there is no more data to
receive.
+ if (dataBufferSize == 0) {
+ return;
+ }
+
+ NettyMessageParser.DataBufferAction dataBufferAction =
intermediateState.messageHeaderParseResult.getDataBufferAction();
+ int remainingBufferSize = dataBufferSize -
decodedBytesOfCurrentStep;
+
+ switch (dataBufferAction) {
+ case RECEIVE:
+ Buffer dataBuffer =
intermediateState.messageHeaderParseResult.getTargetDataBuffer();
+ decodedBytesOfCurrentStep +=
copyToTargetBuffer(dataBuffer.asByteBuf(), data, remainingBufferSize);
+ break;
+ case DISCARD:
+ int actualBytesToDiscard =
Math.min(data.readableBytes(), remainingBufferSize);
+ data.readerIndex(data.readerIndex() +
actualBytesToDiscard);
+ decodedBytesOfCurrentStep +=
actualBytesToDiscard;
+ break;
+ }
+ }
+
+ private ByteBuf cumulateBufferIfNeeded(ByteBuf cumulatedBuffer, ByteBuf
src, int size) {
+ int cumulatedSize = cumulatedBuffer.readableBytes();
+
+ if (cumulatedSize == 0) {
+ if (src.readableBytes() >= size) {
+ return src;
+ } else {
+ // The capacity will stop increasing after
reaching the maximum value.
+ if (cumulatedBuffer.capacity() < size) {
+ cumulatedBuffer.capacity(size);
+ }
+ }
+ }
+
+ copyToTargetBuffer(cumulatedBuffer, src, size -
cumulatedBuffer.readableBytes());
+
+ if (cumulatedBuffer.readableBytes() == size) {
+ return cumulatedBuffer;
+ }
+
+ return null;
+ }
+
+ /**
+ * Clears all the intermediate state for reading the next message.
+ */
+ private void clearState() {
+ frameHeaderBuffer.clear();
+ messageHeaderBuffer.clear();
+
+ intermediateState.resetState();
+
+ decodedBytesOfCurrentStep = 0;
+ decodeStep = DecodeStep.DECODING_FRAME;
+ }
+
+ /**
+ * Copies bytes from the src to dest, but do not exceed the capacity of
the dest buffer.
+ *
+ * @param dest The ByteBuf to copy bytes to.
+ * @param src The ByteBuf to copy bytes from.
+ * @param maxCopySize Maximum size of bytes to copy.
+ * @return The length of actually copied bytes.
+ */
+ private int copyToTargetBuffer(ByteBuf dest, ByteBuf src, int
maxCopySize) {
+ int copyLength = Math.min(src.readableBytes(), maxCopySize);
+ checkState(dest.writableBytes() >= copyLength,
+ "There is not enough space to copy " + copyLength + "
bytes, writable = " + dest.writableBytes());
+
+ dest.writeBytes(src, copyLength);
+
+ return copyLength;
+ }
+
+ /**
+ * Indicates which part of the current message is being decoded.
+ */
+ private enum DecodeStep {
+ /** The frame header is under decoding. */
+ DECODING_FRAME,
+
+ /** The message header is under decoding. */
+ DECODING_MESSAGE_HEADER,
+
+ /** The data buffer part is under decoding. */
+ DECODING_BUFFER;
+ }
+
+ /**
+ * The intermediate state produced when decoding the current message.
+ */
+ private static class MessageDecodeIntermediateState {
+ /** The message id of current message. */
+ byte msgId = -1;
+
+ /** The length of message header part of current message */
+ int messageHeaderLength = -1;
+
+ /** The parse result of the message header part */
+ NettyMessageParser.MessageHeaderParseResult
messageHeaderParseResult;
+
+ void resetState() {
+ msgId = -1;
+ messageHeaderLength = -1;
+ messageHeaderParseResult = null;
+ }
+ }
+}
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
index ea3646dee3e..7b14262ea8c 100644
---
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
@@ -37,13 +37,15 @@
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
import
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Test;
-import static
org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createBufferResponse;
+import java.io.IOException;
+
import static
org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createRemoteInputChannel;
import static
org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createSingleInputGate;
import static
org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel;
@@ -88,12 +90,12 @@ public void testReleaseInputChannelDuringDecode() throws
Exception {
when(inputChannel.getInputChannelId()).thenReturn(new
InputChannelID());
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
- final BufferResponse receivedBuffer = createBufferResponse(
-
TestBufferFactory.createBuffer(TestBufferFactory.BUFFER_SIZE), 0,
inputChannel.getInputChannelId(), 2);
-
final CreditBasedPartitionRequestClientHandler client = new
CreditBasedPartitionRequestClientHandler();
client.addInputChannel(inputChannel);
+ final BufferResponse receivedBuffer = createBufferResponse(
+
TestBufferFactory.createBuffer(TestBufferFactory.BUFFER_SIZE), 0, inputChannel,
2, client);
+
client.channelRead(mock(ChannelHandlerContext.class),
receivedBuffer);
}
@@ -115,13 +117,13 @@ public void testReceiveEmptyBuffer() throws Exception {
// An empty buffer of size 0
final Buffer emptyBuffer = TestBufferFactory.createBuffer(0);
- final int backlog = 2;
- final BufferResponse receivedBuffer = createBufferResponse(
- emptyBuffer, 0, inputChannel.getInputChannelId(),
backlog);
-
final CreditBasedPartitionRequestClientHandler client = new
CreditBasedPartitionRequestClientHandler();
client.addInputChannel(inputChannel);
+ final int backlog = 2;
+ final BufferResponse receivedBuffer = createBufferResponse(
+ emptyBuffer, 0, inputChannel, backlog, client);
+
// Read the empty buffer
client.channelRead(mock(ChannelHandlerContext.class),
receivedBuffer);
@@ -150,7 +152,7 @@ public void testReceiveBuffer() throws Exception {
final int backlog = 2;
final BufferResponse bufferResponse =
createBufferResponse(
- TestBufferFactory.createBuffer(32), 0,
inputChannel.getInputChannelId(), backlog);
+ TestBufferFactory.createBuffer(32), 0,
inputChannel, backlog, handler);
handler.channelRead(mock(ChannelHandlerContext.class),
bufferResponse);
assertEquals(1,
inputChannel.getNumberOfQueuedBuffers());
@@ -180,7 +182,7 @@ public void testThrowExceptionForNoAvailableBuffer() throws
Exception {
0, inputChannel.getNumberOfAvailableBuffers());
final BufferResponse bufferResponse = createBufferResponse(
-
TestBufferFactory.createBuffer(TestBufferFactory.BUFFER_SIZE), 0,
inputChannel.getInputChannelId(), 2);
+
TestBufferFactory.createBuffer(TestBufferFactory.BUFFER_SIZE), 0, inputChannel,
2, handler);
handler.channelRead(mock(ChannelHandlerContext.class),
bufferResponse);
verify(inputChannel,
times(1)).onError(any(IllegalStateException.class));
@@ -273,9 +275,9 @@ public void testNotifyCreditAvailable() throws Exception {
// The buffer response will take one available buffer
from input channel, and it will trigger
// requesting (backlog + numExclusiveBuffers -
numAvailableBuffers) floating buffers
final BufferResponse bufferResponse1 =
createBufferResponse(
- TestBufferFactory.createBuffer(32), 0,
inputChannel1.getInputChannelId(), 1);
+ TestBufferFactory.createBuffer(32), 0,
inputChannel1, 1, handler);
final BufferResponse bufferResponse2 =
createBufferResponse(
- TestBufferFactory.createBuffer(32), 0,
inputChannel2.getInputChannelId(), 1);
+ TestBufferFactory.createBuffer(32), 0,
inputChannel2, 1, handler);
handler.channelRead(mock(ChannelHandlerContext.class),
bufferResponse1);
handler.channelRead(mock(ChannelHandlerContext.class),
bufferResponse2);
@@ -300,7 +302,7 @@ public void testNotifyCreditAvailable() throws Exception {
// Trigger notify credits availability via buffer
response on the condition of an un-writable channel
final BufferResponse bufferResponse3 =
createBufferResponse(
- TestBufferFactory.createBuffer(32), 1,
inputChannel1.getInputChannelId(), 1);
+ TestBufferFactory.createBuffer(32), 1,
inputChannel1, 1, handler);
handler.channelRead(mock(ChannelHandlerContext.class),
bufferResponse3);
assertEquals(1, inputChannel1.getUnannouncedCredit());
@@ -364,7 +366,7 @@ public void testNotifyCreditAvailableAfterReleased() throws
Exception {
// Trigger request floating buffers via buffer response
to notify credits available
final BufferResponse bufferResponse =
createBufferResponse(
- TestBufferFactory.createBuffer(32), 0,
inputChannel.getInputChannelId(), 1);
+ TestBufferFactory.createBuffer(32), 0,
inputChannel, 1, handler);
handler.channelRead(mock(ChannelHandlerContext.class),
bufferResponse);
assertEquals(2, inputChannel.getUnannouncedCredit());
@@ -388,4 +390,37 @@ public void testNotifyCreditAvailableAfterReleased()
throws Exception {
networkBufferPool.destroy();
}
}
+
+ private BufferResponse<Buffer> createBufferResponse(
+ Buffer buffer,
+ int sequenceNumber,
+ RemoteInputChannel receivingChannel,
+ int backlog,
+ CreditBasedPartitionRequestClientHandler clientHandler)
throws IOException {
+
+ // Mock buffer to serialize
+ BufferResponse<Buffer> resp = new BufferResponse<>(
+ new NettyMessage.FlinkBufferHolder(buffer),
+ buffer.isBuffer(),
+ sequenceNumber,
+ receivingChannel.getInputChannelId(),
+ backlog);
+
+ ByteBuf serialized =
resp.write(UnpooledByteBufAllocator.DEFAULT);
+
+ // Skip general header bytes
+ serialized.readBytes(NettyMessage.FRAME_HEADER_LENGTH);
+
+
+ // Deserialize the bytes again. We have to go this way, because
we only partly deserialize
+ // the header of the response and wait for a buffer from the
buffer pool to copy the payload
+ // data into.
+ BufferResponse<Buffer> deserialized =
BufferResponse.readFrom(serialized, new NetworkBufferAllocator(clientHandler));
+
+ if (deserialized.getBuffer() != null) {
+ deserialized.asByteBuf().writeBytes(buffer.asByteBuf());
+ }
+
+ return deserialized;
+ }
}
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
index a8c473f3dcd..74a8e5ea549 100644
---
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
@@ -18,32 +18,13 @@
package org.apache.flink.runtime.io.network.netty;
-import org.apache.flink.core.memory.MemorySegmentFactory;
-import org.apache.flink.runtime.event.task.IntegerTaskEvent;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.io.network.buffer.Buffer;
-import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
-import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
-import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
-import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
-
-import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
-import org.junit.Test;
-
-import java.util.Random;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-
/**
- * Tests for the serialization and deserialization of the various {@link
NettyMessage} sub-classes.
+ * Tests for the serialization and deserialization of the various {@link
NettyMessage} sub-classes
+ * with the non-zero-copy netty handlers.
*/
-public class NettyMessageSerializationTest {
+public class NettyMessageSerializationTest extends
NettyMessageSerializationTestBase {
public static final boolean RESTORE_OLD_NETTY_BEHAVIOUR = false;
@@ -51,140 +32,15 @@
new NettyMessage.NettyMessageEncoder(), // outbound
messages
new
NettyMessage.NettyMessageDecoder(RESTORE_OLD_NETTY_BEHAVIOUR)); // inbound
messages
- private final Random random = new Random();
-
- @Test
- public void testEncodeDecode() {
- testEncodeDecodeBuffer(false);
- testEncodeDecodeBuffer(true);
-
- {
- {
- IllegalStateException expectedError = new
IllegalStateException();
- InputChannelID receiverId = new
InputChannelID();
-
- NettyMessage.ErrorResponse expected = new
NettyMessage.ErrorResponse(expectedError, receiverId);
- NettyMessage.ErrorResponse actual =
encodeAndDecode(expected);
-
- assertEquals(expected.cause.getClass(),
actual.cause.getClass());
- assertEquals(expected.cause.getMessage(),
actual.cause.getMessage());
- assertEquals(receiverId, actual.receiverId);
- }
-
- {
- IllegalStateException expectedError = new
IllegalStateException("Illegal illegal illegal");
- InputChannelID receiverId = new
InputChannelID();
-
- NettyMessage.ErrorResponse expected = new
NettyMessage.ErrorResponse(expectedError, receiverId);
- NettyMessage.ErrorResponse actual =
encodeAndDecode(expected);
-
- assertEquals(expected.cause.getClass(),
actual.cause.getClass());
- assertEquals(expected.cause.getMessage(),
actual.cause.getMessage());
- assertEquals(receiverId, actual.receiverId);
- }
-
- {
- IllegalStateException expectedError = new
IllegalStateException("Illegal illegal illegal");
-
- NettyMessage.ErrorResponse expected = new
NettyMessage.ErrorResponse(expectedError);
- NettyMessage.ErrorResponse actual =
encodeAndDecode(expected);
-
- assertEquals(expected.cause.getClass(),
actual.cause.getClass());
- assertEquals(expected.cause.getMessage(),
actual.cause.getMessage());
- assertNull(actual.receiverId);
- assertTrue(actual.isFatalError());
- }
- }
-
- {
- NettyMessage.PartitionRequest expected = new
NettyMessage.PartitionRequest(new ResultPartitionID(new
IntermediateResultPartitionID(), new ExecutionAttemptID()), random.nextInt(),
new InputChannelID(), random.nextInt());
- NettyMessage.PartitionRequest actual =
encodeAndDecode(expected);
-
- assertEquals(expected.partitionId, actual.partitionId);
- assertEquals(expected.queueIndex, actual.queueIndex);
- assertEquals(expected.receiverId, actual.receiverId);
- assertEquals(expected.credit, actual.credit);
- }
-
- {
- NettyMessage.TaskEventRequest expected = new
NettyMessage.TaskEventRequest(new IntegerTaskEvent(random.nextInt()), new
ResultPartitionID(new IntermediateResultPartitionID(), new
ExecutionAttemptID()), new InputChannelID());
- NettyMessage.TaskEventRequest actual =
encodeAndDecode(expected);
-
- assertEquals(expected.event, actual.event);
- assertEquals(expected.partitionId, actual.partitionId);
- assertEquals(expected.receiverId, actual.receiverId);
- }
-
- {
- NettyMessage.CancelPartitionRequest expected = new
NettyMessage.CancelPartitionRequest(new InputChannelID());
- NettyMessage.CancelPartitionRequest actual =
encodeAndDecode(expected);
-
- assertEquals(expected.receiverId, actual.receiverId);
- }
-
- {
- NettyMessage.CloseRequest expected = new
NettyMessage.CloseRequest();
- NettyMessage.CloseRequest actual =
encodeAndDecode(expected);
-
- assertEquals(expected.getClass(), actual.getClass());
- }
-
- {
- NettyMessage.AddCredit expected = new
NettyMessage.AddCredit(new ResultPartitionID(new
IntermediateResultPartitionID(), new ExecutionAttemptID()),
random.nextInt(Integer.MAX_VALUE) + 1, new InputChannelID());
- NettyMessage.AddCredit actual =
encodeAndDecode(expected);
-
- assertEquals(expected.partitionId, actual.partitionId);
- assertEquals(expected.credit, actual.credit);
- assertEquals(expected.receiverId, actual.receiverId);
- }
+ @Override
+ public EmbeddedChannel getChannel() {
+ return channel;
}
- private void testEncodeDecodeBuffer(boolean testReadOnlyBuffer) {
- NetworkBuffer buffer = new
NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1024),
FreeingBufferRecycler.INSTANCE);
-
- for (int i = 0; i < 1024; i += 4) {
- buffer.writeInt(i);
- }
-
- Buffer testBuffer = testReadOnlyBuffer ? buffer.readOnlySlice()
: buffer;
-
- NettyMessage.BufferResponse expected = new
NettyMessage.BufferResponse(
- testBuffer, random.nextInt(), new InputChannelID(),
random.nextInt());
- NettyMessage.BufferResponse actual = encodeAndDecode(expected);
-
+ @Override
+ public boolean bufferIsReleasedOnDecoding() {
// Netty 4.1 is not copying the messages, but retaining slices
of them. BufferResponse actual is in this case
// holding a reference to the buffer. Buffer will be recycled
only once "actual" will be released.
- assertFalse(buffer.isRecycled());
- assertFalse(testBuffer.isRecycled());
-
- final ByteBuf retainedSlice = actual.getNettyBuffer();
-
- // Ensure not recycled and same size as original buffer
- assertEquals(1, retainedSlice.refCnt());
- assertEquals(1024, retainedSlice.readableBytes());
-
- for (int i = 0; i < 1024; i += 4) {
- assertEquals(i, retainedSlice.readInt());
- }
-
- // Release the retained slice
- actual.releaseBuffer();
- assertEquals(0, retainedSlice.refCnt());
- assertTrue(buffer.isRecycled());
- assertTrue(testBuffer.isRecycled());
-
- assertEquals(expected.sequenceNumber, actual.sequenceNumber);
- assertEquals(expected.receiverId, actual.receiverId);
- assertEquals(expected.backlog, actual.backlog);
- }
-
- @SuppressWarnings("unchecked")
- private <T extends NettyMessage> T encodeAndDecode(T msg) {
- channel.writeOutbound(msg);
- ByteBuf encoded = (ByteBuf) channel.readOutbound();
-
- assertTrue(channel.writeInbound(encoded));
-
- return (T) channel.readInbound();
+ return false;
}
}
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTestBase.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTestBase.java
new file mode 100644
index 00000000000..e4792c63c5c
--- /dev/null
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTestBase.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.event.task.IntegerTaskEvent;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
+import org.junit.Test;
+
+import java.util.Random;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests for the serialization and deserialization of the various {@link
NettyMessage} sub-classes.
+ */
+public abstract class NettyMessageSerializationTestBase {
+ private final Random random = new Random();
+
+ public abstract EmbeddedChannel getChannel();
+
+ public abstract boolean bufferIsReleasedOnDecoding();
+
+ @Test
+ public void testEncodeDecode() {
+ testEncodeDecodeBuffer(false);
+ testEncodeDecodeBuffer(true);
+
+ {
+ {
+ IllegalStateException expectedError = new
IllegalStateException();
+ InputChannelID receiverId = new
InputChannelID();
+
+ NettyMessage.ErrorResponse expected = new
NettyMessage.ErrorResponse(expectedError, receiverId);
+ NettyMessage.ErrorResponse actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.cause.getClass(),
actual.cause.getClass());
+ assertEquals(expected.cause.getMessage(),
actual.cause.getMessage());
+ assertEquals(receiverId, actual.receiverId);
+ }
+
+ {
+ IllegalStateException expectedError = new
IllegalStateException("Illegal illegal illegal");
+ InputChannelID receiverId = new
InputChannelID();
+
+ NettyMessage.ErrorResponse expected = new
NettyMessage.ErrorResponse(expectedError, receiverId);
+ NettyMessage.ErrorResponse actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.cause.getClass(),
actual.cause.getClass());
+ assertEquals(expected.cause.getMessage(),
actual.cause.getMessage());
+ assertEquals(receiverId, actual.receiverId);
+ }
+
+ {
+ IllegalStateException expectedError = new
IllegalStateException("Illegal illegal illegal");
+
+ NettyMessage.ErrorResponse expected = new
NettyMessage.ErrorResponse(expectedError);
+ NettyMessage.ErrorResponse actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.cause.getClass(),
actual.cause.getClass());
+ assertEquals(expected.cause.getMessage(),
actual.cause.getMessage());
+ assertNull(actual.receiverId);
+ assertTrue(actual.isFatalError());
+ }
+ }
+
+ {
+ NettyMessage.PartitionRequest expected = new
NettyMessage.PartitionRequest(new ResultPartitionID(new
IntermediateResultPartitionID(), new ExecutionAttemptID()), random.nextInt(),
new InputChannelID(), random.nextInt());
+ NettyMessage.PartitionRequest actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.partitionId, actual.partitionId);
+ assertEquals(expected.queueIndex, actual.queueIndex);
+ assertEquals(expected.receiverId, actual.receiverId);
+ assertEquals(expected.credit, actual.credit);
+ }
+
+ {
+ NettyMessage.TaskEventRequest expected = new
NettyMessage.TaskEventRequest(new IntegerTaskEvent(random.nextInt()), new
ResultPartitionID(new IntermediateResultPartitionID(), new
ExecutionAttemptID()), new InputChannelID());
+ NettyMessage.TaskEventRequest actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.event, actual.event);
+ assertEquals(expected.partitionId, actual.partitionId);
+ assertEquals(expected.receiverId, actual.receiverId);
+ }
+
+ {
+ NettyMessage.CancelPartitionRequest expected = new
NettyMessage.CancelPartitionRequest(new InputChannelID());
+ NettyMessage.CancelPartitionRequest actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.receiverId, actual.receiverId);
+ }
+
+ {
+ NettyMessage.CloseRequest expected = new
NettyMessage.CloseRequest();
+ NettyMessage.CloseRequest actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.getClass(), actual.getClass());
+ }
+
+ {
+ NettyMessage.AddCredit expected = new
NettyMessage.AddCredit(new ResultPartitionID(new
IntermediateResultPartitionID(), new ExecutionAttemptID()),
random.nextInt(Integer.MAX_VALUE) + 1, new InputChannelID());
+ NettyMessage.AddCredit actual =
encodeAndDecode(expected);
+
+ assertEquals(expected.partitionId, actual.partitionId);
+ assertEquals(expected.credit, actual.credit);
+ assertEquals(expected.receiverId, actual.receiverId);
+ }
+ }
+
+ private void testEncodeDecodeBuffer(boolean testReadOnlyBuffer) {
+ NetworkBuffer buffer = new
NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1024),
FreeingBufferRecycler.INSTANCE);
+
+ for (int i = 0; i < 1024; i += 4) {
+ buffer.writeInt(i);
+ }
+
+ Buffer testBuffer = testReadOnlyBuffer ? buffer.readOnlySlice()
: buffer;
+
+ NettyMessage.BufferResponse expected = new
NettyMessage.BufferResponse<>(
+ new NettyMessage.FlinkBufferHolder(testBuffer),
+ testBuffer.isBuffer(),
+ random.nextInt(),
+ new InputChannelID(),
+ random.nextInt());
+ NettyMessage.BufferResponse actual = encodeAndDecode(expected);
+
+ if (!bufferIsReleasedOnDecoding()) {
+ assertFalse(buffer.isRecycled());
+ assertFalse(testBuffer.isRecycled());
+ } else {
+ assertTrue(buffer.isRecycled());
+ assertTrue(testBuffer.isRecycled());
+ }
+
+ final ByteBuf retainedSlice = actual.asByteBuf();
+
+ // Ensure not recycled and same size as original buffer
+ assertEquals(1, retainedSlice.refCnt());
+ assertEquals(1024, retainedSlice.readableBytes());
+
+ for (int i = 0; i < 1024; i += 4) {
+ assertEquals(i, retainedSlice.readInt());
+ }
+
+ // Release the retained slice
+ actual.releaseBuffer();
+ assertEquals(0, retainedSlice.refCnt());
+ assertTrue(buffer.isRecycled());
+ assertTrue(testBuffer.isRecycled());
+
+ assertEquals(expected.sequenceNumber, actual.sequenceNumber);
+ assertEquals(expected.receiverId, actual.receiverId);
+ assertEquals(expected.backlog, actual.backlog);
+ }
+
+ @SuppressWarnings("unchecked")
+ private <T extends NettyMessage> T encodeAndDecode(T msg) {
+ EmbeddedChannel channel = getChannel();
+ channel.writeOutbound(msg);
+ ByteBuf encoded = channel.readOutbound();
+ assertTrue(channel.writeInbound(encoded));
+ return (T) channel.readInbound();
+ }
+}
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
index 842aed8186d..8e459ff4708 100644
---
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
@@ -289,7 +289,12 @@ static BufferResponse createBufferResponse(
int backlog) throws IOException {
// Mock buffer to serialize
- BufferResponse resp = new BufferResponse(buffer,
sequenceNumber, receivingChannelId, backlog);
+ BufferResponse resp = new BufferResponse<>(
+ new NettyMessage.FlinkBufferHolder(buffer),
+ buffer.isBuffer(),
+ sequenceNumber,
+ receivingChannelId,
+ backlog);
ByteBuf serialized =
resp.write(UnpooledByteBufAllocator.DEFAULT);
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageDecoderTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageDecoderTest.java
new file mode 100644
index 00000000000..5f9f202a7af
--- /dev/null
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageDecoderTest.java
@@ -0,0 +1,338 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.commons.lang3.builder.EqualsBuilder;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
+import
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.io.network.util.TestTaskEvent;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.PooledByteBufAllocator;
+import
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.assertTrue;
+import static
org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createRemoteInputChannel;
+import static
org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createSingleInputGate;
+import static org.apache.flink.util.Preconditions.checkState;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+import static org.powermock.api.mockito.PowerMockito.spy;
+import static org.powermock.api.mockito.PowerMockito.when;
+
+/**
+ * Tests the zero copy message decoder.
+ */
+public class ZeroCopyNettyMessageDecoderTest {
+ private static final PooledByteBufAllocator ALLOCATOR = new
PooledByteBufAllocator();
+
+ private static final InputChannelID NORMAL_INPUT_CHANNEL_ID = new
InputChannelID();
+ private static final InputChannelID RELEASED_INPUT_CHANNEL_ID = new
InputChannelID();
+
+ /**
+ * Verifies that the message decoder works well for the upstream netty
handler pipeline.
+ */
+ @Test
+ public void testUpstreamMessageDecoder() throws Exception {
+ EmbeddedChannel channel = new EmbeddedChannel(new
ZeroCopyNettyMessageDecoder(new NoDataBufferMessageParser()));
+ NettyMessage[] messages = new NettyMessage[]{
+ new NettyMessage.PartitionRequest(new
ResultPartitionID(), 1, NORMAL_INPUT_CHANNEL_ID, 2),
+ new NettyMessage.TaskEventRequest(new TestTaskEvent(),
new ResultPartitionID(), NORMAL_INPUT_CHANNEL_ID),
+ new NettyMessage.CloseRequest(),
+ new
NettyMessage.CancelPartitionRequest(NORMAL_INPUT_CHANNEL_ID),
+ new NettyMessage.AddCredit(new ResultPartitionID(), 2,
NORMAL_INPUT_CHANNEL_ID),
+ };
+
+ // Segment points:
+ // +--------+--------+--------+-----------+-----------+
+ // | | | | | |
+ // +--------+--------+--------+-----------+-----------+
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+ ByteBuf[] splitBuffers = segmentMessages(messages, 3, new int[]
{
+ 1, 7, 11, 14
+ });
+ readInputAndVerify(channel, splitBuffers, messages);
+
+ splitBuffers = segmentMessages(messages, 3, new int[] {
+ 1, 4, 7, 9, 12, 14
+ });
+ readInputAndVerify(channel, splitBuffers, messages);
+ }
+
+ /**
+ * Verifies that the message decoder works well for the downstream
netty handler pipeline.
+ */
+ @Test
+ public void testDownstreamMessageDecode() throws Exception {
+ // 8 buffers required for running 2 rounds and 4 buffers each
round.
+
+ EmbeddedChannel channel = new EmbeddedChannel(
+ new ZeroCopyNettyMessageDecoder(new
BufferResponseAndNoDataBufferMessageParser(
+ new
NetworkBufferAllocator(createPartitionRequestClientHandler(8)))));
+
+ NettyMessage[] messages = new NettyMessage[]{
+ createBufferResponse(128, true, 0, NORMAL_INPUT_CHANNEL_ID,
4),
+ createBufferResponse(256, true, 1, NORMAL_INPUT_CHANNEL_ID, 3),
+ createBufferResponse(32, false, 2, NORMAL_INPUT_CHANNEL_ID, 4),
+ new NettyMessage.ErrorResponse(new EquableException("test"),
NORMAL_INPUT_CHANNEL_ID),
+ createBufferResponse(56, true, 3, NORMAL_INPUT_CHANNEL_ID, 4)
+ };
+
+ // Segment points of the above five messages are
+ // +--------+--------+--------+-----------+-----------+
+ // | | | | | |
+ // +--------+--------+--------+-----------+-----------+
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+ ByteBuf[] splitBuffers = segmentMessages(messages, 3, new int[]
{
+ 1, 7, 11, 14
+ });
+ readInputAndVerify(channel, splitBuffers, messages);
+
+ splitBuffers = segmentMessages(messages, 3, new int[] {
+ 1, 4, 7, 9, 12, 14
+ });
+ readInputAndVerify(channel, splitBuffers, messages);
+ }
+
+ /**
+ * Verifies that NettyMessageDecoder works well with buffers sent to a
released channel.
+ * For such a channel, no Buffer is available to receive the data
buffer in the message,
+ * and the data buffer part should be discarded before reading the next
message.
+ */
+ @Test
+ public void testDownstreamMessageDecodeWithReleasedInputChannel()
throws Exception {
+ // 6 buffers required for running 2 rounds and 3 buffers each
round.
+ EmbeddedChannel channel = new EmbeddedChannel(
+ new ZeroCopyNettyMessageDecoder(new
BufferResponseAndNoDataBufferMessageParser(
+ new
NetworkBufferAllocator(createPartitionRequestClientHandler(6)))));
+
+ NettyMessage[] messages = new NettyMessage[]{
+ createBufferResponse(128, true, 0, NORMAL_INPUT_CHANNEL_ID, 4),
+ createBufferResponse(256, true, 1, RELEASED_INPUT_CHANNEL_ID,
3),
+ createBufferResponse(32, false, 2, NORMAL_INPUT_CHANNEL_ID, 4),
+ new NettyMessage.ErrorResponse(new EquableException("test"),
RELEASED_INPUT_CHANNEL_ID),
+ createBufferResponse(64, false,3, NORMAL_INPUT_CHANNEL_ID, 4),
+ };
+
+ // Segment points of the above five messages are
+ // +--------+--------+--------+-----------+-----------+
+ // | | | | | |
+ // +--------+--------+--------+-----------+-----------+
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+ ByteBuf[] splitBuffers = segmentMessages(messages, 3, new int[]{
+ 1, 4, 7, 9, 12, 14
+ });
+ readInputAndVerify(channel, splitBuffers, messages);
+
+ splitBuffers = segmentMessages(messages, 3, new int[]{
+ 1, 3, 4, 5, 7, 10, 13
+ });
+ readInputAndVerify(channel, splitBuffers, messages);
+ }
+
+
//------------------------------------------------------------------------------------------------------------------
+
+ private void readInputAndVerify(EmbeddedChannel channel, ByteBuf[]
inputBuffers, NettyMessage[] expected) {
+ for (ByteBuf buffer : inputBuffers) {
+ channel.writeInbound(buffer);
+ }
+
+ channel.runPendingTasks();
+
+ List<NettyMessage> decodedMessages = new ArrayList<>();
+ Object input;
+ while ((input = channel.readInbound()) != null) {
+ assertTrue(input instanceof NettyMessage);
+ decodedMessages.add((NettyMessage) input);
+ }
+
+ assertEquals(expected.length, decodedMessages.size());
+ for (int i = 0; i < expected.length; ++i) {
+ assertEquals(expected[i].getClass(),
decodedMessages.get(i).getClass());
+
+ if (expected[i] instanceof NettyMessage.AddCredit ||
+ expected[i] instanceof NettyMessage.PartitionRequest ||
+ expected[i] instanceof NettyMessage.TaskEventRequest ||
+ expected[i] instanceof NettyMessage.CancelPartitionRequest ||
+ expected[i] instanceof NettyMessage.CloseRequest ||
+ expected[i] instanceof NettyMessage.ErrorResponse) {
+
+ assertTrue("Received different message,
expected is " + expected[i] + ", actual is " + decodedMessages.get(i),
+ EqualsBuilder.reflectionEquals(expected[i],
decodedMessages.get(i)));
+ } else if (expected[i] instanceof
NettyMessage.BufferResponse) {
+ assertEquals(((NettyMessage.BufferResponse)
expected[i]).backlog, ((NettyMessage.BufferResponse)
decodedMessages.get(i)).backlog);
+ assertEquals(((NettyMessage.BufferResponse)
expected[i]).sequenceNumber, ((NettyMessage.BufferResponse)
decodedMessages.get(i)).sequenceNumber);
+ assertEquals(((NettyMessage.BufferResponse)
expected[i]).isBuffer, ((NettyMessage.BufferResponse)
decodedMessages.get(i)).isBuffer);
+ assertEquals(((NettyMessage.BufferResponse)
expected[i]).bufferSize, ((NettyMessage.BufferResponse)
decodedMessages.get(i)).bufferSize);
+ assertEquals(((NettyMessage.BufferResponse)
expected[i]).receiverId, ((NettyMessage.BufferResponse)
decodedMessages.get(i)).receiverId);
+
+ if (((NettyMessage.BufferResponse)
expected[i]).receiverId.equals(RELEASED_INPUT_CHANNEL_ID)) {
+
assertNull(((NettyMessage.BufferResponse) decodedMessages.get(i)).getBuffer());
+ } else {
+
assertEquals(((NettyMessage.BufferResponse) expected[i]).getBuffer(),
((NettyMessage.BufferResponse) decodedMessages.get(i)).getBuffer());
+ }
+ } else {
+ fail("Unsupported message type");
+ }
+ }
+ }
+
+ /**
+ * Segments the serialized buffer of the messages. This method first
segments each message into
+ * numberOfSegmentsEachMessage parts, and number all the boundary and
inner segment points from
+ * 0. Then the segment points whose index is in the segmentPointIndex
are used to segment
+ * the serialized buffer.
+ *
+ * <p>For example, suppose there are 3 messages and
numberOfSegmentsEachMessage is 3,
+ * then all the available segment points are:
+ *
+ * <pre>
+ * +---------------+---------------+-------------------+
+ * | | | |
+ * +---------------+---------------+-------------------+
+ * 0 1 2 3 4 5 6 7 8 9
+ * </pre>
+ *
+ * @param messages The messages to be serialized and segmented.
+ * @param numberOfSegmentsEachMessage How much parts each message is
segmented into.
+ * @param segmentPointIndex The chosen segment points.
+ * @return The segmented ByteBuf.
+ */
+ private ByteBuf[] segmentMessages(NettyMessage[] messages, int
numberOfSegmentsEachMessage, int[] segmentPointIndex) throws Exception {
+ List<Integer> segmentPoints = new ArrayList<>();
+ ByteBuf allData = ALLOCATOR.buffer();
+
+ int startBytesOfCurrentMessage = 0;
+ for (NettyMessage message : messages) {
+ ByteBuf buf = message.write(ALLOCATOR);
+
+ // Records the position of each segments point.
+ int length = buf.readableBytes();
+
+ for (int i = 0; i < numberOfSegmentsEachMessage; ++i) {
+ segmentPoints.add(startBytesOfCurrentMessage +
length * i / numberOfSegmentsEachMessage);
+ }
+
+ allData.writeBytes(buf);
+ startBytesOfCurrentMessage += length;
+ }
+
+ // Adds the last segment point.
+ segmentPoints.add(allData.readableBytes());
+
+ // Segments the serialized buffer according to the segment
points.
+ List<ByteBuf> segmentedBuffers = new ArrayList<>();
+
+ for (int i = 0; i <= segmentPointIndex.length; ++i) {
+ ByteBuf buf = ALLOCATOR.buffer();
+
+ int startPos = (i == 0 ? 0 :
segmentPoints.get(segmentPointIndex[i - 1]));
+ int endPos = (i == segmentPointIndex.length ?
segmentPoints.get(segmentPoints.size() - 1) :
segmentPoints.get(segmentPointIndex[i]));
+
+ checkState(startPos == allData.readerIndex());
+
+ buf.writeBytes(allData, endPos - startPos);
+ segmentedBuffers.add(buf);
+ }
+
+ checkState(!allData.isReadable());
+ return segmentedBuffers.toArray(new ByteBuf[0]);
+ }
+
+ private NettyMessage.BufferResponse<Buffer> createBufferResponse(
+ int size,
+ boolean isBuffer,
+ int sequenceNumber,
+ InputChannelID receiverID,
+ int backlog) {
+
+ MemorySegment segment =
MemorySegmentFactory.allocateUnpooledSegment(size);
+ NetworkBuffer buffer = new NetworkBuffer(segment,
FreeingBufferRecycler.INSTANCE);
+ for (int i = 0; i < size / 4; ++i) {
+ buffer.writeInt(i);
+ }
+
+ if (!isBuffer) {
+ buffer.tagAsEvent();
+ }
+
+ return new NettyMessage.BufferResponse<>(
+ new NettyMessage.FlinkBufferHolder(buffer),
+ isBuffer,
+ sequenceNumber,
+ receiverID,
+ backlog);
+ }
+
+ private CreditBasedPartitionRequestClientHandler
createPartitionRequestClientHandler(int numberOfBuffersInNormalChannel) throws
Exception {
+ final NetworkBufferPool networkBufferPool = new
NetworkBufferPool(
+ numberOfBuffersInNormalChannel,
+ 32 * 1024);
+
+ final SingleInputGate inputGate = createSingleInputGate();
+
+ final RemoteInputChannel normalInputChannel =
spy(createRemoteInputChannel(inputGate));
+
when(normalInputChannel.getInputChannelId()).thenReturn(NORMAL_INPUT_CHANNEL_ID);
+
+ // Assign exclusive segments before add the released input
channel so that the
+ // released input channel will return null when requesting
buffers. This will
+ // test the process of discarding the data buffer.
+ inputGate.assignExclusiveSegments(networkBufferPool,
numberOfBuffersInNormalChannel);
+
+ final RemoteInputChannel releasedInputChannel =
spy(createRemoteInputChannel(inputGate));
+
when(releasedInputChannel.getInputChannelId()).thenReturn(RELEASED_INPUT_CHANNEL_ID);
+ when(releasedInputChannel.isReleased()).thenReturn(true);
+
+ CreditBasedPartitionRequestClientHandler handler = new
CreditBasedPartitionRequestClientHandler();
+ handler.addInputChannel(normalInputChannel);
+ handler.addInputChannel(releasedInputChannel);
+
+ return handler;
+ }
+
+ /**
+ * A specialized exception who compares two objects by comparing the
messages.
+ */
+ private static class EquableException extends RuntimeException {
+ EquableException(String message) {
+ super(message);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof EquableException)) {
+ return false;
+ }
+
+ return getMessage().equals(((EquableException) obj).getMessage());
+ }
+ }
+}
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageSerializationTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageSerializationTest.java
new file mode 100644
index 00000000000..a066aa0f9c4
--- /dev/null
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ZeroCopyNettyMessageSerializationTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
+import
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.powermock.api.mockito.PowerMockito.when;
+
+/**
+ * Tests for the serialization and deserialization of the various {@link
NettyMessage} sub-classes with
+ * the zero-copy netty handlers.
+ */
+public class ZeroCopyNettyMessageSerializationTest extends
NettyMessageSerializationTestBase {
+ private final EmbeddedChannel channel = new EmbeddedChannel(
+ new NettyMessage.NettyMessageEncoder(), // outbound
messages
+ new ZeroCopyNettyMessageDecoder(new
BufferResponseAndNoDataBufferMessageParser(
+ new
NetworkBufferAllocator(createPartitionRequestClientHandler()))));
+
+ @Override
+ public EmbeddedChannel getChannel() {
+ return channel;
+ }
+
+ @Override
+ public boolean bufferIsReleasedOnDecoding() {
+ // For ZeroCopyMessageDecoder, the input buffer will be copied
to the input channels directly and thus the
+ // input channel will be released.
+ return true;
+ }
+
+ private CreditBasedPartitionRequestClientHandler
createPartitionRequestClientHandler() {
+ CreditBasedPartitionRequestClientHandler handler =
mock(CreditBasedPartitionRequestClientHandler.class);
+
+ RemoteInputChannel inputChannel =
mock(RemoteInputChannel.class);
+ when(inputChannel.requestBuffer()).thenAnswer(new
Answer<Buffer>() {
+ @Override
+ public Buffer answer(InvocationOnMock invocationOnMock)
throws Throwable {
+ return new
NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1024),
FreeingBufferRecycler.INSTANCE);
+ }
+ });
+
when(handler.getInputChannel(any(InputChannelID.class))).thenReturn(inputChannel);
+
+ return handler;
+ }
+}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services