HolyLow commented on code in PR #3553:
URL: https://github.com/apache/celeborn/pull/3553#discussion_r2579503042
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -143,6 +311,335 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
return true;
}
+std::shared_ptr<PushState> ShuffleClientImpl::getPushState(
+ const std::string& mapKey) {
+ return pushStates_.computeIfAbsent(
+ mapKey, [&]() { return std::make_shared<PushState>(*conf_); });
+}
+
+void ShuffleClientImpl::initReviveManagerLocked() {
+ if (!reviveManager_) {
+ std::string uniqueName = appUniqueId_;
+ uniqueName += std::to_string(utils::currentTimeNanos());
+ reviveManager_ =
+ ReviveManager::create(uniqueName, *conf_, weak_from_this());
+ }
+}
+
+void ShuffleClientImpl::registerShuffle(
+ int shuffleId,
+ int numMappers,
+ int numPartitions) {
+ auto shuffleMutex = shuffleMutexes_.computeIfAbsent(
+ shuffleId, []() { return std::make_shared<std::mutex>(); });
+ // RegisterShuffle might be issued concurrently, we only allow one issue
+ // for each shuffleId.
+ std::lock_guard<std::mutex> lock(*shuffleMutex);
+ if (partitionLocationMaps_.containsKey(shuffleId)) {
+ return;
+ }
+ CELEBORN_CHECK(
+ lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
+ const int maxRetries = conf_->clientRegisterShuffleMaxRetries();
+ int numRetries = 1;
+ for (; numRetries <= maxRetries; numRetries++) {
+ try {
+ // Send the query request to lifecycleManager.
+ auto registerShuffleResponse = lifecycleManagerRef_->askSync<
+ protocol::RegisterShuffle,
+ protocol::RegisterShuffleResponse>(
+ protocol::RegisterShuffle{shuffleId, numMappers, numPartitions},
+ conf_->clientRpcRegisterShuffleRpcAskTimeout());
+
+ switch (registerShuffleResponse->status) {
+ case protocol::StatusCode::SUCCESS: {
+ VLOG(1) << "success to registerShuffle, shuffleId " << shuffleId
+ << " numMappers " << numMappers << " numPartitions "
+ << numPartitions;
+ auto partitionLocationMap =
std::make_shared<utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>>();
+ auto& partitionLocations =
+ registerShuffleResponse->partitionLocations;
+ for (int i = 0; i < partitionLocations.size(); i++) {
+ auto id = partitionLocations[i]->id;
+ partitionLocationMap->set(id, std::move(partitionLocations[i]));
+ }
+ partitionLocationMaps_.set(
+ shuffleId, std::move(partitionLocationMap));
+ return;
+ }
+ default: {
+ LOG(ERROR)
+ << "LifecycleManager request slots return protocol::StatusCode "
+ << registerShuffleResponse->status << " , shuffleId " <<
shuffleId
+ << " NumMappers " << numMappers << " numPartitions "
+ << numPartitions << " , retry again, remain retry times "
+ << maxRetries - numRetries;
+ }
+ }
+ } catch (std::exception& e) {
+ CELEBORN_FAIL(fmt::format(
+ "registerShuffle encounters error after {} tries, "
+ "shuffleId {} numMappers {} numPartitions {}, errorMsg: {}",
+ numRetries,
+ shuffleId,
+ numMappers,
+ numPartitions,
+ e.what()));
+ break;
+ }
+ std::this_thread::sleep_for(conf_->clientRegisterShuffleRetryWait());
+ }
+ partitionLocationMaps_.set(shuffleId, nullptr);
+ CELEBORN_FAIL(fmt::format(
+ "registerShuffle failed after {} tries, "
+ "shuffleId {} numMappers {} numPartitions {}",
+ maxRetries,
+ shuffleId,
+ numMappers,
+ numPartitions));
+}
+
+void ShuffleClientImpl::submitRetryPushData(
+ int shuffleId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+ int batchId,
+ std::shared_ptr<PushDataCallback> pushDataCallback,
+ std::shared_ptr<PushState> pushState,
+ PtrReviveRequest request,
+ int remainReviveTimes,
+ long dueTimeMs) {
+ long reviveWaitTimeMs = dueTimeMs - utils::currentTimeMillis();
+ long accumulatedTimeMs = 0;
+ const long deltaMs = 50;
+ while (request->reviveStatus.load() ==
+ protocol::StatusCode::REVIVE_INITIALIZED &&
+ accumulatedTimeMs <= reviveWaitTimeMs) {
+ std::this_thread::sleep_for(utils::MS(deltaMs));
+ accumulatedTimeMs += deltaMs;
+ }
+ if (mapperEnded(shuffleId, request->mapId)) {
+ VLOG(1) << "Revive for push data success, but the mapper already ended "
+ "for shuffle "
+ << shuffleId << " map " << request->mapId << " attempt "
+ << request->attemptId << " partition " << request->partitionId
+ << " batch " << batchId << " location hostAndPushPort "
+ << request->loc->hostAndPushPort() << ".";
+ pushState->removeBatch(batchId, request->loc->hostAndPushPort());
+ return;
+ }
+ if (request->reviveStatus.load() != protocol::StatusCode::SUCCESS) {
+ // TODO: the exception message here should be assembled...
+ pushDataCallback->onFailure(std::make_unique<std::exception>());
+ return;
+ }
+ auto locationMapOptional = partitionLocationMaps_.get(shuffleId);
+ CELEBORN_CHECK(locationMapOptional.has_value());
+ auto newLocationOptional =
+ locationMapOptional.value()->get(request->partitionId);
+ CELEBORN_CHECK(newLocationOptional.has_value());
+ auto newLocation = newLocationOptional.value();
+ LOG(INFO) << "Revive for push data success, new location for shuffle "
+ << shuffleId << " map " << request->mapId << " attempt "
+ << request->attemptId << " partition " << request->partitionId
+ << " batch " << batchId << " is location hostAndPushPort "
+ << newLocation->hostAndPushPort() << ".";
+ pushDataCallback->updateLatestLocation(newLocation);
+
+ try {
+ CELEBORN_CHECK_GT(remainReviveTimes, 0, "no remainReviveTime left");
+ network::PushData pushData(
+ network::Message::nextRequestId(),
+ protocol::PartitionLocation::Mode::PRIMARY,
+ utils::makeShuffleKey(appUniqueId_, shuffleId),
+ newLocation->uniqueId(),
+ std::move(body));
+ auto client = clientFactory_->createClient(
+ newLocation->host, newLocation->pushPort, request->partitionId);
+ client->pushDataAsync(
+ pushData, conf_->clientPushDataTimeout(), pushDataCallback);
+ } catch (std::exception e) {
+ LOG(ERROR) << "Exception raised while pushing data for shuffle "
+ << shuffleId << " map " << request->mapId << " attempt "
+ << request->attemptId << " partition " << request->partitionId
+ << " batch " << batchId << " location hostAndPushPort "
+ << newLocation->hostAndPushPort() << " errorMsg " << e.what()
+ << ".";
+ // TODO: The failure should be treated better.
+ pushDataCallback->onFailure(std::make_unique<std::exception>(e));
+ }
+}
+
+bool ShuffleClientImpl::checkMapperEnded(
+ int shuffleId,
+ int mapId,
+ const std::string& mapKey) {
+ if (mapperEnded(shuffleId, mapId)) {
+ VLOG(1) << "Mapper already ended for shuffle " << shuffleId << " map "
+ << mapId;
+ if (auto pushStateOptional = pushStates_.get(mapKey);
+ pushStateOptional.has_value()) {
+ auto pushState = pushStateOptional.value();
+ pushState->cleanup();
+ }
+ return true;
+ }
+ return false;
+}
+
+bool ShuffleClientImpl::mapperEnded(int shuffleId, int mapId) {
+ if (auto mapperEndSetOptional = mapperEndSets_.get(shuffleId);
+ mapperEndSetOptional.has_value() &&
+ mapperEndSetOptional.value()->contains(mapId)) {
+ return true;
+ }
+ if (stageEnded(shuffleId)) {
+ return true;
+ }
+ return false;
+}
+
+bool ShuffleClientImpl::stageEnded(int shuffleId) {
+ return stageEndShuffleSet_.contains(shuffleId);
+}
+
+void ShuffleClientImpl::addRequestToReviveManager(
+ std::shared_ptr<protocol::ReviveRequest> reviveRequest) {
+ reviveManager_->addRequest(std::move(reviveRequest));
+}
+
+std::optional<std::unordered_map<int, int>> ShuffleClientImpl::reviveBatch(
+ int shuffleId,
+ const std::unordered_set<int>& mapIds,
+ const std::unordered_map<int, PtrReviveRequest>& requests) {
+ std::unordered_map<int, int> result;
+ auto partitionLocationMap = partitionLocationMaps_.get(shuffleId).value();
+ std::unordered_map<int, std::shared_ptr<const protocol::PartitionLocation>>
+ oldLocationMap;
+ protocol::Revive revive;
+ revive.shuffleId = shuffleId;
+ revive.mapIds.insert(mapIds.begin(), mapIds.end());
+ for (auto& [partitionId, request] : requests) {
+ oldLocationMap[request->partitionId] = request->loc;
+ revive.reviveRequests.insert(request);
+ }
+ try {
+ auto response =
+ lifecycleManagerRef_
+ ->askSync<protocol::Revive, protocol::ChangeLocationResponse>(
+ revive,
+ conf_->clientRpcRequestPartitionLocationRpcAskTimeout());
+ auto mapperEndSet = mapperEndSets_.computeIfAbsent(shuffleId, []() {
+ return std::make_shared<utils::ConcurrentHashSet<int>>();
+ });
+ for (auto endedMapId : response->endedMapIds) {
+ mapperEndSet->insert(endedMapId);
+ }
+ for (auto& partitionInfo : response->partitionInfos) {
+ switch (partitionInfo.status) {
+ case protocol::StatusCode::SUCCESS: {
+ partitionLocationMap->set(
+ partitionInfo.partitionId, partitionInfo.partition);
+ break;
+ }
+ case protocol::StatusCode::STAGE_ENDED: {
+ stageEndShuffleSet_.insert(shuffleId);
+ return {std::move(result)};
+ }
+ case protocol::StatusCode::SHUFFLE_NOT_REGISTERED: {
+ LOG(ERROR) << "shuffleId " << shuffleId << " not registered!";
+ return std::nullopt;
+ }
+ default: {
+ // noop
+ }
+ }
+ result[partitionInfo.partitionId] = partitionInfo.status;
+ }
+ return {std::move(result)};
+ } catch (std::exception& e) {
+ LOG(ERROR) << "reviveBatch failed: " << e.what();
+ return std::nullopt;
+ }
+}
+
+bool ShuffleClientImpl::revive(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ int epoch,
+ std::shared_ptr<const protocol::PartitionLocation> oldLocation,
+ protocol::StatusCode cause) {
+ auto request = std::make_shared<protocol::ReviveRequest>(
+ shuffleId, mapId, attemptId, partitionId, epoch, oldLocation, cause);
+ auto resultOptional =
+ reviveBatch(shuffleId, {mapId}, {{partitionId, request}});
+ if (mapperEnded(shuffleId, mapId)) {
+ VLOG(1) << "Revive success, but the mapper ended for shuffle " << shuffleId
+ << " map " << mapId << " attempt " << attemptId << " partition"
+ << partitionId << ", just return true(Assume revive
successfully).";
+ return true;
+ }
+ if (resultOptional.has_value()) {
+ auto result = resultOptional.value();
+ return result.find(partitionId) != result.end() &&
+ result[partitionId] == protocol::StatusCode::SUCCESS;
+ }
+ return false;
+}
+
+void ShuffleClientImpl::limitMaxInFlight(
+ const std::string& mapKey,
+ PushState& pushState,
+ const std::string hostAndPushPort) {
+ bool reachLimit = pushState.limitMaxInFlight(hostAndPushPort);
+ if (reachLimit) {
+ auto msg = fmt::format(
+ "Waiting timeout for task {} while limiting max "
+ "in-flight requests to {}.",
+ mapKey,
+ hostAndPushPort);
+ if (auto exceptionMsgOptional = pushState.getExceptionMsg();
+ exceptionMsgOptional.has_value()) {
+ msg += " PushState exception: " + exceptionMsgOptional.value();
+ }
+ CELEBORN_FAIL(msg);
+ }
+}
+
+void ShuffleClientImpl::limitZeroInFlight(
+ const std::string& mapKey,
+ PushState& pushState) {
+ bool reachLimit = pushState.limitZeroInFlight();
+ if (reachLimit) {
+ auto msg = fmt::format(
+ "Waiting timeout for task {} while limiting zero "
+ "in-flight requests.",
+ mapKey);
+ if (auto exceptionMsgOptional = pushState.getExceptionMsg();
+ exceptionMsgOptional.has_value()) {
+ msg += " PushState exception: " + exceptionMsgOptional.value();
+ }
+ CELEBORN_FAIL(msg);
+ }
+}
+
+std::optional<ShuffleClientImpl::PtrPartitionLocationMap>
+ShuffleClientImpl::getPartitionLocationMap(int shuffleId) {
+ return partitionLocationMaps_.get(shuffleId);
+}
+
+utils::ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
+ShuffleClientImpl::mapperEndSets() {
+ return mapperEndSets_;
+}
+
+void ShuffleClientImpl::addPushDataRetryTask(folly::Func&& task) {
+ pushDataRetryPool_->add(std::move(task));
Review Comment:
Yes, you are right that the pool is not initiated. My big mistake. But the
initialization must be taken care of, that we want to reuse the critical
resources across different ShuffleClients. Therefore I extract all the
resources to be reused to a standalone class called `ShuffleClientEndpoint`
which could be reused across different clients.
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -143,6 +311,335 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
return true;
}
+std::shared_ptr<PushState> ShuffleClientImpl::getPushState(
+ const std::string& mapKey) {
+ return pushStates_.computeIfAbsent(
+ mapKey, [&]() { return std::make_shared<PushState>(*conf_); });
+}
+
+void ShuffleClientImpl::initReviveManagerLocked() {
+ if (!reviveManager_) {
+ std::string uniqueName = appUniqueId_;
+ uniqueName += std::to_string(utils::currentTimeNanos());
+ reviveManager_ =
+ ReviveManager::create(uniqueName, *conf_, weak_from_this());
+ }
+}
+
+void ShuffleClientImpl::registerShuffle(
+ int shuffleId,
+ int numMappers,
+ int numPartitions) {
+ auto shuffleMutex = shuffleMutexes_.computeIfAbsent(
+ shuffleId, []() { return std::make_shared<std::mutex>(); });
+ // RegisterShuffle might be issued concurrently, we only allow one issue
+ // for each shuffleId.
+ std::lock_guard<std::mutex> lock(*shuffleMutex);
+ if (partitionLocationMaps_.containsKey(shuffleId)) {
+ return;
+ }
+ CELEBORN_CHECK(
+ lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
+ const int maxRetries = conf_->clientRegisterShuffleMaxRetries();
+ int numRetries = 1;
+ for (; numRetries <= maxRetries; numRetries++) {
+ try {
+ // Send the query request to lifecycleManager.
+ auto registerShuffleResponse = lifecycleManagerRef_->askSync<
+ protocol::RegisterShuffle,
+ protocol::RegisterShuffleResponse>(
+ protocol::RegisterShuffle{shuffleId, numMappers, numPartitions},
+ conf_->clientRpcRegisterShuffleRpcAskTimeout());
+
+ switch (registerShuffleResponse->status) {
+ case protocol::StatusCode::SUCCESS: {
+ VLOG(1) << "success to registerShuffle, shuffleId " << shuffleId
+ << " numMappers " << numMappers << " numPartitions "
+ << numPartitions;
+ auto partitionLocationMap =
std::make_shared<utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>>();
+ auto& partitionLocations =
+ registerShuffleResponse->partitionLocations;
+ for (int i = 0; i < partitionLocations.size(); i++) {
Review Comment:
done.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]