This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 42ff69f9f [CELEBORN-2229][CIP-14] Add support for 
celeborn.client.push.maxBytesSizeInFlight in CppClient
42ff69f9f is described below

commit 42ff69f9fdc9610daef4baa8919016bed5c36334
Author: afterincomparableyum <afterincomparableyum>
AuthorDate: Mon Jan 12 18:00:55 2026 +0800

    [CELEBORN-2229][CIP-14] Add support for 
celeborn.client.push.maxBytesSizeInFlight in CppClient
    
    ### What changes were proposed in this pull request?
    
    Add support for celeborn.client.push.maxBytesSizeInFlight in CppClient, 
similar to InFlightRequestTracker.java
    
    ### Does this PR resolve a correctness bug?
    
    No
    
    ### How was this patch tested?
    
    Compile locally, CI/CD will run unit tests
    
    Closes #3568 from afterincomparableyum/cpp-client/celeborn-2229.
    
    Authored-by: afterincomparableyum <afterincomparableyum>
    Signed-off-by: Shuang <[email protected]>
---
 cpp/celeborn/client/ShuffleClient.cpp           |   3 +-
 cpp/celeborn/client/tests/PushStateTest.cpp     | 151 ++++++++++++++++++++----
 cpp/celeborn/client/writer/PushDataCallback.cpp |   4 +-
 cpp/celeborn/client/writer/PushDataCallback.h   |   1 +
 cpp/celeborn/client/writer/PushState.cpp        | 103 ++++++++++++++--
 cpp/celeborn/client/writer/PushState.h          |  16 ++-
 cpp/celeborn/conf/CelebornConf.cpp              |  41 +++++++
 cpp/celeborn/conf/CelebornConf.h                |  20 ++++
 8 files changed, 302 insertions(+), 37 deletions(-)

diff --git a/cpp/celeborn/client/ShuffleClient.cpp 
b/cpp/celeborn/client/ShuffleClient.cpp
index 807fac2e0..401b99e2f 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -170,7 +170,8 @@ int ShuffleClientImpl::pushData(
   // Check limit.
   limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
   // Add inFlight requests.
-  pushState->addBatch(nextBatchId, hostAndPushPort);
+  const int batchBytesSize = length + kBatchHeaderSize;
+  pushState->addBatch(nextBatchId, batchBytesSize, hostAndPushPort);
   // Build pushData request.
   const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
   auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
diff --git a/cpp/celeborn/client/tests/PushStateTest.cpp 
b/cpp/celeborn/client/tests/PushStateTest.cpp
index 94b6c3aa3..400538314 100644
--- a/cpp/celeborn/client/tests/PushStateTest.cpp
+++ b/cpp/celeborn/client/tests/PushStateTest.cpp
@@ -28,58 +28,97 @@ class PushStateTest : public testing::Test {
     conf::CelebornConf conf;
     conf.registerProperty(
         conf::CelebornConf::kClientPushLimitInFlightTimeoutMs,
-        std::to_string(pushTimeoutMs_));
+        std::to_string(kPushTimeoutMs_));
     conf.registerProperty(
         conf::CelebornConf::kClientPushLimitInFlightSleepDeltaMs,
-        std::to_string(pushSleepDeltaMs_));
+        std::to_string(kPushSleepDeltaMs_));
     conf.registerProperty(
         conf::CelebornConf::kClientPushMaxReqsInFlightTotal,
-        std::to_string(maxReqsInFlight_));
+        std::to_string(kMaxReqsInFlight_));
     conf.registerProperty(
         conf::CelebornConf::kClientPushMaxReqsInFlightPerWorker,
-        std::to_string(maxReqsInFlight_));
+        std::to_string(kMaxReqsInFlight_));
 
     pushState_ = std::make_unique<PushState>(conf);
   }
 
   std::unique_ptr<PushState> pushState_;
