This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new f35b6b80a [CELEBORN-2206][CIP-14] Support PushData and Revive in Cpp's 
ShuffleClient
f35b6b80a is described below

commit f35b6b80ac13af3cb28cc8522bc209256a289f76
Author: HolyLow <[email protected]>
AuthorDate: Mon Dec 8 14:00:36 2025 +0800

    [CELEBORN-2206][CIP-14] Support PushData and Revive in Cpp's ShuffleClient
    
    ### What changes were proposed in this pull request?
    This PR supports PushData and Revive in Cpp's ShuffleClient so that the Cpp 
module is capable of writing to Celeborn Server.
    
    ### Why are the changes needed?
    This PR enables Cpp module to write to Celeborn Server.
    
    ### Does this PR resolve a correctness bug?
    No.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation.
    
    Closes #3553 from 
HolyLow/issue/celeborn-2215-support-PushData-and-Revive-in-cpp-ShuffleClient.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/ShuffleClient.cpp              | 540 ++++++++++++++++++++-
 cpp/celeborn/client/ShuffleClient.h                | 137 ++++--
 cpp/celeborn/client/tests/PushDataCallbackTest.cpp |  10 +-
 cpp/celeborn/client/tests/ReviveManagerTest.cpp    |  11 +-
 cpp/celeborn/conf/CelebornConf.cpp                 |  97 ++--
 cpp/celeborn/conf/CelebornConf.h                   |  39 +-
 cpp/celeborn/memory/ByteBuffer.h                   |   7 +-
 cpp/celeborn/network/TransportClient.cpp           |  11 +-
 cpp/celeborn/network/TransportClient.h             |   5 +-
 cpp/celeborn/tests/DataSumWithReaderClient.cpp     |   6 +-
 cpp/celeborn/utils/CelebornUtils.cpp               |   4 +
 cpp/celeborn/utils/CelebornUtils.h                 |   8 +
 12 files changed, 805 insertions(+), 70 deletions(-)

diff --git a/cpp/celeborn/client/ShuffleClient.cpp 
b/cpp/celeborn/client/ShuffleClient.cpp
index 22d19256c..807fac2e0 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -21,19 +21,46 @@
 
 namespace celeborn {
 namespace client {
+
+ShuffleClientEndpoint::ShuffleClientEndpoint(
+    const std::shared_ptr<const conf::CelebornConf>& conf)
+    : conf_(conf),
+      pushDataRetryPool_(std::make_shared<folly::IOThreadPoolExecutor>(
+          conf_->clientPushRetryThreads(),
+          std::make_shared<folly::NamedThreadFactory>(
+              "client-pushdata-retrier"))),
+      clientFactory_(std::make_shared<network::TransportClientFactory>(conf_)) 
{
+}
+
+std::shared_ptr<folly::IOThreadPoolExecutor>
+ShuffleClientEndpoint::pushDataRetryPool() const {
+  return pushDataRetryPool_;
+}
+
+std::shared_ptr<network::TransportClientFactory>
+ShuffleClientEndpoint::clientFactory() const {
+  return clientFactory_;
+}
+
 std::shared_ptr<ShuffleClientImpl> ShuffleClientImpl::create(
     const std::string& appUniqueId,
     const std::shared_ptr<const conf::CelebornConf>& conf,
-    const std::shared_ptr<network::TransportClientFactory>& clientFactory) {
+    const ShuffleClientEndpoint& clientEndpoint) {
   return std::shared_ptr<ShuffleClientImpl>(
-      new ShuffleClientImpl(appUniqueId, conf, clientFactory));
+      new ShuffleClientImpl(appUniqueId, conf, clientEndpoint));
 }
 
 ShuffleClientImpl::ShuffleClientImpl(
     const std::string& appUniqueId,
     const std::shared_ptr<const conf::CelebornConf>& conf,
-    const std::shared_ptr<network::TransportClientFactory>& clientFactory)
-    : appUniqueId_(appUniqueId), conf_(conf), clientFactory_(clientFactory) {}
+    const ShuffleClientEndpoint& clientEndpoint)
+    : appUniqueId_(appUniqueId),
+      conf_(conf),
+      clientFactory_(clientEndpoint.clientFactory()),
+      pushDataRetryPool_(clientEndpoint.pushDataRetryPool()) {
+  CELEBORN_CHECK_NOT_NULL(clientFactory_);
+  CELEBORN_CHECK_NOT_NULL(pushDataRetryPool_);
+}
 
 void ShuffleClientImpl::setupLifecycleManagerRef(std::string& host, int port) {
   auto managerClient = clientFactory_->createClient(host, port);
@@ -47,6 +74,8 @@ void ShuffleClientImpl::setupLifecycleManagerRef(std::string& 
host, int port) {
         port,
         managerClient,
         *conf_);
+
+    initReviveManagerLocked();
   }
 }
 
