HolyLow commented on code in PR #3553:
URL: https://github.com/apache/celeborn/pull/3553#discussion_r2579500304


##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -143,6 +311,335 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
   return true;
 }
 
+std::shared_ptr<PushState> ShuffleClientImpl::getPushState(
+    const std::string& mapKey) {
+  return pushStates_.computeIfAbsent(
+      mapKey, [&]() { return std::make_shared<PushState>(*conf_); });
+}
+
+void ShuffleClientImpl::initReviveManagerLocked() {
+  if (!reviveManager_) {
+    std::string uniqueName = appUniqueId_;
+    uniqueName += std::to_string(utils::currentTimeNanos());
+    reviveManager_ =
+        ReviveManager::create(uniqueName, *conf_, weak_from_this());
+  }
+}
+
+void ShuffleClientImpl::registerShuffle(
+    int shuffleId,
+    int numMappers,
+    int numPartitions) {
+  auto shuffleMutex = shuffleMutexes_.computeIfAbsent(
+      shuffleId, []() { return std::make_shared<std::mutex>(); });
+  // RegisterShuffle might be issued concurrently, we only allow one issue
+  // for each shuffleId.
+  std::lock_guard<std::mutex> lock(*shuffleMutex);
+  if (partitionLocationMaps_.containsKey(shuffleId)) {
+    return;
+  }
+  CELEBORN_CHECK(
+      lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
+  const int maxRetries = conf_->clientRegisterShuffleMaxRetries();
+  int numRetries = 1;
+  for (; numRetries <= maxRetries; numRetries++) {
+    try {
+      // Send the query request to lifecycleManager.
+      auto registerShuffleResponse = lifecycleManagerRef_->askSync<
+          protocol::RegisterShuffle,
+          protocol::RegisterShuffleResponse>(
+          protocol::RegisterShuffle{shuffleId, numMappers, numPartitions},
+          conf_->clientRpcRegisterShuffleRpcAskTimeout());
+
+      switch (registerShuffleResponse->status) {
+        case protocol::StatusCode::SUCCESS: {
+          VLOG(1) << "success to registerShuffle, shuffleId " << shuffleId
+                  << " numMappers " << numMappers << " numPartitions "
+                  << numPartitions;
+          auto partitionLocationMap = 
std::make_shared<utils::ConcurrentHashMap<
+              int,
+              std::shared_ptr<const protocol::PartitionLocation>>>();
+          auto& partitionLocations =
+              registerShuffleResponse->partitionLocations;
+          for (int i = 0; i < partitionLocations.size(); i++) {
+            auto id = partitionLocations[i]->id;
+            partitionLocationMap->set(id, std::move(partitionLocations[i]));
+          }
+          partitionLocationMaps_.set(
+              shuffleId, std::move(partitionLocationMap));
+          return;
+        }
+        default: {
+          LOG(ERROR)
+              << "LifecycleManager request slots return protocol::StatusCode "
+              << registerShuffleResponse->status << " , shuffleId " << 
shuffleId
+              << " NumMappers " << numMappers << " numPartitions "
+              << numPartitions << " , retry again, remain retry times "
+              << maxRetries - numRetries;
+        }
+      }
+    } catch (std::exception& e) {
+      CELEBORN_FAIL(fmt::format(
+          "registerShuffle encounters error after {} tries, "
+          "shuffleId {} numMappers {} numPartitions {}, errorMsg: {}",
+          numRetries,
+          shuffleId,
+          numMappers,
+          numPartitions,
+          e.what()));
+      break;
+    }
+    std::this_thread::sleep_for(conf_->clientRegisterShuffleRetryWait());
+  }
+  partitionLocationMaps_.set(shuffleId, nullptr);
+  CELEBORN_FAIL(fmt::format(
+      "registerShuffle failed after {} tries, "
+      "shuffleId {} numMappers {} numPartitions {}",
+      maxRetries,
+      shuffleId,
+      numMappers,
+      numPartitions));
+}
+
+void ShuffleClientImpl::submitRetryPushData(
+    int shuffleId,
+    std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+    int batchId,
+    std::shared_ptr<PushDataCallback> pushDataCallback,
+    std::shared_ptr<PushState> pushState,
+    PtrReviveRequest request,
+    int remainReviveTimes,
+    long dueTimeMs) {
+  long reviveWaitTimeMs = dueTimeMs - utils::currentTimeMillis();
+  long accumulatedTimeMs = 0;
+  const long deltaMs = 50;
+  while (request->reviveStatus.load() ==
+             protocol::StatusCode::REVIVE_INITIALIZED &&
+         accumulatedTimeMs <= reviveWaitTimeMs) {
+    std::this_thread::sleep_for(utils::MS(deltaMs));
+    accumulatedTimeMs += deltaMs;
+  }
+  if (mapperEnded(shuffleId, request->mapId)) {
+    VLOG(1) << "Revive for push data success, but the mapper already ended "
+               "for shuffle "
+            << shuffleId << " map " << request->mapId << " attempt "
+            << request->attemptId << " partition " << request->partitionId
+            << " batch " << batchId << " location hostAndPushPort "
+            << request->loc->hostAndPushPort() << ".";
+    pushState->removeBatch(batchId, request->loc->hostAndPushPort());
+    return;
+  }
+  if (request->reviveStatus.load() != protocol::StatusCode::SUCCESS) {
+    // TODO: the exception message here should be assembled...
+    pushDataCallback->onFailure(std::make_unique<std::exception>());
+    return;
+  }
+  auto locationMapOptional = partitionLocationMaps_.get(shuffleId);
+  CELEBORN_CHECK(locationMapOptional.has_value());
+  auto newLocationOptional =
+      locationMapOptional.value()->get(request->partitionId);
+  CELEBORN_CHECK(newLocationOptional.has_value());
+  auto newLocation = newLocationOptional.value();
+  LOG(INFO) << "Revive for push data success, new location for shuffle "
+            << shuffleId << " map " << request->mapId << " attempt "
+            << request->attemptId << " partition " << request->partitionId
+            << " batch " << batchId << " is location hostAndPushPort "
+            << newLocation->hostAndPushPort() << ".";
+  pushDataCallback->updateLatestLocation(newLocation);
+
+  try {
+    CELEBORN_CHECK_GT(remainReviveTimes, 0, "no remainReviveTime left");
+    network::PushData pushData(
+        network::Message::nextRequestId(),
+        protocol::PartitionLocation::Mode::PRIMARY,
+        utils::makeShuffleKey(appUniqueId_, shuffleId),
+        newLocation->uniqueId(),
+        std::move(body));
+    auto client = clientFactory_->createClient(
+        newLocation->host, newLocation->pushPort, request->partitionId);
+    client->pushDataAsync(
+        pushData, conf_->clientPushDataTimeout(), pushDataCallback);
+  } catch (std::exception e) {

Review Comment:
   done.



##########
cpp/celeborn/client/ShuffleClient.h:
##########
@@ -121,44 +166,58 @@ class ShuffleClientImpl
       std::shared_ptr<PushState> pushState,
       PtrReviveRequest request,
       int remainReviveTimes,
-      long dueTimeMs) {}
+      long dueTimeMs);
 
-  // TODO: currently this function serves as a stub. will be updated in future
-  //  commits.
-  virtual bool mapperEnded(int shuffleId, int mapId) {
-    return true;
-  }
+  virtual bool mapperEnded(int shuffleId, int mapId);
 
-  // TODO: currently this function serves as a stub. will be updated in future
-  //  commits.
   virtual void addRequestToReviveManager(
-      std::shared_ptr<protocol::ReviveRequest> reviveRequest) {}
+      std::shared_ptr<protocol::ReviveRequest> reviveRequest);
 
