This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 39a40dd2a [CELEBORN-1845][CIP-14] Add MessageDispatcher to cppClient
39a40dd2a is described below

commit 39a40dd2a18f3da4b70c414f2a2aff5bdf35117a
Author: HolyLow <[email protected]>
AuthorDate: Wed Jan 22 20:07:22 2025 +0800

    [CELEBORN-1845][CIP-14] Add MessageDispatcher to cppClient
    
    ### What changes were proposed in this pull request?
    This PR adds MessageDispatcher class to cppClient.
    
    ### Why are the changes needed?
    MessageDispatcher is responsible for recording the connection between 
MessageFuture and MessagePromise, which is the base of async message 
transferring mechanism.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3077 from 
HolyLow/issue/celeborn-1845-add-message-dispatcher-to-cpp-client.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 cpp/celeborn/network/CMakeLists.txt                |   6 +-
 cpp/celeborn/network/Message.h                     |   4 +-
 cpp/celeborn/network/MessageDispatcher.cpp         | 257 +++++++++++++++++++++
 cpp/celeborn/network/MessageDispatcher.h           | 113 +++++++++
 cpp/celeborn/network/tests/CMakeLists.txt          |   3 +-
 .../network/tests/MessageDispatcherTest.cpp        | 200 ++++++++++++++++
 6 files changed, 579 insertions(+), 4 deletions(-)

diff --git a/cpp/celeborn/network/CMakeLists.txt 
b/cpp/celeborn/network/CMakeLists.txt
index 3a65828bd..1acf114bb 100644
--- a/cpp/celeborn/network/CMakeLists.txt
+++ b/cpp/celeborn/network/CMakeLists.txt
@@ -15,7 +15,8 @@
 add_library(
         network
         STATIC
-        Message.cpp)
+        Message.cpp
+        MessageDispatcher.cpp)
 
 target_include_directories(network PUBLIC ${CMAKE_BINARY_DIR})
 
