This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new d1ed98c7 [#133] feat(netty): Rewrite protocol. (#826)
d1ed98c7 is described below
commit d1ed98c7499db298743ee57597c5097b7bd478d4
Author: Xianming Lei <[email protected]>
AuthorDate: Sat Apr 22 22:32:13 2023 +0800
[#133] feat(netty): Rewrite protocol. (#826)
### What changes were proposed in this pull request?
Rewrite protocol.
### Why are the changes needed?
For netty replace grpc.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UTs.
Co-authored-by: leixianming <[email protected]>
---
.../Encodable.java => DecodeException.java} | 18 ++-
.../Encodable.java => EncodeException.java} | 18 ++-
.../uniffle/common/netty/protocol/Decoders.java | 35 ++++++
.../uniffle/common/netty/protocol/Encodable.java | 4 +-
.../uniffle/common/netty/protocol/Encoders.java | 26 +++++
.../netty/protocol/GetLocalShuffleDataRequest.java | 115 +++++++++++++++++++
.../protocol/GetLocalShuffleDataResponse.java | 70 ++++++++++++
.../protocol/GetLocalShuffleIndexRequest.java | 91 +++++++++++++++
.../protocol/GetLocalShuffleIndexResponse.java | 80 +++++++++++++
.../protocol/GetMemoryShuffleDataRequest.java | 125 +++++++++++++++++++++
.../protocol/GetMemoryShuffleDataResponse.java | 85 ++++++++++++++
.../{Encodable.java => RequestMessage.java} | 15 ++-
.../uniffle/common/netty/protocol/RpcResponse.java | 21 ----
.../netty/protocol/SendShuffleDataRequest.java | 14 +--
.../apache/uniffle/common/util/ByteBufUtils.java | 5 +-
.../common/netty/protocol/NettyProtocolTest.java | 125 ++++++++++++++++++++-
.../netty/protocol/NettyProtocolTestUtils.java | 2 +-
17 files changed, 800 insertions(+), 49 deletions(-)
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
b/common/src/main/java/org/apache/uniffle/common/netty/DecodeException.java
similarity index 69%
copy from
common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
copy to
common/src/main/java/org/apache/uniffle/common/netty/DecodeException.java
index 0ec305fa..18c625dc 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/DecodeException.java
@@ -15,13 +15,21 @@
* limitations under the License.
*/
-package org.apache.uniffle.common.netty.protocol;
+package org.apache.uniffle.common.netty;
-import io.netty.buffer.ByteBuf;
+import org.apache.uniffle.common.exception.RssException;
-public interface Encodable {
+public class DecodeException extends RssException {
- int encodedLength();
+ public DecodeException(String message) {
+ super(message);
+ }
- void encode(ByteBuf buf);
+ public DecodeException(Throwable e) {
+ super(e);
+ }
+
+ public DecodeException(String message, Throwable e) {
+ super(message, e);
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
b/common/src/main/java/org/apache/uniffle/common/netty/EncodeException.java
similarity index 69%
copy from
common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
copy to
common/src/main/java/org/apache/uniffle/common/netty/EncodeException.java
index 0ec305fa..ec6855c4 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/EncodeException.java
@@ -15,13 +15,21 @@
* limitations under the License.
*/
-package org.apache.uniffle.common.netty.protocol;
+package org.apache.uniffle.common.netty;
-import io.netty.buffer.ByteBuf;
+import org.apache.uniffle.common.exception.RssException;
-public interface Encodable {
+public class EncodeException extends RssException {
- int encodedLength();
+ public EncodeException(String message) {
+ super(message);
+ }
- void encode(ByteBuf buf);
+ public EncodeException(Throwable e) {
+ super(e);
+ }
+
+ public EncodeException(String message, Throwable e) {
+ super(message, e);
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
index 4b969a62..b53d73b9 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
@@ -18,10 +18,13 @@
package org.apache.uniffle.common.netty.protocol;
import java.util.List;
+import java.util.Map;
import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
import io.netty.buffer.ByteBuf;
+import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.ByteBufUtils;
@@ -53,4 +56,36 @@ public class Decoders {
return new ShuffleBlockInfo(shuffleId, partId, blockId,
length, crc, data, serverInfos, uncompressLength, freeMemory,
taskAttemptId);
}
+
+ public static Map<Integer, List<Long>> decodePartitionToBlockIds(ByteBuf
byteBuf) {
+ Map<Integer, List<Long>> partitionToBlockIds = Maps.newHashMap();
+ int mapSize = byteBuf.readInt();
+ for (int i = 0; i < mapSize; i++) {
+ int partitionId = byteBuf.readInt();
+ int blockListSize = byteBuf.readInt();
+ List<Long> blocks = Lists.newArrayList();
+ for (int j = 0; j < blockListSize; j++) {
+ blocks.add(byteBuf.readLong());
+ }
+ partitionToBlockIds.put(partitionId, blocks);
+ }
+ return partitionToBlockIds;
+ }
+
+ public static List<BufferSegment> decodeBufferSegments(ByteBuf byteBuf) {
+ List<BufferSegment> bufferSegments = Lists.newArrayList();
+ int size = byteBuf.readInt();
+ for (int i = 0; i < size; i++) {
+ long blockId = byteBuf.readLong();
+ int offset = byteBuf.readInt();
+ int length = byteBuf.readInt();
+ int uncompressLength = byteBuf.readInt();
+ long crc = byteBuf.readLong();
+ long taskAttemptId = byteBuf.readLong();
+ BufferSegment bufferSegment = new BufferSegment(blockId, offset, length,
uncompressLength, crc, taskAttemptId);
+ bufferSegments.add(bufferSegment);
+ }
+ return bufferSegments;
+ }
+
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
index 0ec305fa..c3fb8165 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
@@ -19,9 +19,11 @@ package org.apache.uniffle.common.netty.protocol;
import io.netty.buffer.ByteBuf;
+import org.apache.uniffle.common.netty.EncodeException;
+
public interface Encodable {
int encodedLength();
- void encode(ByteBuf buf);
+ void encode(ByteBuf buf) throws EncodeException;
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java
index e819355a..f63f7d98 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java
@@ -21,6 +21,8 @@ import java.util.List;
import io.netty.buffer.ByteBuf;
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.ByteBufUtils;
@@ -67,4 +69,28 @@ public class Encoders {
return encodeLength;
}
+ public static void encodePartitionRanges(List<PartitionRange>
partitionRanges, ByteBuf byteBuf) {
+ byteBuf.writeInt(partitionRanges.size());
+ for (PartitionRange partitionRange : partitionRanges) {
+ byteBuf.writeInt(partitionRange.getStart());
+ byteBuf.writeInt(partitionRange.getEnd());
+ }
+ }
+
+ public static void encodeBufferSegments(List<BufferSegment> bufferSegments,
ByteBuf byteBuf) {
+ byteBuf.writeInt(bufferSegments.size());
+ for (BufferSegment bufferSegment : bufferSegments) {
+ byteBuf.writeLong(bufferSegment.getBlockId());
+ byteBuf.writeInt(bufferSegment.getOffset());
+ byteBuf.writeInt(bufferSegment.getLength());
+ byteBuf.writeInt(bufferSegment.getUncompressLength());
+ byteBuf.writeLong(bufferSegment.getCrc());
+ byteBuf.writeLong(bufferSegment.getTaskAttemptId());
+ }
+ }
+
+ public static int encodeLengthOfBufferSegments(List<BufferSegment>
bufferSegments) {
+ return Integer.BYTES + bufferSegments.size() * (3 * Long.BYTES + 3 *
Integer.BYTES);
+ }
+
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java
new file mode 100644
index 00000000..7cf6f05b
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class GetLocalShuffleDataRequest extends RequestMessage {
+ private String appId;
+ private int shuffleId;
+ private int partitionId;
+ private int partitionNumPerRange;
+ private int partitionNum;
+ private long offset;
+ private int length;
+ private long timestamp;
+
+ public GetLocalShuffleDataRequest(long requestId, String appId, int
shuffleId, int partitionId,
+ int partitionNumPerRange, int partitionNum, long offset, int length,
long timestamp) {
+ super(requestId);
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.partitionId = partitionId;
+ this.partitionNumPerRange = partitionNumPerRange;
+ this.partitionNum = partitionNum;
+ this.offset = offset;
+ this.length = length;
+ this.timestamp = timestamp;
+ }
+
+ @Override
+ public Type type() {
+ return Type.GET_LOCAL_SHUFFLE_DATA_REQUEST;
+ }
+
+ @Override
+ public int encodedLength() {
+ return REQUEST_ID_ENCODE_LENGTH + ByteBufUtils.encodedLength(appId) + 2 *
Long.BYTES + 5 * Integer.BYTES;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(getRequestId());
+ ByteBufUtils.writeLengthAndString(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeInt(partitionId);
+ buf.writeInt(partitionNumPerRange);
+ buf.writeInt(partitionNum);
+ buf.writeLong(offset);
+ buf.writeInt(length);
+ buf.writeLong(timestamp);
+ }
+
+ public static GetLocalShuffleDataRequest decode(ByteBuf byteBuf) {
+ long requestId = byteBuf.readLong();
+ String appId = ByteBufUtils.readLengthAndString(byteBuf);
+ int shuffleId = byteBuf.readInt();
+ int partitionId = byteBuf.readInt();
+ int partitionNumPerRange = byteBuf.readInt();
+ int partitionNum = byteBuf.readInt();
+ long offset = byteBuf.readLong();
+ int length = byteBuf.readInt();
+ long timestamp = byteBuf.readLong();
+ return new GetLocalShuffleDataRequest(requestId, appId, shuffleId,
partitionId, partitionNumPerRange,
+ partitionNum, offset, length, timestamp);
+ }
+
+ public String getAppId() {
+ return appId;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+
+ public int getPartitionNumPerRange() {
+ return partitionNumPerRange;
+ }
+
+ public int getPartitionNum() {
+ return partitionNum;
+ }
+
+ public long getOffset() {
+ return offset;
+ }
+
+ public int getLength() {
+ return length;
+ }
+
+ public long getTimestamp() {
+ return timestamp;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
new file mode 100644
index 00000000..d955f22d
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataResponse.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class GetLocalShuffleDataResponse extends RpcResponse {
+ private ByteBuf data;
+
+ public GetLocalShuffleDataResponse(long requestId, StatusCode statusCode,
byte[] data) {
+ this(requestId, statusCode, null, data);
+ }
+
+ public GetLocalShuffleDataResponse(long requestId, StatusCode statusCode,
String retMessage, byte[] data) {
+ this(requestId, statusCode, retMessage, Unpooled.wrappedBuffer(data));
+ }
+
+ public GetLocalShuffleDataResponse(long requestId, StatusCode statusCode,
String retMessage, ByteBuf data) {
+ super(requestId, statusCode, retMessage);
+ this.data = data;
+ }
+
+ @Override
+ public int encodedLength() {
+ return super.encodedLength() + Integer.BYTES + data.readableBytes();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ super.encode(buf);
+ ByteBufUtils.copyByteBuf(data, buf);
+ data.release();
+ }
+
+ public static GetLocalShuffleDataResponse decode(ByteBuf byteBuf) {
+ long requestId = byteBuf.readLong();
+ StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
+ String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
+ ByteBuf data = ByteBufUtils.readSlice(byteBuf);
+ return new GetLocalShuffleDataResponse(requestId, statusCode, retMessage,
data);
+ }
+
+ @Override
+ public Type type() {
+ return Type.GET_LOCAL_SHUFFLE_DATA_RESPONSE;
+ }
+
+ public ByteBuf getData() {
+ return data;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexRequest.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexRequest.java
new file mode 100644
index 00000000..a3cbb305
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexRequest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class GetLocalShuffleIndexRequest extends RequestMessage {
+ private String appId;
+ private int shuffleId;
+ private int partitionId;
+ private int partitionNumPerRange;
+ private int partitionNum;
+
+ public GetLocalShuffleIndexRequest(long requestId, String appId, int
shuffleId, int partitionId,
+ int partitionNumPerRange, int partitionNum) {
+ super(requestId);
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.partitionId = partitionId;
+ this.partitionNumPerRange = partitionNumPerRange;
+ this.partitionNum = partitionNum;
+ }
+
+ @Override
+ public Type type() {
+ return Type.GET_LOCAL_SHUFFLE_INDEX_REQUEST;
+ }
+
+ @Override
+ public int encodedLength() {
+ return REQUEST_ID_ENCODE_LENGTH + ByteBufUtils.encodedLength(appId) + 4 *
Integer.BYTES;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(getRequestId());
+ ByteBufUtils.writeLengthAndString(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeInt(partitionId);
+ buf.writeInt(partitionNumPerRange);
+ buf.writeInt(partitionNum);
+ }
+
+ public static GetLocalShuffleIndexRequest decode(ByteBuf byteBuf) {
+ long requestId = byteBuf.readLong();
+ String appId = ByteBufUtils.readLengthAndString(byteBuf);
+ int shuffleId = byteBuf.readInt();
+ int partitionId = byteBuf.readInt();
+ int partitionNumPerRange = byteBuf.readInt();
+ int partitionNum = byteBuf.readInt();
+ return new GetLocalShuffleIndexRequest(requestId, appId, shuffleId,
partitionId, partitionNumPerRange,
+ partitionNum);
+ }
+
+ public String getAppId() {
+ return appId;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+
+ public int getPartitionNumPerRange() {
+ return partitionNumPerRange;
+ }
+
+ public int getPartitionNum() {
+ return partitionNum;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
new file mode 100644
index 00000000..dc8c730b
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class GetLocalShuffleIndexResponse extends RpcResponse {
+ private ByteBuf indexData;
+ private long fileLength;
+
+ public GetLocalShuffleIndexResponse(long requestId, StatusCode statusCode,
byte[] indexData, long fileLength) {
+ this(requestId, statusCode, null, indexData, fileLength);
+ }
+
+ public GetLocalShuffleIndexResponse(long requestId, StatusCode statusCode,
String retMessage,
+ byte[] indexData, long fileLength) {
+ this(requestId, statusCode, retMessage, Unpooled.wrappedBuffer(indexData),
fileLength);
+ }
+
+ public GetLocalShuffleIndexResponse(long requestId, StatusCode statusCode,
String retMessage,
+ ByteBuf indexData, long fileLength) {
+ super(requestId, statusCode, retMessage);
+ this.indexData = indexData;
+ this.fileLength = fileLength;
+ }
+
+ @Override
+ public int encodedLength() {
+ return super.encodedLength() + Integer.BYTES + indexData.readableBytes() +
Long.BYTES;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ super.encode(buf);
+ ByteBufUtils.copyByteBuf(indexData, buf);
+ indexData.release();
+ buf.writeLong(fileLength);
+ }
+
+ public static GetLocalShuffleIndexResponse decode(ByteBuf byteBuf) {
+ long requestId = byteBuf.readLong();
+ StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
+ String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
+ ByteBuf indexData = ByteBufUtils.readSlice(byteBuf);
+ long fileLength = byteBuf.readLong();
+ return new GetLocalShuffleIndexResponse(requestId, statusCode, retMessage,
indexData, fileLength);
+ }
+
+ @Override
+ public Type type() {
+ return Type.GET_LOCAL_SHUFFLE_INDEX_RESPONSE;
+ }
+
+ public ByteBuf getIndexData() {
+ return indexData;
+ }
+
+ public long getFileLength() {
+ return fileLength;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
new file mode 100644
index 00000000..4fe77996
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.io.IOException;
+
+import io.netty.buffer.ByteBuf;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.common.netty.DecodeException;
+import org.apache.uniffle.common.netty.EncodeException;
+import org.apache.uniffle.common.util.ByteBufUtils;
+import org.apache.uniffle.common.util.RssUtils;
+
+public class GetMemoryShuffleDataRequest extends RequestMessage {
+ private String appId;
+ private int shuffleId;
+ private int partitionId;
+ private long lastBlockId;
+ private int readBufferSize;
+ private long timestamp;
+ private Roaring64NavigableMap expectedTaskIdsBitmap;
+
+ public GetMemoryShuffleDataRequest(long requestId, String appId, int
shuffleId, int partitionId, long lastBlockId,
+ int readBufferSize, long timestamp, Roaring64NavigableMap
expectedTaskIdsBitmap) {
+ super(requestId);
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ this.partitionId = partitionId;
+ this.lastBlockId = lastBlockId;
+ this.readBufferSize = readBufferSize;
+ this.timestamp = timestamp;
+ this.expectedTaskIdsBitmap = expectedTaskIdsBitmap;
+ }
+
+ @Override
+ public Type type() {
+ return Type.GET_MEMORY_SHUFFLE_DATA_REQUEST;
+ }
+
+ @Override
+ public int encodedLength() {
+ return (int) (REQUEST_ID_ENCODE_LENGTH + ByteBufUtils.encodedLength(appId)
+ 4 * Integer.BYTES
+ + 2 * Long.BYTES +
expectedTaskIdsBitmap.serializedSizeInBytes());
+ }
+
+ @Override
+ public void encode(ByteBuf buf) throws EncodeException {
+ buf.writeLong(getRequestId());
+ ByteBufUtils.writeLengthAndString(buf, appId);
+ buf.writeInt(shuffleId);
+ buf.writeInt(partitionId);
+ buf.writeLong(lastBlockId);
+ buf.writeInt(readBufferSize);
+ buf.writeLong(timestamp);
+ buf.writeInt((int) expectedTaskIdsBitmap.serializedSizeInBytes());
+ try {
+ buf.writeBytes(RssUtils.serializeBitMap(expectedTaskIdsBitmap));
+ } catch (IOException ioException) {
+ throw new EncodeException("serializeBitMap failed while encode
GetMemoryShuffleDataRequest!", ioException);
+ }
+ }
+
+ public static GetMemoryShuffleDataRequest decode(ByteBuf byteBuf) throws
DecodeException {
+ long requestId = byteBuf.readLong();
+ String appId = ByteBufUtils.readLengthAndString(byteBuf);
+ int shuffleId = byteBuf.readInt();
+ int partitionId = byteBuf.readInt();
+ long lastBlockId = byteBuf.readLong();
+ int readBufferSize = byteBuf.readInt();
+ long timestamp = byteBuf.readLong();
+ byte[] bytes = ByteBufUtils.readByteArray(byteBuf);
+ Roaring64NavigableMap expectedTaskIdsBitmap;
+ try {
+ expectedTaskIdsBitmap = RssUtils.deserializeBitMap(bytes);
+ } catch (IOException ioException) {
+ throw new DecodeException("serializeBitMap failed while decode
GetMemoryShuffleDataRequest!", ioException);
+ }
+ return new GetMemoryShuffleDataRequest(requestId, appId, shuffleId,
partitionId, lastBlockId, readBufferSize,
+ timestamp, expectedTaskIdsBitmap);
+ }
+
+ public String getAppId() {
+ return appId;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+
+ public long getLastBlockId() {
+ return lastBlockId;
+ }
+
+ public int getReadBufferSize() {
+ return readBufferSize;
+ }
+
+ public long getTimestamp() {
+ return timestamp;
+ }
+
+ public Roaring64NavigableMap getExpectedTaskIdsBitmap() {
+ return expectedTaskIdsBitmap;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
new file mode 100644
index 00000000..619140be
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataResponse.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class GetMemoryShuffleDataResponse extends RpcResponse {
+ private List<BufferSegment> bufferSegments;
+ private ByteBuf data;
+
+ public GetMemoryShuffleDataResponse(long requestId, StatusCode statusCode,
+ List<BufferSegment> bufferSegments, byte[] data) {
+ this(requestId, statusCode, null, bufferSegments, data);
+ }
+
+ public GetMemoryShuffleDataResponse(long requestId, StatusCode statusCode,
String retMessage,
+ List<BufferSegment> bufferSegments, byte[] data) {
+ this(requestId, statusCode, retMessage, bufferSegments,
Unpooled.wrappedBuffer(data));
+ }
+
+ public GetMemoryShuffleDataResponse(long requestId, StatusCode statusCode,
String retMessage,
+ List<BufferSegment> bufferSegments, ByteBuf data) {
+ super(requestId, statusCode, retMessage);
+ this.bufferSegments = bufferSegments;
+ this.data = data;
+ }
+
+ @Override
+ public int encodedLength() {
+ return super.encodedLength() +
Encoders.encodeLengthOfBufferSegments(bufferSegments)
+ + Integer.BYTES + data.readableBytes();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ super.encode(buf);
+ Encoders.encodeBufferSegments(bufferSegments, buf);
+ ByteBufUtils.copyByteBuf(data, buf);
+ data.release();
+ }
+
+ public static GetMemoryShuffleDataResponse decode(ByteBuf byteBuf) {
+ long requestId = byteBuf.readLong();
+ StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt());
+ String retMessage = ByteBufUtils.readLengthAndString(byteBuf);
+ List<BufferSegment> bufferSegments =
Decoders.decodeBufferSegments(byteBuf);
+ ByteBuf data = ByteBufUtils.readSlice(byteBuf);
+ return new GetMemoryShuffleDataResponse(requestId, statusCode, retMessage,
bufferSegments, data);
+ }
+
+ @Override
+ public Type type() {
+ return Type.GET_MEMORY_SHUFFLE_DATA_RESPONSE;
+ }
+
+ public List<BufferSegment> getBufferSegments() {
+ return bufferSegments;
+ }
+
+ public ByteBuf getData() {
+ return data;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
similarity index 74%
copy from
common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
copy to
common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
index 0ec305fa..695484db 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java
@@ -17,11 +17,16 @@
package org.apache.uniffle.common.netty.protocol;
-import io.netty.buffer.ByteBuf;
+public abstract class RequestMessage extends Message {
+ private final long requestId;
+ public static final int REQUEST_ID_ENCODE_LENGTH = Long.BYTES;
-public interface Encodable {
+ public RequestMessage(long requestId) {
+ super();
+ this.requestId = requestId;
+ }
- int encodedLength();
-
- void encode(ByteBuf buf);
+ public long getRequestId() {
+ return requestId;
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
index 686c8d6e..9fef38cb 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
@@ -17,8 +17,6 @@
package org.apache.uniffle.common.netty.protocol;
-import java.util.Objects;
-
import io.netty.buffer.ByteBuf;
import org.apache.uniffle.common.rpc.StatusCode;
@@ -84,23 +82,4 @@ public class RpcResponse extends Message {
public Type type() {
return Type.RPC_RESPONSE;
}
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- RpcResponse that = (RpcResponse) o;
- return requestId == that.requestId
- && statusCode == that.statusCode
- && Objects.equals(retMessage, that.retMessage);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(requestId, statusCode, retMessage);
- }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
index cc1117a8..4c8c82b2 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
@@ -27,8 +27,7 @@ import io.netty.buffer.ByteBuf;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.util.ByteBufUtils;
-public class SendShuffleDataRequest extends Message {
- public long requestId;
+public class SendShuffleDataRequest extends RequestMessage {
private String appId;
private int shuffleId;
private long requireId;
@@ -37,7 +36,7 @@ public class SendShuffleDataRequest extends Message {
public SendShuffleDataRequest(long requestId, String appId, int shuffleId,
long requireId,
Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks, long timestamp) {
- this.requestId = requestId;
+ super(requestId);
this.appId = appId;
this.shuffleId = shuffleId;
this.requireId = requireId;
@@ -52,7 +51,8 @@ public class SendShuffleDataRequest extends Message {
@Override
public int encodedLength() {
- int encodeLength = Long.BYTES + ByteBufUtils.encodedLength(appId) +
Integer.BYTES + Long.BYTES + Integer.BYTES;
+ int encodeLength =
+ REQUEST_ID_ENCODE_LENGTH + ByteBufUtils.encodedLength(appId) +
Integer.BYTES + Long.BYTES + Integer.BYTES;
for (Map.Entry<Integer, List<ShuffleBlockInfo>> entry :
partitionToBlocks.entrySet()) {
encodeLength += 2 * Integer.BYTES;
for (ShuffleBlockInfo sbi : entry.getValue()) {
@@ -64,7 +64,7 @@ public class SendShuffleDataRequest extends Message {
@Override
public void encode(ByteBuf buf) {
- buf.writeLong(requestId);
+ buf.writeLong(getRequestId());
ByteBufUtils.writeLengthAndString(buf, appId);
buf.writeInt(shuffleId);
buf.writeLong(requireId);
@@ -108,10 +108,6 @@ public class SendShuffleDataRequest extends Message {
}
}
- public long getRequestId() {
- return requestId;
- }
-
public String getAppId() {
return appId;
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
index 1872674b..afdfae27 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
@@ -24,7 +24,10 @@ import io.netty.buffer.ByteBuf;
public class ByteBufUtils {
public static int encodedLength(String s) {
- return 4 + s.getBytes(StandardCharsets.UTF_8).length;
+ if (s == null) {
+ return Integer.BYTES;
+ }
+ return Integer.BYTES + s.getBytes(StandardCharsets.UTF_8).length;
}
public static int encodedLength(ByteBuf buf) {
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
index d2da5512..419c70d9 100644
---
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
+++
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java
@@ -21,11 +21,14 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.rpc.StatusCode;
@@ -81,8 +84,128 @@ public class NettyProtocolTest {
rpcResponse.encode(byteBuf);
assertEquals(byteBuf.readableBytes(), encodeLength);
RpcResponse rpcResponse1 = RpcResponse.decode(byteBuf);
- assertTrue(rpcResponse.equals(rpcResponse1));
+ assertEquals(rpcResponse.getRequestId(), rpcResponse1.getRequestId());
+ assertEquals(rpcResponse.getRetMessage(), rpcResponse1.getRetMessage());
+ assertEquals(rpcResponse.getStatusCode(), rpcResponse1.getStatusCode());
assertEquals(rpcResponse.encodedLength(), rpcResponse1.encodedLength());
byteBuf.release();
}
+
+ @Test
+ public void testGetLocalShuffleDataRequest() {
+ GetLocalShuffleDataRequest getLocalShuffleDataRequest = new
GetLocalShuffleDataRequest(1, "test_app",
+ 1, 1, 1, 100, 0, 200, System.currentTimeMillis());
+ int encodeLength = getLocalShuffleDataRequest.encodedLength();
+ ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
+ getLocalShuffleDataRequest.encode(byteBuf);
+ GetLocalShuffleDataRequest getLocalShuffleDataRequest1 =
GetLocalShuffleDataRequest.decode(byteBuf);
+
+ assertEquals(getLocalShuffleDataRequest.getRequestId(),
getLocalShuffleDataRequest1.getRequestId());
+ assertEquals(getLocalShuffleDataRequest.getAppId(),
getLocalShuffleDataRequest1.getAppId());
+ assertEquals(getLocalShuffleDataRequest.getShuffleId(),
getLocalShuffleDataRequest1.getShuffleId());
+ assertEquals(getLocalShuffleDataRequest.getPartitionId(),
getLocalShuffleDataRequest1.getPartitionId());
+ assertEquals(getLocalShuffleDataRequest.getPartitionNumPerRange(),
+ getLocalShuffleDataRequest1.getPartitionNumPerRange());
+ assertEquals(getLocalShuffleDataRequest.getPartitionNum(),
getLocalShuffleDataRequest1.getPartitionNum());
+ assertEquals(getLocalShuffleDataRequest.getOffset(),
getLocalShuffleDataRequest1.getOffset());
+ assertEquals(getLocalShuffleDataRequest.getLength(),
getLocalShuffleDataRequest1.getLength());
+ assertEquals(getLocalShuffleDataRequest.getTimestamp(),
getLocalShuffleDataRequest1.getTimestamp());
+ }
+
+ @Test
+ public void testGetLocalShuffleDataResponse() {
+ byte[] data = new byte[]{1, 2, 3};
+ GetLocalShuffleDataResponse getLocalShuffleDataResponse =
+ new GetLocalShuffleDataResponse(1, StatusCode.SUCCESS, "",
Unpooled.wrappedBuffer(data).retain());
+ int encodeLength = getLocalShuffleDataResponse.encodedLength();
+ ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
+ getLocalShuffleDataResponse.encode(byteBuf);
+ GetLocalShuffleDataResponse getLocalShuffleDataResponse1 =
GetLocalShuffleDataResponse.decode(byteBuf);
+
+ assertEquals(getLocalShuffleDataResponse.getRequestId(),
getLocalShuffleDataResponse1.getRequestId());
+ assertEquals(getLocalShuffleDataResponse.getRetMessage(),
getLocalShuffleDataResponse1.getRetMessage());
+ assertEquals(getLocalShuffleDataResponse.getStatusCode(),
getLocalShuffleDataResponse1.getStatusCode());
+ assertEquals(getLocalShuffleDataResponse.getData(),
getLocalShuffleDataResponse1.getData());
+ }
+
+ @Test
+ public void testGetLocalShuffleIndexRequest() {
+ GetLocalShuffleIndexRequest getLocalShuffleIndexRequest =
+ new GetLocalShuffleIndexRequest(1, "test_app", 1,
+ 1, 1, 100);
+ int encodeLength = getLocalShuffleIndexRequest.encodedLength();
+ ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
+ getLocalShuffleIndexRequest.encode(byteBuf);
+ GetLocalShuffleIndexRequest getLocalShuffleIndexRequest1 =
GetLocalShuffleIndexRequest.decode(byteBuf);
+
+ assertEquals(getLocalShuffleIndexRequest.getRequestId(),
getLocalShuffleIndexRequest1.getRequestId());
+ assertEquals(getLocalShuffleIndexRequest.getAppId(),
getLocalShuffleIndexRequest1.getAppId());
+ assertEquals(getLocalShuffleIndexRequest.getShuffleId(),
getLocalShuffleIndexRequest1.getShuffleId());
+ assertEquals(getLocalShuffleIndexRequest.getPartitionId(),
getLocalShuffleIndexRequest1.getPartitionId());
+ assertEquals(getLocalShuffleIndexRequest.getPartitionNumPerRange(),
+ getLocalShuffleIndexRequest1.getPartitionNumPerRange());
+ assertEquals(getLocalShuffleIndexRequest.getPartitionNum(),
getLocalShuffleIndexRequest1.getPartitionNum());
+ }
+
+ @Test
+ public void testGetLocalShuffleIndexResponse() {
+ byte[] indexData = new byte[]{1, 2, 3};
+ GetLocalShuffleIndexResponse getLocalShuffleIndexResponse =
+ new GetLocalShuffleIndexResponse(1, StatusCode.SUCCESS, "",
Unpooled.wrappedBuffer(indexData).retain(), 23);
+ int encodeLength = getLocalShuffleIndexResponse.encodedLength();
+ ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
+ getLocalShuffleIndexResponse.encode(byteBuf);
+ GetLocalShuffleIndexResponse getLocalShuffleIndexResponse1 =
GetLocalShuffleIndexResponse.decode(byteBuf);
+
+ assertEquals(getLocalShuffleIndexResponse.getRequestId(),
getLocalShuffleIndexResponse1.getRequestId());
+ assertEquals(getLocalShuffleIndexResponse.getStatusCode(),
getLocalShuffleIndexResponse1.getStatusCode());
+ assertEquals(getLocalShuffleIndexResponse.getRetMessage(),
getLocalShuffleIndexResponse1.getRetMessage());
+ assertEquals(getLocalShuffleIndexResponse.getFileLength(),
getLocalShuffleIndexResponse1.getFileLength());
+ assertEquals(getLocalShuffleIndexResponse.getIndexData(),
getLocalShuffleIndexResponse1.getIndexData());
+ }
+
+ @Test
+ public void testGetMemoryShuffleDataRequest() {
+ Roaring64NavigableMap expectedTaskIdsBitmap =
Roaring64NavigableMap.bitmapOf(1, 2, 3, 4, 5);
+ GetMemoryShuffleDataRequest getMemoryShuffleDataRequest = new
GetMemoryShuffleDataRequest(1, "test_app",
+ 1, 1, 1, 64, System.currentTimeMillis(), expectedTaskIdsBitmap);
+ int encodeLength = getMemoryShuffleDataRequest.encodedLength();
+ ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
+ getMemoryShuffleDataRequest.encode(byteBuf);
+ GetMemoryShuffleDataRequest getMemoryShuffleDataRequest1 =
GetMemoryShuffleDataRequest.decode(byteBuf);
+
+ assertEquals(getMemoryShuffleDataRequest.getRequestId(),
getMemoryShuffleDataRequest1.getRequestId());
+ assertEquals(getMemoryShuffleDataRequest.getAppId(),
getMemoryShuffleDataRequest1.getAppId());
+ assertEquals(getMemoryShuffleDataRequest.getShuffleId(),
getMemoryShuffleDataRequest1.getShuffleId());
+ assertEquals(getMemoryShuffleDataRequest.getPartitionId(),
getMemoryShuffleDataRequest1.getPartitionId());
+ assertEquals(getMemoryShuffleDataRequest.getLastBlockId(),
getMemoryShuffleDataRequest1.getLastBlockId());
+ assertEquals(getMemoryShuffleDataRequest.getReadBufferSize(),
getMemoryShuffleDataRequest1.getReadBufferSize());
+ assertEquals(getMemoryShuffleDataRequest.getTimestamp(),
getMemoryShuffleDataRequest1.getTimestamp());
+
assertEquals(getMemoryShuffleDataRequest.getExpectedTaskIdsBitmap().getLongCardinality(),
+
getMemoryShuffleDataRequest1.getExpectedTaskIdsBitmap().getLongCardinality());
+ }
+
+ @Test
+ public void testGetMemoryShuffleDataResponse() {
+ byte[] data = new byte[]{1, 2, 3, 4, 5};
+ List<BufferSegment> bufferSegments = Lists.newArrayList(
+ new BufferSegment(1, 0, 5, 10, 123, 1),
+ new BufferSegment(1, 0, 5, 10, 345, 1));
+ GetMemoryShuffleDataResponse getMemoryShuffleDataResponse =
+ new GetMemoryShuffleDataResponse(1, StatusCode.SUCCESS, "",
bufferSegments,
+ Unpooled.wrappedBuffer(data).retain());
+ int encodeLength = getMemoryShuffleDataResponse.encodedLength();
+ ByteBuf byteBuf = Unpooled.buffer(encodeLength, encodeLength);
+ getMemoryShuffleDataResponse.encode(byteBuf);
+ GetMemoryShuffleDataResponse getMemoryShuffleDataResponse1 =
GetMemoryShuffleDataResponse.decode(byteBuf);
+
+ assertEquals(getMemoryShuffleDataResponse.getRequestId(),
getMemoryShuffleDataResponse1.getRequestId());
+ assertEquals(getMemoryShuffleDataResponse.getStatusCode(),
getMemoryShuffleDataResponse1.getStatusCode());
+
assertTrue(getMemoryShuffleDataResponse.getData().equals(getMemoryShuffleDataResponse1.getData()));
+
+ for (int i = 0; i < 2; i++) {
+ assertEquals(getMemoryShuffleDataResponse.getBufferSegments().get(i),
+ getMemoryShuffleDataResponse1.getBufferSegments().get(i));
+ }
+ }
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTestUtils.java
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTestUtils.java
index 29f1c122..ea4fb60f 100644
---
a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTestUtils.java
+++
b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTestUtils.java
@@ -73,7 +73,7 @@ public class NettyProtocolTestUtils {
if (req1 == null || req2 == null) {
return false;
}
- boolean isEqual = req1.requestId == req2.requestId
+ boolean isEqual = req1.getRequestId() == req2.getRequestId()
&& req1.getShuffleId() == req2.getShuffleId()
&& req1.getRequireId() == req2.getRequireId()
&& req1.getTimestamp() == req2.getTimestamp()