This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new f8bb2cd4 [CELEBORN-12]Retry on CommitFile request (#1011)
f8bb2cd4 is described below

commit f8bb2cd47d5271583f0a343b1173b24d63bb7669
Author: Keyong Zhou <[email protected]>
AuthorDate: Sat Nov 26 20:56:24 2022 +0800

    [CELEBORN-12]Retry on CommitFile request (#1011)
---
 .../apache/celeborn/client/LifecycleManager.scala  | 42 +++++++----
 .../common/protocol/message/StatusCode.java        |  4 +-
 .../org/apache/celeborn/common/CelebornConf.scala  | 28 ++++++-
 docs/configuration/client.md                       |  2 +
 docs/configuration/worker.md                       |  2 +-
 .../apache/celeborn/tests/spark/HugeDataTest.scala |  2 +-
 ...geDataTest.scala => RetryCommitFilesTest.scala} | 16 ++--
 .../apache/celeborn/tests/spark/RssHashSuite.scala |  2 +-
 .../apache/celeborn/tests/spark/RssSortSuite.scala |  2 +-
 .../celeborn/tests/spark/SkewJoinSuite.scala       |  2 +-
 .../celeborn/tests/spark/SparkTestBase.scala       | 12 +--
 .../service/deploy/worker/storage/FileWriter.java  |  5 ++
 .../service/deploy/worker/CommitInfo.scala         | 28 +++++++
 .../service/deploy/worker/Controller.scala         | 85 +++++++++++++++++++---
 .../celeborn/service/deploy/worker/Worker.scala    |  3 +
 15 files changed, 190 insertions(+), 45 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 917f908b..c28621bf 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -85,6 +85,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) 
extends RpcEndpoin
     .maximumSize(rpcCacheSize)
     .build().asInstanceOf[Cache[Int, ByteBuffer]]
 
+  private val testCommitFileFailure = conf.testRetryCommitFiles
+
   @VisibleForTesting
   def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, 
PartitionLocationInfo] =
     shuffleAllocatedWorkers.get(shuffleId)
@@ -992,7 +994,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
           masterIds,
           slaveIds,
           shuffleMapperAttempts.get(shuffleId))
-        val res = requestCommitFiles(worker.endpoint, commitFiles)
+        val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
 
         res.status match {
           case StatusCode.SUCCESS => // do nothing
@@ -1638,21 +1640,35 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     }
   }
 
-  private def requestCommitFiles(
+  private def requestCommitFilesWithRetry(
       endpoint: RpcEndpointRef,
       message: CommitFiles): CommitFilesResponse = {
-    try {
-      endpoint.askSync[CommitFilesResponse](message)
-    } catch {
-      case e: Exception =>
-        logError(s"AskSync CommitFiles for ${message.shuffleId} failed.", e)
-        CommitFilesResponse(
-          StatusCode.FAILED,
-          List.empty.asJava,
-          List.empty.asJava,
-          message.masterIds,
-          message.slaveIds)
+    val maxRetries = conf.requestCommitFilesMaxRetries
+    var retryTimes = 0
+    while (retryTimes < maxRetries) {
+      try {
+        if (testCommitFileFailure && retryTimes < maxRetries - 1) {
+          endpoint.ask[CommitFilesResponse](message)
+          Thread.sleep(1000)
+          throw new Exception("Mock fail for CommitFiles")
+        } else {
+          return endpoint.askSync[CommitFilesResponse](message)
+        }
+      } catch {
+        case e: Exception =>
+          retryTimes += 1
+          logError(
+            s"AskSync CommitFiles for ${message.shuffleId} failed (attempt 
$retryTimes/$maxRetries).",
+            e)
+      }
     }
+
+    CommitFilesResponse(
+      StatusCode.FAILED,
+      List.empty.asJava,
+      List.empty.asJava,
+      message.masterIds,
+      message.slaveIds)
   }
 
   private def requestReleaseSlots(
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index ced17e90..8a507e43 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -53,7 +53,9 @@ public enum StatusCode {
   WORKER_SHUTDOWN(25),
   NO_AVAILABLE_WORKING_DIR(26),
   WORKER_IN_BLACKLIST(27),
-  UNKNOWN_WORKER(28);
+  UNKNOWN_WORKER(28),
+
+  COMMIT_FILE_EXCEPTION(29);
 
   private final byte value;
 
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 16fcdd82..8881a9cb 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -531,6 +531,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def workerExcludedExpireTimeout: Long = get(WORKER_EXCLUDED_EXPIRE_TIMEOUT)
   def shuffleRangeReadFilterEnabled: Boolean = 
get(SHUFFLE_RANGE_READ_FILTER_ENABLED)
   def shufflePartitionType: PartitionType = 
PartitionType.valueOf(get(SHUFFLE_PARTITION_TYPE))
+  def requestCommitFilesMaxRetries: Int = get(COMMIT_FILE_REQUEST_MAX_RETRY)
 
   // //////////////////////////////////////////////////////
   //               Shuffle Compression                   //
@@ -550,6 +551,12 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
       }
     }
 
+  // //////////////////////////////////////////////////////
+  //                      test                           //
+  // //////////////////////////////////////////////////////
+  def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
+  def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE)
+
   def masterHost: String = get(MASTER_HOST)
 
   def masterPort: Int = get(MASTER_PORT)
