This is an automated email from the ASF dual-hosted git repository.
feiwang pushed a commit to branch branch-0.6
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/branch-0.6 by this push:
new d42debd2c [CELEBORN-1917] Support
celeborn.client.push.maxBytesSizeInFlight
d42debd2c is described below
commit d42debd2ce1eb56a8b01e526a5286cec010cca5f
Author: DDDominik <[email protected]>
AuthorDate: Tue Jul 22 23:07:56 2025 +0800
[CELEBORN-1917] Support celeborn.client.push.maxBytesSizeInFlight
### What changes were proposed in this pull request?
add data size limitation to inflight data by introducing a new
configuration: `celeborn.client.push.maxBytesInFlight.perWorker/total` and
defaults to `celeborn.client.push.buffer.max.size *
celeborn.client.push.maxReqsInFlight.perWorker/total`.
for backward compatibility, also add a control:
`celeborn.client.push.maxReqsInFlight.enabled`.
### Why are the changes needed?
celeborn do supports limiting the number of push inflight requests via
`celeborn.client.push.maxReqsInFlight.perWorker/total`. this is a good
constraint to memory usage where most requests do not exceed
`celeborn.client.push.buffer.max.size`. however, in a vectorized shuffle (like
blaze and gluten), a request might be greatly larger then the max buffer size,
leading to too much inflight data and results OOM.
### Does this PR introduce _any_ user-facing change?
Yes, add new config for client
### How was this patch tested?
test on local env
Closes #3248 from DDDominik/CELEBORN-1917.
Lead-authored-by: DDDominik <[email protected]>
Co-authored-by: SteNicholas <[email protected]>
Co-authored-by: DDDominik <[email protected]>
Signed-off-by: mingji <[email protected]>
(cherry picked from commit 0ed590dc81e160c8df9d89f9345b012177fd4938)
Signed-off-by: Wang, Fei <[email protected]>
---
.../flink/client/FlinkShuffleClientImpl.java | 2 +-
.../apache/celeborn/client/ShuffleClientImpl.java | 7 +-
.../celeborn/client/write/DataPushQueueSuiteJ.java | 2 +-
.../common/write/InFlightRequestTracker.java | 94 ++++++++++++++++++----
.../apache/celeborn/common/write/PushState.java | 4 +-
.../org/apache/celeborn/common/CelebornConf.scala | 53 ++++++++++++
docs/configuration/client.md | 3 +
7 files changed, 143 insertions(+), 22 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
index 2cb9df565..5b402b2ed 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
@@ -340,7 +340,7 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
limitMaxInFlight(mapKey, pushState, location.hostAndPushPort());
// add inFlight requests
- pushState.addBatch(nextBatchId, location.hostAndPushPort());
+ pushState.addBatch(nextBatchId, totalLength, location.hostAndPushPort());
// build PushData request
NettyManagedBuffer buffer = new NettyManagedBuffer(data);
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 828eb3158..2a272e83a 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -1034,7 +1034,7 @@ public class ShuffleClientImpl extends ShuffleClient {
limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort());
// add inFlight requests
- pushState.addBatch(nextBatchId, loc.hostAndPushPort());
+ pushState.addBatch(nextBatchId, body.length, loc.hostAndPushPort());
// build PushData request
NettyManagedBuffer buffer = new
NettyManagedBuffer(Unpooled.wrappedBuffer(body));
@@ -1078,7 +1078,7 @@ public class ShuffleClientImpl extends ShuffleClient {
@Override
public void updateLatestPartition(PartitionLocation newloc) {
- pushState.addBatch(nextBatchId, newloc.hostAndPushPort());
+ pushState.addBatch(nextBatchId, body.length,
newloc.hostAndPushPort());
pushState.removeBatch(nextBatchId,
this.latest.hostAndPushPort());
this.latest = newloc;
}
@@ -1409,7 +1409,8 @@ public class ShuffleClientImpl extends ShuffleClient {
final int port = Integer.parseInt(hostPortArr[1]);
int groupedBatchId = pushState.nextBatchId();
- pushState.addBatch(groupedBatchId, hostPort);
+ int groupedBatchBytesSize = batches.stream().mapToInt(batch ->
batch.body.length).sum();
+ pushState.addBatch(groupedBatchId, groupedBatchBytesSize, hostPort);
final int numBatches = batches.size();
final Integer[] partitionIds = new Integer[numBatches];
diff --git
a/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java
index 99a27eb3d..7081b79d0 100644
---
a/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java
+++
b/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java
@@ -124,7 +124,7 @@ public class DataPushQueueSuiteJ {
for (int i = 0; i < numPartitions; i++) {
byte[] b = intToBytes(workerData.get(i % numWorker).get(i / numWorker));
int batchId = pushState.nextBatchId();
- pushState.addBatch(batchId, reducePartitionMap.get(i).hostAndPushPort());
+ pushState.addBatch(batchId, b.length,
reducePartitionMap.get(i).hostAndPushPort());
partitionBatchIdMap.put(i, batchId);
dataPusher.addTask(i, b, b.length);
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java
b/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java
index 2fd6a3789..613f7784b 100644
---
a/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java
+++
b/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java
@@ -18,6 +18,7 @@
package org.apache.celeborn.common.write;
import java.io.IOException;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@@ -45,9 +46,15 @@ public class InFlightRequestTracker {
private final AtomicInteger batchId = new AtomicInteger();
private final ConcurrentHashMap<String, Set<Integer>>
inflightBatchesPerAddress =
JavaUtils.newConcurrentHashMap();
+ private ConcurrentHashMap<String, LongAdder> inflightBytesSizePerAddress =
null;
+ private ConcurrentHashMap<Integer, Integer> inflightBatchBytesSizes = null;
private final int maxInFlightReqsTotal;
+ private final boolean maxInFlightBytesSizeEnabled;
+ private Long maxInFlightBytesSizeTotal = null;
+ private Long maxInFlightBytesSizePerWorker = null;
private final LongAdder totalInflightReqs = new LongAdder();
+ private LongAdder totalInflightBytes = null;
private volatile boolean cleaned = false;
@@ -57,14 +64,29 @@ public class InFlightRequestTracker {
this.pushState = pushState;
this.pushStrategy = PushStrategy.getStrategy(conf);
this.maxInFlightReqsTotal = conf.clientPushMaxReqsInFlightTotal();
+ this.maxInFlightBytesSizeEnabled =
conf.clientPushMaxBytesSizeInFlightEnabled();
+ if (this.maxInFlightBytesSizeEnabled) {
+ this.inflightBytesSizePerAddress = JavaUtils.newConcurrentHashMap();
+ this.inflightBatchBytesSizes = JavaUtils.newConcurrentHashMap();
+ this.maxInFlightBytesSizeTotal =
conf.clientPushMaxBytesSizeInFlightTotal();
+ this.maxInFlightBytesSizePerWorker =
conf.clientPushMaxBytesSizeInFlightPerWorker();
+ this.totalInflightBytes = new LongAdder();
+ }
}
- public void addBatch(int batchId, String hostAndPushPort) {
+ public void addBatch(int batchId, int batchBytesSize, String
hostAndPushPort) {
Set<Integer> batchIdSetPerPair =
inflightBatchesPerAddress.computeIfAbsent(
hostAndPushPort, id -> ConcurrentHashMap.newKeySet());
batchIdSetPerPair.add(batchId);
totalInflightReqs.increment();
+ if (maxInFlightBytesSizeEnabled) {
+ LongAdder bytesSizePerPair =
+ inflightBytesSizePerAddress.computeIfAbsent(hostAndPushPort, id ->
new LongAdder());
+ bytesSizePerPair.add(batchBytesSize);
+ inflightBatchBytesSizes.put(batchId, batchBytesSize);
+ totalInflightBytes.add(batchBytesSize);
+ }
}
public void removeBatch(int batchId, String hostAndPushPort) {
@@ -75,6 +97,15 @@ public class InFlightRequestTracker {
logger.info("Batches of {} in flight is null.", hostAndPushPort);
}
totalInflightReqs.decrement();
+ if (maxInFlightBytesSizeEnabled) {
+ int inflightBatchBytesSize =
+
-Optional.ofNullable(inflightBatchBytesSizes.remove(batchId)).orElse(0);
+ LongAdder inflightBytesSize =
inflightBytesSizePerAddress.get(hostAndPushPort);
+ if (inflightBytesSize != null) {
+ inflightBytesSize.add(inflightBatchBytesSize);
+ }
+ totalInflightBytes.add(inflightBatchBytesSize);
+ }
}
public void onSuccess(String hostAndPushPort) {
@@ -90,6 +121,12 @@ public class InFlightRequestTracker {
hostAndPort, pair -> ConcurrentHashMap.newKeySet());
}
+ public LongAdder getBatchBytesSizeByAddressPair(String hostAndPort) {
+ return maxInFlightBytesSizeEnabled
+ ? inflightBytesSizePerAddress.computeIfAbsent(hostAndPort, id -> new
LongAdder())
+ : new LongAdder();
+ }
+
public boolean limitMaxInFlight(String hostAndPushPort) throws IOException {
if (pushState.exception.get() != null) {
throw pushState.exception.get();
@@ -99,6 +136,7 @@ public class InFlightRequestTracker {
int currentMaxReqsInFlight =
pushStrategy.getCurrentMaxReqsInFlight(hostAndPushPort);
Set<Integer> batchIdSet = getBatchIdSetByAddressPair(hostAndPushPort);
+ LongAdder batchBytesSize = getBatchBytesSizeByAddressPair(hostAndPushPort);
long times = waitInflightTimeoutMs / delta;
try {
while (times > 0) {
@@ -106,8 +144,11 @@ public class InFlightRequestTracker {
// MapEnd cleans up push state, which does not exceed the max
requests in flight limit.
return false;
} else {
- if (totalInflightReqs.sum() <= maxInFlightReqsTotal
- && batchIdSet.size() <= currentMaxReqsInFlight) {
+ if ((totalInflightReqs.sum() <= maxInFlightReqsTotal
+ && batchIdSet.size() <= currentMaxReqsInFlight)
+ || (maxInFlightBytesSizeEnabled
+ && totalInflightBytes.sum() <= maxInFlightBytesSizeTotal
+ && batchBytesSize.sum() <= maxInFlightBytesSizePerWorker)) {
break;
}
if (pushState.exception.get() != null) {
@@ -122,17 +163,35 @@ public class InFlightRequestTracker {
}
if (times <= 0) {
- logger.warn(
- "After waiting for {} ms, "
- + "there are still {} requests in flight (limit: {}): "
- + "{} batches for hostAndPushPort {}, "
- + "which exceeds the current limit {}.",
- waitInflightTimeoutMs,
- totalInflightReqs.sum(),
- maxInFlightReqsTotal,
- batchIdSet.size(),
- hostAndPushPort,
- currentMaxReqsInFlight);
+ if (totalInflightReqs.sum() > maxInFlightReqsTotal
+ || batchIdSet.size() > currentMaxReqsInFlight) {
+ logger.warn(
+ "After waiting for {} ms, "
+ + "there are still {} requests in flight (limit: {}): "
+ + "{} batches for hostAndPushPort {}, "
+ + "which exceeds the current limit {}.",
+ waitInflightTimeoutMs,
+ totalInflightReqs.sum(),
+ maxInFlightReqsTotal,
+ batchIdSet.size(),
+ hostAndPushPort,
+ currentMaxReqsInFlight);
+ }
+ if (maxInFlightBytesSizeEnabled
+ && (totalInflightBytes.sum() > maxInFlightBytesSizeTotal
+ || batchBytesSize.sum() > maxInFlightBytesSizePerWorker)) {
+ logger.warn(
+ "After waiting for {} ms, "
+ + "there are still {} bytes in flight (limit: {}): "
+ + "{} bytes for hostAndPushPort {}, "
+ + "which exceeds the current limit {}.",
+ waitInflightTimeoutMs,
+ totalInflightBytes.sum(),
+ maxInFlightBytesSizeTotal,
+ batchBytesSize.sum(),
+ hostAndPushPort,
+ maxInFlightBytesSizePerWorker);
+ }
}
if (pushState.exception.get() != null) {
@@ -201,9 +260,14 @@ public class InFlightRequestTracker {
logger.info(
"Cleanup {} requests and {} batches in flight.",
totalInflightReqs.sum(),
- inflightBatchesPerAddress.size());
+ inflightBatchesPerAddress.values().stream().mapToInt(Set::size).sum());
cleaned = true;
inflightBatchesPerAddress.clear();
pushStrategy.clear();
+ if (maxInFlightBytesSizeEnabled) {
+ logger.info("Cleanup {} bytes in flight.", totalInflightBytes.sum());
+ inflightBytesSizePerAddress.clear();
+ inflightBatchBytesSizes.clear();
+ }
}
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/write/PushState.java
b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
index afa22bb8a..56bea92bd 100644
--- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java
+++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
@@ -65,8 +65,8 @@ public class PushState {
return inFlightRequestTracker.nextBatchId();
}
- public void addBatch(int batchId, String hostAndPushPort) {
- inFlightRequestTracker.addBatch(batchId, hostAndPushPort);
+ public void addBatch(int batchId, int batchBytesSize, String
hostAndPushPort) {
+ inFlightRequestTracker.addBatch(batchId, batchBytesSize, hostAndPushPort);
}
public void removeBatch(int batchId, String hostAndPushPort) {
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 20f6d09cd..5148f31b9 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1022,6 +1022,24 @@ class CelebornConf(loadDefaults: Boolean) extends
Cloneable with Logging with Se
get(CLIENT_PUSH_EXCLUDE_WORKER_ON_FAILURE_ENABLED)
def clientPushMaxReqsInFlightPerWorker: Int =
get(CLIENT_PUSH_MAX_REQS_IN_FLIGHT_PERWORKER)
def clientPushMaxReqsInFlightTotal: Int =
get(CLIENT_PUSH_MAX_REQS_IN_FLIGHT_TOTAL)
+ def clientPushMaxBytesSizeInFlightEnabled: Boolean =
+ get(CLIENT_PUSH_MAX_BYTES_SIZE_IN_FLIGHT_ENABLED)
+ def clientPushMaxBytesSizeInFlightTotal: Long = {
+ val maxBytesSizeInFlight =
get(CLIENT_PUSH_MAX_BYTES_SIZE_IN_FLIGHT_TOTAL).getOrElse(0L)
+ if (clientPushMaxBytesSizeInFlightEnabled && maxBytesSizeInFlight > 0L) {
+ maxBytesSizeInFlight
+ } else {
+ clientPushMaxReqsInFlightTotal * clientPushBufferMaxSize
+ }
+ }
+ def clientPushMaxBytesSizeInFlightPerWorker: Long = {
+ val maxBytesSizeInFlight =
get(CLIENT_PUSH_MAX_BYTES_SIZE_IN_FLIGHT_PERWORKER).getOrElse(0L)
+ if (clientPushMaxBytesSizeInFlightEnabled && maxBytesSizeInFlight > 0L) {
+ maxBytesSizeInFlight
+ } else {
+ clientPushMaxReqsInFlightPerWorker * clientPushBufferMaxSize
+ }
+ }
def clientPushMaxReviveTimes: Int = get(CLIENT_PUSH_MAX_REVIVE_TIMES)
def clientPushReviveInterval: Long = get(CLIENT_PUSH_REVIVE_INTERVAL)
def clientPushReviveBatchSize: Int = get(CLIENT_PUSH_REVIVE_BATCHSIZE)
@@ -4599,6 +4617,15 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(512)
+ val CLIENT_PUSH_MAX_BYTES_SIZE_IN_FLIGHT_ENABLED: ConfigEntry[Boolean] =
+ buildConf("celeborn.client.push.maxBytesSizeInFlight.enabled")
+ .withAlternative("celeborn.push.maxBytesSizeInFlight.enabled")
+ .categories("client")
+ .version("0.6.1")
+ .doc("Whether
`celeborn.client.push.maxBytesSizeInFlight.perWorker/total` is enabled")
+ .booleanConf
+ .createWithDefault(false)
+
val CLIENT_PUSH_MAX_REQS_IN_FLIGHT_TOTAL: ConfigEntry[Int] =
buildConf("celeborn.client.push.maxReqsInFlight.total")
.withAlternative("celeborn.push.maxReqsInFlight")
@@ -4610,6 +4637,19 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(256)
+ val CLIENT_PUSH_MAX_BYTES_SIZE_IN_FLIGHT_TOTAL: OptionalConfigEntry[Long] =
+ buildConf("celeborn.client.push.maxBytesSizeInFlight.total")
+ .categories("client")
+ .version("0.6.1")
+ .doc(
+ "Bytes size of total Netty in-flight requests. The maximum memory is "
+
+ "`celeborn.client.push.maxReqsInFlight.total` *
`celeborn.client.push.buffer.max.size` " +
+ "* compression ratio(1 in worst case): 64KiB * 256 = 16MiB. " +
+ "This is an addition to `celeborn.client.push.maxReqsInFlight.total`
" +
+ "in cases where records are huge and exceed the maximum memory.")
+ .bytesConf(ByteUnit.BYTE)
+ .createOptional
+
val CLIENT_PUSH_MAX_REQS_IN_FLIGHT_PERWORKER: ConfigEntry[Int] =
buildConf("celeborn.client.push.maxReqsInFlight.perWorker")
.categories("client")
@@ -4622,6 +4662,19 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(32)
+ val CLIENT_PUSH_MAX_BYTES_SIZE_IN_FLIGHT_PERWORKER:
OptionalConfigEntry[Long] =
+ buildConf("celeborn.client.push.maxBytesSizeInFlight.perWorker")
+ .categories("client")
+ .version("0.6.1")
+ .doc(
+ "Bytes size of Netty in-flight requests per worker. Default max memory
of in flight requests " +
+ " per worker is `celeborn.client.push.maxReqsInFlight.perWorker` *
`celeborn.client.push.buffer.max.size` " +
+ "* compression ratio(1 in worst case): 64KiB * 32 = 2MiB. " +
+ "This is an alternative to
`celeborn.client.push.maxReqsInFlight.perWorker` " +
+ "in cases where records are huge and exceed the maximum memory.")
+ .bytesConf(ByteUnit.BYTE)
+ .createOptional
+
val CLIENT_PUSH_MAX_REVIVE_TIMES: ConfigEntry[Int] =
buildConf("celeborn.client.push.revive.maxRetries")
.categories("client")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 5c4cbf006..0401fe266 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -54,6 +54,9 @@ license: |
| celeborn.client.push.limit.inFlight.sleepInterval | 50ms | false | Sleep
interval when check netty in-flight requests to be done. | 0.3.0 |
celeborn.push.limit.inFlight.sleepInterval |
| celeborn.client.push.limit.inFlight.timeout | <undefined> | false |
Timeout for netty in-flight requests to be done. Default value should be
`celeborn.client.push.timeout * 2`. | 0.3.0 |
celeborn.push.limit.inFlight.timeout |
| celeborn.client.push.limit.strategy | SIMPLE | false | The strategy used to
control the push speed. Valid strategies are SIMPLE and SLOWSTART. The
SLOWSTART strategy usually works with congestion control mechanism on the
worker side. | 0.3.0 | |
+| celeborn.client.push.maxBytesSizeInFlight.enabled | false | false | Whether
`celeborn.client.push.maxBytesSizeInFlight.perWorker/total` is enabled | 0.6.1
| celeborn.push.maxBytesSizeInFlight.enabled |
+| celeborn.client.push.maxBytesSizeInFlight.perWorker | <undefined> |
false | Bytes size of Netty in-flight requests per worker. Default max memory
of in flight requests per worker is
`celeborn.client.push.maxReqsInFlight.perWorker` *
`celeborn.client.push.buffer.max.size` * compression ratio(1 in worst case):
64KiB * 32 = 2MiB. This is an alternative to
`celeborn.client.push.maxReqsInFlight.perWorker` in cases where records are
huge and exceed the maximum memory. | 0.6.1 | |
+| celeborn.client.push.maxBytesSizeInFlight.total | <undefined> | false
| Bytes size of total Netty in-flight requests. The maximum memory is
`celeborn.client.push.maxReqsInFlight.total` *
`celeborn.client.push.buffer.max.size` * compression ratio(1 in worst case):
64KiB * 256 = 16MiB. This is an addition to
`celeborn.client.push.maxReqsInFlight.total` in cases where records are huge
and exceed the maximum memory. | 0.6.1 | |
| celeborn.client.push.maxReqsInFlight.perWorker | 32 | false | Amount of
Netty in-flight requests per worker. Default max memory of in flight requests
per worker is `celeborn.client.push.maxReqsInFlight.perWorker` *
`celeborn.client.push.buffer.max.size` * compression ratio(1 in worst case):
64KiB * 32 = 2MiB. The maximum memory will not exceed
`celeborn.client.push.maxReqsInFlight.total`. | 0.3.0 | |
| celeborn.client.push.maxReqsInFlight.total | 256 | false | Amount of total
Netty in-flight requests. The maximum memory is
`celeborn.client.push.maxReqsInFlight.total` *
`celeborn.client.push.buffer.max.size` * compression ratio(1 in worst case):
64KiB * 256 = 16MiB | 0.3.0 | celeborn.push.maxReqsInFlight |
| celeborn.client.push.queue.capacity | 512 | false | Push buffer queue size
for a task. The maximum memory is `celeborn.client.push.buffer.max.size` *
`celeborn.client.push.queue.capacity`, default: 64KiB * 512 = 32MiB | 0.3.0 |
celeborn.push.queue.capacity |