-  static constexpr int pushTimeoutMs_ = 100;
-  static constexpr int pushSleepDeltaMs_ = 10;
-  static constexpr int maxReqsInFlight_ = 2;
+  static constexpr int kPushTimeoutMs_ = 100;
+  static constexpr int kPushSleepDeltaMs_ = 10;
+  static constexpr int kMaxReqsInFlight_ = 2;
+  static constexpr int kDefaultBatchSize_ = 1024;
+};
+
+class PushStateBytesSizeTest : public testing::Test {
+ protected:
+  void SetUp() override {
+    conf::CelebornConf conf;
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushLimitInFlightTimeoutMs,
+        std::to_string(kPushTimeoutMs_));
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushLimitInFlightSleepDeltaMs,
+        std::to_string(kPushSleepDeltaMs_));
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushMaxReqsInFlightTotal, "2");
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushMaxReqsInFlightPerWorker, "100");
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushMaxBytesSizeInFlightEnabled, "true");
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushMaxBytesSizeInFlightTotal,
+        std::to_string(kMaxBytesSizeTotal_) + "B");
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushMaxBytesSizeInFlightPerWorker,
+        std::to_string(kMaxBytesSizePerWorker_) + "B");
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushBufferMaxSize,
+        std::to_string(kBufferMaxSize_) + "B");
+
+    pushState_ = std::make_unique<PushState>(conf);
+  }
+
+  std::unique_ptr<PushState> pushState_;
+  static constexpr int kPushTimeoutMs_ = 100;
+  static constexpr int kPushSleepDeltaMs_ = 10;
+  static constexpr int kBatchSize_ = 1024;
+  static constexpr long kMaxBytesSizeTotal_ = 3000;
+  static constexpr long kMaxBytesSizePerWorker_ = 2500;
+  static constexpr int kBufferMaxSize_ = 65536;
 };
 
 TEST_F(PushStateTest, limitMaxInFlight) {
   const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
-  const int addBatchCalls = maxReqsInFlight_ + 1;
+  const int addBatchCalls = kMaxReqsInFlight_ + 1;
   std::vector<bool> addBatchMarks(addBatchCalls, false);
   std::thread addBatchThread([&]() {
     for (auto i = 0; i < addBatchCalls; i++) {
-      pushState_->addBatch(i, hostAndPushPort);
+      pushState_->addBatch(i, kDefaultBatchSize_, hostAndPushPort);
       EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort));
       addBatchMarks[i] = true;
     }
   });
 