-  // TODO: currently this function serves as a stub. will be updated in future
-  //  commits.
   virtual std::optional<std::unordered_map<int, int>> reviveBatch(
       int shuffleId,
       const std::unordered_set<int>& mapIds,
-      const std::unordered_map<int, PtrReviveRequest>& requests) {
-    return std::nullopt;
-  }
+      const std::unordered_map<int, PtrReviveRequest>& requests);
 
   virtual std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
-      int shuffleId) {
-    return partitionLocationMaps_.get(shuffleId);
-  }
+      int shuffleId);
 
   virtual utils::
       ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
-      mapperEndSets() {
-    return mapperEndSets_;
-  }
+      mapperEndSets();
 
-  virtual void addPushDataRetryTask(folly::Func&& task) {
-    pushDataRetryPool_->add(std::move(task));
-  }
+  virtual void addPushDataRetryTask(folly::Func&& task);
 
  private:
+  std::shared_ptr<PushState> getPushState(const std::string& mapKey);
+
+  void initReviveManagerLocked();
+
+  void registerShuffle(int shuffleId, int numMappers, int numPartitions);
+
+  bool checkMapperEnded(int shuffleId, int mapId, const std::string& mapKey);
+
+  bool stageEnded(int shuffleId);
+
+  bool revive(
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      int epoch,
+      std::shared_ptr<const protocol::PartitionLocation> oldLocation,
+      protocol::StatusCode cause);
+
+  // Check if the pushState's ongoing package num reaches the max limit, if so,
+  // block until the ongoing package num decreases below max limit.
+  void limitMaxInFlight(
+      const std::string& mapKey,
+      PushState& pushState,
+      const std::string hostAndPushPort);

