This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new f0d4319c1 [CELEBORN-1106] Ensure data is written into flush buffer 
before sending message to client
f0d4319c1 is described below

commit f0d4319c1851c5cf5cbca2b6dc2a71a349a5137d
Author: Aravind Patnam <[email protected]>
AuthorDate: Mon Nov 13 21:17:29 2023 +0800

    [CELEBORN-1106] Ensure data is written into flush buffer before sending 
message to client
    
    ### What changes were proposed in this pull request?
    The changes are to ensure that the data is at least written into the flush 
buffer before sending a message back to the client. Earlier, the message would 
be sent before this happens.
    
    ### Why are the changes needed?
    Changes are needed because currently the primary will send a response back 
to client before it is even written into the flush buffer to persist locally.  
We do this persist async. Additionally, this will prevent data corruption 
issues when data may not be present properly in primary but only on replica, 
but client fetches only from primary.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Will let CI run, and also tested on our internal cluster
    
    Closes #2064 from akpatnam25/CELEBORN-1106.
    
    Lead-authored-by: Aravind Patnam <[email protected]>
    Co-authored-by: Aravind Patnam <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../service/deploy/worker/PushDataHandler.scala    | 350 +++++++++++----------
 1 file changed, 185 insertions(+), 165 deletions(-)

diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index d7b111eb3..208ff8bca 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -22,6 +22,9 @@ import java.util.concurrent.{ConcurrentHashMap, 
ThreadPoolExecutor}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray}
 
 import scala.collection.JavaConverters._
+import scala.concurrent.{Await, Promise}
+import scala.concurrent.duration.Duration
+import scala.util.{Failure, Success, Try}
 
 import com.google.common.base.Throwables
 import com.google.protobuf.GeneratedMessageV3
@@ -254,7 +257,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       fileWriter.decrementPendingWrites()
       return;
     }