@@ -643,7 +650,6 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def fetchTimeoutMs: Long = get(FETCH_TIMEOUT)
   def fetchMaxReqsInFlight: Int = get(FETCH_MAX_REQS_IN_FLIGHT)
   def fetchMaxRetries: Int = get(FETCH_MAX_RETRIES)
-  def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE)
 
   // //////////////////////////////////////////////////////
   //               Shuffle Client Push                   //
@@ -1364,6 +1370,23 @@ object CelebornConf extends Logging {
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("3s")
 
+  val COMMIT_FILE_REQUEST_MAX_RETRY: ConfigEntry[Int] =
+    buildConf("celeborn.rpc.requestCommitFiles.maxRetries")
+      .categories("client")
+      .doc("Max retry times for requestCommitFiles RPC.")
+      .version("1.0.0")
+      .intConf
+      .checkValue(v => v > 0, "value must be positive")
+      .createWithDefault(2)
+
+  val TEST_RETRY_COMMIT_FILE: ConfigEntry[Boolean] =
+    buildConf("celeborn.test.retryCommitFiles")
+      .categories("client")
+      .doc("Fail commitFile request for test")
+      .version("0.2.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val MASTER_HOST: ConfigEntry[String] =
     buildConf("celeborn.master.host")
       .categories("master")
@@ -1796,8 +1819,7 @@ object CelebornConf extends Logging {
       .categories("worker")
       .doc("Timeout for a Celeborn worker to commit files of a shuffle.")
       .version("0.2.0")
-      .timeConf(TimeUnit.SECONDS)
-      .createWithDefaultString("120s")
+      .fallbackConf(RPC_ASK_TIMEOUT)
 
   val PARTITION_SORTER_SORT_TIMEOUT: ConfigEntry[Long] =
     buildConf("celeborn.worker.partitionSorter.sort.timeout")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 111452cb..3910f1d7 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -41,6 +41,7 @@ license: |
 | celeborn.rpc.cache.expireTime | 15s | The time before a cache item is 
removed. | 0.2.0 | 
 | celeborn.rpc.cache.size | 256 | The max cache items count for rpc cache. | 
0.2.0 | 
 | celeborn.rpc.maxParallelism | 1024 | Max parallelism of client on sending 
RPC requests. | 0.2.0 | 
+| celeborn.rpc.requestCommitFiles.maxRetries | 2 | Max retry times for 
requestCommitFiles RPC. | 1.0.0 | 
 | celeborn.shuffle.batchHandleChangePartition.enabled | false | When true, 
LifecycleManager will handle change partition request in batch. Otherwise, 
LifecycleManager will process the requests one by one | 0.2.0 | 
 | celeborn.shuffle.batchHandleChangePartition.interval | 100ms | Interval for 
LifecycleManager to schedule handling change partition requests in batch. | 
0.2.0 | 
 | celeborn.shuffle.batchHandleChangePartition.threads | 8 | Threads number for 
LifecycleManager to handle change partition request in batch. | 0.2.0 | 
@@ -62,6 +63,7 @@ license: |
 | celeborn.slots.reserve.retryWait | 3s | Wait time before next retry if 
reserve slots failed. | 0.2.0 | 
 | celeborn.storage.hdfs.dir | &lt;undefined&gt; | HDFS dir configuration for 
Celeborn to access HDFS. | 0.2.0 | 
 | celeborn.test.fetchFailure | false | Wheter to test fetch chunk failure | 
0.2.0 | 
+| celeborn.test.retryCommitFiles | false | Fail commitFile request for test | 
0.2.0 | 
 | celeborn.worker.excluded.checkInterval | 30s | Interval for client to 
refresh excluded worker list. | 0.2.0 | 
 | celeborn.worker.excluded.expireTimeout | 600s | Timeout time for 
LifecycleManager to clear reserved excluded worker. | 0.2.0 | 
 <!--end-include-->
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 03881f93..bf1e48ac 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -73,7 +73,7 @@ license: |
 | celeborn.worker.replicate.port | 0 | Server port for Worker to receive 
replicate data request from other Workers. | 0.2.0 | 
 | celeborn.worker.replicate.threads | 64 | Thread number of worker to 
replicate shuffle data. | 0.2.0 | 
 | celeborn.worker.rpc.port | 0 | Server port for Worker to receive RPC 
request. | 0.2.0 | 
-| celeborn.worker.shuffle.commit.timeout | 120s | Timeout for a Celeborn 
worker to commit files of a shuffle. | 0.2.0 | 
+| celeborn.worker.shuffle.commit.timeout | &lt;value of 
celeborn.rpc.askTimeout&gt; | Timeout for a Celeborn worker to commit files of 
a shuffle. | 0.2.0 | 
 | celeborn.worker.storage.baseDir.number | 16 | How many directories will be 
used if `celeborn.worker.storage.dirs` is not set. The directory name is a 
combination of `celeborn.worker.storage.baseDir.prefix` and from one(inclusive) 
to `celeborn.worker.storage.baseDir.number`(inclusive) step by one. | 0.2.0 | 
 | celeborn.worker.storage.baseDir.prefix | /mnt/disk | Base directory for 
Celeborn worker to write if `celeborn.worker.storage.dirs` is not set. | 0.2.0 
| 
 | celeborn.worker.storage.dirs | &lt;undefined&gt; | Directory list to store 
shuffle data. It's recommended to configure one directory on each disk. Storage 
size limit can be set for each directory. For the sake of performance, there 
should be no more than 2 flush threads on the same disk partition if you are 
using HDD, and should be 8 or more flush threads on the same disk partition if 
you are using SSD. For example: 
`dir1[:capacity=][:disktype=][:flushthread=],dir2[:capacity=][:disktyp [...]
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
index a45258d6..932998e7 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
@@ -31,7 +31,7 @@ class HugeDataTest extends AnyFunSuite
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniCluster()
+    tuple = setupRssMiniClusterSpark()
   }
 
   override def afterAll(): Unit = {
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
similarity index 74%
copy from 
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
copy to 
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
index a45258d6..afacdf5d 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
@@ -24,14 +24,16 @@ import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 
-class HugeDataTest extends AnyFunSuite
+class RetryCommitFilesTest extends AnyFunSuite
   with SparkTestBase
   with BeforeAndAfterAll
   with BeforeAndAfterEach {
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniCluster()
+    val workerConf = Map(
+      "celeborn.test.retryCommitFiles" -> s"true")
+    tuple = setupRssMiniClusterSpark(masterConfs = null, workerConfs = 
workerConf)
   }
 
   override def afterAll(): Unit = {
@@ -47,11 +49,13 @@ class HugeDataTest extends AnyFunSuite
     System.gc()
   }
 
-  test("celeborn spark integration test - huge data") {
-    val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[4]")
+  test("celeborn spark integration test - retry commit files") {
+    val sparkConf = new SparkConf()
+      .set("spark.celeborn.test.retryCommitFiles", "true")
+      .setAppName("rss-demo").setMaster("local[4]")
     val ss = SparkSession.builder().config(updateSparkConf(sparkConf, 
false)).getOrCreate()
-    ss.sparkContext.parallelize(1 to 10000, 2)
-      .map { i => (i, Range(1, 10000).mkString(",")) }.groupByKey(16).collect()
+    ss.sparkContext.parallelize(1 to 1000, 2)
+      .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(16).collect()
     ss.stop()
   }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
index 2146ae87..6e73c217 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
@@ -31,7 +31,7 @@ class RssHashSuite extends AnyFunSuite
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniCluster()
+    tuple = setupRssMiniClusterSpark()
   }
 
   override def afterAll(): Unit = {
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
index 89f53d25..ea9537f0 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
@@ -31,7 +31,7 @@ class RssSortSuite extends AnyFunSuite
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniCluster()
+    tuple = setupRssMiniClusterSpark()
   }
 
   override def afterAll(): Unit = {
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
index bf6e81e0..19e74005 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
@@ -35,7 +35,7 @@ class SkewJoinSuite extends AnyFunSuite
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniCluster()
+    tuple = setupRssMiniClusterSpark()
   }
 
   override def afterAll(): Unit = {
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index 6e3aac2d..92d79d5f 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -76,7 +76,9 @@ trait SparkTestBase extends Logging with MiniClusterFeature {
     tuple._12.interrupt()
   }
 
-  def setupRssMiniCluster(): (
+  def setupRssMiniClusterSpark(
+      masterConfs: Map[String, String] = null,
+      workerConfs: Map[String, String] = null): (
       Master,
       RpcEnv,
       Worker,
@@ -91,10 +93,10 @@ trait SparkTestBase extends Logging with MiniClusterFeature 
{
       Thread) = {
     Thread.sleep(3000L)
 
-    val (master, masterRpcEnv) = createMaster()
-    val (worker1, workerRpcEnv1) = createWorker()
-    val (worker2, workerRpcEnv2) = createWorker()
-    val (worker3, workerRpcEnv3) = createWorker()
+    val (master, masterRpcEnv) = createMaster(masterConfs)
+    val (worker1, workerRpcEnv1) = createWorker(workerConfs)
+    val (worker2, workerRpcEnv2) = createWorker(workerConfs)
+    val (worker3, workerRpcEnv3) = createWorker(workerConfs)
     val masterThread = runnerWrap(masterRpcEnv.awaitTermination())
     val workerThread1 = runnerWrap(worker1.initialize())
     val workerThread2 = runnerWrap(worker2.initialize())
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
index 004a41c5..52e3975d 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java
@@ -200,6 +200,11 @@ public final class FileWriter implements DeviceObserver {
     final int numBytes = data.readableBytes();
     MemoryTracker.instance().incrementDiskBuffer(numBytes);
     synchronized (this) {
+      if (closed) {
+        String msg = "FileWriter has already closed!, fileName " + 
fileInfo.getFilePath();
+        logger.warn(msg);
+        throw new AlreadyClosedException(msg);
+      }
       if (rangeReadFilter) {
         mapIdBitMap.add(mapId);
       }
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/CommitInfo.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/CommitInfo.scala
new file mode 100644
index 00000000..4778d148
--- /dev/null
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/CommitInfo.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker
+
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.CommitFilesResponse
+
+class CommitInfo(var response: CommitFilesResponse, var status: Int)
+
+object CommitInfo {
+  val COMMIT_NOTSTARTED: Int = 0
+  val COMMIT_INPROCESS: Int = 1
+  val COMMIT_FINISHED: Int = 2
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index 18e48433..992779a8 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -49,6 +49,8 @@ private[deploy] class Controller(
   var workerSource: WorkerSource = _
   var storageManager: StorageManager = _
   var shuffleMapperAttempts: ConcurrentHashMap[String, Array[Int]] = _
+  // shuffleKey -> (CommitInfo)
+  var shuffleCommitInfos: ConcurrentHashMap[String, CommitInfo] = _
   var workerInfo: WorkerInfo = _
   var partitionLocationInfo: PartitionLocationInfo = _
   var timer: HashedWheelTimer = _
@@ -57,10 +59,13 @@ private[deploy] class Controller(
   val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate
   var shutdown: AtomicBoolean = _
 
+  val testCommitFileFailure = conf.testRetryCommitFiles
+
   def init(worker: Worker): Unit = {
     workerSource = worker.workerSource
     storageManager = worker.storageManager
     shuffleMapperAttempts = worker.shuffleMapperAttempts
+    shuffleCommitInfos = worker.shuffleCommitInfos
     workerInfo = worker.workerInfo
     partitionLocationInfo = worker.partitionLocationInfo
     timer = worker.timer
@@ -312,7 +317,8 @@ private[deploy] class Controller(
       slaveIds: jList[String],
       mapAttempts: Array[Int]): Unit = {
     // return null if shuffleKey does not exist
-    if (!partitionLocationInfo.containsShuffle(shuffleKey)) {
+    if (!partitionLocationInfo.containsShuffle(shuffleKey) && 
!shuffleCommitInfos.containsKey(
+        shuffleKey)) {
       logError(s"Shuffle $shuffleKey doesn't exist!")
       context.reply(
         CommitFilesResponse(
@@ -324,6 +330,45 @@ private[deploy] class Controller(
       return
     }
 
+    val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
+
+    shuffleCommitInfos.putIfAbsent(shuffleKey, new CommitInfo(null, 
CommitInfo.COMMIT_NOTSTARTED))
+    val commitInfo = shuffleCommitInfos.get(shuffleKey)
+
+    def waitForCommitFinish(): Unit = {
+      val delta = 100
+      var times = 0
+      while (delta * times < shuffleCommitTimeout) {
+        commitInfo.synchronized {
+          if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
+            context.reply(commitInfo.response)
+            return
+          }
+        }
+        Thread.sleep(delta)
+        times += 1
+      }
+    }
+
+    commitInfo.synchronized {
+      if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
+        logInfo(s"${shuffleKey} CommitFinished, just return the response")
+        context.reply(commitInfo.response)
+        return
+      } else if (commitInfo.status == CommitInfo.COMMIT_INPROCESS) {
+        logInfo(s"${shuffleKey} CommitFiles inprogress, wait for finish")
+        commitThreadPool.submit(new Runnable {
+          override def run(): Unit = {
+            waitForCommitFinish()
+          }
+        })
+        return
+      } else {
+        logInfo(s"Start commitFiles for ${shuffleKey}")
+        commitInfo.status = CommitInfo.COMMIT_INPROCESS
+      }
+    }
+
     // close and flush files.
     shuffleMapperAttempts.putIfAbsent(shuffleKey, mapAttempts)
 
@@ -391,10 +436,10 @@ private[deploy] class Controller(
       val totalSize = partitionSizeList.asScala.sum
       val fileCount = partitionSizeList.size()
       // reply
-      if (failedMasterIds.isEmpty && failedSlaveIds.isEmpty) {
-        logInfo(s"CommitFiles for $shuffleKey success with 
${committedMasterIds.size()}" +
-          s" master partitions and ${committedSlaveIds.size()} slave 
partitions!")
-        context.reply(
+      val response =
+        if (failedMasterIds.isEmpty && failedSlaveIds.isEmpty) {
+          logInfo(s"CommitFiles for $shuffleKey success with 
${committedMasterIds.size()}" +
+            s" master partitions and ${committedSlaveIds.size()} slave 
partitions!")
           CommitFilesResponse(
             StatusCode.SUCCESS,
             committedMasterIdList,
@@ -405,11 +450,10 @@ private[deploy] class Controller(
             committedSlaveStorageAndDiskHintList,
             committedMapIdBitMapList,
             totalSize,
-            fileCount))
-      } else {
-        logWarning(s"CommitFiles for $shuffleKey failed with 
${failedMasterIds.size()} master" +
-          s" partitions and ${failedSlaveIds.size()} slave partitions!")
-        context.reply(
+            fileCount)
+        } else {
+          logWarning(s"CommitFiles for $shuffleKey failed with 
${failedMasterIds.size()} master" +
+            s" partitions and ${failedSlaveIds.size()} slave partitions!")
           CommitFilesResponse(
             StatusCode.PARTIAL_SUCCESS,
             committedMasterIdList,
@@ -420,13 +464,20 @@ private[deploy] class Controller(
             committedSlaveStorageAndDiskHintList,
             committedMapIdBitMapList,
             totalSize,
-            fileCount))
+            fileCount)
+        }
+      if (testCommitFileFailure) {
+        Thread.sleep(5000)
       }
+      commitInfo.synchronized {
+        commitInfo.response = response
+        commitInfo.status = CommitInfo.COMMIT_FINISHED
+      }
+      context.reply(response)
     }
 
     if (future != null) {
       val result = new AtomicReference[CompletableFuture[Unit]]()
-      val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
 
       val timeout = timer.newTimeout(
         new TimerTask {
@@ -458,6 +509,16 @@ private[deploy] class Controller(
                 case throwable: Throwable =>
                   logError("While handling commitFiles, exception occurs.", 
throwable)
               }
+              commitInfo.synchronized {
+                commitInfo.response = CommitFilesResponse(
+                  StatusCode.COMMIT_FILE_EXCEPTION,
+                  List.empty.asJava,
+                  List.empty.asJava,
+                  masterIds,
+                  slaveIds)
+
+                commitInfo.status = CommitInfo.COMMIT_FINISHED
+              }
             } else {
               // finish, cancel timeout job first.
               timeout.cancel()
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index a0067a3a..8633795b 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -179,6 +179,8 @@ private[celeborn] class Worker(
   val shuffleMapperAttempts = new ConcurrentHashMap[String, Array[Int]]()
   val partitionLocationInfo = new PartitionLocationInfo
 
+  val shuffleCommitInfos = new ConcurrentHashMap[String, CommitInfo]()
+
   private val rssHARetryClient = new RssHARetryClient(rpcEnv, conf)
 
   // (workerInfo -> last connect timeout timestamp)
@@ -420,6 +422,7 @@ private[celeborn] class Worker(
       partitionLocationInfo.removeMasterPartitions(shuffleKey)
       partitionLocationInfo.removeSlavePartitions(shuffleKey)
       shuffleMapperAttempts.remove(shuffleKey)
+      shuffleCommitInfos.remove(shuffleKey)
       workerInfo.releaseSlots(shuffleKey)
       logInfo(s"Cleaned up expired shuffle $shuffleKey")
     }

Reply via email to