-  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
-  for (auto i = 0; i < maxReqsInFlight_; i++) {
+  std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+  for (auto i = 0; i < kMaxReqsInFlight_; i++) {
     EXPECT_TRUE(addBatchMarks[i]);
   }
-  EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+  EXPECT_FALSE(addBatchMarks[kMaxReqsInFlight_]);
 
   pushState_->removeBatch(0, hostAndPushPort);
   addBatchThread.join();
-  EXPECT_TRUE(addBatchMarks[maxReqsInFlight_]);
+  EXPECT_TRUE(addBatchMarks[kMaxReqsInFlight_]);
 }
 
 TEST_F(PushStateTest, limitMaxInFlightTimeout) {
   const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
-  const int addBatchCalls = maxReqsInFlight_ + 1;
+  const int addBatchCalls = kMaxReqsInFlight_ + 1;
   std::vector<bool> addBatchMarks(addBatchCalls, false);
   std::thread addBatchThread([&]() {
     for (auto i = 0; i < addBatchCalls; i++) {
-      pushState_->addBatch(i, hostAndPushPort);
+      pushState_->addBatch(i, kDefaultBatchSize_, hostAndPushPort);
       auto result = pushState_->limitMaxInFlight(hostAndPushPort);
-      if (i < maxReqsInFlight_) {
+      if (i < kMaxReqsInFlight_) {
         EXPECT_FALSE(result);
       } else {
         EXPECT_TRUE(result);
@@ -88,14 +127,14 @@ TEST_F(PushStateTest, limitMaxInFlightTimeout) {
     }
   });
 
-  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
-  for (auto i = 0; i < maxReqsInFlight_; i++) {
+  std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+  for (auto i = 0; i < kMaxReqsInFlight_; i++) {
     EXPECT_TRUE(addBatchMarks[i]);
   }
-  EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+  EXPECT_FALSE(addBatchMarks[kMaxReqsInFlight_]);
 
   addBatchThread.join();
-  EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+  EXPECT_FALSE(addBatchMarks[kMaxReqsInFlight_]);
 }
 
 TEST_F(PushStateTest, limitZeroInFlight) {
@@ -103,12 +142,12 @@ TEST_F(PushStateTest, limitZeroInFlight) {
   const int addBatchCalls = 1;
   std::vector<bool> addBatchMarks(addBatchCalls, false);
   std::thread addBatchThread([&]() {
-    pushState_->addBatch(0, hostAndPushPort);
+    pushState_->addBatch(0, kDefaultBatchSize_, hostAndPushPort);
     EXPECT_FALSE(pushState_->limitZeroInFlight());
     addBatchMarks[0] = true;
   });
 
-  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+  std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
   EXPECT_FALSE(addBatchMarks[0]);
 
   pushState_->removeBatch(0, hostAndPushPort);
@@ -121,13 +160,13 @@ TEST_F(PushStateTest, limitZeroInFlightTimeout) {
   const int addBatchCalls = 1;
   std::vector<bool> addBatchMarks(addBatchCalls, false);
   std::thread addBatchThread([&]() {
-    pushState_->addBatch(0, hostAndPushPort);
+    pushState_->addBatch(0, kDefaultBatchSize_, hostAndPushPort);
     auto result = pushState_->limitZeroInFlight();
     EXPECT_TRUE(result);
     addBatchMarks[0] = !result;
   });
 
-  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+  std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
   EXPECT_FALSE(addBatchMarks[0]);
 
   addBatchThread.join();
@@ -153,3 +192,67 @@ TEST_F(PushStateTest, throwException) {
   }
   EXPECT_TRUE(exceptionThrowed);
 }
+
+TEST_F(PushStateBytesSizeTest, limitMaxInFlightByBytesSize) {
+  const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+  const int expectedAllowedBatches = 2;
+  const int addBatchCalls = expectedAllowedBatches + 1;
+  std::vector<bool> addBatchMarks(addBatchCalls, false);
+
+  std::thread addBatchThread([&]() {
+    for (auto i = 0; i < addBatchCalls; i++) {
+      pushState_->addBatch(i, kBatchSize_, hostAndPushPort);
+      auto result = pushState_->limitMaxInFlight(hostAndPushPort);
+      addBatchMarks[i] = true;
+      if (i < expectedAllowedBatches) {
+        EXPECT_FALSE(result) << "Batch " << i << " should be within limits";
+      }
+    }
+  });
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+  for (auto i = 0; i < expectedAllowedBatches; i++) {
+    EXPECT_TRUE(addBatchMarks[i]) << "Batch " << i << " should have completed";
+  }
+
+  pushState_->removeBatch(0, hostAndPushPort);
+  addBatchThread.join();
+  EXPECT_TRUE(addBatchMarks[expectedAllowedBatches]);
+}
+
+TEST_F(PushStateBytesSizeTest, limitMaxInFlightByTotalBytesSize) {
+  const std::string hostAndPushPort1 = "xx.xx.xx.xx:8080";
+  const std::string hostAndPushPort2 = "yy.yy.yy.yy:8080";
+
+  pushState_->addBatch(0, kBatchSize_, hostAndPushPort1);
+  EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort1));
+
+  pushState_->addBatch(1, kBatchSize_, hostAndPushPort2);
+  EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort2));
+
+  std::atomic<bool> thirdBatchCompleted{false};
+  std::thread addBatchThread([&]() {
+    pushState_->addBatch(2, kBatchSize_, hostAndPushPort1);
+    pushState_->limitMaxInFlight(hostAndPushPort1);
+    thirdBatchCompleted = true;
+  });
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+  EXPECT_FALSE(thirdBatchCompleted.load())
+      << "Third batch should be blocked due to total bytes limit";
+
+  pushState_->removeBatch(0, hostAndPushPort1);
+  addBatchThread.join();
+
+  EXPECT_TRUE(thirdBatchCompleted.load());
+}
+
+TEST_F(PushStateBytesSizeTest, cleanupClearsBytesSizeTracking) {
+  const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+
+  pushState_->addBatch(0, kBatchSize_, hostAndPushPort);
+  pushState_->addBatch(1, kBatchSize_, hostAndPushPort);
+  pushState_->cleanup();
+
+  EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort));
+}
diff --git a/cpp/celeborn/client/writer/PushDataCallback.cpp 
b/cpp/celeborn/client/writer/PushDataCallback.cpp
index 43550ba7f..3b0218088 100644
--- a/cpp/celeborn/client/writer/PushDataCallback.cpp
+++ b/cpp/celeborn/client/writer/PushDataCallback.cpp
@@ -73,6 +73,7 @@ PushDataCallback::PushDataCallback(
       numPartitions_(numPartitions),
       mapKey_(mapKey),
       batchId_(batchId),
+      batchBytesSize_(databody ? static_cast<int>(databody->size()) : 0),
       databody_(std::move(databody)),
       pushState_(pushState),
       weakClient_(weakClient),
@@ -208,7 +209,8 @@ void 
PushDataCallback::onFailure(std::unique_ptr<std::exception> exception) {
 
 void PushDataCallback::updateLatestLocation(
     std::shared_ptr<const protocol::PartitionLocation> latestLocation) {
-  pushState_->addBatch(batchId_, latestLocation->hostAndPushPort());
+  pushState_->addBatch(
+      batchId_, batchBytesSize_, latestLocation->hostAndPushPort());
   pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
   latestLocation_ = latestLocation;
 }
diff --git a/cpp/celeborn/client/writer/PushDataCallback.h 
b/cpp/celeborn/client/writer/PushDataCallback.h
index 9916cd191..b1c8334b8 100644
--- a/cpp/celeborn/client/writer/PushDataCallback.h
+++ b/cpp/celeborn/client/writer/PushDataCallback.h
@@ -85,6 +85,7 @@ class PushDataCallback : public network::RpcResponseCallback,
   const int numPartitions_;
   const std::string mapKey_;
   const int batchId_;
+  const int batchBytesSize_;
   const std::unique_ptr<memory::ReadOnlyByteBuffer> databody_;
   const std::shared_ptr<PushState> pushState_;
   const std::weak_ptr<ShuffleClientImpl> weakClient_;
diff --git a/cpp/celeborn/client/writer/PushState.cpp 
b/cpp/celeborn/client/writer/PushState.cpp
index b86a2aadb..8b3158e36 100644
--- a/cpp/celeborn/client/writer/PushState.cpp
+++ b/cpp/celeborn/client/writer/PushState.cpp
@@ -24,18 +24,35 @@ PushState::PushState(const conf::CelebornConf& conf)
     : waitInflightTimeoutMs_(conf.clientPushLimitInFlightTimeoutMs()),
       deltaMs_(conf.clientPushLimitInFlightSleepDeltaMs()),
       pushStrategy_(PushStrategy::create(conf)),
-      maxInFlightReqsTotal_(conf.clientPushMaxReqsInFlightTotal()) {}
+      maxInFlightReqsTotal_(conf.clientPushMaxReqsInFlightTotal()),
+      maxInFlightBytesSizeEnabled_(
+          conf.clientPushMaxBytesSizeInFlightEnabled()),
+      maxInFlightBytesSizeTotal_(conf.clientPushMaxBytesSizeInFlightTotal()),
+      maxInFlightBytesSizePerWorker_(
+          conf.clientPushMaxBytesSizeInFlightPerWorker()) {}
 
 int PushState::nextBatchId() {
   return currBatchId_.fetch_add(1);
 }
 
-void PushState::addBatch(int batchId, const std::string& hostAndPushPort) {
+void PushState::addBatch(
+    int batchId,
+    int batchBytesSize,
+    const std::string& hostAndPushPort) {
   auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
       hostAndPushPort,
       [&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
   batchIdSet->insert(batchId);
   totalInflightReqs_.fetch_add(1);
+
+  if (maxInFlightBytesSizeEnabled_) {
+    auto bytesSizePerAddress = inflightBytesSizePerAddress_.computeIfAbsent(
+        hostAndPushPort,
+        [&]() { return std::make_shared<std::atomic<long>>(0); });
+    bytesSizePerAddress->fetch_add(batchBytesSize);
+    inflightBatchBytesSizes_.set(batchId, batchBytesSize);
+    totalInflightBytes_.fetch_add(batchBytesSize);
+  }
 }
 
 void PushState::onSuccess(const std::string& hostAndPushPort) {
@@ -51,10 +68,24 @@ void PushState::removeBatch(int batchId, const std::string& 
hostAndPushPort) {
   if (batchIdSetOptional.has_value()) {
     auto batchIdSet = batchIdSetOptional.value();
     batchIdSet->erase(batchId);
-    totalInflightReqs_.fetch_sub(1);
   } else {
     LOG(WARNING) << "BatchIdSet of " << hostAndPushPort << " doesn't exist.";
   }
+
+  totalInflightReqs_.fetch_sub(1);
+
+  if (maxInFlightBytesSizeEnabled_) {
+    auto inflightBatchBytesSize = inflightBatchBytesSizes_.get(batchId);
+    inflightBatchBytesSizes_.erase(batchId);
+    if (inflightBatchBytesSize.has_value()) {
+      auto inflightBytesSize =
+          inflightBytesSizePerAddress_.get(hostAndPushPort);
+      if (inflightBytesSize.has_value()) {
+        inflightBytesSize.value()->fetch_sub(inflightBatchBytesSize.value());
+      }
+      totalInflightBytes_.fetch_sub(inflightBatchBytesSize.value());
+    }
+  }
 }
 
 bool PushState::limitMaxInFlight(const std::string& hostAndPushPort) {
@@ -67,22 +98,62 @@ bool PushState::limitMaxInFlight(const std::string& 
hostAndPushPort) {
   auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
       hostAndPushPort,
       [&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+  std::shared_ptr<std::atomic<long>> batchBytesSize = nullptr;
+  if (maxInFlightBytesSizeEnabled_) {
+    batchBytesSize = inflightBytesSizePerAddress_.computeIfAbsent(
+        hostAndPushPort,
+        [&]() { return std::make_shared<std::atomic<long>>(0); });
+  }
   long times = waitInflightTimeoutMs_ / deltaMs_;
   for (; times > 0; times--) {
-    if (totalInflightReqs_ <= maxInFlightReqsTotal_ &&
-        batchIdSet->size() <= currentMaxReqsInFlight) {
+    if (cleaned_.load()) {
+      return false;
+    }
+
+    bool reqCountWithinLimits =
+        (totalInflightReqs_ <= maxInFlightReqsTotal_ &&
+         static_cast<int>(batchIdSet->size()) <= currentMaxReqsInFlight);
+    bool byteSizeWithinLimits = false;
+
+    if (maxInFlightBytesSizeEnabled_ && batchBytesSize) {
+      byteSizeWithinLimits =
+          (totalInflightBytes_.load() <= maxInFlightBytesSizeTotal_ &&
+           batchBytesSize->load() <= maxInFlightBytesSizePerWorker_);
+    }
+
+    if (reqCountWithinLimits ||
+        (maxInFlightBytesSizeEnabled_ && byteSizeWithinLimits)) {
       break;
     }
+
     throwIfExceptionExists();
     std::this_thread::sleep_for(utils::MS(deltaMs_));
   }
 
   if (times <= 0) {
-    LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
-                 << " ms, there are still " << batchIdSet->size()
-                 << " batches in flight for hostAndPushPort " << 
hostAndPushPort
-                 << ", which exceeds the current limit "
-                 << currentMaxReqsInFlight;
+    if (totalInflightReqs_ > maxInFlightReqsTotal_ ||
+        static_cast<int>(batchIdSet->size()) > currentMaxReqsInFlight) {
+      LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
+                   << " ms, there are still " << totalInflightReqs_
+                   << " requests in flight (limit: " << maxInFlightReqsTotal_
+                   << "): " << batchIdSet->size()
+                   << " batches in flight for hostAndPushPort "
+                   << hostAndPushPort << ", which exceeds the current limit "
+                   << currentMaxReqsInFlight;
+    }
+    if (maxInFlightBytesSizeEnabled_ && batchBytesSize) {
+      if (totalInflightBytes_.load() > maxInFlightBytesSizeTotal_ ||
+          batchBytesSize->load() > maxInFlightBytesSizePerWorker_) {
+        LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
+                     << " ms, there are still " << totalInflightBytes_.load()
+                     << " bytes in flight (limit: "
+                     << maxInFlightBytesSizeTotal_
+                     << "): " << batchBytesSize->load()
+                     << " bytes for hostAndPushPort " << hostAndPushPort
+                     << ", which exceeds the current limit "
+                     << maxInFlightBytesSizePerWorker_;
+      }
+    }
   }
   throwIfExceptionExists();
   return times <= 0;
@@ -93,6 +164,9 @@ bool PushState::limitZeroInFlight() {
 
   long times = waitInflightTimeoutMs_ / deltaMs_;
   for (; times > 0; times--) {
+    if (cleaned_.load()) {
+      return false;
+    }
     if (totalInflightReqs_ <= 0) {
       break;
     }
@@ -147,9 +221,18 @@ std::optional<std::string> PushState::getExceptionMsg() 
const {
 }
 
 void PushState::cleanup() {
+  LOG(INFO) << "Cleanup " << totalInflightReqs_.load()
+            << " requests in flight.";
+  cleaned_.store(true);
   inflightBatchesPerAddress_.clear();
   totalInflightReqs_ = 0;
   pushStrategy_->clear();
+
+  if (maxInFlightBytesSizeEnabled_) {
+    inflightBytesSizePerAddress_.clear();
+    inflightBatchBytesSizes_.clear();
+    totalInflightBytes_ = 0;
+  }
 }
 
 void PushState::throwIfExceptionExists() {
diff --git a/cpp/celeborn/client/writer/PushState.h 
b/cpp/celeborn/client/writer/PushState.h
index a32b0971c..38ee9607f 100644
--- a/cpp/celeborn/client/writer/PushState.h
+++ b/cpp/celeborn/client/writer/PushState.h
@@ -18,6 +18,7 @@
 #pragma once
 
 #include <atomic>
+#include <optional>
 
 #include "celeborn/client/writer/PushStrategy.h"
 #include "celeborn/conf/CelebornConf.h"
@@ -36,7 +37,8 @@ class PushState {
 
   int nextBatchId();
 
-  void addBatch(int batchId, const std::string& hostAndPushPort);
+  void
+  addBatch(int batchId, int batchBytesSize, const std::string& 
hostAndPushPort);
 
   void onSuccess(const std::string& hostAndPushPort);
 
@@ -48,6 +50,10 @@ class PushState {
   // block until the ongoing package num decreases below max limit. If the
   // limit operation succeeds before timeout, return false, otherwise return
   // true.
+  // When maxBytesSizeInFlight is enabled, the limit check considers both
+  // request count and byte size limits. The push is allowed if either:
+  // 1. Request count is within limits, or
+  // 2. Byte size is within limits (when enabled)
   bool limitMaxInFlight(const std::string& hostAndPushPort);
 
   // Check if the pushState's ongoing package num reaches zero, if not, block
@@ -68,15 +74,23 @@ class PushState {
 
   std::atomic<int> currBatchId_{1};
   std::atomic<long> totalInflightReqs_{0};
+  std::atomic<long> totalInflightBytes_{0};
   const long waitInflightTimeoutMs_;
   const long deltaMs_;
   const std::unique_ptr<PushStrategy> pushStrategy_;
   const int maxInFlightReqsTotal_;
+  const bool maxInFlightBytesSizeEnabled_;
+  const long maxInFlightBytesSizeTotal_;
+  const long maxInFlightBytesSizePerWorker_;
   utils::ConcurrentHashMap<
       std::string,
       std::shared_ptr<utils::ConcurrentHashSet<int>>>
       inflightBatchesPerAddress_;
+  utils::ConcurrentHashMap<std::string, std::shared_ptr<std::atomic<long>>>
+      inflightBytesSizePerAddress_;
+  utils::ConcurrentHashMap<int, int> inflightBatchBytesSizes_;
   folly::Synchronized<std::unique_ptr<std::exception>> exception_;
+  std::atomic<bool> cleaned_{false};
 };
 
 } // namespace client
diff --git a/cpp/celeborn/conf/CelebornConf.cpp 
b/cpp/celeborn/conf/CelebornConf.cpp
index 50b48aa7f..b6da2f702 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -162,6 +162,10 @@ CelebornConf::defaultProperties() {
               kShuffleCompressionCodec,
               protocol::toString(protocol::CompressionCodec::NONE)),
           NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
+          STR_PROP(kClientPushBufferMaxSize, "64kB"),
+          BOOL_PROP(kClientPushMaxBytesSizeInFlightEnabled, false),
+          NONE_PROP(kClientPushMaxBytesSizeInFlightTotal),
+          NONE_PROP(kClientPushMaxBytesSizeInFlightPerWorker),
           // NUM_PROP(kNumExample, 50'000),
           // BOOL_PROP(kBoolExample, false),
       };
@@ -267,6 +271,43 @@ long CelebornConf::clientPushLimitInFlightSleepDeltaMs() 
const {
       optionalProperty(kClientPushLimitInFlightSleepDeltaMs).value());
 }
 
+int CelebornConf::clientPushBufferMaxSize() const {
+  return toCapacity(
+      optionalProperty(kClientPushBufferMaxSize).value(), CapacityUnit::BYTE);
+}
+
+bool CelebornConf::clientPushMaxBytesSizeInFlightEnabled() const {
+  return optionalProperty<bool>(kClientPushMaxBytesSizeInFlightEnabled)
+      .value_or(false);
+}
+
+long CelebornConf::clientPushMaxBytesSizeInFlightTotal() const {
+  auto optionalValue = optionalProperty(kClientPushMaxBytesSizeInFlightTotal);
+  long maxBytesSizeInFlight = optionalValue.has_value()
+      ? toCapacity(optionalValue.value(), CapacityUnit::BYTE)
+      : 0L;
+  if (clientPushMaxBytesSizeInFlightEnabled() && maxBytesSizeInFlight > 0L) {
+    return maxBytesSizeInFlight;
+  }
+  // Default: maxReqsInFlightTotal * bufferMaxSize
+  return static_cast<long>(clientPushMaxReqsInFlightTotal()) *
+      clientPushBufferMaxSize();
+}
+
+long CelebornConf::clientPushMaxBytesSizeInFlightPerWorker() const {
+  auto optionalValue =
+      optionalProperty(kClientPushMaxBytesSizeInFlightPerWorker);
+  long maxBytesSizeInFlight = optionalValue.has_value()
+      ? toCapacity(optionalValue.value(), CapacityUnit::BYTE)
+      : 0L;
+  if (clientPushMaxBytesSizeInFlightEnabled() && maxBytesSizeInFlight > 0L) {
+    return maxBytesSizeInFlight;
+  }
+  // Default: maxReqsInFlightPerWorker * bufferMaxSize
+  return static_cast<long>(clientPushMaxReqsInFlightPerWorker()) *
+      clientPushBufferMaxSize();
+}
+
 Timeout CelebornConf::clientRpcRequestPartitionLocationRpcAskTimeout() const {
   return utils::toTimeout(toDuration(
       optionalProperty(kClientRpcRequestPartitionLocationAskTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index fb4294d2c..e4299a482 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -89,6 +89,18 @@ class CelebornConf : public BaseConf {
   static constexpr std::string_view kClientPushLimitInFlightSleepDeltaMs{
       "celeborn.client.push.limit.inFlight.sleepInterval"};
 
+  static constexpr std::string_view kClientPushBufferMaxSize{
+      "celeborn.client.push.buffer.max.size"};
+
+  static constexpr std::string_view kClientPushMaxBytesSizeInFlightEnabled{
+      "celeborn.client.push.maxBytesSizeInFlight.enabled"};
+
+  static constexpr std::string_view kClientPushMaxBytesSizeInFlightTotal{
+      "celeborn.client.push.maxBytesSizeInFlight.total"};
+
+  static constexpr std::string_view kClientPushMaxBytesSizeInFlightPerWorker{
+      "celeborn.client.push.maxBytesSizeInFlight.perWorker"};
+
   static constexpr std::string_view
       kClientRpcRequestPartitionLocationAskTimeout{
           "celeborn.client.rpc.requestPartition.askTimeout"};
@@ -159,6 +171,14 @@ class CelebornConf : public BaseConf {
 
   long clientPushLimitInFlightSleepDeltaMs() const;
 
+  int clientPushBufferMaxSize() const;
+
+  bool clientPushMaxBytesSizeInFlightEnabled() const;
+
+  long clientPushMaxBytesSizeInFlightTotal() const;
+
+  long clientPushMaxBytesSizeInFlightPerWorker() const;
+
   Timeout clientRpcRequestPartitionLocationRpcAskTimeout() const;
 
   Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;

Reply via email to