@@ -25,6 +26,9 @@ target_link_libraries(
         proto
         utils
         protocol
+        ${WANGLE}
+        ${FIZZ}
+        ${LIBSODIUM_LIBRARY}
         ${FOLLY_WITH_DEPENDENCIES}
         ${GLOG}
         ${GFLAGS_LIBRARIES}
diff --git a/cpp/celeborn/network/Message.h b/cpp/celeborn/network/Message.h
index a4b269aad..ddb6d9979 100644
--- a/cpp/celeborn/network/Message.h
+++ b/cpp/celeborn/network/Message.h
@@ -178,7 +178,7 @@ class RpcFailure : public Message {
 class ChunkFetchSuccess : public Message {
  public:
   ChunkFetchSuccess(
-      protocol::StreamChunkSlice& streamChunkSlice,
+      const protocol::StreamChunkSlice& streamChunkSlice,
       std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
       : Message(CHUNK_FETCH_SUCCESS, std::move(body)),
         streamChunkSlice_(streamChunkSlice) {}
@@ -201,7 +201,7 @@ class ChunkFetchSuccess : public Message {
 class ChunkFetchFailure : public Message {
  public:
   ChunkFetchFailure(
-      protocol::StreamChunkSlice& streamChunkSlice,
+      const protocol::StreamChunkSlice& streamChunkSlice,
       std::string&& errorString)
       : Message(
             CHUNK_FETCH_FAILURE,
diff --git a/cpp/celeborn/network/MessageDispatcher.cpp 
b/cpp/celeborn/network/MessageDispatcher.cpp
new file mode 100644
index 000000000..30161483d
--- /dev/null
+++ b/cpp/celeborn/network/MessageDispatcher.cpp
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/network/MessageDispatcher.h"
+
+#include "celeborn/protocol/TransportMessage.h"
+
+namespace celeborn {
+namespace network {
+void MessageDispatcher::read(Context*, std::unique_ptr<Message> toRecvMsg) {
+  switch (toRecvMsg->type()) {
+    case Message::RPC_RESPONSE: {
+      RpcResponse* response = reinterpret_cast<RpcResponse*>(toRecvMsg.get());
+      bool found = true;
+      auto holder = requestIdRegistry_.withLock([&](auto& registry) {
+        auto search = registry.find(response->requestId());
+        if (search == registry.end()) {
+          LOG(WARNING)
+              << "requestId " << response->requestId()
+              << " not found when handling RPC_RESPONSE. Might be outdated 
already, ignored.";
+          found = false;
+          return MsgPromiseHolder{};
+        }
+        auto result = std::move(search->second);
+        registry.erase(response->requestId());
+        return std::move(result);
+      });
+      if (found) {
+        holder.msgPromise.setValue(std::move(toRecvMsg));
+      }
+      return;
+    }
+    case Message::RPC_FAILURE: {
+      RpcFailure* failure = reinterpret_cast<RpcFailure*>(toRecvMsg.get());
+      bool found = true;
+      auto holder = requestIdRegistry_.withLock([&](auto& registry) {
+        auto search = registry.find(failure->requestId());
+        if (search == registry.end()) {
+          LOG(WARNING)
+              << "requestId " << failure->requestId()
+              << " not found when handling RPC_FAILURE. Might be outdated 
already, ignored.";
+          found = false;
+          return MsgPromiseHolder{};
+        }
+        auto result = std::move(search->second);
+        registry.erase(failure->requestId());
+        return std::move(result);
+      });
+      LOG(ERROR) << "Rpc failed, requestId: " << failure->requestId()
+                 << " errorMsg: " << failure->errorMsg() << std::endl;
+      if (found) {
+        holder.msgPromise.setException(
+            folly::exception_wrapper(std::exception()));
+      }
+      return;
+    }
+    case Message::CHUNK_FETCH_SUCCESS: {
+      ChunkFetchSuccess* success =
+          reinterpret_cast<ChunkFetchSuccess*>(toRecvMsg.get());
+      auto streamChunkSlice = success->streamChunkSlice();
+      bool found = true;
+      auto holder = streamChunkSliceRegistry_.withLock([&](auto& registry) {
+        auto search = registry.find(streamChunkSlice);
+        if (search == registry.end()) {
+          LOG(WARNING)
+              << "streamChunkSlice " << streamChunkSlice.toString()
+              << " not found when handling CHUNK_FETCH_SUCCESS. Might be 
outdated already, ignored.";
+          found = false;
+          return MsgPromiseHolder{};
+        }
+        auto result = std::move(search->second);
+        registry.erase(streamChunkSlice);
+        return std::move(result);
+      });
+      if (found) {
+        holder.msgPromise.setValue(std::move(toRecvMsg));
+      }
+      return;
+    }
+    case Message::CHUNK_FETCH_FAILURE: {
+      ChunkFetchFailure* failure =
+          reinterpret_cast<ChunkFetchFailure*>(toRecvMsg.get());
+      auto streamChunkSlice = failure->streamChunkSlice();
+      bool found = true;
+      auto holder = streamChunkSliceRegistry_.withLock([&](auto& registry) {
+        auto search = registry.find(streamChunkSlice);
+        if (search == registry.end()) {
+          LOG(WARNING)
+              << "streamChunkSlice " << streamChunkSlice.toString()
+              << " not found when handling CHUNK_FETCH_FAILURE. Might be 
outdated already, ignored.";
+          found = false;
+          return MsgPromiseHolder{};
+        }
+        auto result = std::move(search->second);
+        registry.erase(streamChunkSlice);
+        return std::move(result);
+      });
+      std::string errorMsg = fmt::format(
+          "fetchChunk failed, streamChunkSlice: {}, errorMsg: {}",
+          streamChunkSlice.toString(),
+          failure->errorMsg());
+      LOG(ERROR) << errorMsg;
+      if (found) {
+        holder.msgPromise.setException(
+            folly::exception_wrapper(std::exception()));
+      }
+      return;
+    }
+    default: {
+      LOG(ERROR) << "unsupported msg for dispatcher";
+    }
+  }
+}
+
+folly::Future<std::unique_ptr<Message>> MessageDispatcher::operator()(
+    std::unique_ptr<Message> toSendMsg) {
+  CELEBORN_CHECK(!closed_);
+  CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
+  RpcRequest* request = reinterpret_cast<RpcRequest*>(toSendMsg.get());
+  auto f = requestIdRegistry_.withLock(
+      [&](auto& registry) -> folly::Future<std::unique_ptr<Message>> {
+        auto& holder = registry[request->requestId()];
+        holder.requestTime = std::chrono::system_clock::now();
+        auto& p = holder.msgPromise;
+        p.setInterruptHandler([requestId = request->requestId(),
+                               this](const folly::exception_wrapper&) {
+          this->requestIdRegistry_.lock()->erase(requestId);
+          LOG(WARNING) << "rpc request interrupted, requestId: " << requestId;
+        });
+        return p.getFuture();
+      });
+
+  this->pipeline_->write(std::move(toSendMsg));
+
+  CELEBORN_CHECK(!closed_);
+  return f;
+}
+
+folly::Future<std::unique_ptr<Message>>
+MessageDispatcher::sendFetchChunkRequest(
+    const protocol::StreamChunkSlice& streamChunkSlice,
+    std::unique_ptr<Message> toSendMsg) {
+  CELEBORN_CHECK(!closed_);
+  CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
+  auto f = streamChunkSliceRegistry_.withLock([&](auto& registry) {
+    auto& holder = registry[streamChunkSlice];
+    holder.requestTime = std::chrono::system_clock::now();
+    auto& p = holder.msgPromise;
+    p.setInterruptHandler(
+        [streamChunkSlice, this](const folly::exception_wrapper&) {
+          LOG(WARNING) << "fetchChunk request interrupted, "
+                          "streamChunkSlice: "
+                       << streamChunkSlice.toString();
+          this->streamChunkSliceRegistry_.lock()->erase(streamChunkSlice);
+        });
+    return p.getFuture();
+  });
+  this->pipeline_->write(std::move(toSendMsg));
+  CELEBORN_CHECK(!closed_);
+  return f;
+}
+
+void MessageDispatcher::sendRpcRequestWithoutResponse(
+    std::unique_ptr<Message> toSendMsg) {
+  CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
+  this->pipeline_->write(std::move(toSendMsg));
+}
+
+void MessageDispatcher::readEOF(Context* ctx) {
+  LOG(ERROR) << "readEOF, start to close client";
+  ctx->fireReadEOF();
+  close();
+}
+
+void MessageDispatcher::readException(
+    Context* ctx,
+    folly::exception_wrapper e) {
+  LOG(ERROR) << "readException: " << e.what() << " , start to close client";
+  ctx->fireReadException(std::move(e));
+  close();
+}
+
+void MessageDispatcher::transportActive(Context* ctx) {
+  // Typically do nothing.
+  ctx->fireTransportActive();
+}
+
+void MessageDispatcher::transportInactive(Context* ctx) {
+  LOG(ERROR) << "transportInactive, start to close client";
+  ctx->fireTransportInactive();
+  close();
+}
+
+folly::Future<folly::Unit> MessageDispatcher::writeException(
+    Context* ctx,
+    folly::exception_wrapper e) {
+  LOG(ERROR) << "writeException: " << e.what() << " , start to close client";
+  auto result = ctx->fireWriteException(std::move(e));
+  close();
+  return result;
+}
+
+folly::Future<folly::Unit> MessageDispatcher::close() {
+  if (!closed_) {
+    closed_ = true;
+    cleanup();
+  }
+  return ClientDispatcherBase::close();
+}
+
+folly::Future<folly::Unit> MessageDispatcher::close(Context* ctx) {
+  if (!closed_) {
+    closed_ = true;
+    cleanup();
+  }
+
+  return ClientDispatcherBase::close(ctx);
+}
+
+void MessageDispatcher::cleanup() {
+  LOG(WARNING) << "Cleaning up client!";
+  requestIdRegistry_.withLock([&](auto& registry) {
+    for (auto& [requestId, promiseHolder] : registry) {
+      auto errorMsg =
+          fmt::format("Client closed, cancel ongoing requestId {}", requestId);
+      LOG(WARNING) << errorMsg;
+      promiseHolder.msgPromise.setException(std::runtime_error(errorMsg));
+    }
+    registry.clear();
+  });
+  streamChunkSliceRegistry_.withLock([&](auto& registry) {
+    for (auto& [streamChunkSlice, promiseHolder] : registry) {
+      auto errorMsg = fmt::format(
+          "Client closed, cancel ongoing streamChunkSlice {}",
+          streamChunkSlice.toString());
+      LOG(WARNING) << errorMsg;
+      promiseHolder.msgPromise.setException(std::runtime_error(errorMsg));
+    }
+    registry.clear();
+  });
+}
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/MessageDispatcher.h 
b/cpp/celeborn/network/MessageDispatcher.h
new file mode 100644
index 000000000..24a233a7c
--- /dev/null
+++ b/cpp/celeborn/network/MessageDispatcher.h
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <wangle/bootstrap/ClientBootstrap.h>
+#include <wangle/channel/AsyncSocketHandler.h>
+#include <wangle/channel/EventBaseHandler.h>
+#include <wangle/service/ClientDispatcher.h>
+
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/network/Message.h"
+#include "celeborn/protocol/ControlMessages.h"
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace network {
+using SerializePipeline =
+    wangle::Pipeline<folly::IOBufQueue&, std::unique_ptr<Message>>;
+
+/**
+ * MessageDispatcher is responsible for:
+ *  1. Record the connection between MessageFuture and MessagePromise.
+ *     When a request message is sent via write(), the MessageFuture is
+ *     recorded; then when the response message is received via read(),
+ *     the response would be transferred to MessageFuture by fulfilling
+ *     the corresponding MessagePromise.
+ *  2. Send different messages via different interfaces, and calls
+ *     write() to send it to the network layer. A MessagePromise is
+ *     created and recorded for each returned MessageFuture.
+ *  3. Receive response messages via read(), and dispatch the message
+ *     according to the message kind, and finally fulfills the
+ *     corresponding MessagePromise.
+ *  4. Handles and reports all kinds of network issues, e.g. EOF,
+ *     inactive, exception, etc.
+ */
+class MessageDispatcher : public wangle::ClientDispatcherBase<
+                              SerializePipeline,
+                              std::unique_ptr<Message>,
+                              std::unique_ptr<Message>> {
+public:
+    void read(Context*, std::unique_ptr<Message> toRecvMsg) override;
+
+    virtual folly::Future<std::unique_ptr<Message>> sendRpcRequest(
+        std::unique_ptr<Message> toSendMsg) {
+        return operator()(std::move(toSendMsg));
+    }
+
+    virtual folly::Future<std::unique_ptr<Message>> sendFetchChunkRequest(
+        const protocol::StreamChunkSlice& streamChunkSlice,
+        std::unique_ptr<Message> toSendMsg);
+
+    virtual void sendRpcRequestWithoutResponse(
+        std::unique_ptr<Message> toSendMsg);
+
+    folly::Future<std::unique_ptr<Message>> operator()(
+        std::unique_ptr<Message> toSendMsg) override;
+
+    void readEOF(Context* ctx) override;
+
+    void readException(Context* ctx, folly::exception_wrapper e) override;
+
+    void transportActive(Context* ctx) override;
+
+    void transportInactive(Context* ctx) override;
+
+    folly::Future<folly::Unit> writeException(
+        Context* ctx,
+        folly::exception_wrapper e) override;
+
+    folly::Future<folly::Unit> close() override;
+
+    folly::Future<folly::Unit> close(Context* ctx) override;
+
+    bool isAvailable() override {
+        return !closed_;
+    }
+
+private:
+    void cleanup();
+
+    using MsgPromise = folly::Promise<std::unique_ptr<Message>>;
+    struct MsgPromiseHolder {
+        MsgPromise msgPromise;
+        std::chrono::time_point<std::chrono::system_clock> requestTime;
+    };
+    folly::Synchronized<std::unordered_map<long, MsgPromiseHolder>, std::mutex>
+        requestIdRegistry_;
+    folly::Synchronized<
+        std::unordered_map<
+            protocol::StreamChunkSlice,
+            MsgPromiseHolder,
+            protocol::StreamChunkSlice::Hasher>,
+        std::mutex>
+        streamChunkSliceRegistry_;
+    std::atomic<bool> closed_{false};
+};
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/tests/CMakeLists.txt 
b/cpp/celeborn/network/tests/CMakeLists.txt
index db38fb484..fdd2c874e 100644
--- a/cpp/celeborn/network/tests/CMakeLists.txt
+++ b/cpp/celeborn/network/tests/CMakeLists.txt
@@ -16,7 +16,8 @@
 add_executable(
         celeborn_network_test
         FrameDecoderTest.cpp
-        MessageTest.cpp)
+        MessageTest.cpp
+        MessageDispatcherTest.cpp)
 
 add_test(NAME celeborn_network_test COMMAND celeborn_network_test)
 
diff --git a/cpp/celeborn/network/tests/MessageDispatcherTest.cpp 
b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
new file mode 100644
index 000000000..8baf9f8d7
--- /dev/null
+++ b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/network/FrameDecoder.h"
+#include "celeborn/network/MessageDispatcher.h"
+
+using namespace celeborn;
+using namespace celeborn::network;
+
+namespace {
+class MockHandler : public wangle::Handler<
+                        std::unique_ptr<folly::IOBuf>,
+                        std::unique_ptr<Message>,
+                        std::unique_ptr<Message>,
+                        std::unique_ptr<folly::IOBuf>> {
+ public:
+  MockHandler(std::unique_ptr<Message>& writedMsg) : writedMsg_(writedMsg) {}
+
+  void read(Context* ctx, std::unique_ptr<folly::IOBuf> msg) override {}
+
+  folly::Future<folly::Unit> write(Context* ctx, std::unique_ptr<Message> msg)
+      override {
+    writedMsg_ = std::move(msg);
+    return {};
+  }
+
+ private:
+  std::unique_ptr<Message>& writedMsg_;
+};
+
+SerializePipeline::Ptr createMockedPipeline(MockHandler&& mockHandler) {
+  auto pipeline = SerializePipeline::create();
+  // FrameDecoder here is just for forming a complete pipeline to pass
+  // the type checking, not used here.
+  pipeline->addBack(FrameDecoder());
+  pipeline->addBack(std::move(mockHandler));
+  pipeline->finalize();
+  return pipeline;
+}
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> toReadOnlyByteBuffer(
+    const std::string& content) {
+  auto buffer = memory::ByteBuffer::createWriteOnly(content.size());
+  buffer->writeFromString(content);
+  return memory::ByteBuffer::toReadOnly(std::move(buffer));
+}
+
+} // namespace
+
+TEST(MessageDispatcherTest, sendRpcRequestAndReceiveResponse) {
+  std::unique_ptr<Message> sendedMsg;
+  MockHandler mockHandler(sendedMsg);
+  auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+  auto dispatcher = std::make_unique<MessageDispatcher>();
+  dispatcher->setPipeline(mockPipeline.get());
+
+  const long requestId = 1001;
+  const std::string requestBody = "test-request-body";
+  auto rpcRequest = std::make_unique<RpcRequest>(
+      requestId, toReadOnlyByteBuffer(requestBody));
+  auto future = dispatcher->sendRpcRequest(std::move(rpcRequest));
+
+  EXPECT_FALSE(future.isReady());
+  EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+  auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+  EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+  const std::string responseBody = "test-response-body";
+  auto rpcResponse = std::make_unique<RpcResponse>(
+      requestId, toReadOnlyByteBuffer(responseBody));
+  dispatcher->read(nullptr, std::move(rpcResponse));
+
+  EXPECT_TRUE(future.isReady());
+  auto receivedMsg = std::move(future).get();
+  EXPECT_EQ(receivedMsg->type(), Message::RPC_RESPONSE);
+  auto receivedRpcResponse = dynamic_cast<RpcResponse*>(receivedMsg.get());
+  EXPECT_EQ(receivedRpcResponse->body()->remainingSize(), responseBody.size());
+  EXPECT_EQ(
+      receivedRpcResponse->body()->readToString(responseBody.size()),
+      responseBody);
+}
+
+TEST(MessageDispatcherTest, sendRpcRequestAndReceiveFailure) {
+  std::unique_ptr<Message> sendedMsg;
+  MockHandler mockHandler(sendedMsg);
+  auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+  auto dispatcher = std::make_unique<MessageDispatcher>();
+  dispatcher->setPipeline(mockPipeline.get());
+
+  const long requestId = 1001;
+  const std::string requestBody = "test-request-body";
+  auto rpcRequest = std::make_unique<RpcRequest>(
+      requestId, toReadOnlyByteBuffer(requestBody));
+  auto future = dispatcher->sendRpcRequest(std::move(rpcRequest));
+
+  EXPECT_FALSE(future.isReady());
+  EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+  auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+  EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+  const std::string errorMsg = "test-error-msg";
+  auto copiedErrorMsg = errorMsg;
+  auto rpcFailure =
+      std::make_unique<RpcFailure>(requestId, std::move(copiedErrorMsg));
+  dispatcher->read(nullptr, std::move(rpcFailure));
+
+  EXPECT_TRUE(future.hasException());
+}
+
+TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveSuccess) {
+  std::unique_ptr<Message> sendedMsg;
+  MockHandler mockHandler(sendedMsg);
+  auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+  auto dispatcher = std::make_unique<MessageDispatcher>();
+  dispatcher->setPipeline(mockPipeline.get());
+
+  const protocol::StreamChunkSlice streamChunkSlice{1001, 1002, 1003, 1004};
+  const long requestId = 1001;
+  const std::string requestBody = "test-request-body";
+  auto rpcRequest = std::make_unique<RpcRequest>(
+      requestId, toReadOnlyByteBuffer(requestBody));
+  auto future = dispatcher->sendFetchChunkRequest(
+      streamChunkSlice, std::move(rpcRequest));
+
+  EXPECT_FALSE(future.isReady());
+  EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+  auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+  EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+  const std::string chunkFetchSuccessBody = "test-chunk-fetch-success-body";
+  auto chunkFetchSuccess = std::make_unique<ChunkFetchSuccess>(
+      streamChunkSlice, toReadOnlyByteBuffer(chunkFetchSuccessBody));
+  dispatcher->read(nullptr, std::move(chunkFetchSuccess));
+
+  EXPECT_TRUE(future.isReady());
+  auto receivedMsg = std::move(future).get();
+  EXPECT_EQ(receivedMsg->type(), Message::CHUNK_FETCH_SUCCESS);
+  auto receivedChunkFetchSuccess =
+      dynamic_cast<ChunkFetchSuccess*>(receivedMsg.get());
+  EXPECT_EQ(
+      receivedChunkFetchSuccess->body()->remainingSize(),
+      chunkFetchSuccessBody.size());
+  EXPECT_EQ(
+      receivedChunkFetchSuccess->body()->readToString(
+          chunkFetchSuccessBody.size()),
+      chunkFetchSuccessBody);
+}
+
+TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveFailure) {
+  std::unique_ptr<Message> sendedMsg;
+  MockHandler mockHandler(sendedMsg);
+  auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+  auto dispatcher = std::make_unique<MessageDispatcher>();
+  dispatcher->setPipeline(mockPipeline.get());
+
+  const protocol::StreamChunkSlice streamChunkSlice{1001, 1002, 1003, 1004};
+  const long requestId = 1001;
+  const std::string requestBody = "test-request-body";
+  auto rpcRequest = std::make_unique<RpcRequest>(
+      requestId, toReadOnlyByteBuffer(requestBody));
+  auto future = dispatcher->sendFetchChunkRequest(
+      streamChunkSlice, std::move(rpcRequest));
+
+  EXPECT_FALSE(future.isReady());
+  EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+  auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+  EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+  const std::string errorMsg = "test-error-msg";
+  auto copiedErrorMsg = errorMsg;
+  auto chunkFetchFailure = std::make_unique<ChunkFetchFailure>(
+      streamChunkSlice, std::move(copiedErrorMsg));
+  dispatcher->read(nullptr, std::move(chunkFetchFailure));
+
+  EXPECT_TRUE(future.hasException());
+}

Reply via email to