This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new e155ec12 [CELEBORN-190] doPushMergedData should also support revive
multiple times, not only twice (#1136)
e155ec12 is described below
commit e155ec122adc08b283c30fa6352a028b27029296
Author: Angerszhuuuu <[email protected]>
AuthorDate: Tue Jan 10 11:39:40 2023 +0800
[CELEBORN-190] doPushMergedData should also support revive multiple times,
not only twice (#1136)
---
.../apache/celeborn/client/ShuffleClientImpl.java | 132 +++++++++++++++------
.../org/apache/celeborn/common/CelebornConf.scala | 32 +++--
docs/configuration/client.md | 2 +
.../celeborn/tests/spark/RetryReviveTest.scala | 53 +++++++++
4 files changed, 177 insertions(+), 42 deletions(-)
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 3735dbe6..d7bc1700 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -85,6 +85,8 @@ public class ShuffleClientImpl extends ShuffleClient {
private final int registerShuffleMaxRetries;
private final long registerShuffleRetryWaitMs;
private int maxInFlight;
+ private int maxReviveTimes;
+ private boolean testRetryRevive;
private final AtomicInteger currentMaxReqsInFlight;
private int congestionAvoidanceFlag = 0;
private final int pushBufferMaxSize;
@@ -139,6 +141,8 @@ public class ShuffleClientImpl extends ShuffleClient {
registerShuffleMaxRetries = conf.registerShuffleMaxRetry();
registerShuffleRetryWaitMs = conf.registerShuffleRetryWaitMs();
maxInFlight = conf.pushMaxReqsInFlight();
+ maxReviveTimes = conf.pushMaxReviveTimes();
+ testRetryRevive = conf.testRetryRevive();
if (conf.pushDataSlowStart()) {
currentMaxReqsInFlight = new AtomicInteger(1);
@@ -178,11 +182,13 @@ public class ShuffleClientImpl extends ShuffleClient {
PartitionLocation loc,
RpcResponseCallback callback,
PushState pushState,
- StatusCode cause) {
+ StatusCode cause,
+ int remainReviveTimes) {
int partitionId = loc.getId();
if (!revive(
applicationId, shuffleId, mapId, attemptId, partitionId,
loc.getEpoch(), loc, cause)) {
- callback.onFailure(new IOException("Revive Failed"));
+ callback.onFailure(
+ new IOException("Revive Failed, remain revive times " +
remainReviveTimes));
} else if (mapperEnded(shuffleId, mapId, attemptId)) {
logger.debug(
"Retrying push data, but the mapper(map {} attempt {}) has ended.",
mapId, attemptId);
@@ -191,15 +197,20 @@ public class ShuffleClientImpl extends ShuffleClient {
PartitionLocation newLoc =
reducePartitionMap.get(shuffleId).get(partitionId);
logger.info("Revive success, new location for reduce {} is {}.",
partitionId, newLoc);
try {
- TransportClient client =
- dataClientFactory.createClient(newLoc.getHost(),
newLoc.getPushPort(), partitionId);
- NettyManagedBuffer newBuffer = new
NettyManagedBuffer(Unpooled.wrappedBuffer(body));
- String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
-
- PushData newPushData =
- new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(),
newBuffer);
- ChannelFuture future = client.pushData(newPushData, callback);
- pushState.pushStarted(batchId, future, callback);
+ if (!testRetryRevive || remainReviveTimes < 1) {
+ TransportClient client =
+ dataClientFactory.createClient(newLoc.getHost(),
newLoc.getPushPort(), partitionId);
+ NettyManagedBuffer newBuffer = new
NettyManagedBuffer(Unpooled.wrappedBuffer(body));
+ String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
+
+ PushData newPushData =
+ new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(),
newBuffer);
+ ChannelFuture future = client.pushData(newPushData, callback);
+ pushState.pushStarted(batchId, future, callback);
+ } else {
+ throw new RuntimeException(
+ "Mock push data submit retry failed. remainReviveTimes = " +
remainReviveTimes + ".");
+ }
} catch (Exception ex) {
logger.warn(
"Exception raised while pushing data for shuffle {} map {} attempt
{}" + " batch {}.",
@@ -221,8 +232,10 @@ public class ShuffleClientImpl extends ShuffleClient {
int attemptId,
ArrayList<DataBatches.DataBatch> batches,
StatusCode cause,
- Integer oldGroupedBatchId) {
+ Integer oldGroupedBatchId,
+ int remainReviveTimes) {
HashMap<String, DataBatches> newDataBatchesMap = new HashMap<>();
+ ArrayList<DataBatches.DataBatch> reviveFailedBatchesMap = new
ArrayList<>();
for (DataBatches.DataBatch batch : batches) {
int partitionId = batch.loc.getId();
if (!revive(
@@ -234,10 +247,16 @@ public class ShuffleClientImpl extends ShuffleClient {
batch.loc.getEpoch(),
batch.loc,
cause)) {
- pushState.exception.compareAndSet(
- null,
- new IOException("Revive Failed in retry push merged data for
location: " + batch.loc));
- return;
+
+ if (remainReviveTimes > 0) {
+ reviveFailedBatchesMap.add(batch);
+ } else {
+ pushState.exception.compareAndSet(
+ null,
+ new IOException(
+ "Revive Failed in retry push merged data for location: " +
batch.loc));
+ return;
+ }
} else if (mapperEnded(shuffleId, mapId, attemptId)) {
logger.debug(
"Retrying push data, but the mapper(map {} attempt {}) has
ended.", mapId, attemptId);
@@ -262,9 +281,24 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
newDataBatches.requireBatches(),
pushState,
- true);
+ remainReviveTimes);
+ }
+ if (reviveFailedBatchesMap.isEmpty()) {
+ pushState.removeBatch(oldGroupedBatchId);
+ } else {
+ pushDataRetryPool.submit(
+ () ->
+ submitRetryPushMergedData(
+ pushState,
+ applicationId,
+ shuffleId,
+ mapId,
+ attemptId,
+ reviveFailedBatchesMap,
+ cause,
+ oldGroupedBatchId,
+ remainReviveTimes - 1));
}
- pushState.removeBatch(oldGroupedBatchId);
}
private String genAddressPair(PartitionLocation loc) {
@@ -652,6 +686,8 @@ public class ShuffleClientImpl extends ShuffleClient {
RpcResponseCallback wrappedCallback =
new RpcResponseCallback() {
+ int remainReviveTimes = maxReviveTimes;
+
@Override
public void onSuccess(ByteBuffer response) {
if (response.remaining() > 0) {
@@ -683,7 +719,8 @@ public class ShuffleClientImpl extends ShuffleClient {
loc,
this,
pushState,
- StatusCode.HARD_SPLIT));
+ StatusCode.HARD_SPLIT,
+ remainReviveTimes));
} else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_MASTER_CONGESTED.getValue()) {
logger.debug(
"Push data split for map {} attempt {} batch {} return
master congested.",
@@ -716,6 +753,12 @@ public class ShuffleClientImpl extends ShuffleClient {
if (pushState.exception.get() != null) {
return;
}
+
+ if (remainReviveTimes <= 0) {
+ callback.onFailure(e);
+ return;
+ }
+
logger.error(
"Push data to {}:{} failed for map {} attempt {} batch {}.",
loc.getHost(),
@@ -726,6 +769,7 @@ public class ShuffleClientImpl extends ShuffleClient {
e);
// async retry push data
if (!mapperEnded(shuffleId, mapId, attemptId)) {
+ remainReviveTimes = remainReviveTimes - 1;
pushDataRetryPool.submit(
() ->
submitRetryPushData(
@@ -736,9 +780,10 @@ public class ShuffleClientImpl extends ShuffleClient {
body,
nextBatchId,
loc,
- callback,
+ this,
pushState,
- getPushDataFailCause(e.getMessage())));
+ getPushDataFailCause(e.getMessage()),
+ remainReviveTimes));
} else {
pushState.removeBatch(nextBatchId);
logger.info(
@@ -753,10 +798,14 @@ public class ShuffleClientImpl extends ShuffleClient {
// do push data
try {
- TransportClient client =
- dataClientFactory.createClient(loc.getHost(), loc.getPushPort(),
partitionId);
- ChannelFuture future = client.pushData(pushData, wrappedCallback);
- pushState.pushStarted(nextBatchId, future, wrappedCallback);
+ if (!testRetryRevive) {
+ TransportClient client =
+ dataClientFactory.createClient(loc.getHost(), loc.getPushPort(),
partitionId);
+ ChannelFuture future = client.pushData(pushData, wrappedCallback);
+ pushState.pushStarted(nextBatchId, future, wrappedCallback);
+ } else {
+ throw new RuntimeException("Mock push data first time failed.");
+ }
} catch (Exception e) {
logger.warn("PushData failed", e);
wrappedCallback.onFailure(
@@ -778,7 +827,7 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
dataBatches.requireBatches(),
pushState,
- false);
+ maxReviveTimes);
}
}
@@ -894,7 +943,14 @@ public class ShuffleClientImpl extends ShuffleClient {
}
String[] tokens = entry.getKey().split("-");
doPushMergedData(
- tokens[0], applicationId, shuffleId, mapId, attemptId, batches,
pushState, false);
+ tokens[0],
+ applicationId,
+ shuffleId,
+ mapId,
+ attemptId,
+ batches,
+ pushState,
+ maxReviveTimes);
}
}
@@ -906,7 +962,7 @@ public class ShuffleClientImpl extends ShuffleClient {
int attemptId,
ArrayList<DataBatches.DataBatch> batches,
PushState pushState,
- boolean revived) {
+ int remainReviveTimes) {
final String[] splits = hostPort.split(":");
final String host = splits[0];
final int port = Integer.parseInt(splits[1]);
@@ -954,7 +1010,7 @@ public class ShuffleClientImpl extends ShuffleClient {
@Override
public void onFailure(Throwable e) {
String errorMsg =
- (revived ? "Revived push" : "Push")
+ (remainReviveTimes < maxReviveTimes ? "Revived push" : "Push")
+ " merged data to "
+ host
+ ":"
@@ -1001,7 +1057,8 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
batches,
StatusCode.HARD_SPLIT,
- groupedBatchId));
+ groupedBatchId,
+ remainReviveTimes));
} else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_MASTER_CONGESTED.getValue()) {
logger.debug(
"Push data split for map {} attempt {} batchs {} return
master congested.",
@@ -1036,7 +1093,7 @@ public class ShuffleClientImpl extends ShuffleClient {
if (pushState.exception.get() != null) {
return;
}
- if (revived) {
+ if (remainReviveTimes <= 0) {
callback.onFailure(e);
return;
}
@@ -1064,16 +1121,21 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
batches,
getPushDataFailCause(e.getMessage()),
- groupedBatchId));
+ groupedBatchId,
+ remainReviveTimes - 1));
}
}
};
// do push merged data
try {
- TransportClient client = dataClientFactory.createClient(host, port);
- ChannelFuture future = client.pushMergedData(mergedData,
wrappedCallback);
- pushState.pushStarted(groupedBatchId, future, wrappedCallback);
+ if (!testRetryRevive || remainReviveTimes < 1) {
+ TransportClient client = dataClientFactory.createClient(host, port);
+ ChannelFuture future = client.pushMergedData(mergedData,
wrappedCallback);
+ pushState.pushStarted(groupedBatchId, future, wrappedCallback);
+ } else {
+ throw new RuntimeException("Mock push merge data failed");
+ }
} catch (Exception e) {
logger.warn("PushMergedData failed", e);
wrappedCallback.onFailure(new
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 2d2da940..aeff72d4 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -552,13 +552,6 @@ class CelebornConf(loadDefaults: Boolean) extends
Cloneable with Logging with Se
}
}
- // //////////////////////////////////////////////////////
- // test //
- // //////////////////////////////////////////////////////
- def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
- def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
- def testPushDataTimeout: Boolean = get(TEST_PUSHDATA_TIMEOUT)
-
def masterHost: String = get(MASTER_HOST)
def masterPort: Int = get(MASTER_PORT)
@@ -663,6 +656,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
def pushBufferMaxSize: Int = get(PUSH_BUFFER_MAX_SIZE).toInt
def pushQueueCapacity: Int = get(PUSH_QUEUE_CAPACITY)
def pushMaxReqsInFlight: Int = get(PUSH_MAX_REQS_IN_FLIGHT)
+ def pushMaxReviveTimes: Int = get(PUSH_MAX_REVIVE_TIMES)
def pushSortMemoryThreshold: Long = get(PUSH_SORT_MEMORY_THRESHOLD)
def pushRetryThreads: Int = get(PUSH_RETRY_THREADS)
def pushStageEndTimeout: Long =
@@ -821,6 +815,14 @@ class CelebornConf(loadDefaults: Boolean) extends
Cloneable with Logging with Se
get(COLUMNAR_SHUFFLE_DICTIONARY_ENCODING_MAX_FACTOR)
def columnarShuffleCodeGenEnabled: Boolean =
get(COLUMNAR_SHUFFLE_CODEGEN_ENABLED)
+
+ // //////////////////////////////////////////////////////
+ // test //
+ // //////////////////////////////////////////////////////
+ def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
+ def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
+ def testPushDataTimeout: Boolean = get(TEST_PUSHDATA_TIMEOUT)
+ def testRetryRevive: Boolean = get(TEST_RETRY_REVIVE)
}
object CelebornConf extends Logging {
@@ -1268,6 +1270,22 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(32)
+ val PUSH_MAX_REVIVE_TIMES: ConfigEntry[Int] =
+ buildConf("celeborn.push.revive.maxRetries")
+ .categories("client")
+ .version("0.3.0")
+ .doc("Max retry times for reviving when celeborn push data failed.")
+ .intConf
+ .createWithDefault(5)
+
+ val TEST_RETRY_REVIVE: ConfigEntry[Boolean] =
+ buildConf("celeborn.test.retryRevive")
+ .categories("client")
+ .doc("Fail push data and request for test")
+ .version("0.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
val FETCH_TIMEOUT: ConfigEntry[Long] =
buildConf("celeborn.fetch.timeout")
.withAlternative("rss.fetch.chunk.timeout")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index a15192c8..9e2692c2 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -36,6 +36,7 @@ license: |
| celeborn.push.queue.capacity | 512 | Push buffer queue size for a task. The
maximum memory is `celeborn.push.buffer.max.size` *
`celeborn.push.queue.capacity`, default: 64KiB * 512 = 32MiB | 0.2.0 |
| celeborn.push.replicate.enabled | true | When true, Celeborn worker will
replicate shuffle data to another Celeborn worker asynchronously to ensure the
pushed shuffle data won't be lost after the node failure. | 0.2.0 |
| celeborn.push.retry.threads | 8 | Thread number to process shuffle re-send
push data requests. | 0.2.0 |
+| celeborn.push.revive.maxRetries | 5 | Max retry times for reviving when
celeborn push data failed. | 0.3.0 |
| celeborn.push.sortMemory.threshold | 64m | When SortBasedPusher use memory
over the threshold, will trigger push data. | 0.2.0 |
| celeborn.push.splitPartition.threads | 8 | Thread number to process shuffle
split request in shuffle client. | 0.2.0 |
| celeborn.push.stageEnd.timeout | <undefined> | Timeout for waiting
StageEnd. Default value should be `celeborn.rpc.askTimeout *
(celeborn.rpc.requestCommitFiles.maxRetries + 1)`. | 0.2.0 |
@@ -72,6 +73,7 @@ license: |
| celeborn.storage.hdfs.dir | <undefined> | HDFS dir configuration for
Celeborn to access HDFS. | 0.2.0 |
| celeborn.test.fetchFailure | false | Wheter to test fetch chunk failure |
0.2.0 |
| celeborn.test.retryCommitFiles | false | Fail commitFile request for test |
0.2.0 |
+| celeborn.test.retryRevive | false | Fail push data and request for test |
0.2.0 |
| celeborn.worker.excluded.checkInterval | 30s | Interval for client to
refresh excluded worker list. | 0.2.0 |
| celeborn.worker.excluded.expireTimeout | 600s | Timeout time for
LifecycleManager to clear reserved excluded worker. | 0.2.0 |
<!--end-include-->
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
new file mode 100644
index 00000000..611bc0fe
--- /dev/null
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+
+class RetryReviveTest extends AnyFunSuite
+ with SparkTestBase
+ with BeforeAndAfterEach {
+
+ override def beforeAll(): Unit = {
+ logInfo("test initialized , setup celeborn mini cluster")
+ setUpMiniCluster(masterConfs = null)
+ }
+
+ override def beforeEach(): Unit = {
+ ShuffleClient.reset()
+ }
+
+ override def afterEach(): Unit = {
+ System.gc()
+ }
+
+ test("celeborn spark integration test - retry revive as configured times") {
+ val sparkConf = new SparkConf()
+ .set("spark.celeborn.test.retryRevive", "true")
+ .setAppName("rss-demo").setMaster("local[4]")
+ val ss = SparkSession.builder().config(updateSparkConf(sparkConf,
false)).getOrCreate()
+ ss.sparkContext.parallelize(1 to 1000, 2)
+ .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(16).collect()
+ ss.stop()
+ }
+}