-
+    val writePromise = Promise[Unit]()
     // for primary, send data to replica
     if (doReplicate) {
       pushData.body().retain()
@@ -280,34 +283,38 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           // Handle the response from replica
           val wrappedCallback = new RpcResponseCallback() {
             override def onSuccess(response: ByteBuffer): Unit = {
-              if (response.remaining() > 0) {
-                val resp = ByteBuffer.allocate(response.remaining())
-                resp.put(response)
-                resp.flip()
-                callbackWithTimer.onSuccess(resp)
-              } else if (softSplit.get()) {
-                // TODO Currently if the worker is in soft split status, given 
the guess that the client
-                // will fast stop pushing data to the worker, we won't return 
congest status. But
-                // in the long term, especially if this issue could frequently 
happen, we may need to return
-                // congest&softSplit status together
-                callbackWithTimer.onSuccess(
-                  ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
-              } else {
-                Option(CongestionController.instance()) match {
-                  case Some(congestionController) =>
-                    if (congestionController.isUserCongested(
-                        fileWriter.getFileInfo.getUserIdentifier)) {
-                      // Check whether primary congest the data though the 
replicas doesn't congest
-                      // it(the response is empty)
-                      callbackWithTimer.onSuccess(
-                        ByteBuffer.wrap(
-                          
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
-                    } else {
-                      
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+              Try(Await.result(writePromise.future, Duration.Inf)) match {
+                case Success(_) =>
+                  if (response.remaining() > 0) {
+                    val resp = ByteBuffer.allocate(response.remaining())
+                    resp.put(response)
+                    resp.flip()
+                    callbackWithTimer.onSuccess(resp)
+                  } else if (softSplit.get()) {
+                    // TODO Currently if the worker is in soft split status, 
given the guess that the client
+                    // will fast stop pushing data to the worker, we won't 
return congest status. But
+                    // in the long term, especially if this issue could 
frequently happen, we may need to return
+                    // congest&softSplit status together
+                    callbackWithTimer.onSuccess(
+                      
ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+                  } else {
+                    Option(CongestionController.instance()) match {
+                      case Some(congestionController) =>
+                        if (congestionController.isUserCongested(
+                            fileWriter.getFileInfo.getUserIdentifier)) {
+                          // Check whether primary congest the data though the 
replicas doesn't congest
+                          // it(the response is empty)
+                          callbackWithTimer.onSuccess(
+                            ByteBuffer.wrap(
+                              
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                        } else {
+                          
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+                        }
+                      case None =>
+                        
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
                     }
-                  case None =>
-                    callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
-                }
+                  }
+                case Failure(e) => callbackWithTimer.onFailure(e)
               }
             }
 
@@ -350,6 +357,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           }
         }
       })
+      writeLocalData(Seq(fileWriter), body, shuffleKey, isPrimary, None, 
writePromise)
     } else {
       // The codes here could be executed if
       // 1. the client doesn't enable push data to the replica, the primary 
worker could hit here
@@ -358,47 +366,36 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       // will fast stop pushing data to the worker, we won't return congest 
status. But
       // in the long term, especially if this issue could frequently happen, 
we may need to return
       // congest&softSplit status together
-      if (softSplit.get()) {
-        
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
-      } else {
-        Option(CongestionController.instance()) match {
-          case Some(congestionController) =>
-            if 
(congestionController.isUserCongested(fileWriter.getFileInfo.getUserIdentifier))
 {
-              if (isPrimary) {
-                callbackWithTimer.onSuccess(
-                  ByteBuffer.wrap(
-                    
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
-              } else {
-                callbackWithTimer.onSuccess(
-                  ByteBuffer.wrap(
-                    
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
-              }
-            } else {
-              callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+      writeLocalData(Seq(fileWriter), body, shuffleKey, isPrimary, None, 
writePromise)
+      Try(Await.result(writePromise.future, Duration.Inf)) match {
+        case Success(_) =>
+          if (softSplit.get()) {
+            callbackWithTimer.onSuccess(
+              ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+          } else {
+            Option(CongestionController.instance()) match {
+              case Some(congestionController) =>
+                if (congestionController.isUserCongested(
+                    fileWriter.getFileInfo.getUserIdentifier)) {
+                  if (isPrimary) {
+                    callbackWithTimer.onSuccess(
+                      ByteBuffer.wrap(
+                        
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                  } else {
+                    callbackWithTimer.onSuccess(
+                      ByteBuffer.wrap(
+                        
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+                  }
+                } else {
+                  callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+                }
+              case None =>
+                callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
             }
-          case None =>
-            callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
-        }
+          }
+        case Failure(e) => callbackWithTimer.onFailure(e)
       }
     }
-
-    try {
-      fileWriter.write(body)
-    } catch {
-      case e: AlreadyClosedException =>
-        fileWriter.decrementPendingWrites()
-        val (mapId, attemptId) = getMapAttempt(body)
-        val endedAttempt =
-          if (shuffleMapperAttempts.containsKey(shuffleKey)) {
-            shuffleMapperAttempts.get(shuffleKey).get(mapId)
-          } else -1
-        // TODO just info log for ended attempt
-        logError(
-          s"[handlePushData] Append data failed for task(shuffle $shuffleKey, 
map $mapId, attempt" +
-            s" $attemptId), caused by AlreadyClosedException, endedAttempt 
$endedAttempt, error message: ${e.getMessage}")
-      case e: Exception =>
-        logError("Exception encountered when write.", e)
-    }
   }
 
   def handlePushMergedData(
@@ -525,7 +522,7 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       fileWriters.foreach(_.decrementPendingWrites())
       return
     }
-
+    val writePromise = Promise[Unit]()
     // for primary, send data to replica
     if (doReplicate) {
       pushMergedData.body().retain()
@@ -552,28 +549,32 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           // Handle the response from replica
           val wrappedCallback = new RpcResponseCallback() {
             override def onSuccess(response: ByteBuffer): Unit = {
-              // Only primary data enable replication will push data to replica
-              if (response.remaining() > 0) {
-                val resp = ByteBuffer.allocate(response.remaining())
-                resp.put(response)
-                resp.flip()
-                callbackWithTimer.onSuccess(resp)
-              } else {
-                Option(CongestionController.instance()) match {
-                  case Some(congestionController) if fileWriters.nonEmpty =>
-                    if (congestionController.isUserCongested(
-                        fileWriters.head.getFileInfo.getUserIdentifier)) {
-                      // Check whether primary congest the data though the 
replicas doesn't congest
-                      // it(the response is empty)
-                      callbackWithTimer.onSuccess(
-                        ByteBuffer.wrap(
-                          
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
-                    } else {
-                      
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+              Try(Await.result(writePromise.future, Duration.Inf)) match {
+                case Success(_) =>
+                  // Only primary data enable replication will push data to 
replica
+                  if (response.remaining() > 0) {
+                    val resp = ByteBuffer.allocate(response.remaining())
+                    resp.put(response)
+                    resp.flip()
+                    callbackWithTimer.onSuccess(resp)
+                  } else {
+                    Option(CongestionController.instance()) match {
+                      case Some(congestionController) if fileWriters.nonEmpty 
=>
+                        if (congestionController.isUserCongested(
+                            fileWriters.head.getFileInfo.getUserIdentifier)) {
+                          // Check whether primary congest the data though the 
replicas doesn't congest
+                          // it(the response is empty)
+                          callbackWithTimer.onSuccess(
+                            ByteBuffer.wrap(
+                              
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                        } else {
+                          
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+                        }
+                      case None =>
+                        
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
                     }
-                  case None =>
-                    callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
-                }
+                  }
+                case Failure(e) => callbackWithTimer.onFailure(e)
               }
             }
 
@@ -621,69 +622,36 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           }
         }
       })
+      writeLocalData(fileWriters, body, shuffleKey, isPrimary, 
Some(batchOffsets), writePromise)
     } else {
       // The codes here could be executed if
       // 1. the client doesn't enable push data to the replica, the primary 
worker could hit here
       // 2. the client enables push data to the replica, and the replica 
worker could hit here
-      Option(CongestionController.instance()) match {
-        case Some(congestionController) if fileWriters.nonEmpty =>
-          if (congestionController.isUserCongested(
-              fileWriters.head.getFileInfo.getUserIdentifier)) {
-            if (isPrimary) {
-              callbackWithTimer.onSuccess(
-                ByteBuffer.wrap(
-                  
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
-            } else {
-              callbackWithTimer.onSuccess(
-                ByteBuffer.wrap(
-                  
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
-            }
-          } else {
-            callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+      writeLocalData(fileWriters, body, shuffleKey, isPrimary, 
Some(batchOffsets), writePromise)
+      Try(Await.result(writePromise.future, Duration.Inf)) match {
+        case Success(_) =>
+          Option(CongestionController.instance()) match {
+            case Some(congestionController) if fileWriters.nonEmpty =>
+              if (congestionController.isUserCongested(
+                  fileWriters.head.getFileInfo.getUserIdentifier)) {
+                if (isPrimary) {
+                  callbackWithTimer.onSuccess(
+                    ByteBuffer.wrap(
+                      
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+                } else {
+                  callbackWithTimer.onSuccess(
+                    ByteBuffer.wrap(
+                      
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+                }
+              } else {
+                callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+              }
+            case None =>
+              callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
           }
-        case None =>
-          callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+        case Failure(e) => callbackWithTimer.onFailure(e)
       }
     }
-
-    index = 0
-    var fileWriter: FileWriter = null
-    var alreadyClosed = false
-    while (index < fileWriters.length) {
-      fileWriter = fileWriters(index)
-      val offset = body.readerIndex() + batchOffsets(index)
-      val length =
-        if (index == fileWriters.length - 1) {
-          body.readableBytes() - batchOffsets(index)
-        } else {
-          batchOffsets(index + 1) - batchOffsets(index)
-        }
-      val batchBody = body.slice(offset, length)
-
-      try {
-        if (!alreadyClosed) {
-          fileWriter.write(batchBody)
-        } else {
-          fileWriter.decrementPendingWrites()
-        }
-      } catch {
-        case e: AlreadyClosedException =>
-          fileWriter.decrementPendingWrites()
-          alreadyClosed = true
-          val (mapId, attemptId) = getMapAttempt(body)
-          val endedAttempt =
-            if (shuffleMapperAttempts.containsKey(shuffleKey)) {
-              shuffleMapperAttempts.get(shuffleKey).get(mapId)
-            } else -1
-          // TODO just info log for ended attempt
-          logError(
-            s"[handlePushMergedData] Append data failed for task(shuffle 
$shuffleKey, map $mapId, attempt" +
-              s" $attemptId), caused by AlreadyClosedException, endedAttempt 
$endedAttempt, error message: ${e.getMessage}")
-        case e: Exception =>
-          logError("Exception encountered when write.", e)
-      }
-      index += 1
-    }
   }
 
   /**
@@ -827,31 +795,20 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       fileWriter.decrementPendingWrites()
       return;
     }
-
+    val writePromise = Promise[Unit]()
+    writeLocalData(Seq(fileWriter), body, shuffleKey, isPrimary, None, 
writePromise)
     // for primary, send data to replica
     if (location.hasPeer && isPrimary) {
       // to do
-      wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+      Try(Await.result(writePromise.future, Duration.Inf)) match {
+        case Success(_) => 
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+        case Failure(e) => wrappedCallback.onFailure(e)
+      }
     } else {
-      wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
-    }
-
-    try {
-      fileWriter.write(body)
-    } catch {
-      case e: AlreadyClosedException =>
-        fileWriter.decrementPendingWrites()
-        val (mapId, attemptId) = getMapAttempt(body)
-        val endedAttempt =
-          if (shuffleMapperAttempts.containsKey(shuffleKey)) {
-            shuffleMapperAttempts.get(shuffleKey).get(mapId)
-          } else -1
-        // TODO just info log for ended attempt
-        logError(
-          s"[handleMapPartitionPushData] Append data failed for task(shuffle 
$shuffleKey, map $mapId, attempt" +
-            s" $attemptId), caused by AlreadyClosedException, endedAttempt 
$endedAttempt, error message: ${e.getMessage}")
-      case e: Exception =>
-        logError("Exception encountered when write.", e)
+      Try(Await.result(writePromise.future, Duration.Inf)) match {
+        case Success(_) => 
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+        case Failure(e) => wrappedCallback.onFailure(e)
+      }
     }
   }
 
@@ -1253,6 +1210,69 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     }
   }
 
+  private def writeLocalData(
+      fileWriters: Seq[FileWriter],
+      body: ByteBuf,
+      shuffleKey: String,
+      isPrimary: Boolean,
+      batchOffsets: Option[Array[Int]],
+      writePromise: Promise[Unit]): Unit = {
+    def writeData(fileWriter: FileWriter, body: ByteBuf, shuffleKey: String): 
Unit = {
+      try {
+        fileWriter.write(body)
+      } catch {
+        case e: Exception =>
+          if (e.isInstanceOf[AlreadyClosedException]) {
+            val (mapId, attemptId) = getMapAttempt(body)
+            val endedAttempt =
+              if (shuffleMapperAttempts.containsKey(shuffleKey)) {
+                shuffleMapperAttempts.get(shuffleKey).get(mapId)
+              } else -1
+            // TODO just info log for ended attempt
+            logWarning(s"Append data failed for task(shuffle $shuffleKey, map 
$mapId, attempt" +
+              s" $attemptId), caused by AlreadyClosedException, endedAttempt 
$endedAttempt, error message: ${e.getMessage}")
+          } else {
+            logError("Exception encountered when write.", e)
+          }
+          val cause =
+            if (isPrimary) {
+              StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
+            } else {
+              StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA
+            }
+          writePromise.failure(new CelebornIOException(cause))
+          fileWriter.decrementPendingWrites()
+      }
+    }
+    batchOffsets match {
+      case Some(batchOffsets) =>
+        var index = 0
+        var fileWriter: FileWriter = null
+        while (index < fileWriters.length) {
+          if (!writePromise.isCompleted) {
+            fileWriter = fileWriters(index)
+            val offset = body.readerIndex() + batchOffsets(index)
+            val length =
+              if (index == fileWriters.length - 1) {
+                body.readableBytes() - batchOffsets(index)
+              } else {
+                batchOffsets(index + 1) - batchOffsets(index)
+              }
+            val batchBody = body.slice(offset, length)
+            writeData(fileWriter, batchBody, shuffleKey)
+          } else {
+            fileWriter.decrementPendingWrites()
+          }
+          index += 1
+        }
+      case _ =>
+        writeData(fileWriters.head, body, shuffleKey)
+    }
+    if (!writePromise.isCompleted) {
+      writePromise.success()
+    }
+  }
+
   /**
    * Invoked when the channel associated with the given client is active.
    */

Reply via email to