@@ -54,6 +83,172 @@ void ShuffleClientImpl::setupLifecycleManagerRef(
     std::shared_ptr<network::NettyRpcEndpointRef>& lifecycleManagerRef) {
   std::lock_guard<std::mutex> lock(mutex_);
   lifecycleManagerRef_ = lifecycleManagerRef;
+
+  initReviveManagerLocked();
+}
+
+std::shared_ptr<utils::ConcurrentHashMap<
+    int,
+    std::shared_ptr<const protocol::PartitionLocation>>>
+ShuffleClientImpl::getPartitionLocation(
+    int shuffleId,
+    int numMappers,
+    int numPartitions) {
+  auto partitionLocationOptional = partitionLocationMaps_.get(shuffleId);
+  if (partitionLocationOptional.has_value()) {
+    return partitionLocationOptional.value();
+  }
+
+  registerShuffle(shuffleId, numMappers, numPartitions);
+
+  partitionLocationOptional = partitionLocationMaps_.get(shuffleId);
+  CELEBORN_CHECK(
+      partitionLocationOptional.has_value(),
+      "partitionLocation is empty because registerShuffle failed");
+  auto partitionLocationMap = partitionLocationOptional.value();
+  CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+  return partitionLocationMap;
+}
+
+int ShuffleClientImpl::pushData(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    const uint8_t* data,
+    size_t offset,
+    size_t length,
+    int numMappers,
+    int numPartitions) {
+  const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+  if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+    return 0;
+  }
+
+  auto partitionLocationMap =
+      getPartitionLocation(shuffleId, numMappers, numPartitions);
+  CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+  auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+  if (!partitionLocationOptional.has_value()) {
+    if (!revive(
+            shuffleId,
+            mapId,
+            attemptId,
+            partitionId,
+            -1,
+            nullptr,
+            protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+      CELEBORN_FAIL(fmt::format(
+          "Revive for shuffleId {} partitionId {} failed.",
+          shuffleId,
+          partitionId));
+    }
+    partitionLocationOptional = partitionLocationMap->get(partitionId);
+  }
+  if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+    return 0;
+  }
+
+  CELEBORN_CHECK(partitionLocationOptional.has_value());
+  auto partitionLocation = partitionLocationOptional.value();
+  auto pushState = getPushState(mapKey);
+  const int nextBatchId = pushState->nextBatchId();
+
+  // TODO: compression in writing is not supported.
+
+  auto writeBuffer =
+      memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length);
+  // TODO: the java side uses Platform to write the data. We simply assume
+  //  littleEndian here.
+  writeBuffer->writeLE<int>(mapId);
+  writeBuffer->writeLE<int>(attemptId);
+  writeBuffer->writeLE<int>(nextBatchId);
+  writeBuffer->writeLE<int>(length);
+  writeBuffer->writeFromBuffer(data, offset, length);
+
+  auto hostAndPushPort = partitionLocation->hostAndPushPort();
+  // Check limit.
+  limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+  // Add inFlight requests.
+  pushState->addBatch(nextBatchId, hostAndPushPort);
+  // Build pushData request.
+  const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
+  auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+  network::PushData pushData(
+      network::Message::nextRequestId(),
+      protocol::PartitionLocation::Mode::PRIMARY,
+      shuffleKey,
+      partitionLocation->uniqueId(),
+      body->clone());
+  // Build callback.
+  auto pushDataCallback = PushDataCallback::create(
+      shuffleId,
+      mapId,
+      attemptId,
+      partitionId,
+      numMappers,
+      numPartitions,
+      mapKey,
+      nextBatchId,
+      body->clone(),
+      pushState,
+      weak_from_this(),
+      conf_->clientPushMaxReviveTimes(),
+      partitionLocation);
+  // Do push data.
+  auto client = clientFactory_->createClient(
+      partitionLocation->host, partitionLocation->pushPort, partitionId);
+  client->pushDataAsync(
+      pushData, conf_->clientPushDataTimeout(), pushDataCallback);
+  return body->remainingSize();
+}
+
+void ShuffleClientImpl::mapperEnd(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int numMappers) {
+  mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, -1);
+}
+
+void ShuffleClientImpl::mapPartitionMapperEnd(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int numMappers,
+    int partitionId) {
+  auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+  auto pushState = getPushState(mapKey);
+
+  try {
+    limitZeroInFlight(mapKey, *pushState);
+
+    auto mapperEndResponse =
+        lifecycleManagerRef_
+            ->askSync<protocol::MapperEnd, protocol::MapperEndResponse>(
+                protocol::MapperEnd{
+                    shuffleId, mapId, attemptId, numMappers, partitionId});
+    if (mapperEndResponse->status != protocol::StatusCode::SUCCESS) {
+      CELEBORN_FAIL(
+          "MapperEnd failed. protocol::StatusCode " +
+          std::to_string(mapperEndResponse->status));
+    }
+  } catch (std::exception& e) {
+    LOG(ERROR) << "mapperEnd failed, error msg: " << e.what();
+    pushStates_.erase(mapKey);
+    CELEBORN_FAIL(e.what());
+  }
+  pushStates_.erase(mapKey);
+}
+
+void ShuffleClientImpl::cleanup(int shuffleId, int mapId, int attemptId) {
+  auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+  auto pushStateOptional = pushStates_.erase(mapKey);
+  if (pushStateOptional.has_value()) {
+    auto pushState = pushStateOptional.value();
+    pushState->setException(
+        std::make_unique<std::runtime_error>(mapKey + "is cleaned up"));
+  }
 }
 
 std::unique_ptr<CelebornInputStream> ShuffleClientImpl::readPartition(
@@ -143,6 +338,343 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
   return true;
 }
 
