FMX commented on code in PR #2924:
URL: https://github.com/apache/celeborn/pull/2924#discussion_r1849986142


##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     }
   }
 
+  class PushMergedDataCallback(callback: RpcResponseCallback) {
+    private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]()

Review Comment:
   Renaming this variable will be better. It could be failedPartitionStatuses.



##########
client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java:
##########
@@ -1437,71 +1465,82 @@ public void onFailure(Throwable e) {
         new RpcResponseCallback() {
           @Override
           public void onSuccess(ByteBuffer response) {
-            if (response.remaining() > 0) {
-              byte reason = response.get();
-              if (reason == StatusCode.HARD_SPLIT.getValue()) {
-                logger.info(
-                    "Push merged data to {} hard split required for shuffle {} 
map {} attempt {} partition {} groupedBatch {} batch {}.",
-                    addressPair,
-                    shuffleId,
-                    mapId,
-                    attemptId,
-                    Arrays.toString(partitionIds),
-                    groupedBatchId,
-                    Arrays.toString(batchIds));
-
-                ReviveRequest[] requests =
-                    addAndGetReviveRequests(
-                        shuffleId, mapId, attemptId, batches, 
StatusCode.HARD_SPLIT);
-                pushDataRetryPool.submit(
-                    () ->
-                        submitRetryPushMergedData(
-                            pushState,
-                            shuffleId,
-                            mapId,
-                            attemptId,
-                            batches,
-                            StatusCode.HARD_SPLIT,
-                            groupedBatchId,
-                            requests,
-                            remainReviveTimes,
-                            System.currentTimeMillis()
-                                + 
conf.clientRpcRequestPartitionLocationAskTimeout()
-                                    .duration()
-                                    .toMillis()));
-              } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
-                logger.debug(
-                    "Push merged data to {} primary congestion required for 
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
-                    addressPair,
-                    shuffleId,
-                    mapId,
-                    attemptId,
-                    Arrays.toString(partitionIds),
-                    groupedBatchId,
-                    Arrays.toString(batchIds));
-                pushState.onCongestControl(hostPort);
-                callback.onSuccess(response);
-              } else if (reason == 
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
-                logger.debug(
-                    "Push merged data to {} replica congestion required for 
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
-                    addressPair,
-                    shuffleId,
-                    mapId,
-                    attemptId,
-                    Arrays.toString(partitionIds),
-                    groupedBatchId,
-                    Arrays.toString(batchIds));
-                pushState.onCongestControl(hostPort);
-                callback.onSuccess(response);
-              } else {
-                // StageEnd.
-                response.rewind();
-                pushState.onSuccess(hostPort);
-                callback.onSuccess(response);
+            byte reason = response.get();
+            if (reason == StatusCode.HARD_SPLIT.getValue()) {
+              PushMergedDataUnsuccessfulPartitionInfo partitionInfo =
+                  (PushMergedDataUnsuccessfulPartitionInfo) 
Message.decode(response);
+              int length = partitionInfo.unsuccessfulPartitionIndexes.length;
+              ArrayList<DataBatches.DataBatch> toRetryBatched = new 
ArrayList<>();
+              DataBatchReviveInfo[] dataBatchReviveInfos = new 
DataBatchReviveInfo[length];

Review Comment:
   StringBuilder can replace this class.



##########
client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java:
##########
@@ -388,13 +414,13 @@ private void submitRetryPushMergedData(
       int mapId,
       int attemptId,
       ArrayList<DataBatches.DataBatch> batches,
-      StatusCode cause,
       Integer oldGroupedBatchId,
       ReviveRequest[] reviveRequests,
       int remainReviveTimes,
       long reviveResponseDueTime) {
     HashMap<Pair<String, String>, DataBatches> newDataBatchesMap = new 
HashMap<>();
     ArrayList<DataBatches.DataBatch> reviveFailedBatchesMap = new 
ArrayList<>();
+    ArrayList<StatusCode> reviveFailedBatchesCauses = new ArrayList<>();

Review Comment:
   Use `List` instead, In Java, we prefer to decouple your code from a specific 
implementation of the interface



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     }
   }
 
+  class PushMergedDataCallback(callback: RpcResponseCallback) {
+    private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]()
+
+    def addUnsuccessfulPartition(index: Int, statusCode: StatusCode): Unit = {
+      partitionIndex2StatusCode.put(index, statusCode.getValue)
+    }
+
+    def containIndex(index: Int): Boolean = {

Review Comment:
   ```suggestion
       def isPartitionFailed(index: Int): Boolean = {
   ```



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -530,30 +538,44 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
         s"While handling PushMergedData, throw $cause, fileWriter 
$fileWriterWithException has exception.",
         fileWriterWithException.getException)
       workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
-      callbackWithTimer.onFailure(new CelebornIOException(cause))
+      pushMergedDataCallback.onFailure(new CelebornIOException(cause))
       return
     }
 
-    if (fileWriters.exists(checkDiskFull(_) == true)) {
-      val (mapId, attemptId) = getMapAttempt(body)
-      logWarning(
-        s"return hard split for disk full with shuffle $shuffleKey map $mapId 
attempt $attemptId")
-      
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
-      return
+    fileWriters.zipWithIndex.foreach {
+      case (fileWriter, index) =>
+        if (fileWriter == null) {
+          if (!pushMergedDataCallback.containIndex(index)) {

Review Comment:
   A more meaningful method name will be better. Maybe IsSplittedPartition.



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -530,30 +538,44 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
         s"While handling PushMergedData, throw $cause, fileWriter 
$fileWriterWithException has exception.",
         fileWriterWithException.getException)
       workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
-      callbackWithTimer.onFailure(new CelebornIOException(cause))
+      pushMergedDataCallback.onFailure(new CelebornIOException(cause))
       return
     }
 
-    if (fileWriters.exists(checkDiskFull(_) == true)) {
-      val (mapId, attemptId) = getMapAttempt(body)
-      logWarning(
-        s"return hard split for disk full with shuffle $shuffleKey map $mapId 
attempt $attemptId")
-      
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
-      return
+    fileWriters.zipWithIndex.foreach {
+      case (fileWriter, index) =>
+        if (fileWriter == null) {
+          if (!pushMergedDataCallback.containIndex(index)) {
+            pushMergedDataCallback.onFailure(new 
CelebornIOException(s"Partition $index's fileWriter not found, but it hasn't 
been identified in the previous validation step."))
+            return
+          }
+        } else if (checkDiskFull(fileWriter)) {
+          logWarning(
+            s"return hard split for disk full with shuffle $shuffleKey map 
$mapId attempt $attemptId")
+          pushMergedDataCallback.addUnsuccessfulPartition(index, 
StatusCode.HARD_SPLIT)
+        } else if (fileWriter.isClosed) {
+          val fileInfo = fileWriter.getCurrentFileInfo
+          logWarning(
+            s"[handlePushMergedData] FileWriter is already closed! File path 
${fileInfo.getFilePath} " +
+              s"length ${fileInfo.getFileLength}")
+          pushMergedDataCallback.addUnsuccessfulPartition(index, 
StatusCode.HARD_SPLIT)
+        } else {
+          val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary)
+          if (splitStatus == StatusCode.HARD_SPLIT) {
+            logWarning(
+              s"return hard split for disk full with shuffle $shuffleKey map 
$mapId attempt $attemptId")
+            workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+            pushMergedDataCallback.addUnsuccessfulPartition(index, 
StatusCode.HARD_SPLIT)
+          } else if (splitStatus == StatusCode.SOFT_SPLIT) {
+            pushMergedDataCallback.addUnsuccessfulPartition(index, 
StatusCode.SOFT_SPLIT)
+          }
+        }
+        if (!pushMergedDataCallback.containIndex(index) || 
pushMergedDataCallback.getStatusCode(

Review Comment:
   The method getStatusCode will return -1 by default. So this check can be 
simplified to `pushMergedDataCallback.getStatusCode(
               index) == StatusCode.SOFT_SPLIT.getValue`



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -1277,11 +1391,11 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
                |fileLength:${diskFileInfo.getFileLength},
                |fileName:${diskFileInfo.getFilePath}
                |""".stripMargin)
-          return true
+          return StatusCode.HARD_SPLIT
         }
       }
     }
-    false
+    StatusCode.SUCCESS

Review Comment:
   This change breaks the semantics of this method.  I'd rather add a new 
status code named NO_SPLIT.



##########
common/src/main/java/org/apache/celeborn/common/write/DataBatchReviveInfo.java:
##########
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.write;
+
+import org.apache.celeborn.common.protocol.message.StatusCode;
+
+public class DataBatchReviveInfo {

Review Comment:
   This class is only used in the logs, it should not exist.



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     }
   }
 
+  class PushMergedDataCallback(callback: RpcResponseCallback) {
+    private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]()
+
+    def addUnsuccessfulPartition(index: Int, statusCode: StatusCode): Unit = {
+      partitionIndex2StatusCode.put(index, statusCode.getValue)
+    }
+
+    def containIndex(index: Int): Boolean = {
+      partitionIndex2StatusCode.contains(index)
+    }
+
+    def getStatusCode(index: Int): Byte = {

Review Comment:
   
   ```suggestion
       def getPartitionFailedStatus(index:Int): Byte = {
   ```



##########
common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java:
##########
@@ -96,4 +96,119 @@ public enum StatusCode {
   public final byte getValue() {
     return value;
   }
+
+  public static StatusCode fromValue(byte value) {

Review Comment:
   This method can be replaced by following codes.
   
   ```
     private static final Map<Byte, StatusCode> lookup =
         Arrays.stream(StatusCode.values()).collect(Collectors.toMap(i -> 
i.getValue(), i -> i));
   
     public static StatusCode fromValue(byte value) {
       StatusCode code = lookup.get(value);
       if (code != null) {
         return code;
       }
       throw new IllegalArgumentException("Unknown status code: " + value);
   
     }
   ```



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     }
   }
 
+  class PushMergedDataCallback(callback: RpcResponseCallback) {
+    private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]()
+
+    def addUnsuccessfulPartition(index: Int, statusCode: StatusCode): Unit = {
+      partitionIndex2StatusCode.put(index, statusCode.getValue)
+    }
+
+    def containIndex(index: Int): Boolean = {
+      partitionIndex2StatusCode.contains(index)
+    }
+
+    def getStatusCode(index: Int): Byte = {
+      partitionIndex2StatusCode.getOrElse(index, -1)
+    }
+
+    def unionUnsuccessfulPartition(
+        replicaPartitionIndexes: Array[Int],
+        replicaStatusCodes: Array[Byte]): Unit = {
+      if (replicaPartitionIndexes.length != replicaStatusCodes.length) {
+        throw new IllegalArgumentException(
+          "replicaPartitionIndexes and replicaStatusCodes must have the same 
size")
+      }
+      for (i <- replicaPartitionIndexes.indices) {
+        val index = replicaPartitionIndexes(i)
+        // if primary and replica have the same index, use the primary's 
status code
+        if (!partitionIndex2StatusCode.contains(index)) {
+          partitionIndex2StatusCode.put(index, replicaStatusCodes(i))
+        }
+      }
+    }
+
+    def getSortedSkipIndexes(): Array[Int] = {

Review Comment:
   ```suggestion
       def getFailedPartitionIndexes(): Array[Int] = {
   ```



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -1330,21 +1445,28 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     batchOffsets match {
       case Some(batchOffsets) =>
         var index = 0
+        var indexOfSkipIndexes = 0
+        val skipListLength = sortedSkipIndexes.length
         var fileWriter: PartitionDataWriter = null
         while (index < fileWriters.length) {
-          fileWriter = fileWriters(index)
-          if (!writePromise.isCompleted) {
-            val offset = body.readerIndex() + batchOffsets(index)
-            val length =
-              if (index == fileWriters.length - 1) {
-                body.readableBytes() - batchOffsets(index)
-              } else {
-                batchOffsets(index + 1) - batchOffsets(index)
-              }
-            val batchBody = body.slice(offset, length)
-            writeData(fileWriter, batchBody, shuffleKey)
+          if (indexOfSkipIndexes < skipListLength && index == 
sortedSkipIndexes(
+              indexOfSkipIndexes)) {
+            indexOfSkipIndexes += 1
           } else {
-            fileWriter.decrementPendingWrites()
+            fileWriter = fileWriters(index)
+            if (!writePromise.isCompleted) {
+              val offset = body.readerIndex() + batchOffsets(index)
+              val length =
+                if (index == fileWriters.length - 1) {
+                  body.readableBytes() - batchOffsets(index)
+                } else {
+                  batchOffsets(index + 1) - batchOffsets(index)
+                }
+              val batchBody = body.slice(offset, length)
+              writeData(fileWriter, batchBody, shuffleKey)

Review Comment:
   This PR is intended to optimize process logic for pushMergedData and 
possibly create another PR for fileWriters with failures.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@celeborn.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to