This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b8d6c86 [#133] feat(netty): Implement ShuffleServer interface. (#879)
3b8d6c86 is described below

commit 3b8d6c86b7a2c89ec05ffa67928df640aba15ac3
Author: Xianming Lei <[email protected]>
AuthorDate: Wed May 17 16:29:38 2023 +0800

    [#133] feat(netty): Implement ShuffleServer interface. (#879)
    
    ### What changes were proposed in this pull request?
    
    Implement ShuffleServer interface.
    
    ### Why are the changes needed?
    
    For #133.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    existing UTs.
    
    Co-authored-by: leixianming <[email protected]>
---
 .../apache/uniffle/common/ShuffleDataResult.java   |  23 +-
 .../uniffle/common/ShufflePartitionedBlock.java    |  29 +-
 .../common/netty/client/TransportClient.java       |   4 +
 .../netty/client/TransportClientFactory.java       |  14 +-
 .../common/netty/client/TransportContext.java      |  34 +-
 .../common/netty/handle/BaseMessageHandler.java    |  32 +-
 .../common/netty/handle/MessageHandler.java        |  36 +-
 .../netty/handle/TransportChannelHandler.java      | 139 +++++++
 .../netty/handle/TransportRequestHandler.java      |  61 +++
 .../netty/handle/TransportResponseHandler.java     |  94 ++++-
 .../apache/uniffle/common/util/ByteBufUtils.java   |   5 +
 .../org/apache/uniffle/common/util/Constants.java  |   2 +
 .../common/ShufflePartitionedBlockTest.java        |   4 +-
 .../uniffle/common/util/ByteBufUtilsTest.java      |  26 ++
 .../uniffle/server/buffer/ShuffleBuffer.java       |  16 +-
 .../server/netty/ShuffleServerNettyHandler.java    | 407 +++++++++++++++++++++
 .../apache/uniffle/server/netty/StreamServer.java  |  22 +-
 .../uniffle/server/ShuffleFlushManagerTest.java    |   5 +-
 .../server/buffer/ShuffleBufferManagerTest.java    |   9 +-
 .../uniffle/server/buffer/ShuffleBufferTest.java   |   3 +-
 .../server/storage/MultiStorageManagerTest.java    |   9 +-
 .../StorageManagerFallbackStrategyTest.java        |   6 +-
 .../handler/impl/HdfsShuffleWriteHandler.java      |   3 +-
 .../handler/impl/LocalFileWriteHandler.java        |   4 +-
 .../storage/HdfsShuffleHandlerTestBase.java        |   3 +-
 .../handler/impl/LocalFileHandlerTestBase.java     |   5 +-
 26 files changed, 864 insertions(+), 131 deletions(-)

diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
index 19e93dd1..98a8b1d7 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
@@ -21,10 +21,14 @@ import java.nio.ByteBuffer;
 import java.util.List;
 
 import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.uniffle.common.util.ByteBufUtils;
 
 public class ShuffleDataResult {
 
-  private final ByteBuffer data;
+  private final ByteBuf data;
   private final List<BufferSegment> bufferSegments;
 
   public ShuffleDataResult() {
@@ -36,6 +40,11 @@ public class ShuffleDataResult {
   }
 
   public ShuffleDataResult(ByteBuffer data, List<BufferSegment> 
bufferSegments) {
+    this.data = data != null ? Unpooled.wrappedBuffer(data) : 
Unpooled.EMPTY_BUFFER;
+    this.bufferSegments = bufferSegments;
+  }
+
+  public ShuffleDataResult(ByteBuf data, List<BufferSegment> bufferSegments) {
     this.data = data;
     this.bufferSegments = bufferSegments;
   }
@@ -51,16 +60,17 @@ public class ShuffleDataResult {
     if (data.hasArray()) {
       return data.array();
     }
-    ByteBuffer dataBuffer = data.duplicate();
-    byte[] byteArray = new byte[dataBuffer.remaining()];
-    dataBuffer.get(byteArray);
-    return byteArray;
+    return ByteBufUtils.readBytes(data);
   }
 
-  public ByteBuffer getDataBuffer() {
+  public ByteBuf getDataBuf() {
     return data;
   }
 
+  public ByteBuffer getDataBuffer() {
+    return data.nioBuffer();
+  }
+
   public List<BufferSegment> getBufferSegments() {
     return bufferSegments;
   }
@@ -68,5 +78,4 @@ public class ShuffleDataResult {
   public boolean isEmpty() {
     return bufferSegments == null || bufferSegments.isEmpty() || data == null 
|| data.capacity() == 0;
   }
-
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java 
b/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java
index 9bee3fe3..04ba6542 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java
@@ -17,16 +17,18 @@
 
 package org.apache.uniffle.common;
 
-import java.util.Arrays;
 import java.util.Objects;
 
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
 public class ShufflePartitionedBlock {
 
   private int length;
   private long crc;
   private long blockId;
   private int uncompressLength;
-  private byte[] data;
+  private ByteBuf data;
   private long taskAttemptId;
 
   public ShufflePartitionedBlock(
@@ -41,6 +43,21 @@ public class ShufflePartitionedBlock {
     this.blockId = blockId;
     this.uncompressLength = uncompressLength;
     this.taskAttemptId = taskAttemptId;
+    this.data = data == null ? Unpooled.EMPTY_BUFFER : 
Unpooled.wrappedBuffer(data);
+  }
+
+  public ShufflePartitionedBlock(
+      int length,
+      int uncompressLength,
+      long crc,
+      long blockId,
+      long taskAttemptId,
+      ByteBuf data) {
+    this.length = length;
+    this.crc = crc;
+    this.blockId = blockId;
+    this.uncompressLength = uncompressLength;
+    this.taskAttemptId = taskAttemptId;
     this.data = data;
   }
 
@@ -62,12 +79,12 @@ public class ShufflePartitionedBlock {
     return length == that.length
         && crc == that.crc
         && blockId == that.blockId
-        && Arrays.equals(data, that.data);
+        && data.equals(that.data);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(length, crc, blockId, Arrays.hashCode(data));
+    return Objects.hash(length, crc, blockId, data);
   }
 
   public int getLength() {
@@ -94,11 +111,11 @@ public class ShufflePartitionedBlock {
     this.blockId = blockId;
   }
 
-  public byte[] getData() {
+  public ByteBuf getData() {
     return data;
   }
 
-  public void setData(byte[] data) {
+  public void setData(ByteBuf data) {
     this.data = data;
   }
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
index c82408af..01846a2f 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
@@ -165,4 +165,8 @@ public class TransportClient implements Closeable {
     channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
   }
 
+  public void timeOut() {
+    this.timedOut = true;
+  }
+
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
index c8056151..3e9e706a 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
@@ -40,7 +40,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.netty.IOMode;
 import org.apache.uniffle.common.netty.TransportFrameDecoder;
-import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+import org.apache.uniffle.common.netty.handle.TransportChannelHandler;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.NettyUtils;
 
@@ -121,8 +121,11 @@ public class TransportClientFactory implements Closeable {
       // Make sure that the channel will not timeout by updating the last use 
time of the
       // handler. Then check that the client is still alive, in case it timed 
out before
       // this code was able to update things.
-      TransportResponseHandler handler =
-          
cachedClient.getChannel().pipeline().get(TransportResponseHandler.class);
+      TransportChannelHandler handler =
+          
cachedClient.getChannel().pipeline().get(TransportChannelHandler.class);
+      synchronized (handler) {
+        handler.getResponseHandler().updateTimeOfLastRequest();
+      }
 
       if (cachedClient.isActive()) {
         logger.trace(
@@ -197,9 +200,8 @@ public class TransportClientFactory implements Closeable {
         new ChannelInitializer<SocketChannel>() {
           @Override
           public void initChannel(SocketChannel ch) {
-            TransportResponseHandler transportResponseHandler = 
context.initializePipeline(ch, decoder);
-            TransportClient client = new TransportClient(ch, 
transportResponseHandler);
-            clientRef.set(client);
+            TransportChannelHandler transportResponseHandler = 
context.initializePipeline(ch, decoder);
+            clientRef.set(transportResponseHandler.getClient());
             channelRef.set(ch);
           }
         });
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
index 134b633a..5acb46dd 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
@@ -17,6 +17,7 @@
 
 package org.apache.uniffle.common.netty.client;
 
+import io.netty.channel.Channel;
 import io.netty.channel.ChannelInboundHandlerAdapter;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.handler.timeout.IdleStateHandler;
@@ -24,34 +25,59 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.netty.MessageEncoder;
+import org.apache.uniffle.common.netty.handle.BaseMessageHandler;
+import org.apache.uniffle.common.netty.handle.TransportChannelHandler;
+import org.apache.uniffle.common.netty.handle.TransportRequestHandler;
 import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
 
 public class TransportContext {
   private static final Logger logger = 
LoggerFactory.getLogger(TransportContext.class);
 
   private TransportConf transportConf;
+  private final BaseMessageHandler msgHandler;
+  private boolean closeIdleConnections;
 
   private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
 
   public TransportContext(TransportConf transportConf) {
+    this(transportConf, true);
+  }
+
+  public TransportContext(TransportConf transportConf, boolean 
closeIdleConnections) {
+    this(transportConf, null, closeIdleConnections);
+  }
+
+  public TransportContext(TransportConf transportConf, BaseMessageHandler 
msgHandler, boolean closeIdleConnections) {
     this.transportConf = transportConf;
+    this.msgHandler = msgHandler;
+    this.closeIdleConnections = closeIdleConnections;
   }
 
   public TransportClientFactory createClientFactory() {
     return new TransportClientFactory(this);
   }
 
-  public TransportResponseHandler initializePipeline(
+  public TransportChannelHandler initializePipeline(
       SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
-    TransportResponseHandler responseHandler = new 
TransportResponseHandler(channel);
+    TransportChannelHandler channelHandler = createChannelHandler(channel, 
msgHandler);
     channel
         .pipeline()
         .addLast("encoder", ENCODER) // out
         .addLast("decoder", decoder) // in
         .addLast(
             "idleStateHandler", new IdleStateHandler(0, 0, 
transportConf.connectionTimeoutMs() / 1000))
-        .addLast("responseHandler", responseHandler);
-    return responseHandler;
+        .addLast("responseHandler", channelHandler);
+    return channelHandler;
+  }
+
+  private TransportChannelHandler createChannelHandler(
+      Channel channel, BaseMessageHandler msgHandler) {
+    TransportResponseHandler responseHandler = new 
TransportResponseHandler(channel);
+    TransportClient client = new TransportClient(channel, responseHandler);
+    TransportRequestHandler requestHandler =
+        new TransportRequestHandler(client, msgHandler);
+    return new TransportChannelHandler(
+        client, responseHandler, requestHandler, 
transportConf.connectionTimeoutMs(), closeIdleConnections);
   }
 
   public TransportConf getConf() {
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/BaseMessageHandler.java
similarity index 54%
copy from 
server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
copy to 
common/src/main/java/org/apache/uniffle/common/netty/handle/BaseMessageHandler.java
index d2551e55..91ba7a01 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/BaseMessageHandler.java
@@ -15,34 +15,14 @@
  * limitations under the License.
  */
 
-package org.apache.uniffle.server.netty.decoder;
+package org.apache.uniffle.common.netty.handle;
 
-import java.util.List;
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
 
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.handler.codec.ByteToMessageDecoder;
+public interface BaseMessageHandler {
 
-public class StreamServerInitDecoder extends ByteToMessageDecoder {
+  void receive(TransportClient client, RequestMessage msg);
 
-  public StreamServerInitDecoder() {
-  }
-
-  private void addDecoder(ChannelHandlerContext ctx, byte type) {
-
-  }
-
-  @Override
-  protected void decode(ChannelHandlerContext ctx,
-      ByteBuf in,
-      List<Object> out) {
-    if (in.readableBytes() < Byte.BYTES) {
-      return;
-    }
-    in.markReaderIndex();
-    byte magicByte = in.readByte();
-    in.resetReaderIndex();
-
-    addDecoder(ctx, magicByte);
-  }
+  void exceptionCaught(Throwable cause, TransportClient client);
 }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/MessageHandler.java
similarity index 54%
rename from 
server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
rename to 
common/src/main/java/org/apache/uniffle/common/netty/handle/MessageHandler.java
index d2551e55..6712ddfc 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/MessageHandler.java
@@ -15,34 +15,20 @@
  * limitations under the License.
  */
 
-package org.apache.uniffle.server.netty.decoder;
+package org.apache.uniffle.common.netty.handle;
 
-import java.util.List;
+import org.apache.uniffle.common.netty.protocol.Message;
 
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.handler.codec.ByteToMessageDecoder;
+public abstract class MessageHandler<T extends Message> {
+  /** Handles the receipt of a single message. */
+  public abstract void handle(T message) throws Exception;
 
-public class StreamServerInitDecoder extends ByteToMessageDecoder {
+  /** Invoked when the channel this MessageHandler is on is active. */
+  public abstract void channelActive();
 
-  public StreamServerInitDecoder() {
-  }
+  /** Invoked when an exception was caught on the Channel. */
+  public abstract void exceptionCaught(Throwable cause);
 
-  private void addDecoder(ChannelHandlerContext ctx, byte type) {
-
-  }
-
-  @Override
-  protected void decode(ChannelHandlerContext ctx,
-      ByteBuf in,
-      List<Object> out) {
-    if (in.readableBytes() < Byte.BYTES) {
-      return;
-    }
-    in.markReaderIndex();
-    byte magicByte = in.readByte();
-    in.resetReaderIndex();
-
-    addDecoder(ctx, magicByte);
-  }
+  /** Invoked when the channel this MessageHandler is on is inactive. */
+  public abstract void channelInactive();
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportChannelHandler.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportChannelHandler.java
new file mode 100644
index 00000000..f7fb9f33
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportChannelHandler.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.handle;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.timeout.IdleState;
+import io.netty.handler.timeout.IdleStateEvent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.util.NettyUtils;
+
+public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
+  private static final Logger logger = 
LoggerFactory.getLogger(TransportChannelHandler.class);
+
+  private final TransportClient client;
+  private final TransportResponseHandler responseHandler;
+  private final TransportRequestHandler requestHandler;
+  private final long requestTimeoutNs;
+  private final boolean closeIdleConnections;
+
+  public TransportChannelHandler(
+      TransportClient client,
+      TransportResponseHandler responseHandler,
+      TransportRequestHandler requestHandler,
+      long requestTimeoutMs,
+      boolean closeIdleConnections) {
+    this.client = client;
+    this.responseHandler = responseHandler;
+    this.requestHandler = requestHandler;
+    this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
+    this.closeIdleConnections = closeIdleConnections;
+  }
+
+  public TransportClient getClient() {
+    return client;
+  }
+
+  @Override
+  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) 
throws Exception {
+    logger.warn(
+        "Exception in connection from " + 
NettyUtils.getRemoteAddress(ctx.channel()), cause);
+    requestHandler.exceptionCaught(cause);
+    responseHandler.exceptionCaught(cause);
+    ctx.close();
+  }
+
+  @Override
+  public void channelActive(ChannelHandlerContext ctx) throws Exception {
+    try {
+      requestHandler.channelActive();
+    } catch (RuntimeException e) {
+      logger.error("Exception from request handler while channel is active", 
e);
+    }
+    try {
+      responseHandler.channelActive();
+    } catch (RuntimeException e) {
+      logger.error("Exception from response handler while channel is active", 
e);
+    }
+    super.channelActive(ctx);
+  }
+
+  @Override
+  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+    try {
+      requestHandler.channelInactive();
+    } catch (RuntimeException e) {
+      logger.error("Exception from request handler while channel is inactive", 
e);
+    }
+    try {
+      responseHandler.channelInactive();
+    } catch (RuntimeException e) {
+      logger.error("Exception from response handler while channel is 
inactive", e);
+    }
+    super.channelInactive(ctx);
+  }
+
+  @Override
+  public void channelRead(ChannelHandlerContext ctx, Object request) throws 
Exception {
+    if (request instanceof RequestMessage) {
+      requestHandler.handle((RequestMessage) request);
+    } else if (request instanceof RpcResponse) {
+      responseHandler.handle((RpcResponse) request);
+    } else {
+      ctx.fireChannelRead(request);
+    }
+  }
+
+  /** Triggered based on events from an {@link 
io.netty.handler.timeout.IdleStateHandler}. */
+  @Override
+  public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
+    if (evt instanceof IdleStateEvent) {
+      IdleStateEvent e = (IdleStateEvent) evt;
+      synchronized (this) {
+        boolean isActuallyOverdue =
+            System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > 
requestTimeoutNs;
+        if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
+          if (responseHandler.numOutstandingRequests() > 0) {
+            String address = NettyUtils.getRemoteAddress(ctx.channel());
+            logger.error(
+                "Connection to {} has been quiet for {} ms while there are 
outstanding "
+                    + "requests.",
+                address,
+                requestTimeoutNs / 1000 / 1000);
+          }
+          if (closeIdleConnections) {
+            // While CloseIdleConnections is enable, we also close idle 
connection
+            client.timeOut();
+            ctx.close();
+          }
+        }
+      }
+    }
+    ctx.fireUserEventTriggered(evt);
+  }
+
+  public TransportResponseHandler getResponseHandler() {
+    return responseHandler;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportRequestHandler.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportRequestHandler.java
new file mode 100644
index 00000000..795cf270
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportRequestHandler.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.handle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
+
+public class TransportRequestHandler extends MessageHandler<RequestMessage> {
+
+  private static final Logger logger = 
LoggerFactory.getLogger(TransportRequestHandler.class);
+
+  /** Client on the same channel allowing us to talk back to the requester. */
+  private final TransportClient reverseClient;
+
+  /** Handles all RPC messages. */
+  private final BaseMessageHandler msgHandler;
+
+  public TransportRequestHandler(TransportClient reverseClient, 
BaseMessageHandler msgHandler) {
+    this.reverseClient = reverseClient;
+    this.msgHandler = msgHandler;
+  }
+
+  @Override
+  public void exceptionCaught(Throwable cause) {
+    msgHandler.exceptionCaught(cause, reverseClient);
+  }
+
+  @Override
+  public void channelActive() {
+    logger.debug("channelActive: {}", reverseClient.getSocketAddress());
+  }
+
+  @Override
+  public void channelInactive() {
+    logger.debug("channelInactive: {}", reverseClient.getSocketAddress());
+  }
+
+  @Override
+  public void handle(RequestMessage request) {
+    msgHandler.receive(reverseClient, request);
+  }
+}
+
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
index 86dd9953..b0681503 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
@@ -17,56 +17,110 @@
 
 package org.apache.uniffle.common.netty.handle;
 
+import java.io.IOException;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
 
 import io.netty.channel.Channel;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.netty.client.RpcResponseCallback;
 import org.apache.uniffle.common.netty.protocol.RpcResponse;
 import org.apache.uniffle.common.util.NettyUtils;
 
 
-public class TransportResponseHandler extends ChannelInboundHandlerAdapter {
+public class TransportResponseHandler extends MessageHandler<RpcResponse> {
   private static final Logger logger = 
LoggerFactory.getLogger(TransportResponseHandler.class);
 
   private Map<Long, RpcResponseCallback> outstandingRpcRequests;
   private Channel channel;
 
+  /** Records the time (in system nanoseconds) that the last fetch or RPC 
request was sent. */
+  private final AtomicLong timeOfLastRequestNs;
+
   public TransportResponseHandler(Channel channel) {
     this.channel = channel;
     this.outstandingRpcRequests = new ConcurrentHashMap<>();
+    this.timeOfLastRequestNs = new AtomicLong(0);
+  }
+
+  public void addResponseCallback(long requestId, RpcResponseCallback 
callback) {
+    updateTimeOfLastRequest();
+    if (outstandingRpcRequests.containsKey(requestId)) {
+      logger.warn("[addRpcRequest] requestId {} already exists!", requestId);
+    }
+    outstandingRpcRequests.put(requestId, callback);
+  }
+
+  public void removeRpcRequest(long requestId) {
+    outstandingRpcRequests.remove(requestId);
   }
 
   @Override
-  public void channelRead(ChannelHandlerContext ctx, Object msg) throws 
Exception {
-    if (msg instanceof RpcResponse) {
-      RpcResponse responseMessage = (RpcResponse) msg;
-      RpcResponseCallback listener = 
outstandingRpcRequests.get(responseMessage.getRequestId());
-      if (listener == null) {
-        logger.warn("Ignoring response from {} since it is not outstanding",
-            NettyUtils.getRemoteAddress(channel));
-      } else {
-        listener.onSuccess(responseMessage);
-      }
+  public void handle(RpcResponse message) throws Exception {
+    RpcResponseCallback listener = 
outstandingRpcRequests.get(message.getRequestId());
+    if (listener == null) {
+      logger.warn("Ignoring response from {} since it is not outstanding",
+          NettyUtils.getRemoteAddress(channel));
     } else {
-      throw new RssException("receive unexpected message!");
+      listener.onSuccess(message);
     }
-    super.channelRead(ctx, msg);
   }
 
-  public void addResponseCallback(long requestId, RpcResponseCallback 
callback) {
-    outstandingRpcRequests.put(requestId, callback);
+  @Override
+  public void channelActive() {
+
   }
 
-  public void removeRpcRequest(long requestId) {
-    outstandingRpcRequests.remove(requestId);
+  @Override
+  public void exceptionCaught(Throwable cause) {
+    if (numOutstandingRequests() > 0) {
+      String remoteAddress = NettyUtils.getRemoteAddress(channel);
+      logger.error(
+          "Still have {} requests outstanding when connection from {} is 
closed",
+          numOutstandingRequests(),
+          remoteAddress);
+      failOutstandingRequests(cause);
+    }
+  }
+
+  @Override
+  public void channelInactive() {
+    if (numOutstandingRequests() > 0) {
+      String remoteAddress = NettyUtils.getRemoteAddress(channel);
+      logger.error(
+          "Still have {} requests outstanding when connection from {} is 
closed",
+          numOutstandingRequests(),
+          remoteAddress);
+      failOutstandingRequests(new IOException("Connection from " + 
remoteAddress + " closed"));
+    }
+  }
+
+  public int numOutstandingRequests() {
+    return outstandingRpcRequests.size();
+  }
+
+  private void failOutstandingRequests(Throwable cause) {
+    for (Map.Entry<Long, RpcResponseCallback> entry : 
outstandingRpcRequests.entrySet()) {
+      try {
+        entry.getValue().onFailure(cause);
+      } catch (Exception e) {
+        logger.warn("RpcResponseCallback.onFailure throws exception", e);
+      }
+    }
+
+    outstandingRpcRequests.clear();
   }
 
+  /** Returns the time in nanoseconds of when the last request was sent out. */
+  public long getTimeOfLastRequestNs() {
+    return timeOfLastRequestNs.get();
+  }
 
+  /** Updates the time of the last request to the current system time. */
+  public void updateTimeOfLastRequest() {
+    timeOfLastRequestNs.set(System.nanoTime());
+  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
index afdfae27..8288df8f 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
@@ -80,4 +80,9 @@ public class ByteBufUtils {
     buf.resetReaderIndex();
     return bytes;
   }
+
+  public static void readBytes(ByteBuf from, byte[] to, int offset, int 
length) {
+    from.readBytes(to, offset, length);
+    from.resetReaderIndex();
+  }
 }
diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java 
b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
index af959c5b..6fdd2754 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
@@ -75,4 +75,6 @@ public final class Constants {
   public static final String DEVICE_NO_SPACE_ERROR_MESSAGE = "No space left on 
device";
   public static final String NETTY_STREAM_SERVICE_NAME = "netty.rpc.server";
   public static final String GRPC_SERVICE_NAME = "grpc.server";
+
+  public static final int COMPOSITE_BYTE_BUF_MAX_COMPONENTS = 1024;
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
 
b/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
index d8e660d4..569643bd 100644
--- 
a/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
+++ 
b/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
@@ -23,6 +23,8 @@ import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.CsvSource;
 
+import org.apache.uniffle.common.util.ByteBufUtils;
+
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
@@ -40,7 +42,7 @@ public class ShufflePartitionedBlockTest {
     assertEquals(3, b1.getBlockId());
 
     ShufflePartitionedBlock b3 = new ShufflePartitionedBlock(1, 1, 2, 3, 3, 
buf);
-    assertArrayEquals(buf, b3.getData());
+    assertArrayEquals(buf, ByteBufUtils.readBytes(b3.getData()));
   }
 
   @Test
diff --git 
a/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java 
b/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
index 3f60eaef..fa2f6099 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
@@ -17,7 +17,10 @@
 
 package org.apache.uniffle.common.util;
 
+import java.nio.charset.StandardCharsets;
+
 import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
 import io.netty.buffer.Unpooled;
 import org.junit.jupiter.api.Test;
 
@@ -42,5 +45,28 @@ public class ByteBufUtilsTest {
     byteBuf.clear();
     ByteBufUtils.writeLengthAndString(byteBuf, null);
     assertNull(ByteBufUtils.readLengthAndString(byteBuf));
+
+    byteBuf.clear();
+    ByteBufUtils.writeLengthAndString(byteBuf, expectedString);
+    ByteBuf byteBuf1 = Unpooled.buffer(100);
+    ByteBufUtils.writeLengthAndString(byteBuf1, expectedString);
+    final int expectedLength = byteBuf.readableBytes() + 
byteBuf1.readableBytes();
+    CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+    compositeByteBuf.addComponent(true, byteBuf);
+    compositeByteBuf.addComponent(true, byteBuf1);
+
+    ByteBuf res = Unpooled.buffer(100);
+    ByteBufUtils.copyByteBuf(compositeByteBuf, res);
+    assertEquals(expectedLength, res.readableBytes() - Integer.BYTES);
+
+    res.clear();
+    ByteBufUtils.copyByteBuf(compositeByteBuf, res);
+    assertEquals(expectedLength, res.readableBytes()  - Integer.BYTES);
+
+    byteBuf.clear();
+    byte[] bytes = expectedString.getBytes(StandardCharsets.UTF_8);
+    byteBuf.writeBytes(bytes);
+    ByteBufUtils.readBytes(byteBuf, bytes, 1, byteBuf.readableBytes() - 1);
+    assertEquals("ttest_st", new String(bytes, StandardCharsets.UTF_8));
   }
 }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
index 78d76b38..fcb6567e 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
@@ -25,6 +25,8 @@ import java.util.function.Supplier;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.CompositeByteBuf;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -161,11 +163,11 @@ public class ShuffleBuffer {
       updateBufferSegmentsAndResultBlocks(
           lastBlockId, readBufferSize, bufferSegments, readBlocks, 
expectedTaskIds);
       if (!bufferSegments.isEmpty()) {
-        int length = calculateDataLength(bufferSegments);
-        byte[] data = new byte[length];
+        CompositeByteBuf byteBuf =
+            new CompositeByteBuf(ByteBufAllocator.DEFAULT, true, 
Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
         // copy result data
-        updateShuffleData(readBlocks, data);
-        return new ShuffleDataResult(data, bufferSegments);
+        updateShuffleData(readBlocks, byteBuf);
+        return new ShuffleDataResult(byteBuf, bufferSegments);
       }
     } catch (Exception e) {
       LOG.error("Exception happened when getShuffleData in buffer", e);
@@ -238,16 +240,16 @@ public class ShuffleBuffer {
     return bufferSegment.getOffset() + bufferSegment.getLength();
   }
 
-  private void updateShuffleData(List<ShufflePartitionedBlock> readBlocks, 
byte[] data) {
+  private void updateShuffleData(List<ShufflePartitionedBlock> readBlocks, 
CompositeByteBuf data) {
     int offset = 0;
     for (ShufflePartitionedBlock block : readBlocks) {
       // fill shuffle data
       try {
-        System.arraycopy(block.getData(), 0, data, offset, block.getLength());
+        data.addComponent(true, block.getData().retain());
       } catch (Exception e) {
         LOG.error("Unexpected exception for System.arraycopy, length["
             + block.getLength() + "], offset["
-            + offset + "], dataLength[" + data.length + "]", e);
+            + offset + "], dataLength[" + data.capacity() + "]", e);
         throw e;
       }
       offset += block.getLength();
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
new file mode 100644
index 00000000..b87d57a3
--- /dev/null
+++ 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.netty;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.ShuffleIndexResult;
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.exception.FileNotFoundException;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.handle.BaseMessageHandler;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.server.ShuffleDataReadEvent;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.server.ShuffleServerGrpcMetrics;
+import org.apache.uniffle.server.ShuffleServerMetrics;
+import org.apache.uniffle.server.ShuffleTaskManager;
+import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo;
+import org.apache.uniffle.storage.common.Storage;
+import org.apache.uniffle.storage.common.StorageReadMetrics;
+import org.apache.uniffle.storage.util.ShuffleStorageUtils;
+
+public class ShuffleServerNettyHandler implements BaseMessageHandler {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleServerNettyHandler.class);
+  private static final int RPC_TIMEOUT = 60000;
+  private final ShuffleServer shuffleServer;
+
+  public ShuffleServerNettyHandler(ShuffleServer shuffleServer) {
+    this.shuffleServer = shuffleServer;
+  }
+
+  @Override
+  public void receive(TransportClient client, RequestMessage msg) {
+    if (msg instanceof SendShuffleDataRequest) {
+      handleSendShuffleDataRequest(client, (SendShuffleDataRequest)msg);
+    } else if (msg instanceof GetLocalShuffleDataRequest) {
+      handleGetLocalShuffleData(client, (GetLocalShuffleDataRequest)msg);
+    } else if (msg instanceof GetLocalShuffleIndexRequest) {
+      handleGetLocalShuffleIndexRequest(client, 
(GetLocalShuffleIndexRequest)msg);
+    } else if (msg instanceof GetMemoryShuffleDataRequest) {
+      handleGetMemoryShuffleDataRequest(client, 
(GetMemoryShuffleDataRequest)msg);
+    } else {
+      throw new RssException("Can not handle message " + msg.type());
+    }
+  }
+
+  @Override
+  public void exceptionCaught(Throwable cause, TransportClient client) {
+    LOG.error("exception caught {}", client.getSocketAddress(), cause);
+  }
+
+  public void handleSendShuffleDataRequest(
+      TransportClient client, SendShuffleDataRequest req) {
+    RpcResponse rpcResponse;
+    String appId = req.getAppId();
+    int shuffleId = req.getShuffleId();
+    long requireBufferId = req.getRequireId();
+    long timestamp = req.getTimestamp();
+    if (timestamp > 0) {
+      /*
+       * Here we record the transport time, but we don't consider the impact 
of data size on transport time.
+       * The amount of data will not cause great fluctuations in latency. For 
example, 100K costs 1ms,
+       * and 1M costs 10ms. This seems like a normal fluctuation, but it may 
rise to 10s when the server load is high.
+       * In addition, we need to pay attention to that the time of the client 
machine and the machine
+       * time of the Shuffle Server should be kept in sync. TransportTime is 
accurate only if this condition is met.
+       * */
+      long transportTime = System.currentTimeMillis() - timestamp;
+      if (transportTime > 0) {
+        shuffleServer.getGrpcMetrics().recordTransportTime(
+            ShuffleServerGrpcMetrics.SEND_SHUFFLE_DATA_METHOD, transportTime);
+      }
+    }
+    int requireSize = shuffleServer
+                          
.getShuffleTaskManager().getRequireBufferSize(requireBufferId);
+
+    StatusCode ret = StatusCode.SUCCESS;
+    String responseMessage = "OK";
+    if (req.getPartitionToBlocks().size() > 0) {
+      ShuffleServerMetrics.counterTotalReceivedDataSize.inc(requireSize);
+      ShuffleTaskManager manager = shuffleServer.getShuffleTaskManager();
+      PreAllocatedBufferInfo info = 
manager.getAndRemovePreAllocatedBuffer(requireBufferId);
+      boolean isPreAllocated = info != null;
+      if (!isPreAllocated) {
+        String errorMsg = "Can't find requireBufferId[" + requireBufferId + "] 
for appId[" + appId
+                              + "], shuffleId[" + shuffleId + "]";
+        LOG.warn(errorMsg);
+        responseMessage = errorMsg;
+        rpcResponse = new RpcResponse(req.getRequestId(), 
StatusCode.INTERNAL_ERROR, responseMessage);
+        client.sendRpcSync(rpcResponse, RPC_TIMEOUT);
+        return;
+      }
+      final long start = System.currentTimeMillis();
+      List<ShufflePartitionedData> shufflePartitionedData = 
toPartitionedData(req);
+      long alreadyReleasedSize = 0;
+      for (ShufflePartitionedData spd : shufflePartitionedData) {
+        String shuffleDataInfo = "appId[" + appId + "], shuffleId[" + shuffleId
+                                     + "], partitionId[" + 
spd.getPartitionId() + "]";
+        try {
+          ret = manager.cacheShuffleData(appId, shuffleId, isPreAllocated, 
spd);
+          if (ret != StatusCode.SUCCESS) {
+            String errorMsg = "Error happened when shuffleEngine.write for "
+                                  + shuffleDataInfo + ", statusCode=" + ret;
+            LOG.error(errorMsg);
+            responseMessage = errorMsg;
+            break;
+          } else {
+            long toReleasedSize = spd.getTotalBlockSize();
+            // after each cacheShuffleData call, the `preAllocatedSize` is 
updated timely.
+            manager.releasePreAllocatedSize(toReleasedSize);
+            alreadyReleasedSize += toReleasedSize;
+            manager.updateCachedBlockIds(appId, shuffleId, 
spd.getPartitionId(), spd.getBlockList());
+          }
+        } catch (Exception e) {
+          String errorMsg = "Error happened when shuffleEngine.write for "
+                                + shuffleDataInfo + ": " + e.getMessage();
+          ret = StatusCode.INTERNAL_ERROR;
+          responseMessage = errorMsg;
+          LOG.error(errorMsg);
+          break;
+        }
+      }
+      // since the required buffer id is only used once, the shuffle client 
would try to require another buffer whether
+      // current connection succeeded or not. Therefore, the 
preAllocatedBuffer is first get and removed, then after
+      // cacheShuffleData finishes, the preAllocatedSize should be updated 
accordingly.
+      if (info.getRequireSize() > alreadyReleasedSize) {
+        manager.releasePreAllocatedSize(info.getRequireSize() - 
alreadyReleasedSize);
+      }
+      rpcResponse = new RpcResponse(req.getRequestId(), ret, responseMessage);
+      long costTime = System.currentTimeMillis() - start;
+      
shuffleServer.getGrpcMetrics().recordProcessTime(ShuffleServerGrpcMetrics.SEND_SHUFFLE_DATA_METHOD,
 costTime);
+      LOG.debug("Cache Shuffle Data for appId[" + appId + "], shuffleId[" + 
shuffleId
+                    + "], cost " + costTime
+                    + " ms with " + shufflePartitionedData.size() + " blocks 
and " + requireSize + " bytes");
+    } else {
+      rpcResponse =  new RpcResponse(req.getRequestId(), 
StatusCode.INTERNAL_ERROR, "No data in request");
+    }
+
+    client.sendRpcSync(rpcResponse, RPC_TIMEOUT);
+  }
+
+  public void handleGetMemoryShuffleDataRequest(
+      TransportClient client, GetMemoryShuffleDataRequest req) {
+    String appId = req.getAppId();
+    int shuffleId = req.getShuffleId();
+    int partitionId = req.getPartitionId();
+    long blockId = req.getLastBlockId();
+    int readBufferSize = req.getReadBufferSize();
+    long timestamp = req.getTimestamp();
+
+    if (timestamp > 0) {
+      long transportTime = System.currentTimeMillis() - timestamp;
+      if (transportTime > 0) {
+        shuffleServer.getGrpcMetrics().recordTransportTime(
+            ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD, 
transportTime);
+      }
+    }
+    long start = System.currentTimeMillis();
+    StatusCode status = StatusCode.SUCCESS;
+    String msg = "OK";
+    GetMemoryShuffleDataResponse response;
+    String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "], 
partitionId["
+                             + partitionId + "]";
+
+    // todo: if can get the exact memory size?
+    if 
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(readBufferSize))
 {
+      try {
+        ShuffleDataResult shuffleDataResult = shuffleServer
+                                                  .getShuffleTaskManager()
+                                                  .getInMemoryShuffleData(
+                                                      appId,
+                                                      shuffleId,
+                                                      partitionId,
+                                                      blockId,
+                                                      readBufferSize,
+                                                      
req.getExpectedTaskIdsBitmap()
+                                                  );
+        ByteBuf data = Unpooled.EMPTY_BUFFER;
+        List<BufferSegment> bufferSegments = Lists.newArrayList();
+        if (shuffleDataResult != null) {
+          data = Unpooled.wrappedBuffer(shuffleDataResult.getDataBuffer());
+          bufferSegments = shuffleDataResult.getBufferSegments();
+          
ShuffleServerMetrics.counterTotalReadDataSize.inc(data.readableBytes());
+          
ShuffleServerMetrics.counterTotalReadMemoryDataSize.inc(data.readableBytes());
+        }
+        long costTime = System.currentTimeMillis() - start;
+        shuffleServer.getGrpcMetrics().recordProcessTime(
+            ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD, costTime);
+        LOG.info("Successfully getInMemoryShuffleData cost {} ms with {} bytes 
shuffle"
+                     + " data for {}", costTime, data.readableBytes(), 
requestInfo);
+
+        response = new GetMemoryShuffleDataResponse(req.getRequestId(), 
status, msg, bufferSegments, data);
+      } catch (Exception e) {
+        status = StatusCode.INTERNAL_ERROR;
+        msg = "Error happened when get in memory shuffle data for "
+                  + requestInfo + ", " + e.getMessage();
+        LOG.error(msg, e);
+        response = new GetMemoryShuffleDataResponse(req.getRequestId(),
+            status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER);
+      } finally {
+        
shuffleServer.getShuffleBufferManager().releaseReadMemory(readBufferSize);
+      }
+    } else {
+      status = StatusCode.INTERNAL_ERROR;
+      msg = "Can't require memory to get in memory shuffle data";
+      LOG.error(msg + " for " + requestInfo);
+      response = new GetMemoryShuffleDataResponse(req.getRequestId(),
+          status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER);
+    }
+    client.sendRpcSync(response, RPC_TIMEOUT);
+  }
+
+  public void handleGetLocalShuffleIndexRequest(
+      TransportClient client, GetLocalShuffleIndexRequest req) {
+    String appId = req.getAppId();
+    int shuffleId = req.getShuffleId();
+    int partitionId = req.getPartitionId();
+    int partitionNumPerRange = req.getPartitionNumPerRange();
+    int partitionNum = req.getPartitionNum();
+    StatusCode status = StatusCode.SUCCESS;
+    String msg = "OK";
+    GetLocalShuffleIndexResponse response;
+    String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "], 
partitionId["
+                             + partitionId + "]";
+
+    int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, 
partitionNumPerRange, partitionNum);
+    Storage storage = shuffleServer.getStorageManager()
+                          .selectStorage(new ShuffleDataReadEvent(appId, 
shuffleId, partitionId, range[0]));
+    if (storage != null) {
+      storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+    }
+    // Index file is expected small size and won't cause oom problem with the 
assumed size. An index segment is 40B,
+    // with the default size - 2MB, it can support 50k blocks for shuffle data.
+    long assumedFileSize = shuffleServer
+                               
.getShuffleServerConf().getLong(ShuffleServerConf.SERVER_SHUFFLE_INDEX_SIZE_HINT);
+    if 
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(assumedFileSize))
 {
+      try {
+        final long start = System.currentTimeMillis();
+        ShuffleIndexResult shuffleIndexResult = 
shuffleServer.getShuffleTaskManager().getShuffleIndex(
+            appId, shuffleId, partitionId, partitionNumPerRange, partitionNum);
+
+        ByteBuffer data = shuffleIndexResult.getIndexData();
+        ShuffleServerMetrics.counterTotalReadDataSize.inc(data.remaining());
+        
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.remaining());
+        response = new GetLocalShuffleIndexResponse(req.getRequestId(),
+            status, msg, Unpooled.wrappedBuffer(data), 
shuffleIndexResult.getDataFileLen());
+        long readTime = System.currentTimeMillis() - start;
+        LOG.info("Successfully getShuffleIndex cost {} ms for {}"
+                     + " bytes with {}", readTime, data.remaining(), 
requestInfo);
+      } catch (FileNotFoundException indexFileNotFoundException) {
+        LOG.warn("Index file for {} is not found, maybe the data has been 
flushed to cold storage.",
+            requestInfo, indexFileNotFoundException);
+        response = new GetLocalShuffleIndexResponse(req.getRequestId(), 
status, msg, Unpooled.EMPTY_BUFFER, 0L);
+      } catch (Exception e) {
+        status = StatusCode.INTERNAL_ERROR;
+        msg = "Error happened when get shuffle index for " + requestInfo + ", 
" + e.getMessage();
+        LOG.error(msg, e);
+        response = new GetLocalShuffleIndexResponse(req.getRequestId(), 
status, msg, Unpooled.EMPTY_BUFFER, 0L);
+      } finally {
+        
shuffleServer.getShuffleBufferManager().releaseReadMemory(assumedFileSize);
+      }
+    } else {
+      status = StatusCode.INTERNAL_ERROR;
+      msg = "Can't require memory to get shuffle index";
+      LOG.error(msg + " for " + requestInfo);
+      response = new GetLocalShuffleIndexResponse(req.getRequestId(), status, 
msg, Unpooled.EMPTY_BUFFER, 0L);
+    }
+    client.sendRpcSync(response, RPC_TIMEOUT);
+  }
+
+  public void handleGetLocalShuffleData(
+      TransportClient client, GetLocalShuffleDataRequest req) {
+    String appId = req.getAppId();
+    int shuffleId = req.getShuffleId();
+    int partitionId = req.getPartitionId();
+    int partitionNumPerRange = req.getPartitionNumPerRange();
+    int partitionNum = req.getPartitionNum();
+    long offset = req.getOffset();
+    int length = req.getLength();
+    long timestamp = req.getTimestamp();
+    if (timestamp > 0) {
+      long transportTime = System.currentTimeMillis() - timestamp;
+      if (transportTime > 0) {
+        shuffleServer.getGrpcMetrics().recordTransportTime(
+            ShuffleServerGrpcMetrics.GET_SHUFFLE_DATA_METHOD, transportTime);
+      }
+    }
+    String storageType = 
shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE);
+    StatusCode status = StatusCode.SUCCESS;
+    String msg = "OK";
+    GetLocalShuffleDataResponse response;
+    ShuffleDataResult sdr;
+    String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "], 
partitionId["
+                             + partitionId + "]" + "offset[" + offset + "]" + 
"length[" + length + "]";
+
+    int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, 
partitionNumPerRange, partitionNum);
+    Storage storage = shuffleServer
+                          .getStorageManager()
+                          .selectStorage(
+                              new ShuffleDataReadEvent(appId, shuffleId, 
partitionId, range[0])
+                          );
+    if (storage != null) {
+      storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+    }
+
+    if 
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(length)) {
+      try {
+        long start = System.currentTimeMillis();
+        sdr = shuffleServer.getShuffleTaskManager().getShuffleData(appId, 
shuffleId, partitionId,
+            partitionNumPerRange, partitionNum, storageType, offset, length);
+        long readTime = System.currentTimeMillis() - start;
+        ShuffleServerMetrics.counterTotalReadTime.inc(readTime);
+        
ShuffleServerMetrics.counterTotalReadDataSize.inc(sdr.getData().length);
+        
ShuffleServerMetrics.counterTotalReadLocalDataFileSize.inc(sdr.getData().length);
+        shuffleServer.getGrpcMetrics().recordProcessTime(
+            ShuffleServerGrpcMetrics.GET_SHUFFLE_DATA_METHOD, readTime);
+        LOG.info("Successfully getShuffleData cost {} ms for shuffle"
+            + " data with {}", readTime, requestInfo);
+        response = new GetLocalShuffleDataResponse(req.getRequestId(),
+            status, msg, sdr.getDataBuf());
+      } catch (Exception e) {
+        status = StatusCode.INTERNAL_ERROR;
+        msg = "Error happened when get shuffle data for " + requestInfo + ", " 
+ e.getMessage();
+        LOG.error(msg, e);
+        response = new GetLocalShuffleDataResponse(req.getRequestId(), status, 
msg, Unpooled.EMPTY_BUFFER);
+      } finally {
+        shuffleServer.getShuffleBufferManager().releaseReadMemory(length);
+      }
+    } else {
+      status = StatusCode.INTERNAL_ERROR;
+      msg = "Can't require memory to get shuffle data";
+      LOG.error(msg + " for " + requestInfo);
+      response = new GetLocalShuffleDataResponse(req.getRequestId(), status, 
msg, Unpooled.EMPTY_BUFFER);
+    }
+    client.sendRpcSync(response, RPC_TIMEOUT);
+  }
+
+  private List<ShufflePartitionedData> 
toPartitionedData(SendShuffleDataRequest req) {
+    List<ShufflePartitionedData> ret = Lists.newArrayList();
+
+    for (Map.Entry<Integer, List<ShuffleBlockInfo>> entry: 
req.getPartitionToBlocks().entrySet()) {
+      ret.add(new ShufflePartitionedData(
+          entry.getKey(),
+          toPartitionedBlock(entry.getValue())));
+    }
+    return ret;
+  }
+
+  private ShufflePartitionedBlock[] toPartitionedBlock(List<ShuffleBlockInfo> 
blocks) {
+    if (blocks == null || blocks.size() == 0) {
+      return new ShufflePartitionedBlock[]{};
+    }
+    ShufflePartitionedBlock[] ret = new ShufflePartitionedBlock[blocks.size()];
+    int i = 0;
+    for (ShuffleBlockInfo block : blocks) {
+      ret[i] = new ShufflePartitionedBlock(
+          block.getLength(),
+          block.getUncompressLength(),
+          block.getCrc(),
+          block.getBlockId(),
+          block.getTaskAttemptId(),
+          block.getData());
+      i++;
+    }
+    return ret;
+  }
+}
+
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java 
b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
index 57f91338..ae412b38 100644
--- a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
@@ -19,12 +19,10 @@ package org.apache.uniffle.server.netty;
 
 import java.io.IOException;
 import java.util.concurrent.TimeUnit;
-import java.util.function.Supplier;
 
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.buffer.PooledByteBufAllocator;
 import io.netty.channel.ChannelFuture;
-import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelInitializer;
 import io.netty.channel.ChannelOption;
 import io.netty.channel.EventLoopGroup;
@@ -36,13 +34,15 @@ import io.netty.channel.socket.nio.NioServerSocketChannel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.netty.TransportFrameDecoder;
+import org.apache.uniffle.common.netty.client.TransportConf;
+import org.apache.uniffle.common.netty.client.TransportContext;
 import org.apache.uniffle.common.rpc.ServerInterface;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.ExitUtils;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.server.ShuffleServer;
 import org.apache.uniffle.server.ShuffleServerConf;
-import org.apache.uniffle.server.netty.decoder.StreamServerInitDecoder;
 
 public class StreamServer implements ServerInterface {
 
@@ -75,8 +75,7 @@ public class StreamServer implements ServerInterface {
       int backlogSize,
       int timeoutMillis,
       int sendBuf,
-      int receiveBuf,
-      Supplier<ChannelHandler[]> handlerSupplier) {
+      int receiveBuf) {
     ServerBootstrap serverBootstrap = new ServerBootstrap().group(bossGroup, 
workerGroup);
     if (bossGroup instanceof EpollEventLoopGroup) {
       serverBootstrap.channel(EpollServerSocketChannel.class);
@@ -84,10 +83,13 @@ public class StreamServer implements ServerInterface {
       serverBootstrap.channel(NioServerSocketChannel.class);
     }
 
+    ShuffleServerNettyHandler serverNettyHandler = new 
ShuffleServerNettyHandler(shuffleServer);
+    TransportContext transportContext =
+        new TransportContext(new TransportConf(shuffleServerConf), 
serverNettyHandler, true);
     serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
       @Override
       public void initChannel(final SocketChannel ch) {
-        ch.pipeline().addLast(handlerSupplier.get());
+        transportContext.initializePipeline(ch, new TransportFrameDecoder());
       }
     })
                .option(ChannelOption.SO_BACKLOG, backlogSize)
@@ -121,15 +123,12 @@ public class StreamServer implements ServerInterface {
 
   @Override
   public void startOnPort(int port) throws Exception {
-    Supplier<ChannelHandler[]> streamHandlers = () -> new ChannelHandler[]{
-        new StreamServerInitDecoder()
-    };
+
     ServerBootstrap serverBootstrap = bootstrapChannel(shuffleBossGroup, 
shuffleWorkerGroup,
         
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_CONNECT_BACKLOG),
         
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_CONNECT_TIMEOUT),
         shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_SEND_BUF),
-        
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_RECEIVE_BUF),
-        streamHandlers);
+        
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_RECEIVE_BUF));
 
     // Bind the ports and save the results so that the channels can be closed 
later.
     // If the second bind fails, the first one gets cleaned up in the shutdown.
@@ -143,6 +142,7 @@ public class StreamServer implements ServerInterface {
     }
   }
 
+  @Override
   public void stop() {
     if (channelFuture != null) {
       channelFuture.channel().close().awaitUninterruptibly(10L, 
TimeUnit.SECONDS);
diff --git 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
index 53eb4ef8..659a9724 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
@@ -539,7 +539,7 @@ public class ShuffleFlushManagerTest extends HdfsTestBase {
 
     // case3: local disk is full or corrupted, fallback to HDFS
     List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
-        new ShufflePartitionedBlock(100000, 1000, 1, 1, 1L, null)
+        new ShufflePartitionedBlock(100000, 1000, 1, 1, 1L, (byte[]) null)
     );
     ShuffleDataFlushEvent bigEvent = new ShuffleDataFlushEvent(1, "1", 1, 1, 
1, 100, blocks, null, null);
     
bigEvent.setUnderStorage(((MultiStorageManager)storageManager).getWarmStorageManager().selectStorage(event));
@@ -571,7 +571,8 @@ public class ShuffleFlushManagerTest extends HdfsTestBase {
       Thread.sleep(1 * 1000);
     } while (manager.getEventNumInFlush() != 0);
 
-    List<ShufflePartitionedBlock> blocks = Lists.newArrayList(new 
ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+    List<ShufflePartitionedBlock> blocks =
+        Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, 
(byte[]) null));
     ShuffleDataFlushEvent bigEvent = new ShuffleDataFlushEvent(1, "1", 1, 1, 
1, 100, blocks, null, null);
     bigEvent.setUnderStorage(storageManager.selectStorage(event));
     storageManager.updateWriteMetrics(bigEvent, 0);
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
index ff39cac6..dabd1e08 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
@@ -35,6 +35,7 @@ import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.ShufflePartitionedData;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ByteBufUtils;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.server.ShuffleFlushManager;
 import org.apache.uniffle.server.ShuffleServer;
@@ -173,11 +174,11 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
     // validate get shuffle data
     ShuffleDataResult sdr = shuffleBufferManager.getShuffleData(
         appId, 2, 0, Constants.INVALID_BLOCK_ID, 60);
-    assertArrayEquals(spd2.getBlockList()[0].getData(), sdr.getData());
+    
assertArrayEquals(ByteBufUtils.readBytes(spd2.getBlockList()[0].getData()), 
sdr.getData());
     long lastBlockId = spd2.getBlockList()[0].getBlockId();
     sdr = shuffleBufferManager.getShuffleData(
         appId, 2, 0, lastBlockId, 100);
-    assertArrayEquals(spd3.getBlockList()[0].getData(), sdr.getData());
+    
assertArrayEquals(ByteBufUtils.readBytes(spd3.getBlockList()[0].getData()), 
sdr.getData());
     // flush happen
     ShufflePartitionedData spd5 = createData(0, 10);
     shuffleBufferManager.cacheShuffleData(appId, 4, false, spd5);
@@ -193,11 +194,11 @@ public class ShuffleBufferManagerTest extends 
BufferTestBase {
     // data in flush buffer now, it also can be got before flush finish
     sdr = shuffleBufferManager.getShuffleData(
         appId, 2, 0, Constants.INVALID_BLOCK_ID, 60);
-    assertArrayEquals(spd2.getBlockList()[0].getData(), sdr.getData());
+    
assertArrayEquals(ByteBufUtils.readBytes(spd2.getBlockList()[0].getData()), 
sdr.getData());
     lastBlockId = spd2.getBlockList()[0].getBlockId();
     sdr = shuffleBufferManager.getShuffleData(
         appId, 2, 0, lastBlockId, 100);
-    assertArrayEquals(spd3.getBlockList()[0].getData(), sdr.getData());
+    
assertArrayEquals(ByteBufUtils.readBytes(spd3.getBlockList()[0].getData()), 
sdr.getData());
     // cache data again, it should cause flush
     spd1 = createData(0, 10);
     shuffleBufferManager.cacheShuffleData(appId, 1, false, spd1);
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
index 39ab4e61..b4c0ee4e 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
@@ -28,6 +28,7 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.util.ByteBufUtils;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.server.ShuffleDataFlushEvent;
 
@@ -603,7 +604,7 @@ public class ShuffleBufferTest extends BufferTestBase {
     int offset = 0;
     for (ShufflePartitionedData spd : spds) {
       ShufflePartitionedBlock block = spd.getBlockList()[0];
-      System.arraycopy(block.getData(), 0, expectedData, offset, 
block.getLength());
+      ByteBufUtils.readBytes(block.getData(), expectedData, offset, 
block.getLength());
       offset += block.getLength();
     }
     return expectedData;
diff --git 
a/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
 
b/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
index 3cea6d32..1c66ad43 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
@@ -47,7 +47,8 @@ public class MultiStorageManagerTest {
     String remoteStorage = "test";
     String appId = "selectStorageManagerTest_appId";
     manager.registerRemoteStorage(appId, new RemoteStorageInfo(remoteStorage));
-    List<ShufflePartitionedBlock> blocks = Lists.newArrayList(new 
ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+    List<ShufflePartitionedBlock> blocks =
+        Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, 
(byte[]) null));
     ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
         1, appId, 1, 1, 1, 1000, blocks, null, null);
     assertTrue((manager.selectStorage(event) instanceof LocalStorage));
@@ -81,7 +82,7 @@ public class MultiStorageManagerTest {
      * is enabled.
      */
     List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
-        new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, null)
+        new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, (byte[]) null)
     );
     ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
         1, appId, 1, 1, 1, 100000, blocks, null, null);
@@ -112,7 +113,7 @@ public class MultiStorageManagerTest {
      * case1: big event should be written into cold storage directly
      */
     List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
-        new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, null)
+        new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, (byte[]) null)
     );
     ShuffleDataFlushEvent hugeEvent = new ShuffleDataFlushEvent(
         1, appId, 1, 1, 1, 10001, blocks, null, null);
@@ -121,7 +122,7 @@ public class MultiStorageManagerTest {
     /**
      * case2: fallback when disk can not write
      */
-    blocks = Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1, 
1L, null));
+    blocks = Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1, 
1L, (byte[]) null));
     ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
         1, appId, 1, 1, 1, 1000, blocks, null, null);
     Storage storage = manager.selectStorage(event);
diff --git 
a/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
 
b/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
index c5b4512b..2446c9af 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
@@ -54,7 +54,7 @@ public class StorageManagerFallbackStrategyTest {
     String appId = "testDefaultFallbackStrategy_appId";
     coldStorageManager.registerRemoteStorage(appId, new 
RemoteStorageInfo(remoteStorage));
     List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
-        new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+        new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, (byte[]) null));
     ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
         1, appId, 1, 1, 1, 1000, blocks, null, null);
     event.increaseRetryTimes();
@@ -90,7 +90,7 @@ public class StorageManagerFallbackStrategyTest {
     String appId = "testHdfsFallbackStrategy_appId";
     coldStorageManager.registerRemoteStorage(appId, new 
RemoteStorageInfo(remoteStorage));
     List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
-        new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+        new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, (byte[]) null));
     ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
         1, appId, 1, 1, 1, 1000, blocks, null, null);
     event.increaseRetryTimes();
@@ -112,7 +112,7 @@ public class StorageManagerFallbackStrategyTest {
     String appId = "testLocalFallbackStrategy_appId";
     coldStorageManager.registerRemoteStorage(appId, new 
RemoteStorageInfo(remoteStorage));
     List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
-        new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+        new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, (byte[]) null));
     ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
         1, appId, 1, 1, 1, 1000, blocks, null, null);
     event.increaseRetryTimes();
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
index da701702..776115ef 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
@@ -32,6 +32,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.filesystem.HadoopFilesystemProvider;
+import org.apache.uniffle.common.util.ByteBufUtils;
 import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
 import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
 import org.apache.uniffle.storage.util.ShuffleStorageUtils;
@@ -118,7 +119,7 @@ public class HdfsShuffleWriteHandler implements 
ShuffleWriteHandler {
           long blockId = block.getBlockId();
           long crc = block.getCrc();
           long startOffset = dataWriter.nextOffset();
-          dataWriter.writeData(block.getData());
+          dataWriter.writeData(ByteBufUtils.readBytes(block.getData()));
 
           FileBasedShuffleSegment segment = new FileBasedShuffleSegment(
               blockId, startOffset, block.getLength(), 
block.getUncompressLength(), crc, block.getTaskAttemptId());
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
index 00e43927..25f504fb 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
@@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufUtils;
 import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
 import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
 import org.apache.uniffle.storage.util.ShuffleStorageUtils;
@@ -106,7 +107,8 @@ public class LocalFileWriteHandler implements 
ShuffleWriteHandler {
         long blockId = block.getBlockId();
         long crc = block.getCrc();
         long startOffset = dataWriter.nextOffset();
-        dataWriter.writeData(block.getData());
+        dataWriter.writeData(ByteBufUtils.readBytes(block.getData()));
+        block.getData().release();
 
         FileBasedShuffleSegment segment = new FileBasedShuffleSegment(
             blockId, startOffset, block.getLength(), 
block.getUncompressLength(), crc, block.getTaskAttemptId());
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
 
b/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
index dd8f8859..52d34efc 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
@@ -30,6 +30,7 @@ import org.apache.hadoop.fs.Path;
 import org.apache.uniffle.common.BufferSegment;
 import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.util.ByteBufUtils;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
@@ -88,7 +89,7 @@ public class HdfsShuffleHandlerTestBase {
       offset += spb.getLength();
       segments.add(segment);
       if (doWrite) {
-        writer.writeData(spb.getData());
+        writer.writeData(ByteBufUtils.readBytes(spb.getData()));
       }
     }
     expectedIndexSegments.put(partitionId, segments);
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
index 89eb24d5..199e6ad2 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
@@ -33,6 +33,7 @@ import org.apache.uniffle.common.ShuffleDataSegment;
 import org.apache.uniffle.common.ShuffleIndexResult;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.segment.FixedSizeSegmentSplitter;
+import org.apache.uniffle.common.util.ByteBufUtils;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
 import org.apache.uniffle.storage.handler.api.ServerReadHandler;
@@ -58,9 +59,11 @@ public class LocalFileHandlerTestBase {
 
   public static void writeTestData(List<ShufflePartitionedBlock> blocks, 
ShuffleWriteHandler handler,
       Map<Long, byte[]> expectedData, Set<Long> expectedBlockIds) throws 
Exception {
+    blocks.forEach(block -> block.getData().retain());
     handler.write(blocks);
     blocks.forEach(block -> expectedBlockIds.add(block.getBlockId()));
-    blocks.forEach(block -> expectedData.put(block.getBlockId(), 
block.getData()));
+    blocks.forEach(block -> expectedData.put(block.getBlockId(), 
ByteBufUtils.readBytes(block.getData())));
+    blocks.forEach(block -> block.getData().release());
   }
 
   public static void validateResult(ServerReadHandler readHandler, Set<Long> 
expectedBlockIds,

Reply via email to