This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 39a40dd2a [CELEBORN-1845][CIP-14] Add MessageDispatcher to cppClient
39a40dd2a is described below
commit 39a40dd2a18f3da4b70c414f2a2aff5bdf35117a
Author: HolyLow <[email protected]>
AuthorDate: Wed Jan 22 20:07:22 2025 +0800
[CELEBORN-1845][CIP-14] Add MessageDispatcher to cppClient
### What changes were proposed in this pull request?
This PR adds MessageDispatcher class to cppClient.
### Why are the changes needed?
MessageDispatcher is responsible for recording the connection between
MessageFuture and MessagePromise, which is the base of async message
transferring mechanism.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3077 from
HolyLow/issue/celeborn-1845-add-message-dispatcher-to-cpp-client.
Authored-by: HolyLow <[email protected]>
Signed-off-by: mingji <[email protected]>
---
cpp/celeborn/network/CMakeLists.txt | 6 +-
cpp/celeborn/network/Message.h | 4 +-
cpp/celeborn/network/MessageDispatcher.cpp | 257 +++++++++++++++++++++
cpp/celeborn/network/MessageDispatcher.h | 113 +++++++++
cpp/celeborn/network/tests/CMakeLists.txt | 3 +-
.../network/tests/MessageDispatcherTest.cpp | 200 ++++++++++++++++
6 files changed, 579 insertions(+), 4 deletions(-)
diff --git a/cpp/celeborn/network/CMakeLists.txt
b/cpp/celeborn/network/CMakeLists.txt
index 3a65828bd..1acf114bb 100644
--- a/cpp/celeborn/network/CMakeLists.txt
+++ b/cpp/celeborn/network/CMakeLists.txt
@@ -15,7 +15,8 @@
add_library(
network
STATIC
- Message.cpp)
+ Message.cpp
+ MessageDispatcher.cpp)
target_include_directories(network PUBLIC ${CMAKE_BINARY_DIR})
@@ -25,6 +26,9 @@ target_link_libraries(
proto
utils
protocol
+ ${WANGLE}
+ ${FIZZ}
+ ${LIBSODIUM_LIBRARY}
${FOLLY_WITH_DEPENDENCIES}
${GLOG}
${GFLAGS_LIBRARIES}
diff --git a/cpp/celeborn/network/Message.h b/cpp/celeborn/network/Message.h
index a4b269aad..ddb6d9979 100644
--- a/cpp/celeborn/network/Message.h
+++ b/cpp/celeborn/network/Message.h
@@ -178,7 +178,7 @@ class RpcFailure : public Message {
class ChunkFetchSuccess : public Message {
public:
ChunkFetchSuccess(
- protocol::StreamChunkSlice& streamChunkSlice,
+ const protocol::StreamChunkSlice& streamChunkSlice,
std::unique_ptr<memory::ReadOnlyByteBuffer>&& body)
: Message(CHUNK_FETCH_SUCCESS, std::move(body)),
streamChunkSlice_(streamChunkSlice) {}
@@ -201,7 +201,7 @@ class ChunkFetchSuccess : public Message {
class ChunkFetchFailure : public Message {
public:
ChunkFetchFailure(
- protocol::StreamChunkSlice& streamChunkSlice,
+ const protocol::StreamChunkSlice& streamChunkSlice,
std::string&& errorString)
: Message(
CHUNK_FETCH_FAILURE,
diff --git a/cpp/celeborn/network/MessageDispatcher.cpp
b/cpp/celeborn/network/MessageDispatcher.cpp
new file mode 100644
index 000000000..30161483d
--- /dev/null
+++ b/cpp/celeborn/network/MessageDispatcher.cpp
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/network/MessageDispatcher.h"
+
+#include "celeborn/protocol/TransportMessage.h"
+
+namespace celeborn {
+namespace network {
+void MessageDispatcher::read(Context*, std::unique_ptr<Message> toRecvMsg) {
+ switch (toRecvMsg->type()) {
+ case Message::RPC_RESPONSE: {
+ RpcResponse* response = reinterpret_cast<RpcResponse*>(toRecvMsg.get());
+ bool found = true;
+ auto holder = requestIdRegistry_.withLock([&](auto& registry) {
+ auto search = registry.find(response->requestId());
+ if (search == registry.end()) {
+ LOG(WARNING)
+ << "requestId " << response->requestId()
+ << " not found when handling RPC_RESPONSE. Might be outdated
already, ignored.";
+ found = false;
+ return MsgPromiseHolder{};
+ }
+ auto result = std::move(search->second);
+ registry.erase(response->requestId());
+ return std::move(result);
+ });
+ if (found) {
+ holder.msgPromise.setValue(std::move(toRecvMsg));
+ }
+ return;
+ }
+ case Message::RPC_FAILURE: {
+ RpcFailure* failure = reinterpret_cast<RpcFailure*>(toRecvMsg.get());
+ bool found = true;
+ auto holder = requestIdRegistry_.withLock([&](auto& registry) {
+ auto search = registry.find(failure->requestId());
+ if (search == registry.end()) {
+ LOG(WARNING)
+ << "requestId " << failure->requestId()
+ << " not found when handling RPC_FAILURE. Might be outdated
already, ignored.";
+ found = false;
+ return MsgPromiseHolder{};
+ }
+ auto result = std::move(search->second);
+ registry.erase(failure->requestId());
+ return std::move(result);
+ });
+ LOG(ERROR) << "Rpc failed, requestId: " << failure->requestId()
+ << " errorMsg: " << failure->errorMsg() << std::endl;
+ if (found) {
+ holder.msgPromise.setException(
+ folly::exception_wrapper(std::exception()));
+ }
+ return;
+ }
+ case Message::CHUNK_FETCH_SUCCESS: {
+ ChunkFetchSuccess* success =
+ reinterpret_cast<ChunkFetchSuccess*>(toRecvMsg.get());
+ auto streamChunkSlice = success->streamChunkSlice();
+ bool found = true;
+ auto holder = streamChunkSliceRegistry_.withLock([&](auto& registry) {
+ auto search = registry.find(streamChunkSlice);
+ if (search == registry.end()) {
+ LOG(WARNING)
+ << "streamChunkSlice " << streamChunkSlice.toString()
+ << " not found when handling CHUNK_FETCH_SUCCESS. Might be
outdated already, ignored.";
+ found = false;
+ return MsgPromiseHolder{};
+ }
+ auto result = std::move(search->second);
+ registry.erase(streamChunkSlice);
+ return std::move(result);
+ });
+ if (found) {
+ holder.msgPromise.setValue(std::move(toRecvMsg));
+ }
+ return;
+ }
+ case Message::CHUNK_FETCH_FAILURE: {
+ ChunkFetchFailure* failure =
+ reinterpret_cast<ChunkFetchFailure*>(toRecvMsg.get());
+ auto streamChunkSlice = failure->streamChunkSlice();
+ bool found = true;
+ auto holder = streamChunkSliceRegistry_.withLock([&](auto& registry) {
+ auto search = registry.find(streamChunkSlice);
+ if (search == registry.end()) {
+ LOG(WARNING)
+ << "streamChunkSlice " << streamChunkSlice.toString()
+ << " not found when handling CHUNK_FETCH_FAILURE. Might be
outdated already, ignored.";
+ found = false;
+ return MsgPromiseHolder{};
+ }
+ auto result = std::move(search->second);
+ registry.erase(streamChunkSlice);
+ return std::move(result);
+ });
+ std::string errorMsg = fmt::format(
+ "fetchChunk failed, streamChunkSlice: {}, errorMsg: {}",
+ streamChunkSlice.toString(),
+ failure->errorMsg());
+ LOG(ERROR) << errorMsg;
+ if (found) {
+ holder.msgPromise.setException(
+ folly::exception_wrapper(std::exception()));
+ }
+ return;
+ }
+ default: {
+ LOG(ERROR) << "unsupported msg for dispatcher";
+ }
+ }
+}
+
+folly::Future<std::unique_ptr<Message>> MessageDispatcher::operator()(
+ std::unique_ptr<Message> toSendMsg) {
+ CELEBORN_CHECK(!closed_);
+ CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
+ RpcRequest* request = reinterpret_cast<RpcRequest*>(toSendMsg.get());
+ auto f = requestIdRegistry_.withLock(
+ [&](auto& registry) -> folly::Future<std::unique_ptr<Message>> {
+ auto& holder = registry[request->requestId()];
+ holder.requestTime = std::chrono::system_clock::now();
+ auto& p = holder.msgPromise;
+ p.setInterruptHandler([requestId = request->requestId(),
+ this](const folly::exception_wrapper&) {
+ this->requestIdRegistry_.lock()->erase(requestId);
+ LOG(WARNING) << "rpc request interrupted, requestId: " << requestId;
+ });
+ return p.getFuture();
+ });
+
+ this->pipeline_->write(std::move(toSendMsg));
+
+ CELEBORN_CHECK(!closed_);
+ return f;
+}
+
+folly::Future<std::unique_ptr<Message>>
+MessageDispatcher::sendFetchChunkRequest(
+ const protocol::StreamChunkSlice& streamChunkSlice,
+ std::unique_ptr<Message> toSendMsg) {
+ CELEBORN_CHECK(!closed_);
+ CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
+ auto f = streamChunkSliceRegistry_.withLock([&](auto& registry) {
+ auto& holder = registry[streamChunkSlice];
+ holder.requestTime = std::chrono::system_clock::now();
+ auto& p = holder.msgPromise;
+ p.setInterruptHandler(
+ [streamChunkSlice, this](const folly::exception_wrapper&) {
+ LOG(WARNING) << "fetchChunk request interrupted, "
+ "streamChunkSlice: "
+ << streamChunkSlice.toString();
+ this->streamChunkSliceRegistry_.lock()->erase(streamChunkSlice);
+ });
+ return p.getFuture();
+ });
+ this->pipeline_->write(std::move(toSendMsg));
+ CELEBORN_CHECK(!closed_);
+ return f;
+}
+
+void MessageDispatcher::sendRpcRequestWithoutResponse(
+ std::unique_ptr<Message> toSendMsg) {
+ CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
+ this->pipeline_->write(std::move(toSendMsg));
+}
+
+void MessageDispatcher::readEOF(Context* ctx) {
+ LOG(ERROR) << "readEOF, start to close client";
+ ctx->fireReadEOF();
+ close();
+}
+
+void MessageDispatcher::readException(
+ Context* ctx,
+ folly::exception_wrapper e) {
+ LOG(ERROR) << "readException: " << e.what() << " , start to close client";
+ ctx->fireReadException(std::move(e));
+ close();
+}
+
+void MessageDispatcher::transportActive(Context* ctx) {
+ // Typically do nothing.
+ ctx->fireTransportActive();
+}
+
+void MessageDispatcher::transportInactive(Context* ctx) {
+ LOG(ERROR) << "transportInactive, start to close client";
+ ctx->fireTransportInactive();
+ close();
+}
+
+folly::Future<folly::Unit> MessageDispatcher::writeException(
+ Context* ctx,
+ folly::exception_wrapper e) {
+ LOG(ERROR) << "writeException: " << e.what() << " , start to close client";
+ auto result = ctx->fireWriteException(std::move(e));
+ close();
+ return result;
+}
+
+folly::Future<folly::Unit> MessageDispatcher::close() {
+ if (!closed_) {
+ closed_ = true;
+ cleanup();
+ }
+ return ClientDispatcherBase::close();
+}
+
+folly::Future<folly::Unit> MessageDispatcher::close(Context* ctx) {
+ if (!closed_) {
+ closed_ = true;
+ cleanup();
+ }
+
+ return ClientDispatcherBase::close(ctx);
+}
+
+void MessageDispatcher::cleanup() {
+ LOG(WARNING) << "Cleaning up client!";
+ requestIdRegistry_.withLock([&](auto& registry) {
+ for (auto& [requestId, promiseHolder] : registry) {
+ auto errorMsg =
+ fmt::format("Client closed, cancel ongoing requestId {}", requestId);
+ LOG(WARNING) << errorMsg;
+ promiseHolder.msgPromise.setException(std::runtime_error(errorMsg));
+ }
+ registry.clear();
+ });
+ streamChunkSliceRegistry_.withLock([&](auto& registry) {
+ for (auto& [streamChunkSlice, promiseHolder] : registry) {
+ auto errorMsg = fmt::format(
+ "Client closed, cancel ongoing streamChunkSlice {}",
+ streamChunkSlice.toString());
+ LOG(WARNING) << errorMsg;
+ promiseHolder.msgPromise.setException(std::runtime_error(errorMsg));
+ }
+ registry.clear();
+ });
+}
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/MessageDispatcher.h
b/cpp/celeborn/network/MessageDispatcher.h
new file mode 100644
index 000000000..24a233a7c
--- /dev/null
+++ b/cpp/celeborn/network/MessageDispatcher.h
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <wangle/bootstrap/ClientBootstrap.h>
+#include <wangle/channel/AsyncSocketHandler.h>
+#include <wangle/channel/EventBaseHandler.h>
+#include <wangle/service/ClientDispatcher.h>
+
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/network/Message.h"
+#include "celeborn/protocol/ControlMessages.h"
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace network {
+using SerializePipeline =
+ wangle::Pipeline<folly::IOBufQueue&, std::unique_ptr<Message>>;
+
+/**
+ * MessageDispatcher is responsible for:
+ * 1. Record the connection between MessageFuture and MessagePromise.
+ * When a request message is sent via write(), the MessageFuture is
+ * recorded; then when the response message is received via read(),
+ * the response would be transferred to MessageFuture by fulfilling
+ * the corresponding MessagePromise.
+ * 2. Send different messages via different interfaces, and calls
+ * write() to send it to the network layer. A MessagePromise is
+ * created and recorded for each returned MessageFuture.
+ * 3. Receive response messages via read(), and dispatch the message
+ * according to the message kind, and finally fulfills the
+ * corresponding MessagePromise.
+ * 4. Handles and reports all kinds of network issues, e.g. EOF,
+ * inactive, exception, etc.
+ */
+class MessageDispatcher : public wangle::ClientDispatcherBase<
+ SerializePipeline,
+ std::unique_ptr<Message>,
+ std::unique_ptr<Message>> {
+public:
+ void read(Context*, std::unique_ptr<Message> toRecvMsg) override;
+
+ virtual folly::Future<std::unique_ptr<Message>> sendRpcRequest(
+ std::unique_ptr<Message> toSendMsg) {
+ return operator()(std::move(toSendMsg));
+ }
+
+ virtual folly::Future<std::unique_ptr<Message>> sendFetchChunkRequest(
+ const protocol::StreamChunkSlice& streamChunkSlice,
+ std::unique_ptr<Message> toSendMsg);
+
+ virtual void sendRpcRequestWithoutResponse(
+ std::unique_ptr<Message> toSendMsg);
+
+ folly::Future<std::unique_ptr<Message>> operator()(
+ std::unique_ptr<Message> toSendMsg) override;
+
+ void readEOF(Context* ctx) override;
+
+ void readException(Context* ctx, folly::exception_wrapper e) override;
+
+ void transportActive(Context* ctx) override;
+
+ void transportInactive(Context* ctx) override;
+
+ folly::Future<folly::Unit> writeException(
+ Context* ctx,
+ folly::exception_wrapper e) override;
+
+ folly::Future<folly::Unit> close() override;
+
+ folly::Future<folly::Unit> close(Context* ctx) override;
+
+ bool isAvailable() override {
+ return !closed_;
+ }
+
+private:
+ void cleanup();
+
+ using MsgPromise = folly::Promise<std::unique_ptr<Message>>;
+ struct MsgPromiseHolder {
+ MsgPromise msgPromise;
+ std::chrono::time_point<std::chrono::system_clock> requestTime;
+ };
+ folly::Synchronized<std::unordered_map<long, MsgPromiseHolder>, std::mutex>
+ requestIdRegistry_;
+ folly::Synchronized<
+ std::unordered_map<
+ protocol::StreamChunkSlice,
+ MsgPromiseHolder,
+ protocol::StreamChunkSlice::Hasher>,
+ std::mutex>
+ streamChunkSliceRegistry_;
+ std::atomic<bool> closed_{false};
+};
+} // namespace network
+} // namespace celeborn
diff --git a/cpp/celeborn/network/tests/CMakeLists.txt
b/cpp/celeborn/network/tests/CMakeLists.txt
index db38fb484..fdd2c874e 100644
--- a/cpp/celeborn/network/tests/CMakeLists.txt
+++ b/cpp/celeborn/network/tests/CMakeLists.txt
@@ -16,7 +16,8 @@
add_executable(
celeborn_network_test
FrameDecoderTest.cpp
- MessageTest.cpp)
+ MessageTest.cpp
+ MessageDispatcherTest.cpp)
add_test(NAME celeborn_network_test COMMAND celeborn_network_test)
diff --git a/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
new file mode 100644
index 000000000..8baf9f8d7
--- /dev/null
+++ b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/network/FrameDecoder.h"
+#include "celeborn/network/MessageDispatcher.h"
+
+using namespace celeborn;
+using namespace celeborn::network;
+
+namespace {
+class MockHandler : public wangle::Handler<
+ std::unique_ptr<folly::IOBuf>,
+ std::unique_ptr<Message>,
+ std::unique_ptr<Message>,
+ std::unique_ptr<folly::IOBuf>> {
+ public:
+ MockHandler(std::unique_ptr<Message>& writedMsg) : writedMsg_(writedMsg) {}
+
+ void read(Context* ctx, std::unique_ptr<folly::IOBuf> msg) override {}
+
+ folly::Future<folly::Unit> write(Context* ctx, std::unique_ptr<Message> msg)
+ override {
+ writedMsg_ = std::move(msg);
+ return {};
+ }
+
+ private:
+ std::unique_ptr<Message>& writedMsg_;
+};
+
+SerializePipeline::Ptr createMockedPipeline(MockHandler&& mockHandler) {
+ auto pipeline = SerializePipeline::create();
+ // FrameDecoder here is just for forming a complete pipeline to pass
+ // the type checking, not used here.
+ pipeline->addBack(FrameDecoder());
+ pipeline->addBack(std::move(mockHandler));
+ pipeline->finalize();
+ return pipeline;
+}
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> toReadOnlyByteBuffer(
+ const std::string& content) {
+ auto buffer = memory::ByteBuffer::createWriteOnly(content.size());
+ buffer->writeFromString(content);
+ return memory::ByteBuffer::toReadOnly(std::move(buffer));
+}
+
+} // namespace
+
+TEST(MessageDispatcherTest, sendRpcRequestAndReceiveResponse) {
+ std::unique_ptr<Message> sendedMsg;
+ MockHandler mockHandler(sendedMsg);
+ auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+ auto dispatcher = std::make_unique<MessageDispatcher>();
+ dispatcher->setPipeline(mockPipeline.get());
+
+ const long requestId = 1001;
+ const std::string requestBody = "test-request-body";
+ auto rpcRequest = std::make_unique<RpcRequest>(
+ requestId, toReadOnlyByteBuffer(requestBody));
+ auto future = dispatcher->sendRpcRequest(std::move(rpcRequest));
+
+ EXPECT_FALSE(future.isReady());
+ EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+ auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+ EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+ const std::string responseBody = "test-response-body";
+ auto rpcResponse = std::make_unique<RpcResponse>(
+ requestId, toReadOnlyByteBuffer(responseBody));
+ dispatcher->read(nullptr, std::move(rpcResponse));
+
+ EXPECT_TRUE(future.isReady());
+ auto receivedMsg = std::move(future).get();
+ EXPECT_EQ(receivedMsg->type(), Message::RPC_RESPONSE);
+ auto receivedRpcResponse = dynamic_cast<RpcResponse*>(receivedMsg.get());
+ EXPECT_EQ(receivedRpcResponse->body()->remainingSize(), responseBody.size());
+ EXPECT_EQ(
+ receivedRpcResponse->body()->readToString(responseBody.size()),
+ responseBody);
+}
+
+TEST(MessageDispatcherTest, sendRpcRequestAndReceiveFailure) {
+ std::unique_ptr<Message> sendedMsg;
+ MockHandler mockHandler(sendedMsg);
+ auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+ auto dispatcher = std::make_unique<MessageDispatcher>();
+ dispatcher->setPipeline(mockPipeline.get());
+
+ const long requestId = 1001;
+ const std::string requestBody = "test-request-body";
+ auto rpcRequest = std::make_unique<RpcRequest>(
+ requestId, toReadOnlyByteBuffer(requestBody));
+ auto future = dispatcher->sendRpcRequest(std::move(rpcRequest));
+
+ EXPECT_FALSE(future.isReady());
+ EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+ auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+ EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+ const std::string errorMsg = "test-error-msg";
+ auto copiedErrorMsg = errorMsg;
+ auto rpcFailure =
+ std::make_unique<RpcFailure>(requestId, std::move(copiedErrorMsg));
+ dispatcher->read(nullptr, std::move(rpcFailure));
+
+ EXPECT_TRUE(future.hasException());
+}
+
+TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveSuccess) {
+ std::unique_ptr<Message> sendedMsg;
+ MockHandler mockHandler(sendedMsg);
+ auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+ auto dispatcher = std::make_unique<MessageDispatcher>();
+ dispatcher->setPipeline(mockPipeline.get());
+
+ const protocol::StreamChunkSlice streamChunkSlice{1001, 1002, 1003, 1004};
+ const long requestId = 1001;
+ const std::string requestBody = "test-request-body";
+ auto rpcRequest = std::make_unique<RpcRequest>(
+ requestId, toReadOnlyByteBuffer(requestBody));
+ auto future = dispatcher->sendFetchChunkRequest(
+ streamChunkSlice, std::move(rpcRequest));
+
+ EXPECT_FALSE(future.isReady());
+ EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+ auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+ EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+ const std::string chunkFetchSuccessBody = "test-chunk-fetch-success-body";
+ auto chunkFetchSuccess = std::make_unique<ChunkFetchSuccess>(
+ streamChunkSlice, toReadOnlyByteBuffer(chunkFetchSuccessBody));
+ dispatcher->read(nullptr, std::move(chunkFetchSuccess));
+
+ EXPECT_TRUE(future.isReady());
+ auto receivedMsg = std::move(future).get();
+ EXPECT_EQ(receivedMsg->type(), Message::CHUNK_FETCH_SUCCESS);
+ auto receivedChunkFetchSuccess =
+ dynamic_cast<ChunkFetchSuccess*>(receivedMsg.get());
+ EXPECT_EQ(
+ receivedChunkFetchSuccess->body()->remainingSize(),
+ chunkFetchSuccessBody.size());
+ EXPECT_EQ(
+ receivedChunkFetchSuccess->body()->readToString(
+ chunkFetchSuccessBody.size()),
+ chunkFetchSuccessBody);
+}
+
+TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveFailure) {
+ std::unique_ptr<Message> sendedMsg;
+ MockHandler mockHandler(sendedMsg);
+ auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+ auto dispatcher = std::make_unique<MessageDispatcher>();
+ dispatcher->setPipeline(mockPipeline.get());
+
+ const protocol::StreamChunkSlice streamChunkSlice{1001, 1002, 1003, 1004};
+ const long requestId = 1001;
+ const std::string requestBody = "test-request-body";
+ auto rpcRequest = std::make_unique<RpcRequest>(
+ requestId, toReadOnlyByteBuffer(requestBody));
+ auto future = dispatcher->sendFetchChunkRequest(
+ streamChunkSlice, std::move(rpcRequest));
+
+ EXPECT_FALSE(future.isReady());
+ EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST);
+ auto sendedRpcRequest = dynamic_cast<RpcRequest*>(sendedMsg.get());
+ EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sendedRpcRequest->body()->readToString(requestBody.size()), requestBody);
+
+ const std::string errorMsg = "test-error-msg";
+ auto copiedErrorMsg = errorMsg;
+ auto chunkFetchFailure = std::make_unique<ChunkFetchFailure>(
+ streamChunkSlice, std::move(copiedErrorMsg));
+ dispatcher->read(nullptr, std::move(chunkFetchFailure));
+
+ EXPECT_TRUE(future.hasException());
+}