This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 42ff69f9f [CELEBORN-2229][CIP-14] Add support for
celeborn.client.push.maxBytesSizeInFlight in CppClient
42ff69f9f is described below
commit 42ff69f9fdc9610daef4baa8919016bed5c36334
Author: afterincomparableyum <afterincomparableyum>
AuthorDate: Mon Jan 12 18:00:55 2026 +0800
[CELEBORN-2229][CIP-14] Add support for
celeborn.client.push.maxBytesSizeInFlight in CppClient
### What changes were proposed in this pull request?
Add support for celeborn.client.push.maxBytesSizeInFlight in CppClient,
similar to InFlightRequestTracker.java
### Does this PR resolve a correctness bug?
No
### How was this patch tested?
Compile locally, CI/CD will run unit tests
Closes #3568 from afterincomparableyum/cpp-client/celeborn-2229.
Authored-by: afterincomparableyum <afterincomparableyum>
Signed-off-by: Shuang <[email protected]>
---
cpp/celeborn/client/ShuffleClient.cpp | 3 +-
cpp/celeborn/client/tests/PushStateTest.cpp | 151 ++++++++++++++++++++----
cpp/celeborn/client/writer/PushDataCallback.cpp | 4 +-
cpp/celeborn/client/writer/PushDataCallback.h | 1 +
cpp/celeborn/client/writer/PushState.cpp | 103 ++++++++++++++--
cpp/celeborn/client/writer/PushState.h | 16 ++-
cpp/celeborn/conf/CelebornConf.cpp | 41 +++++++
cpp/celeborn/conf/CelebornConf.h | 20 ++++
8 files changed, 302 insertions(+), 37 deletions(-)
diff --git a/cpp/celeborn/client/ShuffleClient.cpp
b/cpp/celeborn/client/ShuffleClient.cpp
index 807fac2e0..401b99e2f 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -170,7 +170,8 @@ int ShuffleClientImpl::pushData(
// Check limit.
limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
// Add inFlight requests.
- pushState->addBatch(nextBatchId, hostAndPushPort);
+ const int batchBytesSize = length + kBatchHeaderSize;
+ pushState->addBatch(nextBatchId, batchBytesSize, hostAndPushPort);
// Build pushData request.
const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
diff --git a/cpp/celeborn/client/tests/PushStateTest.cpp
b/cpp/celeborn/client/tests/PushStateTest.cpp
index 94b6c3aa3..400538314 100644
--- a/cpp/celeborn/client/tests/PushStateTest.cpp
+++ b/cpp/celeborn/client/tests/PushStateTest.cpp
@@ -28,58 +28,97 @@ class PushStateTest : public testing::Test {
conf::CelebornConf conf;
conf.registerProperty(
conf::CelebornConf::kClientPushLimitInFlightTimeoutMs,
- std::to_string(pushTimeoutMs_));
+ std::to_string(kPushTimeoutMs_));
conf.registerProperty(
conf::CelebornConf::kClientPushLimitInFlightSleepDeltaMs,
- std::to_string(pushSleepDeltaMs_));
+ std::to_string(kPushSleepDeltaMs_));
conf.registerProperty(
conf::CelebornConf::kClientPushMaxReqsInFlightTotal,
- std::to_string(maxReqsInFlight_));
+ std::to_string(kMaxReqsInFlight_));
conf.registerProperty(
conf::CelebornConf::kClientPushMaxReqsInFlightPerWorker,
- std::to_string(maxReqsInFlight_));
+ std::to_string(kMaxReqsInFlight_));
pushState_ = std::make_unique<PushState>(conf);
}
std::unique_ptr<PushState> pushState_;
- static constexpr int pushTimeoutMs_ = 100;
- static constexpr int pushSleepDeltaMs_ = 10;
- static constexpr int maxReqsInFlight_ = 2;
+ static constexpr int kPushTimeoutMs_ = 100;
+ static constexpr int kPushSleepDeltaMs_ = 10;
+ static constexpr int kMaxReqsInFlight_ = 2;
+ static constexpr int kDefaultBatchSize_ = 1024;
+};
+
+class PushStateBytesSizeTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ conf::CelebornConf conf;
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushLimitInFlightTimeoutMs,
+ std::to_string(kPushTimeoutMs_));
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushLimitInFlightSleepDeltaMs,
+ std::to_string(kPushSleepDeltaMs_));
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushMaxReqsInFlightTotal, "2");
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushMaxReqsInFlightPerWorker, "100");
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushMaxBytesSizeInFlightEnabled, "true");
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushMaxBytesSizeInFlightTotal,
+ std::to_string(kMaxBytesSizeTotal_) + "B");
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushMaxBytesSizeInFlightPerWorker,
+ std::to_string(kMaxBytesSizePerWorker_) + "B");
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushBufferMaxSize,
+ std::to_string(kBufferMaxSize_) + "B");
+
+ pushState_ = std::make_unique<PushState>(conf);
+ }
+
+ std::unique_ptr<PushState> pushState_;
+ static constexpr int kPushTimeoutMs_ = 100;
+ static constexpr int kPushSleepDeltaMs_ = 10;
+ static constexpr int kBatchSize_ = 1024;
+ static constexpr long kMaxBytesSizeTotal_ = 3000;
+ static constexpr long kMaxBytesSizePerWorker_ = 2500;
+ static constexpr int kBufferMaxSize_ = 65536;
};
TEST_F(PushStateTest, limitMaxInFlight) {
const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
- const int addBatchCalls = maxReqsInFlight_ + 1;
+ const int addBatchCalls = kMaxReqsInFlight_ + 1;
std::vector<bool> addBatchMarks(addBatchCalls, false);
std::thread addBatchThread([&]() {
for (auto i = 0; i < addBatchCalls; i++) {
- pushState_->addBatch(i, hostAndPushPort);
+ pushState_->addBatch(i, kDefaultBatchSize_, hostAndPushPort);
EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort));
addBatchMarks[i] = true;
}
});
- std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
- for (auto i = 0; i < maxReqsInFlight_; i++) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+ for (auto i = 0; i < kMaxReqsInFlight_; i++) {
EXPECT_TRUE(addBatchMarks[i]);
}
- EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+ EXPECT_FALSE(addBatchMarks[kMaxReqsInFlight_]);
pushState_->removeBatch(0, hostAndPushPort);
addBatchThread.join();
- EXPECT_TRUE(addBatchMarks[maxReqsInFlight_]);
+ EXPECT_TRUE(addBatchMarks[kMaxReqsInFlight_]);
}
TEST_F(PushStateTest, limitMaxInFlightTimeout) {
const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
- const int addBatchCalls = maxReqsInFlight_ + 1;
+ const int addBatchCalls = kMaxReqsInFlight_ + 1;
std::vector<bool> addBatchMarks(addBatchCalls, false);
std::thread addBatchThread([&]() {
for (auto i = 0; i < addBatchCalls; i++) {
- pushState_->addBatch(i, hostAndPushPort);
+ pushState_->addBatch(i, kDefaultBatchSize_, hostAndPushPort);
auto result = pushState_->limitMaxInFlight(hostAndPushPort);
- if (i < maxReqsInFlight_) {
+ if (i < kMaxReqsInFlight_) {
EXPECT_FALSE(result);
} else {
EXPECT_TRUE(result);
@@ -88,14 +127,14 @@ TEST_F(PushStateTest, limitMaxInFlightTimeout) {
}
});
- std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
- for (auto i = 0; i < maxReqsInFlight_; i++) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+ for (auto i = 0; i < kMaxReqsInFlight_; i++) {
EXPECT_TRUE(addBatchMarks[i]);
}
- EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+ EXPECT_FALSE(addBatchMarks[kMaxReqsInFlight_]);
addBatchThread.join();
- EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+ EXPECT_FALSE(addBatchMarks[kMaxReqsInFlight_]);
}
TEST_F(PushStateTest, limitZeroInFlight) {
@@ -103,12 +142,12 @@ TEST_F(PushStateTest, limitZeroInFlight) {
const int addBatchCalls = 1;
std::vector<bool> addBatchMarks(addBatchCalls, false);
std::thread addBatchThread([&]() {
- pushState_->addBatch(0, hostAndPushPort);
+ pushState_->addBatch(0, kDefaultBatchSize_, hostAndPushPort);
EXPECT_FALSE(pushState_->limitZeroInFlight());
addBatchMarks[0] = true;
});
- std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+ std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
EXPECT_FALSE(addBatchMarks[0]);
pushState_->removeBatch(0, hostAndPushPort);
@@ -121,13 +160,13 @@ TEST_F(PushStateTest, limitZeroInFlightTimeout) {
const int addBatchCalls = 1;
std::vector<bool> addBatchMarks(addBatchCalls, false);
std::thread addBatchThread([&]() {
- pushState_->addBatch(0, hostAndPushPort);
+ pushState_->addBatch(0, kDefaultBatchSize_, hostAndPushPort);
auto result = pushState_->limitZeroInFlight();
EXPECT_TRUE(result);
addBatchMarks[0] = !result;
});
- std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+ std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
EXPECT_FALSE(addBatchMarks[0]);
addBatchThread.join();
@@ -153,3 +192,67 @@ TEST_F(PushStateTest, throwException) {
}
EXPECT_TRUE(exceptionThrowed);
}
+
+TEST_F(PushStateBytesSizeTest, limitMaxInFlightByBytesSize) {
+ const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+ const int expectedAllowedBatches = 2;
+ const int addBatchCalls = expectedAllowedBatches + 1;
+ std::vector<bool> addBatchMarks(addBatchCalls, false);
+
+ std::thread addBatchThread([&]() {
+ for (auto i = 0; i < addBatchCalls; i++) {
+ pushState_->addBatch(i, kBatchSize_, hostAndPushPort);
+ auto result = pushState_->limitMaxInFlight(hostAndPushPort);
+ addBatchMarks[i] = true;
+ if (i < expectedAllowedBatches) {
+ EXPECT_FALSE(result) << "Batch " << i << " should be within limits";
+ }
+ }
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+ for (auto i = 0; i < expectedAllowedBatches; i++) {
+ EXPECT_TRUE(addBatchMarks[i]) << "Batch " << i << " should have completed";
+ }
+
+ pushState_->removeBatch(0, hostAndPushPort);
+ addBatchThread.join();
+ EXPECT_TRUE(addBatchMarks[expectedAllowedBatches]);
+}
+
+TEST_F(PushStateBytesSizeTest, limitMaxInFlightByTotalBytesSize) {
+ const std::string hostAndPushPort1 = "xx.xx.xx.xx:8080";
+ const std::string hostAndPushPort2 = "yy.yy.yy.yy:8080";
+
+ pushState_->addBatch(0, kBatchSize_, hostAndPushPort1);
+ EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort1));
+
+ pushState_->addBatch(1, kBatchSize_, hostAndPushPort2);
+ EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort2));
+
+ std::atomic<bool> thirdBatchCompleted{false};
+ std::thread addBatchThread([&]() {
+ pushState_->addBatch(2, kBatchSize_, hostAndPushPort1);
+ pushState_->limitMaxInFlight(hostAndPushPort1);
+ thirdBatchCompleted = true;
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(kPushSleepDeltaMs_));
+ EXPECT_FALSE(thirdBatchCompleted.load())
+ << "Third batch should be blocked due to total bytes limit";
+
+ pushState_->removeBatch(0, hostAndPushPort1);
+ addBatchThread.join();
+
+ EXPECT_TRUE(thirdBatchCompleted.load());
+}
+
+TEST_F(PushStateBytesSizeTest, cleanupClearsBytesSizeTracking) {
+ const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+
+ pushState_->addBatch(0, kBatchSize_, hostAndPushPort);
+ pushState_->addBatch(1, kBatchSize_, hostAndPushPort);
+ pushState_->cleanup();
+
+ EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort));
+}
diff --git a/cpp/celeborn/client/writer/PushDataCallback.cpp
b/cpp/celeborn/client/writer/PushDataCallback.cpp
index 43550ba7f..3b0218088 100644
--- a/cpp/celeborn/client/writer/PushDataCallback.cpp
+++ b/cpp/celeborn/client/writer/PushDataCallback.cpp
@@ -73,6 +73,7 @@ PushDataCallback::PushDataCallback(
numPartitions_(numPartitions),
mapKey_(mapKey),
batchId_(batchId),
+ batchBytesSize_(databody ? static_cast<int>(databody->size()) : 0),
databody_(std::move(databody)),
pushState_(pushState),
weakClient_(weakClient),
@@ -208,7 +209,8 @@ void
PushDataCallback::onFailure(std::unique_ptr<std::exception> exception) {
void PushDataCallback::updateLatestLocation(
std::shared_ptr<const protocol::PartitionLocation> latestLocation) {
- pushState_->addBatch(batchId_, latestLocation->hostAndPushPort());
+ pushState_->addBatch(
+ batchId_, batchBytesSize_, latestLocation->hostAndPushPort());
pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
latestLocation_ = latestLocation;
}
diff --git a/cpp/celeborn/client/writer/PushDataCallback.h
b/cpp/celeborn/client/writer/PushDataCallback.h
index 9916cd191..b1c8334b8 100644
--- a/cpp/celeborn/client/writer/PushDataCallback.h
+++ b/cpp/celeborn/client/writer/PushDataCallback.h
@@ -85,6 +85,7 @@ class PushDataCallback : public network::RpcResponseCallback,
const int numPartitions_;
const std::string mapKey_;
const int batchId_;
+ const int batchBytesSize_;
const std::unique_ptr<memory::ReadOnlyByteBuffer> databody_;
const std::shared_ptr<PushState> pushState_;
const std::weak_ptr<ShuffleClientImpl> weakClient_;
diff --git a/cpp/celeborn/client/writer/PushState.cpp
b/cpp/celeborn/client/writer/PushState.cpp
index b86a2aadb..8b3158e36 100644
--- a/cpp/celeborn/client/writer/PushState.cpp
+++ b/cpp/celeborn/client/writer/PushState.cpp
@@ -24,18 +24,35 @@ PushState::PushState(const conf::CelebornConf& conf)
: waitInflightTimeoutMs_(conf.clientPushLimitInFlightTimeoutMs()),
deltaMs_(conf.clientPushLimitInFlightSleepDeltaMs()),
pushStrategy_(PushStrategy::create(conf)),
- maxInFlightReqsTotal_(conf.clientPushMaxReqsInFlightTotal()) {}
+ maxInFlightReqsTotal_(conf.clientPushMaxReqsInFlightTotal()),
+ maxInFlightBytesSizeEnabled_(
+ conf.clientPushMaxBytesSizeInFlightEnabled()),
+ maxInFlightBytesSizeTotal_(conf.clientPushMaxBytesSizeInFlightTotal()),
+ maxInFlightBytesSizePerWorker_(
+ conf.clientPushMaxBytesSizeInFlightPerWorker()) {}
int PushState::nextBatchId() {
return currBatchId_.fetch_add(1);
}
-void PushState::addBatch(int batchId, const std::string& hostAndPushPort) {
+void PushState::addBatch(
+ int batchId,
+ int batchBytesSize,
+ const std::string& hostAndPushPort) {
auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
hostAndPushPort,
[&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
batchIdSet->insert(batchId);
totalInflightReqs_.fetch_add(1);
+
+ if (maxInFlightBytesSizeEnabled_) {
+ auto bytesSizePerAddress = inflightBytesSizePerAddress_.computeIfAbsent(
+ hostAndPushPort,
+ [&]() { return std::make_shared<std::atomic<long>>(0); });
+ bytesSizePerAddress->fetch_add(batchBytesSize);
+ inflightBatchBytesSizes_.set(batchId, batchBytesSize);
+ totalInflightBytes_.fetch_add(batchBytesSize);
+ }
}
void PushState::onSuccess(const std::string& hostAndPushPort) {
@@ -51,10 +68,24 @@ void PushState::removeBatch(int batchId, const std::string&
hostAndPushPort) {
if (batchIdSetOptional.has_value()) {
auto batchIdSet = batchIdSetOptional.value();
batchIdSet->erase(batchId);
- totalInflightReqs_.fetch_sub(1);
} else {
LOG(WARNING) << "BatchIdSet of " << hostAndPushPort << " doesn't exist.";
}
+
+ totalInflightReqs_.fetch_sub(1);
+
+ if (maxInFlightBytesSizeEnabled_) {
+ auto inflightBatchBytesSize = inflightBatchBytesSizes_.get(batchId);
+ inflightBatchBytesSizes_.erase(batchId);
+ if (inflightBatchBytesSize.has_value()) {
+ auto inflightBytesSize =
+ inflightBytesSizePerAddress_.get(hostAndPushPort);
+ if (inflightBytesSize.has_value()) {
+ inflightBytesSize.value()->fetch_sub(inflightBatchBytesSize.value());
+ }
+ totalInflightBytes_.fetch_sub(inflightBatchBytesSize.value());
+ }
+ }
}
bool PushState::limitMaxInFlight(const std::string& hostAndPushPort) {
@@ -67,22 +98,62 @@ bool PushState::limitMaxInFlight(const std::string&
hostAndPushPort) {
auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
hostAndPushPort,
[&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+ std::shared_ptr<std::atomic<long>> batchBytesSize = nullptr;
+ if (maxInFlightBytesSizeEnabled_) {
+ batchBytesSize = inflightBytesSizePerAddress_.computeIfAbsent(
+ hostAndPushPort,
+ [&]() { return std::make_shared<std::atomic<long>>(0); });
+ }
long times = waitInflightTimeoutMs_ / deltaMs_;
for (; times > 0; times--) {
- if (totalInflightReqs_ <= maxInFlightReqsTotal_ &&
- batchIdSet->size() <= currentMaxReqsInFlight) {
+ if (cleaned_.load()) {
+ return false;
+ }
+
+ bool reqCountWithinLimits =
+ (totalInflightReqs_ <= maxInFlightReqsTotal_ &&
+ static_cast<int>(batchIdSet->size()) <= currentMaxReqsInFlight);
+ bool byteSizeWithinLimits = false;
+
+ if (maxInFlightBytesSizeEnabled_ && batchBytesSize) {
+ byteSizeWithinLimits =
+ (totalInflightBytes_.load() <= maxInFlightBytesSizeTotal_ &&
+ batchBytesSize->load() <= maxInFlightBytesSizePerWorker_);
+ }
+
+ if (reqCountWithinLimits ||
+ (maxInFlightBytesSizeEnabled_ && byteSizeWithinLimits)) {
break;
}
+
throwIfExceptionExists();
std::this_thread::sleep_for(utils::MS(deltaMs_));
}
if (times <= 0) {
- LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
- << " ms, there are still " << batchIdSet->size()
- << " batches in flight for hostAndPushPort " <<
hostAndPushPort
- << ", which exceeds the current limit "
- << currentMaxReqsInFlight;
+ if (totalInflightReqs_ > maxInFlightReqsTotal_ ||
+ static_cast<int>(batchIdSet->size()) > currentMaxReqsInFlight) {
+ LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
+ << " ms, there are still " << totalInflightReqs_
+ << " requests in flight (limit: " << maxInFlightReqsTotal_
+ << "): " << batchIdSet->size()
+ << " batches in flight for hostAndPushPort "
+ << hostAndPushPort << ", which exceeds the current limit "
+ << currentMaxReqsInFlight;
+ }
+ if (maxInFlightBytesSizeEnabled_ && batchBytesSize) {
+ if (totalInflightBytes_.load() > maxInFlightBytesSizeTotal_ ||
+ batchBytesSize->load() > maxInFlightBytesSizePerWorker_) {
+ LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
+ << " ms, there are still " << totalInflightBytes_.load()
+ << " bytes in flight (limit: "
+ << maxInFlightBytesSizeTotal_
+ << "): " << batchBytesSize->load()
+ << " bytes for hostAndPushPort " << hostAndPushPort
+ << ", which exceeds the current limit "
+ << maxInFlightBytesSizePerWorker_;
+ }
+ }
}
throwIfExceptionExists();
return times <= 0;
@@ -93,6 +164,9 @@ bool PushState::limitZeroInFlight() {
long times = waitInflightTimeoutMs_ / deltaMs_;
for (; times > 0; times--) {
+ if (cleaned_.load()) {
+ return false;
+ }
if (totalInflightReqs_ <= 0) {
break;
}
@@ -147,9 +221,18 @@ std::optional<std::string> PushState::getExceptionMsg()
const {
}
void PushState::cleanup() {
+ LOG(INFO) << "Cleanup " << totalInflightReqs_.load()
+ << " requests in flight.";
+ cleaned_.store(true);
inflightBatchesPerAddress_.clear();
totalInflightReqs_ = 0;
pushStrategy_->clear();
+
+ if (maxInFlightBytesSizeEnabled_) {
+ inflightBytesSizePerAddress_.clear();
+ inflightBatchBytesSizes_.clear();
+ totalInflightBytes_ = 0;
+ }
}
void PushState::throwIfExceptionExists() {
diff --git a/cpp/celeborn/client/writer/PushState.h
b/cpp/celeborn/client/writer/PushState.h
index a32b0971c..38ee9607f 100644
--- a/cpp/celeborn/client/writer/PushState.h
+++ b/cpp/celeborn/client/writer/PushState.h
@@ -18,6 +18,7 @@
#pragma once
#include <atomic>
+#include <optional>
#include "celeborn/client/writer/PushStrategy.h"
#include "celeborn/conf/CelebornConf.h"
@@ -36,7 +37,8 @@ class PushState {
int nextBatchId();
- void addBatch(int batchId, const std::string& hostAndPushPort);
+ void
+ addBatch(int batchId, int batchBytesSize, const std::string&
hostAndPushPort);
void onSuccess(const std::string& hostAndPushPort);
@@ -48,6 +50,10 @@ class PushState {
// block until the ongoing package num decreases below max limit. If the
// limit operation succeeds before timeout, return false, otherwise return
// true.
+ // When maxBytesSizeInFlight is enabled, the limit check considers both
+ // request count and byte size limits. The push is allowed if either:
+ // 1. Request count is within limits, or
+ // 2. Byte size is within limits (when enabled)
bool limitMaxInFlight(const std::string& hostAndPushPort);
// Check if the pushState's ongoing package num reaches zero, if not, block
@@ -68,15 +74,23 @@ class PushState {
std::atomic<int> currBatchId_{1};
std::atomic<long> totalInflightReqs_{0};
+ std::atomic<long> totalInflightBytes_{0};
const long waitInflightTimeoutMs_;
const long deltaMs_;
const std::unique_ptr<PushStrategy> pushStrategy_;
const int maxInFlightReqsTotal_;
+ const bool maxInFlightBytesSizeEnabled_;
+ const long maxInFlightBytesSizeTotal_;
+ const long maxInFlightBytesSizePerWorker_;
utils::ConcurrentHashMap<
std::string,
std::shared_ptr<utils::ConcurrentHashSet<int>>>
inflightBatchesPerAddress_;
+ utils::ConcurrentHashMap<std::string, std::shared_ptr<std::atomic<long>>>
+ inflightBytesSizePerAddress_;
+ utils::ConcurrentHashMap<int, int> inflightBatchBytesSizes_;
folly::Synchronized<std::unique_ptr<std::exception>> exception_;
+ std::atomic<bool> cleaned_{false};
};
} // namespace client
diff --git a/cpp/celeborn/conf/CelebornConf.cpp
b/cpp/celeborn/conf/CelebornConf.cpp
index 50b48aa7f..b6da2f702 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -162,6 +162,10 @@ CelebornConf::defaultProperties() {
kShuffleCompressionCodec,
protocol::toString(protocol::CompressionCodec::NONE)),
NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
+ STR_PROP(kClientPushBufferMaxSize, "64kB"),
+ BOOL_PROP(kClientPushMaxBytesSizeInFlightEnabled, false),
+ NONE_PROP(kClientPushMaxBytesSizeInFlightTotal),
+ NONE_PROP(kClientPushMaxBytesSizeInFlightPerWorker),
// NUM_PROP(kNumExample, 50'000),
// BOOL_PROP(kBoolExample, false),
};
@@ -267,6 +271,43 @@ long CelebornConf::clientPushLimitInFlightSleepDeltaMs()
const {
optionalProperty(kClientPushLimitInFlightSleepDeltaMs).value());
}
+int CelebornConf::clientPushBufferMaxSize() const {
+ return toCapacity(
+ optionalProperty(kClientPushBufferMaxSize).value(), CapacityUnit::BYTE);
+}
+
+bool CelebornConf::clientPushMaxBytesSizeInFlightEnabled() const {
+ return optionalProperty<bool>(kClientPushMaxBytesSizeInFlightEnabled)
+ .value_or(false);
+}
+
+long CelebornConf::clientPushMaxBytesSizeInFlightTotal() const {
+ auto optionalValue = optionalProperty(kClientPushMaxBytesSizeInFlightTotal);
+ long maxBytesSizeInFlight = optionalValue.has_value()
+ ? toCapacity(optionalValue.value(), CapacityUnit::BYTE)
+ : 0L;
+ if (clientPushMaxBytesSizeInFlightEnabled() && maxBytesSizeInFlight > 0L) {
+ return maxBytesSizeInFlight;
+ }
+ // Default: maxReqsInFlightTotal * bufferMaxSize
+ return static_cast<long>(clientPushMaxReqsInFlightTotal()) *
+ clientPushBufferMaxSize();
+}
+
+long CelebornConf::clientPushMaxBytesSizeInFlightPerWorker() const {
+ auto optionalValue =
+ optionalProperty(kClientPushMaxBytesSizeInFlightPerWorker);
+ long maxBytesSizeInFlight = optionalValue.has_value()
+ ? toCapacity(optionalValue.value(), CapacityUnit::BYTE)
+ : 0L;
+ if (clientPushMaxBytesSizeInFlightEnabled() && maxBytesSizeInFlight > 0L) {
+ return maxBytesSizeInFlight;
+ }
+ // Default: maxReqsInFlightPerWorker * bufferMaxSize
+ return static_cast<long>(clientPushMaxReqsInFlightPerWorker()) *
+ clientPushBufferMaxSize();
+}
+
Timeout CelebornConf::clientRpcRequestPartitionLocationRpcAskTimeout() const {
return utils::toTimeout(toDuration(
optionalProperty(kClientRpcRequestPartitionLocationAskTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index fb4294d2c..e4299a482 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -89,6 +89,18 @@ class CelebornConf : public BaseConf {
static constexpr std::string_view kClientPushLimitInFlightSleepDeltaMs{
"celeborn.client.push.limit.inFlight.sleepInterval"};
+ static constexpr std::string_view kClientPushBufferMaxSize{
+ "celeborn.client.push.buffer.max.size"};
+
+ static constexpr std::string_view kClientPushMaxBytesSizeInFlightEnabled{
+ "celeborn.client.push.maxBytesSizeInFlight.enabled"};
+
+ static constexpr std::string_view kClientPushMaxBytesSizeInFlightTotal{
+ "celeborn.client.push.maxBytesSizeInFlight.total"};
+
+ static constexpr std::string_view kClientPushMaxBytesSizeInFlightPerWorker{
+ "celeborn.client.push.maxBytesSizeInFlight.perWorker"};
+
static constexpr std::string_view
kClientRpcRequestPartitionLocationAskTimeout{
"celeborn.client.rpc.requestPartition.askTimeout"};
@@ -159,6 +171,14 @@ class CelebornConf : public BaseConf {
long clientPushLimitInFlightSleepDeltaMs() const;
+ int clientPushBufferMaxSize() const;
+
+ bool clientPushMaxBytesSizeInFlightEnabled() const;
+
+ long clientPushMaxBytesSizeInFlightTotal() const;
+
+ long clientPushMaxBytesSizeInFlightPerWorker() const;
+
Timeout clientRpcRequestPartitionLocationRpcAskTimeout() const;
Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;