+std::shared_ptr<PushState> ShuffleClientImpl::getPushState(
+    const std::string& mapKey) {
+  return pushStates_.computeIfAbsent(
+      mapKey, [&]() { return std::make_shared<PushState>(*conf_); });
+}
+
+void ShuffleClientImpl::initReviveManagerLocked() {
+  if (!reviveManager_) {
+    std::string uniqueName = appUniqueId_;
+    uniqueName += std::to_string(utils::currentTimeNanos());
+    reviveManager_ =
+        ReviveManager::create(uniqueName, *conf_, weak_from_this());
+  }
+}
+
+void ShuffleClientImpl::registerShuffle(
+    int shuffleId,
+    int numMappers,
+    int numPartitions) {
+  auto shuffleMutex = shuffleMutexes_.computeIfAbsent(
+      shuffleId, []() { return std::make_shared<std::mutex>(); });
+  // RegisterShuffle might be issued concurrently, we only allow one issue
+  // for each shuffleId.
+  std::lock_guard<std::mutex> lock(*shuffleMutex);
+  if (partitionLocationMaps_.containsKey(shuffleId)) {
+    return;
+  }
+  CELEBORN_CHECK(
+      lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
+  const int maxRetries = conf_->clientRegisterShuffleMaxRetries();
+  int numRetries = 1;
+  for (; numRetries <= maxRetries; numRetries++) {
+    try {
+      // Send the query request to lifecycleManager.
+      auto registerShuffleResponse = lifecycleManagerRef_->askSync<
+          protocol::RegisterShuffle,
+          protocol::RegisterShuffleResponse>(
+          protocol::RegisterShuffle{shuffleId, numMappers, numPartitions},
+          conf_->clientRpcRegisterShuffleRpcAskTimeout());
+
+      switch (registerShuffleResponse->status) {
+        case protocol::StatusCode::SUCCESS: {
+          VLOG(1) << "success to registerShuffle, shuffleId " << shuffleId
+                  << " numMappers " << numMappers << " numPartitions "
+                  << numPartitions;
+          auto partitionLocationMap = 
std::make_shared<utils::ConcurrentHashMap<
+              int,
+              std::shared_ptr<const protocol::PartitionLocation>>>();
+          auto& partitionLocations =
+              registerShuffleResponse->partitionLocations;
+          for (auto i = 0; i < partitionLocations.size(); i++) {
+            auto id = partitionLocations[i]->id;
+            partitionLocationMap->set(id, std::move(partitionLocations[i]));
+          }
+          partitionLocationMaps_.set(
+              shuffleId, std::move(partitionLocationMap));
+          return;
+        }
+        default: {
+          LOG(ERROR)
+              << "LifecycleManager request slots return protocol::StatusCode "
+              << registerShuffleResponse->status << " , shuffleId " << 
shuffleId
+              << " numMappers " << numMappers << " numPartitions "
+              << numPartitions << " , retry again, remain retry times "
+              << maxRetries - numRetries;
+        }
+      }
+    } catch (std::exception& e) {
+      CELEBORN_FAIL(fmt::format(
+          "registerShuffle encounters error after {} tries, "
+          "shuffleId {} numMappers {} numPartitions {}, errorMsg: {}",
+          numRetries,
+          shuffleId,
+          numMappers,
+          numPartitions,
+          e.what()));
+      break;
+    }
+    std::this_thread::sleep_for(conf_->clientRegisterShuffleRetryWait());
+  }
+  partitionLocationMaps_.set(shuffleId, nullptr);
+  CELEBORN_FAIL(fmt::format(
+      "registerShuffle failed after {} tries, "
+      "shuffleId {} numMappers {} numPartitions {}",
+      maxRetries,
+      shuffleId,
+      numMappers,
+      numPartitions));
+}
+
+void ShuffleClientImpl::submitRetryPushData(
+    int shuffleId,
+    std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+    int batchId,
+    std::shared_ptr<PushDataCallback> pushDataCallback,
+    std::shared_ptr<PushState> pushState,
+    PtrReviveRequest request,
+    int remainReviveTimes,
+    long dueTimeMs) {
+  long reviveWaitTimeMs = dueTimeMs - utils::currentTimeMillis();
+  long accumulatedTimeMs = 0;
+  const long deltaMs = 50;
+  while (request->reviveStatus.load() ==
+             protocol::StatusCode::REVIVE_INITIALIZED &&
+         accumulatedTimeMs <= reviveWaitTimeMs) {
+    std::this_thread::sleep_for(utils::MS(deltaMs));
+    accumulatedTimeMs += deltaMs;
+  }
+  if (mapperEnded(shuffleId, request->mapId)) {
+    if (request->loc) {
+      VLOG(1) << "Revive for push data success, but the mapper already ended "
+                 "for shuffle "
+              << shuffleId << " map " << request->mapId << " attempt "
+              << request->attemptId << " partition " << request->partitionId
+              << " batch " << batchId << " location hostAndPushPort "
+              << request->loc->hostAndPushPort() << ".";
+      pushState->removeBatch(batchId, request->loc->hostAndPushPort());
+    } else {
+      VLOG(1) << "Revive for push data success, but the mapper already ended "
+                 "for shuffle "
+              << shuffleId << " map " << request->mapId << " attempt "
+              << request->attemptId << " partition " << request->partitionId
+              << " batch " << batchId << " no location available.";
+    }
+    return;
+  }
+  if (request->reviveStatus.load() != protocol::StatusCode::SUCCESS) {
+    // TODO: the exception message here should be assembled.
+    pushDataCallback->onFailure(std::make_unique<std::exception>());
+    return;
+  }
+  auto locationMapOptional = partitionLocationMaps_.get(shuffleId);
+  CELEBORN_CHECK(locationMapOptional.has_value());
+  auto newLocationOptional =
+      locationMapOptional.value()->get(request->partitionId);
+  CELEBORN_CHECK(newLocationOptional.has_value());
+  auto newLocation = newLocationOptional.value();
+  LOG(INFO) << "Revive for push data success, new location for shuffle "
+            << shuffleId << " map " << request->mapId << " attempt "
+            << request->attemptId << " partition " << request->partitionId
+            << " batch " << batchId << " is location hostAndPushPort "
+            << newLocation->hostAndPushPort() << ".";
+  pushDataCallback->updateLatestLocation(newLocation);
+
+  try {
+    CELEBORN_CHECK_GT(remainReviveTimes, 0, "no remainReviveTime left");
+    network::PushData pushData(
+        network::Message::nextRequestId(),
+        protocol::PartitionLocation::Mode::PRIMARY,
+        utils::makeShuffleKey(appUniqueId_, shuffleId),
+        newLocation->uniqueId(),
+        std::move(body));
+    auto client = clientFactory_->createClient(
+        newLocation->host, newLocation->pushPort, request->partitionId);
+    client->pushDataAsync(
+        pushData, conf_->clientPushDataTimeout(), pushDataCallback);
+  } catch (const std::exception& e) {
+    LOG(ERROR) << "Exception raised while pushing data for shuffle "
+               << shuffleId << " map " << request->mapId << " attempt "
+               << request->attemptId << " partition " << request->partitionId
+               << " batch " << batchId << " location hostAndPushPort "
+               << newLocation->hostAndPushPort() << " errorMsg " << e.what()
+               << ".";
+    // TODO: The failure should be treated better.
+    pushDataCallback->onFailure(std::make_unique<std::exception>(e));
+  }
+}
+
+bool ShuffleClientImpl::checkMapperEnded(
+    int shuffleId,
+    int mapId,
+    const std::string& mapKey) {
+  if (mapperEnded(shuffleId, mapId)) {
+    VLOG(1) << "Mapper already ended for shuffle " << shuffleId << " map "
+            << mapId;
+    if (auto pushStateOptional = pushStates_.get(mapKey);
+        pushStateOptional.has_value()) {
+      auto pushState = pushStateOptional.value();
+      pushState->cleanup();
+    }
+    return true;
+  }
+  return false;
+}
+
+bool ShuffleClientImpl::mapperEnded(int shuffleId, int mapId) {
+  if (auto mapperEndSetOptional = mapperEndSets_.get(shuffleId);
+      mapperEndSetOptional.has_value() &&
+      mapperEndSetOptional.value()->contains(mapId)) {
+    return true;
+  }
+  if (stageEnded(shuffleId)) {
+    return true;
+  }
+  return false;
+}
+
+bool ShuffleClientImpl::stageEnded(int shuffleId) {
+  return stageEndShuffleSet_.contains(shuffleId);
+}
+
+void ShuffleClientImpl::addRequestToReviveManager(
+    std::shared_ptr<protocol::ReviveRequest> reviveRequest) {
+  reviveManager_->addRequest(std::move(reviveRequest));
+}
+
+std::optional<std::unordered_map<int, int>> ShuffleClientImpl::reviveBatch(
+    int shuffleId,
+    const std::unordered_set<int>& mapIds,
+    const std::unordered_map<int, PtrReviveRequest>& requests) {
+  std::unordered_map<int, int> result;
+  auto partitionLocationMap = partitionLocationMaps_.get(shuffleId).value();
+  std::unordered_map<int, std::shared_ptr<const protocol::PartitionLocation>>
+      oldLocationMap;
+  protocol::Revive revive;
+  revive.shuffleId = shuffleId;
+  revive.mapIds.insert(mapIds.begin(), mapIds.end());
+  for (auto& [partitionId, request] : requests) {
+    oldLocationMap[request->partitionId] = request->loc;
+    revive.reviveRequests.insert(request);
+  }
+  try {
+    auto response =
+        lifecycleManagerRef_
+            ->askSync<protocol::Revive, protocol::ChangeLocationResponse>(
+                revive,
+                conf_->clientRpcRequestPartitionLocationRpcAskTimeout());
+    auto mapperEndSet = mapperEndSets_.computeIfAbsent(shuffleId, []() {
+      return std::make_shared<utils::ConcurrentHashSet<int>>();
+    });
+    for (auto endedMapId : response->endedMapIds) {
+      mapperEndSet->insert(endedMapId);
+    }
+    for (auto& partitionInfo : response->partitionInfos) {
+      switch (partitionInfo.status) {
+        case protocol::StatusCode::SUCCESS: {
+          partitionLocationMap->set(
+              partitionInfo.partitionId, partitionInfo.partition);
+          break;
+        }
+        case protocol::StatusCode::STAGE_ENDED: {
+          stageEndShuffleSet_.insert(shuffleId);
+          return {std::move(result)};
+        }
+        case protocol::StatusCode::SHUFFLE_NOT_REGISTERED: {
+          LOG(ERROR) << "shuffleId " << shuffleId << " not registered!";
+          return std::nullopt;
+        }
+        default: {
+          // noop
+        }
+      }
+      result[partitionInfo.partitionId] = partitionInfo.status;
+    }
+    return {std::move(result)};
+  } catch (std::exception& e) {
+    LOG(ERROR) << "reviveBatch failed: " << e.what();
+    return std::nullopt;
+  }
+}
+
+bool ShuffleClientImpl::revive(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    int epoch,
+    std::shared_ptr<const protocol::PartitionLocation> oldLocation,
+    protocol::StatusCode cause) {
+  auto request = std::make_shared<protocol::ReviveRequest>(
+      shuffleId, mapId, attemptId, partitionId, epoch, oldLocation, cause);
+  auto resultOptional =
+      reviveBatch(shuffleId, {mapId}, {{partitionId, request}});
+  if (mapperEnded(shuffleId, mapId)) {
+    VLOG(1) << "Revive success, but the mapper ended for shuffle " << shuffleId
+            << " map " << mapId << " attempt " << attemptId << " partition"
+            << partitionId << ", just return true(Assume revive 
successfully).";
+    return true;
+  }
+  if (resultOptional.has_value()) {
+    auto result = resultOptional.value();
+    return result.find(partitionId) != result.end() &&
+        result[partitionId] == protocol::StatusCode::SUCCESS;
+  }
+  return false;
+}
+
+void ShuffleClientImpl::limitMaxInFlight(
+    const std::string& mapKey,
+    PushState& pushState,
+    const std::string& hostAndPushPort) {
+  bool reachLimit = pushState.limitMaxInFlight(hostAndPushPort);
+  if (reachLimit) {
+    auto msg = fmt::format(
+        "Waiting timeout for task {} while limiting max "
+        "in-flight requests to {}.",
+        mapKey,
+        hostAndPushPort);
+    if (auto exceptionMsgOptional = pushState.getExceptionMsg();
+        exceptionMsgOptional.has_value()) {
+      msg += " PushState exception: " + exceptionMsgOptional.value();
+    }
+    CELEBORN_FAIL(msg);
+  }
+}
+
+void ShuffleClientImpl::limitZeroInFlight(
+    const std::string& mapKey,
+    PushState& pushState) {
+  bool reachLimit = pushState.limitZeroInFlight();
+  if (reachLimit) {
+    auto msg = fmt::format(
+        "Waiting timeout for task {} while limiting zero "
+        "in-flight requests.",
+        mapKey);
+    if (auto exceptionMsgOptional = pushState.getExceptionMsg();
+        exceptionMsgOptional.has_value()) {
+      msg += " PushState exception: " + exceptionMsgOptional.value();
+    }
+    CELEBORN_FAIL(msg);
+  }
+}
+
+std::optional<ShuffleClientImpl::PtrPartitionLocationMap>
+ShuffleClientImpl::getPartitionLocationMap(int shuffleId) {
+  return partitionLocationMaps_.get(shuffleId);
+}
+
+utils::ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
+ShuffleClientImpl::mapperEndSets() {
+  return mapperEndSets_;
+}
+
+void ShuffleClientImpl::addPushDataRetryTask(folly::Func&& task) {
+  pushDataRetryPool_->add(std::move(task));
+}
+
 bool ShuffleClientImpl::newerPartitionLocationExists(
     std::shared_ptr<utils::ConcurrentHashMap<
         int,
diff --git a/cpp/celeborn/client/ShuffleClient.h 
b/cpp/celeborn/client/ShuffleClient.h
index dc71a39bc..3e8cb9d37 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -32,6 +32,25 @@ class ShuffleClient {
   virtual void setupLifecycleManagerRef(
       std::shared_ptr<network::NettyRpcEndpointRef>& lifecycleManagerRef) = 0;
 
+  virtual int pushData(
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      const uint8_t* data,
+      size_t offset,
+      size_t length,
+      int numMappers,
+      int numPartitions) = 0;
+
+  // TODO: PushMergedData is not supported yet.
+
+  virtual void
+  mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers) = 0;
+
+  // Cleanup states of a map task.
+  virtual void cleanup(int shuffleId, int mapId, int attemptId) = 0;
+
   virtual void updateReducerFileGroup(int shuffleId) = 0;
 
   virtual std::unique_ptr<CelebornInputStream> readPartition(
@@ -57,6 +76,23 @@ class ShuffleClient {
 class ReviveManager;
 class PushDataCallback;
 
+/// ShuffleClientEndpoint holds all the resources of ShuffleClient, including
+/// threadPools and clientFactories. The endpoint could be reused by multiple
+/// ShuffleClient to avoid creating too many resources.
+class ShuffleClientEndpoint {
+ public:
+  ShuffleClientEndpoint(const std::shared_ptr<const conf::CelebornConf>& conf);
+
+  std::shared_ptr<folly::IOThreadPoolExecutor> pushDataRetryPool() const;
+
+  std::shared_ptr<network::TransportClientFactory> clientFactory() const;
+
+ private:
+  const std::shared_ptr<const conf::CelebornConf> conf_;
+  std::shared_ptr<folly::IOThreadPoolExecutor> pushDataRetryPool_;
+  std::shared_ptr<network::TransportClientFactory> clientFactory_;
+};
+
 class ShuffleClientImpl
     : public ShuffleClient,
       public std::enable_shared_from_this<ShuffleClientImpl> {
@@ -75,13 +111,41 @@ class ShuffleClientImpl
   static std::shared_ptr<ShuffleClientImpl> create(
       const std::string& appUniqueId,
       const std::shared_ptr<const conf::CelebornConf>& conf,
-      const std::shared_ptr<network::TransportClientFactory>& clientFactory);
+      const ShuffleClientEndpoint& clientEndpoint);
 
   void setupLifecycleManagerRef(std::string& host, int port) override;
 
   void setupLifecycleManagerRef(std::shared_ptr<network::NettyRpcEndpointRef>&
                                     lifecycleManagerRef) override;
 
+  std::shared_ptr<utils::ConcurrentHashMap<
+      int,
+      std::shared_ptr<const protocol::PartitionLocation>>>
+  getPartitionLocation(int shuffleId, int numMappers, int numPartitions);
+
+  int pushData(
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      const uint8_t* data,
+      size_t offset,
+      size_t length,
+      int numMappers,
+      int numPartitions) override;
+
+  void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers)
+      override;
+
+  void mapPartitionMapperEnd(
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int numMappers,
+      int partitionId);
+
+  void cleanup(int shuffleId, int mapId, int attemptId) override;
+
   std::unique_ptr<CelebornInputStream> readPartition(
       int shuffleId,
       int partitionId,
@@ -109,10 +173,8 @@ class ShuffleClientImpl
   ShuffleClientImpl(
       const std::string& appUniqueId,
       const std::shared_ptr<const conf::CelebornConf>& conf,
-      const std::shared_ptr<network::TransportClientFactory>& clientFactory);
+      const ShuffleClientEndpoint& clientEndpoint);
 
-  // TODO: currently this function serves as a stub. will be updated in future
-  //  commits.
   virtual void submitRetryPushData(
       int shuffleId,
       std::unique_ptr<memory::ReadOnlyByteBuffer> body,
@@ -121,44 +183,58 @@ class ShuffleClientImpl
       std::shared_ptr<PushState> pushState,
       PtrReviveRequest request,
       int remainReviveTimes,
-      long dueTimeMs) {}
+      long dueTimeMs);
 
-  // TODO: currently this function serves as a stub. will be updated in future
-  //  commits.
-  virtual bool mapperEnded(int shuffleId, int mapId) {
-    return true;
-  }
+  virtual bool mapperEnded(int shuffleId, int mapId);
 
-  // TODO: currently this function serves as a stub. will be updated in future
-  //  commits.
   virtual void addRequestToReviveManager(
-      std::shared_ptr<protocol::ReviveRequest> reviveRequest) {}
+      std::shared_ptr<protocol::ReviveRequest> reviveRequest);
 
