This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 35a14d246 [CELEBORN-1836][CIP-14] Add Message to cppClient
35a14d246 is described below

commit 35a14d24697d3332eb359606fe249c2ecaeb447c
Author: HolyLow <[email protected]>
AuthorDate: Sat Jan 18 12:52:32 2025 +0800

    [CELEBORN-1836][CIP-14] Add Message to cppClient
    
    ### What changes were proposed in this pull request?
    This PR adds Message implementation to cppClient.
    
    ### Why are the changes needed?
    The Message is the data structure that transfers between network stack 
layers.
    The decode/encode methods are supported and are compatible with existing 
java implementation.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3066 from HolyLow/issue/celeborn-1836-add-message-to-cpp-client.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 cpp/celeborn/network/CMakeLists.txt        |  17 +++
 cpp/celeborn/network/Message.cpp           | 153 +++++++++++++++++++
 cpp/celeborn/network/Message.h             | 235 +++++++++++++++++++++++++++++
 cpp/celeborn/network/tests/CMakeLists.txt  |   4 +-
 cpp/celeborn/network/tests/MessageTest.cpp | 155 +++++++++++++++++++
 cpp/celeborn/protocol/CMakeLists.txt       |   1 +
 6 files changed, 564 insertions(+), 1 deletion(-)

diff --git a/cpp/celeborn/network/CMakeLists.txt 
b/cpp/celeborn/network/CMakeLists.txt
index 889d959f5..3a65828bd 100644
--- a/cpp/celeborn/network/CMakeLists.txt
+++ b/cpp/celeborn/network/CMakeLists.txt
@@ -12,6 +12,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+add_library(
+        network
+        STATIC
+        Message.cpp)
+
+target_include_directories(network PUBLIC ${CMAKE_BINARY_DIR})
+
+target_link_libraries(
+        network
+        memory
+        proto
+        utils
+        protocol
+        ${FOLLY_WITH_DEPENDENCIES}
+        ${GLOG}
+        ${GFLAGS_LIBRARIES}
+)
 
 if(CELEBORN_BUILD_TESTS)
     add_subdirectory(tests)
