This is an automated email from the ASF dual-hosted git repository.
smallzhongfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new d7fbfc04 [#133] feat(netty): Add Encoder and Decoder. (#742)
d7fbfc04 is described below
commit d7fbfc04d36485f8cb026d6bd986928ffa1677b4
Author: Xianming Lei <[email protected]>
AuthorDate: Sun Mar 26 18:00:48 2023 +0800
[#133] feat(netty): Add Encoder and Decoder. (#742)
### What changes were proposed in this pull request?
Add Netty Encoder and Decoder. Add `SendShuffleDataRequest` protocol for
tests. Other protocols will be added by following prs.
### Why are the changes needed?
For #133
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UTs.
---
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 2 +-
.../apache/uniffle/common/ShuffleBlockInfo.java | 18 +-
.../apache/uniffle/common/netty/FrameDecoder.java | 23 +++
.../uniffle/common/netty/MessageEncoder.java | 58 +++++++
.../common/netty/TransportFrameDecoder.java | 192 +++++++++++++++++++++
.../uniffle/common/netty/protocol/Decoders.java | 56 ++++++
.../uniffle/common/netty/protocol/Encoders.java | 70 ++++++++
.../uniffle/common/netty/protocol/Message.java | 64 ++++++-
.../uniffle/common/netty/protocol/RpcResponse.java | 21 +++
.../netty/protocol/SendShuffleDataRequest.java | 134 ++++++++++++++
.../apache/uniffle/common/util/ByteBufUtils.java | 23 +++
.../common/netty/EncoderAndDecoderTest.java | 174 +++++++++++++++++++
.../common/netty/protocol/NettyProtocolTest.java | 88 ++++++++++
.../netty/protocol/NettyProtocolTestUtils.java | 86 +++++++++
.../test/ShuffleServerFaultToleranceTest.java | 7 +-
.../apache/uniffle/test/ShuffleServerGrpcTest.java | 4 +-
.../test/ShuffleServerWithMemLocalHdfsTest.java | 17 +-
.../uniffle/test/ShuffleServerWithMemoryTest.java | 39 +++--
.../client/impl/grpc/ShuffleServerGrpcClient.java | 2 +-
19 files changed, 1039 insertions(+), 39 deletions(-)
diff --git
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index 54f24672..2ec6b709 100644
---
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -360,7 +360,7 @@ public class FetcherTest {
shuffleBlockInfoList.forEach(block -> {
ByteBuffer uncompressedBuffer =
ByteBuffer.allocate(block.getUncompressLength());
codec.decompress(
- ByteBuffer.wrap(block.getData()),
+ block.getData().nioBuffer(),
block.getUncompressLength(),
uncompressedBuffer,
0
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index 8bcad46f..4d0f73a5 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -19,6 +19,9 @@ package org.apache.uniffle.common;
import java.util.List;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
public class ShuffleBlockInfo {
private int partitionId;
@@ -27,14 +30,21 @@ public class ShuffleBlockInfo {
private int shuffleId;
private long crc;
private long taskAttemptId;
- private byte[] data;
+ private ByteBuf data;
private List<ShuffleServerInfo> shuffleServerInfos;
private int uncompressLength;
private long freeMemory;
public ShuffleBlockInfo(int shuffleId, int partitionId, long blockId, int
length, long crc,
byte[] data, List<ShuffleServerInfo> shuffleServerInfos,
- int uncompressLength, int freeMemory, long taskAttemptId) {
+ int uncompressLength, long freeMemory, long taskAttemptId) {
+ this(shuffleId, partitionId, blockId, length, crc,
Unpooled.wrappedBuffer(data),
+ shuffleServerInfos, uncompressLength, freeMemory, taskAttemptId);
+ }
+
+ public ShuffleBlockInfo(int shuffleId, int partitionId, long blockId, int
length, long crc,
+ ByteBuf data, List<ShuffleServerInfo> shuffleServerInfos,
+ int uncompressLength, long freeMemory, long taskAttemptId) {
this.partitionId = partitionId;
this.blockId = blockId;
this.length = length;
@@ -56,7 +66,7 @@ public class ShuffleBlockInfo {
}
// calculate the data size for this block in memory including metadata which
are
- // blockId, crc, taskAttemptId, length, uncompressLength
+ // partitionId, blockId, crc, taskAttemptId, length, uncompressLength
public int getSize() {
return length + 3 * 8 + 2 * 4;
}
@@ -65,7 +75,7 @@ public class ShuffleBlockInfo {
return crc;
}
- public byte[] getData() {
+ public ByteBuf getData() {
return data;
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
new file mode 100644
index 00000000..53366dc3
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/netty/FrameDecoder.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty;
+
+public interface FrameDecoder {
+ String HANDLER_NAME = "FrameDecoder";
+ int HEADER_SIZE = Integer.BYTES + Byte.BYTES;
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
new file mode 100644
index 00000000..4167e53a
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.protocol.Message;
+
+/**
+ * Encoder used by the server side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ * The content of encode consists of two parts, header and message body.
+ * The encoded binary stream contains encodeLength (4 bytes), messageType (1
byte)
+ * and messageBody (encodeLength bytes).
+ */
[email protected]
+public class MessageEncoder extends ChannelOutboundHandlerAdapter {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(MessageEncoder.class);
+
+ @Override
+ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise
promise) {
+ // todo: support zero copy
+ Message message = (Message) msg;
+ int encodeLength = message.encodedLength();
+ ByteBuf byteBuf = ctx.alloc().buffer(FrameDecoder.HEADER_SIZE +
encodeLength);
+ try {
+ byteBuf.writeInt(encodeLength);
+ byteBuf.writeByte(message.type().id());
+ message.encode(byteBuf);
+ } catch (Exception e) {
+ LOG.error("Unexpected exception during process encode!", e);
+ byteBuf.release();
+ }
+ ctx.writeAndFlush(byteBuf);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
new file mode 100644
index 00000000..76e86d0c
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty;
+
+import java.util.LinkedList;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+
+import org.apache.uniffle.common.netty.protocol.Message;
+
+/**
+ * A customized frame decoder that allows intercepting raw data.
+ *
+ * <p>This behaves like Netty's frame decoder (with hard coded parameters that
match this library's
+ * needs), except it allows an interceptor to be installed to read data
directly before it's framed.
+ *
+ * <p>Unlike Netty's frame decoder, each frame is dispatched to child handlers
as soon as it's
+ * decoded, instead of building as many frames as the current buffer allows
and dispatching all of
+ * them. This allows a child handler to install an interceptor if needed.
+ *
+ * <p>If an interceptor is installed, framing stops, and data is instead fed
directly to the
+ * interceptor. When the interceptor indicates that it doesn't need to read
any more data, framing
+ * resumes. Interceptors should not hold references to the data buffers
provided to their handle()
+ * method.
+ */
+public class TransportFrameDecoder extends ChannelInboundHandlerAdapter
implements FrameDecoder {
+ private int msgSize = -1;
+ private Message.Type curType = Message.Type.UNKNOWN_TYPE;
+ private ByteBuf headerBuf = Unpooled.buffer(HEADER_SIZE, HEADER_SIZE);
+ private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+ private static final int UNKNOWN_FRAME_SIZE = -1;
+
+ private final LinkedList<ByteBuf> buffers = new LinkedList<>();
+
+ private long totalSize = 0;
+ private long nextFrameSize = UNKNOWN_FRAME_SIZE;
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object data) {
+ ByteBuf in = (ByteBuf) data;
+ buffers.add(in);
+ totalSize += in.readableBytes();
+
+ while (!buffers.isEmpty()) {
+ ByteBuf frame = decodeNext();
+ if (frame == null) {
+ break;
+ }
+ // todo: An exception may be thrown during the decoding process, causing
frame.release() to fail to be called
+ Message msg = Message.decode(curType, frame);
+ frame.release();
+ ctx.fireChannelRead(msg);
+ clear();
+ }
+ }
+
+ private void clear() {
+ curType = Message.Type.UNKNOWN_TYPE;
+ msgSize = -1;
+ headerBuf.clear();
+ }
+
+ private long decodeFrameSize() {
+ if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < HEADER_SIZE) {
+ return nextFrameSize;
+ }
+
+ // We know there's enough data. If the first buffer contains all the data,
great. Otherwise,
+ // hold the bytes for the frame length in a composite buffer until we have
enough data to read
+ // the frame size. Normally, it should be rare to need more than one
buffer to read the frame
+ // size.
+ ByteBuf first = buffers.getFirst();
+ if (first.readableBytes() >= HEADER_SIZE) {
+ msgSize = first.readInt();
+ curType = Message.Type.decode(first);
+ nextFrameSize = msgSize;
+ totalSize -= HEADER_SIZE;
+ if (!first.isReadable()) {
+ buffers.removeFirst().release();
+ }
+ return nextFrameSize;
+ }
+
+ while (headerBuf.readableBytes() < HEADER_SIZE) {
+ ByteBuf next = buffers.getFirst();
+ int toRead = Math.min(next.readableBytes(), HEADER_SIZE -
headerBuf.readableBytes());
+ headerBuf.writeBytes(next, toRead);
+ if (!next.isReadable()) {
+ buffers.removeFirst().release();
+ }
+ }
+
+ msgSize = headerBuf.readInt();
+ curType = Message.Type.decode(headerBuf);
+ nextFrameSize = msgSize;
+ totalSize -= HEADER_SIZE;
+ return nextFrameSize;
+ }
+
+ private ByteBuf decodeNext() {
+ long frameSize = decodeFrameSize();
+ if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
+ return null;
+ }
+
+ // Reset size for next frame.
+ nextFrameSize = UNKNOWN_FRAME_SIZE;
+
+ Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame:
%s", frameSize);
+ Preconditions.checkArgument(frameSize > 0, "Frame length should be
positive: %s", frameSize);
+
+ // If the first buffer holds the entire frame, return it.
+ int remaining = (int) frameSize;
+ if (buffers.getFirst().readableBytes() >= remaining) {
+ return nextBufferForFrame(remaining);
+ }
+
+ // Otherwise, create a composite buffer.
+ CompositeByteBuf frame =
buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
+ while (remaining > 0) {
+ ByteBuf next = nextBufferForFrame(remaining);
+ remaining -= next.readableBytes();
+ frame.addComponent(next).writerIndex(frame.writerIndex() +
next.readableBytes());
+ }
+ assert remaining == 0;
+ return frame;
+ }
+
+ /**
+ * Takes the first buffer in the internal list, and either adjust it to fit
in the frame (by
+ * taking a slice out of it) or remove it from the internal list.
+ */
+ private ByteBuf nextBufferForFrame(int bytesToRead) {
+ ByteBuf buf = buffers.getFirst();
+ ByteBuf frame;
+
+ if (buf.readableBytes() > bytesToRead) {
+ frame = buf.retain().readSlice(bytesToRead);
+ totalSize -= bytesToRead;
+ } else {
+ frame = buf;
+ buffers.removeFirst();
+ totalSize -= frame.readableBytes();
+ }
+
+ return frame;
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ super.channelInactive(ctx);
+ }
+
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ // Release all buffers that are still in our ownership.
+ // Doing this in handlerRemoved(...) guarantees that this will happen in
all cases:
+ // - When the Channel becomes inactive
+ // - When the decoder is removed from the ChannelPipeline
+ for (ByteBuf b : buffers) {
+ b.release();
+ }
+ buffers.clear();
+ headerBuf.release();
+ super.handlerRemoved(ctx);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
+ super.exceptionCaught(ctx, cause);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
new file mode 100644
index 00000000..4b969a62
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class Decoders {
+ public static ShuffleServerInfo decodeShuffleServerInfo(ByteBuf byteBuf) {
+ String id = ByteBufUtils.readLengthAndString(byteBuf);
+ String host = ByteBufUtils.readLengthAndString(byteBuf);
+ int grpcPort = byteBuf.readInt();
+ int nettyPort = byteBuf.readInt();
+ return new ShuffleServerInfo(id, host, grpcPort, nettyPort);
+ }
+
+ public static ShuffleBlockInfo decodeShuffleBlockInfo(ByteBuf byteBuf) {
+ int partId = byteBuf.readInt();
+ long blockId = byteBuf.readLong();
+ int length = byteBuf.readInt();
+ int shuffleId = byteBuf.readInt();
+ long crc = byteBuf.readLong();
+ long taskAttemptId = byteBuf.readLong();
+ ByteBuf data = ByteBufUtils.readSlice(byteBuf);
+ int lengthOfShuffleServers = byteBuf.readInt();
+ List<ShuffleServerInfo> serverInfos = Lists.newArrayList();
+ for (int k = 0; k < lengthOfShuffleServers; k++) {
+ serverInfos.add(decodeShuffleServerInfo(byteBuf));
+ }
+ int uncompressLength = byteBuf.readInt();
+ long freeMemory = byteBuf.readLong();
+ return new ShuffleBlockInfo(shuffleId, partId, blockId,
+ length, crc, data, serverInfos, uncompressLength, freeMemory,
taskAttemptId);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java
new file mode 100644
index 00000000..e819355a
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class Encoders {
+ public static void encodeShuffleServerInfo(ShuffleServerInfo
shuffleServerInfo, ByteBuf byteBuf) {
+ ByteBufUtils.writeLengthAndString(byteBuf, shuffleServerInfo.getId());
+ ByteBufUtils.writeLengthAndString(byteBuf, shuffleServerInfo.getHost());
+ byteBuf.writeInt(shuffleServerInfo.getGrpcPort());
+ byteBuf.writeInt(shuffleServerInfo.getNettyPort());
+ }
+
+ public static void encodeShuffleBlockInfo(ShuffleBlockInfo shuffleBlockInfo,
ByteBuf byteBuf) {
+ byteBuf.writeInt(shuffleBlockInfo.getPartitionId());
+ byteBuf.writeLong(shuffleBlockInfo.getBlockId());
+ byteBuf.writeInt(shuffleBlockInfo.getLength());
+ byteBuf.writeInt(shuffleBlockInfo.getShuffleId());
+ byteBuf.writeLong(shuffleBlockInfo.getCrc());
+ byteBuf.writeLong(shuffleBlockInfo.getTaskAttemptId());
+ // todo: avoid copy
+ ByteBufUtils.copyByteBuf(shuffleBlockInfo.getData(), byteBuf);
+ shuffleBlockInfo.getData().release();
+ List<ShuffleServerInfo> shuffleServerInfoList =
shuffleBlockInfo.getShuffleServerInfos();
+ byteBuf.writeInt(shuffleServerInfoList.size());
+ for (ShuffleServerInfo shuffleServerInfo : shuffleServerInfoList) {
+ Encoders.encodeShuffleServerInfo(shuffleServerInfo, byteBuf);
+ }
+ byteBuf.writeInt(shuffleBlockInfo.getUncompressLength());
+ byteBuf.writeLong(shuffleBlockInfo.getFreeMemory());
+ }
+
+ public static int encodeLengthOfShuffleServerInfo(ShuffleServerInfo
shuffleServerInfo) {
+ return ByteBufUtils.encodedLength(shuffleServerInfo.getId())
+ + ByteBufUtils.encodedLength(shuffleServerInfo.getHost())
+ + 2 * Integer.BYTES;
+ }
+
+ public static int encodeLengthOfShuffleBlockInfo(ShuffleBlockInfo
shuffleBlockInfo) {
+ int encodeLength = 4 * Long.BYTES + 4 * Integer.BYTES
+ + ByteBufUtils.encodedLength(shuffleBlockInfo.getData()) +
Integer.BYTES;
+ for (ShuffleServerInfo shuffleServerInfo :
shuffleBlockInfo.getShuffleServerInfos()) {
+ encodeLength += encodeLengthOfShuffleServerInfo(shuffleServerInfo);
+ }
+ return encodeLength;
+ }
+
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
index 6eb2813b..b0a3da1f 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
@@ -25,7 +25,27 @@ public abstract class Message implements Encodable {
public enum Type implements Encodable {
UNKNOWN_TYPE(-1),
- RPC_RESPONSE(0);
+ RPC_RESPONSE(0),
+ SHUFFLE_REGISTER_REQUEST(1),
+ SHUFFLE_UNREGISTER_REQUEST(2),
+ SEND_SHUFFLE_DATA_REQUEST(3),
+ GET_LOCAL_SHUFFLE_INDEX_REQUEST(4),
+ GET_LOCAL_SHUFFLE_DATA_REQUEST(5),
+ GET_MEMORY_SHUFFLE_DATA_REQUEST(6),
+ SHUFFLE_COMMIT_REQUEST(7),
+ REPORT_SHUFFLE_RESULT_REQUEST(8),
+ GET_SHUFFLE_RESULT_REQUEST(9),
+ GET_SHUFFLE_RESULT_FOR_MULTI_PART_REQUEST(10),
+ FINISH_SHUFFLE_REQUEST(11),
+ REQUIRE_BUFFER_REQUEST(12),
+ APP_HEART_BEAT_REQUEST(13),
+ GET_LOCAL_SHUFFLE_INDEX_RESPONSE(14),
+ GET_LOCAL_SHUFFLE_DATA_RESPONSE(15),
+ GET_MEMORY_SHUFFLE_DATA_RESPONSE(16),
+ SHUFFLE_COMMIT_RESPONSE(17),
+ GET_SHUFFLE_RESULT_RESPONSE(18),
+ GET_SHUFFLE_RESULT_FOR_MULTI_PART_RESPONSE(19),
+ REQUIRE_BUFFER_RESPONSE(20);
private final byte id;
@@ -53,6 +73,46 @@ public abstract class Message implements Encodable {
switch (id) {
case 0:
return RPC_RESPONSE;
+ case 1:
+ return SHUFFLE_REGISTER_REQUEST;
+ case 2:
+ return SHUFFLE_UNREGISTER_REQUEST;
+ case 3:
+ return SEND_SHUFFLE_DATA_REQUEST;
+ case 4:
+ return GET_LOCAL_SHUFFLE_INDEX_REQUEST;
+ case 5:
+ return GET_LOCAL_SHUFFLE_DATA_REQUEST;
+ case 6:
+ return GET_MEMORY_SHUFFLE_DATA_REQUEST;
+ case 7:
+ return SHUFFLE_COMMIT_REQUEST;
+ case 8:
+ return REPORT_SHUFFLE_RESULT_REQUEST;
+ case 9:
+ return GET_SHUFFLE_RESULT_REQUEST;
+ case 10:
+ return GET_SHUFFLE_RESULT_FOR_MULTI_PART_REQUEST;
+ case 11:
+ return FINISH_SHUFFLE_REQUEST;
+ case 12:
+ return REQUIRE_BUFFER_REQUEST;
+ case 13:
+ return APP_HEART_BEAT_REQUEST;
+ case 14:
+ return GET_LOCAL_SHUFFLE_INDEX_RESPONSE;
+ case 15:
+ return GET_LOCAL_SHUFFLE_DATA_RESPONSE;
+ case 16:
+ return GET_MEMORY_SHUFFLE_DATA_RESPONSE;
+ case 17:
+ return SHUFFLE_COMMIT_RESPONSE;
+ case 18:
+ return GET_SHUFFLE_RESULT_RESPONSE;
+ case 19:
+ return GET_SHUFFLE_RESULT_FOR_MULTI_PART_RESPONSE;
+ case 20:
+ return REQUIRE_BUFFER_RESPONSE;
case -1:
throw new IllegalArgumentException("User type messages cannot be
decoded.");
default:
@@ -65,6 +125,8 @@ public abstract class Message implements Encodable {
switch (msgType) {
case RPC_RESPONSE:
return RpcResponse.decode(in);
+ case SEND_SHUFFLE_DATA_REQUEST:
+ return SendShuffleDataRequest.decode(in);
default:
throw new IllegalArgumentException("Unexpected message type: " +
msgType);
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
index 9fef38cb..686c8d6e 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
@@ -17,6 +17,8 @@
package org.apache.uniffle.common.netty.protocol;
+import java.util.Objects;
+
import io.netty.buffer.ByteBuf;
import org.apache.uniffle.common.rpc.StatusCode;
@@ -82,4 +84,23 @@ public class RpcResponse extends Message {
public Type type() {
return Type.RPC_RESPONSE;
}
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ RpcResponse that = (RpcResponse) o;
+ return requestId == that.requestId
+ && statusCode == that.statusCode
+ && Objects.equals(retMessage, that.retMessage);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(requestId, statusCode, retMessage);
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
new file mode 100644
index 00000000..cc1117a8
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class SendShuffleDataRequest extends Message {
+ public long requestId;
+ private String appId;
+ private int shuffleId;
+ private long requireId;
+ private Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks;
+ private long timestamp;
+
+ public SendShuffleDataRequest(long requestId, String appId, int shuffleId,
long requireId,
+ Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks, long timestamp) {
+ this.requestId = requestId;
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.requireId = requireId;
+ this.partitionToBlocks = partitionToBlocks;
+ this.timestamp = timestamp;
+ }
+
+ @Override
+ public Type type() {
+ return Type.SEND_SHUFFLE_DATA_REQUEST;
+ }
+
+ @Override
+ public int encodedLength() {
+ int encodeLength = Long.BYTES + ByteBufUtils.encodedLength(appId) +
Integer.BYTES + Long.BYTES + Integer.BYTES;
+ for (Map.Entry<Integer, List<ShuffleBlockInfo>> entry :
partitionToBlocks.entrySet()) {
+ encodeLength += 2 * Integer.BYTES;
+ for (ShuffleBlockInfo sbi : entry.getValue()) {
+ encodeLength += Encoders.encodeLengthOfShuffleBlockInfo(sbi);
+ }
+ }
+ return encodeLength + Long.BYTES;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ ByteBufUtils.writeLengthAndString(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeLong(requireId);
+ encodePartitionData(buf);
+ buf.writeLong(timestamp);
+ }
+
+ private static Map<Integer, List<ShuffleBlockInfo>>
decodePartitionData(ByteBuf byteBuf) {
+ Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks = Maps.newHashMap();
+ int lengthOfPartitionData = byteBuf.readInt();
+ for (int i = 0; i < lengthOfPartitionData; i++) {
+ int partitionId = byteBuf.readInt();
+ int lengthOfShuffleBlocks = byteBuf.readInt();
+ List<ShuffleBlockInfo> shuffleBlockInfoList = Lists.newArrayList();
+ for (int j = 0; j < lengthOfShuffleBlocks; j++) {
+ shuffleBlockInfoList.add(Decoders.decodeShuffleBlockInfo(byteBuf));
+ }
+ partitionToBlocks.put(partitionId, shuffleBlockInfoList);
+ }
+ return partitionToBlocks;
+ }
+
+ public static SendShuffleDataRequest decode(ByteBuf byteBuf) {
+ long requestId = byteBuf.readLong();
+ String appId = ByteBufUtils.readLengthAndString(byteBuf);
+ int shuffleId = byteBuf.readInt();
+ long requireId = byteBuf.readLong();
+ Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks =
decodePartitionData(byteBuf);
+ long timestamp = byteBuf.readLong();
+ return new SendShuffleDataRequest(requestId, appId, shuffleId, requireId,
partitionToBlocks, timestamp);
+ }
+
+ private void encodePartitionData(ByteBuf buf) {
+ buf.writeInt(partitionToBlocks.size());
+ for (Map.Entry<Integer, List<ShuffleBlockInfo>> entry :
partitionToBlocks.entrySet()) {
+ buf.writeInt(entry.getKey());
+ buf.writeInt(entry.getValue().size());
+ for (ShuffleBlockInfo sbi : entry.getValue()) {
+ Encoders.encodeShuffleBlockInfo(sbi, buf);
+ }
+ }
+ }
+
+ public long getRequestId() {
+ return requestId;
+ }
+
+ public String getAppId() {
+ return appId;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public long getRequireId() {
+ return requireId;
+ }
+
+ public Map<Integer, List<ShuffleBlockInfo>> getPartitionToBlocks() {
+ return partitionToBlocks;
+ }
+
+ public long getTimestamp() {
+ return timestamp;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
index 6b1f0dd0..1872674b 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
@@ -27,6 +27,10 @@ public class ByteBufUtils {
return 4 + s.getBytes(StandardCharsets.UTF_8).length;
}
+ public static int encodedLength(ByteBuf buf) {
+ return 4 + buf.readableBytes();
+ }
+
public static final void writeLengthAndString(ByteBuf buf, String str) {
if (str == null) {
buf.writeInt(-1);
@@ -49,9 +53,28 @@ public class ByteBufUtils {
return new String(bytes, StandardCharsets.UTF_8);
}
+ public static final void copyByteBuf(ByteBuf from, ByteBuf to) {
+ to.writeInt(from.readableBytes());
+ to.writeBytes(from);
+ from.resetReaderIndex();
+ }
+
+ public static final byte[] readByteArray(ByteBuf byteBuf) {
+ int length = byteBuf.readInt();
+ byte[] data = new byte[length];
+ byteBuf.readBytes(data);
+ return data;
+ }
+
+ public static final ByteBuf readSlice(ByteBuf from) {
+ int length = from.readInt();
+ return from.retain().readSlice(length);
+ }
+
public static final byte[] readBytes(ByteBuf buf) {
byte[] bytes = new byte[buf.readableBytes()];
buf.readBytes(bytes);
+ buf.resetReaderIndex();
return bytes;
}
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
new file mode 100644
index 00000000..7adce841
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+
+import com.google.common.collect.Maps;
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.netty.protocol.NettyProtocolTestUtils;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.NettyUtils;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class EncoderAndDecoderTest {
+ private EventLoopGroup bossGroup;
+ private EventLoopGroup workerGroup;
+ private ChannelFuture channelFuture;
+ private static final SendShuffleDataRequest DATA_REQUEST =
generateShuffleDataRequest();
+ private static final String EXPECTED_MESSAGE = "test_message";
+ private static final long REQUEST_ID = 1;
+ private static final StatusCode STATUS_CODE = StatusCode.SUCCESS;
+ private static final int PORT = 12345;
+
+ static class MockResponseHandler extends ChannelInboundHandlerAdapter {
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws
Exception {
+ if (msg instanceof RpcResponse) {
+ RpcResponse rpcResponse = (RpcResponse) msg;
+ assertEquals(REQUEST_ID, rpcResponse.getRequestId());
+ assertEquals(STATUS_CODE, rpcResponse.getStatusCode());
+ assertEquals(EXPECTED_MESSAGE, rpcResponse.getRetMessage());
+ } else if (msg instanceof SendShuffleDataRequest) {
+ SendShuffleDataRequest sendShuffleDataRequest =
(SendShuffleDataRequest) msg;
+
assertTrue(NettyProtocolTestUtils.compareSendShuffleDataRequest(sendShuffleDataRequest,
DATA_REQUEST));
+
sendShuffleDataRequest.getPartitionToBlocks().values().stream().flatMap(Collection::stream)
+ .forEach(shuffleBlockInfo -> shuffleBlockInfo.getData().release());
+
sendShuffleDataRequest.getPartitionToBlocks().values().stream().flatMap(Collection::stream)
+ .forEach(shuffleBlockInfo ->
assertEquals(0,shuffleBlockInfo.getData().refCnt()));
+ RpcResponse rpcResponse = new RpcResponse(REQUEST_ID, STATUS_CODE,
EXPECTED_MESSAGE);
+ ctx.writeAndFlush(rpcResponse);
+ } else {
+ throw new RssException("receive unexpected message!");
+ }
+ super.channelRead(ctx, msg);
+ }
+ }
+
+ @Test
+ public void test() throws InterruptedException {
+ EventLoopGroup workerGroup = NettyUtils.createEventLoop(IOMode.NIO, 2,
"netty-client");
+ PooledByteBufAllocator pooledByteBufAllocator =
+ NettyUtils.createPooledByteBufAllocator(
+ true, false /* allowCache */, 2);
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap
+ .group(workerGroup)
+ .channel(NettyUtils.getClientChannelClass(IOMode.NIO))
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 30000)
+ .option(ChannelOption.ALLOCATOR, pooledByteBufAllocator);
+ final AtomicReference<Channel> channelRef = new AtomicReference<>();
+ bootstrap.handler(
+ new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ ch.pipeline().addLast("ClientEncoder", new MessageEncoder())
+ .addLast("ClientDecoder", new TransportFrameDecoder())
+ .addLast("ClientResponseHandler", new MockResponseHandler());
+ channelRef.set(ch);
+ }
+ });
+ bootstrap.connect("localhost", PORT);
+ // wait for initChannel
+ Thread.sleep(200);
+ channelRef.get().writeAndFlush(DATA_REQUEST);
+ channelRef.get().closeFuture().await(3L, TimeUnit.SECONDS);
+
DATA_REQUEST.getPartitionToBlocks().values().stream().flatMap(Collection::stream)
+ .forEach(shuffleBlockInfo -> shuffleBlockInfo.getData().release());
+ }
+
+ private static SendShuffleDataRequest generateShuffleDataRequest() {
+ String appId = "test_app";
+ byte[] data = new byte[]{1, 2, 3};
+ List<ShuffleServerInfo> shuffleServerInfoList = Arrays.asList(new
ShuffleServerInfo("aaa", 1),
+ new ShuffleServerInfo("bbb", 2));
+ List<ShuffleBlockInfo> shuffleBlockInfoList1 =
+ Arrays.asList(new ShuffleBlockInfo(1, 1, 1, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1),
+ new ShuffleBlockInfo(1, 1, 1, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1));
+ List<ShuffleBlockInfo> shuffleBlockInfoList2 =
+ Arrays.asList(new ShuffleBlockInfo(1, 2, 1, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1),
+ new ShuffleBlockInfo(1, 1, 2, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1));
+ Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks = Maps.newHashMap();
+ partitionToBlocks.put(1, shuffleBlockInfoList1);
+ partitionToBlocks.put(2, shuffleBlockInfoList2);
+ return new SendShuffleDataRequest(1L, appId, 1, 1, partitionToBlocks,
12345);
+ }
+
+ @BeforeEach
+ public void startNettyServer() {
+ bossGroup = NettyUtils.createEventLoop(IOMode.NIO, 1, "netty-boss-group");
+ workerGroup = NettyUtils.createEventLoop(IOMode.NIO, 5,
"netty-worker-group");
+ ServerBootstrap serverBootstrap = new ServerBootstrap().group(bossGroup,
workerGroup)
+
.channel(NioServerSocketChannel.class);
+ serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(final SocketChannel ch) {
+ ch.pipeline().addLast("ServerEncoder", new MessageEncoder())
+ .addLast("ServerDecoder", new TransportFrameDecoder())
+ .addLast("ServerResponseHandler", new MockResponseHandler());
+ }
+ })
+ .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.TCP_NODELAY, true)
+ .childOption(ChannelOption.SO_KEEPALIVE, true);
+ channelFuture = serverBootstrap.bind(PORT);
+ channelFuture.syncUninterruptibly();
+ }
+
+ @AfterEach
+ public void stopNettyServer() {
+ channelFuture.channel().close().awaitUninterruptibly(10L,
TimeUnit.SECONDS);
+ bossGroup.shutdownGracefully();
+ workerGroup.shutdownGracefully();
+ }
+}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
new file mode 100644
index 00000000..d2da5512
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class NettyProtocolTest {
+ @Test
+ public void testSendShuffleDataRequest() {
+ String appId = "test_app";
+ byte[] data = new byte[]{1, 2, 3};
+ List<ShuffleServerInfo> shuffleServerInfoList = Arrays.asList(new
ShuffleServerInfo("aaa", 1),
+ new ShuffleServerInfo("bbb", 2));
+ List<ShuffleBlockInfo> shuffleBlockInfoList1 =
+ Arrays.asList(new ShuffleBlockInfo(1, 1, 1, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1),
+ new ShuffleBlockInfo(1, 1, 1, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1));
+ List<ShuffleBlockInfo> shuffleBlockInfoList2 =
+ Arrays.asList(new ShuffleBlockInfo(1, 2, 1, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1),
+ new ShuffleBlockInfo(1, 1, 2, 10, 123,
+ Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList,
5, 0, 1));
+ Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks = Maps.newHashMap();
+ partitionToBlocks.put(1, shuffleBlockInfoList1);
+ partitionToBlocks.put(2, shuffleBlockInfoList2);
+ SendShuffleDataRequest sendShuffleDataRequest =
+ new SendShuffleDataRequest(1L, appId, 1, 1, partitionToBlocks, 12345);
+ int encodeLength = sendShuffleDataRequest.encodedLength();
+
+ ByteBuf byteBuf = Unpooled.buffer(sendShuffleDataRequest.encodedLength());
+ sendShuffleDataRequest.encode(byteBuf);
+ assertEquals(byteBuf.readableBytes(), encodeLength);
+ SendShuffleDataRequest sendShuffleDataRequest1 =
sendShuffleDataRequest.decode(byteBuf);
+
assertTrue(NettyProtocolTestUtils.compareSendShuffleDataRequest(sendShuffleDataRequest,
sendShuffleDataRequest1));
+ assertEquals(encodeLength, sendShuffleDataRequest1.encodedLength());
+ byteBuf.release();
+ for (ShuffleBlockInfo shuffleBlockInfo :
sendShuffleDataRequest1.getPartitionToBlocks().get(1)) {
+ shuffleBlockInfo.getData().release();
+ }
+ for (ShuffleBlockInfo shuffleBlockInfo :
sendShuffleDataRequest1.getPartitionToBlocks().get(2)) {
+ shuffleBlockInfo.getData().release();
+ }
+ assertEquals(0, byteBuf.refCnt());
+ }
+
+ @Test
+ public void testRpcResponse() {
+ RpcResponse rpcResponse = new RpcResponse(1, StatusCode.SUCCESS,
"test_message");
+ int encodeLength = rpcResponse.encodedLength();
+ ByteBuf byteBuf = Unpooled.buffer(encodeLength);
+ rpcResponse.encode(byteBuf);
+ assertEquals(byteBuf.readableBytes(), encodeLength);
+ RpcResponse rpcResponse1 = RpcResponse.decode(byteBuf);
+ assertTrue(rpcResponse.equals(rpcResponse1));
+ assertEquals(rpcResponse.encodedLength(), rpcResponse1.encodedLength());
+ byteBuf.release();
+ }
+}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTestUtils.java
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTestUtils.java
new file mode 100644
index 00000000..29f1c122
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTestUtils.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+public class NettyProtocolTestUtils {
+
+ private static boolean compareShuffleBlockInfo(ShuffleBlockInfo blockInfo1,
ShuffleBlockInfo blockInfo2) {
+ return blockInfo1.getPartitionId() == blockInfo2.getPartitionId()
+ && blockInfo1.getBlockId() == blockInfo2.getBlockId()
+ && blockInfo1.getLength() == blockInfo2.getLength()
+ && blockInfo1.getShuffleId() == blockInfo2.getShuffleId()
+ && blockInfo1.getCrc() == blockInfo2.getCrc()
+ && blockInfo1.getTaskAttemptId() == blockInfo2.getTaskAttemptId()
+ && blockInfo1.getUncompressLength() == blockInfo2.getUncompressLength()
+ && blockInfo1.getFreeMemory() == blockInfo2.getFreeMemory()
+ && blockInfo1.getData().equals(blockInfo2.getData())
+ &&
blockInfo1.getShuffleServerInfos().equals(blockInfo2.getShuffleServerInfos());
+ }
+
+ private static boolean compareBlockList(List<ShuffleBlockInfo> list1,
List<ShuffleBlockInfo> list2) {
+ if (list1 == null || list2 == null || list1.size() != list2.size()) {
+ return false;
+ }
+ for (int i = 0; i < list1.size(); i++) {
+ if (!compareShuffleBlockInfo(list1.get(i), list2.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static boolean comparePartitionToBlockList(Map<Integer,
List<ShuffleBlockInfo>> m1,
+ Map<Integer, List<ShuffleBlockInfo>> m2) {
+ if (m1 == null || m2 == null || m1.size() != m2.size()) {
+ return false;
+ }
+ Iterator<Map.Entry<Integer, List<ShuffleBlockInfo>>> iter1 =
m1.entrySet().iterator();
+ while (iter1.hasNext()) {
+ Map.Entry<Integer, List<ShuffleBlockInfo>> entry1 = iter1.next();
+ if (!compareBlockList(entry1.getValue(), m2.get(entry1.getKey()))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public static boolean compareSendShuffleDataRequest(SendShuffleDataRequest
req1,
+ SendShuffleDataRequest req2) {
+ if (req1 == req2) {
+ return true;
+ }
+ if (req1 == null || req2 == null) {
+ return false;
+ }
+ boolean isEqual = req1.requestId == req2.requestId
+ && req1.getShuffleId() == req2.getShuffleId()
+ && req1.getRequireId() == req2.getRequireId()
+ && req1.getTimestamp() == req2.getTimestamp()
+ && req1.getAppId().equals(req2.getAppId());
+ if (!isEqual) {
+ return false;
+ }
+ return comparePartitionToBlockList(req1.getPartitionToBlocks(),
req2.getPartitionToBlocks());
+ }
+}
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
index 675108aa..3c4a3635 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerFaultToleranceTest.java
@@ -42,6 +42,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.coordinator.CoordinatorServer;
import org.apache.uniffle.server.MockedShuffleServer;
@@ -119,7 +120,7 @@ public class ShuffleServerFaultToleranceTest extends
ShuffleReadWriteBase {
Map<Long, byte[]> expectedData = Maps.newHashMap();
expectedData.clear();
blocks.forEach((block) -> {
- expectedData.put(block.getBlockId(), block.getData());
+ expectedData.put(block.getBlockId(),
ByteBufUtils.readBytes(block.getData()));
});
ShuffleDataResult sdr = clientReadHandler.readShuffleData();
TestUtils.validateResult(expectedData, sdr);
@@ -148,7 +149,7 @@ public class ShuffleServerFaultToleranceTest extends
ShuffleReadWriteBase {
ShuffleHandlerFactory.getInstance().createShuffleReadHandler(request);
sdr = clientReadHandler.readShuffleData();
blocks2.forEach((block) -> {
- expectedData.put(block.getBlockId(), block.getData());
+ expectedData.put(block.getBlockId(),
ByteBufUtils.readBytes(block.getData()));
});
TestUtils.validateResult(expectedData, sdr);
for (BufferSegment bs : sdr.getBufferSegments()) {
@@ -166,7 +167,7 @@ public class ShuffleServerFaultToleranceTest extends
ShuffleReadWriteBase {
expectBlockIds, dataMap, mockSSI);
expectedData.clear();
blocks3.forEach((block) -> {
- expectedData.put(block.getBlockId(), block.getData());
+ expectedData.put(block.getBlockId(),
ByteBufUtils.readBytes(block.getData()));
});
rssdr = getRssSendShuffleDataRequest(testAppId, shuffleId, partitionId,
blocks3);
shuffleServerClients.get(1).sendShuffleData(rssdr);
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
index 2c9a5df9..fecb9738 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
@@ -28,7 +28,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
-import com.google.protobuf.ByteString;
+import com.google.protobuf.UnsafeByteOperations;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
@@ -424,7 +424,7 @@ public class ShuffleServerGrpcTest extends
IntegrationTestBase {
.setLength(sbi.getLength())
.setTaskAttemptId(sbi.getTaskAttemptId())
.setUncompressLength(sbi.getUncompressLength())
- .setData(ByteString.copyFrom(sbi.getData()))
+
.setData(UnsafeByteOperations.unsafeWrap(sbi.getData().nioBuffer()))
.build());
}
shuffleData.add(RssProtos.ShuffleData.newBuilder().setPartitionId(ptb.getKey())
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
index 5fc74979..89397649 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
@@ -39,6 +39,7 @@ import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.buffer.ShuffleBuffer;
@@ -142,9 +143,9 @@ public class ShuffleServerWithMemLocalHdfsTest extends
ShuffleReadWriteBase {
ssi, handlers);
Map<Long, byte[]> expectedData = Maps.newHashMap();
expectedData.clear();
- expectedData.put(blocks.get(0).getBlockId(), blocks.get(0).getData());
- expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
- expectedData.put(blocks.get(2).getBlockId(), blocks.get(1).getData());
+ expectedData.put(blocks.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks.get(0).getData()));
+ expectedData.put(blocks.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
+ expectedData.put(blocks.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
ShuffleDataResult sdr = composedClientReadHandler.readShuffleData();
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks.get(0).getBlockId());
@@ -169,8 +170,8 @@ public class ShuffleServerWithMemLocalHdfsTest extends
ShuffleReadWriteBase {
// notice: the 1-th segment is skipped, because it is processed
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks2.get(0).getBlockId(), blocks2.get(0).getData());
- expectedData.put(blocks2.get(1).getBlockId(), blocks2.get(1).getData());
+ expectedData.put(blocks2.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(0).getData()));
+ expectedData.put(blocks2.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(1).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks2.get(0).getBlockId());
processBlockIds.addLong(blocks2.get(1).getBlockId());
@@ -179,7 +180,7 @@ public class ShuffleServerWithMemLocalHdfsTest extends
ShuffleReadWriteBase {
// read the 3-th segment from localFile
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks2.get(2).getBlockId(), blocks2.get(2).getData());
+ expectedData.put(blocks2.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(2).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks2.get(2).getBlockId());
sdr.getBufferSegments().forEach(bs ->
composedClientReadHandler.updateConsumedBlockInfo(bs, checkSkippedMetrics));
@@ -200,8 +201,8 @@ public class ShuffleServerWithMemLocalHdfsTest extends
ShuffleReadWriteBase {
// read the 4-th segment from HDFS
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks3.get(0).getBlockId(), blocks3.get(0).getData());
- expectedData.put(blocks3.get(1).getBlockId(), blocks3.get(1).getData());
+ expectedData.put(blocks3.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks3.get(0).getData()));
+ expectedData.put(blocks3.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks3.get(1).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks3.get(0).getBlockId());
processBlockIds.addLong(blocks3.get(1).getBlockId());
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
index b2f32d8c..8e47c131 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemoryTest.java
@@ -39,6 +39,7 @@ import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.buffer.ShuffleBuffer;
@@ -121,17 +122,17 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// start to read data, one block data for every call
ShuffleDataResult sdr = memoryClientReadHandler.readShuffleData();
Map<Long, byte[]> expectedData = Maps.newHashMap();
- expectedData.put(blocks.get(0).getBlockId(), blocks.get(0).getData());
+ expectedData.put(blocks.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks.get(0).getData()));
validateResult(expectedData, sdr);
sdr = memoryClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
+ expectedData.put(blocks.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
validateResult(expectedData, sdr);
sdr = memoryClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks.get(2).getBlockId(), blocks.get(2).getData());
+ expectedData.put(blocks.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks.get(2).getData()));
validateResult(expectedData, sdr);
// no data in cache, empty return
@@ -155,8 +156,8 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// read from memory with ComposedClientReadHandler
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks.get(0).getBlockId(), blocks.get(0).getData());
- expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
+ expectedData.put(blocks.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks.get(0).getData()));
+ expectedData.put(blocks.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
validateResult(expectedData, sdr);
// send data to shuffle server, flush should happen
@@ -192,21 +193,21 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// when segment filter is introduced, there is no need to read duplicated
data
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks.get(2).getBlockId(), blocks.get(2).getData());
- expectedData.put(blocks2.get(0).getBlockId(), blocks2.get(0).getData());
+ expectedData.put(blocks.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks.get(2).getData()));
+ expectedData.put(blocks2.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(0).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks.get(2).getBlockId());
processBlockIds.addLong(blocks2.get(0).getBlockId());
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks2.get(1).getBlockId(), blocks2.get(1).getData());
+ expectedData.put(blocks2.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(1).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks2.get(1).getBlockId());
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks2.get(2).getBlockId(), blocks2.get(2).getData());
+ expectedData.put(blocks2.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(2).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks2.get(2).getBlockId());
@@ -256,7 +257,7 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// start to read data, one block data for every call
ShuffleDataResult sdr = memoryClientReadHandler.readShuffleData();
Map<Long, byte[]> expectedData = Maps.newHashMap();
- expectedData.put(blocks.get(0).getBlockId(), blocks.get(0).getData());
+ expectedData.put(blocks.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks.get(0).getData()));
validateResult(expectedData, sdr);
// read by different reader, the first block should be skipped.
exceptTaskIds.removeLong(blocks.get(0).getTaskAttemptId());
@@ -264,17 +265,17 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
testAppId, shuffleId, partitionId, 20, shuffleServerClient,
exceptTaskIds);
sdr = memoryClientReadHandler2.readShuffleData();
expectedData.clear();
- expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
+ expectedData.put(blocks.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
validateResult(expectedData, sdr);
sdr = memoryClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
+ expectedData.put(blocks.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
validateResult(expectedData, sdr);
sdr = memoryClientReadHandler2.readShuffleData();
expectedData.clear();
- expectedData.put(blocks.get(2).getBlockId(), blocks.get(2).getData());
+ expectedData.put(blocks.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks.get(2).getData()));
validateResult(expectedData, sdr);
// no data in cache, empty return
sdr = memoryClientReadHandler2.readShuffleData();
@@ -321,9 +322,9 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
new ShuffleServerInfo(LOCALHOST, SHUFFLE_SERVER_PORT), handlers);
Map<Long, byte[]> expectedData = Maps.newHashMap();
expectedData.clear();
- expectedData.put(blocks.get(0).getBlockId(), blocks.get(0).getData());
- expectedData.put(blocks.get(1).getBlockId(), blocks.get(1).getData());
- expectedData.put(blocks.get(2).getBlockId(), blocks.get(1).getData());
+ expectedData.put(blocks.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks.get(0).getData()));
+ expectedData.put(blocks.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
+ expectedData.put(blocks.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks.get(1).getData()));
ShuffleDataResult sdr = composedClientReadHandler.readShuffleData();
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks.get(0).getBlockId());
@@ -360,8 +361,8 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// notice: the 1-th segment is skipped, because it is processed
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks2.get(0).getBlockId(), blocks2.get(0).getData());
- expectedData.put(blocks2.get(1).getBlockId(), blocks2.get(1).getData());
+ expectedData.put(blocks2.get(0).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(0).getData()));
+ expectedData.put(blocks2.get(1).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(1).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks2.get(0).getBlockId());
processBlockIds.addLong(blocks2.get(1).getBlockId());
@@ -369,7 +370,7 @@ public class ShuffleServerWithMemoryTest extends
ShuffleReadWriteBase {
// read the 3-th segment from localFile
sdr = composedClientReadHandler.readShuffleData();
expectedData.clear();
- expectedData.put(blocks2.get(2).getBlockId(), blocks2.get(2).getData());
+ expectedData.put(blocks2.get(2).getBlockId(),
ByteBufUtils.readBytes(blocks2.get(2).getData()));
validateResult(expectedData, sdr);
processBlockIds.addLong(blocks2.get(2).getBlockId());
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index f2ed1e87..94c7d070 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -320,7 +320,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
.setLength(sbi.getLength())
.setTaskAttemptId(sbi.getTaskAttemptId())
.setUncompressLength(sbi.getUncompressLength())
- .setData(ByteString.copyFrom(sbi.getData()))
+
.setData(UnsafeByteOperations.unsafeWrap(sbi.getData().nioBuffer()))
.build());
size += sbi.getSize();
blockNum++;