Review Comment:
   done.



##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -143,6 +311,335 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
   return true;
 }
 
+std::shared_ptr<PushState> ShuffleClientImpl::getPushState(
+    const std::string& mapKey) {
+  return pushStates_.computeIfAbsent(
+      mapKey, [&]() { return std::make_shared<PushState>(*conf_); });
+}
+
+void ShuffleClientImpl::initReviveManagerLocked() {
+  if (!reviveManager_) {
+    std::string uniqueName = appUniqueId_;
+    uniqueName += std::to_string(utils::currentTimeNanos());
+    reviveManager_ =
+        ReviveManager::create(uniqueName, *conf_, weak_from_this());
+  }
+}
+
+void ShuffleClientImpl::registerShuffle(
+    int shuffleId,
+    int numMappers,
+    int numPartitions) {
+  auto shuffleMutex = shuffleMutexes_.computeIfAbsent(
+      shuffleId, []() { return std::make_shared<std::mutex>(); });
+  // RegisterShuffle might be issued concurrently, we only allow one issue
+  // for each shuffleId.
+  std::lock_guard<std::mutex> lock(*shuffleMutex);
+  if (partitionLocationMaps_.containsKey(shuffleId)) {
+    return;
+  }
+  CELEBORN_CHECK(
+      lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
+  const int maxRetries = conf_->clientRegisterShuffleMaxRetries();
+  int numRetries = 1;
+  for (; numRetries <= maxRetries; numRetries++) {
+    try {
+      // Send the query request to lifecycleManager.
+      auto registerShuffleResponse = lifecycleManagerRef_->askSync<
+          protocol::RegisterShuffle,
+          protocol::RegisterShuffleResponse>(
+          protocol::RegisterShuffle{shuffleId, numMappers, numPartitions},
+          conf_->clientRpcRegisterShuffleRpcAskTimeout());
+
+      switch (registerShuffleResponse->status) {
+        case protocol::StatusCode::SUCCESS: {
+          VLOG(1) << "success to registerShuffle, shuffleId " << shuffleId
+                  << " numMappers " << numMappers << " numPartitions "
+                  << numPartitions;
+          auto partitionLocationMap = 
std::make_shared<utils::ConcurrentHashMap<
+              int,
+              std::shared_ptr<const protocol::PartitionLocation>>>();
+          auto& partitionLocations =
+              registerShuffleResponse->partitionLocations;
+          for (int i = 0; i < partitionLocations.size(); i++) {
+            auto id = partitionLocations[i]->id;
+            partitionLocationMap->set(id, std::move(partitionLocations[i]));
+          }
+          partitionLocationMaps_.set(
+              shuffleId, std::move(partitionLocationMap));
+          return;
+        }
+        default: {
+          LOG(ERROR)
+              << "LifecycleManager request slots return protocol::StatusCode "
+              << registerShuffleResponse->status << " , shuffleId " << 
shuffleId
+              << " NumMappers " << numMappers << " numPartitions "
+              << numPartitions << " , retry again, remain retry times "
+              << maxRetries - numRetries;
+        }
+      }
+    } catch (std::exception& e) {
+      CELEBORN_FAIL(fmt::format(
+          "registerShuffle encounters error after {} tries, "
+          "shuffleId {} numMappers {} numPartitions {}, errorMsg: {}",
+          numRetries,
+          shuffleId,
+          numMappers,
+          numPartitions,
+          e.what()));
+      break;
+    }
+    std::this_thread::sleep_for(conf_->clientRegisterShuffleRetryWait());
+  }
+  partitionLocationMaps_.set(shuffleId, nullptr);
+  CELEBORN_FAIL(fmt::format(
+      "registerShuffle failed after {} tries, "
+      "shuffleId {} numMappers {} numPartitions {}",
+      maxRetries,
+      shuffleId,
+      numMappers,
+      numPartitions));
+}
+
+void ShuffleClientImpl::submitRetryPushData(
+    int shuffleId,
+    std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+    int batchId,
+    std::shared_ptr<PushDataCallback> pushDataCallback,
+    std::shared_ptr<PushState> pushState,
+    PtrReviveRequest request,
+    int remainReviveTimes,
+    long dueTimeMs) {
+  long reviveWaitTimeMs = dueTimeMs - utils::currentTimeMillis();
+  long accumulatedTimeMs = 0;
+  const long deltaMs = 50;
+  while (request->reviveStatus.load() ==
+             protocol::StatusCode::REVIVE_INITIALIZED &&
+         accumulatedTimeMs <= reviveWaitTimeMs) {
+    std::this_thread::sleep_for(utils::MS(deltaMs));
+    accumulatedTimeMs += deltaMs;
+  }
+  if (mapperEnded(shuffleId, request->mapId)) {
+    VLOG(1) << "Revive for push data success, but the mapper already ended "
+               "for shuffle "
+            << shuffleId << " map " << request->mapId << " attempt "
+            << request->attemptId << " partition " << request->partitionId
+            << " batch " << batchId << " location hostAndPushPort "
+            << request->loc->hostAndPushPort() << ".";
+    pushState->removeBatch(batchId, request->loc->hostAndPushPort());
+    return;
+  }
+  if (request->reviveStatus.load() != protocol::StatusCode::SUCCESS) {
+    // TODO: the exception message here should be assembled...
+    pushDataCallback->onFailure(std::make_unique<std::exception>());
+    return;
+  }
+  auto locationMapOptional = partitionLocationMaps_.get(shuffleId);
+  CELEBORN_CHECK(locationMapOptional.has_value());
+  auto newLocationOptional =
+      locationMapOptional.value()->get(request->partitionId);
+  CELEBORN_CHECK(newLocationOptional.has_value());
+  auto newLocation = newLocationOptional.value();
+  LOG(INFO) << "Revive for push data success, new location for shuffle "
+            << shuffleId << " map " << request->mapId << " attempt "
+            << request->attemptId << " partition " << request->partitionId
+            << " batch " << batchId << " is location hostAndPushPort "
+            << newLocation->hostAndPushPort() << ".";
+  pushDataCallback->updateLatestLocation(newLocation);
+
+  try {
+    CELEBORN_CHECK_GT(remainReviveTimes, 0, "no remainReviveTime left");
+    network::PushData pushData(
+        network::Message::nextRequestId(),
+        protocol::PartitionLocation::Mode::PRIMARY,
+        utils::makeShuffleKey(appUniqueId_, shuffleId),
+        newLocation->uniqueId(),
+        std::move(body));
+    auto client = clientFactory_->createClient(
+        newLocation->host, newLocation->pushPort, request->partitionId);
+    client->pushDataAsync(
+        pushData, conf_->clientPushDataTimeout(), pushDataCallback);
+  } catch (std::exception e) {
+    LOG(ERROR) << "Exception raised while pushing data for shuffle "
+               << shuffleId << " map " << request->mapId << " attempt "
+               << request->attemptId << " partition " << request->partitionId
+               << " batch " << batchId << " location hostAndPushPort "
+               << newLocation->hostAndPushPort() << " errorMsg " << e.what()
+               << ".";
+    // TODO: The failure should be treated better.
+    pushDataCallback->onFailure(std::make_unique<std::exception>(e));
+  }
+}
+
+bool ShuffleClientImpl::checkMapperEnded(
+    int shuffleId,
+    int mapId,
+    const std::string& mapKey) {
+  if (mapperEnded(shuffleId, mapId)) {
+    VLOG(1) << "Mapper already ended for shuffle " << shuffleId << " map "
+            << mapId;
+    if (auto pushStateOptional = pushStates_.get(mapKey);
+        pushStateOptional.has_value()) {
+      auto pushState = pushStateOptional.value();
+      pushState->cleanup();
+    }
+    return true;
+  }
+  return false;
+}
+
+bool ShuffleClientImpl::mapperEnded(int shuffleId, int mapId) {
+  if (auto mapperEndSetOptional = mapperEndSets_.get(shuffleId);
+      mapperEndSetOptional.has_value() &&
+      mapperEndSetOptional.value()->contains(mapId)) {
+    return true;
+  }
+  if (stageEnded(shuffleId)) {
+    return true;
+  }
+  return false;
+}
+
+bool ShuffleClientImpl::stageEnded(int shuffleId) {
+  return stageEndShuffleSet_.contains(shuffleId);
+}
+
+void ShuffleClientImpl::addRequestToReviveManager(
+    std::shared_ptr<protocol::ReviveRequest> reviveRequest) {
+  reviveManager_->addRequest(std::move(reviveRequest));
+}
+
+std::optional<std::unordered_map<int, int>> ShuffleClientImpl::reviveBatch(
+    int shuffleId,
+    const std::unordered_set<int>& mapIds,
+    const std::unordered_map<int, PtrReviveRequest>& requests) {
+  std::unordered_map<int, int> result;
+  auto partitionLocationMap = partitionLocationMaps_.get(shuffleId).value();
+  std::unordered_map<int, std::shared_ptr<const protocol::PartitionLocation>>
+      oldLocationMap;
+  protocol::Revive revive;
+  revive.shuffleId = shuffleId;
+  revive.mapIds.insert(mapIds.begin(), mapIds.end());
+  for (auto& [partitionId, request] : requests) {
+    oldLocationMap[request->partitionId] = request->loc;
+    revive.reviveRequests.insert(request);
+  }
+  try {
+    auto response =
+        lifecycleManagerRef_
+            ->askSync<protocol::Revive, protocol::ChangeLocationResponse>(
+                revive,
+                conf_->clientRpcRequestPartitionLocationRpcAskTimeout());
+    auto mapperEndSet = mapperEndSets_.computeIfAbsent(shuffleId, []() {
+      return std::make_shared<utils::ConcurrentHashSet<int>>();
+    });
+    for (auto endedMapId : response->endedMapIds) {
+      mapperEndSet->insert(endedMapId);
+    }
+    for (auto& partitionInfo : response->partitionInfos) {
+      switch (partitionInfo.status) {
+        case protocol::StatusCode::SUCCESS: {
+          partitionLocationMap->set(
+              partitionInfo.partitionId, partitionInfo.partition);
+          break;
+        }
+        case protocol::StatusCode::STAGE_ENDED: {
+          stageEndShuffleSet_.insert(shuffleId);
+          return {std::move(result)};
+        }
+        case protocol::StatusCode::SHUFFLE_NOT_REGISTERED: {
+          LOG(ERROR) << "shuffleId " << shuffleId << " not registered!";
+          return std::nullopt;
+        }
+        default: {
+          // noop
+        }
+      }
+      result[partitionInfo.partitionId] = partitionInfo.status;
+    }
+    return {std::move(result)};
+  } catch (std::exception& e) {
+    LOG(ERROR) << "reviveBatch failed: " << e.what();
+    return std::nullopt;
+  }
+}
+
+bool ShuffleClientImpl::revive(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    int epoch,
+    std::shared_ptr<const protocol::PartitionLocation> oldLocation,
+    protocol::StatusCode cause) {
+  auto request = std::make_shared<protocol::ReviveRequest>(
+      shuffleId, mapId, attemptId, partitionId, epoch, oldLocation, cause);
+  auto resultOptional =
+      reviveBatch(shuffleId, {mapId}, {{partitionId, request}});
+  if (mapperEnded(shuffleId, mapId)) {
+    VLOG(1) << "Revive success, but the mapper ended for shuffle " << shuffleId
+            << " map " << mapId << " attempt " << attemptId << " partition"
+            << partitionId << ", just return true(Assume revive 
successfully).";
+    return true;
+  }
+  if (resultOptional.has_value()) {
+    auto result = resultOptional.value();
+    return result.find(partitionId) != result.end() &&
+        result[partitionId] == protocol::StatusCode::SUCCESS;
+  }
+  return false;
+}
+
+void ShuffleClientImpl::limitMaxInFlight(
+    const std::string& mapKey,
+    PushState& pushState,
+    const std::string hostAndPushPort) {

Review Comment:
   done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to