-  // TODO: currently this function serves as a stub. will be updated in future
-  //  commits.
   virtual std::optional<std::unordered_map<int, int>> reviveBatch(
       int shuffleId,
       const std::unordered_set<int>& mapIds,
-      const std::unordered_map<int, PtrReviveRequest>& requests) {
-    return std::nullopt;
-  }
+      const std::unordered_map<int, PtrReviveRequest>& requests);
 
   virtual std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
-      int shuffleId) {
-    return partitionLocationMaps_.get(shuffleId);
-  }
+      int shuffleId);
 
   virtual utils::
       ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
-      mapperEndSets() {
-    return mapperEndSets_;
-  }
+      mapperEndSets();
 
-  virtual void addPushDataRetryTask(folly::Func&& task) {
-    pushDataRetryPool_->add(std::move(task));
-  }
+  virtual void addPushDataRetryTask(folly::Func&& task);
 
  private:
+  std::shared_ptr<PushState> getPushState(const std::string& mapKey);
+
+  void initReviveManagerLocked();
+
+  void registerShuffle(int shuffleId, int numMappers, int numPartitions);
+
+  bool checkMapperEnded(int shuffleId, int mapId, const std::string& mapKey);
+
+  bool stageEnded(int shuffleId);
+
+  bool revive(
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      int epoch,
+      std::shared_ptr<const protocol::PartitionLocation> oldLocation,
+      protocol::StatusCode cause);
+
+  // Check if the pushState's ongoing package num reaches the max limit, if so,
+  // block until the ongoing package num decreases below max limit.
+  void limitMaxInFlight(
+      const std::string& mapKey,
+      PushState& pushState,
+      const std::string& hostAndPushPort);
+
+  // Check if the pushState's ongoing package num reaches zero, if not, block
+  // until the ongoing package num decreases to zero.
+  void limitZeroInFlight(const std::string& mapKey, PushState& pushState);
+
   // TODO: no support for WAIT as it is not used.
   static bool newerPartitionLocationExists(
       std::shared_ptr<utils::ConcurrentHashMap<
@@ -170,6 +246,8 @@ class ShuffleClientImpl
   std::shared_ptr<protocol::GetReducerFileGroupResponse>
   getReducerFileGroupInfo(int shuffleId);
 
+  static constexpr size_t kBatchHeaderSize = 4 * 4;
+
   const std::string appUniqueId_;
   std::shared_ptr<const conf::CelebornConf> conf_;
   std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
@@ -177,13 +255,18 @@ class ShuffleClientImpl
   std::shared_ptr<folly::IOExecutor> pushDataRetryPool_;
   std::shared_ptr<ReviveManager> reviveManager_;
   std::mutex mutex_;
+  utils::ConcurrentHashMap<int, std::shared_ptr<std::mutex>> shuffleMutexes_;
   utils::ConcurrentHashMap<
       int,
       std::shared_ptr<protocol::GetReducerFileGroupResponse>>
       reducerFileGroupInfos_;
   utils::ConcurrentHashMap<int, PtrPartitionLocationMap> 
partitionLocationMaps_;
+  utils::ConcurrentHashMap<std::string, std::shared_ptr<PushState>> 
pushStates_;
   utils::ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>
       mapperEndSets_;
+  utils::ConcurrentHashSet<int> stageEndShuffleSet_;
+
+  // TODO: pushExcludedWorker is not supported yet
 };
 } // namespace client
 } // namespace celeborn
