This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 595ab41f5 [CELEBORN-1881][CIP-14] Add WorkerPartitionReader to 
cppClient
595ab41f5 is described below

commit 595ab41f5e39dbd6d25a4326521af001788c0a8b
Author: HolyLow <[email protected]>
AuthorDate: Mon Mar 10 20:47:17 2025 +0800

    [CELEBORN-1881][CIP-14] Add WorkerPartitionReader to cppClient
    
    ### What changes were proposed in this pull request?
    This PR adds WorkerPartitionReader to cppClient.
    
    ### Why are the changes needed?
    WorkerPartitionReader is the building block of CelebornInputStream.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3137 from 
HolyLow/issue/celeborn-1881-add-workerpartitionreader-to-cppclient.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 cpp/celeborn/CMakeLists.txt                        |   1 +
 cpp/celeborn/{ => client}/CMakeLists.txt           |  30 ++-
 .../client/reader/WorkerPartitionReader.cpp        | 148 +++++++++++++
 cpp/celeborn/client/reader/WorkerPartitionReader.h |  93 +++++++++
 cpp/celeborn/{ => client/tests}/CMakeLists.txt     |  34 ++-
 .../client/tests/WorkerPartitionReaderTest.cpp     | 229 +++++++++++++++++++++
 cpp/celeborn/network/TransportClient.h             |   8 +-
 7 files changed, 527 insertions(+), 16 deletions(-)

diff --git a/cpp/celeborn/CMakeLists.txt b/cpp/celeborn/CMakeLists.txt
index c5fe93782..b1effedb6 100644
--- a/cpp/celeborn/CMakeLists.txt
+++ b/cpp/celeborn/CMakeLists.txt
@@ -18,3 +18,4 @@ add_subdirectory(utils)
 add_subdirectory(conf)
 add_subdirectory(protocol)
 add_subdirectory(network)
