This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 595ab41f5 [CELEBORN-1881][CIP-14] Add WorkerPartitionReader to
cppClient
595ab41f5 is described below
commit 595ab41f5e39dbd6d25a4326521af001788c0a8b
Author: HolyLow <[email protected]>
AuthorDate: Mon Mar 10 20:47:17 2025 +0800
[CELEBORN-1881][CIP-14] Add WorkerPartitionReader to cppClient
### What changes were proposed in this pull request?
This PR adds WorkerPartitionReader to cppClient.
### Why are the changes needed?
WorkerPartitionReader is the building block of CelebornInputStream.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3137 from
HolyLow/issue/celeborn-1881-add-workerpartitionreader-to-cppclient.
Authored-by: HolyLow <[email protected]>
Signed-off-by: mingji <[email protected]>
---
cpp/celeborn/CMakeLists.txt | 1 +
cpp/celeborn/{ => client}/CMakeLists.txt | 30 ++-
.../client/reader/WorkerPartitionReader.cpp | 148 +++++++++++++
cpp/celeborn/client/reader/WorkerPartitionReader.h | 93 +++++++++
cpp/celeborn/{ => client/tests}/CMakeLists.txt | 34 ++-
.../client/tests/WorkerPartitionReaderTest.cpp | 229 +++++++++++++++++++++
cpp/celeborn/network/TransportClient.h | 8 +-
7 files changed, 527 insertions(+), 16 deletions(-)
diff --git a/cpp/celeborn/CMakeLists.txt b/cpp/celeborn/CMakeLists.txt
index c5fe93782..b1effedb6 100644
--- a/cpp/celeborn/CMakeLists.txt
+++ b/cpp/celeborn/CMakeLists.txt
@@ -18,3 +18,4 @@ add_subdirectory(utils)
add_subdirectory(conf)
add_subdirectory(protocol)
add_subdirectory(network)
+add_subdirectory(client)
diff --git a/cpp/celeborn/CMakeLists.txt b/cpp/celeborn/client/CMakeLists.txt
similarity index 63%
copy from cpp/celeborn/CMakeLists.txt
copy to cpp/celeborn/client/CMakeLists.txt
index c5fe93782..eaa077e3a 100644
--- a/cpp/celeborn/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -12,9 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(proto)
-add_subdirectory(memory)
-add_subdirectory(utils)
-add_subdirectory(conf)
-add_subdirectory(protocol)
-add_subdirectory(network)
+add_library(
+ client
+ reader/WorkerPartitionReader.cpp)
+
+target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR})
+
+target_link_libraries(
+ client
+ network
+ proto
+ memory
+ protocol
+ utils
+ ${WANGLE}
+ ${FIZZ}
+ ${LIBSODIUM_LIBRARY}
+ ${FOLLY_WITH_DEPENDENCIES}
+ ${GLOG}
+ ${GFLAGS_LIBRARIES}
+)
+
+if(CELEBORN_BUILD_TESTS)
+ add_subdirectory(tests)
+endif()
\ No newline at end of file
diff --git a/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
b/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
new file mode 100644
index 000000000..a7983a01e
--- /dev/null
+++ b/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/reader/WorkerPartitionReader.h"
+
+namespace celeborn {
+namespace client {
+std::shared_ptr<WorkerPartitionReader> WorkerPartitionReader::create(
+ const std::shared_ptr<const conf::CelebornConf>& conf,
+ const std::string& shuffleKey,
+ const protocol::PartitionLocation& location,
+ int32_t startMapIndex,
+ int32_t endMapIndex,
+ network::TransportClientFactory* clientFactory) {
+ return std::shared_ptr<WorkerPartitionReader>(new WorkerPartitionReader(
+ conf, shuffleKey, location, startMapIndex, endMapIndex, clientFactory));
+}
+
+WorkerPartitionReader::WorkerPartitionReader(
+ const std::shared_ptr<const conf::CelebornConf>& conf,
+ const std::string& shuffleKey,
+ const protocol::PartitionLocation& location,
+ int32_t startMapIndex,
+ int32_t endMapIndex,
+ network::TransportClientFactory* clientFactory)
+ : shuffleKey_(shuffleKey),
+ location_(location),
+ startMapIndex_(startMapIndex),
+ endMapIndex_(endMapIndex),
+ fetchingChunkId_(0),
+ toConsumeChunkId_(0),
+ maxFetchChunksInFlight_(conf->clientFetchMaxReqsInFlight()),
+ fetchTimeout_(conf->clientFetchTimeout()) {
+ CELEBORN_CHECK_NOT_NULL(clientFactory);
+ client_ = clientFactory->createClient(location_.host, location_.fetchPort);
+
+ protocol::OpenStream openStream(
+ shuffleKey, location_.filename(), startMapIndex_, endMapIndex_);
+
+ network::RpcRequest request(
+ network::Message::nextRequestId(),
+ openStream.toTransportMessage().toReadOnlyByteBuffer());
+
+ // TODO: it might not be safe to call blocking & might failing command
+ // in constructor
+ auto response = client_->sendRpcRequestSync(request);
+ auto body = response.body();
+ auto transportMessage = protocol::TransportMessage(std::move(body));
+ streamHandler_ =
+ protocol::StreamHandler::fromTransportMessage(transportMessage);
+}
+
+WorkerPartitionReader::~WorkerPartitionReader() {
+ protocol::BufferStreamEnd bufferStreamEnd;
+ bufferStreamEnd.streamId = streamHandler_->streamId;
+ network::RpcRequest request(
+ network::Message::nextRequestId(),
+ bufferStreamEnd.toTransportMessage().toReadOnlyByteBuffer());
+ client_->sendRpcRequestWithoutResponse(request);
+}
+
+bool WorkerPartitionReader::hasNext() {
+ return toConsumeChunkId_ < streamHandler_->numChunks;
+}
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> WorkerPartitionReader::next() {
+ initAndCheck();
+ fetchChunks();
+ auto result = std::unique_ptr<memory::ReadOnlyByteBuffer>();
+ while (!result) {
+ initAndCheck();
+ // TODO: add metric or time tracing
+ chunkQueue_.try_dequeue_for(result, kDefaultConsumeIter);
+ }
+ toConsumeChunkId_++;
+ return std::move(result);
+}
+
+void WorkerPartitionReader::fetchChunks() {
+ initAndCheck();
+ while (fetchingChunkId_ - toConsumeChunkId_ < maxFetchChunksInFlight_ &&
+ fetchingChunkId_ < streamHandler_->numChunks) {
+ auto chunkId = fetchingChunkId_++;
+ auto streamChunkSlice = protocol::StreamChunkSlice{
+ streamHandler_->streamId, chunkId, 0, INT_MAX};
+ protocol::ChunkFetchRequest chunkFetchRequest;
+ chunkFetchRequest.streamChunkSlice = streamChunkSlice;
+ network::RpcRequest request(
+ network::Message::nextRequestId(),
+ chunkFetchRequest.toTransportMessage().toReadOnlyByteBuffer());
+ client_->fetchChunkAsync(streamChunkSlice, request, onSuccess_,
onFailure_);
+ }
+}
+
+void WorkerPartitionReader::initAndCheck() {
+ if (!onSuccess_) {
+ onSuccess_ = [weak_this = weak_from_this()](
+ protocol::StreamChunkSlice streamChunkSlice,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> chunk) {
+ auto shared_this = weak_this.lock();
+ if (!shared_this) {
+ return;
+ }
+ shared_this->chunkQueue_.enqueue(std::move(chunk));
+ VLOG(1) << "WorkerPartitionReader::onSuccess: "
+ << streamChunkSlice.toString();
+ };
+
+ onFailure_ = [weak_this = weak_from_this()](
+ protocol::StreamChunkSlice streamChunkSlice,
+ std::unique_ptr<std::exception> exception) {
+ auto shared_this = weak_this.lock();
+ if (!shared_this) {
+ return;
+ }
+ LOG(ERROR) << "WorkerPartitionReader::onFailure: "
+ << streamChunkSlice.toString()
+ << " msg: " << exception->what();
+ {
+ auto exp = shared_this->exception_.wlock();
+ *exp = std::move(exception);
+ }
+ };
+ }
+
+ {
+ auto exp = exception_.rlock();
+ if (*exp) {
+ CELEBORN_FAIL((*exp)->what());
+ }
+ }
+}
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/reader/WorkerPartitionReader.h
b/cpp/celeborn/client/reader/WorkerPartitionReader.h
new file mode 100644
index 000000000..68fb00230
--- /dev/null
+++ b/cpp/celeborn/client/reader/WorkerPartitionReader.h
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "celeborn/network/TransportClient.h"
+#include "celeborn/protocol/PartitionLocation.h"
+
+namespace celeborn {
+namespace client {
+class PartitionReader {
+ public:
+ virtual ~PartitionReader() = default;
+
+ virtual bool hasNext() = 0;
+
+ virtual std::unique_ptr<memory::ReadOnlyByteBuffer> next() = 0;
+};
+
+class WorkerPartitionReader
+ : public PartitionReader,
+ public std::enable_shared_from_this<WorkerPartitionReader> {
+ public:
+ // Only allow using create method to get the shared_ptr holder. This is
+ // required by the std::enable_shared_from_this functionality.
+ static std::shared_ptr<WorkerPartitionReader> create(
+ const std::shared_ptr<const conf::CelebornConf>& conf,
+ const std::string& shuffleKey,
+ const protocol::PartitionLocation& location,
+ int32_t startMapIndex,
+ int32_t endMapIndex,
+ network::TransportClientFactory* clientFactory);
+
+ ~WorkerPartitionReader() override;
+
+ bool hasNext() override;
+
+ std::unique_ptr<memory::ReadOnlyByteBuffer> next() override;
+
+ private:
+ // Disable creating the object directly to make sure that
+ // std::enable_shared_from_this works properly.
+ WorkerPartitionReader(
+ const std::shared_ptr<const conf::CelebornConf>& conf,
+ const std::string& shuffleKey,
+ const protocol::PartitionLocation& location,
+ int32_t startMapIndex,
+ int32_t endMapIndex,
+ network::TransportClientFactory* clientFactory);
+
+ void fetchChunks();
+
+ // This function cannot be called within constructor!
+ void initAndCheck();
+
+ std::string shuffleKey_;
+ protocol::PartitionLocation location_;
+ std::shared_ptr<network::TransportClient> client_;
+ int32_t startMapIndex_;
+ int32_t endMapIndex_;
+ std::unique_ptr<protocol::StreamHandler> streamHandler_;
+
+ int32_t fetchingChunkId_;
+ int32_t toConsumeChunkId_;
+ int32_t maxFetchChunksInFlight_;
+ Timeout fetchTimeout_;
+
+ folly::UMPSCQueue<std::unique_ptr<memory::ReadOnlyByteBuffer>, true>
+ chunkQueue_;
+ network::FetchChunkSuccessCallback onSuccess_;
+ network::FetchChunkFailureCallback onFailure_;
+ folly::Synchronized<std::unique_ptr<std::exception>> exception_;
+
+ static constexpr auto kDefaultConsumeIter = std::chrono::milliseconds(500);
+
+ // TODO: add other params, such as fetchChunkRetryCnt, fetchChunkMaxRetry
+};
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/CMakeLists.txt
b/cpp/celeborn/client/tests/CMakeLists.txt
similarity index 56%
copy from cpp/celeborn/CMakeLists.txt
copy to cpp/celeborn/client/tests/CMakeLists.txt
index c5fe93782..6543ca28d 100644
--- a/cpp/celeborn/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -12,9 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-add_subdirectory(proto)
-add_subdirectory(memory)
-add_subdirectory(utils)
-add_subdirectory(conf)
-add_subdirectory(protocol)
-add_subdirectory(network)
+
+add_executable(
+ celeborn_client_test
+ WorkerPartitionReaderTest.cpp)
+
+add_test(NAME celeborn_client_test COMMAND celeborn_client_test)
+
+target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR})
+
+target_link_libraries(
+ celeborn_client_test
+ PRIVATE
+ memory
+ conf
+ proto
+ protocol
+ utils
+ network
+ client
+ ${WANGLE}
+ ${FIZZ}
+ ${LIBSODIUM_LIBRARY}
+ ${FOLLY_WITH_DEPENDENCIES}
+ ${GLOG}
+ ${GFLAGS_LIBRARIES}
+ GTest::gtest
+ GTest::gmock
+ GTest::gtest_main)
\ No newline at end of file
diff --git a/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
b/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
new file mode 100644
index 000000000..d5a07a992
--- /dev/null
+++ b/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <gtest/gtest.h>
+
+#include "celeborn/client/reader/WorkerPartitionReader.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+using namespace celeborn::network;
+using namespace celeborn::protocol;
+using namespace celeborn::memory;
+using namespace celeborn::conf;
+
+namespace {
+using MS = std::chrono::milliseconds;
+
+class MockTransportClient : public TransportClient {
+ public:
+ MockTransportClient()
+ : TransportClient(nullptr, nullptr, MS(100)),
+ syncResponse_(RpcResponse(0, ReadOnlyByteBuffer::createEmptyBuffer())),
+ syncRequest_(RpcRequest(0, ReadOnlyByteBuffer::createEmptyBuffer())),
+ nonResponseRequest_(
+ RpcRequest(0, ReadOnlyByteBuffer::createEmptyBuffer())),
+ fetchChunkRequest_(
+ RpcRequest(0, ReadOnlyByteBuffer::createEmptyBuffer())) {}
+
+ RpcResponse sendRpcRequestSync(const RpcRequest& request, Timeout timeout)
+ override {
+ syncRequest_ = request;
+ return syncResponse_;
+ }
+
+ void sendRpcRequestWithoutResponse(const RpcRequest& request) override {
+ nonResponseRequest_ = request;
+ }
+
+ void fetchChunkAsync(
+ const StreamChunkSlice& streamChunkSlice,
+ const RpcRequest& request,
+ FetchChunkSuccessCallback onSuccess,
+ FetchChunkFailureCallback onFailure) override {
+ streamChunkSlice_ = streamChunkSlice;
+ fetchChunkRequest_ = request;
+ if (fetchChunkOnSuccessResult_) {
+ onSuccess(streamChunkSlice, std::move(fetchChunkOnSuccessResult_));
+ } else if (fetchChunkOnFailureResult_) {
+ onFailure(streamChunkSlice, std::move(fetchChunkOnFailureResult_));
+ }
+ }
+
+ void setSyncResponse(const RpcResponse& response) {
+ syncResponse_ = response;
+ }
+
+ void setFetchChunkSuccessResult(std::unique_ptr<ReadOnlyByteBuffer> result) {
+ fetchChunkOnSuccessResult_ = std::move(result);
+ }
+
+ void setFetchChunkFailureResult(std::unique_ptr<std::exception> result) {
+ fetchChunkOnFailureResult_ = std::move(result);
+ }
+
+ RpcRequest getSyncRequest() {
+ return syncRequest_;
+ }
+
+ RpcRequest getNoneResponseRequest() {
+ return nonResponseRequest_;
+ }
+
+ StreamChunkSlice getStreamChunkSlice() {
+ return streamChunkSlice_;
+ }
+
+ RpcRequest getFetchChunkRequest() {
+ return fetchChunkRequest_;
+ }
+
+ private:
+ RpcResponse syncResponse_;
+ RpcRequest syncRequest_;
+ RpcRequest nonResponseRequest_;
+ StreamChunkSlice streamChunkSlice_;
+ RpcRequest fetchChunkRequest_;
+ std::unique_ptr<ReadOnlyByteBuffer> fetchChunkOnSuccessResult_;
+ std::unique_ptr<std::exception> fetchChunkOnFailureResult_;
+};
+
+class MockTransportClientFactory : public TransportClientFactory {
+ public:
+ MockTransportClientFactory()
+ : TransportClientFactory(std::make_shared<CelebornConf>()),
+ transportClient_(std::make_shared<MockTransportClient>()) {}
+
+ std::shared_ptr<TransportClient> createClient(
+ const std::string& host,
+ uint16_t port) override {
+ return transportClient_;
+ }
+
+ std::shared_ptr<MockTransportClient> getClient() {
+ return transportClient_;
+ }
+
+ private:
+ std::shared_ptr<MockTransportClient> transportClient_;
+};
+
+std::unique_ptr<ReadOnlyByteBuffer> toReadOnlyByteBuffer(
+ const std::string& content) {
+ auto buffer = ByteBuffer::createWriteOnly(content.size());
+ buffer->writeFromString(content);
+ return ByteBuffer::toReadOnly(std::move(buffer));
+}
+} // namespace
+
+TEST(WorkerPartitionReaderTest, fetchChunkSuccess) {
+ const std::string shuffleKey = "test-shuffle-key";
+ PartitionLocation location;
+ location.host = "test-host";
+ location.fetchPort = 1011;
+ location.id = 1;
+ location.epoch = 2;
+ location.mode = PartitionLocation::Mode::PRIMARY;
+ location.storageInfo = std::make_unique<StorageInfo>();
+ const std::string filename = std::to_string(location.id) + "-" +
+ std::to_string(location.epoch) + "-" + std::to_string(location.mode);
+ const int32_t startMapIndex = 0;
+ const int32_t endMapIndex = 100;
+ MockTransportClientFactory mockedClientFactory;
+ auto conf = std::make_shared<CelebornConf>();
+ auto transportClient = mockedClientFactory.getClient();
+
+ // Build a pbStreamHandler with 1 chunk, pack into a RpcResponse,
+ // set as response.
+ PbStreamHandler pb;
+ const int streamId = 100;
+ const int numChunks = 1;
+ pb.set_streamid(streamId);
+ pb.set_numchunks(numChunks);
+ for (int i = 0; i < numChunks; i++) {
+ pb.add_chunkoffsets(i);
+ }
+ pb.set_fullpath("test-fullpath");
+ TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+ RpcResponse response =
+ RpcResponse(1111, transportMessage.toReadOnlyByteBuffer());
+ transportClient->setSyncResponse(response);
+
+ // Set the chunk to be returned.
+ const std::string chunkBody = "test-chunk-body";
+ transportClient->setFetchChunkSuccessResult(toReadOnlyByteBuffer(chunkBody));
+
+ auto partitionReader = WorkerPartitionReader::create(
+ conf,
+ shuffleKey,
+ location,
+ startMapIndex,
+ endMapIndex,
+ &mockedClientFactory);
+ EXPECT_TRUE(partitionReader->hasNext());
+ auto chunk = partitionReader->next();
+ EXPECT_EQ(chunk->remainingSize(), chunkBody.size());
+ EXPECT_EQ(chunk->readToString(chunk->remainingSize()), chunkBody);
+ EXPECT_FALSE(partitionReader->hasNext());
+}
+
+TEST(WorkerPartitionReaderTest, fetchChunkFailure) {
+ const std::string shuffleKey = "test-shuffle-key";
+ PartitionLocation location;
+ location.host = "test-host";
+ location.fetchPort = 1011;
+ location.id = 1;
+ location.epoch = 2;
+ location.mode = PartitionLocation::Mode::PRIMARY;
+ location.storageInfo = std::make_unique<StorageInfo>();
+ const std::string filename = std::to_string(location.id) + "-" +
+ std::to_string(location.epoch) + "-" + std::to_string(location.mode);
+ const int32_t startMapIndex = 0;
+ const int32_t endMapIndex = 100;
+ MockTransportClientFactory mockedClientFactory;
+ auto conf = std::make_shared<CelebornConf>();
+ auto transportClient = mockedClientFactory.getClient();
+
+ // Build a pbStreamHandler with 1 chunk, pack into a RpcResponse,
+ // set as response.
+ PbStreamHandler pb;
+ const int streamId = 100;
+ const int numChunks = 1;
+ pb.set_streamid(streamId);
+ pb.set_numchunks(numChunks);
+ for (int i = 0; i < numChunks; i++) {
+ pb.add_chunkoffsets(i);
+ }
+ pb.set_fullpath("test-fullpath");
+ TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+ RpcResponse response =
+ RpcResponse(1111, transportMessage.toReadOnlyByteBuffer());
+ transportClient->setSyncResponse(response);
+
+ // Set the error to be returned.
+ transportClient->setFetchChunkFailureResult(
+ std::make_unique<std::runtime_error>("test-runtime-error"));
+
+ auto partitionReader = WorkerPartitionReader::create(
+ conf,
+ shuffleKey,
+ location,
+ startMapIndex,
+ endMapIndex,
+ &mockedClientFactory);
+ EXPECT_TRUE(partitionReader->hasNext());
+ EXPECT_THROW(partitionReader->next(), std::exception);
+}
diff --git a/cpp/celeborn/network/TransportClient.h
b/cpp/celeborn/network/TransportClient.h
index 35388d4d6..5bfa4afdc 100644
--- a/cpp/celeborn/network/TransportClient.h
+++ b/cpp/celeborn/network/TransportClient.h
@@ -65,7 +65,7 @@ class TransportClient {
std::unique_ptr<MessageDispatcher> dispatcher,
Timeout defaultTimeout);
- RpcResponse sendRpcRequestSync(const RpcRequest& request) {
+ virtual RpcResponse sendRpcRequestSync(const RpcRequest& request) {
return sendRpcRequestSync(request, defaultTimeout_);
}
@@ -74,9 +74,9 @@ class TransportClient {
Timeout timeout);
// Ignore the response, return immediately.
- void sendRpcRequestWithoutResponse(const RpcRequest& request);
+ virtual void sendRpcRequestWithoutResponse(const RpcRequest& request);
- void fetchChunkAsync(
+ virtual void fetchChunkAsync(
const protocol::StreamChunkSlice& streamChunkSlice,
const RpcRequest& request,
FetchChunkSuccessCallback onSuccess,
@@ -105,7 +105,7 @@ class TransportClientFactory {
public:
TransportClientFactory(const std::shared_ptr<conf::CelebornConf>& conf);
- std::shared_ptr<TransportClient> createClient(
+ virtual std::shared_ptr<TransportClient> createClient(
const std::string& host,
uint16_t port);