HolyLow commented on code in PR #3543:
URL: https://github.com/apache/celeborn/pull/3543#discussion_r2562908552


##########
cpp/celeborn/client/writer/PushDataCallback.cpp:
##########
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushDataCallback.h"
+#include "celeborn/conf/CelebornConf.h"
+
+namespace celeborn {
+namespace client {
+
+std::shared_ptr<PushDataCallback> PushDataCallback::create(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    int numMappers,
+    int numPartitions,
+    const std::string& mapKey,
+    int batchId,
+    std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+    std::shared_ptr<PushState> pushState,
+    std::weak_ptr<ShuffleClientImpl> weakClient,
+    int remainingReviveTimes,
+    std::shared_ptr<const protocol::PartitionLocation> latestLocation) {
+  return std::shared_ptr<PushDataCallback>(new PushDataCallback(
+      shuffleId,
+      mapId,
+      attemptId,
+      partitionId,
+      numMappers,
+      numPartitions,
+      mapKey,
+      batchId,
+      std::move(databody),
+      pushState,
+      weakClient,
+      remainingReviveTimes,
+      latestLocation));
+}
+
+PushDataCallback::PushDataCallback(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    int numMappers,
+    int numPartitions,
+    const std::string& mapKey,
+    int batchId,
+    std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+    std::shared_ptr<PushState> pushState,
+    std::weak_ptr<ShuffleClientImpl> weakClient,
+    int remainingReviveTimes,
+    std::shared_ptr<const protocol::PartitionLocation> latestLocation)
+    : shuffleId_(shuffleId),
+      mapId_(mapId),
+      attemptId_(attemptId),
+      partitionId_(partitionId),
+      numMappers_(numMappers),
+      numPartitions_(numPartitions),
+      mapKey_(mapKey),
+      batchId_(batchId),
+      databody_(std::move(databody)),
+      pushState_(pushState),
+      weakClient_(weakClient),
+      remainingReviveTimes_(remainingReviveTimes),
+      latestLocation_(latestLocation) {}
+
+void PushDataCallback::onSuccess(
+    std::unique_ptr<memory::ReadOnlyByteBuffer> response) {
+  auto sharedClient = weakClient_.lock();
+  if (!sharedClient) {
+    LOG(WARNING) << "ShuffleClientImpl has expired when "
+                    "PushDataCallbackOnSuccess, ignored, shuffle "
+                 << shuffleId_ << " map " << mapId_ << " attempt " << 
attemptId_
+                 << " partition " << partitionId_ << " batch " << batchId_
+                 << ".";
+    return;
+  }
+  if (response->remainingSize() <= 0) {
+    pushState_->onSuccess(latestLocation_->hostAndPushPort());
+    pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+    return;
+  }
+  protocol::StatusCode reason =
+      static_cast<protocol::StatusCode>(response->read<uint8_t>());
+  switch (reason) {
+    case protocol::StatusCode::MAP_ENDED: {
+      auto mapperEndSet = sharedClient->mapperEndSets().computeIfAbsent(
+          shuffleId_,
+          []() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+      mapperEndSet->insert(mapId_);
+      break;
+    }
+    case protocol::StatusCode::SOFT_SPLIT: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " soft split required for shuffle " << shuffleId_ << " map "
+              << mapId_ << " attempt " << attemptId_ << " partition "
+              << partitionId_ << " batch " << batchId_ << ".";
+      if (!ShuffleClientImpl::newerPartitionLocationExists(
+              sharedClient->getPartitionLocationMap(shuffleId_).value(),
+              partitionId_,
+              latestLocation_->epoch)) {
+        auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+            shuffleId_,
+            mapId_,
+            attemptId_,
+            partitionId_,
+            latestLocation_->epoch,
+            latestLocation_,
+            protocol::StatusCode::SOFT_SPLIT);
+        sharedClient->addRequestToReviveManager(reviveRequest);
+      }
+      pushState_->onSuccess(latestLocation_->hostAndPushPort());
+      pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+      break;
+    }
+    case protocol::StatusCode::HARD_SPLIT: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " hard split required for shuffle " << shuffleId_ << " map "
+              << mapId_ << " attempt " << attemptId_ << " partition "
+              << partitionId_ << " batch " << batchId_ << ".";
+      reviveAndRetryPushData(*sharedClient, protocol::StatusCode::HARD_SPLIT);
+      break;
+    }
+    case protocol::StatusCode::PUSH_DATA_SUCCESS_PRIMARY_CONGESTED: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " primary congestion required for shuffle " << shuffleId_
+              << " map " << mapId_ << " attempt " << attemptId_ << " partition 
"
+              << partitionId_ << " batch " << batchId_ << ".";
+      pushState_->onCongestControl(latestLocation_->hostAndPushPort());
+      pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+      break;
+    }
+    case protocol::StatusCode::PUSH_DATA_SUCCESS_REPLICA_CONGESTED: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " primary congestion required for shuffle " << shuffleId_
+              << " map " << mapId_ << " attempt " << attemptId_ << " partition 
"
+              << partitionId_ << " batch " << batchId_ << ".";
+      pushState_->onCongestControl(latestLocation_->hostAndPushPort());
+      pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+      break;
+    }
+    default: {
+      // This is treated as success.
+      LOG(WARNING) << "unhandled PushData success protocol::StatusCode: "
+                   << reason;
+    }
+  }
+}
+
+void PushDataCallback::onFailure(std::unique_ptr<std::exception> exception) {
+  auto sharedClient = weakClient_.lock();

Review Comment:
   For protection purpose I think it is more proper to keep the checking lines 
verbose. Besides, if we extract the checking logic to some other function, the 
function still needs to return the shared_ptr, and we still needs to check if 
the returned value is valid or not then exit the function accordingly. In this 
way, no code line is saved, and the logic becomes even more verbose.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to