diff --git a/cpp/celeborn/client/tests/PushDataCallbackTest.cpp 
b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
index 171c8e714..f91637488 100644
--- a/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
+++ b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
@@ -106,7 +106,9 @@ class MockShuffleClient : public ShuffleClientImpl {
       : ShuffleClientImpl(
             "mock",
             std::make_shared<conf::CelebornConf>(),
-            nullptr) {}
+            dummyEndpoint()) {}
+
+  static const ShuffleClientEndpoint& dummyEndpoint();
 
   FuncOnSubmitRetryPushData onSubmitRetryPushData_ =
       [](int,
@@ -126,6 +128,12 @@ class MockShuffleClient : public ShuffleClientImpl {
   };
 };
 
+const ShuffleClientEndpoint& MockShuffleClient::dummyEndpoint() {
+  static auto conf = std::make_shared<conf::CelebornConf>();
+  static auto dummy = ShuffleClientEndpoint(conf);
+  return dummy;
+}
+
 std::unique_ptr<memory::ReadOnlyByteBuffer> createReadOnlyByteBuffer(
     uint8_t code) {
   auto writeBuffer = memory::ByteBuffer::createWriteOnly(1);
diff --git a/cpp/celeborn/client/tests/ReviveManagerTest.cpp 
b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
index 8db844485..7bae1ab2a 100644
--- a/cpp/celeborn/client/tests/ReviveManagerTest.cpp
+++ b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
@@ -72,7 +72,10 @@ class MockShuffleClient : public ShuffleClientImpl {
       : ShuffleClientImpl(
             "mock",
             std::make_shared<conf::CelebornConf>(),
-            nullptr) {}
+            dummyEndpoint()) {}
+
+  static const ShuffleClientEndpoint& dummyEndpoint();
+
   std::function<bool(int, int)> onMapperEnded_ = [](int, int) { return false; 
};
   std::function<std::optional<std::unordered_map<int, int>>(
       int,
@@ -89,6 +92,12 @@ class MockShuffleClient : public ShuffleClientImpl {
     return {std::make_shared<PartitionLocationMap>()};
   };
 };
+
+const ShuffleClientEndpoint& MockShuffleClient::dummyEndpoint() {
+  static auto conf = std::make_shared<conf::CelebornConf>();
+  static auto dummy = ShuffleClientEndpoint(conf);
+  return dummy;
+}
 } // namespace
 
 class ReviveManagerTest : public testing::Test {
diff --git a/cpp/celeborn/conf/CelebornConf.cpp 
b/cpp/celeborn/conf/CelebornConf.cpp
index 1d58516b3..50b48aa7f 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -131,39 +131,50 @@ Duration toDuration(const std::string& str) {
 
 } // namespace
 
-const std::unordered_map<std::string, folly::Optional<std::string>>
-    CelebornConf::kDefaultProperties = {
-        STR_PROP(kRpcAskTimeout, "60s"),
-        STR_PROP(kRpcLookupTimeout, "30s"),
-        STR_PROP(kClientPushReviveInterval, "100ms"),
-        NUM_PROP(kClientPushReviveBatchSize, 2048),
-        STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
-        NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
-        NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
-        NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
-        NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
-        STR_PROP(kClientRpcRequestPartitionLocationAskTimeout, "60s"),
-        STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
-        STR_PROP(kNetworkConnectTimeout, "10s"),
-        STR_PROP(kClientFetchTimeout, "600s"),
-        NUM_PROP(kNetworkIoNumConnectionsPerPeer, 1),
-        NUM_PROP(kNetworkIoClientThreads, 0),
-        NUM_PROP(kClientFetchMaxReqsInFlight, 3),
-        STR_PROP(
-            kShuffleCompressionCodec,
-            protocol::toString(protocol::CompressionCodec::NONE)),
-        NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
-        // NUM_PROP(kNumExample, 50'000),
-        // BOOL_PROP(kBoolExample, false),
-};
+const std::unordered_map<std::string, folly::Optional<std::string>>&
+CelebornConf::defaultProperties() {
+  static const std::unordered_map<std::string, folly::Optional<std::string>>
+      defaultProp = {
+          STR_PROP(kRpcAskTimeout, "60s"),
+          STR_PROP(kRpcLookupTimeout, "30s"),
+          STR_PROP(kClientIoConnectionTimeout, "300s"),
+          STR_PROP(kClientRpcRegisterShuffleAskTimeout, "60s"),
+          NUM_PROP(kClientRegisterShuffleMaxRetries, 3),
+          STR_PROP(kClientRegisterShuffleRetryWait, "3s"),
+          NUM_PROP(kClientPushRetryThreads, 8),
+          STR_PROP(kClientPushTimeout, "120s"),
+          STR_PROP(kClientPushReviveInterval, "100ms"),
+          NUM_PROP(kClientPushReviveBatchSize, 2048),
+          NUM_PROP(kClientPushMaxReviveTimes, 5),
+          STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
+          NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
+          NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
+          NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
+          NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
+          STR_PROP(kClientRpcRequestPartitionLocationAskTimeout, "60s"),
+          STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
+          STR_PROP(kNetworkConnectTimeout, "10s"),
+          STR_PROP(kClientFetchTimeout, "600s"),
+          NUM_PROP(kNetworkIoNumConnectionsPerPeer, 1),
+          NUM_PROP(kNetworkIoClientThreads, 0),
+          NUM_PROP(kClientFetchMaxReqsInFlight, 3),
+          STR_PROP(
+              kShuffleCompressionCodec,
+              protocol::toString(protocol::CompressionCodec::NONE)),
+          NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
+          // NUM_PROP(kNumExample, 50'000),
+          // BOOL_PROP(kBoolExample, false),
+      };
+  return defaultProp;
+}
 
 CelebornConf::CelebornConf() {
-  registeredProps_ = kDefaultProperties;
+  registeredProps_ = defaultProperties();
 }
 
 CelebornConf::CelebornConf(const std::string& filename) {
   initialize(filename);
-  registeredProps_ = kDefaultProperties;
+  registeredProps_ = defaultProperties();
 }
 
 CelebornConf::CelebornConf(const CelebornConf& other) {
@@ -193,6 +204,34 @@ Timeout CelebornConf::rpcLookupTimeout() const {
       toDuration(optionalProperty(kRpcLookupTimeout).value()));
 }
 
+Timeout CelebornConf::clientIoConnectionTimeout() const {
+  return utils::toTimeout(
+      toDuration(optionalProperty(kClientIoConnectionTimeout).value()));
+}
+
+Timeout CelebornConf::clientRpcRegisterShuffleRpcAskTimeout() const {
+  return utils::toTimeout(toDuration(
+      optionalProperty(kClientRpcRegisterShuffleAskTimeout).value()));
+}
+
+int CelebornConf::clientRegisterShuffleMaxRetries() const {
+  return std::stoi(optionalProperty(kClientRegisterShuffleMaxRetries).value());
+}
+
+Timeout CelebornConf::clientRegisterShuffleRetryWait() const {
+  return utils::toTimeout(
+      toDuration(optionalProperty(kClientRegisterShuffleRetryWait).value()));
+}
+
+int CelebornConf::clientPushRetryThreads() const {
+  return std::stoi(optionalProperty(kClientPushRetryThreads).value());
+}
+
+Timeout CelebornConf::clientPushDataTimeout() const {
+  return utils::toTimeout(
+      toDuration(optionalProperty(kClientPushTimeout).value()));
+}
+
 Timeout CelebornConf::clientPushReviveInterval() const {
   return utils::toTimeout(
       toDuration(optionalProperty(kClientPushReviveInterval).value()));
@@ -202,6 +241,10 @@ int CelebornConf::clientPushReviveBatchSize() const {
   return std::stoi(optionalProperty(kClientPushReviveBatchSize).value());
 }
 
+int CelebornConf::clientPushMaxReviveTimes() const {
+  return std::stoi(optionalProperty(kClientPushMaxReviveTimes).value());
+}
+
 std::string CelebornConf::clientPushLimitStrategy() const {
   return optionalProperty(kClientPushLimitStrategy).value();
 }
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 530a59781..fb4294d2c 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -37,20 +37,41 @@ namespace conf {
 
 class CelebornConf : public BaseConf {
  public:
-  static const std::unordered_map<std::string, folly::Optional<std::string>>
-      kDefaultProperties;
+  static const std::unordered_map<std::string, folly::Optional<std::string>>&
+  defaultProperties();
 
   static constexpr std::string_view kRpcAskTimeout{"celeborn.rpc.askTimeout"};
 
   static constexpr std::string_view kRpcLookupTimeout{
       "celeborn.rpc.lookupTimeout"};
 
+  static constexpr std::string_view kClientIoConnectionTimeout{
+      "celeborn.client.io.connectionTimeout"};
+
+  static constexpr std::string_view kClientRpcRegisterShuffleAskTimeout{
+      "celeborn.client.rpc.registerShuffle.askTimeout"};
+
+  static constexpr std::string_view kClientRegisterShuffleMaxRetries{
+      "celeborn.client.registerShuffle.maxRetries"};
+
+  static constexpr std::string_view kClientRegisterShuffleRetryWait{
+      "celeborn.client.registerShuffle.retryWait"};
+
+  static constexpr std::string_view kClientPushRetryThreads{
+      "celeborn.client.push.retry.threads"};
+
+  static constexpr std::string_view kClientPushTimeout{
+      "celeborn.client.push.timeout"};
+
   static constexpr std::string_view kClientPushReviveInterval{
       "celeborn.client.push.revive.interval"};
 
   static constexpr std::string_view kClientPushReviveBatchSize{
       "celeborn.client.push.revive.batchSize"};
 
+  static constexpr std::string_view kClientPushMaxReviveTimes{
+      "celeborn.client.push.revive.maxRetries"};
+
   static constexpr std::string_view kClientPushLimitStrategy{
       "celeborn.client.push.limit.strategy"};
 
@@ -110,10 +131,24 @@ class CelebornConf : public BaseConf {
 
   Timeout rpcLookupTimeout() const;
 
+  Timeout clientIoConnectionTimeout() const;
+
+  Timeout clientRpcRegisterShuffleRpcAskTimeout() const;
+
+  int clientRegisterShuffleMaxRetries() const;
+
+  Timeout clientRegisterShuffleRetryWait() const;
+
+  int clientPushRetryThreads() const;
+
+  Timeout clientPushDataTimeout() const;
+
   Timeout clientPushReviveInterval() const;
 
   int clientPushReviveBatchSize() const;
 
+  int clientPushMaxReviveTimes() const;
+
   std::string clientPushLimitStrategy() const;
 
   int clientPushMaxReqsInFlightPerWorker() const;
diff --git a/cpp/celeborn/memory/ByteBuffer.h b/cpp/celeborn/memory/ByteBuffer.h
index 0c1027fe2..434d51b7f 100644
--- a/cpp/celeborn/memory/ByteBuffer.h
+++ b/cpp/celeborn/memory/ByteBuffer.h
@@ -172,8 +172,11 @@ class WriteOnlyByteBuffer : public ByteBuffer {
     appender_->push(reinterpret_cast<const uint8_t*>(ptr), data.size());
   }
 
-  void writeFromBuffer(const void* data, const size_t len) const {
-    appender_->push(static_cast<const uint8_t*>(data), len);
+  void writeFromBuffer(
+      const uint8_t* data,
+      const size_t offset,
+      const size_t length) const {
+    appender_->push(data + offset, length);
   }
 
   size_t size() const {
diff --git a/cpp/celeborn/network/TransportClient.cpp 
b/cpp/celeborn/network/TransportClient.cpp
index c3e4d81fd..c4b865d1c 100644
--- a/cpp/celeborn/network/TransportClient.cpp
+++ b/cpp/celeborn/network/TransportClient.cpp
@@ -165,7 +165,7 @@ SerializePipeline::Ptr MessagePipelineFactory::newPipeline(
 }
 
 TransportClientFactory::TransportClientFactory(
-    const std::shared_ptr<conf::CelebornConf>& conf) {
+    const std::shared_ptr<const conf::CelebornConf>& conf) {
   numConnectionsPerPeer_ = conf->networkIoNumConnectionsPerPeer();
   rpcLookupTimeout_ = conf->rpcLookupTimeout();
   connectTimeout_ = conf->networkConnectTimeout();
@@ -180,6 +180,13 @@ TransportClientFactory::TransportClientFactory(
 std::shared_ptr<TransportClient> TransportClientFactory::createClient(
     const std::string& host,
     uint16_t port) {
+  return createClient(host, port, std::rand());
+}
+
+std::shared_ptr<TransportClient> TransportClientFactory::createClient(
+    const std::string& host,
+    uint16_t port,
+    int32_t partitionId) {
   auto address = folly::SocketAddress(host, port);
   auto pool = clientPools_.withLock([&](auto& registry) {
     auto iter = registry.find(address);
@@ -191,7 +198,7 @@ std::shared_ptr<TransportClient> 
TransportClientFactory::createClient(
     registry[address] = createdPool;
     return createdPool;
   });
-  auto clientId = std::rand() % numConnectionsPerPeer_;
+  auto clientId = partitionId % numConnectionsPerPeer_;
   {
     std::lock_guard<std::mutex> lock(pool->mutex);
     // TODO: auto-disconnect if the connection is idle for a long time?
diff --git a/cpp/celeborn/network/TransportClient.h 
b/cpp/celeborn/network/TransportClient.h
index e3ece7d22..78c87414c 100644
--- a/cpp/celeborn/network/TransportClient.h
+++ b/cpp/celeborn/network/TransportClient.h
@@ -119,12 +119,15 @@ class MessagePipelineFactory
 
 class TransportClientFactory {
  public:
-  TransportClientFactory(const std::shared_ptr<conf::CelebornConf>& conf);
+  TransportClientFactory(const std::shared_ptr<const conf::CelebornConf>& 
conf);
 
   virtual std::shared_ptr<TransportClient> createClient(
       const std::string& host,
       uint16_t port);
 
+  virtual std::shared_ptr<TransportClient>
+  createClient(const std::string& host, uint16_t port, int32_t partitionId);
+
  private:
   struct ClientPool {
     std::mutex mutex;
diff --git a/cpp/celeborn/tests/DataSumWithReaderClient.cpp 
b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
index ac62d13fe..533303323 100644
--- a/cpp/celeborn/tests/DataSumWithReaderClient.cpp
+++ b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
@@ -44,10 +44,10 @@ int main(int argc, char** argv) {
   auto conf = std::make_shared<celeborn::conf::CelebornConf>();
   conf->registerProperty(
       celeborn::conf::CelebornConf::kShuffleCompressionCodec, compressCodec);
-  auto clientFactory =
-      std::make_shared<celeborn::network::TransportClientFactory>(conf);
+  auto clientEndpoint =
+      std::make_shared<celeborn::client::ShuffleClientEndpoint>(conf);
   auto shuffleClient = celeborn::client::ShuffleClientImpl::create(
-      appUniqueId, conf, clientFactory);
+      appUniqueId, conf, *clientEndpoint);
   shuffleClient->setupLifecycleManagerRef(
       lifecycleManagerHost, lifecycleManagerPort);
 
diff --git a/cpp/celeborn/utils/CelebornUtils.cpp 
b/cpp/celeborn/utils/CelebornUtils.cpp
index 24340cb79..675850445 100644
--- a/cpp/celeborn/utils/CelebornUtils.cpp
+++ b/cpp/celeborn/utils/CelebornUtils.cpp
@@ -23,6 +23,10 @@ std::string makeShuffleKey(const std::string& appId, const 
int shuffleId) {
   return appId + "-" + std::to_string(shuffleId);
 }
 
+std::string makeMapKey(int shuffleId, int mapId, int attemptId) {
+  return fmt::format("{}-{}-{}", shuffleId, mapId, attemptId);
+}
+
 void writeUTF(memory::WriteOnlyByteBuffer& buffer, const std::string& msg) {
   buffer.write<short>(msg.size());
   buffer.writeFromString(msg);
diff --git a/cpp/celeborn/utils/CelebornUtils.h 
b/cpp/celeborn/utils/CelebornUtils.h
index ac1419914..f79f329a5 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -48,6 +48,8 @@ std::vector<T> toVector(const std::set<T>& in) {
 
 std::string makeShuffleKey(const std::string& appId, int shuffleId);
 
+std::string makeMapKey(int shuffleId, int mapId, int attemptId);
+
 void writeUTF(memory::WriteOnlyByteBuffer& buffer, const std::string& msg);
 
 void writeRpcAddress(
@@ -68,6 +70,12 @@ inline uint64_t currentTimeMillis() {
       .count();
 }
 
+inline uint64_t currentTimeNanos() {
+  return std::chrono::duration_cast<std::chrono::nanoseconds>(
+             std::chrono::high_resolution_clock ::now().time_since_epoch())
+      .count();
+}
+
 /// parse string like "Any-Host-Str:Port#1:Port#2:...:Port#num", split into
 /// {"Any-Host-Str", "Port#1", "Port#2", ..., "Port#num"}. Note that the
 /// "Any-Host-Str" might contain ':' in IPV6 address.


Reply via email to