This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new e155ec12 [CELEBORN-190] doPushMergedData should also support revive 
multiple times, not only twice (#1136)
e155ec12 is described below

commit e155ec122adc08b283c30fa6352a028b27029296
Author: Angerszhuuuu <[email protected]>
AuthorDate: Tue Jan 10 11:39:40 2023 +0800

    [CELEBORN-190] doPushMergedData should also support revive multiple times, 
not only twice (#1136)
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 132 +++++++++++++++------
 .../org/apache/celeborn/common/CelebornConf.scala  |  32 +++--
 docs/configuration/client.md                       |   2 +
 .../celeborn/tests/spark/RetryReviveTest.scala     |  53 +++++++++
 4 files changed, 177 insertions(+), 42 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 3735dbe6..d7bc1700 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -85,6 +85,8 @@ public class ShuffleClientImpl extends ShuffleClient {
   private final int registerShuffleMaxRetries;
   private final long registerShuffleRetryWaitMs;
   private int maxInFlight;
+  private int maxReviveTimes;
+  private boolean testRetryRevive;
   private final AtomicInteger currentMaxReqsInFlight;
   private int congestionAvoidanceFlag = 0;
   private final int pushBufferMaxSize;
@@ -139,6 +141,8 @@ public class ShuffleClientImpl extends ShuffleClient {
     registerShuffleMaxRetries = conf.registerShuffleMaxRetry();
     registerShuffleRetryWaitMs = conf.registerShuffleRetryWaitMs();
     maxInFlight = conf.pushMaxReqsInFlight();
+    maxReviveTimes = conf.pushMaxReviveTimes();
+    testRetryRevive = conf.testRetryRevive();
 
     if (conf.pushDataSlowStart()) {
       currentMaxReqsInFlight = new AtomicInteger(1);
@@ -178,11 +182,13 @@ public class ShuffleClientImpl extends ShuffleClient {
       PartitionLocation loc,
       RpcResponseCallback callback,
       PushState pushState,
-      StatusCode cause) {
+      StatusCode cause,
+      int remainReviveTimes) {
     int partitionId = loc.getId();
     if (!revive(
         applicationId, shuffleId, mapId, attemptId, partitionId, 
loc.getEpoch(), loc, cause)) {
-      callback.onFailure(new IOException("Revive Failed"));
+      callback.onFailure(
+          new IOException("Revive Failed, remain revive times " + 
remainReviveTimes));
     } else if (mapperEnded(shuffleId, mapId, attemptId)) {
       logger.debug(
           "Retrying push data, but the mapper(map {} attempt {}) has ended.", 
mapId, attemptId);
@@ -191,15 +197,20 @@ public class ShuffleClientImpl extends ShuffleClient {
       PartitionLocation newLoc = 
reducePartitionMap.get(shuffleId).get(partitionId);
       logger.info("Revive success, new location for reduce {} is {}.", 
partitionId, newLoc);
       try {
-        TransportClient client =
-            dataClientFactory.createClient(newLoc.getHost(), 
newLoc.getPushPort(), partitionId);
-        NettyManagedBuffer newBuffer = new 
NettyManagedBuffer(Unpooled.wrappedBuffer(body));
-        String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
-
-        PushData newPushData =
-            new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(), 
newBuffer);
-        ChannelFuture future = client.pushData(newPushData, callback);
-        pushState.pushStarted(batchId, future, callback);
+        if (!testRetryRevive || remainReviveTimes < 1) {
+          TransportClient client =
+              dataClientFactory.createClient(newLoc.getHost(), 
newLoc.getPushPort(), partitionId);
+          NettyManagedBuffer newBuffer = new 
NettyManagedBuffer(Unpooled.wrappedBuffer(body));
+          String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
+
+          PushData newPushData =
+              new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(), 
newBuffer);
+          ChannelFuture future = client.pushData(newPushData, callback);
+          pushState.pushStarted(batchId, future, callback);
+        } else {
+          throw new RuntimeException(
+              "Mock push data submit retry failed. remainReviveTimes = " + 
remainReviveTimes + ".");
+        }
       } catch (Exception ex) {
         logger.warn(
             "Exception raised while pushing data for shuffle {} map {} attempt 
{}" + " batch {}.",
@@ -221,8 +232,10 @@ public class ShuffleClientImpl extends ShuffleClient {
       int attemptId,
       ArrayList<DataBatches.DataBatch> batches,
       StatusCode cause,
-      Integer oldGroupedBatchId) {
+      Integer oldGroupedBatchId,
+      int remainReviveTimes) {
     HashMap<String, DataBatches> newDataBatchesMap = new HashMap<>();
+    ArrayList<DataBatches.DataBatch> reviveFailedBatchesMap = new 
ArrayList<>();
     for (DataBatches.DataBatch batch : batches) {
       int partitionId = batch.loc.getId();
       if (!revive(
@@ -234,10 +247,16 @@ public class ShuffleClientImpl extends ShuffleClient {
           batch.loc.getEpoch(),
           batch.loc,
           cause)) {
-        pushState.exception.compareAndSet(
-            null,
-            new IOException("Revive Failed in retry push merged data for 
location: " + batch.loc));
-        return;
+
+        if (remainReviveTimes > 0) {
+          reviveFailedBatchesMap.add(batch);
+        } else {
+          pushState.exception.compareAndSet(
+              null,
+              new IOException(
+                  "Revive Failed in retry push merged data for location: " + 
batch.loc));
+          return;
+        }
       } else if (mapperEnded(shuffleId, mapId, attemptId)) {
         logger.debug(
             "Retrying push data, but the mapper(map {} attempt {}) has 
ended.", mapId, attemptId);
@@ -262,9 +281,24 @@ public class ShuffleClientImpl extends ShuffleClient {
           attemptId,
           newDataBatches.requireBatches(),
           pushState,
-          true);
+          remainReviveTimes);
+    }
+    if (reviveFailedBatchesMap.isEmpty()) {
+      pushState.removeBatch(oldGroupedBatchId);
+    } else {
+      pushDataRetryPool.submit(
+          () ->
+              submitRetryPushMergedData(
+                  pushState,
+                  applicationId,
+                  shuffleId,
+                  mapId,
+                  attemptId,
+                  reviveFailedBatchesMap,
+                  cause,
+                  oldGroupedBatchId,
+                  remainReviveTimes - 1));
     }
-    pushState.removeBatch(oldGroupedBatchId);
   }
 
   private String genAddressPair(PartitionLocation loc) {
@@ -652,6 +686,8 @@ public class ShuffleClientImpl extends ShuffleClient {
 
       RpcResponseCallback wrappedCallback =
           new RpcResponseCallback() {
+            int remainReviveTimes = maxReviveTimes;
+
             @Override
             public void onSuccess(ByteBuffer response) {
               if (response.remaining() > 0) {
@@ -683,7 +719,8 @@ public class ShuffleClientImpl extends ShuffleClient {
                               loc,
                               this,
                               pushState,
-                              StatusCode.HARD_SPLIT));
+                              StatusCode.HARD_SPLIT,
+                              remainReviveTimes));
                 } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_MASTER_CONGESTED.getValue()) {
                   logger.debug(
                       "Push data split for map {} attempt {} batch {} return 
master congested.",
@@ -716,6 +753,12 @@ public class ShuffleClientImpl extends ShuffleClient {
               if (pushState.exception.get() != null) {
                 return;
               }
+
+              if (remainReviveTimes <= 0) {
+                callback.onFailure(e);
+                return;
+              }
+
               logger.error(
                   "Push data to {}:{} failed for map {} attempt {} batch {}.",
                   loc.getHost(),
@@ -726,6 +769,7 @@ public class ShuffleClientImpl extends ShuffleClient {
                   e);
               // async retry push data
               if (!mapperEnded(shuffleId, mapId, attemptId)) {
+                remainReviveTimes = remainReviveTimes - 1;
                 pushDataRetryPool.submit(
                     () ->
                         submitRetryPushData(
@@ -736,9 +780,10 @@ public class ShuffleClientImpl extends ShuffleClient {
                             body,
                             nextBatchId,
                             loc,
-                            callback,
+                            this,
                             pushState,
-                            getPushDataFailCause(e.getMessage())));
+                            getPushDataFailCause(e.getMessage()),
+                            remainReviveTimes));
               } else {
                 pushState.removeBatch(nextBatchId);
                 logger.info(
@@ -753,10 +798,14 @@ public class ShuffleClientImpl extends ShuffleClient {
 
       // do push data
       try {
-        TransportClient client =
-            dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), 
partitionId);
-        ChannelFuture future = client.pushData(pushData, wrappedCallback);
-        pushState.pushStarted(nextBatchId, future, wrappedCallback);
+        if (!testRetryRevive) {
+          TransportClient client =
+              dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), 
partitionId);
+          ChannelFuture future = client.pushData(pushData, wrappedCallback);
+          pushState.pushStarted(nextBatchId, future, wrappedCallback);
+        } else {
+          throw new RuntimeException("Mock push data first time failed.");
+        }
       } catch (Exception e) {
         logger.warn("PushData failed", e);
         wrappedCallback.onFailure(
@@ -778,7 +827,7 @@ public class ShuffleClientImpl extends ShuffleClient {
             attemptId,
             dataBatches.requireBatches(),
             pushState,
-            false);
+            maxReviveTimes);
       }
     }
 
@@ -894,7 +943,14 @@ public class ShuffleClientImpl extends ShuffleClient {
       }
       String[] tokens = entry.getKey().split("-");
       doPushMergedData(
-          tokens[0], applicationId, shuffleId, mapId, attemptId, batches, 
pushState, false);
+          tokens[0],
+          applicationId,
+          shuffleId,
+          mapId,
+          attemptId,
+          batches,
+          pushState,
+          maxReviveTimes);
     }
   }
 
@@ -906,7 +962,7 @@ public class ShuffleClientImpl extends ShuffleClient {
       int attemptId,
       ArrayList<DataBatches.DataBatch> batches,
       PushState pushState,
-      boolean revived) {
+      int remainReviveTimes) {
     final String[] splits = hostPort.split(":");
     final String host = splits[0];
     final int port = Integer.parseInt(splits[1]);
@@ -954,7 +1010,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           @Override
           public void onFailure(Throwable e) {
             String errorMsg =
-                (revived ? "Revived push" : "Push")
+                (remainReviveTimes < maxReviveTimes ? "Revived push" : "Push")
                     + " merged data to "
                     + host
                     + ":"
@@ -1001,7 +1057,8 @@ public class ShuffleClientImpl extends ShuffleClient {
                             attemptId,
                             batches,
                             StatusCode.HARD_SPLIT,
-                            groupedBatchId));
+                            groupedBatchId,
+                            remainReviveTimes));
               } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_MASTER_CONGESTED.getValue()) {
                 logger.debug(
                     "Push data split for map {} attempt {} batchs {} return 
master congested.",
@@ -1036,7 +1093,7 @@ public class ShuffleClientImpl extends ShuffleClient {
             if (pushState.exception.get() != null) {
               return;
             }
-            if (revived) {
+            if (remainReviveTimes <= 0) {
               callback.onFailure(e);
               return;
             }
@@ -1064,16 +1121,21 @@ public class ShuffleClientImpl extends ShuffleClient {
                           attemptId,
                           batches,
                           getPushDataFailCause(e.getMessage()),
-                          groupedBatchId));
+                          groupedBatchId,
+                          remainReviveTimes - 1));
             }
           }
         };
 
     // do push merged data
     try {
-      TransportClient client = dataClientFactory.createClient(host, port);
-      ChannelFuture future = client.pushMergedData(mergedData, 
wrappedCallback);
-      pushState.pushStarted(groupedBatchId, future, wrappedCallback);
+      if (!testRetryRevive || remainReviveTimes < 1) {
+        TransportClient client = dataClientFactory.createClient(host, port);
+        ChannelFuture future = client.pushMergedData(mergedData, 
wrappedCallback);
+        pushState.pushStarted(groupedBatchId, future, wrappedCallback);
+      } else {
+        throw new RuntimeException("Mock push merge data failed");
+      }
     } catch (Exception e) {
       logger.warn("PushMergedData failed", e);
       wrappedCallback.onFailure(new 
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 2d2da940..aeff72d4 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -552,13 +552,6 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
       }
     }
 
-  // //////////////////////////////////////////////////////
-  //                      test                           //
-  // //////////////////////////////////////////////////////
-  def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
-  def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
-  def testPushDataTimeout: Boolean = get(TEST_PUSHDATA_TIMEOUT)
-
   def masterHost: String = get(MASTER_HOST)
 
   def masterPort: Int = get(MASTER_PORT)
@@ -663,6 +656,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def pushBufferMaxSize: Int = get(PUSH_BUFFER_MAX_SIZE).toInt
   def pushQueueCapacity: Int = get(PUSH_QUEUE_CAPACITY)
   def pushMaxReqsInFlight: Int = get(PUSH_MAX_REQS_IN_FLIGHT)
+  def pushMaxReviveTimes: Int = get(PUSH_MAX_REVIVE_TIMES)
   def pushSortMemoryThreshold: Long = get(PUSH_SORT_MEMORY_THRESHOLD)
   def pushRetryThreads: Int = get(PUSH_RETRY_THREADS)
   def pushStageEndTimeout: Long =
@@ -821,6 +815,14 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
     get(COLUMNAR_SHUFFLE_DICTIONARY_ENCODING_MAX_FACTOR)
 
   def columnarShuffleCodeGenEnabled: Boolean = 
get(COLUMNAR_SHUFFLE_CODEGEN_ENABLED)
+
+  // //////////////////////////////////////////////////////
+  //                      test                           //
+  // //////////////////////////////////////////////////////
+  def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
+  def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
+  def testPushDataTimeout: Boolean = get(TEST_PUSHDATA_TIMEOUT)
+  def testRetryRevive: Boolean = get(TEST_RETRY_REVIVE)
 }
 
 object CelebornConf extends Logging {
@@ -1268,6 +1270,22 @@ object CelebornConf extends Logging {
       .intConf
       .createWithDefault(32)
 
+  val PUSH_MAX_REVIVE_TIMES: ConfigEntry[Int] =
+    buildConf("celeborn.push.revive.maxRetries")
+      .categories("client")
+      .version("0.3.0")
+      .doc("Max retry times for reviving when celeborn push data failed.")
+      .intConf
+      .createWithDefault(5)
+
+  val TEST_RETRY_REVIVE: ConfigEntry[Boolean] =
+    buildConf("celeborn.test.retryRevive")
+      .categories("client")
+      .doc("Fail push data and request for test")
+      .version("0.2.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val FETCH_TIMEOUT: ConfigEntry[Long] =
     buildConf("celeborn.fetch.timeout")
       .withAlternative("rss.fetch.chunk.timeout")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index a15192c8..9e2692c2 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -36,6 +36,7 @@ license: |
 | celeborn.push.queue.capacity | 512 | Push buffer queue size for a task. The 
maximum memory is `celeborn.push.buffer.max.size` * 
`celeborn.push.queue.capacity`, default: 64KiB * 512 = 32MiB | 0.2.0 | 
 | celeborn.push.replicate.enabled | true | When true, Celeborn worker will 
replicate shuffle data to another Celeborn worker asynchronously to ensure the 
pushed shuffle data won't be lost after the node failure. | 0.2.0 | 
 | celeborn.push.retry.threads | 8 | Thread number to process shuffle re-send 
push data requests. | 0.2.0 | 
+| celeborn.push.revive.maxRetries | 5 | Max retry times for reviving when 
celeborn push data failed. | 0.3.0 | 
 | celeborn.push.sortMemory.threshold | 64m | When SortBasedPusher use memory 
over the threshold, will trigger push data. | 0.2.0 | 
 | celeborn.push.splitPartition.threads | 8 | Thread number to process shuffle 
split request in shuffle client. | 0.2.0 | 
 | celeborn.push.stageEnd.timeout | &lt;undefined&gt; | Timeout for waiting 
StageEnd. Default value should be `celeborn.rpc.askTimeout * 
(celeborn.rpc.requestCommitFiles.maxRetries + 1)`. | 0.2.0 | 
@@ -72,6 +73,7 @@ license: |
 | celeborn.storage.hdfs.dir | &lt;undefined&gt; | HDFS dir configuration for 
Celeborn to access HDFS. | 0.2.0 | 
 | celeborn.test.fetchFailure | false | Wheter to test fetch chunk failure | 
0.2.0 | 
 | celeborn.test.retryCommitFiles | false | Fail commitFile request for test | 
0.2.0 | 
+| celeborn.test.retryRevive | false | Fail push data and request for test | 
0.2.0 | 
 | celeborn.worker.excluded.checkInterval | 30s | Interval for client to 
refresh excluded worker list. | 0.2.0 | 
 | celeborn.worker.excluded.expireTimeout | 600s | Timeout time for 
LifecycleManager to clear reserved excluded worker. | 0.2.0 | 
 <!--end-include-->
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
new file mode 100644
index 00000000..611bc0fe
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+
+class RetryReviveTest extends AnyFunSuite
+  with SparkTestBase
+  with BeforeAndAfterEach {
+
+  override def beforeAll(): Unit = {
+    logInfo("test initialized , setup celeborn mini cluster")
+    setUpMiniCluster(masterConfs = null)
+  }
+
+  override def beforeEach(): Unit = {
+    ShuffleClient.reset()
+  }
+
+  override def afterEach(): Unit = {
+    System.gc()
+  }
+
+  test("celeborn spark integration test - retry revive as configured times") {
+    val sparkConf = new SparkConf()
+      .set("spark.celeborn.test.retryRevive", "true")
+      .setAppName("rss-demo").setMaster("local[4]")
+    val ss = SparkSession.builder().config(updateSparkConf(sparkConf, 
false)).getOrCreate()
+    ss.sparkContext.parallelize(1 to 1000, 2)
+      .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(16).collect()
+    ss.stop()
+  }
+}

Reply via email to