This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new f2751c280 [CELEBORN-1829] Replace waitThreadPoll's thread pool with 
ScheduledExecutorService in Controller
f2751c280 is described below

commit f2751c2802407d6e999cab6bfc50e24f163f0e4a
Author: zhengtao <[email protected]>
AuthorDate: Sat Jan 18 13:00:04 2025 +0800

    [CELEBORN-1829] Replace waitThreadPoll's thread pool with 
ScheduledExecutorService in Controller
    
    ### What changes were proposed in this pull request?
    1. Replace waitThreadPoll's thread pool with ScheduledExecutorService.
    2. commitFile should reply when shuffleCommitTimeout.
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Cluster test & UT.
    
    Closes #3059 from zaynt4606/clb1829.
    
    Authored-by: zhengtao <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../org/apache/celeborn/common/CelebornConf.scala  |  14 +-
 docs/configuration/worker.md                       |   2 +-
 .../service/deploy/worker/Controller.scala         | 119 +++++++++++++----
 .../celeborn/service/deploy/worker/Worker.scala    |  13 +-
 .../deploy/worker/{storage => }/WorkerSuite.scala  | 147 ++++++++++++++++++++-
 5 files changed, 252 insertions(+), 43 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 80d31d747..791c6fc2f 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -828,7 +828,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def workerReplicateThreads: Int = get(WORKER_REPLICATE_THREADS)
   def workerCommitThreads: Int =
     if (hasHDFSStorage) Math.max(128, get(WORKER_COMMIT_THREADS)) else 
get(WORKER_COMMIT_THREADS)
-  def workerCommitFilesWaitThreads: Int = get(WORKER_COMMIT_FILES_WAIT_THREADS)
+  def workerCommitFilesCheckInterval: Long = 
get(WORKER_COMMIT_FILES_CHECK_INTERVAL)
   def workerCleanThreads: Int = get(WORKER_CLEAN_THREADS)
   def workerShuffleCommitTimeout: Long = get(WORKER_SHUFFLE_COMMIT_TIMEOUT)
   def maxPartitionSizeToEstimate: Long =
