This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 34db3f118 [#1187] feat(netty): Netty Encoder Support zero-copy. (#1313)
34db3f118 is described below

commit 34db3f118c62407a880a3b12d8f582e8c556bf89
Author: Xianming Lei <[email protected]>
AuthorDate: Fri Nov 24 16:33:45 2023 +0800

    [#1187] feat(netty): Netty Encoder Support zero-copy. (#1313)
    
    ### What changes were proposed in this pull request?
    
    Netty Encoder Support zero-copy.
    
    ### Why are the changes needed?
    
    Reduce ByteBuf copy.
    
    ### Does this PR introduce _any_ user-facing change?
    
    
    No.
    
    ### How was this patch tested?
    
    Existing UTs.
    
    Co-authored-by: leixianming <[email protected]>
---
 .../client/factory/ShuffleClientFactory.java       |  11 ++
 .../uniffle/client/impl/ShuffleReadClientImpl.java |  12 +-
 .../apache/uniffle/common/ShuffleDataResult.java   |  13 +-
 .../apache/uniffle/common/ShuffleIndexResult.java  |   4 +
 .../uniffle/common/config/RssClientConf.java       |   7 +
 .../apache/uniffle/common/netty/FrameDecoder.java  |   2 +-
 .../uniffle/common/netty/MessageEncoder.java       |  79 ++++++---
 .../common/netty/TransportFrameDecoder.java        |  15 +-
 .../netty/buffer/FileSegmentManagedBuffer.java     |   5 +
 .../uniffle/common/netty/buffer/ManagedBuffer.java |   2 +
 .../common/netty/buffer/NettyManagedBuffer.java    |   9 +
 .../netty/handle/TransportResponseHandler.java     |   8 +-
 .../AbstractFileRegion.java}                       |  39 ++---
 .../uniffle/common/netty/protocol/Decoders.java    |   4 +-
 .../protocol/GetLocalShuffleDataResponse.java      |  51 ++----
 .../protocol/GetLocalShuffleIndexResponse.java     |  42 ++---
 .../protocol/GetMemoryShuffleDataRequest.java      |  16 +-
 .../protocol/GetMemoryShuffleDataResponse.java     |  39 +++--
 .../uniffle/common/netty/protocol/Message.java     |  30 +++-
 .../common/netty/protocol/MessageWithHeader.java   | 193 +++++++++++++++++++++
 .../common/netty/protocol/RequestMessage.java      |   8 +-
 .../{Transferable.java => ResponseMessage.java}    |  15 +-
 .../uniffle/common/netty/protocol/RpcResponse.java |  31 +++-
 .../apache/uniffle/common/util/ByteBufUtils.java   |   3 +
 .../org/apache/uniffle/common/util/NettyUtils.java |  20 +++
 .../common/netty/protocol/NettyProtocolTest.java   |  16 +-
 .../test/ShuffleServerFaultToleranceTest.java      |   2 +
 .../client/impl/grpc/ShuffleServerGrpcClient.java  |   5 +-
 .../impl/grpc/ShuffleServerGrpcNettyClient.java    |   6 +-
 .../RssGetInMemoryShuffleDataResponse.java         |  13 +-
 .../client/response/RssGetShuffleDataResponse.java |  12 +-
 .../response/RssGetShuffleIndexResponse.java       |   5 +-
 .../uniffle/server/buffer/ShuffleBuffer.java       |   6 +-
 .../server/netty/ShuffleServerNettyHandler.java    |  27 ++-
 .../apache/uniffle/server/netty/StreamServer.java  |   6 +-
 .../storage/factory/ShuffleHandlerFactory.java     |   5 +-
 .../handler/impl/DataSkippableReadHandler.java     |   1 +
 .../request/CreateShuffleReadHandlerRequest.java   |  11 ++
 .../impl/LocalFileServerReadHandlerTest.java       |   7 +-
 39 files changed, 562 insertions(+), 218 deletions(-)

diff --git 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index 6baccc05b..ce7d90041 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -26,6 +26,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
 import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssConf;
@@ -210,6 +211,7 @@ public class ShuffleClientFactory {
     private String storageType;
     private int indexReadLimit;
     private long readBufferSize;
+    private ClientType clientType;
 
     public ReadClientBuilder appId(String appId) {
       this.appId = appId;
@@ -303,6 +305,11 @@ public class ShuffleClientFactory {
       return this;
     }
 
+    public ReadClientBuilder clientType(ClientType clientType) {
+      this.clientType = clientType;
+      return this;
+    }
+
     public ReadClientBuilder() {}
 
     public String getAppId() {
@@ -377,6 +384,10 @@ public class ShuffleClientFactory {
       return readBufferSize;
     }
 
+    public ClientType getClientType() {
+      return clientType;
+    }
+
     public ShuffleReadClientImpl build() {
       return new ShuffleReadClientImpl(this);
     }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 24fb20609..e94f5aa6d 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -55,6 +55,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
   private int shuffleId;
   private int partitionId;
   private ByteBuffer readBuffer;
+  private ShuffleDataResult sdr;
   private Roaring64NavigableMap blockIdBitmap;
   private Roaring64NavigableMap taskIdBitmap;
   private Roaring64NavigableMap pendingBlockIds;
@@ -96,6 +97,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
       builder.storageType(storageType);
       builder.readBufferSize(readBufferSize);
       builder.offHeapEnable(offHeapEnabled);
+      
builder.clientType(builder.getRssConf().get(RssClientConf.RSS_CLIENT_TYPE));
     } else {
       // most for test
       RssConf rssConf = new RssConf();
@@ -107,6 +109,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
       builder.rssConf(rssConf);
       builder.offHeapEnable(false);
       builder.expectedTaskIdsBitmapFilterEnable(false);
+      builder.clientType(rssConf.get(RssClientConf.RSS_CLIENT_TYPE));
     }
 
     init(builder);
@@ -138,6 +141,7 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
     request.setIdHelper(idHelper);
     request.setExpectTaskIds(taskIdBitmap);
     request.setClientConf(builder.getRssConf());
+    request.setClientType(builder.getClientType());
     if (builder.isExpectedTaskIdsBitmapFilterEnable()) {
       request.useExpectedTaskIdsBitmapFilter();
     }
@@ -258,7 +262,13 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
 
   private int read() {
     long start = System.currentTimeMillis();
-    ShuffleDataResult sdr = clientReadHandler.readShuffleData();
+    // In order to avoid copying, we postpone the release here instead of in 
the Decoder.
+    // RssUtils.releaseByteBuffer(readBuffer) cannot actually release the 
memory,
+    // because PlatformDependent.freeDirectBuffer can only release the 
ByteBuffer with cleaner.
+    if (sdr != null) {
+      sdr.release();
+    }
+    sdr = clientReadHandler.readShuffleData();
     readDataTime.addAndGet(System.currentTimeMillis() - start);
     if (sdr == null) {
       return 0;
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
index e63357e10..00867e700 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
@@ -53,14 +53,18 @@ public class ShuffleDataResult {
   }
 
   public ShuffleDataResult(ByteBuf data, List<BufferSegment> bufferSegments) {
-    this.buffer = new NettyManagedBuffer(data);
-    this.bufferSegments = bufferSegments;
+    this(new NettyManagedBuffer(data), bufferSegments);
   }
 
   public ShuffleDataResult(byte[] data, List<BufferSegment> bufferSegments) {
     this(data != null ? ByteBuffer.wrap(data) : null, bufferSegments);
   }
 
+  public ShuffleDataResult(ManagedBuffer data, List<BufferSegment> 
bufferSegments) {
+    this.buffer = data;
+    this.bufferSegments = bufferSegments;
+  }
+
   public byte[] getData() {
     if (buffer == null) {
       return null;
@@ -75,10 +79,7 @@ public class ShuffleDataResult {
     if (buffer == null) {
       return 0;
     }
-    if (buffer.nioByteBuffer().hasArray()) {
-      return buffer.nioByteBuffer().array().length;
-    }
-    return buffer.nioByteBuffer().remaining();
+    return buffer.size();
   }
 
   public ByteBuf getDataBuf() {
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
index 2a686c44f..71bb3df39 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java
@@ -75,4 +75,8 @@ public class ShuffleIndexResult {
       this.buffer.release();
     }
   }
+
+  public ManagedBuffer getManagedBuffer() {
+    return buffer;
+  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java 
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index c516f7470..86eb1d950 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -17,6 +17,7 @@
 
 package org.apache.uniffle.common.config;
 
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.common.netty.IOMode;
@@ -135,4 +136,10 @@ public class RssClientConf {
           .stringType()
           .defaultValue("14m")
           .withDescription("The max data size read from storage");
+
+  public static final ConfigOption<ClientType> RSS_CLIENT_TYPE =
+      ConfigOptions.key("rss.client.type")
+          .enumType(ClientType.class)
+          .defaultValue(ClientType.GRPC)
+          .withDescription("Supports GRPC, GRPC_NETTY");
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java 
b/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
index 53366dc3c..493fbf8d6 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
@@ -19,5 +19,5 @@ package org.apache.uniffle.common.netty;
 
 public interface FrameDecoder {
   String HANDLER_NAME = "FrameDecoder";
-  int HEADER_SIZE = Integer.BYTES + Byte.BYTES;
+  int HEADER_SIZE = Integer.BYTES + Byte.BYTES + Integer.BYTES;
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java 
b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
index 5221cc428..d2947b3bc 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
@@ -17,50 +17,81 @@
 
 package org.apache.uniffle.common.netty;
 
+import java.util.List;
+
 import io.netty.buffer.ByteBuf;
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelOutboundHandlerAdapter;
-import io.netty.channel.ChannelPromise;
+import io.netty.handler.codec.MessageToMessageEncoder;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.netty.protocol.Message;
-import org.apache.uniffle.common.netty.protocol.Transferable;
+import org.apache.uniffle.common.netty.protocol.MessageWithHeader;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
 
 /**
  * Encoder used by the server side to encode server-to-client responses. This 
encoder is stateless
- * so it is safe to be shared by multiple threads. The content of encode 
consists of two parts,
- * header and message body. The encoded binary stream contains encodeLength (4 
bytes), messageType
- * (1 byte) and messageBody (encodeLength bytes).
+ * so it is safe to be shared by multiple threads.
  */
 @ChannelHandler.Sharable
-public class MessageEncoder extends ChannelOutboundHandlerAdapter {
+public final class MessageEncoder extends MessageToMessageEncoder<Message> {
 
-  private static final Logger LOG = 
LoggerFactory.getLogger(MessageEncoder.class);
+  private static final Logger logger = 
LoggerFactory.getLogger(MessageEncoder.class);
 
   public static final MessageEncoder INSTANCE = new MessageEncoder();
 
   private MessageEncoder() {}
 
+  /**
+   * * Encodes a Message by invoking its encode() method. For non-data 
messages, we will add one
+   * ByteBuf to 'out' containing the total frame length, the message type, and 
the message itself.
+   * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer 
corresponding to the
+   * data to 'out', in order to enable zero-copy transfer.
+   */
   @Override
-  public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise 
promise) {
-    Message message = (Message) msg;
-    int encodeLength = message.encodedLength();
-    ByteBuf byteBuf = ctx.alloc().buffer(FrameDecoder.HEADER_SIZE + 
encodeLength);
-    try {
-      byteBuf.writeInt(encodeLength);
-      byteBuf.writeByte(message.type().id());
-      message.encode(byteBuf);
-    } catch (Exception e) {
-      LOG.error("Unexpected exception during process encode!", e);
-      byteBuf.release();
-      throw e;
+  public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) 
throws Exception {
+    Object body = null;
+    int bodyLength = 0;
+
+    // If the message has a body, take it out to enable zero-copy transfer for 
the payload.
+    if (in.body() != null) {
+      try {
+        bodyLength = (int) in.body().size();
+        body = in.body().convertToNetty();
+      } catch (Exception e) {
+        in.body().release();
+        if (in instanceof RpcResponse) {
+          RpcResponse resp = (RpcResponse) in;
+          // Re-encode this message as a failure response.
+          String error = e.getMessage() != null ? e.getMessage() : "null";
+          logger.error(
+              String.format("Error processing %s for client %s", in, 
ctx.channel().remoteAddress()),
+              e);
+          encode(ctx, resp.createFailureResponse(error), out);
+        } else {
+          throw e;
+        }
+        return;
+      }
     }
-    ctx.writeAndFlush(byteBuf);
-    // do transferTo send data after encode buffer send.
-    if (message instanceof Transferable) {
-      ((Transferable) message).transferTo(ctx.channel());
+
+    Message.Type msgType = in.type();
+    // message size, message type size, body size, message encoded length
+    int headerLength = Integer.BYTES + msgType.encodedLength() + Integer.BYTES 
+ in.encodedLength();
+    ByteBuf header = ctx.alloc().heapBuffer(headerLength);
+    header.writeInt(in.encodedLength());
+    msgType.encode(header);
+    header.writeInt(bodyLength);
+    in.encode(header);
+    assert header.writableBytes() == 0;
+
+    if (body != null) {
+      // We transfer ownership of the reference on in.body() to 
MessageWithHeader.
+      // This reference will be freed when MessageWithHeader.deallocate() is 
called.
+      out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
+    } else {
+      out.add(header);
     }
   }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
index d9fd734cd..4a7b8ab4b 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
@@ -45,6 +45,7 @@ import org.apache.uniffle.common.netty.protocol.Message;
  */
 public class TransportFrameDecoder extends ChannelInboundHandlerAdapter 
implements FrameDecoder {
   private int msgSize = -1;
+  private int bodySize = -1;
   private Message.Type curType = Message.Type.UNKNOWN_TYPE;
   private ByteBuf headerBuf = Unpooled.buffer(HEADER_SIZE, HEADER_SIZE);
   private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
@@ -66,10 +67,10 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter implemen
       if (frame == null) {
         break;
       }
-      // todo: An exception may be thrown during the decoding process, causing 
frame.release() to
-      // fail to be called
       Message msg = Message.decode(curType, frame);
-      frame.release();
+      if (msg.body() == null) {
+        frame.release();
+      }
       ctx.fireChannelRead(msg);
       clear();
     }
@@ -78,6 +79,7 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter implemen
   private void clear() {
     curType = Message.Type.UNKNOWN_TYPE;
     msgSize = -1;
+    bodySize = -1;
     headerBuf.clear();
   }
 
@@ -94,7 +96,8 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter implemen
     if (first.readableBytes() >= HEADER_SIZE) {
       msgSize = first.readInt();
       curType = Message.Type.decode(first);
-      nextFrameSize = msgSize;
+      bodySize = first.readInt();
+      nextFrameSize = msgSize + bodySize;
       totalSize -= HEADER_SIZE;
       if (!first.isReadable()) {
         buffers.removeFirst().release();
@@ -113,7 +116,8 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter implemen
 
     msgSize = headerBuf.readInt();
     curType = Message.Type.decode(headerBuf);
-    nextFrameSize = msgSize;
+    bodySize = headerBuf.readInt();
+    nextFrameSize = msgSize + bodySize;
     totalSize -= HEADER_SIZE;
     return nextFrameSize;
   }
@@ -126,7 +130,6 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter implemen
 
     // Reset size for next frame.
     nextFrameSize = UNKNOWN_FRAME_SIZE;
-
     Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: 
%s", frameSize);
     Preconditions.checkArgument(frameSize > 0, "Frame length should be 
positive: %s", frameSize);
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
index a4975bde4..a83492f79 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/FileSegmentManagedBuffer.java
@@ -91,6 +91,11 @@ public class FileSegmentManagedBuffer extends ManagedBuffer {
     }
   }
 
+  @Override
+  public ManagedBuffer retain() {
+    return this;
+  }
+
   @Override
   public ManagedBuffer release() {
     return this;
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
index ed15640e9..844ed4c13 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/ManagedBuffer.java
@@ -29,6 +29,8 @@ public abstract class ManagedBuffer {
 
   public abstract ByteBuffer nioByteBuffer();
 
+  public abstract ManagedBuffer retain();
+
   public abstract ManagedBuffer release();
 
   /**
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
index 4cc6686a5..53286429d 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
@@ -24,6 +24,9 @@ import io.netty.buffer.Unpooled;
 
 public class NettyManagedBuffer extends ManagedBuffer {
 
+  public static final NettyManagedBuffer EMPTY_BUFFER =
+      new NettyManagedBuffer(Unpooled.buffer(0, 0));
+
   private ByteBuf buf;
 
   public NettyManagedBuffer(ByteBuf byteBuf) {
@@ -45,6 +48,12 @@ public class NettyManagedBuffer extends ManagedBuffer {
     return buf.nioBuffer();
   }
 
+  @Override
+  public ManagedBuffer retain() {
+    buf.retain();
+    return this;
+  }
+
   @Override
   public ManagedBuffer release() {
     buf.release();
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
index f1ab3adc9..f7c9ccc53 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
@@ -61,9 +61,11 @@ public class TransportResponseHandler extends 
MessageHandler<RpcResponse> {
   public void handle(RpcResponse message) throws Exception {
     RpcResponseCallback listener = 
outstandingRpcRequests.get(message.getRequestId());
     if (listener == null) {
-      logger.warn(
-          "Ignoring response from {} since it is not outstanding",
-          NettyUtils.getRemoteAddress(channel));
+      logger.error(
+          "Ignoring response from {} since it is not outstanding, {} {}",
+          NettyUtils.getRemoteAddress(channel),
+          message.type(),
+          message.getRequestId());
     } else {
       listener.onSuccess(message);
     }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/AbstractFileRegion.java
similarity index 59%
copy from 
common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
copy to 
common/src/main/java/org/apache/uniffle/common/netty/protocol/AbstractFileRegion.java
index 4cc6686a5..c716ba7d8 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/buffer/NettyManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/AbstractFileRegion.java
@@ -15,44 +15,39 @@
  * limitations under the License.
  */
 
-package org.apache.uniffle.common.netty.buffer;
+package org.apache.uniffle.common.netty.protocol;
 
-import java.nio.ByteBuffer;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
 
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-
-public class NettyManagedBuffer extends ManagedBuffer {
-
-  private ByteBuf buf;
-
-  public NettyManagedBuffer(ByteBuf byteBuf) {
-    this.buf = byteBuf;
-  }
+public abstract class AbstractFileRegion extends AbstractReferenceCounted 
implements FileRegion {
 
   @Override
-  public int size() {
-    return buf.readableBytes();
+  @SuppressWarnings("deprecation")
+  public final long transfered() {
+    return transferred();
   }
 
   @Override
-  public ByteBuf byteBuf() {
-    return Unpooled.wrappedBuffer(this.nioByteBuffer());
+  public AbstractFileRegion retain() {
+    super.retain();
+    return this;
   }
 
   @Override
-  public ByteBuffer nioByteBuffer() {
-    return buf.nioBuffer();
+  public AbstractFileRegion retain(int increment) {
+    super.retain(increment);
+    return this;
   }
 
   @Override
-  public ManagedBuffer release() {
-    buf.release();
+  public AbstractFileRegion touch() {
+    super.touch();
     return this;
   }
 
   @Override
-  public Object convertToNetty() {
-    return buf.duplicate().retain();
+  public AbstractFileRegion touch(Object o) {
+    return this;
   }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
index db6c2a000..b8c687ce7 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
@@ -23,12 +23,12 @@ import java.util.Map;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
 
 import org.apache.uniffle.common.BufferSegment;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.util.ByteBufUtils;
+import org.apache.uniffle.common.util.NettyUtils;
 
 public class Decoders {
   public static ShuffleServerInfo decodeShuffleServerInfo(ByteBuf byteBuf) {
@@ -47,7 +47,7 @@ public class Decoders {
     long crc = byteBuf.readLong();
     long taskAttemptId = byteBuf.readLong();
     int dataLength = byteBuf.readInt();
-    ByteBuf data = Unpooled.directBuffer(dataLength);
+    ByteBuf data = 
NettyUtils.getNettyBufferAllocator().directBuffer(dataLength);
     data.writeBytes(byteBuf, dataLength);
     int lengthOfShuffleServers = byteBuf.readInt();
     List<ShuffleServerInfo> serverInfos = Lists.newArrayList();
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
index e1b7af7c4..1a078b7a0 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
@@ -18,65 +18,34 @@
 package org.apache.uniffle.common.netty.protocol;
 
 import io.netty.buffer.ByteBuf;
-import io.netty.channel.Channel;
 
-import org.apache.uniffle.common.netty.buffer.FileSegmentManagedBuffer;
 import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
 import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.ByteBufUtils;
 
-public class GetLocalShuffleDataResponse extends RpcResponse implements 
Transferable {
-
-  private ManagedBuffer buffer;
+public class GetLocalShuffleDataResponse extends RpcResponse {
 
   public GetLocalShuffleDataResponse(
       long requestId, StatusCode statusCode, String retMessage, ManagedBuffer 
data) {
-    super(requestId, statusCode, retMessage);
-    this.buffer = data;
-  }
-
-  @Override
-  public int encodedLength() {
-    return super.encodedLength() + Integer.BYTES + buffer.size();
-  }
-
-  @Override
-  public void encode(ByteBuf buf) {
-    super.encode(buf);
-    if (buffer instanceof FileSegmentManagedBuffer) {
-      buf.writeInt(buffer.size());
-    } else {
-      ByteBufUtils.copyByteBuf(buffer.byteBuf(), buf);
-      buffer.release();
-    }
+    super(requestId, statusCode, retMessage, data);
   }
 
-  public static GetLocalShuffleDataResponse decode(ByteBuf byteBuf) {
+  public static GetLocalShuffleDataResponse decode(ByteBuf byteBuf, boolean 
decodeBody) {
     long requestId = byteBuf.readLong();
     StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
     String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
-    ByteBuf data = ByteBufUtils.readSlice(byteBuf);
-    return new GetLocalShuffleDataResponse(
-        requestId, statusCode, retMessage, new NettyManagedBuffer(data));
+    if (decodeBody) {
+      NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+      return new GetLocalShuffleDataResponse(requestId, statusCode, 
retMessage, nettyManagedBuffer);
+    } else {
+      return new GetLocalShuffleDataResponse(
+          requestId, statusCode, retMessage, NettyManagedBuffer.EMPTY_BUFFER);
+    }
   }
 
   @Override
   public Type type() {
     return Type.GET_LOCAL_SHUFFLE_DATA_RESPONSE;
   }
-
-  public ManagedBuffer getBuffer() {
-    return buffer;
-  }
-
-  public Object getData() {
-    return buffer.convertToNetty();
-  }
-
-  @Override
-  public void transferTo(Channel channel) {
-    channel.write(buffer.convertToNetty());
-    buffer.release();
-  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
index d6179086b..f97373805 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
@@ -19,16 +19,14 @@ package org.apache.uniffle.common.netty.protocol;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
-import io.netty.channel.Channel;
 
-import org.apache.uniffle.common.netty.buffer.FileSegmentManagedBuffer;
 import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
 import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.ByteBufUtils;
 
-public class GetLocalShuffleIndexResponse extends RpcResponse implements 
Transferable {
-  private ManagedBuffer buffer;
+public class GetLocalShuffleIndexResponse extends RpcResponse {
+
   private long fileLength;
 
   public GetLocalShuffleIndexResponse(
@@ -54,38 +52,36 @@ public class GetLocalShuffleIndexResponse extends 
RpcResponse implements Transfe
       long requestId,
       StatusCode statusCode,
       String retMessage,
-      ManagedBuffer indexData,
+      ManagedBuffer managedBuffer,
       long fileLength) {
-    super(requestId, statusCode, retMessage);
-    this.buffer = indexData;
+    super(requestId, statusCode, retMessage, managedBuffer);
     this.fileLength = fileLength;
   }
 
   @Override
   public int encodedLength() {
-    return super.encodedLength() + Integer.BYTES + buffer.size() + Long.BYTES;
+    return super.encodedLength() + Long.BYTES;
   }
 
   @Override
   public void encode(ByteBuf buf) {
     super.encode(buf);
-    if (buffer instanceof FileSegmentManagedBuffer) {
-      buf.writeInt(buffer.size());
-    } else {
-      ByteBufUtils.copyByteBuf(buffer.byteBuf(), buf);
-      buffer.release();
-    }
     buf.writeLong(fileLength);
   }
 
-  public static GetLocalShuffleIndexResponse decode(ByteBuf byteBuf) {
+  public static GetLocalShuffleIndexResponse decode(ByteBuf byteBuf, boolean 
decodeBody) {
     long requestId = byteBuf.readLong();
     StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
     String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
-    ByteBuf indexData = ByteBufUtils.readSlice(byteBuf);
     long fileLength = byteBuf.readLong();
-    return new GetLocalShuffleIndexResponse(
-        requestId, statusCode, retMessage, indexData, fileLength);
+    if (decodeBody) {
+      NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+      return new GetLocalShuffleIndexResponse(
+          requestId, statusCode, retMessage, nettyManagedBuffer, fileLength);
+    } else {
+      return new GetLocalShuffleIndexResponse(
+          requestId, statusCode, retMessage, NettyManagedBuffer.EMPTY_BUFFER, 
fileLength);
+    }
   }
 
   @Override
@@ -93,17 +89,7 @@ public class GetLocalShuffleIndexResponse extends 
RpcResponse implements Transfe
     return Type.GET_LOCAL_SHUFFLE_INDEX_RESPONSE;
   }
 
-  public ByteBuf getIndexData() {
-    return buffer.byteBuf();
-  }
-
   public long getFileLength() {
     return fileLength;
   }
-
-  @Override
-  public void transferTo(Channel channel) {
-    channel.write(buffer.convertToNetty());
-    buffer.release();
-  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
index e766b28b1..d358cf7cd 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
@@ -67,7 +67,7 @@ public class GetMemoryShuffleDataRequest extends 
RequestMessage {
             + ByteBufUtils.encodedLength(appId)
             + 4 * Integer.BYTES
             + 2 * Long.BYTES
-            + expectedTaskIdsBitmap.serializedSizeInBytes());
+            + (expectedTaskIdsBitmap == null ? 0L : 
expectedTaskIdsBitmap.serializedSizeInBytes()));
   }
 
   @Override
@@ -79,9 +79,13 @@ public class GetMemoryShuffleDataRequest extends 
RequestMessage {
     buf.writeLong(lastBlockId);
     buf.writeInt(readBufferSize);
     buf.writeLong(timestamp);
-    buf.writeInt((int) expectedTaskIdsBitmap.serializedSizeInBytes());
     try {
-      buf.writeBytes(RssUtils.serializeBitMap(expectedTaskIdsBitmap));
+      if (expectedTaskIdsBitmap != null) {
+        buf.writeInt((int) expectedTaskIdsBitmap.serializedSizeInBytes());
+        buf.writeBytes(RssUtils.serializeBitMap(expectedTaskIdsBitmap));
+      } else {
+        buf.writeInt(-1);
+      }
     } catch (IOException ioException) {
       throw new EncodeException(
           "serializeBitMap failed while encode GetMemoryShuffleDataRequest!", 
ioException);
@@ -97,9 +101,11 @@ public class GetMemoryShuffleDataRequest extends 
RequestMessage {
     int readBufferSize = byteBuf.readInt();
     long timestamp = byteBuf.readLong();
     byte[] bytes = ByteBufUtils.readByteArray(byteBuf);
-    Roaring64NavigableMap expectedTaskIdsBitmap;
+    Roaring64NavigableMap expectedTaskIdsBitmap = null;
     try {
-      expectedTaskIdsBitmap = RssUtils.deserializeBitMap(bytes);
+      if (bytes != null) {
+        expectedTaskIdsBitmap = RssUtils.deserializeBitMap(bytes);
+      }
     } catch (IOException ioException) {
       throw new DecodeException(
           "serializeBitMap failed while decode GetMemoryShuffleDataRequest!", 
ioException);
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
index c77f64a96..c9b9f7ed5 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
@@ -23,12 +23,13 @@ import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 
 import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.ByteBufUtils;
 
 public class GetMemoryShuffleDataResponse extends RpcResponse {
   private List<BufferSegment> bufferSegments;
-  private ByteBuf data;
 
   public GetMemoryShuffleDataResponse(
       long requestId, StatusCode statusCode, List<BufferSegment> 
bufferSegments, byte[] data) {
@@ -50,35 +51,43 @@ public class GetMemoryShuffleDataResponse extends 
RpcResponse {
       String retMessage,
       List<BufferSegment> bufferSegments,
       ByteBuf data) {
-    super(requestId, statusCode, retMessage);
+    this(requestId, statusCode, retMessage, bufferSegments, new 
NettyManagedBuffer(data));
+  }
+
+  public GetMemoryShuffleDataResponse(
+      long requestId,
+      StatusCode statusCode,
+      String retMessage,
+      List<BufferSegment> bufferSegments,
+      ManagedBuffer managedBuffer) {
+    super(requestId, statusCode, retMessage, managedBuffer);
     this.bufferSegments = bufferSegments;
-    this.data = data;
   }
 
   @Override
   public int encodedLength() {
-    return super.encodedLength()
-        + Encoders.encodeLengthOfBufferSegments(bufferSegments)
-        + Integer.BYTES
-        + data.readableBytes();
+    return super.encodedLength() + 
Encoders.encodeLengthOfBufferSegments(bufferSegments);
   }
 
   @Override
   public void encode(ByteBuf buf) {
     super.encode(buf);
     Encoders.encodeBufferSegments(bufferSegments, buf);
-    ByteBufUtils.copyByteBuf(data, buf);
-    data.release();
   }
 
-  public static GetMemoryShuffleDataResponse decode(ByteBuf byteBuf) {
+  public static GetMemoryShuffleDataResponse decode(ByteBuf byteBuf, boolean 
decodeBody) {
     long requestId = byteBuf.readLong();
     StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
     String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
     List<BufferSegment> bufferSegments = 
Decoders.decodeBufferSegments(byteBuf);
-    ByteBuf data = ByteBufUtils.readSlice(byteBuf);
-    return new GetMemoryShuffleDataResponse(
-        requestId, statusCode, retMessage, bufferSegments, data);
+    if (decodeBody) {
+      NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+      return new GetMemoryShuffleDataResponse(
+          requestId, statusCode, retMessage, bufferSegments, 
nettyManagedBuffer);
+    } else {
+      return new GetMemoryShuffleDataResponse(
+          requestId, statusCode, retMessage, bufferSegments, 
NettyManagedBuffer.EMPTY_BUFFER);
+    }
   }
 
   @Override
@@ -89,8 +98,4 @@ public class GetMemoryShuffleDataResponse extends RpcResponse 
{
   public List<BufferSegment> getBufferSegments() {
     return bufferSegments;
   }
-
-  public ByteBuf getData() {
-    return data;
-  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
index c019099fd..e941a8d1b 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
@@ -19,10 +19,26 @@ package org.apache.uniffle.common.netty.protocol;
 
 import io.netty.buffer.ByteBuf;
 
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+
 public abstract class Message implements Encodable {
 
+  private ManagedBuffer body;
+
+  protected Message() {
+    this(null);
+  }
+
+  protected Message(ManagedBuffer body) {
+    this.body = body;
+  }
+
   public abstract Type type();
 
+  public ManagedBuffer body() {
+    return body;
+  }
+
   public enum Type implements Encodable {
     UNKNOWN_TYPE(-1),
     RPC_RESPONSE(0),
@@ -124,9 +140,21 @@ public abstract class Message implements Encodable {
   public static Message decode(Type msgType, ByteBuf in) {
     switch (msgType) {
       case RPC_RESPONSE:
-        return RpcResponse.decode(in);
+        return RpcResponse.decode(in, false);
       case SEND_SHUFFLE_DATA_REQUEST:
         return SendShuffleDataRequest.decode(in);
+      case GET_LOCAL_SHUFFLE_DATA_REQUEST:
+        return GetLocalShuffleDataRequest.decode(in);
+      case GET_LOCAL_SHUFFLE_DATA_RESPONSE:
+        return GetLocalShuffleDataResponse.decode(in, true);
+      case GET_LOCAL_SHUFFLE_INDEX_REQUEST:
+        return GetLocalShuffleIndexRequest.decode(in);
+      case GET_LOCAL_SHUFFLE_INDEX_RESPONSE:
+        return GetLocalShuffleIndexResponse.decode(in, true);
+      case GET_MEMORY_SHUFFLE_DATA_REQUEST:
+        return GetMemoryShuffleDataRequest.decode(in);
+      case GET_MEMORY_SHUFFLE_DATA_RESPONSE:
+        return GetMemoryShuffleDataResponse.decode(in, true);
       default:
         throw new IllegalArgumentException("Unexpected message type: " + 
msgType);
     }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/MessageWithHeader.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/MessageWithHeader.java
new file mode 100644
index 000000000..a53f013d0
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/MessageWithHeader.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import javax.annotation.Nullable;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.FileRegion;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * <p>The header must be a ByteBuf, while the body can be a ByteBuf or a 
FileRegion.
+ */
+public class MessageWithHeader extends AbstractFileRegion {
+
+  @Nullable private final ManagedBuffer managedBuffer;
+  private final ByteBuf header;
+  private final int headerLength;
+  private final Object body;
+  private final long bodyLength;
+  private long totalBytesTransferred;
+
+  /**
+   * When the write buffer size is larger than this limit, I/O will be done in 
chunks of this size.
+   * The size should not be too large as it will waste underlying memory copy. 
e.g. If network
+   * available buffer is smaller than this limit, the data cannot be sent 
within one single write
+   * operation while it still will make memory copy with this size.
+   */
+  private static final int NIO_BUFFER_LIMIT = 256 * 1024;
+
+  /**
+   * Construct a new MessageWithHeader.
+   *
+   * @param managedBuffer the {@link ManagedBuffer} that the message body came 
from. This needs to
+   *     be passed in so that the buffer can be freed when this message is 
deallocated. Ownership of
+   *     the caller's reference to this buffer is transferred to this class, 
so if the caller wants
+   *     to continue to use the ManagedBuffer in other messages then they will 
need to call retain()
+   *     on it before passing it to this constructor. This may be null if and 
only if `body` is a
+   *     {@link FileRegion}.
+   * @param header the message header.
+   * @param body the message body. Must be either a {@link ByteBuf} or a 
{@link FileRegion}.
+   * @param bodyLength the length of the message body, in bytes.
+   */
+  public MessageWithHeader(
+      @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long 
bodyLength) {
+    Preconditions.checkArgument(
+        body instanceof ByteBuf || body instanceof FileRegion,
+        "Body must be a ByteBuf or a FileRegion.");
+    this.managedBuffer = managedBuffer;
+    this.header = header;
+    this.headerLength = header.readableBytes();
+    this.body = body;
+    this.bodyLength = bodyLength;
+  }
+
+  @Override
+  public long count() {
+    return headerLength + bodyLength;
+  }
+
+  @Override
+  public long position() {
+    return 0;
+  }
+
+  @Override
+  public long transferred() {
+    return totalBytesTransferred;
+  }
+
+  /**
+   * This code is more complicated than you would think because we might 
require multiple transferTo
+   * invocations in order to transfer a single MessageWithHeader to avoid busy 
waiting.
+   *
+   * <p>The contract is that the caller will ensure position is properly set 
to the total number of
+   * bytes transferred so far (i.e. value returned by transferred()).
+   */
+  @Override
+  public long transferTo(final WritableByteChannel target, final long 
position) throws IOException {
+    Preconditions.checkArgument(position == totalBytesTransferred, "Invalid 
position.");
+    // Bytes written for header in this call.
+    long writtenHeader = 0;
+    if (header.readableBytes() > 0) {
+      writtenHeader = copyByteBuf(header, target);
+      totalBytesTransferred += writtenHeader;
+      if (header.readableBytes() > 0) {
+        return writtenHeader;
+      }
+    }
+
+    // Bytes written for body in this call.
+    long writtenBody = 0;
+    if (body instanceof FileRegion) {
+      writtenBody = ((FileRegion) body).transferTo(target, 
totalBytesTransferred - headerLength);
+    } else if (body instanceof ByteBuf) {
+      writtenBody = copyByteBuf((ByteBuf) body, target);
+    }
+    totalBytesTransferred += writtenBody;
+
+    return writtenHeader + writtenBody;
+  }
+
+  @Override
+  protected void deallocate() {
+    header.release();
+    ReferenceCountUtil.release(body);
+    if (managedBuffer != null) {
+      managedBuffer.release();
+    }
+  }
+
+  private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws 
IOException {
+    // SPARK-24578: cap the sub-region's size of returned nio buffer to 
improve the performance
+    // for the case that the passed-in buffer has too many components.
+    int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT);
+    // If the ByteBuf holds more then one ByteBuffer we should better call 
nioBuffers(...)
+    // to eliminate extra memory copies.
+    int written = 0;
+    if (buf.nioBufferCount() == 1) {
+      ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length);
+      written = target.write(buffer);
+    } else {
+      ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length);
+      for (ByteBuffer buffer : buffers) {
+        int remaining = buffer.remaining();
+        int w = target.write(buffer);
+        written += w;
+        if (w < remaining) {
+          // Could not write all, we need to break now.
+          break;
+        }
+      }
+    }
+    buf.skipBytes(written);
+    return written;
+  }
+
+  @Override
+  public MessageWithHeader touch(Object o) {
+    super.touch(o);
+    header.touch(o);
+    ReferenceCountUtil.touch(body, o);
+    return this;
+  }
+
+  @Override
+  public MessageWithHeader retain(int increment) {
+    super.retain(increment);
+    header.retain(increment);
+    ReferenceCountUtil.retain(body, increment);
+    if (managedBuffer != null) {
+      for (int i = 0; i < increment; i++) {
+        managedBuffer.retain();
+      }
+    }
+    return this;
+  }
+
+  @Override
+  public boolean release(int decrement) {
+    header.release(decrement);
+    ReferenceCountUtil.release(body, decrement);
+    if (managedBuffer != null) {
+      for (int i = 0; i < decrement; i++) {
+        managedBuffer.release();
+      }
+    }
+    return super.release(decrement);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
index 695484dbe..cfa55287c 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
@@ -17,12 +17,18 @@
 
 package org.apache.uniffle.common.netty.protocol;
 
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+
 public abstract class RequestMessage extends Message {
   private final long requestId;
   public static final int REQUEST_ID_ENCODE_LENGTH = Long.BYTES;
 
   public RequestMessage(long requestId) {
-    super();
+    this(requestId, null);
+  }
+
+  public RequestMessage(long requestId, ManagedBuffer managedBuffer) {
+    super(managedBuffer);
     this.requestId = requestId;
   }
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Transferable.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
similarity index 71%
rename from 
common/src/main/java/org/apache/uniffle/common/netty/protocol/Transferable.java
rename to 
common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
index d0fc8ac9d..36e3f3c2e 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Transferable.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
@@ -17,9 +17,18 @@
 
 package org.apache.uniffle.common.netty.protocol;
 
-import io.netty.channel.Channel;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
 
-public interface Transferable {
+public abstract class ResponseMessage extends Message {
+  public ResponseMessage() {
+    super();
+  }
 
-  void transferTo(Channel channel);
+  public ResponseMessage(ManagedBuffer buffer) {
+    super(buffer);
+  }
+
+  public ResponseMessage createFailureResponse(String error) {
+    throw new UnsupportedOperationException();
+  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
index b5a35b6af..d56a24d9e 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
@@ -19,19 +19,27 @@ package org.apache.uniffle.common.netty.protocol;
 
 import io.netty.buffer.ByteBuf;
 
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.ByteBufUtils;
 
-public class RpcResponse extends Message {
+public class RpcResponse extends ResponseMessage {
   private long requestId;
   private StatusCode statusCode;
   private String retMessage;
 
-  public RpcResponse(long requestId, StatusCode statusCode) {
-    this(requestId, statusCode, null);
+  public RpcResponse(long requestId, StatusCode statusCode, String retMessage) 
{
+    this(requestId, statusCode, retMessage, null);
   }
 
-  public RpcResponse(long requestId, StatusCode statusCode, String retMessage) 
{
+  public RpcResponse(long requestId, StatusCode statusCode, ManagedBuffer 
message) {
+    this(requestId, statusCode, null, message);
+  }
+
+  public RpcResponse(
+      long requestId, StatusCode statusCode, String retMessage, ManagedBuffer 
message) {
+    super(message);
     this.requestId = requestId;
     this.statusCode = statusCode;
     this.retMessage = retMessage;
@@ -70,11 +78,16 @@ public class RpcResponse extends Message {
     ByteBufUtils.writeLengthAndString(buf, retMessage);
   }
 
-  public static RpcResponse decode(ByteBuf buf) {
-    long requestId = buf.readLong();
-    StatusCode statusCode = StatusCode.fromCode(buf.readInt());
-    String retMessage = ByteBufUtils.readLengthAndString(buf);
-    return new RpcResponse(requestId, statusCode, retMessage);
+  public static RpcResponse decode(ByteBuf byteBuf, boolean decodeBody) {
+    long requestId = byteBuf.readLong();
+    StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
+    String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
+    if (decodeBody) {
+      NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf);
+      return new RpcResponse(requestId, statusCode, retMessage, 
nettyManagedBuffer);
+    } else {
+      return new RpcResponse(requestId, statusCode, retMessage, 
NettyManagedBuffer.EMPTY_BUFFER);
+    }
   }
 
   public long getRequestId() {
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
index 7cb5eab6f..84f77c34d 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
@@ -67,6 +67,9 @@ public class ByteBufUtils {
 
   public static final byte[] readByteArray(ByteBuf byteBuf) {
     int length = byteBuf.readInt();
+    if (length < 0) {
+      return null;
+    }
     byte[] data = new byte[length];
     byteBuf.readBytes(data);
     return data;
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
index ac0946034..5f1c87ced 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
@@ -112,4 +112,24 @@ public class NettyUtils {
   public static String getServerConnectionInfo(Channel channel) {
     return String.format("[%s -> %s]", channel.localAddress(), 
channel.remoteAddress());
   }
+
+  private static class AllocatorHolder {
+    private static final PooledByteBufAllocator INSTANCE = createAllocator();
+  }
+
+  public static PooledByteBufAllocator getNettyBufferAllocator() {
+    return AllocatorHolder.INSTANCE;
+  }
+
+  private static PooledByteBufAllocator createAllocator() {
+    return new PooledByteBufAllocator(
+        true,
+        PooledByteBufAllocator.defaultNumHeapArena(),
+        PooledByteBufAllocator.defaultNumDirectArena(),
+        PooledByteBufAllocator.defaultPageSize(),
+        PooledByteBufAllocator.defaultMaxOrder(),
+        0,
+        0,
+        PooledByteBufAllocator.defaultUseCacheForAllThreads());
+  }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
 
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
index 48ff31ffa..a370828c3 100644
--- 
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
+++ 
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
@@ -126,7 +126,7 @@ public class NettyProtocolTest {
     ByteBuf byteBuf = Unpooled.buffer(encodeLength);
     rpcResponse.encode(byteBuf);
     assertEquals(byteBuf.readableBytes(), encodeLength);
-    RpcResponse rpcResponse1 = RpcResponse.decode(byteBuf);
+    RpcResponse rpcResponse1 = RpcResponse.decode(byteBuf, true);
     assertEquals(rpcResponse.getRequestId(), rpcResponse1.getRequestId());
     assertEquals(rpcResponse.getRetMessage(), rpcResponse1.getRetMessage());
     assertEquals(rpcResponse.getStatusCode(), rpcResponse1.getStatusCode());
@@ -177,7 +177,7 @@ public class NettyProtocolTest {
     ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
     getLocalShuffleDataResponse.encode(byteBuf);
     GetLocalShuffleDataResponse getLocalShuffleDataResponse1 =
-        GetLocalShuffleDataResponse.decode(byteBuf);
+        GetLocalShuffleDataResponse.decode(byteBuf, true);
 
     assertEquals(
         getLocalShuffleDataResponse.getRequestId(), 
getLocalShuffleDataResponse1.getRequestId());
@@ -185,7 +185,6 @@ public class NettyProtocolTest {
         getLocalShuffleDataResponse.getRetMessage(), 
getLocalShuffleDataResponse1.getRetMessage());
     assertEquals(
         getLocalShuffleDataResponse.getStatusCode(), 
getLocalShuffleDataResponse1.getStatusCode());
-    assertEquals(getLocalShuffleDataResponse.getData(), 
getLocalShuffleDataResponse1.getData());
   }
 
   @Test
@@ -224,7 +223,7 @@ public class NettyProtocolTest {
     ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
     getLocalShuffleIndexResponse.encode(byteBuf);
     GetLocalShuffleIndexResponse getLocalShuffleIndexResponse1 =
-        GetLocalShuffleIndexResponse.decode(byteBuf);
+        GetLocalShuffleIndexResponse.decode(byteBuf, true);
 
     assertEquals(
         getLocalShuffleIndexResponse.getRequestId(), 
getLocalShuffleIndexResponse1.getRequestId());
@@ -234,11 +233,6 @@ public class NettyProtocolTest {
     assertEquals(
         getLocalShuffleIndexResponse.getRetMessage(),
         getLocalShuffleIndexResponse1.getRetMessage());
-    assertEquals(
-        getLocalShuffleIndexResponse.getFileLength(),
-        getLocalShuffleIndexResponse1.getFileLength());
-    assertEquals(
-        getLocalShuffleIndexResponse.getIndexData(), 
getLocalShuffleIndexResponse1.getIndexData());
   }
 
   @Test
@@ -287,15 +281,13 @@ public class NettyProtocolTest {
     ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
     getMemoryShuffleDataResponse.encode(byteBuf);
     GetMemoryShuffleDataResponse getMemoryShuffleDataResponse1 =
-        GetMemoryShuffleDataResponse.decode(byteBuf);
+        GetMemoryShuffleDataResponse.decode(byteBuf, true);
 
     assertEquals(
         getMemoryShuffleDataResponse.getRequestId(), 
getMemoryShuffleDataResponse1.getRequestId());
     assertEquals(
         getMemoryShuffleDataResponse.getStatusCode(),
         getMemoryShuffleDataResponse1.getStatusCode());
-    assertTrue(
-        
getMemoryShuffleDataResponse.getData().equals(getMemoryShuffleDataResponse1.getData()));
 
     for (int i = 0; i < 2; i++) {
       assertEquals(
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
index 85c410908..bb9dd4450 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
@@ -37,6 +37,7 @@ import 
org.apache.uniffle.client.request.RssRegisterShuffleRequest;
 import org.apache.uniffle.client.request.RssSendCommitRequest;
 import org.apache.uniffle.client.request.RssSendShuffleDataRequest;
 import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -241,6 +242,7 @@ public class ShuffleServerFaultToleranceTest extends 
ShuffleReadWriteBase {
     request.setDistributionType(ShuffleDataDistributionType.NORMAL);
     Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
     request.setExpectTaskIds(taskIdBitmap);
+    request.setClientType(ClientType.GRPC);
     return request;
   }
 
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index fa9a0520b..e36d18528 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -29,6 +29,7 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.UnsafeByteOperations;
+import io.netty.buffer.Unpooled;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -64,6 +65,7 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.exception.NotRetryException;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
@@ -810,7 +812,8 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
         response =
             new RssGetShuffleIndexResponse(
                 StatusCode.SUCCESS,
-                ByteBuffer.wrap(rpcResponse.getIndexData().toByteArray()),
+                new NettyManagedBuffer(
+                    
Unpooled.wrappedBuffer(rpcResponse.getIndexData().toByteArray())),
                 rpcResponse.getDataFileLen());
 
         break;
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index cb8e1e623..9230a49e5 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -200,7 +200,7 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
       case SUCCESS:
         return new RssGetInMemoryShuffleDataResponse(
             StatusCode.SUCCESS,
-            getMemoryShuffleDataResponse.getData().nioBuffer(),
+            getMemoryShuffleDataResponse.body(),
             getMemoryShuffleDataResponse.getBufferSegments());
       default:
         String msg =
@@ -251,7 +251,7 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
       case SUCCESS:
         return new RssGetShuffleIndexResponse(
             StatusCode.SUCCESS,
-            getLocalShuffleIndexResponse.getIndexData().nioBuffer(),
+            getLocalShuffleIndexResponse.body(),
             getLocalShuffleIndexResponse.getFileLength());
       default:
         String msg =
@@ -305,7 +305,7 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
     switch (statusCode) {
       case SUCCESS:
         return new RssGetShuffleDataResponse(
-            StatusCode.SUCCESS, 
getLocalShuffleDataResponse.getBuffer().nioByteBuffer());
+            StatusCode.SUCCESS, getLocalShuffleDataResponse.body());
       default:
         String msg =
             "Can't get shuffle data from "
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
index 1468d5104..bbf3738cb 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetInMemoryShuffleDataResponse.java
@@ -20,22 +20,31 @@ package org.apache.uniffle.client.response;
 import java.nio.ByteBuffer;
 import java.util.List;
 
+import io.netty.buffer.Unpooled;
+
 import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 
 public class RssGetInMemoryShuffleDataResponse extends ClientResponse {
 
-  private final ByteBuffer data;
+  private final ManagedBuffer data;
   private final List<BufferSegment> bufferSegments;
 
   public RssGetInMemoryShuffleDataResponse(
       StatusCode statusCode, ByteBuffer data, List<BufferSegment> 
bufferSegments) {
+    this(statusCode, new NettyManagedBuffer(Unpooled.wrappedBuffer(data)), 
bufferSegments);
+  }
+
+  public RssGetInMemoryShuffleDataResponse(
+      StatusCode statusCode, ManagedBuffer data, List<BufferSegment> 
bufferSegments) {
     super(statusCode);
     this.bufferSegments = bufferSegments;
     this.data = data;
   }
 
-  public ByteBuffer getData() {
+  public ManagedBuffer getData() {
     return data;
   }
 
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
index 7d9ca3927..474106fb0 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleDataResponse.java
@@ -19,18 +19,26 @@ package org.apache.uniffle.client.response;
 
 import java.nio.ByteBuffer;
 
+import io.netty.buffer.Unpooled;
+
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 
 public class RssGetShuffleDataResponse extends ClientResponse {
 
-  private final ByteBuffer shuffleData;
+  private final ManagedBuffer shuffleData;
 
   public RssGetShuffleDataResponse(StatusCode statusCode, ByteBuffer data) {
+    this(statusCode, new NettyManagedBuffer(Unpooled.wrappedBuffer(data)));
+  }
+
+  public RssGetShuffleDataResponse(StatusCode statusCode, ManagedBuffer data) {
     super(statusCode);
     this.shuffleData = data;
   }
 
-  public ByteBuffer getShuffleData() {
+  public ManagedBuffer getShuffleData() {
     return shuffleData;
   }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
index 29ba4320a..37a31652e 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java
@@ -17,15 +17,14 @@
 
 package org.apache.uniffle.client.response;
 
-import java.nio.ByteBuffer;
-
 import org.apache.uniffle.common.ShuffleIndexResult;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 
 public class RssGetShuffleIndexResponse extends ClientResponse {
   private final ShuffleIndexResult shuffleIndexResult;
 
-  public RssGetShuffleIndexResponse(StatusCode statusCode, ByteBuffer data, 
long dataFileLen) {
+  public RssGetShuffleIndexResponse(StatusCode statusCode, ManagedBuffer data, 
long dataFileLen) {
     super(statusCode);
     this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen);
   }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
index 0736780b0..4a9a215a7 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
@@ -25,7 +25,6 @@ import java.util.function.Supplier;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
-import io.netty.buffer.ByteBufAllocator;
 import io.netty.buffer.CompositeByteBuf;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
@@ -38,6 +37,7 @@ import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.common.util.NettyUtils;
 import org.apache.uniffle.server.ShuffleDataFlushEvent;
 import org.apache.uniffle.server.ShuffleFlushManager;
 
@@ -164,7 +164,9 @@ public class ShuffleBuffer {
       if (!bufferSegments.isEmpty()) {
         CompositeByteBuf byteBuf =
             new CompositeByteBuf(
-                ByteBufAllocator.DEFAULT, true, 
Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
+                NettyUtils.getNettyBufferAllocator(),
+                true,
+                Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
         // copy result data
         updateShuffleData(readBlocks, byteBuf);
         return new ShuffleDataResult(byteBuf, bufferSegments);
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index fd1c99e62..e4b469152 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -17,12 +17,10 @@
 
 package org.apache.uniffle.server.netty;
 
-import java.nio.ByteBuffer;
 import java.util.List;
 import java.util.Map;
 
 import com.google.common.collect.Lists;
-import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -36,6 +34,7 @@ import org.apache.uniffle.common.ShufflePartitionedData;
 import org.apache.uniffle.common.config.RssBaseConf;
 import org.apache.uniffle.common.exception.FileNotFoundException;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
 import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.netty.client.TransportClient;
 import org.apache.uniffle.common.netty.handle.BaseMessageHandler;
@@ -252,13 +251,13 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
                     blockId,
                     readBufferSize,
                     req.getExpectedTaskIdsBitmap());
-        ByteBuf data = Unpooled.EMPTY_BUFFER;
+        ManagedBuffer data = NettyManagedBuffer.EMPTY_BUFFER;
         List<BufferSegment> bufferSegments = Lists.newArrayList();
         if (shuffleDataResult != null) {
-          data = Unpooled.wrappedBuffer(shuffleDataResult.getDataBuffer());
+          data = shuffleDataResult.getManagedBuffer();
           bufferSegments = shuffleDataResult.getBufferSegments();
-          
ShuffleServerMetrics.counterTotalReadDataSize.inc(data.readableBytes());
-          
ShuffleServerMetrics.counterTotalReadMemoryDataSize.inc(data.readableBytes());
+          ShuffleServerMetrics.counterTotalReadDataSize.inc(data.size());
+          ShuffleServerMetrics.counterTotalReadMemoryDataSize.inc(data.size());
         }
         long costTime = System.currentTimeMillis() - start;
         shuffleServer
@@ -267,7 +266,7 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
         LOG.info(
             "Successfully getInMemoryShuffleData cost {} ms with {} bytes 
shuffle" + " data for {}",
             costTime,
-            data.readableBytes(),
+            data.size(),
             requestInfo);
 
         response =
@@ -334,21 +333,17 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
                 .getShuffleTaskManager()
                 .getShuffleIndex(appId, shuffleId, partitionId, 
partitionNumPerRange, partitionNum);
 
-        ByteBuffer data = shuffleIndexResult.getIndexData();
-        ShuffleServerMetrics.counterTotalReadDataSize.inc(data.remaining());
-        
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.remaining());
+        ManagedBuffer data = shuffleIndexResult.getManagedBuffer();
+        ShuffleServerMetrics.counterTotalReadDataSize.inc(data.size());
+        
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.size());
         response =
             new GetLocalShuffleIndexResponse(
-                req.getRequestId(),
-                status,
-                msg,
-                Unpooled.wrappedBuffer(data),
-                shuffleIndexResult.getDataFileLen());
+                req.getRequestId(), status, msg, data, 
shuffleIndexResult.getDataFileLen());
         long readTime = System.currentTimeMillis() - start;
         LOG.info(
             "Successfully getShuffleIndex cost {} ms for {}" + " bytes with 
{}",
             readTime,
-            data.remaining(),
+            data.size(),
             requestInfo);
       } catch (FileNotFoundException indexFileNotFoundException) {
         LOG.warn(
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java 
b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
index 2c2e6892e..d7990a126 100644
--- a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
@@ -21,7 +21,6 @@ import java.io.IOException;
 import java.util.concurrent.TimeUnit;
 
 import io.netty.bootstrap.ServerBootstrap;
-import io.netty.buffer.PooledByteBufAllocator;
 import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelInitializer;
 import io.netty.channel.ChannelOption;
@@ -40,6 +39,7 @@ import 
org.apache.uniffle.common.netty.client.TransportContext;
 import org.apache.uniffle.common.rpc.ServerInterface;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.ExitUtils;
+import org.apache.uniffle.common.util.NettyUtils;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.server.ShuffleServer;
 import org.apache.uniffle.server.ShuffleServerConf;
@@ -101,9 +101,9 @@ public class StreamServer implements ServerInterface {
             })
         .option(ChannelOption.SO_BACKLOG, backlogSize)
         .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMillis)
-        .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+        .option(ChannelOption.ALLOCATOR, NettyUtils.getNettyBufferAllocator())
         .childOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMillis)
-        .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+        .childOption(ChannelOption.ALLOCATOR, 
NettyUtils.getNettyBufferAllocator())
         .childOption(ChannelOption.TCP_NODELAY, true)
         .childOption(ChannelOption.SO_KEEPALIVE, true);
 
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index dcdf4ac77..1db3326be 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -27,7 +27,6 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
 import org.apache.uniffle.client.api.ShuffleServerClient;
 import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
-import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.RssUtils;
@@ -121,7 +120,7 @@ public class ShuffleHandlerFactory {
       CreateShuffleReadHandlerRequest request, ShuffleServerInfo ssi) {
     ShuffleServerClient shuffleServerClient =
         ShuffleServerClientFactory.getInstance()
-            .getShuffleServerClient(ClientType.GRPC.name(), ssi, 
request.getClientConf());
+            .getShuffleServerClient(request.getClientType().name(), ssi, 
request.getClientConf());
     Roaring64NavigableMap expectTaskIds = null;
     if (request.isExpectedTaskIdsBitmapFilterEnable()) {
       Roaring64NavigableMap realExceptBlockIds = 
RssUtils.cloneBitMap(request.getExpectBlockIds());
@@ -143,7 +142,7 @@ public class ShuffleHandlerFactory {
       CreateShuffleReadHandlerRequest request, ShuffleServerInfo ssi) {
     ShuffleServerClient shuffleServerClient =
         ShuffleServerClientFactory.getInstance()
-            .getShuffleServerClient(ClientType.GRPC.name(), ssi, 
request.getClientConf());
+            .getShuffleServerClient(request.getClientType().name(), ssi, 
request.getClientConf());
     return new LocalFileClientReadHandler(
         request.getAppId(),
         request.getShuffleId(),
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
index f39637d8e..f3ea3c1f5 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
@@ -77,6 +77,7 @@ public abstract class DataSkippableReadHandler extends 
AbstractClientReadHandler
           SegmentSplitterFactory.getInstance()
               .get(distributionType, expectTaskIds, readBufferSize)
               .split(shuffleIndexResult);
+      shuffleIndexResult.release();
     }
 
     // We should skip unexpected and processed segments when handler is read
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index 90fc0e397..38c7e9efb 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -22,6 +22,7 @@ import java.util.List;
 import org.apache.hadoop.conf.Configuration;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssBaseConf;
@@ -52,6 +53,8 @@ public class CreateShuffleReadHandlerRequest {
 
   private IdHelper idHelper;
 
+  private ClientType clientType;
+
   public CreateShuffleReadHandlerRequest() {}
 
   public RssBaseConf getRssBaseConf() {
@@ -213,4 +216,12 @@ public class CreateShuffleReadHandlerRequest {
   public void setClientConf(RssConf clientConf) {
     this.clientConf = clientConf;
   }
+
+  public ClientType getClientType() {
+    return clientType;
+  }
+
+  public void setClientType(ClientType clientType) {
+    this.clientType = clientType;
+  }
 }
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
index 29ddd6070..884f2b969 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
@@ -24,6 +24,7 @@ import java.util.Map;
 import java.util.stream.Collectors;
 
 import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
 import org.junit.jupiter.api.Test;
 import org.mockito.ArgumentMatcher;
 import org.mockito.Mockito;
@@ -36,6 +37,7 @@ import 
org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
 
@@ -86,7 +88,10 @@ public class LocalFileServerReadHandlerTest {
     int actualWriteDataBlock = expectTotalBlockNum - 1;
     int actualFileLen = blockSize * actualWriteDataBlock;
     RssGetShuffleIndexResponse response =
-        new RssGetShuffleIndexResponse(StatusCode.SUCCESS, byteBuffer, 
actualFileLen);
+        new RssGetShuffleIndexResponse(
+            StatusCode.SUCCESS,
+            new NettyManagedBuffer(Unpooled.wrappedBuffer(byteBuffer)),
+            actualFileLen);
     
Mockito.doReturn(response).when(mockShuffleServerClient).getShuffleIndex(Mockito.any());
 
     int readBufferSize = 13;

Reply via email to