This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new d381df71 [CELEBORN-70] Add epoch for each commitFiles request (#1012)
d381df71 is described below

commit d381df71f86ed6cbc9efd690fc483c9f3075f19e
Author: Keyong Zhou <[email protected]>
AuthorDate: Sun Nov 27 21:05:14 2022 +0800

    [CELEBORN-70] Add epoch for each commitFiles request (#1012)
---
 .../apache/celeborn/client/LifecycleManager.scala  | 78 ++++++++++++++++------
 common/src/main/proto/TransportMessages.proto      |  1 +
 .../common/protocol/message/ControlMessages.scala  |  9 ++-
 .../service/deploy/worker/Controller.scala         | 39 +++++++----
 .../service/deploy/worker/PushDataHandler.scala    | 29 ++++----
 .../celeborn/service/deploy/worker/Worker.scala    |  6 +-
 6 files changed, 112 insertions(+), 50 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 139c4503..bd2887d7 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
 import java.util
 import java.util.{function, List => JList}
 import java.util.concurrent.{Callable, ConcurrentHashMap, 
ScheduledExecutorService, ScheduledFuture, TimeUnit}
-import java.util.concurrent.atomic.LongAdder
+import java.util.concurrent.atomic.{AtomicLong, LongAdder}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -85,7 +85,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) 
extends RpcEndpoin
     .maximumSize(rpcCacheSize)
     .build().asInstanceOf[Cache[Int, ByteBuffer]]
 
-  private val testCommitFileFailure = conf.testRetryCommitFiles
+  private val testRetryCommitFiles = conf.testRetryCommitFiles
+  private val commitEpoch = new AtomicLong()
 
   @VisibleForTesting
   def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, 
PartitionLocationInfo] =
@@ -988,22 +989,61 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
         val masterIds = masterParts.asScala.map(_.getUniqueId).asJava
         val slaveIds = slaveParts.asScala.map(_.getUniqueId).asJava
 
