This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 35a14d246 [CELEBORN-1836][CIP-14] Add Message to cppClient
35a14d246 is described below
commit 35a14d24697d3332eb359606fe249c2ecaeb447c
Author: HolyLow <[email protected]>
AuthorDate: Sat Jan 18 12:52:32 2025 +0800
[CELEBORN-1836][CIP-14] Add Message to cppClient
### What changes were proposed in this pull request?
This PR adds Message implementation to cppClient.
### Why are the changes needed?
The Message is the data structure that transfers between network stack
layers.
The decode/encode methods are supported and are compatible with existing
java implementation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3066 from HolyLow/issue/celeborn-1836-add-message-to-cpp-client.
Authored-by: HolyLow <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
cpp/celeborn/network/CMakeLists.txt | 17 +++
cpp/celeborn/network/Message.cpp | 153 +++++++++++++++++++
cpp/celeborn/network/Message.h | 235 +++++++++++++++++++++++++++++
cpp/celeborn/network/tests/CMakeLists.txt | 4 +-
cpp/celeborn/network/tests/MessageTest.cpp | 155 +++++++++++++++++++
cpp/celeborn/protocol/CMakeLists.txt | 1 +
6 files changed, 564 insertions(+), 1 deletion(-)
diff --git a/cpp/celeborn/network/CMakeLists.txt
b/cpp/celeborn/network/CMakeLists.txt
index 889d959f5..3a65828bd 100644
--- a/cpp/celeborn/network/CMakeLists.txt
+++ b/cpp/celeborn/network/CMakeLists.txt
@@ -12,6 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+add_library(
+ network
+ STATIC
+ Message.cpp)
+
+target_include_directories(network PUBLIC ${CMAKE_BINARY_DIR})
+
+target_link_libraries(
+ network
+ memory
+ proto
+ utils
+ protocol
+ ${FOLLY_WITH_DEPENDENCIES}
+ ${GLOG}
+ ${GFLAGS_LIBRARIES}
+)
if(CELEBORN_BUILD_TESTS)
add_subdirectory(tests)
diff --git a/cpp/celeborn/network/Message.cpp b/cpp/celeborn/network/Message.cpp
new file mode 100644
index 000000000..f4ad4baca
--- /dev/null
+++ b/cpp/celeborn/network/Message.cpp
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/network/Message.h"
+
+namespace celeborn {
+namespace network {
+Message::Type Message::decodeType(uint8_t typeId) {
+ switch (typeId) {
+ case 0:
+ return CHUNK_FETCH_REQUEST;
+ case 1:
+ return CHUNK_FETCH_SUCCESS;
+ case 2:
+ return CHUNK_FETCH_FAILURE;
+ case 3:
+ return RPC_REQUEST;
+ case 4:
+ return RPC_RESPONSE;
+ case 5:
+ return RPC_FAILURE;
+ case 6:
+ return OPEN_STREAM;
+ case 7:
+ return STREAM_HANDLE;
+ case 9:
+ return ONE_WAY_MESSAGE;
+ case 11:
+ return PUSH_DATA;
+ case 12:
+ return PUSH_MERGED_DATA;
+ case 13:
+ return REGION_START;
+ case 14:
+ return REGION_FINISH;
+ case 15:
+ return PUSH_DATA_HAND_SHAKE;
+ case 16:
+ return READ_ADD_CREDIT;
+ case 17:
+ return READ_DATA;
+ case 18:
+ return OPEN_STREAM_WITH_CREDIT;
+ case 19:
+ return BACKLOG_ANNOUNCEMENT;
+ case 20:
+ return TRANSPORTABLE_ERROR;
+ case 21:
+ return BUFFER_STREAM_END;
+ case 22:
+ return HEARTBEAT;
+ default:
+ CELEBORN_FAIL("Unknown message type " + std::to_string(typeId));
+ }
+}
+
+std::atomic<long> Message::currRequestId_ = 0;
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> Message::encode() const {
+ int bodyLength = body_->remainingSize();
+ int encodedLength = internalEncodedLength();
+ int headerLength =
+ sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t) + encodedLength;
+ auto buffer = memory::ByteBuffer::createWriteOnly(headerLength);
+ buffer->write<int32_t>(encodedLength);
+ buffer->write<uint8_t>(type_);
+ buffer->write<int32_t>(bodyLength);
+ internalEncodeTo(*buffer);
+ auto header = memory::ByteBuffer::toReadOnly(std::move(buffer));
+ auto combinedFrame = memory::ByteBuffer::concat(*header, *body_);
+ return std::move(combinedFrame);
+}
+
+std::unique_ptr<Message> Message::decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+ int32_t encodedLength = data->read<int32_t>();
+ uint8_t typeId = data->read<uint8_t>();
+ int32_t bodyLength = data->read<int32_t>();
+ CELEBORN_CHECK_EQ(encodedLength + bodyLength, data->remainingSize());
+ Type type = decodeType(typeId);
+ switch (type) {
+ case RPC_RESPONSE:
+ return RpcResponse::decodeFrom(std::move(data));
+ case RPC_FAILURE:
+ return RpcFailure::decodeFrom(std::move(data));
+ case CHUNK_FETCH_SUCCESS:
+ return ChunkFetchSuccess::decodeFrom(std::move(data));
+ case CHUNK_FETCH_FAILURE:
+ return ChunkFetchFailure::decodeFrom(std::move(data));
+ default:
+ CELEBORN_FAIL("unsupported Message decode type " + std::to_string(type));
+ }
+}
+
+int RpcRequest::internalEncodedLength() const {
+ return sizeof(long) + sizeof(int32_t);
+}
+
+void RpcRequest::internalEncodeTo(memory::WriteOnlyByteBuffer& buffer) const {
+ buffer.write<long>(requestId_);
+ buffer.write<int32_t>(body_->remainingSize());
+}
+
+std::unique_ptr<RpcResponse> RpcResponse::decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+ long requestId = data->read<long>();
+ data->skip(4);
+ auto result = std::make_unique<RpcResponse>(requestId, std::move(data));
+ return result;
+}
+
+std::unique_ptr<RpcFailure> RpcFailure::decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+ long requestId = data->read<long>();
+ int strLen = data->read<int>();
+ CELEBORN_CHECK_EQ(data->remainingSize(), strLen);
+ std::string errorString = data->readToString(strLen);
+ return std::make_unique<RpcFailure>(requestId, std::move(errorString));
+}
+
+std::unique_ptr<ChunkFetchSuccess> ChunkFetchSuccess::decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+ protocol::StreamChunkSlice streamChunkSlice =
+ protocol::StreamChunkSlice::decodeFrom(*data);
+ return std::make_unique<ChunkFetchSuccess>(streamChunkSlice,
std::move(data));
+}
+
+std::unique_ptr<ChunkFetchFailure> ChunkFetchFailure::decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+ protocol::StreamChunkSlice streamChunkSlice =
+ protocol::StreamChunkSlice::decodeFrom(*data);
+ int strLen = data->read<int>();
+ CELEBORN_CHECK_EQ(data->remainingSize(), strLen);
+ std::string errorString = data->readToString(strLen);
+ return std::make_unique<ChunkFetchFailure>(
+ streamChunkSlice, std::move(errorString));
+}
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/Message.h b/cpp/celeborn/network/Message.h
new file mode 100644
index 000000000..a4b269aad
--- /dev/null
+++ b/cpp/celeborn/network/Message.h
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <cstdint>
+
+#include "celeborn/memory/ByteBuffer.h"
+#include "celeborn/protocol/ControlMessages.h"
+#include "celeborn/utils/Exceptions.h"
+
+namespace celeborn {
+namespace network {
+/// the Message = RpcRequest, RpcResponse, ... is the direct object that
+/// decoded/encoded by the java's netty encode/decode stack. The message is
then
+/// decoded again to get application-level objects.
+class Message {
+ public:
+ enum Type {
+ UNKNOWN_TYPE = -1,
+ CHUNK_FETCH_REQUEST = 0,
+ CHUNK_FETCH_SUCCESS = 1,
+ CHUNK_FETCH_FAILURE = 2,
+ RPC_REQUEST = 3,
+ RPC_RESPONSE = 4,
+ RPC_FAILURE = 5,
+ OPEN_STREAM = 6,
+ STREAM_HANDLE = 7,
+ ONE_WAY_MESSAGE = 9,
+ PUSH_DATA = 11,
+ PUSH_MERGED_DATA = 12,
+ REGION_START = 13,
+ REGION_FINISH = 14,
+ PUSH_DATA_HAND_SHAKE = 15,
+ READ_ADD_CREDIT = 16,
+ READ_DATA = 17,
+ OPEN_STREAM_WITH_CREDIT = 18,
+ BACKLOG_ANNOUNCEMENT = 19,
+ TRANSPORTABLE_ERROR = 20,
+ BUFFER_STREAM_END = 21,
+ HEARTBEAT = 22,
+ };
+
+ static Type decodeType(uint8_t typeId);
+
+ Message(Type type, std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
+ : type_(type), body_(std::move(body)) {}
+
+ virtual ~Message() = default;
+
+ Type type() const {
+ return type_;
+ }
+
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body() const {
+ return body_->clone();
+ }
+
+ std::unique_ptr<memory::ReadOnlyByteBuffer> encode() const;
+
+ static std::unique_ptr<Message> decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+ static long nextRequestId() {
+ return currRequestId_.fetch_add(1);
+ }
+
+ protected:
+ virtual int internalEncodedLength() const {
+ CELEBORN_UNREACHABLE(
+ "unsupported message encodeLength type " + std::to_string(type_));
+ }
+
+ virtual void internalEncodeTo(memory::WriteOnlyByteBuffer& buffer) const {
+ CELEBORN_UNREACHABLE(
+ "unsupported message internalEncodeTo type " + std::to_string(type_));
+ }
+
+ Type type_;
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body_;
+
+ static std::atomic<long> currRequestId_;
+};
+
+class RpcRequest : public Message {
+ // TODO: add decode method when required
+ public:
+ RpcRequest(long requestId, std::unique_ptr<memory::ReadOnlyByteBuffer>&& buf)
+ : Message(Type::RPC_REQUEST, std::move(buf)), requestId_(requestId) {}
+
+ RpcRequest(const RpcRequest& other)
+ : Message(RPC_REQUEST, other.body_->clone()),
+ requestId_(other.requestId_) {}
+
+ virtual ~RpcRequest() = default;
+
+ long requestId() const {
+ return requestId_;
+ }
+
+ private:
+ int internalEncodedLength() const override;
+
+ void internalEncodeTo(memory::WriteOnlyByteBuffer& buffer) const override;
+
+ long requestId_;
+};
+
+class RpcResponse : public Message {
+ // TODO: add decode method when required
+ public:
+ RpcResponse(
+ long requestId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
+ : Message(RPC_RESPONSE, std::move(body)), requestId_(requestId) {}
+
+ RpcResponse(const RpcResponse& lhs)
+ : Message(RPC_RESPONSE, lhs.body_->clone()), requestId_(lhs.requestId_)
{}
+
+ void operator=(const RpcResponse& lhs) {
+ requestId_ = lhs.requestId();
+ body_ = lhs.body_->clone();
+ }
+
+ long requestId() const {
+ return requestId_;
+ }
+
+ static std::unique_ptr<RpcResponse> decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+ private:
+ long requestId_;
+};
+
+class RpcFailure : public Message {
+ public:
+ RpcFailure(long requestId, std::string&& errorString)
+ : Message(RPC_FAILURE, memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+ requestId_(requestId),
+ errorString_(std::move(errorString)) {}
+
+ RpcFailure(const RpcFailure& other)
+ : Message(RPC_FAILURE, memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+ requestId_(other.requestId_),
+ errorString_(other.errorString_) {}
+
+ long requestId() const {
+ return requestId_;
+ }
+
+ std::string errorMsg() const {
+ return errorString_;
+ }
+
+ static std::unique_ptr<RpcFailure> decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+ private:
+ long requestId_;
+ std::string errorString_;
+};
+
+class ChunkFetchSuccess : public Message {
+ public:
+ ChunkFetchSuccess(
+ protocol::StreamChunkSlice& streamChunkSlice,
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
+ : Message(CHUNK_FETCH_SUCCESS, std::move(body)),
+ streamChunkSlice_(streamChunkSlice) {}
+
+ ChunkFetchSuccess(const ChunkFetchSuccess& other)
+ : Message(CHUNK_FETCH_SUCCESS, other.body_->clone()),
+ streamChunkSlice_(other.streamChunkSlice_) {}
+
+ static std::unique_ptr<ChunkFetchSuccess> decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+ protocol::StreamChunkSlice streamChunkSlice() const {
+ return streamChunkSlice_;
+ }
+
+ private:
+ protocol::StreamChunkSlice streamChunkSlice_;
+};
+
+class ChunkFetchFailure : public Message {
+ public:
+ ChunkFetchFailure(
+ protocol::StreamChunkSlice& streamChunkSlice,
+ std::string&& errorString)
+ : Message(
+ CHUNK_FETCH_FAILURE,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+ streamChunkSlice_(streamChunkSlice),
+ errorString_(std::move(errorString)) {}
+
+ ChunkFetchFailure(const ChunkFetchFailure& other)
+ : Message(
+ CHUNK_FETCH_FAILURE,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+ streamChunkSlice_(other.streamChunkSlice_),
+ errorString_(other.errorString_) {}
+
+ static std::unique_ptr<ChunkFetchFailure> decodeFrom(
+ std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+ protocol::StreamChunkSlice streamChunkSlice() const {
+ return streamChunkSlice_;
+ }
+
+ std::string errorMsg() const {
+ return errorString_;
+ }
+
+ private:
+ protocol::StreamChunkSlice streamChunkSlice_;
+ std::string errorString_;
+};
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/tests/CMakeLists.txt
b/cpp/celeborn/network/tests/CMakeLists.txt
index 46a3fcf56..db38fb484 100644
--- a/cpp/celeborn/network/tests/CMakeLists.txt
+++ b/cpp/celeborn/network/tests/CMakeLists.txt
@@ -15,7 +15,8 @@
add_executable(
celeborn_network_test
- FrameDecoderTest.cpp)
+ FrameDecoderTest.cpp
+ MessageTest.cpp)
add_test(NAME celeborn_network_test COMMAND celeborn_network_test)
@@ -25,6 +26,7 @@ target_link_libraries(
memory
proto
protocol
+ network
utils
${WANGLE}
${FIZZ}
diff --git a/cpp/celeborn/network/tests/MessageTest.cpp
b/cpp/celeborn/network/tests/MessageTest.cpp
new file mode 100644
index 000000000..0604f08a6
--- /dev/null
+++ b/cpp/celeborn/network/tests/MessageTest.cpp
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/network/Message.h"
+
+using namespace celeborn;
+using namespace celeborn::network;
+
+TEST(MessageTest, encodeRpcRequest) {
+ const std::string body = "test-body";
+ auto bodyBuffer = memory::ByteBuffer::createWriteOnly(body.size());
+ bodyBuffer->writeFromString(body);
+ const long requestId = 1000;
+ auto rpcRequest = std::make_unique<RpcRequest>(
+ requestId, memory::ByteBuffer::toReadOnly(std::move(bodyBuffer)));
+
+ auto encodedBuffer = rpcRequest->encode();
+ EXPECT_EQ(encodedBuffer->read<int32_t>(), sizeof(long) + sizeof(int32_t));
+ EXPECT_EQ(encodedBuffer->read<uint8_t>(), Message::Type::RPC_REQUEST);
+ EXPECT_EQ(encodedBuffer->read<int32_t>(), body.size());
+ EXPECT_EQ(encodedBuffer->read<long>(), requestId);
+ EXPECT_EQ(encodedBuffer->read<int32_t>(), body.size());
+ EXPECT_EQ(encodedBuffer->readToString(body.size()), body);
+}
+
+TEST(MessageTest, decodeRpcResponse) {
+ const std::string body = "test-body";
+ const long requestId = 1000;
+ const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+ const int encodedLength = sizeof(long) + 4;
+ const int bodyLength = body.size();
+ size_t size = headerLength + encodedLength + bodyLength;
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+ writeBuffer->write<int32_t>(encodedLength);
+ writeBuffer->write<uint8_t>(Message::Type::RPC_RESPONSE);
+ writeBuffer->write<int32_t>(bodyLength);
+ writeBuffer->write<long>(requestId);
+ writeBuffer->write<int32_t>(bodyLength);
+ writeBuffer->writeFromString(body);
+ auto message = Message::decodeFrom(
+ memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+ EXPECT_EQ(message->type(), Message::Type::RPC_RESPONSE);
+ auto rpcResponse = dynamic_cast<RpcResponse*>(message.get());
+ EXPECT_EQ(rpcResponse->requestId(), requestId);
+ auto rpcResponseBody = rpcResponse->body();
+ EXPECT_EQ(rpcResponseBody->remainingSize(), body.size());
+ EXPECT_EQ(rpcResponseBody->readToString(body.size()), body);
+}
+
+TEST(MessageTest, decodeRpcFailure) {
+ const std::string failureMsg = "test-failure-msg";
+ const long requestId = 1000;
+ const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+ const int encodedLength = sizeof(long) + sizeof(int);
+ const int failureMsgLength = failureMsg.size();
+ size_t size = headerLength + encodedLength + failureMsgLength;
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+ writeBuffer->write<int32_t>(encodedLength);
+ writeBuffer->write<uint8_t>(Message::Type::RPC_FAILURE);
+ writeBuffer->write<int32_t>(failureMsgLength);
+ writeBuffer->write<long>(requestId);
+ writeBuffer->write<int32_t>(failureMsgLength);
+ writeBuffer->writeFromString(failureMsg);
+ auto message = Message::decodeFrom(
+ memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+ EXPECT_EQ(message->type(), Message::Type::RPC_FAILURE);
+ auto rpcFailure = dynamic_cast<RpcFailure*>(message.get());
+ EXPECT_EQ(rpcFailure->requestId(), requestId);
+ auto rpcFailureBody = rpcFailure->body();
+ EXPECT_EQ(rpcFailureBody->remainingSize(), 0);
+ EXPECT_EQ(rpcFailure->errorMsg(), failureMsg);
+}
+
+TEST(MessageTest, decodeChunkFetchSuccess) {
+ const long streamId = 1000;
+ const int chunkIndex = 1001;
+ const int offset = 1002;
+ const int len = 1003;
+ const std::string body = "test-body";
+ const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+ const int encodedLength =
+ sizeof(long) + sizeof(int) + sizeof(int) + sizeof(int);
+ const int bodyLength = body.size();
+ size_t size = headerLength + encodedLength + bodyLength;
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+ writeBuffer->write<int32_t>(encodedLength);
+ writeBuffer->write<uint8_t>(Message::Type::CHUNK_FETCH_SUCCESS);
+ writeBuffer->write<int32_t>(bodyLength);
+ writeBuffer->write<long>(streamId);
+ writeBuffer->write<int>(chunkIndex);
+ writeBuffer->write<int>(offset);
+ writeBuffer->write<int>(len);
+ writeBuffer->writeFromString(body);
+ auto message = Message::decodeFrom(
+ memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+ EXPECT_EQ(message->type(), Message::Type::CHUNK_FETCH_SUCCESS);
+ auto chunkFetchSuccess = dynamic_cast<ChunkFetchSuccess*>(message.get());
+ auto streamChunkSlice = chunkFetchSuccess->streamChunkSlice();
+ EXPECT_EQ(streamChunkSlice.streamId, streamId);
+ EXPECT_EQ(streamChunkSlice.chunkIndex, chunkIndex);
+ EXPECT_EQ(streamChunkSlice.offset, offset);
+ EXPECT_EQ(streamChunkSlice.len, len);
+ auto chunkFetchSuccessBody = chunkFetchSuccess->body();
+ EXPECT_EQ(chunkFetchSuccessBody->remainingSize(), body.size());
+ EXPECT_EQ(chunkFetchSuccessBody->readToString(body.size()), body);
+}
+
+TEST(MessageTest, decodeChunkFetchFailure) {
+ const long streamId = 1000;
+ const int chunkIndex = 1001;
+ const int offset = 1002;
+ const int len = 1003;
+ const std::string failureMsg = "test-failure-msg";
+ const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+ const int encodedLength =
+ sizeof(long) + sizeof(int) + sizeof(int) + sizeof(int) + sizeof(int);
+ const int failureMsgLength = failureMsg.size();
+ size_t size = headerLength + encodedLength + failureMsgLength;
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+ writeBuffer->write<int32_t>(encodedLength);
+ writeBuffer->write<uint8_t>(Message::Type::CHUNK_FETCH_FAILURE);
+ writeBuffer->write<int32_t>(failureMsgLength);
+ writeBuffer->write<long>(streamId);
+ writeBuffer->write<int>(chunkIndex);
+ writeBuffer->write<int>(offset);
+ writeBuffer->write<int>(len);
+ writeBuffer->write<int>(failureMsgLength);
+ writeBuffer->writeFromString(failureMsg);
+ auto message = Message::decodeFrom(
+ memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+ EXPECT_EQ(message->type(), Message::Type::CHUNK_FETCH_FAILURE);
+ auto chunkFetchFailure = dynamic_cast<ChunkFetchFailure*>(message.get());
+ auto streamChunkSlice = chunkFetchFailure->streamChunkSlice();
+ EXPECT_EQ(streamChunkSlice.streamId, streamId);
+ EXPECT_EQ(streamChunkSlice.chunkIndex, chunkIndex);
+ EXPECT_EQ(streamChunkSlice.offset, offset);
+ EXPECT_EQ(streamChunkSlice.len, len);
+ EXPECT_EQ(chunkFetchFailure->errorMsg(), failureMsg);
+}
diff --git a/cpp/celeborn/protocol/CMakeLists.txt
b/cpp/celeborn/protocol/CMakeLists.txt
index aa601c479..a1dc054d3 100644
--- a/cpp/celeborn/protocol/CMakeLists.txt
+++ b/cpp/celeborn/protocol/CMakeLists.txt
@@ -25,6 +25,7 @@ target_link_libraries(
protocol
memory
proto
+ utils
${FOLLY_WITH_DEPENDENCIES}
${GLOG}
${GFLAGS_LIBRARIES}