diff --git a/cpp/celeborn/network/Message.cpp b/cpp/celeborn/network/Message.cpp
new file mode 100644
index 000000000..f4ad4baca
--- /dev/null
+++ b/cpp/celeborn/network/Message.cpp
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/network/Message.h"
+
+namespace celeborn {
+namespace network {
+Message::Type Message::decodeType(uint8_t typeId) {
+  switch (typeId) {
+    case 0:
+      return CHUNK_FETCH_REQUEST;
+    case 1:
+      return CHUNK_FETCH_SUCCESS;
+    case 2:
+      return CHUNK_FETCH_FAILURE;
+    case 3:
+      return RPC_REQUEST;
+    case 4:
+      return RPC_RESPONSE;
+    case 5:
+      return RPC_FAILURE;
+    case 6:
+      return OPEN_STREAM;
+    case 7:
+      return STREAM_HANDLE;
+    case 9:
+      return ONE_WAY_MESSAGE;
+    case 11:
+      return PUSH_DATA;
+    case 12:
+      return PUSH_MERGED_DATA;
+    case 13:
+      return REGION_START;
+    case 14:
+      return REGION_FINISH;
+    case 15:
+      return PUSH_DATA_HAND_SHAKE;
+    case 16:
+      return READ_ADD_CREDIT;
+    case 17:
+      return READ_DATA;
+    case 18:
+      return OPEN_STREAM_WITH_CREDIT;
+    case 19:
+      return BACKLOG_ANNOUNCEMENT;
+    case 20:
+      return TRANSPORTABLE_ERROR;
+    case 21:
+      return BUFFER_STREAM_END;
+    case 22:
+      return HEARTBEAT;
+    default:
+      CELEBORN_FAIL("Unknown message type " + std::to_string(typeId));
+  }
+}
+
+std::atomic<long> Message::currRequestId_ = 0;
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> Message::encode() const {
+  int bodyLength = body_->remainingSize();
+  int encodedLength = internalEncodedLength();
+  int headerLength =
+      sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t) + encodedLength;
+  auto buffer = memory::ByteBuffer::createWriteOnly(headerLength);
+  buffer->write<int32_t>(encodedLength);
+  buffer->write<uint8_t>(type_);
+  buffer->write<int32_t>(bodyLength);
+  internalEncodeTo(*buffer);
+  auto header = memory::ByteBuffer::toReadOnly(std::move(buffer));
+  auto combinedFrame = memory::ByteBuffer::concat(*header, *body_);
+  return std::move(combinedFrame);
+}
+
+std::unique_ptr<Message> Message::decodeFrom(
+    std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+  int32_t encodedLength = data->read<int32_t>();
+  uint8_t typeId = data->read<uint8_t>();
+  int32_t bodyLength = data->read<int32_t>();
+  CELEBORN_CHECK_EQ(encodedLength + bodyLength, data->remainingSize());
+  Type type = decodeType(typeId);
+  switch (type) {
+    case RPC_RESPONSE:
+      return RpcResponse::decodeFrom(std::move(data));
+    case RPC_FAILURE:
+      return RpcFailure::decodeFrom(std::move(data));
+    case CHUNK_FETCH_SUCCESS:
+      return ChunkFetchSuccess::decodeFrom(std::move(data));
+    case CHUNK_FETCH_FAILURE:
+      return ChunkFetchFailure::decodeFrom(std::move(data));
+    default:
+      CELEBORN_FAIL("unsupported Message decode type " + std::to_string(type));
+  }
+}
+
+int RpcRequest::internalEncodedLength() const {
+  return sizeof(long) + sizeof(int32_t);
+}
+
+void RpcRequest::internalEncodeTo(memory::WriteOnlyByteBuffer& buffer) const {
+  buffer.write<long>(requestId_);
+  buffer.write<int32_t>(body_->remainingSize());
+}
+
+std::unique_ptr<RpcResponse> RpcResponse::decodeFrom(
+    std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+  long requestId = data->read<long>();
+  data->skip(4);
+  auto result = std::make_unique<RpcResponse>(requestId, std::move(data));
+  return result;
+}
+
+std::unique_ptr<RpcFailure> RpcFailure::decodeFrom(
+    std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+  long requestId = data->read<long>();
+  int strLen = data->read<int>();
+  CELEBORN_CHECK_EQ(data->remainingSize(), strLen);
+  std::string errorString = data->readToString(strLen);
+  return std::make_unique<RpcFailure>(requestId, std::move(errorString));
+}
+
+std::unique_ptr<ChunkFetchSuccess> ChunkFetchSuccess::decodeFrom(
+    std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+  protocol::StreamChunkSlice streamChunkSlice =
+      protocol::StreamChunkSlice::decodeFrom(*data);
+  return std::make_unique<ChunkFetchSuccess>(streamChunkSlice, 
std::move(data));
+}
+
+std::unique_ptr<ChunkFetchFailure> ChunkFetchFailure::decodeFrom(
+    std::unique_ptr<memory::ReadOnlyByteBuffer>&& data) {
+  protocol::StreamChunkSlice streamChunkSlice =
+      protocol::StreamChunkSlice::decodeFrom(*data);
+  int strLen = data->read<int>();
+  CELEBORN_CHECK_EQ(data->remainingSize(), strLen);
+  std::string errorString = data->readToString(strLen);
+  return std::make_unique<ChunkFetchFailure>(
+      streamChunkSlice, std::move(errorString));
+}
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/Message.h b/cpp/celeborn/network/Message.h
new file mode 100644
index 000000000..a4b269aad
--- /dev/null
+++ b/cpp/celeborn/network/Message.h
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <cstdint>
+
+#include "celeborn/memory/ByteBuffer.h"
+#include "celeborn/protocol/ControlMessages.h"
+#include "celeborn/utils/Exceptions.h"
+
+namespace celeborn {
+namespace network {
+/// the Message  = RpcRequest, RpcResponse, ... is the direct object that
+/// decoded/encoded by the java's netty encode/decode stack. The message is 
then
+/// decoded again to get application-level objects.
+class Message {
+ public:
+  enum Type {
+    UNKNOWN_TYPE = -1,
+    CHUNK_FETCH_REQUEST = 0,
+    CHUNK_FETCH_SUCCESS = 1,
+    CHUNK_FETCH_FAILURE = 2,
+    RPC_REQUEST = 3,
+    RPC_RESPONSE = 4,
+    RPC_FAILURE = 5,
+    OPEN_STREAM = 6,
+    STREAM_HANDLE = 7,
+    ONE_WAY_MESSAGE = 9,
+    PUSH_DATA = 11,
+    PUSH_MERGED_DATA = 12,
+    REGION_START = 13,
+    REGION_FINISH = 14,
+    PUSH_DATA_HAND_SHAKE = 15,
+    READ_ADD_CREDIT = 16,
+    READ_DATA = 17,
+    OPEN_STREAM_WITH_CREDIT = 18,
+    BACKLOG_ANNOUNCEMENT = 19,
+    TRANSPORTABLE_ERROR = 20,
+    BUFFER_STREAM_END = 21,
+    HEARTBEAT = 22,
+  };
+
+  static Type decodeType(uint8_t typeId);
+
+  Message(Type type, std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
+      : type_(type), body_(std::move(body)) {}
+
+  virtual ~Message() = default;
+
+  Type type() const {
+    return type_;
+  }
+
+  std::unique_ptr<memory::ReadOnlyByteBuffer> body() const {
+    return body_->clone();
+  }
+
+  std::unique_ptr<memory::ReadOnlyByteBuffer> encode() const;
+
+  static std::unique_ptr<Message> decodeFrom(
+      std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+  static long nextRequestId() {
+    return currRequestId_.fetch_add(1);
+  }
+
+ protected:
+  virtual int internalEncodedLength() const {
+    CELEBORN_UNREACHABLE(
+        "unsupported message encodeLength type " + std::to_string(type_));
+  }
+
+  virtual void internalEncodeTo(memory::WriteOnlyByteBuffer& buffer) const {
+    CELEBORN_UNREACHABLE(
+        "unsupported message internalEncodeTo type " + std::to_string(type_));
+  }
+
+  Type type_;
+  std::unique_ptr<memory::ReadOnlyByteBuffer> body_;
+
+  static std::atomic<long> currRequestId_;
+};
+
+class RpcRequest : public Message {
+  // TODO: add decode method when required
+ public:
+  RpcRequest(long requestId, std::unique_ptr<memory::ReadOnlyByteBuffer>&& buf)
+      : Message(Type::RPC_REQUEST, std::move(buf)), requestId_(requestId) {}
+
+  RpcRequest(const RpcRequest& other)
+      : Message(RPC_REQUEST, other.body_->clone()),
+        requestId_(other.requestId_) {}
+
+  virtual ~RpcRequest() = default;
+
+  long requestId() const {
+    return requestId_;
+  }
+
+ private:
+  int internalEncodedLength() const override;
+
+  void internalEncodeTo(memory::WriteOnlyByteBuffer& buffer) const override;
+
+  long requestId_;
+};
+
+class RpcResponse : public Message {
+  // TODO: add decode method when required
+ public:
+  RpcResponse(
+      long requestId,
+      std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
+      : Message(RPC_RESPONSE, std::move(body)), requestId_(requestId) {}
+
+  RpcResponse(const RpcResponse& lhs)
+      : Message(RPC_RESPONSE, lhs.body_->clone()), requestId_(lhs.requestId_) 
{}
+
+  void operator=(const RpcResponse& lhs) {
+    requestId_ = lhs.requestId();
+    body_ = lhs.body_->clone();
+  }
+
+  long requestId() const {
+    return requestId_;
+  }
+
+  static std::unique_ptr<RpcResponse> decodeFrom(
+      std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+ private:
+  long requestId_;
+};
+
+class RpcFailure : public Message {
+ public:
+  RpcFailure(long requestId, std::string&& errorString)
+      : Message(RPC_FAILURE, memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+        requestId_(requestId),
+        errorString_(std::move(errorString)) {}
+
+  RpcFailure(const RpcFailure& other)
+      : Message(RPC_FAILURE, memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+        requestId_(other.requestId_),
+        errorString_(other.errorString_) {}
+
+  long requestId() const {
+    return requestId_;
+  }
+
+  std::string errorMsg() const {
+    return errorString_;
+  }
+
+  static std::unique_ptr<RpcFailure> decodeFrom(
+      std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+ private:
+  long requestId_;
+  std::string errorString_;
+};
+
+class ChunkFetchSuccess : public Message {
+ public:
+  ChunkFetchSuccess(
+      protocol::StreamChunkSlice& streamChunkSlice,
+      std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
+      : Message(CHUNK_FETCH_SUCCESS, std::move(body)),
+        streamChunkSlice_(streamChunkSlice) {}
+
+  ChunkFetchSuccess(const ChunkFetchSuccess& other)
+      : Message(CHUNK_FETCH_SUCCESS, other.body_->clone()),
+        streamChunkSlice_(other.streamChunkSlice_) {}
+
+  static std::unique_ptr<ChunkFetchSuccess> decodeFrom(
+      std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+  protocol::StreamChunkSlice streamChunkSlice() const {
+    return streamChunkSlice_;
+  }
+
+ private:
+  protocol::StreamChunkSlice streamChunkSlice_;
+};
+
+class ChunkFetchFailure : public Message {
+ public:
+  ChunkFetchFailure(
+      protocol::StreamChunkSlice& streamChunkSlice,
+      std::string&& errorString)
+      : Message(
+            CHUNK_FETCH_FAILURE,
+            memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+        streamChunkSlice_(streamChunkSlice),
+        errorString_(std::move(errorString)) {}
+
+  ChunkFetchFailure(const ChunkFetchFailure& other)
+      : Message(
+            CHUNK_FETCH_FAILURE,
+            memory::ReadOnlyByteBuffer::createEmptyBuffer()),
+        streamChunkSlice_(other.streamChunkSlice_),
+        errorString_(other.errorString_) {}
+
+  static std::unique_ptr<ChunkFetchFailure> decodeFrom(
+      std::unique_ptr<memory::ReadOnlyByteBuffer>&& data);
+
+  protocol::StreamChunkSlice streamChunkSlice() const {
+    return streamChunkSlice_;
+  }
+
+  std::string errorMsg() const {
+    return errorString_;
+  }
+
+ private:
+  protocol::StreamChunkSlice streamChunkSlice_;
+  std::string errorString_;
+};
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/tests/CMakeLists.txt 
b/cpp/celeborn/network/tests/CMakeLists.txt
index 46a3fcf56..db38fb484 100644
--- a/cpp/celeborn/network/tests/CMakeLists.txt
+++ b/cpp/celeborn/network/tests/CMakeLists.txt
@@ -15,7 +15,8 @@
 
 add_executable(
         celeborn_network_test
-        FrameDecoderTest.cpp)
+        FrameDecoderTest.cpp
+        MessageTest.cpp)
 
 add_test(NAME celeborn_network_test COMMAND celeborn_network_test)
 
@@ -25,6 +26,7 @@ target_link_libraries(
         memory
         proto
         protocol
+        network
         utils
         ${WANGLE}
         ${FIZZ}
diff --git a/cpp/celeborn/network/tests/MessageTest.cpp 
b/cpp/celeborn/network/tests/MessageTest.cpp
new file mode 100644
index 000000000..0604f08a6
--- /dev/null
+++ b/cpp/celeborn/network/tests/MessageTest.cpp
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/network/Message.h"
+
+using namespace celeborn;
+using namespace celeborn::network;
+
+TEST(MessageTest, encodeRpcRequest) {
+  const std::string body = "test-body";
+  auto bodyBuffer = memory::ByteBuffer::createWriteOnly(body.size());
+  bodyBuffer->writeFromString(body);
+  const long requestId = 1000;
+  auto rpcRequest = std::make_unique<RpcRequest>(
+      requestId, memory::ByteBuffer::toReadOnly(std::move(bodyBuffer)));
+
+  auto encodedBuffer = rpcRequest->encode();
+  EXPECT_EQ(encodedBuffer->read<int32_t>(), sizeof(long) + sizeof(int32_t));
+  EXPECT_EQ(encodedBuffer->read<uint8_t>(), Message::Type::RPC_REQUEST);
+  EXPECT_EQ(encodedBuffer->read<int32_t>(), body.size());
+  EXPECT_EQ(encodedBuffer->read<long>(), requestId);
+  EXPECT_EQ(encodedBuffer->read<int32_t>(), body.size());
+  EXPECT_EQ(encodedBuffer->readToString(body.size()), body);
+}
+
+TEST(MessageTest, decodeRpcResponse) {
+  const std::string body = "test-body";
+  const long requestId = 1000;
+  const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+  const int encodedLength = sizeof(long) + 4;
+  const int bodyLength = body.size();
+  size_t size = headerLength + encodedLength + bodyLength;
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+  writeBuffer->write<int32_t>(encodedLength);
+  writeBuffer->write<uint8_t>(Message::Type::RPC_RESPONSE);
+  writeBuffer->write<int32_t>(bodyLength);
+  writeBuffer->write<long>(requestId);
+  writeBuffer->write<int32_t>(bodyLength);
+  writeBuffer->writeFromString(body);
+  auto message = Message::decodeFrom(
+      memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+  EXPECT_EQ(message->type(), Message::Type::RPC_RESPONSE);
+  auto rpcResponse = dynamic_cast<RpcResponse*>(message.get());
+  EXPECT_EQ(rpcResponse->requestId(), requestId);
+  auto rpcResponseBody = rpcResponse->body();
+  EXPECT_EQ(rpcResponseBody->remainingSize(), body.size());
+  EXPECT_EQ(rpcResponseBody->readToString(body.size()), body);
+}
+
+TEST(MessageTest, decodeRpcFailure) {
+  const std::string failureMsg = "test-failure-msg";
+  const long requestId = 1000;
+  const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+  const int encodedLength = sizeof(long) + sizeof(int);
+  const int failureMsgLength = failureMsg.size();
+  size_t size = headerLength + encodedLength + failureMsgLength;
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+  writeBuffer->write<int32_t>(encodedLength);
+  writeBuffer->write<uint8_t>(Message::Type::RPC_FAILURE);
+  writeBuffer->write<int32_t>(failureMsgLength);
+  writeBuffer->write<long>(requestId);
+  writeBuffer->write<int32_t>(failureMsgLength);
+  writeBuffer->writeFromString(failureMsg);
+  auto message = Message::decodeFrom(
+      memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+  EXPECT_EQ(message->type(), Message::Type::RPC_FAILURE);
+  auto rpcFailure = dynamic_cast<RpcFailure*>(message.get());
+  EXPECT_EQ(rpcFailure->requestId(), requestId);
+  auto rpcFailureBody = rpcFailure->body();
+  EXPECT_EQ(rpcFailureBody->remainingSize(), 0);
+  EXPECT_EQ(rpcFailure->errorMsg(), failureMsg);
+}
+
+TEST(MessageTest, decodeChunkFetchSuccess) {
+  const long streamId = 1000;
+  const int chunkIndex = 1001;
+  const int offset = 1002;
+  const int len = 1003;
+  const std::string body = "test-body";
+  const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+  const int encodedLength =
+      sizeof(long) + sizeof(int) + sizeof(int) + sizeof(int);
+  const int bodyLength = body.size();
+  size_t size = headerLength + encodedLength + bodyLength;
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+  writeBuffer->write<int32_t>(encodedLength);
+  writeBuffer->write<uint8_t>(Message::Type::CHUNK_FETCH_SUCCESS);
+  writeBuffer->write<int32_t>(bodyLength);
+  writeBuffer->write<long>(streamId);
+  writeBuffer->write<int>(chunkIndex);
+  writeBuffer->write<int>(offset);
+  writeBuffer->write<int>(len);
+  writeBuffer->writeFromString(body);
+  auto message = Message::decodeFrom(
+      memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+  EXPECT_EQ(message->type(), Message::Type::CHUNK_FETCH_SUCCESS);
+  auto chunkFetchSuccess = dynamic_cast<ChunkFetchSuccess*>(message.get());
+  auto streamChunkSlice = chunkFetchSuccess->streamChunkSlice();
+  EXPECT_EQ(streamChunkSlice.streamId, streamId);
+  EXPECT_EQ(streamChunkSlice.chunkIndex, chunkIndex);
+  EXPECT_EQ(streamChunkSlice.offset, offset);
+  EXPECT_EQ(streamChunkSlice.len, len);
+  auto chunkFetchSuccessBody = chunkFetchSuccess->body();
+  EXPECT_EQ(chunkFetchSuccessBody->remainingSize(), body.size());
+  EXPECT_EQ(chunkFetchSuccessBody->readToString(body.size()), body);
+}
+
+TEST(MessageTest, decodeChunkFetchFailure) {
+  const long streamId = 1000;
+  const int chunkIndex = 1001;
+  const int offset = 1002;
+  const int len = 1003;
+  const std::string failureMsg = "test-failure-msg";
+  const int headerLength = sizeof(int32_t) + sizeof(uint8_t) + sizeof(int32_t);
+  const int encodedLength =
+      sizeof(long) + sizeof(int) + sizeof(int) + sizeof(int) + sizeof(int);
+  const int failureMsgLength = failureMsg.size();
+  size_t size = headerLength + encodedLength + failureMsgLength;
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(size);
+  writeBuffer->write<int32_t>(encodedLength);
+  writeBuffer->write<uint8_t>(Message::Type::CHUNK_FETCH_FAILURE);
+  writeBuffer->write<int32_t>(failureMsgLength);
+  writeBuffer->write<long>(streamId);
+  writeBuffer->write<int>(chunkIndex);
+  writeBuffer->write<int>(offset);
+  writeBuffer->write<int>(len);
+  writeBuffer->write<int>(failureMsgLength);
+  writeBuffer->writeFromString(failureMsg);
+  auto message = Message::decodeFrom(
+      memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+  EXPECT_EQ(message->type(), Message::Type::CHUNK_FETCH_FAILURE);
+  auto chunkFetchFailure = dynamic_cast<ChunkFetchFailure*>(message.get());
+  auto streamChunkSlice = chunkFetchFailure->streamChunkSlice();
+  EXPECT_EQ(streamChunkSlice.streamId, streamId);
+  EXPECT_EQ(streamChunkSlice.chunkIndex, chunkIndex);
+  EXPECT_EQ(streamChunkSlice.offset, offset);
+  EXPECT_EQ(streamChunkSlice.len, len);
+  EXPECT_EQ(chunkFetchFailure->errorMsg(), failureMsg);
+}
diff --git a/cpp/celeborn/protocol/CMakeLists.txt 
b/cpp/celeborn/protocol/CMakeLists.txt
index aa601c479..a1dc054d3 100644
--- a/cpp/celeborn/protocol/CMakeLists.txt
+++ b/cpp/celeborn/protocol/CMakeLists.txt
@@ -25,6 +25,7 @@ target_link_libraries(
         protocol
         memory
         proto
+        utils
         ${FOLLY_WITH_DEPENDENCIES}
         ${GLOG}
         ${GFLAGS_LIBRARIES}

Reply via email to