@@ -3483,13 +3483,13 @@ object CelebornConf extends Logging {
       .intConf
       .createWithDefault(32)
 
-  val WORKER_COMMIT_FILES_WAIT_THREADS: ConfigEntry[Int] =
-    buildConf("celeborn.worker.commitFiles.wait.threads")
+  val WORKER_COMMIT_FILES_CHECK_INTERVAL: ConfigEntry[Long] =
+    buildConf("celeborn.worker.commitFiles.check.interval")
       .categories("worker")
-      .version("0.5.0")
-      .doc("Thread number of worker to wait for commit shuffle data files to 
finish.")
-      .intConf
-      .createWithDefault(32)
+      .version("0.6.0")
+      .doc("Time length for a window about checking whether commit shuffle 
data files finished.")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("100")
 
   val WORKER_CLEAN_THREADS: ConfigEntry[Int] =
     buildConf("celeborn.worker.clean.threads")
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 29fb8457a..14c7e791c 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -56,9 +56,9 @@ license: |
 | celeborn.worker.bufferStream.threadsPerMountpoint | 8 | false | Threads 
count for read buffer per mount point. | 0.3.0 |  | 
 | celeborn.worker.clean.threads | 64 | false | Thread number of worker to 
clean up expired shuffle keys. | 0.3.2 |  | 
 | celeborn.worker.closeIdleConnections | false | false | Whether worker will 
close idle connections. | 0.2.0 |  | 
+| celeborn.worker.commitFiles.check.interval | 100 | false | Time length for a 
window about checking whether commit shuffle data files finished. | 0.6.0 |  | 
 | celeborn.worker.commitFiles.threads | 32 | false | Thread number of worker 
to commit shuffle data files asynchronously. It's recommended to set at least 
`128` when `HDFS` is enabled in `celeborn.storage.availableTypes`. | 0.3.0 | 
celeborn.worker.commit.threads | 
 | celeborn.worker.commitFiles.timeout | 120s | false | Timeout for a Celeborn 
worker to commit files of a shuffle. It's recommended to set at least `240s` 
when `HDFS` is enabled in `celeborn.storage.availableTypes`. | 0.3.0 | 
celeborn.worker.shuffle.commit.timeout | 
-| celeborn.worker.commitFiles.wait.threads | 32 | false | Thread number of 
worker to wait for commit shuffle data files to finish. | 0.5.0 |  | 
 | celeborn.worker.congestionControl.check.interval | 10ms | false | Interval 
of worker checks congestion if celeborn.worker.congestionControl.enabled is 
true. | 0.3.2 |  | 
 | celeborn.worker.congestionControl.diskBuffer.high.watermark | 
9223372036854775807b | false | If the total bytes in disk buffer exceeds this 
configure, will start to congest users whose produce rate is higher than the 
potential average consume rate. The congestion will stop if the produce rate is 
lower or equal to the average consume rate, or the total pending bytes lower 
than celeborn.worker.congestionControl.diskBuffer.low.watermark | 0.3.0 | 
celeborn.worker.congestionControl.high.wat [...]
 | celeborn.worker.congestionControl.diskBuffer.low.watermark | 
9223372036854775807b | false | Will stop congest users if the total pending 
bytes of disk buffer is lower than this configuration | 0.3.0 | 
celeborn.worker.congestionControl.low.watermark | 
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index 446d28c8d..65285f647 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -52,18 +52,24 @@ private[deploy] class Controller(
   var shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] = _
   // shuffleKey -> (epoch -> CommitInfo)
   var shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, 
CommitInfo]] = _
+  // shuffleKey -> (epoch -> (commitWaitTimestamp, RpcContext))
+  var shuffleCommitTime
+      : ConcurrentHashMap[String, ConcurrentHashMap[Long, (Long, 
RpcCallContext)]] =
+    _
   var shufflePartitionType: ConcurrentHashMap[String, PartitionType] = _
   var shufflePushDataTimeout: ConcurrentHashMap[String, Long] = _
   var workerInfo: WorkerInfo = _
   var partitionLocationInfo: WorkerPartitionLocationInfo = _
   var timer: HashedWheelTimer = _
   var commitThreadPool: ThreadPoolExecutor = _
-  var waitThreadPool: ThreadPoolExecutor = _
+  var commitFinishedChecker: ScheduledExecutorService = _
   var asyncReplyPool: ScheduledExecutorService = _
   val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate
   var shutdown: AtomicBoolean = _
   val defaultPushdataTimeout = conf.pushDataTimeoutMs
   val mockCommitFilesFailure = conf.testMockCommitFilesFailure
+  val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
+  val workerCommitFilesCheckInterval = conf.workerCommitFilesCheckInterval
 
   def init(worker: Worker): Unit = {
     storageManager = worker.storageManager
@@ -71,13 +77,24 @@ private[deploy] class Controller(
     shufflePushDataTimeout = worker.shufflePushDataTimeout
     shuffleMapperAttempts = worker.shuffleMapperAttempts
     shuffleCommitInfos = worker.shuffleCommitInfos
+    shuffleCommitTime = worker.shuffleCommitTime
     workerInfo = worker.workerInfo
     partitionLocationInfo = worker.partitionLocationInfo
     timer = worker.timer
     commitThreadPool = worker.commitThreadPool
-    waitThreadPool = worker.waitThreadPool
     asyncReplyPool = worker.asyncReplyPool
     shutdown = worker.shutdown
+
+    commitFinishedChecker = worker.commitFinishedChecker
+    commitFinishedChecker.scheduleWithFixedDelay(
+      new Runnable {
+        override def run(): Unit = {
+          checkCommitTimeout(shuffleCommitTime)
+        }
+      },
+      0,
+      workerCommitFilesCheckInterval,
+      TimeUnit.MILLISECONDS)
   }
 
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
@@ -313,7 +330,7 @@ private[deploy] class Controller(
                 }
 
                 val fileWriter = 
location.asInstanceOf[WorkingPartition].getFileWriter
-                waitMapPartitionRegionFinished(fileWriter, 
conf.workerShuffleCommitTimeout)
+                waitMapPartitionRegionFinished(fileWriter, 
shuffleCommitTimeout)
                 val bytes = fileWriter.close()
                 if (bytes > 0L) {
                   if (fileWriter.getStorageInfo == null) {
@@ -410,27 +427,20 @@ private[deploy] class Controller(
       return
     }
 
-    val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
-
-    shuffleCommitInfos.putIfAbsent(shuffleKey, 
JavaUtils.newConcurrentHashMap[Long, CommitInfo]())
+    shuffleCommitInfos.putIfAbsent(
+      shuffleKey,
+      JavaUtils.newConcurrentHashMap[Long, CommitInfo]())
     val epochCommitMap = shuffleCommitInfos.get(shuffleKey)
-    epochCommitMap.putIfAbsent(epoch, new CommitInfo(null, 
CommitInfo.COMMIT_NOTSTARTED))
-    val commitInfo = epochCommitMap.get(epoch)
 
-    def waitForCommitFinish(): Unit = {
-      val delta = 100
-      var times = 0
-      while (delta * times < shuffleCommitTimeout) {
-        commitInfo.synchronized {
-          if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
-            context.reply(commitInfo.response)
-            return
-          }
-        }
-        Thread.sleep(delta)
-        times += 1
-      }
-    }
+    // to store the primaryIds and replicaIds
+    val response = CommitFilesResponse(
+      null,
+      List.empty.asJava,
+      List.empty.asJava,
+      primaryIds,
+      replicaIds)
+    epochCommitMap.putIfAbsent(epoch, new CommitInfo(response, 
CommitInfo.COMMIT_NOTSTARTED))
+    val commitInfo = epochCommitMap.get(epoch)
 
     commitInfo.synchronized {
       if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
@@ -439,12 +449,14 @@ private[deploy] class Controller(
         return
       } else if (commitInfo.status == CommitInfo.COMMIT_INPROCESS) {
         logInfo(s"$shuffleKey CommitFiles inprogress, wait for finish")
-        // should not use commitThreadPool in case of block by commit files.
-        waitThreadPool.submit(new Runnable {
-          override def run(): Unit = {
-            waitForCommitFinish()
-          }
-        })
+        // Replace the ThreadPool to avoid blocking
+        // Read and write security of epoch in epochWaitTimeMap is guaranteed 
by commitInfo's lock
+        shuffleCommitTime.putIfAbsent(
+          shuffleKey,
+          JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]())
+        val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey)
+        val commitStartWaitTime = System.currentTimeMillis()
+        epochWaitTimeMap.put(epoch, (commitStartWaitTime, context))
         return
       } else {
         logInfo(s"Start commitFiles for $shuffleKey")
@@ -730,6 +742,57 @@ private[deploy] class Controller(
     }
   }
 
+  def checkCommitTimeout(shuffleCommitTime: ConcurrentHashMap[
+    String,
+    ConcurrentHashMap[Long, (Long, RpcCallContext)]]): Unit = {
+
+    val currentTime = System.currentTimeMillis()
+    val commitTimeIterator = shuffleCommitTime.entrySet().iterator()
+    while (commitTimeIterator.hasNext) {
+      val timeMapEntry = commitTimeIterator.next()
+      val shuffleKey = timeMapEntry.getKey
+      val epochWaitTimeMap = timeMapEntry.getValue
+      val epochIterator = epochWaitTimeMap.entrySet().iterator()
+
+      while (epochIterator.hasNext && 
shuffleCommitInfos.containsKey(shuffleKey)) {
+        val epochWaitTimeEntry = epochIterator.next()
+        val epoch = epochWaitTimeEntry.getKey
+        val (commitStartWaitTime, context) = epochWaitTimeEntry.getValue
+        try {
+          val commitInfo = shuffleCommitInfos.get(shuffleKey).get(epoch)
+          commitInfo.synchronized {
+            if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
+              context.reply(commitInfo.response)
+              epochIterator.remove()
+            } else {
+              if (currentTime - commitStartWaitTime >= shuffleCommitTimeout) {
+                val replyResponse = CommitFilesResponse(
+                  StatusCode.COMMIT_FILE_EXCEPTION,
+                  List.empty.asJava,
+                  List.empty.asJava,
+                  commitInfo.response.failedPrimaryIds,
+                  commitInfo.response.failedReplicaIds)
+                commitInfo.status = CommitInfo.COMMIT_FINISHED
+                commitInfo.response = replyResponse
+                context.reply(replyResponse)
+                epochIterator.remove()
+              }
+            }
+          }
+        } catch {
+          case error: Exception =>
+            epochIterator.remove()
+            logWarning(
+              s"Exception occurs when checkCommitTimeout for 
shuffleKey-epoch:$shuffleKey-$epoch, error: $error")
+        }
+      }
+      if (!shuffleCommitInfos.containsKey(shuffleKey)) {
+        logWarning(s"Shuffle $shuffleKey commit expired when 
checkCommitTimeout.")
+        commitTimeIterator.remove()
+      }
+    }
+  }
+
   private def updateShuffleMapperAttempts(
       mapAttempts: Array[Int],
       shuffleMapperAttempts: AtomicIntegerArray): Unit = {
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index ac9dc6328..f9eb6f7f7 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -311,6 +311,10 @@ private[celeborn] class Worker(
   val shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, 
CommitInfo]] =
     JavaUtils.newConcurrentHashMap[String, ConcurrentHashMap[Long, 
CommitInfo]]()
 
+  val shuffleCommitTime
+      : ConcurrentHashMap[String, ConcurrentHashMap[Long, (Long, 
RpcCallContext)]] =
+    JavaUtils.newConcurrentHashMap[String, ConcurrentHashMap[Long, (Long, 
RpcCallContext)]]()
+
   private val masterClient = new MasterClient(internalRpcEnvInUse, conf, true)
   secretRegistry.initialize(masterClient)
 
@@ -331,8 +335,8 @@ private[celeborn] class Worker(
     ThreadUtils.newDaemonCachedThreadPool("worker-data-replicator", 
conf.workerReplicateThreads)
   val commitThreadPool: ThreadPoolExecutor =
     ThreadUtils.newDaemonCachedThreadPool("worker-files-committer", 
conf.workerCommitThreads)
-  val waitThreadPool: ThreadPoolExecutor =
-    ThreadUtils.newDaemonCachedThreadPool("worker-commit-waiter", 
conf.workerCommitFilesWaitThreads)
+  val commitFinishedChecker: ScheduledExecutorService =
+    ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-commit-checker")
   val cleanThreadPool: ThreadPoolExecutor =
     ThreadUtils.newDaemonCachedThreadPool(
       "worker-expired-shuffle-cleaner",
@@ -593,13 +597,13 @@ private[celeborn] class Worker(
         forwardMessageScheduler.shutdown()
         replicateThreadPool.shutdown()
         commitThreadPool.shutdown()
-        waitThreadPool.shutdown();
+        commitFinishedChecker.shutdown();
         asyncReplyPool.shutdown()
       } else {
         forwardMessageScheduler.shutdownNow()
         replicateThreadPool.shutdownNow()
         commitThreadPool.shutdownNow()
-        waitThreadPool.shutdownNow();
+        commitFinishedChecker.shutdownNow();
         asyncReplyPool.shutdownNow()
       }
       workerSource.appActiveConnections.clear()
@@ -757,6 +761,7 @@ private[celeborn] class Worker(
         shufflePushDataTimeout.remove(shuffleKey)
         shuffleMapperAttempts.remove(shuffleKey)
         shuffleCommitInfos.remove(shuffleKey)
+        shuffleCommitTime.remove(shuffleKey)
         workerInfo.releaseSlots(shuffleKey)
         val applicationId = Utils.splitShuffleKey(shuffleKey)._1
         if (!workerInfo.getApplicationIdSet.contains(applicationId)) {
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/WorkerSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala
similarity index 50%
rename from 
worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/WorkerSuite.scala
rename to 
worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala
index 6f0521bf5..c28aa9f58 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/WorkerSuite.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala
@@ -15,26 +15,31 @@
  * limitations under the License.
  */
 
-package org.apache.celeborn.service.deploy.worker.storage
+package org.apache.celeborn.service.deploy.worker
 
 import java.io.File
 import java.nio.file.{Files, Paths}
 import java.util
 import java.util.{HashSet => JHashSet}
+import java.util.concurrent.ConcurrentHashMap
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 
 import org.junit.Assert
 import org.mockito.MockitoSugar._
-import org.scalatest.BeforeAndAfterEach
+import org.scalatest.{shortstacks, BeforeAndAfterEach}
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.protocol.{PartitionLocation, 
PartitionSplitMode, PartitionType}
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.CommitFilesResponse
+import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.quota.ResourceConsumption
+import org.apache.celeborn.common.rpc.RpcCallContext
 import org.apache.celeborn.common.util.{CelebornExitKind, JavaUtils, 
ThreadUtils}
-import org.apache.celeborn.service.deploy.worker.{Worker, WorkerArguments}
+import org.apache.celeborn.service.deploy.worker.storage.PartitionDataWriter
 
 class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach {
   private var worker: Worker = _
@@ -161,4 +166,140 @@ class WorkerSuite extends AnyFunSuite with 
BeforeAndAfterEach {
       Map("app1" -> ResourceConsumption(1024, 1, 0, 0)).asJava)).asJava)
     assert(worker.resourceConsumptionSource.gauges().size == 4)
   }
+
+  test("test checkCommitTimeout in controller") {
+    conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, "/tmp")
+    conf.set(CelebornConf.WORKER_SHUFFLE_COMMIT_TIMEOUT.key, "1000")
+    worker = new Worker(conf, workerArgs)
+    val controller = worker.controller
+    controller.init(worker)
+    val shuffleCommitInfos = controller.shuffleCommitInfos
+    val shuffleCommitTime = controller.shuffleCommitTime
+    ThreadUtils.shutdown(worker.controller.commitFinishedChecker)
+    shuffleCommitInfos.clear()
+    shuffleCommitTime.clear()
+
+    val shuffleKey = "1"
+    val context = mock[RpcCallContext]
+    shuffleCommitInfos.putIfAbsent(
+      shuffleKey,
+      JavaUtils.newConcurrentHashMap[Long, CommitInfo]())
+    val epochCommitMap = shuffleCommitInfos.get(shuffleKey)
+    val primaryIds = List("0", "1", "2", "3")
+    val replicaIds = List("4", "5", "6", "7")
+    val epoch0: Long = 0
+    val epoch1: Long = 1
+    val epoch2: Long = 2
+    val epoch3: Long = 3
+    val startWaitTime = System.currentTimeMillis()
+
+    // update an INPROCESS commitInfo
+    val response0 = CommitFilesResponse(
+      null,
+      List.empty.asJava,
+      List.empty.asJava,
+      primaryIds.asJava,
+      replicaIds.asJava)
+    epochCommitMap.putIfAbsent(epoch0, new CommitInfo(response0, 
CommitInfo.COMMIT_INPROCESS))
+
+    val commitInfo0 = epochCommitMap.get(epoch0)
+    commitInfo0.synchronized {
+      shuffleCommitTime.putIfAbsent(
+        shuffleKey,
+        JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]())
+      val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey)
+      epochWaitTimeMap.put(epoch0, (startWaitTime, context))
+    }
+
+    assert(shuffleCommitTime.get(shuffleKey).get(epoch0)._1 == startWaitTime)
+    assert(epochCommitMap.get(epoch0).status == CommitInfo.COMMIT_INPROCESS)
+
+    // update an INPROCESS commitInfo
+    val response1 = CommitFilesResponse(
+      null,
+      List.empty.asJava,
+      List.empty.asJava,
+      primaryIds.asJava,
+      replicaIds.asJava)
+    epochCommitMap.putIfAbsent(epoch1, new CommitInfo(response1, 
CommitInfo.COMMIT_INPROCESS))
+
+    val commitInfo1 = epochCommitMap.get(epoch1)
+    commitInfo1.synchronized {
+      shuffleCommitTime.putIfAbsent(
+        shuffleKey,
+        JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]())
+      val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey)
+      epochWaitTimeMap.put(epoch1, (startWaitTime, context))
+    }
+
+    assert(shuffleCommitTime.get(shuffleKey).get(epoch1)._1 == startWaitTime)
+    assert(epochCommitMap.get(epoch1).status == CommitInfo.COMMIT_INPROCESS)
+
+    // update an FINISHED commitInfo
+    val response2 = CommitFilesResponse(
+      StatusCode.SUCCESS,
+      primaryIds.asJava,
+      replicaIds.asJava,
+      List.empty.asJava,
+      List.empty.asJava)
+    epochCommitMap.put(epoch2, new CommitInfo(response2, 
CommitInfo.COMMIT_FINISHED))
+
+    val commitInfo2 = epochCommitMap.get(epoch2)
+    commitInfo2.synchronized {
+      shuffleCommitTime.putIfAbsent(
+        shuffleKey,
+        JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]())
+      val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey)
+      // epoch2 is already timeout
+      epochWaitTimeMap.put(epoch2, (startWaitTime, context))
+    }
+
+    assert(shuffleCommitTime.get(shuffleKey).get(epoch2)._1 == startWaitTime)
+    assert(epochCommitMap.get(epoch2).status == CommitInfo.COMMIT_FINISHED)
+
+    // add a new shuffleKey2 to shuffleCommitTime but not to shuffleCommitInfos
+    val shuffleKey2 = "2"
+    shuffleCommitTime.putIfAbsent(
+      shuffleKey2,
+      JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]())
+    shuffleCommitTime.get(shuffleKey2).put(epoch0, (startWaitTime, context))
+    assert(shuffleCommitTime.containsKey(shuffleKey2))
+    assert(!shuffleCommitInfos.containsKey(shuffleKey2))
+
+    // add an epoch to shuffleCommitTime but not to shuffleCommitInfos
+    shuffleCommitTime.get(shuffleKey).put(epoch3, (startWaitTime, context))
+    assert(shuffleCommitTime.get(shuffleKey).get(epoch3)._1 == startWaitTime)
+    assert(!shuffleCommitInfos.get(shuffleKey).containsKey(epoch3))
+
+    // update status of epoch1 to FINISHED
+    epochCommitMap.get(epoch1).status = CommitInfo.COMMIT_FINISHED
+    assert(epochCommitMap.get(epoch1).status == CommitInfo.COMMIT_FINISHED)
+
+    // first timeout check
+    controller.checkCommitTimeout(shuffleCommitTime)
+    assert(epochCommitMap.get(epoch0).status == CommitInfo.COMMIT_INPROCESS)
+
+    // shuffleCommitTime will be removed when shuffleCommitInfos contains no 
shuffleKey
+    assert(!shuffleCommitTime.containsKey(shuffleKey2))
+    assert(!shuffleCommitInfos.containsKey(shuffleKey2))
+
+    // epoch will be removed when shuffleCommitInfos contains no epoch
+    assert(!shuffleCommitTime.get(shuffleKey).containsKey(epoch3))
+
+    // FINISHED status of epoch1 will be removed from shuffleCommitTime
+    assert(shuffleCommitTime.get(shuffleKey).get(epoch1) == null)
+
+    // timeout after 1000 ms
+    Thread.sleep(2000)
+    controller.checkCommitTimeout(shuffleCommitTime)
+
+    // remove epoch0 in shuffleCommitTime when timeout
+    assert(shuffleCommitTime.get(shuffleKey).get(epoch0) == null)
+    assert(epochCommitMap.get(epoch0).status == CommitInfo.COMMIT_FINISHED)
+    assert(epochCommitMap.get(epoch0).response.status == 
StatusCode.COMMIT_FILE_EXCEPTION)
+
+    // timeout but SUCCESS epoch2 can reply
+    assert(shuffleCommitTime.get(shuffleKey).get(epoch2) == null)
+    assert(epochCommitMap.get(epoch2).response.status == StatusCode.SUCCESS)
+  }
 }

Reply via email to