This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 2f068226 [CELEBORN-119] Add timeout for pushdata (#1097)
2f068226 is described below
commit 2f0682265e2044604ccfbe8185124081b0002568
Author: Keyong Zhou <[email protected]>
AuthorDate: Tue Dec 20 20:40:42 2022 +0800
[CELEBORN-119] Add timeout for pushdata (#1097)
---
.../apache/celeborn/client/ShuffleClientImpl.java | 75 ++++++++++---------
.../apache/celeborn/client/write/PushState.java | 83 +++++++++++++++++-----
.../common/network/client/TransportClient.java | 2 +-
.../common/protocol/message/StatusCode.java | 6 +-
.../org/apache/celeborn/common/CelebornConf.scala | 15 +++-
docs/configuration/client.md | 2 +-
docs/configuration/worker.md | 1 +
.../celeborn/tests/spark/PushdataTimeoutTest.scala | 80 +++++++++++++++++++++
.../service/deploy/worker/PushDataHandler.scala | 20 ++++++
9 files changed, 221 insertions(+), 63 deletions(-)
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index a7193d03..65c65cec 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -178,7 +178,7 @@ public class ShuffleClientImpl extends ShuffleClient {
} else if (mapperEnded(shuffleId, mapId, attemptId)) {
logger.debug(
"Retrying push data, but the mapper(map {} attempt {}) has ended.",
mapId, attemptId);
- pushState.inFlightBatches.remove(batchId);
+ pushState.removeBatch(batchId);
} else {
PartitionLocation newLoc =
reducePartitionMap.get(shuffleId).get(partitionId);
logger.info("Revive success, new location for reduce {} is {}.",
partitionId, newLoc);
@@ -191,7 +191,7 @@ public class ShuffleClientImpl extends ShuffleClient {
PushData newPushData =
new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(),
newBuffer);
ChannelFuture future = client.pushData(newPushData, callback);
- pushState.addFuture(batchId, future);
+ pushState.pushStarted(batchId, future, callback);
} catch (Exception ex) {
logger.warn(
"Exception raised while pushing data for shuffle {} map {} attempt
{}" + " batch {}.",
@@ -256,7 +256,7 @@ public class ShuffleClientImpl extends ShuffleClient {
pushState,
true);
}
- pushState.inFlightBatches.remove(oldGroupedBatchId);
+ pushState.removeBatch(oldGroupedBatchId);
}
private String genAddressPair(PartitionLocation loc) {
@@ -360,18 +360,20 @@ public class ShuffleClientImpl extends ShuffleClient {
throw pushState.exception.get();
}
- ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches =
pushState.inFlightBatches;
long timeoutMs = conf.pushLimitInFlightTimeoutMs();
long delta = conf.pushLimitInFlightSleepDeltaMs();
long times = timeoutMs / delta;
try {
while (times > 0) {
- if (inFlightBatches.size() <= limit) {
+ if (pushState.inflightBatchCount() <= limit) {
break;
}
if (pushState.exception.get() != null) {
throw pushState.exception.get();
}
+
+ pushState.failExpiredBatch();
+
Thread.sleep(delta);
times--;
}
@@ -384,10 +386,9 @@ public class ShuffleClientImpl extends ShuffleClient {
"After waiting for {} ms, there are still {} batches in flight for
map {}, "
+ "which exceeds the limit {}.",
timeoutMs,
- inFlightBatches.size(),
+ pushState.inflightBatchCount(),
mapKey,
limit);
- logger.error("Map: {} in flight batches: {}", mapKey, inFlightBatches);
throw new IOException("wait timeout for task " + mapKey,
pushState.exception.get());
}
if (pushState.exception.get() != null) {
@@ -507,7 +508,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId);
PushState pushState = pushStates.get(mapKey);
if (pushState != null) {
- pushState.cancelFutures();
+ pushState.cleanup();
}
return 0;
}
@@ -546,7 +547,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId);
PushState pushState = pushStates.get(mapKey);
if (pushState != null) {
- pushState.cancelFutures();
+ pushState.cleanup();
}
return 0;
}
@@ -593,7 +594,7 @@ public class ShuffleClientImpl extends ShuffleClient {
limitMaxInFlight(mapKey, pushState, currentMaxReqsInFlight);
// add inFlight requests
- pushState.inFlightBatches.put(nextBatchId, loc);
+ pushState.addBatch(nextBatchId);
// build PushData request
NettyManagedBuffer buffer = new
NettyManagedBuffer(Unpooled.wrappedBuffer(body));
@@ -604,7 +605,7 @@ public class ShuffleClientImpl extends ShuffleClient {
new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
- pushState.inFlightBatches.remove(nextBatchId);
+ pushState.removeBatch(nextBatchId);
// TODO Need to adjust maxReqsInFlight if server response is
congested, see
// CELEBORN-62
if (response.remaining() > 0 && response.get() ==
StatusCode.STAGE_ENDED.getValue()) {
@@ -612,7 +613,6 @@ public class ShuffleClientImpl extends ShuffleClient {
.computeIfAbsent(shuffleId, (id) ->
ConcurrentHashMap.newKeySet())
.add(mapKey);
}
- pushState.removeFuture(nextBatchId);
logger.debug(
"Push data to {}:{} success for map {} attempt {} batch {}.",
loc.getHost(),
@@ -626,7 +626,6 @@ public class ShuffleClientImpl extends ShuffleClient {
public void onFailure(Throwable e) {
pushState.exception.compareAndSet(
null, new IOException("Revived PushData failed!", e));
- pushState.removeFuture(nextBatchId);
logger.error(
"Push data to {}:{} failed for map {} attempt {} batch {}.",
loc.getHost(),
@@ -679,6 +678,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
nextBatchId);
congestionControl();
+ callback.onSuccess(response);
} else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) {
logger.debug(
"Push data split for map {} attempt {} batch {} return
slave congested.",
@@ -686,6 +686,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
nextBatchId);
congestionControl();
+ callback.onSuccess(response);
} else {
response.rewind();
slowStart();
@@ -726,7 +727,7 @@ public class ShuffleClientImpl extends ShuffleClient {
pushState,
getPushDataFailCause(e.getMessage())));
} else {
- pushState.inFlightBatches.remove(nextBatchId);
+ pushState.removeBatch(nextBatchId);
logger.info(
"Mapper shuffleId:{} mapId:{} attempt:{} already ended,
remove batchId:{}.",
shuffleId,
@@ -742,7 +743,7 @@ public class ShuffleClientImpl extends ShuffleClient {
TransportClient client =
dataClientFactory.createClient(loc.getHost(), loc.getPushPort(),
partitionId);
ChannelFuture future = client.pushData(pushData, wrappedCallback);
- pushState.addFuture(nextBatchId, future);
+ pushState.pushStarted(nextBatchId, future, wrappedCallback);
} catch (Exception e) {
logger.warn("PushData failed", e);
wrappedCallback.onFailure(
@@ -898,7 +899,7 @@ public class ShuffleClientImpl extends ShuffleClient {
final int port = Integer.parseInt(splits[1]);
int groupedBatchId = pushState.batchId.addAndGet(1);
- pushState.inFlightBatches.put(groupedBatchId, batches.get(0).loc);
+ pushState.addBatch(groupedBatchId);
final int numBatches = batches.size();
final String[] partitionUniqueIds = new String[numBatches];
@@ -928,7 +929,7 @@ public class ShuffleClientImpl extends ShuffleClient {
mapId,
attemptId,
groupedBatchId);
- pushState.inFlightBatches.remove(groupedBatchId);
+ pushState.removeBatch(groupedBatchId);
// TODO Need to adjust maxReqsInFlight if server response is
congested, see CELEBORN-62
if (response.remaining() > 0 && response.get() ==
StatusCode.STAGE_ENDED.getValue()) {
mapperEndMap
@@ -995,6 +996,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
Arrays.toString(batchIds));
congestionControl();
+ callback.onSuccess(response);
} else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) {
logger.debug(
"Push data split for map {} attempt {} batchs {} return
slave congested.",
@@ -1002,6 +1004,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
Arrays.toString(batchIds));
congestionControl();
+ callback.onSuccess(response);
} else {
// Should not happen in current architecture.
response.rewind();
@@ -1056,7 +1059,8 @@ public class ShuffleClientImpl extends ShuffleClient {
// do push merged data
try {
TransportClient client = dataClientFactory.createClient(host, port);
- client.pushMergedData(mergedData, wrappedCallback);
+ ChannelFuture future = client.pushMergedData(mergedData,
wrappedCallback);
+ pushState.pushStarted(groupedBatchId, future, wrappedCallback);
} catch (Exception e) {
logger.warn("PushMergedData failed", e);
wrappedCallback.onFailure(new
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
@@ -1114,7 +1118,7 @@ public class ShuffleClientImpl extends ShuffleClient {
PushState pushState = pushStates.remove(mapKey);
if (pushState != null) {
pushState.exception.compareAndSet(null, new IOException("Cleaned Up"));
- pushState.cancelFutures();
+ pushState.cleanup();
}
}
@@ -1303,6 +1307,8 @@ public class ShuffleClientImpl extends ShuffleClient {
} else if (StatusCode.PUSH_DATA_FAIL_MASTER.getMessage().equals(message)
|| connectFail(message)) {
cause = StatusCode.PUSH_DATA_FAIL_MASTER;
+ } else if (StatusCode.PUSH_DATA_TIMEOUT.getMessage().equals(message)) {
+ cause = StatusCode.PUSH_DATA_TIMEOUT;
} else {
cause = StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE;
}
@@ -1338,7 +1344,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId);
PushState pushState = pushStates.get(mapKey);
if (pushState != null) {
- pushState.cancelFutures();
+ pushState.cleanup();
}
return 0;
}
@@ -1367,7 +1373,7 @@ public class ShuffleClientImpl extends ShuffleClient {
limitMaxInFlight(mapKey, pushState, maxInFlight);
// add inFlight requests
- pushState.inFlightBatches.put(nextBatchId, location);
+ pushState.addBatch(nextBatchId);
// build PushData request
NettyManagedBuffer buffer = new NettyManagedBuffer(data);
@@ -1379,8 +1385,7 @@ public class ShuffleClientImpl extends ShuffleClient {
@Override
public void onSuccess(ByteBuffer response) {
closeCallBack.getAsBoolean();
- pushState.inFlightBatches.remove(nextBatchId);
- pushState.removeFuture(nextBatchId);
+ pushState.removeBatch(nextBatchId);
if (response.remaining() > 0) {
byte reason = response.get();
if (reason == StatusCode.STAGE_ENDED.getValue()) {
@@ -1401,8 +1406,7 @@ public class ShuffleClientImpl extends ShuffleClient {
@Override
public void onFailure(Throwable e) {
closeCallBack.getAsBoolean();
- pushState.inFlightBatches.remove(nextBatchId);
- pushState.removeFuture(nextBatchId);
+ pushState.removeBatch(nextBatchId);
if (pushState.exception.get() != null) {
return;
}
@@ -1432,7 +1436,7 @@ public class ShuffleClientImpl extends ShuffleClient {
TransportClient client =
dataClientFactory.createClient(location.getHost(),
location.getPushPort(), partitionId);
ChannelFuture future = client.pushData(pushData, callback);
- pushState.addFuture(nextBatchId, future);
+ pushState.pushStarted(nextBatchId, future, callback);
} catch (Exception e) {
logger.warn("PushData byteBuf failed", e);
callback.onFailure(new
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
@@ -1454,7 +1458,6 @@ public class ShuffleClientImpl extends ShuffleClient {
shuffleId,
mapId,
attemptId,
- location,
() -> {
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
@@ -1473,7 +1476,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
numPartitions,
bufferSize);
- client.sendRpcSync(handShake.toByteBuffer(),
conf.pushDataRpcTimeoutMs());
+ client.sendRpcSync(handShake.toByteBuffer(),
conf.pushDataTimeoutMs());
return null;
});
}
@@ -1492,7 +1495,6 @@ public class ShuffleClientImpl extends ShuffleClient {
shuffleId,
mapId,
attemptId,
- location,
() -> {
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
@@ -1512,7 +1514,7 @@ public class ShuffleClientImpl extends ShuffleClient {
currentRegionIdx,
isBroadcast);
ByteBuffer regionStartResponse =
- client.sendRpcSync(regionStart.toByteBuffer(),
conf.pushDataRpcTimeoutMs());
+ client.sendRpcSync(regionStart.toByteBuffer(),
conf.pushDataTimeoutMs());
if (regionStartResponse.hasRemaining()
&& regionStartResponse.get() ==
StatusCode.HARD_SPLIT.getValue()) {
// if split then revive
@@ -1561,7 +1563,6 @@ public class ShuffleClientImpl extends ShuffleClient {
shuffleId,
mapId,
attemptId,
- location,
() -> {
final String shuffleKey = Utils.makeShuffleKey(applicationId,
shuffleId);
logger.info(
@@ -1574,17 +1575,13 @@ public class ShuffleClientImpl extends ShuffleClient {
dataClientFactory.createClient(location.getHost(),
location.getPushPort());
RegionFinish regionFinish =
new RegionFinish(MASTER_MODE, shuffleKey,
location.getUniqueId(), attemptId);
- client.sendRpcSync(regionFinish.toByteBuffer(),
conf.pushDataRpcTimeoutMs());
+ client.sendRpcSync(regionFinish.toByteBuffer(),
conf.pushDataTimeoutMs());
return null;
});
}
private <R> R sendMessageInternal(
- int shuffleId,
- int mapId,
- int attemptId,
- PartitionLocation location,
- ThrowingExceptionSupplier<R, Exception> supplier)
+ int shuffleId, int mapId, int attemptId, ThrowingExceptionSupplier<R,
Exception> supplier)
throws IOException {
PushState pushState = null;
int batchId = 0;
@@ -1606,11 +1603,11 @@ public class ShuffleClientImpl extends ShuffleClient {
// add inFlight requests
batchId = pushState.batchId.incrementAndGet();
- pushState.inFlightBatches.put(batchId, location);
+ pushState.addBatch(batchId);
return retrySendMessage(supplier);
} finally {
if (pushState != null) {
- pushState.inFlightBatches.remove(batchId);
+ pushState.removeBatch(batchId);
}
}
}
diff --git
a/client/src/main/java/org/apache/celeborn/client/write/PushState.java
b/client/src/main/java/org/apache/celeborn/client/write/PushState.java
index 67a719aa..a8f96240 100644
--- a/client/src/main/java/org/apache/celeborn/client/write/PushState.java
+++ b/client/src/main/java/org/apache/celeborn/client/write/PushState.java
@@ -18,8 +18,6 @@
package org.apache.celeborn.client.write;
import java.io.IOException;
-import java.util.HashSet;
-import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@@ -29,41 +27,90 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.message.StatusCode;
public class PushState {
+ class BatchInfo {
+ ChannelFuture channelFuture;
+ long pushTime;
+ RpcResponseCallback callback;
+ }
+
private static final Logger logger =
LoggerFactory.getLogger(PushState.class);
private int pushBufferMaxSize;
+ private long pushTimeout;
public final AtomicInteger batchId = new AtomicInteger();
- public final ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches =
+ private final ConcurrentHashMap<Integer, BatchInfo> inflightBatchInfos =
new ConcurrentHashMap<>();
- public final ConcurrentHashMap<Integer, ChannelFuture> futures = new
ConcurrentHashMap<>();
public AtomicReference<IOException> exception = new AtomicReference<>();
public PushState(CelebornConf conf) {
pushBufferMaxSize = conf.pushBufferMaxSize();
+ pushTimeout = conf.pushDataTimeoutMs();
+ }
+
+ public void addBatch(int batchId) {
+ inflightBatchInfos.computeIfAbsent(batchId, id -> new BatchInfo());
+ }
+
+ public void removeBatch(int batchId) {
+ BatchInfo info = inflightBatchInfos.remove(batchId);
+ if (info != null && info.channelFuture != null) {
+ info.channelFuture.cancel(true);
+ }
}
- public void addFuture(int batchId, ChannelFuture future) {
- futures.put(batchId, future);
+ public int inflightBatchCount() {
+ return inflightBatchInfos.size();
}
- public void removeFuture(int batchId) {
- futures.remove(batchId);
+ public synchronized void failExpiredBatch() {
+ long currentTime = System.currentTimeMillis();
+ inflightBatchInfos
+ .values()
+ .forEach(
+ info -> {
+ if (currentTime - info.pushTime > pushTimeout) {
+ if (info.callback != null) {
+ info.channelFuture.cancel(true);
+ info.callback.onFailure(
+ new
IOException(StatusCode.PUSH_DATA_TIMEOUT.getMessage()));
+ info.channelFuture = null;
+ info.callback = null;
+ }
+ }
+ });
+ }
+
+ public void pushStarted(int batchId, ChannelFuture future,
RpcResponseCallback callback) {
+ BatchInfo info = inflightBatchInfos.get(batchId);
+ // In rare cases info could be null. For example, a speculative task has
one thread pushing,
+ // and other thread retry-pushing. At time 1 thread 1 find StageEnded,
then it cleans up
+ // PushState, at the same time thread 2 pushes data and calles pushStarted,
+ // at this time info will be null
+ if (info != null) {
+ info.pushTime = System.currentTimeMillis();
+ info.channelFuture = future;
+ info.callback = callback;
+ }
}
- public synchronized void cancelFutures() {
- if (!futures.isEmpty()) {
- Set<Integer> keys = new HashSet<>(futures.keySet());
- logger.debug("Cancel all {} futures.", keys.size());
- for (Integer batchId : keys) {
- ChannelFuture future = futures.remove(batchId);
- if (future != null) {
- future.cancel(true);
- }
- }
+ public void cleanup() {
+ if (!inflightBatchInfos.isEmpty()) {
+ logger.debug("Cancel all {} futures.", inflightBatchInfos.size());
+ inflightBatchInfos
+ .values()
+ .forEach(
+ entry -> {
+ if (entry.channelFuture != null) {
+ entry.channelFuture.cancel(true);
+ }
+ });
+ inflightBatchInfos.clear();
}
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index 4ffc579d..88e8c42a 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -288,7 +288,7 @@ public class TransportClient implements Closeable {
} else {
String errorMsg =
String.format(
- "Failed to send RPC %s to %s: %s, channel will be closed",
+ "Failed to send request %s to %s: %s, channel will be closed",
requestId, NettyUtils.getRemoteAddress(channel),
future.cause());
logger.warn(errorMsg);
channel.close();
diff --git
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index 3b3f067d..0fb46765 100644
---
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -66,7 +66,9 @@ public enum StatusCode {
REGION_START_FAIL_SLAVE(34),
REGION_START_FAIL_MASTER(35),
REGION_FINISH_FAIL_SLAVE(36),
- REGION_FINISH_FAIL_MASTER(37);
+ REGION_FINISH_FAIL_MASTER(37),
+
+ PUSH_DATA_TIMEOUT(38);
private final byte value;
@@ -103,6 +105,8 @@ public enum StatusCode {
msg = "RegionFinishFailMaster";
} else if (value == REGION_FINISH_FAIL_SLAVE.getValue()) {
msg = "RegionFinishFailSlave";
+ } else if (value == PUSH_DATA_TIMEOUT.getValue()) {
+ msg = "PushDataTimeout";
}
return msg;
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 16b078ba..0dfbc3b8 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -556,6 +556,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
// //////////////////////////////////////////////////////
def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
+ def testPushDataTimeout: Boolean = get(TEST_PUSHDATA_TIMEOUT)
def masterHost: String = get(MASTER_HOST)
@@ -679,7 +680,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
def rpcCacheSize: Int = get(RPC_CACHE_SIZE)
def rpcCacheConcurrencyLevel: Int = get(RPC_CACHE_CONCURRENCY_LEVEL)
def rpcCacheExpireTime: Long = get(RPC_CACHE_EXPIRE_TIME)
- def pushDataRpcTimeoutMs = get(PUSH_DATA_RPC_TIMEOUT)
+ def pushDataTimeoutMs = get(PUSH_DATA_TIMEOUT)
def registerShuffleRpcAskTimeout: RpcTimeout =
new RpcTimeout(
@@ -2192,8 +2193,8 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("5s")
- val PUSH_DATA_RPC_TIMEOUT: ConfigEntry[Long] =
- buildConf("celeborn.push.data.rpc.timeout")
+ val PUSH_DATA_TIMEOUT: ConfigEntry[Long] =
+ buildConf("celeborn.push.data.timeout")
.withAlternative("rss.push.data.rpc.timeout")
.categories("client")
.version("0.2.0")
@@ -2201,6 +2202,14 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("120s")
+ val TEST_PUSHDATA_TIMEOUT: ConfigEntry[Boolean] =
+ buildConf("celeborn.test.pushdataTimeout")
+ .categories("worker")
+ .version("0.2.0")
+ .doc("Wheter to test pushdata timeout")
+ .booleanConf
+ .createWithDefault(false)
+
val REGISTER_SHUFFLE_RPC_ASK_TIMEOUT: OptionalConfigEntry[Long] =
buildConf("celeborn.rpc.registerShuffle.askTimeout")
.categories("client")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 8ec3aafc..b58fac84 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -27,7 +27,7 @@ license: |
| celeborn.master.endpoints | <localhost>:9097 | Endpoints of master
nodes for celeborn client to connect, allowed pattern is:
`<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If
the port is omitted, 9097 will be used. | 0.2.0 |
| celeborn.push.buffer.initial.size | 8k | | 0.2.0 |
| celeborn.push.buffer.max.size | 64k | Max size of reducer partition buffer
memory for shuffle hash writer. The pushed data will be buffered in memory
before sending to Celeborn worker. For performance consideration keep this
buffer size higher than 32K. Example: If reducer amount is 2000, buffer size is
64K, then each task will consume up to `64KiB * 2000 = 125MiB` heap memory. |
0.2.0 |
-| celeborn.push.data.rpc.timeout | 120s | Timeout for a task to push data rpc
message. | 0.2.0 |
+| celeborn.push.data.timeout | 120s | Timeout for a task to push data rpc
message. | 0.2.0 |
| celeborn.push.limit.inFlight.sleepInterval | 50ms | Sleep interval when
check netty in-flight requests to be done. | 0.2.0 |
| celeborn.push.limit.inFlight.timeout | 240s | Timeout for netty in-flight
requests to be done. | 0.2.0 |
| celeborn.push.maxReqsInFlight | 32 | Amount of Netty in-flight requests. The
maximum memory is `celeborn.push.maxReqsInFlight` *
`celeborn.push.buffer.max.size` * compression ratio(1 in worst case), default:
64Kib * 32 = 2Mib | 0.2.0 |
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 0ae29ef9..44b62816 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -29,6 +29,7 @@ license: |
| celeborn.shuffle.chuck.size | 8m | Max chunk size of reducer's merged
shuffle data. For example, if a reducer's shuffle data is 128M and the data
will need 16 fetch chunk requests to fetch. | 0.2.0 |
| celeborn.shuffle.minPartitionSizeToEstimate | 8mb | Ignore partition size
smaller than this configuration of partition size for estimation. | 0.2.0 |
| celeborn.storage.hdfs.dir | <undefined> | HDFS dir configuration for
Celeborn to access HDFS. | 0.2.0 |
+| celeborn.test.pushdataTimeout | false | Wheter to test pushdata timeout |
0.2.0 |
| celeborn.worker.closeIdleConnections | false | Whether worker will close
idle connections. | 0.2.0 |
| celeborn.worker.commit.threads | 32 | Thread number of worker to commit
shuffle data files asynchronously. | 0.2.0 |
| celeborn.worker.directMemoryRatioForMemoryShuffleStorage | 0.1 | Max ratio
of direct memory to store shuffle data | 0.2.0 |
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
new file mode 100644
index 00000000..ed876524
--- /dev/null
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+
+class PushdataTimeoutTest extends AnyFunSuite
+ with SparkTestBase
+ with BeforeAndAfterAll
+ with BeforeAndAfterEach {
+
+ override def beforeAll(): Unit = {
+ logInfo("test initialized , setup rss mini cluster")
+ val workerConf = Map(
+ "celeborn.test.pushdataTimeout" -> s"true")
+ tuple = setupRssMiniClusterSpark(masterConfs = null, workerConfs =
workerConf)
+ }
+
+ override def afterAll(): Unit = {
+ logInfo("all test complete , stop rss mini cluster")
+ clearMiniCluster(tuple)
+ }
+
+ override def beforeEach(): Unit = {
+ ShuffleClient.reset()
+ }
+
+ override def afterEach(): Unit = {
+ System.gc()
+ }
+
+ test("celeborn spark integration test - pushdata timeout") {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[4]")
+ .set("spark.celeborn.push.data.timeout", "10s")
+ val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+ val combineResult = combine(sparkSession)
+ val groupbyResult = groupBy(sparkSession)
+ val repartitionResult = repartition(sparkSession)
+ val sqlResult = runsql(sparkSession)
+
+ Thread.sleep(3000L)
+ sparkSession.stop()
+
+ val rssSparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, false)).getOrCreate()
+ val rssCombineResult = combine(rssSparkSession)
+ val rssGroupbyResult = groupBy(rssSparkSession)
+ val rssRepartitionResult = repartition(rssSparkSession)
+ val rssSqlResult = runsql(rssSparkSession)
+
+ assert(combineResult.equals(rssCombineResult))
+ assert(groupbyResult.equals(rssGroupbyResult))
+ assert(repartitionResult.equals(rssRepartitionResult))
+ assert(combineResult.equals(rssCombineResult))
+ assert(sqlResult.equals(rssSqlResult))
+
+ rssSparkSession.stop()
+
+ }
+}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 933f3b6b..d91c55c9 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.{AtomicBoolean,
AtomicIntegerArray}
import com.google.common.base.Throwables
import io.netty.buffer.ByteBuf
+import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.AlreadyClosedException
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
@@ -55,6 +56,8 @@ class PushDataHandler extends BaseMessageHandler with Logging
{
var partitionSplitMinimumSize: Long = _
var shutdown: AtomicBoolean = _
var storageManager: StorageManager = _
+ var conf: CelebornConf = _
+ @volatile var pushDataTimeoutTested = false
def init(worker: Worker): Unit = {
workerSource = worker.workerSource
@@ -71,6 +74,7 @@ class PushDataHandler extends BaseMessageHandler with Logging
{
partitionSplitMinimumSize = worker.conf.partitionSplitMinimumSize
storageManager = worker.storageManager
shutdown = worker.shutdown
+ conf = worker.conf
logInfo(s"diskReserveSize $diskReserveSize")
}
@@ -122,6 +126,12 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
val body = pushData.body.asInstanceOf[NettyManagedBuffer].getBuf
val isMaster = mode == PartitionLocation.Mode.MASTER
+ // For test
+ if (conf.testPushDataTimeout && !pushDataTimeoutTested) {
+ pushDataTimeoutTested = true
+ return
+ }
+
val key = s"${pushData.requestId}"
if (isMaster) {
workerSource.startTimer(WorkerSource.MasterPushDataTime, key)
@@ -315,6 +325,12 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
workerSource.startTimer(WorkerSource.SlavePushDataTime, key)
}
+ // For test
+ if (conf.testPushDataTimeout && !PushDataHandler.pushDataTimeoutTested) {
+ PushDataHandler.pushDataTimeoutTested = true
+ return
+ }
+
val wrappedCallback = new RpcResponseCallback() {
override def onSuccess(response: ByteBuffer): Unit = {
if (isMaster) {
@@ -897,3 +913,7 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
(PackedPartitionId.getRawPartitionId(id),
PackedPartitionId.getAttemptId(id))
}
}
+
+object PushDataHandler {
+ @volatile var pushDataTimeoutTested = false
+}