+add_subdirectory(client)
diff --git a/cpp/celeborn/CMakeLists.txt b/cpp/celeborn/client/CMakeLists.txt
similarity index 63%
copy from cpp/celeborn/CMakeLists.txt
copy to cpp/celeborn/client/CMakeLists.txt
index c5fe93782..eaa077e3a 100644
--- a/cpp/celeborn/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -12,9 +12,27 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-add_subdirectory(proto)
-add_subdirectory(memory)
-add_subdirectory(utils)
-add_subdirectory(conf)
-add_subdirectory(protocol)
-add_subdirectory(network)
+add_library(
+        client
+        reader/WorkerPartitionReader.cpp)
+
+target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR})
+
+target_link_libraries(
+        client
+        network
+        proto
+        memory
+        protocol
+        utils
+        ${WANGLE}
+        ${FIZZ}
+        ${LIBSODIUM_LIBRARY}
+        ${FOLLY_WITH_DEPENDENCIES}
+        ${GLOG}
+        ${GFLAGS_LIBRARIES}
+)
+
+if(CELEBORN_BUILD_TESTS)
+    add_subdirectory(tests)
+endif()
\ No newline at end of file
diff --git a/cpp/celeborn/client/reader/WorkerPartitionReader.cpp 
b/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
new file mode 100644
index 000000000..a7983a01e
--- /dev/null
+++ b/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/reader/WorkerPartitionReader.h"
+
+namespace celeborn {
+namespace client {
+std::shared_ptr<WorkerPartitionReader> WorkerPartitionReader::create(
+    const std::shared_ptr<const conf::CelebornConf>& conf,
+    const std::string& shuffleKey,
+    const protocol::PartitionLocation& location,
+    int32_t startMapIndex,
+    int32_t endMapIndex,
+    network::TransportClientFactory* clientFactory) {
+  return std::shared_ptr<WorkerPartitionReader>(new WorkerPartitionReader(
+      conf, shuffleKey, location, startMapIndex, endMapIndex, clientFactory));
+}
+
+WorkerPartitionReader::WorkerPartitionReader(
+    const std::shared_ptr<const conf::CelebornConf>& conf,
+    const std::string& shuffleKey,
+    const protocol::PartitionLocation& location,
+    int32_t startMapIndex,
+    int32_t endMapIndex,
+    network::TransportClientFactory* clientFactory)
+    : shuffleKey_(shuffleKey),
+      location_(location),
+      startMapIndex_(startMapIndex),
+      endMapIndex_(endMapIndex),
+      fetchingChunkId_(0),
+      toConsumeChunkId_(0),
+      maxFetchChunksInFlight_(conf->clientFetchMaxReqsInFlight()),
+      fetchTimeout_(conf->clientFetchTimeout()) {
+  CELEBORN_CHECK_NOT_NULL(clientFactory);
+  client_ = clientFactory->createClient(location_.host, location_.fetchPort);
+
+  protocol::OpenStream openStream(
+      shuffleKey, location_.filename(), startMapIndex_, endMapIndex_);
+
+  network::RpcRequest request(
+      network::Message::nextRequestId(),
+      openStream.toTransportMessage().toReadOnlyByteBuffer());
+
+  // TODO: it might not be safe to call blocking & might failing command
+  // in constructor
+  auto response = client_->sendRpcRequestSync(request);
+  auto body = response.body();
+  auto transportMessage = protocol::TransportMessage(std::move(body));
+  streamHandler_ =
+      protocol::StreamHandler::fromTransportMessage(transportMessage);
+}
+
+WorkerPartitionReader::~WorkerPartitionReader() {
+  protocol::BufferStreamEnd bufferStreamEnd;
+  bufferStreamEnd.streamId = streamHandler_->streamId;
+  network::RpcRequest request(
+      network::Message::nextRequestId(),
+      bufferStreamEnd.toTransportMessage().toReadOnlyByteBuffer());
+  client_->sendRpcRequestWithoutResponse(request);
+}
+
+bool WorkerPartitionReader::hasNext() {
+  return toConsumeChunkId_ < streamHandler_->numChunks;
+}
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> WorkerPartitionReader::next() {
+  initAndCheck();
+  fetchChunks();
+  auto result = std::unique_ptr<memory::ReadOnlyByteBuffer>();
+  while (!result) {
+    initAndCheck();
+    // TODO: add metric or time tracing
+    chunkQueue_.try_dequeue_for(result, kDefaultConsumeIter);
+  }
+  toConsumeChunkId_++;
+  return std::move(result);
+}
+
+void WorkerPartitionReader::fetchChunks() {
+  initAndCheck();
+  while (fetchingChunkId_ - toConsumeChunkId_ < maxFetchChunksInFlight_ &&
+         fetchingChunkId_ < streamHandler_->numChunks) {
+    auto chunkId = fetchingChunkId_++;
+    auto streamChunkSlice = protocol::StreamChunkSlice{
+        streamHandler_->streamId, chunkId, 0, INT_MAX};
+    protocol::ChunkFetchRequest chunkFetchRequest;
+    chunkFetchRequest.streamChunkSlice = streamChunkSlice;
+    network::RpcRequest request(
+        network::Message::nextRequestId(),
+        chunkFetchRequest.toTransportMessage().toReadOnlyByteBuffer());
+    client_->fetchChunkAsync(streamChunkSlice, request, onSuccess_, 
onFailure_);
+  }
+}
+
+void WorkerPartitionReader::initAndCheck() {
+  if (!onSuccess_) {
+    onSuccess_ = [weak_this = weak_from_this()](
+                     protocol::StreamChunkSlice streamChunkSlice,
+                     std::unique_ptr<memory::ReadOnlyByteBuffer> chunk) {
+      auto shared_this = weak_this.lock();
+      if (!shared_this) {
+        return;
+      }
+      shared_this->chunkQueue_.enqueue(std::move(chunk));
+      VLOG(1) << "WorkerPartitionReader::onSuccess: "
+              << streamChunkSlice.toString();
+    };
+
+    onFailure_ = [weak_this = weak_from_this()](
+                     protocol::StreamChunkSlice streamChunkSlice,
+                     std::unique_ptr<std::exception> exception) {
+      auto shared_this = weak_this.lock();
+      if (!shared_this) {
+        return;
+      }
+      LOG(ERROR) << "WorkerPartitionReader::onFailure: "
+                 << streamChunkSlice.toString()
+                 << " msg: " << exception->what();
+      {
+        auto exp = shared_this->exception_.wlock();
+        *exp = std::move(exception);
+      }
+    };
+  }
+
+  {
+    auto exp = exception_.rlock();
+    if (*exp) {
+      CELEBORN_FAIL((*exp)->what());
+    }
+  }
+}
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/reader/WorkerPartitionReader.h 
b/cpp/celeborn/client/reader/WorkerPartitionReader.h
new file mode 100644
index 000000000..68fb00230
--- /dev/null
+++ b/cpp/celeborn/client/reader/WorkerPartitionReader.h
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "celeborn/network/TransportClient.h"
+#include "celeborn/protocol/PartitionLocation.h"
+
+namespace celeborn {
+namespace client {
+class PartitionReader {
+ public:
+  virtual ~PartitionReader() = default;
+
+  virtual bool hasNext() = 0;
+
+  virtual std::unique_ptr<memory::ReadOnlyByteBuffer> next() = 0;
+};
+
+class WorkerPartitionReader
+    : public PartitionReader,
+      public std::enable_shared_from_this<WorkerPartitionReader> {
+ public:
+  // Only allow using create method to get the shared_ptr holder. This is
+  // required by the std::enable_shared_from_this functionality.
+  static std::shared_ptr<WorkerPartitionReader> create(
+      const std::shared_ptr<const conf::CelebornConf>& conf,
+      const std::string& shuffleKey,
+      const protocol::PartitionLocation& location,
+      int32_t startMapIndex,
+      int32_t endMapIndex,
+      network::TransportClientFactory* clientFactory);
+
+  ~WorkerPartitionReader() override;
+
+  bool hasNext() override;
+
+  std::unique_ptr<memory::ReadOnlyByteBuffer> next() override;
+
+ private:
+  // Disable creating the object directly to make sure that
+  // std::enable_shared_from_this works properly.
+  WorkerPartitionReader(
+      const std::shared_ptr<const conf::CelebornConf>& conf,
+      const std::string& shuffleKey,
+      const protocol::PartitionLocation& location,
+      int32_t startMapIndex,
+      int32_t endMapIndex,
+      network::TransportClientFactory* clientFactory);
+
+  void fetchChunks();
+
+  // This function cannot be called within constructor!
+  void initAndCheck();
+
+  std::string shuffleKey_;
+  protocol::PartitionLocation location_;
+  std::shared_ptr<network::TransportClient> client_;
+  int32_t startMapIndex_;
+  int32_t endMapIndex_;
+  std::unique_ptr<protocol::StreamHandler> streamHandler_;
+
+  int32_t fetchingChunkId_;
+  int32_t toConsumeChunkId_;
+  int32_t maxFetchChunksInFlight_;
+  Timeout fetchTimeout_;
+
+  folly::UMPSCQueue<std::unique_ptr<memory::ReadOnlyByteBuffer>, true>
+      chunkQueue_;
+  network::FetchChunkSuccessCallback onSuccess_;
+  network::FetchChunkFailureCallback onFailure_;
+  folly::Synchronized<std::unique_ptr<std::exception>> exception_;
+
+  static constexpr auto kDefaultConsumeIter = std::chrono::milliseconds(500);
+
+  // TODO: add other params, such as fetchChunkRetryCnt, fetchChunkMaxRetry
+};
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/CMakeLists.txt 
b/cpp/celeborn/client/tests/CMakeLists.txt
similarity index 56%
copy from cpp/celeborn/CMakeLists.txt
copy to cpp/celeborn/client/tests/CMakeLists.txt
index c5fe93782..6543ca28d 100644
--- a/cpp/celeborn/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -12,9 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-add_subdirectory(proto)
-add_subdirectory(memory)
-add_subdirectory(utils)
-add_subdirectory(conf)
-add_subdirectory(protocol)
-add_subdirectory(network)
+
+add_executable(
+        celeborn_client_test
+        WorkerPartitionReaderTest.cpp)
+
+add_test(NAME celeborn_client_test COMMAND celeborn_client_test)
+
+target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR})
+
+target_link_libraries(
+        celeborn_client_test
+        PRIVATE
+        memory
+        conf
+        proto
+        protocol
+        utils
+        network
+        client
+        ${WANGLE}
+        ${FIZZ}
+        ${LIBSODIUM_LIBRARY}
+        ${FOLLY_WITH_DEPENDENCIES}
+        ${GLOG}
+        ${GFLAGS_LIBRARIES}
+        GTest::gtest
+        GTest::gmock
+        GTest::gtest_main)
\ No newline at end of file
diff --git a/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp 
b/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
new file mode 100644
index 000000000..d5a07a992
--- /dev/null
+++ b/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <gtest/gtest.h>
+
+#include "celeborn/client/reader/WorkerPartitionReader.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+using namespace celeborn::network;
+using namespace celeborn::protocol;
+using namespace celeborn::memory;
+using namespace celeborn::conf;
+
+namespace {
+using MS = std::chrono::milliseconds;
+
+class MockTransportClient : public TransportClient {
+ public:
+  MockTransportClient()
+      : TransportClient(nullptr, nullptr, MS(100)),
+        syncResponse_(RpcResponse(0, ReadOnlyByteBuffer::createEmptyBuffer())),
+        syncRequest_(RpcRequest(0, ReadOnlyByteBuffer::createEmptyBuffer())),
+        nonResponseRequest_(
+            RpcRequest(0, ReadOnlyByteBuffer::createEmptyBuffer())),
+        fetchChunkRequest_(
+            RpcRequest(0, ReadOnlyByteBuffer::createEmptyBuffer())) {}
+
+  RpcResponse sendRpcRequestSync(const RpcRequest& request, Timeout timeout)
+      override {
+    syncRequest_ = request;
+    return syncResponse_;
+  }
+
+  void sendRpcRequestWithoutResponse(const RpcRequest& request) override {
+    nonResponseRequest_ = request;
+  }
+
+  void fetchChunkAsync(
+      const StreamChunkSlice& streamChunkSlice,
+      const RpcRequest& request,
+      FetchChunkSuccessCallback onSuccess,
+      FetchChunkFailureCallback onFailure) override {
+    streamChunkSlice_ = streamChunkSlice;
+    fetchChunkRequest_ = request;
+    if (fetchChunkOnSuccessResult_) {
+      onSuccess(streamChunkSlice, std::move(fetchChunkOnSuccessResult_));
+    } else if (fetchChunkOnFailureResult_) {
+      onFailure(streamChunkSlice, std::move(fetchChunkOnFailureResult_));
+    }
+  }
+
+  void setSyncResponse(const RpcResponse& response) {
+    syncResponse_ = response;
+  }
+
+  void setFetchChunkSuccessResult(std::unique_ptr<ReadOnlyByteBuffer> result) {
+    fetchChunkOnSuccessResult_ = std::move(result);
+  }
+
+  void setFetchChunkFailureResult(std::unique_ptr<std::exception> result) {
+    fetchChunkOnFailureResult_ = std::move(result);
+  }
+
+  RpcRequest getSyncRequest() {
+    return syncRequest_;
+  }
+
+  RpcRequest getNoneResponseRequest() {
+    return nonResponseRequest_;
+  }
+
+  StreamChunkSlice getStreamChunkSlice() {
+    return streamChunkSlice_;
+  }
+
+  RpcRequest getFetchChunkRequest() {
+    return fetchChunkRequest_;
+  }
+
+ private:
+  RpcResponse syncResponse_;
+  RpcRequest syncRequest_;
+  RpcRequest nonResponseRequest_;
+  StreamChunkSlice streamChunkSlice_;
+  RpcRequest fetchChunkRequest_;
+  std::unique_ptr<ReadOnlyByteBuffer> fetchChunkOnSuccessResult_;
+  std::unique_ptr<std::exception> fetchChunkOnFailureResult_;
+};
+
+class MockTransportClientFactory : public TransportClientFactory {
+ public:
+  MockTransportClientFactory()
+      : TransportClientFactory(std::make_shared<CelebornConf>()),
+        transportClient_(std::make_shared<MockTransportClient>()) {}
+
+  std::shared_ptr<TransportClient> createClient(
+      const std::string& host,
+      uint16_t port) override {
+    return transportClient_;
+  }
+
+  std::shared_ptr<MockTransportClient> getClient() {
+    return transportClient_;
+  }
+
+ private:
+  std::shared_ptr<MockTransportClient> transportClient_;
+};
+
+std::unique_ptr<ReadOnlyByteBuffer> toReadOnlyByteBuffer(
+    const std::string& content) {
+  auto buffer = ByteBuffer::createWriteOnly(content.size());
+  buffer->writeFromString(content);
+  return ByteBuffer::toReadOnly(std::move(buffer));
+}
+} // namespace
+
+TEST(WorkerPartitionReaderTest, fetchChunkSuccess) {
+  const std::string shuffleKey = "test-shuffle-key";
+  PartitionLocation location;
+  location.host = "test-host";
+  location.fetchPort = 1011;
+  location.id = 1;
+  location.epoch = 2;
+  location.mode = PartitionLocation::Mode::PRIMARY;
+  location.storageInfo = std::make_unique<StorageInfo>();
+  const std::string filename = std::to_string(location.id) + "-" +
+      std::to_string(location.epoch) + "-" + std::to_string(location.mode);
+  const int32_t startMapIndex = 0;
+  const int32_t endMapIndex = 100;
+  MockTransportClientFactory mockedClientFactory;
+  auto conf = std::make_shared<CelebornConf>();
+  auto transportClient = mockedClientFactory.getClient();
+
+  // Build a pbStreamHandler with 1 chunk,  pack into a RpcResponse,
+  // set as response.
+  PbStreamHandler pb;
+  const int streamId = 100;
+  const int numChunks = 1;
+  pb.set_streamid(streamId);
+  pb.set_numchunks(numChunks);
+  for (int i = 0; i < numChunks; i++) {
+    pb.add_chunkoffsets(i);
+  }
+  pb.set_fullpath("test-fullpath");
+  TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+  RpcResponse response =
+      RpcResponse(1111, transportMessage.toReadOnlyByteBuffer());
+  transportClient->setSyncResponse(response);
+
+  // Set the chunk to be returned.
+  const std::string chunkBody = "test-chunk-body";
+  transportClient->setFetchChunkSuccessResult(toReadOnlyByteBuffer(chunkBody));
+
+  auto partitionReader = WorkerPartitionReader::create(
+      conf,
+      shuffleKey,
+      location,
+      startMapIndex,
+      endMapIndex,
+      &mockedClientFactory);
+  EXPECT_TRUE(partitionReader->hasNext());
+  auto chunk = partitionReader->next();
+  EXPECT_EQ(chunk->remainingSize(), chunkBody.size());
+  EXPECT_EQ(chunk->readToString(chunk->remainingSize()), chunkBody);
+  EXPECT_FALSE(partitionReader->hasNext());
+}
+
+TEST(WorkerPartitionReaderTest, fetchChunkFailure) {
+  const std::string shuffleKey = "test-shuffle-key";
+  PartitionLocation location;
+  location.host = "test-host";
+  location.fetchPort = 1011;
+  location.id = 1;
+  location.epoch = 2;
+  location.mode = PartitionLocation::Mode::PRIMARY;
+  location.storageInfo = std::make_unique<StorageInfo>();
+  const std::string filename = std::to_string(location.id) + "-" +
+      std::to_string(location.epoch) + "-" + std::to_string(location.mode);
+  const int32_t startMapIndex = 0;
+  const int32_t endMapIndex = 100;
+  MockTransportClientFactory mockedClientFactory;
+  auto conf = std::make_shared<CelebornConf>();
+  auto transportClient = mockedClientFactory.getClient();
+
+  // Build a pbStreamHandler with 1 chunk,  pack into a RpcResponse,
+  // set as response.
+  PbStreamHandler pb;
+  const int streamId = 100;
+  const int numChunks = 1;
+  pb.set_streamid(streamId);
+  pb.set_numchunks(numChunks);
+  for (int i = 0; i < numChunks; i++) {
+    pb.add_chunkoffsets(i);
+  }
+  pb.set_fullpath("test-fullpath");
+  TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+  RpcResponse response =
+      RpcResponse(1111, transportMessage.toReadOnlyByteBuffer());
+  transportClient->setSyncResponse(response);
+
+  // Set the error to be returned.
+  transportClient->setFetchChunkFailureResult(
+      std::make_unique<std::runtime_error>("test-runtime-error"));
+
+  auto partitionReader = WorkerPartitionReader::create(
+      conf,
+      shuffleKey,
+      location,
+      startMapIndex,
+      endMapIndex,
+      &mockedClientFactory);
+  EXPECT_TRUE(partitionReader->hasNext());
+  EXPECT_THROW(partitionReader->next(), std::exception);
+}
diff --git a/cpp/celeborn/network/TransportClient.h 
b/cpp/celeborn/network/TransportClient.h
index 35388d4d6..5bfa4afdc 100644
--- a/cpp/celeborn/network/TransportClient.h
+++ b/cpp/celeborn/network/TransportClient.h
@@ -65,7 +65,7 @@ class TransportClient {
       std::unique_ptr<MessageDispatcher> dispatcher,
       Timeout defaultTimeout);
 
-  RpcResponse sendRpcRequestSync(const RpcRequest& request) {
+  virtual RpcResponse sendRpcRequestSync(const RpcRequest& request) {
     return sendRpcRequestSync(request, defaultTimeout_);
   }
 
@@ -74,9 +74,9 @@ class TransportClient {
       Timeout timeout);
 
   // Ignore the response, return immediately.
-  void sendRpcRequestWithoutResponse(const RpcRequest& request);
+  virtual void sendRpcRequestWithoutResponse(const RpcRequest& request);
 
-  void fetchChunkAsync(
+  virtual void fetchChunkAsync(
       const protocol::StreamChunkSlice& streamChunkSlice,
       const RpcRequest& request,
       FetchChunkSuccessCallback onSuccess,
@@ -105,7 +105,7 @@ class TransportClientFactory {
  public:
   TransportClientFactory(const std::shared_ptr<conf::CelebornConf>& conf);
 
-  std::shared_ptr<TransportClient> createClient(
+  virtual std::shared_ptr<TransportClient> createClient(
       const std::string& host,
       uint16_t port);
 

Reply via email to