This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 34db3f118 [#1187] feat(netty): Netty Encoder Support zero-copy. (#1313)
34db3f118 is described below
commit 34db3f118c62407a880a3b12d8f582e8c556bf89
Author: Xianming Lei <[email protected]>
AuthorDate: Fri Nov 24 16:33:45 2023 +0800
[#1187] feat(netty): Netty Encoder Support zero-copy. (#1313)
### What changes were proposed in this pull request?
Netty Encoder Support zero-copy.
### Why are the changes needed?
Reduce ByteBuf copy.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UTs.
Co-authored-by: leixianming <[email protected]>
---
.../client/factory/ShuffleClientFactory.java | 11 ++
.../uniffle/client/impl/ShuffleReadClientImpl.java | 12 +-
.../apache/uniffle/common/ShuffleDataResult.java | 13 +-
.../apache/uniffle/common/ShuffleIndexResult.java | 4 +
.../uniffle/common/config/RssClientConf.java | 7 +
.../apache/uniffle/common/netty/FrameDecoder.java | 2 +-
.../uniffle/common/netty/MessageEncoder.java | 79 ++++++---
.../common/netty/TransportFrameDecoder.java | 15 +-
.../netty/buffer/FileSegmentManagedBuffer.java | 5 +
.../uniffle/common/netty/buffer/ManagedBuffer.java | 2 +
.../common/netty/buffer/NettyManagedBuffer.java | 9 +
.../netty/handle/TransportResponseHandler.java | 8 +-
.../AbstractFileRegion.java} | 39 ++---
.../uniffle/common/netty/protocol/Decoders.java | 4 +-
.../protocol/GetLocalShuffleDataResponse.java | 51 ++----
.../protocol/GetLocalShuffleIndexResponse.java | 42 ++---
.../protocol/GetMemoryShuffleDataRequest.java | 16 +-
.../protocol/GetMemoryShuffleDataResponse.java | 39 +++--
.../uniffle/common/netty/protocol/Message.java | 30 +++-
.../common/netty/protocol/MessageWithHeader.java | 193 +++++++++++++++++++++
.../common/netty/protocol/RequestMessage.java | 8 +-
.../{Transferable.java => ResponseMessage.java} | 15 +-
.../uniffle/common/netty/protocol/RpcResponse.java | 31 +++-
.../apache/uniffle/common/util/ByteBufUtils.java | 3 +
.../org/apache/uniffle/common/util/NettyUtils.java | 20 +++
.../common/netty/protocol/NettyProtocolTest.java | 16 +-
.../test/ShuffleServerFaultToleranceTest.java | 2 +
.../client/impl/grpc/ShuffleServerGrpcClient.java | 5 +-
.../impl/grpc/ShuffleServerGrpcNettyClient.java | 6 +-
.../RssGetInMemoryShuffleDataResponse.java | 13 +-
.../client/response/RssGetShuffleDataResponse.java | 12 +-
.../response/RssGetShuffleIndexResponse.java | 5 +-
.../uniffle/server/buffer/ShuffleBuffer.java | 6 +-
.../server/netty/ShuffleServerNettyHandler.java | 27 ++-
.../apache/uniffle/server/netty/StreamServer.java | 6 +-
.../storage/factory/ShuffleHandlerFactory.java | 5 +-
.../handler/impl/DataSkippableReadHandler.java | 1 +
.../request/CreateShuffleReadHandlerRequest.java | 11 ++
.../impl/LocalFileServerReadHandlerTest.java | 7 +-
39 files changed, 562 insertions(+), 218 deletions(-)
diff --git
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index 6baccc05b..ce7d90041 100644
---
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -26,6 +26,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
@@ -210,6 +211,7 @@ public class ShuffleClientFactory {
private String storageType;
private int indexReadLimit;
private long readBufferSize;
+ private ClientType clientType;
public ReadClientBuilder appId(String appId) {
this.appId = appId;
@@ -303,6 +305,11 @@ public class ShuffleClientFactory {
return this;
}
+ public ReadClientBuilder clientType(ClientType clientType) {
+ this.clientType = clientType;
+ return this;
+ }
+
public ReadClientBuilder() {}
public String getAppId() {
@@ -377,6 +384,10 @@ public class ShuffleClientFactory {
return readBufferSize;
}
+ public ClientType getClientType() {
+ return clientType;
+ }
+
public ShuffleReadClientImpl build() {
return new ShuffleReadClientImpl(this);
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 24fb20609..e94f5aa6d 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -55,6 +55,7 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
private int shuffleId;
private int partitionId;
private ByteBuffer readBuffer;
+ private ShuffleDataResult sdr;
private Roaring64NavigableMap blockIdBitmap;
private Roaring64NavigableMap taskIdBitmap;
private Roaring64NavigableMap pendingBlockIds;
@@ -96,6 +97,7 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
builder.storageType(storageType);
builder.readBufferSize(readBufferSize);
builder.offHeapEnable(offHeapEnabled);
+
builder.clientType(builder.getRssConf().get(RssClientConf.RSS_CLIENT_TYPE));
} else {
// most for test
RssConf rssConf = new RssConf();
@@ -107,6 +109,7 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
builder.rssConf(rssConf);
builder.offHeapEnable(false);
builder.expectedTaskIdsBitmapFilterEnable(false);
+ builder.clientType(rssConf.get(RssClientConf.RSS_CLIENT_TYPE));
}
init(builder);
@@ -138,6 +141,7 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
request.setIdHelper(idHelper);
request.setExpectTaskIds(taskIdBitmap);
request.setClientConf(builder.getRssConf());
+ request.setClientType(builder.getClientType());
if (builder.isExpectedTaskIdsBitmapFilterEnable()) {
request.useExpectedTaskIdsBitmapFilter();
}
@@ -258,7 +262,13 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
private int read() {
long start = System.currentTimeMillis();
- ShuffleDataResult sdr = clientReadHandler.readShuffleData();
+ // In order to avoid copying, we postpone the release here instead of in
the Decoder.
+ // RssUtils.releaseByteBuffer(readBuffer) cannot actually release the
memory,
+ // because PlatformDependent.freeDirectBuffer can only release the
ByteBuffer with cleaner.
+ if (sdr != null) {
+ sdr.release();
+ }
+ sdr = clientReadHandler.readShuffleData();
readDataTime.addAndGet(System.currentTimeMillis() - start);
if (sdr == null) {
return 0;
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
index e63357e10..00867e700 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
@@ -53,14 +53,18 @@ public class ShuffleDataResult {
}
public ShuffleDataResult(ByteBuf data, List<BufferSegment> bufferSegments) {
- this.buffer = new NettyManagedBuffer(data);
- this.bufferSegments = bufferSegments;
+ this(new NettyManagedBuffer(data), bufferSegments);
}
public ShuffleDataResult(byte[] data, List<BufferSegment> bufferSegments) {
this(data != null ? ByteBuffer.wrap(data) : null, bufferSegments);
}
+ public ShuffleDataResult(ManagedBuffer data, List<BufferSegment>
bufferSegments) {
+ this.buffer = data;
+ this.bufferSegments = bufferSegments;
+ }
+
public byte[] getData() {
if (buffer == null) {
return null;
@@ -75,10 +79,7 @@ public class ShuffleDataResult {
if (buffer == null) {
return 0;
}
- if (buffer.nioByteBuffer().hasArray()) {
- return buffer.nioByteBuffer().array().length;
- }
- return buffer.nioByteBuffer().remaining();
+ return buffer.size();
}
public ByteBuf getDataBuf() {
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
index 2a686c44f..71bb3df39 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
@@ -75,4 +75,8 @@ public class ShuffleIndexResult {
this.buffer.release();
}
}
+
+ public ManagedBuffer getManagedBuffer() {
+ return buffer;
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index c516f7470..86eb1d950 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.common.config;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.netty.IOMode;
@@ -135,4 +136,10 @@ public class RssClientConf {
.stringType()
.defaultValue("14m")
.withDescription("The max data size read from storage");
+
+ public static final ConfigOption<ClientType> RSS_CLIENT_TYPE =
+ ConfigOptions.key("rss.client.type")
+ .enumType(ClientType.class)
+ .defaultValue(ClientType.GRPC)
+ .withDescription("Supports GRPC, GRPC_NETTY");
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
index 53366dc3c..493fbf8d6 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
@@ -19,5 +19,5 @@ package org.apache.uniffle.common.netty;
public interface FrameDecoder {
String HANDLER_NAME = "FrameDecoder";
- int HEADER_SIZE = Integer.BYTES + Byte.BYTES;
+ int HEADER_SIZE = Integer.BYTES + Byte.BYTES + Integer.BYTES;
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
index 5221cc428..d2947b3bc 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
@@ -17,50 +17,81 @@
package org.apache.uniffle.common.netty;
+import java.util.List;
+
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelOutboundHandlerAdapter;
-import io.netty.channel.ChannelPromise;
+import io.netty.handler.codec.MessageToMessageEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.netty.protocol.Message;
-import org.apache.uniffle.common.netty.protocol.Transferable;
+import org.apache.uniffle.common.netty.protocol.MessageWithHeader;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
/**
* Encoder used by the server side to encode server-to-client responses. This
encoder is stateless
- * so it is safe to be shared by multiple threads. The content of encode
consists of two parts,
- * header and message body. The encoded binary stream contains encodeLength (4
bytes), messageType
- * (1 byte) and messageBody (encodeLength bytes).
+ * so it is safe to be shared by multiple threads.
*/
@ChannelHandler.Sharable
-public class MessageEncoder extends ChannelOutboundHandlerAdapter {
+public final class MessageEncoder extends MessageToMessageEncoder<Message> {
- private static final Logger LOG =
LoggerFactory.getLogger(MessageEncoder.class);
+ private static final Logger logger =
LoggerFactory.getLogger(MessageEncoder.class);
public static final MessageEncoder INSTANCE = new MessageEncoder();
private MessageEncoder() {}
+ /**
+ * * Encodes a Message by invoking its encode() method. For non-data
messages, we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and
the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer
corresponding to the
+ * data to 'out', in order to enable zero-copy transfer.
+ */
@Override
- public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise
promise) {
- Message message = (Message) msg;
- int encodeLength = message.encodedLength();
- ByteBuf byteBuf = ctx.alloc().buffer(FrameDecoder.HEADER_SIZE +
encodeLength);
- try {
- byteBuf.writeInt(encodeLength);
- byteBuf.writeByte(message.type().id());
- message.encode(byteBuf);
- } catch (Exception e) {
- LOG.error("Unexpected exception during process encode!", e);
- byteBuf.release();
- throw e;
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out)
throws Exception {
+ Object body = null;
+ int bodyLength = 0;
+
+ // If the message has a body, take it out to enable zero-copy transfer for
the payload.
+ if (in.body() != null) {
+ try {
+ bodyLength = (int) in.body().size();
+ body = in.body().convertToNetty();
+ } catch (Exception e) {
+ in.body().release();
+ if (in instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) in;
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(
+ String.format("Error processing %s for client %s", in,
ctx.channel().remoteAddress()),
+ e);
+ encode(ctx, resp.createFailureResponse(error), out);
+ } else {
+ throw e;
+ }
+ return;
+ }
}
- ctx.writeAndFlush(byteBuf);
- // do transferTo send data after encode buffer send.
- if (message instanceof Transferable) {
- ((Transferable) message).transferTo(ctx.channel());
+
+ Message.Type msgType = in.type();
+ // message size, message type size, body size, message encoded length
+ int headerLength = Integer.BYTES + msgType.encodedLength() + Integer.BYTES
+ in.encodedLength();
+ ByteBuf header = ctx.alloc().heapBuffer(headerLength);
+ header.writeInt(in.encodedLength());
+ msgType.encode(header);
+ header.writeInt(bodyLength);
+ in.encode(header);
+ assert header.writableBytes() == 0;
+
+ if (body != null) {
+ // We transfer ownership of the reference on in.body() to
MessageWithHeader.
+ // This reference will be freed when MessageWithHeader.deallocate() is
called.
+ out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
+ } else {
+ out.add(header);
}
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
index d9fd734cd..4a7b8ab4b 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
@@ -45,6 +45,7 @@ import org.apache.uniffle.common.netty.protocol.Message;
*/
public class TransportFrameDecoder extends ChannelInboundHandlerAdapter
implements FrameDecoder {
private int msgSize = -1;
+ private int bodySize = -1;
private Message.Type curType = Message.Type.UNKNOWN_TYPE;
private ByteBuf headerBuf = Unpooled.buffer(HEADER_SIZE, HEADER_SIZE);
private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
@@ -66,10 +67,10 @@ public class TransportFrameDecoder extends
ChannelInboundHandlerAdapter implemen
if (frame == null) {
break;
}
- // todo: An exception may be thrown during the decoding process, causing
frame.release() to
- // fail to be called
Message msg = Message.decode(curType, frame);
- frame.release();
+ if (msg.body() == null) {
+ frame.release();
+ }
ctx.fireChannelRead(msg);
clear();
}
@@ -78,6 +79,7 @@ public class TransportFrameDecoder extends
ChannelInboundHandlerAdapter implemen
private void clear() {
curType = Message.Type.UNKNOWN_TYPE;
msgSize = -1;
+ bodySize = -1;
headerBuf.clear();
}
@@ -94,7 +96,8 @@ public class TransportFrameDecoder extends
ChannelInboundHandlerAdapter implemen
if (first.readableBytes() >= HEADER_SIZE) {
msgSize = first.readInt();
curType = Message.Type.decode(first);
- nextFrameSize = msgSize;
+ bodySize = first.readInt();
+ nextFrameSize = msgSize + bodySize;
totalSize -= HEADER_SIZE;
if (!first.isReadable()) {
buffers.removeFirst().release();
@@ -113,7 +116,8 @@ public class TransportFrameDecoder extends
ChannelInboundHandlerAdapter implemen
msgSize = headerBuf.readInt();
curType = Message.Type.decode(headerBuf);
- nextFrameSize = msgSize;
+ bodySize = headerBuf.readInt();
+ nextFrameSize = msgSize + bodySize;
totalSize -= HEADER_SIZE;
return nextFrameSize;
}
@@ -126,7 +130,6 @@ public class TransportFrameDecoder extends
ChannelInboundHandlerAdapter implemen
// Reset size for next frame.
nextFrameSize = UNKNOWN_FRAME_SIZE;
-
Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame:
%s", frameSize);
Preconditions.checkArgument(frameSize > 0, "Frame length should be
positive: %s", frameSize);
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
index a4975bde4..a83492f79 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
@@ -91,6 +91,11 @@ public class FileSegmentManagedBuffer extends ManagedBuffer {
}
}
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
@Override
public ManagedBuffer release() {
return this;
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
index ed15640e9..844ed4c13 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
@@ -29,6 +29,8 @@ public abstract class ManagedBuffer {
public abstract ByteBuffer nioByteBuffer();
+ public abstract ManagedBuffer retain();
+
public abstract ManagedBuffer release();
/**
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
index 4cc6686a5..53286429d 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
@@ -24,6 +24,9 @@ import io.netty.buffer.Unpooled;
public class NettyManagedBuffer extends ManagedBuffer {
+ public static final NettyManagedBuffer EMPTY_BUFFER =
+ new NettyManagedBuffer(Unpooled.buffer(0, 0));
+
private ByteBuf buf;
public NettyManagedBuffer(ByteBuf byteBuf) {
@@ -45,6 +48,12 @@ public class NettyManagedBuffer extends ManagedBuffer {
return buf.nioBuffer();
}
+ @Override
+ public ManagedBuffer retain() {
+ buf.retain();
+ return this;
+ }
+
@Override
public ManagedBuffer release() {
buf.release();
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
index f1ab3adc9..f7c9ccc53 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
@@ -61,9 +61,11 @@ public class TransportResponseHandler extends
MessageHandler<RpcResponse> {
public void handle(RpcResponse message) throws Exception {
RpcResponseCallback listener =
outstandingRpcRequests.get(message.getRequestId());
if (listener == null) {
- logger.warn(
- "Ignoring response from {} since it is not outstanding",
- NettyUtils.getRemoteAddress(channel));
+ logger.error(
+ "Ignoring response from {} since it is not outstanding, {} {}",
+ NettyUtils.getRemoteAddress(channel),
+ message.type(),
+ message.getRequestId());
} else {
listener.onSuccess(message);
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/AbstractFileRegion.java
similarity index 59%
copy from
common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
copy to
common/src/main/java/org/apache/uniffle/common/netty/protocol/AbstractFileRegion.java
index 4cc6686a5..c716ba7d8 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/AbstractFileRegion.java
@@ -15,44 +15,39 @@
* limitations under the License.
*/
-package org.apache.uniffle.common.netty.buffer;
+package org.apache.uniffle.common.netty.protocol;
-import java.nio.ByteBuffer;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-
-public class NettyManagedBuffer extends ManagedBuffer {
-
- private ByteBuf buf;
-
- public NettyManagedBuffer(ByteBuf byteBuf) {
- this.buf = byteBuf;
- }
+public abstract class AbstractFileRegion extends AbstractReferenceCounted
implements FileRegion {
@Override
- public int size() {
- return buf.readableBytes();
+ @SuppressWarnings("deprecation")
+ public final long transfered() {
+ return transferred();
}
@Override
- public ByteBuf byteBuf() {
- return Unpooled.wrappedBuffer(this.nioByteBuffer());
+ public AbstractFileRegion retain() {
+ super.retain();
+ return this;
}
@Override
- public ByteBuffer nioByteBuffer() {
- return buf.nioBuffer();
+ public AbstractFileRegion retain(int increment) {
+ super.retain(increment);
+ return this;
}
@Override
- public ManagedBuffer release() {
- buf.release();
+ public AbstractFileRegion touch() {
+ super.touch();
return this;
}
@Override
- public Object convertToNetty() {
- return buf.duplicate().retain();
+ public AbstractFileRegion touch(Object o) {
+ return this;
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
index db6c2a000..b8c687ce7 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
@@ -23,12 +23,12 @@ import java.util.Map;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.ByteBufUtils;
+import org.apache.uniffle.common.util.NettyUtils;
public class Decoders {
public static ShuffleServerInfo decodeShuffleServerInfo(ByteBuf byteBuf) {
@@ -47,7 +47,7 @@ public class Decoders {
long crc = byteBuf.readLong();
long taskAttemptId = byteBuf.readLong();
int dataLength = byteBuf.readInt();
- ByteBuf data = Unpooled.directBuffer(dataLength);
+ ByteBuf data =
NettyUtils.getNettyBufferAllocator().directBuffer(dataLength);
data.writeBytes(byteBuf, dataLength);
int lengthOfShuffleServers = byteBuf.readInt();
List<ShuffleServerInfo> serverInfos = Lists.newArrayList();
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
index e1b7af7c4..1a078b7a0 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
@@ -18,65 +18,34 @@
package org.apache.uniffle.common.netty.protocol;
import io.netty.buffer.ByteBuf;
-import io.netty.channel.Channel;
-import org.apache.uniffle.common.netty.buffer.FileSegmentManagedBuffer;
import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ByteBufUtils;
-public class GetLocalShuffleDataResponse extends RpcResponse implements
Transferable {
-
- private ManagedBuffer buffer;
+public class GetLocalShuffleDataResponse extends RpcResponse {
public GetLocalShuffleDataResponse(
long requestId, StatusCode statusCode, String retMessage, ManagedBuffer
data) {
- super(requestId, statusCode, retMessage);
- this.buffer = data;
- }
-
- @Override
- public int encodedLength() {
- return super.encodedLength() + Integer.BYTES + buffer.size();
- }
-
- @Override
- public void encode(ByteBuf buf) {
- super.encode(buf);
- if (buffer instanceof FileSegmentManagedBuffer) {
- buf.writeInt(buffer.size());
- } else {
- ByteBufUtils.copyByteBuf(buffer.byteBuf(), buf);
- buffer.release();
- }
+ super(requestId, statusCode, retMessage, data);
}
- public static GetLocalShuffleDataResponse decode(ByteBuf byteBuf) {
+ public static GetLocalShuffleDataResponse decode(ByteBuf byteBuf, boolean
decodeBody) {
long requestId = byteBuf.readLong();
StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
- ByteBuf data = ByteBufUtils.readSlice(byteBuf);
- return new GetLocalShuffleDataResponse(
- requestId, statusCode, retMessage, new NettyManagedBuffer(data));
+ if (decodeBody) {
+ NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+ return new GetLocalShuffleDataResponse(requestId, statusCode,
retMessage, nettyManagedBuffer);
+ } else {
+ return new GetLocalShuffleDataResponse(
+ requestId, statusCode, retMessage, NettyManagedBuffer.EMPTY_BUFFER);
+ }
}
@Override
public Type type() {
return Type.GET_LOCAL_SHUFFLE_DATA_RESPONSE;
}
-
- public ManagedBuffer getBuffer() {
- return buffer;
- }
-
- public Object getData() {
- return buffer.convertToNetty();
- }
-
- @Override
- public void transferTo(Channel channel) {
- channel.write(buffer.convertToNetty());
- buffer.release();
- }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
index d6179086b..f97373805 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
@@ -19,16 +19,14 @@ package org.apache.uniffle.common.netty.protocol;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
-import io.netty.channel.Channel;
-import org.apache.uniffle.common.netty.buffer.FileSegmentManagedBuffer;
import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ByteBufUtils;
-public class GetLocalShuffleIndexResponse extends RpcResponse implements
Transferable {
- private ManagedBuffer buffer;
+public class GetLocalShuffleIndexResponse extends RpcResponse {
+
private long fileLength;
public GetLocalShuffleIndexResponse(
@@ -54,38 +52,36 @@ public class GetLocalShuffleIndexResponse extends
RpcResponse implements Transfe
long requestId,
StatusCode statusCode,
String retMessage,
- ManagedBuffer indexData,
+ ManagedBuffer managedBuffer,
long fileLength) {
- super(requestId, statusCode, retMessage);
- this.buffer = indexData;
+ super(requestId, statusCode, retMessage, managedBuffer);
this.fileLength = fileLength;
}
@Override
public int encodedLength() {
- return super.encodedLength() + Integer.BYTES + buffer.size() + Long.BYTES;
+ return super.encodedLength() + Long.BYTES;
}
@Override
public void encode(ByteBuf buf) {
super.encode(buf);
- if (buffer instanceof FileSegmentManagedBuffer) {
- buf.writeInt(buffer.size());
- } else {
- ByteBufUtils.copyByteBuf(buffer.byteBuf(), buf);
- buffer.release();
- }
buf.writeLong(fileLength);
}
- public static GetLocalShuffleIndexResponse decode(ByteBuf byteBuf) {
+ public static GetLocalShuffleIndexResponse decode(ByteBuf byteBuf, boolean
decodeBody) {
long requestId = byteBuf.readLong();
StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
- ByteBuf indexData = ByteBufUtils.readSlice(byteBuf);
long fileLength = byteBuf.readLong();
- return new GetLocalShuffleIndexResponse(
- requestId, statusCode, retMessage, indexData, fileLength);
+ if (decodeBody) {
+ NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+ return new GetLocalShuffleIndexResponse(
+ requestId, statusCode, retMessage, nettyManagedBuffer, fileLength);
+ } else {
+ return new GetLocalShuffleIndexResponse(
+ requestId, statusCode, retMessage, NettyManagedBuffer.EMPTY_BUFFER,
fileLength);
+ }
}
@Override
@@ -93,17 +89,7 @@ public class GetLocalShuffleIndexResponse extends
RpcResponse implements Transfe
return Type.GET_LOCAL_SHUFFLE_INDEX_RESPONSE;
}
- public ByteBuf getIndexData() {
- return buffer.byteBuf();
- }
-
public long getFileLength() {
return fileLength;
}
-
- @Override
- public void transferTo(Channel channel) {
- channel.write(buffer.convertToNetty());
- buffer.release();
- }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
index e766b28b1..d358cf7cd 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
@@ -67,7 +67,7 @@ public class GetMemoryShuffleDataRequest extends
RequestMessage {
+ ByteBufUtils.encodedLength(appId)
+ 4 * Integer.BYTES
+ 2 * Long.BYTES
- + expectedTaskIdsBitmap.serializedSizeInBytes());
+ + (expectedTaskIdsBitmap == null ? 0L :
expectedTaskIdsBitmap.serializedSizeInBytes()));
}
@Override
@@ -79,9 +79,13 @@ public class GetMemoryShuffleDataRequest extends
RequestMessage {
buf.writeLong(lastBlockId);
buf.writeInt(readBufferSize);
buf.writeLong(timestamp);
- buf.writeInt((int) expectedTaskIdsBitmap.serializedSizeInBytes());
try {
- buf.writeBytes(RssUtils.serializeBitMap(expectedTaskIdsBitmap));
+ if (expectedTaskIdsBitmap != null) {
+ buf.writeInt((int) expectedTaskIdsBitmap.serializedSizeInBytes());
+ buf.writeBytes(RssUtils.serializeBitMap(expectedTaskIdsBitmap));
+ } else {
+ buf.writeInt(-1);
+ }
} catch (IOException ioException) {
throw new EncodeException(
"serializeBitMap failed while encode GetMemoryShuffleDataRequest!",
ioException);
@@ -97,9 +101,11 @@ public class GetMemoryShuffleDataRequest extends
RequestMessage {
int readBufferSize = byteBuf.readInt();
long timestamp = byteBuf.readLong();
byte[] bytes = ByteBufUtils.readByteArray(byteBuf);
- Roaring64NavigableMap expectedTaskIdsBitmap;
+ Roaring64NavigableMap expectedTaskIdsBitmap = null;
try {
- expectedTaskIdsBitmap = RssUtils.deserializeBitMap(bytes);
+ if (bytes != null) {
+ expectedTaskIdsBitmap = RssUtils.deserializeBitMap(bytes);
+ }
} catch (IOException ioException) {
throw new DecodeException(
"serializeBitMap failed while decode GetMemoryShuffleDataRequest!",
ioException);
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
index c77f64a96..c9b9f7ed5 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
@@ -23,12 +23,13 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ByteBufUtils;
public class GetMemoryShuffleDataResponse extends RpcResponse {
private List<BufferSegment> bufferSegments;
- private ByteBuf data;
public GetMemoryShuffleDataResponse(
long requestId, StatusCode statusCode, List<BufferSegment>
bufferSegments, byte[] data) {
@@ -50,35 +51,43 @@ public class GetMemoryShuffleDataResponse extends
RpcResponse {
String retMessage,
List<BufferSegment> bufferSegments,
ByteBuf data) {
- super(requestId, statusCode, retMessage);
+ this(requestId, statusCode, retMessage, bufferSegments, new
NettyManagedBuffer(data));
+ }
+
+ public GetMemoryShuffleDataResponse(
+ long requestId,
+ StatusCode statusCode,
+ String retMessage,
+ List<BufferSegment> bufferSegments,
+ ManagedBuffer managedBuffer) {
+ super(requestId, statusCode, retMessage, managedBuffer);
this.bufferSegments = bufferSegments;
- this.data = data;
}
@Override
public int encodedLength() {
- return super.encodedLength()
- + Encoders.encodeLengthOfBufferSegments(bufferSegments)
- + Integer.BYTES
- + data.readableBytes();
+ return super.encodedLength() +
Encoders.encodeLengthOfBufferSegments(bufferSegments);
}
@Override
public void encode(ByteBuf buf) {
super.encode(buf);
Encoders.encodeBufferSegments(bufferSegments, buf);
- ByteBufUtils.copyByteBuf(data, buf);
- data.release();
}
- public static GetMemoryShuffleDataResponse decode(ByteBuf byteBuf) {
+ public static GetMemoryShuffleDataResponse decode(ByteBuf byteBuf, boolean
decodeBody) {
long requestId = byteBuf.readLong();
StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
List<BufferSegment> bufferSegments =
Decoders.decodeBufferSegments(byteBuf);
- ByteBuf data = ByteBufUtils.readSlice(byteBuf);
- return new GetMemoryShuffleDataResponse(
- requestId, statusCode, retMessage, bufferSegments, data);
+ if (decodeBody) {
+ NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+ return new GetMemoryShuffleDataResponse(
+ requestId, statusCode, retMessage, bufferSegments,
nettyManagedBuffer);
+ } else {
+ return new GetMemoryShuffleDataResponse(
+ requestId, statusCode, retMessage, bufferSegments,
NettyManagedBuffer.EMPTY_BUFFER);
+ }
}
@Override
@@ -89,8 +98,4 @@ public class GetMemoryShuffleDataResponse extends RpcResponse
{
public List<BufferSegment> getBufferSegments() {
return bufferSegments;
}
-
- public ByteBuf getData() {
- return data;
- }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
index c019099fd..e941a8d1b 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
@@ -19,10 +19,26 @@ package org.apache.uniffle.common.netty.protocol;
import io.netty.buffer.ByteBuf;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+
public abstract class Message implements Encodable {
+ private ManagedBuffer body;
+
+ protected Message() {
+ this(null);
+ }
+
+ protected Message(ManagedBuffer body) {
+ this.body = body;
+ }
+
public abstract Type type();
+ public ManagedBuffer body() {
+ return body;
+ }
+
public enum Type implements Encodable {
UNKNOWN_TYPE(-1),
RPC_RESPONSE(0),
@@ -124,9 +140,21 @@ public abstract class Message implements Encodable {
public static Message decode(Type msgType, ByteBuf in) {
switch (msgType) {
case RPC_RESPONSE:
- return RpcResponse.decode(in);
+ return RpcResponse.decode(in, false);
case SEND_SHUFFLE_DATA_REQUEST:
return SendShuffleDataRequest.decode(in);
+ case GET_LOCAL_SHUFFLE_DATA_REQUEST:
+ return GetLocalShuffleDataRequest.decode(in);
+ case GET_LOCAL_SHUFFLE_DATA_RESPONSE:
+ return GetLocalShuffleDataResponse.decode(in, true);
+ case GET_LOCAL_SHUFFLE_INDEX_REQUEST:
+ return GetLocalShuffleIndexRequest.decode(in);
+ case GET_LOCAL_SHUFFLE_INDEX_RESPONSE:
+ return GetLocalShuffleIndexResponse.decode(in, true);
+ case GET_MEMORY_SHUFFLE_DATA_REQUEST:
+ return GetMemoryShuffleDataRequest.decode(in);
+ case GET_MEMORY_SHUFFLE_DATA_RESPONSE:
+ return GetMemoryShuffleDataResponse.decode(in, true);
default:
throw new IllegalArgumentException("Unexpected message type: " +
msgType);
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/MessageWithHeader.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/MessageWithHeader.java
new file mode 100644
index 000000000..a53f013d0
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/MessageWithHeader.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import javax.annotation.Nullable;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.FileRegion;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * <p>The header must be a ByteBuf, while the body can be a ByteBuf or a
FileRegion.
+ */
+public class MessageWithHeader extends AbstractFileRegion {
+
+ @Nullable private final ManagedBuffer managedBuffer;
+ private final ByteBuf header;
+ private final int headerLength;
+ private final Object body;
+ private final long bodyLength;
+ private long totalBytesTransferred;
+
+ /**
+ * When the write buffer size is larger than this limit, I/O will be done in
chunks of this size.
+ * The size should not be too large as it will waste underlying memory copy.
e.g. If network
+ * available buffer is smaller than this limit, the data cannot be sent
within one single write
+ * operation while it still will make memory copy with this size.
+ */
+ private static final int NIO_BUFFER_LIMIT = 256 * 1024;
+
+ /**
+ * Construct a new MessageWithHeader.
+ *
+ * @param managedBuffer the {@link ManagedBuffer} that the message body came
from. This needs to
+ * be passed in so that the buffer can be freed when this message is
deallocated. Ownership of
+ * the caller's reference to this buffer is transferred to this class,
so if the caller wants
+ * to continue to use the ManagedBuffer in other messages then they will
need to call retain()
+ * on it before passing it to this constructor. This may be null if and
only if `body` is a
+ * {@link FileRegion}.
+ * @param header the message header.
+ * @param body the message body. Must be either a {@link ByteBuf} or a
{@link FileRegion}.
+ * @param bodyLength the length of the message body, in bytes.
+ */
+ public MessageWithHeader(
+ @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long
bodyLength) {
+ Preconditions.checkArgument(
+ body instanceof ByteBuf || body instanceof FileRegion,
+ "Body must be a ByteBuf or a FileRegion.");
+ this.managedBuffer = managedBuffer;
+ this.header = header;
+ this.headerLength = header.readableBytes();
+ this.body = body;
+ this.bodyLength = bodyLength;
+ }
+
+ @Override
+ public long count() {
+ return headerLength + bodyLength;
+ }
+
+ @Override
+ public long position() {
+ return 0;
+ }
+
+ @Override
+ public long transferred() {
+ return totalBytesTransferred;
+ }
+
+ /**
+ * This code is more complicated than you would think because we might
require multiple transferTo
+ * invocations in order to transfer a single MessageWithHeader to avoid busy
waiting.
+ *
+ * <p>The contract is that the caller will ensure position is properly set
to the total number of
+ * bytes transferred so far (i.e. value returned by transferred()).
+ */
+ @Override
+ public long transferTo(final WritableByteChannel target, final long
position) throws IOException {
+ Preconditions.checkArgument(position == totalBytesTransferred, "Invalid
position.");
+ // Bytes written for header in this call.
+ long writtenHeader = 0;
+ if (header.readableBytes() > 0) {
+ writtenHeader = copyByteBuf(header, target);
+ totalBytesTransferred += writtenHeader;
+ if (header.readableBytes() > 0) {
+ return writtenHeader;
+ }
+ }
+
+ // Bytes written for body in this call.
+ long writtenBody = 0;
+ if (body instanceof FileRegion) {
+ writtenBody = ((FileRegion) body).transferTo(target,
totalBytesTransferred - headerLength);
+ } else if (body instanceof ByteBuf) {
+ writtenBody = copyByteBuf((ByteBuf) body, target);
+ }
+ totalBytesTransferred += writtenBody;
+
+ return writtenHeader + writtenBody;
+ }
+
+ @Override
+ protected void deallocate() {
+ header.release();
+ ReferenceCountUtil.release(body);
+ if (managedBuffer != null) {
+ managedBuffer.release();
+ }
+ }
+
+ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws
IOException {
+ // SPARK-24578: cap the sub-region's size of returned nio buffer to
improve the performance
+ // for the case that the passed-in buffer has too many components.
+ int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT);
+ // If the ByteBuf holds more then one ByteBuffer we should better call
nioBuffers(...)
+ // to eliminate extra memory copies.
+ int written = 0;
+ if (buf.nioBufferCount() == 1) {
+ ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length);
+ written = target.write(buffer);
+ } else {
+ ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length);
+ for (ByteBuffer buffer : buffers) {
+ int remaining = buffer.remaining();
+ int w = target.write(buffer);
+ written += w;
+ if (w < remaining) {
+ // Could not write all, we need to break now.
+ break;
+ }
+ }
+ }
+ buf.skipBytes(written);
+ return written;
+ }
+
+ @Override
+ public MessageWithHeader touch(Object o) {
+ super.touch(o);
+ header.touch(o);
+ ReferenceCountUtil.touch(body, o);
+ return this;
+ }
+
+ @Override
+ public MessageWithHeader retain(int increment) {
+ super.retain(increment);
+ header.retain(increment);
+ ReferenceCountUtil.retain(body, increment);
+ if (managedBuffer != null) {
+ for (int i = 0; i < increment; i++) {
+ managedBuffer.retain();
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public boolean release(int decrement) {
+ header.release(decrement);
+ ReferenceCountUtil.release(body, decrement);
+ if (managedBuffer != null) {
+ for (int i = 0; i < decrement; i++) {
+ managedBuffer.release();
+ }
+ }
+ return super.release(decrement);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
index 695484dbe..cfa55287c 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
@@ -17,12 +17,18 @@
package org.apache.uniffle.common.netty.protocol;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+
public abstract class RequestMessage extends Message {
private final long requestId;
public static final int REQUEST_ID_ENCODE_LENGTH = Long.BYTES;
public RequestMessage(long requestId) {
- super();
+ this(requestId, null);
+ }
+
+ public RequestMessage(long requestId, ManagedBuffer managedBuffer) {
+ super(managedBuffer);
this.requestId = requestId;
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Transferable.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
similarity index 71%
rename from
common/src/main/java/org/apache/uniffle/common/netty/protocol/Transferable.java
rename to
common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
index d0fc8ac9d..36e3f3c2e 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Transferable.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
@@ -17,9 +17,18 @@
package org.apache.uniffle.common.netty.protocol;
-import io.netty.channel.Channel;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
-public interface Transferable {
+public abstract class ResponseMessage extends Message {
+ public ResponseMessage() {
+ super();
+ }
- void transferTo(Channel channel);
+ public ResponseMessage(ManagedBuffer buffer) {
+ super(buffer);
+ }
+
+ public ResponseMessage createFailureResponse(String error) {
+ throw new UnsupportedOperationException();
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
index b5a35b6af..d56a24d9e 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
@@ -19,19 +19,27 @@ package org.apache.uniffle.common.netty.protocol;
import io.netty.buffer.ByteBuf;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ByteBufUtils;
-public class RpcResponse extends Message {
+public class RpcResponse extends ResponseMessage {
private long requestId;
private StatusCode statusCode;
private String retMessage;
- public RpcResponse(long requestId, StatusCode statusCode) {
- this(requestId, statusCode, null);
+ public RpcResponse(long requestId, StatusCode statusCode, String retMessage)
{
+ this(requestId, statusCode, retMessage, null);
}
- public RpcResponse(long requestId, StatusCode statusCode, String retMessage)
{
+ public RpcResponse(long requestId, StatusCode statusCode, ManagedBuffer
message) {
+ this(requestId, statusCode, null, message);
+ }
+
+ public RpcResponse(
+ long requestId, StatusCode statusCode, String retMessage, ManagedBuffer
message) {
+ super(message);
this.requestId = requestId;
this.statusCode = statusCode;
this.retMessage = retMessage;
@@ -70,11 +78,16 @@ public class RpcResponse extends Message {
ByteBufUtils.writeLengthAndString(buf, retMessage);
}
- public static RpcResponse decode(ByteBuf buf) {
- long requestId = buf.readLong();
- StatusCode statusCode = StatusCode.fromCode(buf.readInt());
- String retMessage = ByteBufUtils.readLengthAndString(buf);
- return new RpcResponse(requestId, statusCode, retMessage);
+ public static RpcResponse decode(ByteBuf byteBuf, boolean decodeBody) {
+ long requestId = byteBuf.readLong();
+ StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
+ String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
+ if (decodeBody) {
+ NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+ return new RpcResponse(requestId, statusCode, retMessage,
nettyManagedBuffer);
+ } else {
+ return new RpcResponse(requestId, statusCode, retMessage,
NettyManagedBuffer.EMPTY_BUFFER);
+ }
}
public long getRequestId() {
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
index 7cb5eab6f..84f77c34d 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
@@ -67,6 +67,9 @@ public class ByteBufUtils {
public static final byte[] readByteArray(ByteBuf byteBuf) {
int length = byteBuf.readInt();
+ if (length < 0) {
+ return null;
+ }
byte[] data = new byte[length];
byteBuf.readBytes(data);
return data;
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
index ac0946034..5f1c87ced 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
@@ -112,4 +112,24 @@ public class NettyUtils {
public static String getServerConnectionInfo(Channel channel) {
return String.format("[%s -> %s]", channel.localAddress(),
channel.remoteAddress());
}
+
+ private static class AllocatorHolder {
+ private static final PooledByteBufAllocator INSTANCE = createAllocator();
+ }
+
+ public static PooledByteBufAllocator getNettyBufferAllocator() {
+ return AllocatorHolder.INSTANCE;
+ }
+
+ private static PooledByteBufAllocator createAllocator() {
+ return new PooledByteBufAllocator(
+ true,
+ PooledByteBufAllocator.defaultNumHeapArena(),
+ PooledByteBufAllocator.defaultNumDirectArena(),
+ PooledByteBufAllocator.defaultPageSize(),
+ PooledByteBufAllocator.defaultMaxOrder(),
+ 0,
+ 0,
+ PooledByteBufAllocator.defaultUseCacheForAllThreads());
+ }
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
index 48ff31ffa..a370828c3 100644
---
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
+++
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
@@ -126,7 +126,7 @@ public class NettyProtocolTest {
ByteBuf byteBuf = Unpooled.buffer(encodeLength);
rpcResponse.encode(byteBuf);
assertEquals(byteBuf.readableBytes(), encodeLength);
- RpcResponse rpcResponse1 = RpcResponse.decode(byteBuf);
+ RpcResponse rpcResponse1 = RpcResponse.decode(byteBuf, true);
assertEquals(rpcResponse.getRequestId(), rpcResponse1.getRequestId());
assertEquals(rpcResponse.getRetMessage(), rpcResponse1.getRetMessage());
assertEquals(rpcResponse.getStatusCode(), rpcResponse1.getStatusCode());
@@ -177,7 +177,7 @@ public class NettyProtocolTest {
ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
getLocalShuffleDataResponse.encode(byteBuf);
GetLocalShuffleDataResponse getLocalShuffleDataResponse1 =
- GetLocalShuffleDataResponse.decode(byteBuf);
+ GetLocalShuffleDataResponse.decode(byteBuf, true);
assertEquals(
getLocalShuffleDataResponse.getRequestId(),
getLocalShuffleDataResponse1.getRequestId());
@@ -185,7 +185,6 @@ public class NettyProtocolTest {
getLocalShuffleDataResponse.getRetMessage(),
getLocalShuffleDataResponse1.getRetMessage());
assertEquals(
getLocalShuffleDataResponse.getStatusCode(),
getLocalShuffleDataResponse1.getStatusCode());
- assertEquals(getLocalShuffleDataResponse.getData(),
getLocalShuffleDataResponse1.getData());
}
@Test
@@ -224,7 +223,7 @@ public class NettyProtocolTest {
ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
getLocalShuffleIndexResponse.encode(byteBuf);
GetLocalShuffleIndexResponse getLocalShuffleIndexResponse1 =
- GetLocalShuffleIndexResponse.decode(byteBuf);
+ GetLocalShuffleIndexResponse.decode(byteBuf, true);
assertEquals(
getLocalShuffleIndexResponse.getRequestId(),
getLocalShuffleIndexResponse1.getRequestId());
@@ -234,11 +233,6 @@ public class NettyProtocolTest {
assertEquals(
getLocalShuffleIndexResponse.getRetMessage(),
getLocalShuffleIndexResponse1.getRetMessage());
- assertEquals(
- getLocalShuffleIndexResponse.getFileLength(),
- getLocalShuffleIndexResponse1.getFileLength());
- assertEquals(
- getLocalShuffleIndexResponse.getIndexData(),
getLocalShuffleIndexResponse1.getIndexData());
}
@Test
@@ -287,15 +281,13 @@ public class NettyProtocolTest {
ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
getMemoryShuffleDataResponse.encode(byteBuf);
GetMemoryShuffleDataResponse getMemoryShuffleDataResponse1 =
- GetMemoryShuffleDataResponse.decode(byteBuf);
+ GetMemoryShuffleDataResponse.decode(byteBuf, true);
assertEquals(
getMemoryShuffleDataResponse.getRequestId(),
getMemoryShuffleDataResponse1.getRequestId());
assertEquals(
getMemoryShuffleDataResponse.getStatusCode(),
getMemoryShuffleDataResponse1.getStatusCode());
- assertTrue(
-
getMemoryShuffleDataResponse.getData().equals(getMemoryShuffleDataResponse1.getData()));
for (int i = 0; i < 2; i++) {
assertEquals(
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
index 85c410908..bb9dd4450 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
@@ -37,6 +37,7 @@ import
org.apache.uniffle.client.request.RssRegisterShuffleRequest;
import org.apache.uniffle.client.request.RssSendCommitRequest;
import org.apache.uniffle.client.request.RssSendShuffleDataRequest;
import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -241,6 +242,7 @@ public class ShuffleServerFaultToleranceTest extends
ShuffleReadWriteBase {
request.setDistributionType(ShuffleDataDistributionType.NORMAL);
Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
request.setExpectTaskIds(taskIdBitmap);
+ request.setClientType(ClientType.GRPC);
return request;
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index fa9a0520b..e36d18528 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -29,6 +29,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import com.google.protobuf.UnsafeByteOperations;
+import io.netty.buffer.Unpooled;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -64,6 +65,7 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.exception.NotRetryException;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
@@ -810,7 +812,8 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
response =
new RssGetShuffleIndexResponse(
StatusCode.SUCCESS,
- ByteBuffer.wrap(rpcResponse.getIndexData().toByteArray()),
+ new NettyManagedBuffer(
+
Unpooled.wrappedBuffer(rpcResponse.getIndexData().toByteArray())),
rpcResponse.getDataFileLen());
break;
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index cb8e1e623..9230a49e5 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -200,7 +200,7 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
case SUCCESS:
return new RssGetInMemoryShuffleDataResponse(
StatusCode.SUCCESS,
- getMemoryShuffleDataResponse.getData().nioBuffer(),
+ getMemoryShuffleDataResponse.body(),
getMemoryShuffleDataResponse.getBufferSegments());
default:
String msg =
@@ -251,7 +251,7 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
case SUCCESS:
return new RssGetShuffleIndexResponse(
StatusCode.SUCCESS,
- getLocalShuffleIndexResponse.getIndexData().nioBuffer(),
+ getLocalShuffleIndexResponse.body(),
getLocalShuffleIndexResponse.getFileLength());
default:
String msg =
@@ -305,7 +305,7 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
switch (statusCode) {
case SUCCESS:
return new RssGetShuffleDataResponse(
- StatusCode.SUCCESS,
getLocalShuffleDataResponse.getBuffer().nioByteBuffer());
+ StatusCode.SUCCESS, getLocalShuffleDataResponse.body());
default:
String msg =
"Can't get shuffle data from "
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
index 1468d5104..bbf3738cb 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
@@ -20,22 +20,31 @@ package org.apache.uniffle.client.response;
import java.nio.ByteBuffer;
import java.util.List;
+import io.netty.buffer.Unpooled;
+
import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
public class RssGetInMemoryShuffleDataResponse extends ClientResponse {
- private final ByteBuffer data;
+ private final ManagedBuffer data;
private final List<BufferSegment> bufferSegments;
public RssGetInMemoryShuffleDataResponse(
StatusCode statusCode, ByteBuffer data, List<BufferSegment>
bufferSegments) {
+ this(statusCode, new NettyManagedBuffer(Unpooled.wrappedBuffer(data)),
bufferSegments);
+ }
+
+ public RssGetInMemoryShuffleDataResponse(
+ StatusCode statusCode, ManagedBuffer data, List<BufferSegment>
bufferSegments) {
super(statusCode);
this.bufferSegments = bufferSegments;
this.data = data;
}
- public ByteBuffer getData() {
+ public ManagedBuffer getData() {
return data;
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
index 7d9ca3927..474106fb0 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
@@ -19,18 +19,26 @@ package org.apache.uniffle.client.response;
import java.nio.ByteBuffer;
+import io.netty.buffer.Unpooled;
+
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
public class RssGetShuffleDataResponse extends ClientResponse {
- private final ByteBuffer shuffleData;
+ private final ManagedBuffer shuffleData;
public RssGetShuffleDataResponse(StatusCode statusCode, ByteBuffer data) {
+ this(statusCode, new NettyManagedBuffer(Unpooled.wrappedBuffer(data)));
+ }
+
+ public RssGetShuffleDataResponse(StatusCode statusCode, ManagedBuffer data) {
super(statusCode);
this.shuffleData = data;
}
- public ByteBuffer getShuffleData() {
+ public ManagedBuffer getShuffleData() {
return shuffleData;
}
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
index 29ba4320a..37a31652e 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
@@ -17,15 +17,14 @@
package org.apache.uniffle.client.response;
-import java.nio.ByteBuffer;
-
import org.apache.uniffle.common.ShuffleIndexResult;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
public class RssGetShuffleIndexResponse extends ClientResponse {
private final ShuffleIndexResult shuffleIndexResult;
- public RssGetShuffleIndexResponse(StatusCode statusCode, ByteBuffer data,
long dataFileLen) {
+ public RssGetShuffleIndexResponse(StatusCode statusCode, ManagedBuffer data,
long dataFileLen) {
super(statusCode);
this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen);
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
index 0736780b0..4a9a215a7 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
@@ -25,7 +25,6 @@ import java.util.function.Supplier;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
-import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
@@ -38,6 +37,7 @@ import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.ShufflePartitionedData;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.common.util.NettyUtils;
import org.apache.uniffle.server.ShuffleDataFlushEvent;
import org.apache.uniffle.server.ShuffleFlushManager;
@@ -164,7 +164,9 @@ public class ShuffleBuffer {
if (!bufferSegments.isEmpty()) {
CompositeByteBuf byteBuf =
new CompositeByteBuf(
- ByteBufAllocator.DEFAULT, true,
Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
+ NettyUtils.getNettyBufferAllocator(),
+ true,
+ Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
// copy result data
updateShuffleData(readBlocks, byteBuf);
return new ShuffleDataResult(byteBuf, bufferSegments);
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index fd1c99e62..e4b469152 100644
---
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -17,12 +17,10 @@
package org.apache.uniffle.server.netty;
-import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import com.google.common.collect.Lists;
-import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,6 +34,7 @@ import org.apache.uniffle.common.ShufflePartitionedData;
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.exception.FileNotFoundException;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.netty.client.TransportClient;
import org.apache.uniffle.common.netty.handle.BaseMessageHandler;
@@ -252,13 +251,13 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
blockId,
readBufferSize,
req.getExpectedTaskIdsBitmap());
- ByteBuf data = Unpooled.EMPTY_BUFFER;
+ ManagedBuffer data = NettyManagedBuffer.EMPTY_BUFFER;
List<BufferSegment> bufferSegments = Lists.newArrayList();
if (shuffleDataResult != null) {
- data = Unpooled.wrappedBuffer(shuffleDataResult.getDataBuffer());
+ data = shuffleDataResult.getManagedBuffer();
bufferSegments = shuffleDataResult.getBufferSegments();
-
ShuffleServerMetrics.counterTotalReadDataSize.inc(data.readableBytes());
-
ShuffleServerMetrics.counterTotalReadMemoryDataSize.inc(data.readableBytes());
+ ShuffleServerMetrics.counterTotalReadDataSize.inc(data.size());
+ ShuffleServerMetrics.counterTotalReadMemoryDataSize.inc(data.size());
}
long costTime = System.currentTimeMillis() - start;
shuffleServer
@@ -267,7 +266,7 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
LOG.info(
"Successfully getInMemoryShuffleData cost {} ms with {} bytes
shuffle" + " data for {}",
costTime,
- data.readableBytes(),
+ data.size(),
requestInfo);
response =
@@ -334,21 +333,17 @@ public class ShuffleServerNettyHandler implements
BaseMessageHandler {
.getShuffleTaskManager()
.getShuffleIndex(appId, shuffleId, partitionId,
partitionNumPerRange, partitionNum);
- ByteBuffer data = shuffleIndexResult.getIndexData();
- ShuffleServerMetrics.counterTotalReadDataSize.inc(data.remaining());
-
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.remaining());
+ ManagedBuffer data = shuffleIndexResult.getManagedBuffer();
+ ShuffleServerMetrics.counterTotalReadDataSize.inc(data.size());
+
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.size());
response =
new GetLocalShuffleIndexResponse(
- req.getRequestId(),
- status,
- msg,
- Unpooled.wrappedBuffer(data),
- shuffleIndexResult.getDataFileLen());
+ req.getRequestId(), status, msg, data,
shuffleIndexResult.getDataFileLen());
long readTime = System.currentTimeMillis() - start;
LOG.info(
"Successfully getShuffleIndex cost {} ms for {}" + " bytes with
{}",
readTime,
- data.remaining(),
+ data.size(),
requestInfo);
} catch (FileNotFoundException indexFileNotFoundException) {
LOG.warn(
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
index 2c2e6892e..d7990a126 100644
--- a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
@@ -21,7 +21,6 @@ import java.io.IOException;
import java.util.concurrent.TimeUnit;
import io.netty.bootstrap.ServerBootstrap;
-import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
@@ -40,6 +39,7 @@ import
org.apache.uniffle.common.netty.client.TransportContext;
import org.apache.uniffle.common.rpc.ServerInterface;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.ExitUtils;
+import org.apache.uniffle.common.util.NettyUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
@@ -101,9 +101,9 @@ public class StreamServer implements ServerInterface {
})
.option(ChannelOption.SO_BACKLOG, backlogSize)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMillis)
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .option(ChannelOption.ALLOCATOR, NettyUtils.getNettyBufferAllocator())
.childOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMillis)
- .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.ALLOCATOR,
NettyUtils.getNettyBufferAllocator())
.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.SO_KEEPALIVE, true);
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index dcdf4ac77..1db3326be 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -27,7 +27,6 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
-import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.RssUtils;
@@ -121,7 +120,7 @@ public class ShuffleHandlerFactory {
CreateShuffleReadHandlerRequest request, ShuffleServerInfo ssi) {
ShuffleServerClient shuffleServerClient =
ShuffleServerClientFactory.getInstance()
- .getShuffleServerClient(ClientType.GRPC.name(), ssi,
request.getClientConf());
+ .getShuffleServerClient(request.getClientType().name(), ssi,
request.getClientConf());
Roaring64NavigableMap expectTaskIds = null;
if (request.isExpectedTaskIdsBitmapFilterEnable()) {
Roaring64NavigableMap realExceptBlockIds =
RssUtils.cloneBitMap(request.getExpectBlockIds());
@@ -143,7 +142,7 @@ public class ShuffleHandlerFactory {
CreateShuffleReadHandlerRequest request, ShuffleServerInfo ssi) {
ShuffleServerClient shuffleServerClient =
ShuffleServerClientFactory.getInstance()
- .getShuffleServerClient(ClientType.GRPC.name(), ssi,
request.getClientConf());
+ .getShuffleServerClient(request.getClientType().name(), ssi,
request.getClientConf());
return new LocalFileClientReadHandler(
request.getAppId(),
request.getShuffleId(),
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
index f39637d8e..f3ea3c1f5 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
@@ -77,6 +77,7 @@ public abstract class DataSkippableReadHandler extends
AbstractClientReadHandler
SegmentSplitterFactory.getInstance()
.get(distributionType, expectTaskIds, readBufferSize)
.split(shuffleIndexResult);
+ shuffleIndexResult.release();
}
// We should skip unexpected and processed segments when handler is read
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index 90fc0e397..38c7e9efb 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -22,6 +22,7 @@ import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssBaseConf;
@@ -52,6 +53,8 @@ public class CreateShuffleReadHandlerRequest {
private IdHelper idHelper;
+ private ClientType clientType;
+
public CreateShuffleReadHandlerRequest() {}
public RssBaseConf getRssBaseConf() {
@@ -213,4 +216,12 @@ public class CreateShuffleReadHandlerRequest {
public void setClientConf(RssConf clientConf) {
this.clientConf = clientConf;
}
+
+ public ClientType getClientType() {
+ return clientType;
+ }
+
+ public void setClientType(ClientType clientType) {
+ this.clientType = clientType;
+ }
}
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
index 29ddd6070..884f2b969 100644
---
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
+++
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
@@ -24,6 +24,7 @@ import java.util.Map;
import java.util.stream.Collectors;
import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentMatcher;
import org.mockito.Mockito;
@@ -36,6 +37,7 @@ import
org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
@@ -86,7 +88,10 @@ public class LocalFileServerReadHandlerTest {
int actualWriteDataBlock = expectTotalBlockNum - 1;
int actualFileLen = blockSize * actualWriteDataBlock;
RssGetShuffleIndexResponse response =
- new RssGetShuffleIndexResponse(StatusCode.SUCCESS, byteBuffer,
actualFileLen);
+ new RssGetShuffleIndexResponse(
+ StatusCode.SUCCESS,
+ new NettyManagedBuffer(Unpooled.wrappedBuffer(byteBuffer)),
+ actualFileLen);
Mockito.doReturn(response).when(mockShuffleServerClient).getShuffleIndex(Mockito.any());
int readBufferSize = 13;