-        val commitFiles = CommitFiles(
-          applicationId,
-          shuffleId,
-          masterIds,
-          slaveIds,
-          shuffleMapperAttempts.get(shuffleId))
-        val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-
-        res.status match {
-          case StatusCode.SUCCESS => // do nothing
-          case StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED 
| StatusCode.FAILED =>
-            logDebug(s"Request $commitFiles return ${res.status} for " +
-              s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
-            commitFilesFailedWorkers.put(worker, (res.status, 
System.currentTimeMillis()))
-          case _ => // won't happen
-        }
+        val res =
+          if (!testRetryCommitFiles) {
+            val commitFiles = CommitFiles(
+              applicationId,
+              shuffleId,
+              masterIds,
+              slaveIds,
+              shuffleMapperAttempts.get(shuffleId),
+              commitEpoch.incrementAndGet())
+            val res = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
+
+            res.status match {
+              case StatusCode.SUCCESS => // do nothing
+              case StatusCode.PARTIAL_SUCCESS | 
StatusCode.SHUFFLE_NOT_REGISTERED | StatusCode.FAILED =>
+                logDebug(s"Request $commitFiles return ${res.status} for " +
+                  s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
+                commitFilesFailedWorkers.put(worker, (res.status, 
System.currentTimeMillis()))
+              case _ => // won't happen
+            }
+            res
+          } else {
+            // for test
+            val commitFiles1 = CommitFiles(
+              applicationId,
+              shuffleId,
+              masterIds.subList(0, masterIds.size() / 2),
+              slaveIds.subList(0, slaveIds.size() / 2),
+              shuffleMapperAttempts.get(shuffleId),
+              commitEpoch.incrementAndGet())
+            val res1 = requestCommitFilesWithRetry(worker.endpoint, 
commitFiles1)
+
+            val commitFiles = CommitFiles(
+              applicationId,
+              shuffleId,
+              masterIds.subList(masterIds.size() / 2, masterIds.size()),
+              slaveIds.subList(slaveIds.size() / 2, slaveIds.size()),
+              shuffleMapperAttempts.get(shuffleId),
+              commitEpoch.incrementAndGet())
+            val res2 = requestCommitFilesWithRetry(worker.endpoint, 
commitFiles)
+
+            
res1.committedMasterStorageInfos.putAll(res2.committedMasterStorageInfos)
+            
res1.committedSlaveStorageInfos.putAll(res2.committedSlaveStorageInfos)
+            res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
+            CommitFilesResponse(
+              status = if (res1.status == StatusCode.SUCCESS) res2.status else 
res1.status,
+              (res1.committedMasterIds.asScala ++ 
res2.committedMasterIds.asScala).toList.asJava,
+              (res1.committedSlaveIds.asScala ++ 
res1.committedSlaveIds.asScala).toList.asJava,
+              (res1.failedMasterIds.asScala ++ 
res1.failedMasterIds.asScala).toList.asJava,
+              (res1.failedSlaveIds.asScala ++ 
res2.failedSlaveIds.asScala).toList.asJava,
+              res1.committedMasterStorageInfos,
+              res1.committedSlaveStorageInfos,
+              res1.committedMapIdBitMap,
+              res1.totalWritten + res2.totalWritten,
+              res1.fileCount + res2.fileCount)
+          }
 
         // record committed partitionIds
         committedMasterIds.addAll(res.committedMasterIds)
@@ -1651,7 +1691,7 @@ class LifecycleManager(appId: String, val conf: 
CelebornConf) extends RpcEndpoin
     var retryTimes = 0
     while (retryTimes < maxRetries) {
       try {
-        if (testCommitFileFailure && retryTimes < maxRetries - 1) {
+        if (testRetryCommitFiles && retryTimes < maxRetries - 1) {
           endpoint.ask[CommitFilesResponse](message)
           Thread.sleep(1000)
           throw new Exception("Mock fail for CommitFiles")
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index a8c4f766..81762fd9 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -330,6 +330,7 @@ message PbCommitFiles {
   repeated string masterIds = 3;
   repeated string slaveIds = 4;
   repeated int32 mapAttempts = 5;
+  int64 epoch = 6;
 }
 
 message PbCommitFilesResponse {
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 7eff974b..9697aabb 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -368,7 +368,8 @@ object ControlMessages extends Logging {
       shuffleId: Int,
       masterIds: util.List[String],
       slaveIds: util.List[String],
-      mapAttempts: Array[Int])
+      mapAttempts: Array[Int],
+      epoch: Long)
     extends WorkerMessage
 
   case class CommitFilesResponse(
@@ -686,13 +687,14 @@ object ControlMessages extends Logging {
         .build().toByteArray
       new TransportMessage(MessageType.RESERVE_SLOTS_RESPONSE, payload)
 
-    case CommitFiles(applicationId, shuffleId, masterIds, slaveIds, 
mapAttempts) =>
+    case CommitFiles(applicationId, shuffleId, masterIds, slaveIds, 
mapAttempts, epoch) =>
       val payload = PbCommitFiles.newBuilder()
         .setApplicationId(applicationId)
         .setShuffleId(shuffleId)
         .addAllMasterIds(masterIds)
         .addAllSlaveIds(slaveIds)
         .addAllMapAttempts(mapAttempts.map(new Integer(_)).toIterable.asJava)
+        .setEpoch(epoch)
         .build().toByteArray
       new TransportMessage(MessageType.COMMIT_FILES, payload)
 
@@ -985,7 +987,8 @@ object ControlMessages extends Logging {
           pbCommitFiles.getShuffleId,
           pbCommitFiles.getMasterIdsList,
           pbCommitFiles.getSlaveIdsList,
-          pbCommitFiles.getMapAttemptsList.asScala.map(_.toInt).toArray)
+          pbCommitFiles.getMapAttemptsList.asScala.map(_.toInt).toArray,
+          pbCommitFiles.getEpoch)
 
       case COMMIT_FILES_RESPONSE =>
         val pbCommitFilesResponse = 
PbCommitFilesResponse.parseFrom(message.getPayload)
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index 992779a8..3fb27b48 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -20,7 +20,7 @@ package org.apache.celeborn.service.deploy.worker
 import java.io.IOException
 import java.util.{ArrayList => jArrayList, HashMap => jHashMap, List => jList, 
Set => jSet}
 import java.util.concurrent._
-import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray, 
AtomicReference}
 import java.util.function.BiFunction
 
 import scala.collection.JavaConverters._
@@ -48,9 +48,9 @@ private[deploy] class Controller(
 
   var workerSource: WorkerSource = _
   var storageManager: StorageManager = _
-  var shuffleMapperAttempts: ConcurrentHashMap[String, Array[Int]] = _
-  // shuffleKey -> (CommitInfo)
-  var shuffleCommitInfos: ConcurrentHashMap[String, CommitInfo] = _
+  var shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] = _
+  // shuffleKe -> (epoch -> CommitInfo)
+  var shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, 
CommitInfo]] = _
   var workerInfo: WorkerInfo = _
   var partitionLocationInfo: PartitionLocationInfo = _
   var timer: HashedWheelTimer = _
@@ -59,7 +59,7 @@ private[deploy] class Controller(
   val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate
   var shutdown: AtomicBoolean = _
 
-  val testCommitFileFailure = conf.testRetryCommitFiles
+  val testRetryCommitFiles = conf.testRetryCommitFiles
 
   def init(worker: Worker): Unit = {
     workerSource = worker.workerSource
@@ -104,13 +104,13 @@ private[deploy] class Controller(
         logDebug(s"ReserveSlots for $shuffleKey finished.")
       }
 
-    case CommitFiles(applicationId, shuffleId, masterIds, slaveIds, 
mapAttempts) =>
+    case CommitFiles(applicationId, shuffleId, masterIds, slaveIds, 
mapAttempts, epoch) =>
       val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
       workerSource.sample(WorkerSource.CommitFilesTime, shuffleKey) {
         logDebug(s"Received CommitFiles request, $shuffleKey, master files" +
           s" ${masterIds.asScala.mkString(",")}; slave files 
${slaveIds.asScala.mkString(",")}.")
         val commitFilesTimeMs = Utils.timeIt({
-          handleCommitFiles(context, shuffleKey, masterIds, slaveIds, 
mapAttempts)
+          handleCommitFiles(context, shuffleKey, masterIds, slaveIds, 
mapAttempts, epoch)
         })
         logDebug(s"Done processed CommitFiles request with shuffleKey 
$shuffleKey, in " +
           s"$commitFilesTimeMs ms.")
@@ -315,10 +315,11 @@ private[deploy] class Controller(
       shuffleKey: String,
       masterIds: jList[String],
       slaveIds: jList[String],
-      mapAttempts: Array[Int]): Unit = {
+      mapAttempts: Array[Int],
+      epoch: Long): Unit = {
     // return null if shuffleKey does not exist
     if (!partitionLocationInfo.containsShuffle(shuffleKey) && 
!shuffleCommitInfos.containsKey(
-        shuffleKey)) {
+        shuffleKey) && !shuffleCommitInfos.get(shuffleKey).containsKey(epoch)) 
{
       logError(s"Shuffle $shuffleKey doesn't exist!")
       context.reply(
         CommitFilesResponse(
@@ -332,8 +333,10 @@ private[deploy] class Controller(
 
     val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
 
-    shuffleCommitInfos.putIfAbsent(shuffleKey, new CommitInfo(null, 
CommitInfo.COMMIT_NOTSTARTED))
-    val commitInfo = shuffleCommitInfos.get(shuffleKey)
+    shuffleCommitInfos.putIfAbsent(shuffleKey, new ConcurrentHashMap[Long, 
CommitInfo]())
+    val epochCommitMap = shuffleCommitInfos.get(shuffleKey)
+    epochCommitMap.putIfAbsent(epoch, new CommitInfo(null, 
CommitInfo.COMMIT_NOTSTARTED))
+    val commitInfo = epochCommitMap.get(epoch)
 
     def waitForCommitFinish(): Unit = {
       val delta = 100
@@ -370,7 +373,17 @@ private[deploy] class Controller(
     }
 
     // close and flush files.
-    shuffleMapperAttempts.putIfAbsent(shuffleKey, mapAttempts)
+    shuffleMapperAttempts.putIfAbsent(shuffleKey, new 
AtomicIntegerArray(mapAttempts))
+    val attempts = shuffleMapperAttempts.get(shuffleKey)
+    if (mapAttempts(0) != -1) {
+      attempts.synchronized {
+        if (attempts.get(0) == -1) {
+          0 until attempts.length() foreach (idx => {
+            attempts.set(idx, mapAttempts(idx))
+          })
+        }
+      }
+    }
 
     // Use ConcurrentSet to avoid excessive lock contention.
     val committedMasterIds = ConcurrentHashMap.newKeySet[String]()
@@ -466,7 +479,7 @@ private[deploy] class Controller(
             totalSize,
             fileCount)
         }
-      if (testCommitFileFailure) {
+      if (testRetryCommitFiles) {
         Thread.sleep(5000)
       }
       commitInfo.synchronized {
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 9d7f2316..808abebd 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -19,12 +19,11 @@ package org.apache.celeborn.service.deploy.worker
 
 import java.nio.ByteBuffer
 import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
-import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray}
 
 import com.google.common.base.Throwables
 import io.netty.buffer.ByteBuf
 
-import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.AlreadyClosedException
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
@@ -44,7 +43,7 @@ class PushDataHandler extends BaseMessageHandler with Logging 
{
   var workerSource: WorkerSource = _
   var rpcSource: RPCSource = _
   var partitionLocationInfo: PartitionLocationInfo = _
-  var shuffleMapperAttempts: ConcurrentHashMap[String, Array[Int]] = _
+  var shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] = _
   var replicateThreadPool: ThreadPoolExecutor = _
   var unavailablePeers: ConcurrentHashMap[WorkerInfo, Long] = _
   var pushClientFactory: TransportClientFactory = _
@@ -155,12 +154,18 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
 
     if (location == null) {
       val (mapId, attemptId) = getMapAttempt(body)
-      if (shuffleMapperAttempts.containsKey(shuffleKey) &&
-        -1 != shuffleMapperAttempts.get(shuffleKey)(mapId)) {
-        // partition data has already been committed
-        logInfo(s"Receive push data from speculative task(shuffle $shuffleKey, 
map $mapId, " +
-          s" attempt $attemptId), but this mapper has already been ended.")
-        
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.STAGE_ENDED.getValue)))
+      if (shuffleMapperAttempts.containsKey(shuffleKey)) {
+        if (-1 != shuffleMapperAttempts.get(shuffleKey).get(mapId)) {
+          // partition data has already been committed
+          logInfo(s"Receive push data from speculative task(shuffle 
$shuffleKey, map $mapId, " +
+            s" attempt $attemptId), but this mapper has already been ended.")
+          
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.STAGE_ENDED.getValue)))
+        } else {
+          logInfo(
+            s"Receive push data for committed hard split partition of (shuffle 
$shuffleKey, " +
+              s"map $mapId attempt $attemptId)")
+          
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+        }
       } else {
         val msg = s"Partition location wasn't found for task(shuffle 
$shuffleKey, map $mapId, " +
           s"attempt $attemptId, uniqueId ${pushData.partitionUniqueId})."
@@ -252,7 +257,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
         val (mapId, attemptId) = getMapAttempt(body)
         val endedAttempt =
           if (shuffleMapperAttempts.containsKey(shuffleKey)) {
-            shuffleMapperAttempts.get(shuffleKey)(mapId)
+            shuffleMapperAttempts.get(shuffleKey).get(mapId)
           } else -1
         // TODO just info log for ended attempt
         logWarning(s"Append data failed for task(shuffle $shuffleKey, map 
$mapId, attempt" +
@@ -313,7 +318,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
       if (loc == null) {
         val (mapId, attemptId) = getMapAttempt(body)
         if (shuffleMapperAttempts.containsKey(shuffleKey)
-          && -1 != shuffleMapperAttempts.get(shuffleKey)(mapId)) {
+          && -1 != shuffleMapperAttempts.get(shuffleKey).get(mapId)) {
           val msg = s"Receive push data from speculative task(shuffle 
$shuffleKey, map $mapId," +
             s" attempt $attemptId), but this mapper has already been ended."
           logInfo(msg)
@@ -422,7 +427,7 @@ class PushDataHandler extends BaseMessageHandler with 
Logging {
           val (mapId, attemptId) = getMapAttempt(body)
           val endedAttempt =
             if (shuffleMapperAttempts.containsKey(shuffleKey)) {
-              shuffleMapperAttempts.get(shuffleKey)(mapId)
+              shuffleMapperAttempts.get(shuffleKey).get(mapId)
             } else -1
           // TODO just info log for ended attempt
           logWarning(s"Append data failed for task(shuffle $shuffleKey, map 
$mapId, attempt" +
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index 8633795b..501f6fb7 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -21,7 +21,7 @@ import java.io.IOException
 import java.lang.{Long => JLong}
 import java.util.{HashMap => JHashMap, HashSet => JHashSet}
 import java.util.concurrent._
-import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -176,10 +176,10 @@ private[celeborn] class Worker(
   // whether this Worker registered to Master successfully
   val registered = new AtomicBoolean(false)
 
-  val shuffleMapperAttempts = new ConcurrentHashMap[String, Array[Int]]()
+  val shuffleMapperAttempts = new ConcurrentHashMap[String, 
AtomicIntegerArray]()
   val partitionLocationInfo = new PartitionLocationInfo
 
-  val shuffleCommitInfos = new ConcurrentHashMap[String, CommitInfo]()
+  val shuffleCommitInfos = new ConcurrentHashMap[String, 
ConcurrentHashMap[Long, CommitInfo]]()
 
   private val rssHARetryClient = new RssHARetryClient(rpcEnv, conf)
 

Reply via email to