FMX commented on code in PR #2924: URL: https://github.com/apache/celeborn/pull/2924#discussion_r1849986142
########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } } + class PushMergedDataCallback(callback: RpcResponseCallback) { + private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]() Review Comment: Renaming this variable will be better. It could be failedPartitionStatuses. ########## client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java: ########## @@ -1437,71 +1465,82 @@ public void onFailure(Throwable e) { new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { - if (response.remaining() > 0) { - byte reason = response.get(); - if (reason == StatusCode.HARD_SPLIT.getValue()) { - logger.info( - "Push merged data to {} hard split required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.", - addressPair, - shuffleId, - mapId, - attemptId, - Arrays.toString(partitionIds), - groupedBatchId, - Arrays.toString(batchIds)); - - ReviveRequest[] requests = - addAndGetReviveRequests( - shuffleId, mapId, attemptId, batches, StatusCode.HARD_SPLIT); - pushDataRetryPool.submit( - () -> - submitRetryPushMergedData( - pushState, - shuffleId, - mapId, - attemptId, - batches, - StatusCode.HARD_SPLIT, - groupedBatchId, - requests, - remainReviveTimes, - System.currentTimeMillis() - + conf.clientRpcRequestPartitionLocationAskTimeout() - .duration() - .toMillis())); - } else if (reason == StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) { - logger.debug( - "Push merged data to {} primary congestion required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.", - addressPair, - shuffleId, - mapId, - attemptId, - Arrays.toString(partitionIds), - groupedBatchId, - Arrays.toString(batchIds)); - pushState.onCongestControl(hostPort); - callback.onSuccess(response); - } else if (reason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) { - logger.debug( - "Push merged data to {} replica congestion required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.", - addressPair, - shuffleId, - mapId, - attemptId, - Arrays.toString(partitionIds), - groupedBatchId, - Arrays.toString(batchIds)); - pushState.onCongestControl(hostPort); - callback.onSuccess(response); - } else { - // StageEnd. - response.rewind(); - pushState.onSuccess(hostPort); - callback.onSuccess(response); + byte reason = response.get(); + if (reason == StatusCode.HARD_SPLIT.getValue()) { + PushMergedDataUnsuccessfulPartitionInfo partitionInfo = + (PushMergedDataUnsuccessfulPartitionInfo) Message.decode(response); + int length = partitionInfo.unsuccessfulPartitionIndexes.length; + ArrayList<DataBatches.DataBatch> toRetryBatched = new ArrayList<>(); + DataBatchReviveInfo[] dataBatchReviveInfos = new DataBatchReviveInfo[length]; Review Comment: StringBuilder can replace this class. ########## client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java: ########## @@ -388,13 +414,13 @@ private void submitRetryPushMergedData( int mapId, int attemptId, ArrayList<DataBatches.DataBatch> batches, - StatusCode cause, Integer oldGroupedBatchId, ReviveRequest[] reviveRequests, int remainReviveTimes, long reviveResponseDueTime) { HashMap<Pair<String, String>, DataBatches> newDataBatchesMap = new HashMap<>(); ArrayList<DataBatches.DataBatch> reviveFailedBatchesMap = new ArrayList<>(); + ArrayList<StatusCode> reviveFailedBatchesCauses = new ArrayList<>(); Review Comment: Use `List` instead, In Java, we prefer to decouple your code from a specific implementation of the interface ########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } } + class PushMergedDataCallback(callback: RpcResponseCallback) { + private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]() + + def addUnsuccessfulPartition(index: Int, statusCode: StatusCode): Unit = { + partitionIndex2StatusCode.put(index, statusCode.getValue) + } + + def containIndex(index: Int): Boolean = { Review Comment: ```suggestion def isPartitionFailed(index: Int): Boolean = { ``` ########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -530,30 +538,44 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler s"While handling PushMergedData, throw $cause, fileWriter $fileWriterWithException has exception.", fileWriterWithException.getException) workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT) - callbackWithTimer.onFailure(new CelebornIOException(cause)) + pushMergedDataCallback.onFailure(new CelebornIOException(cause)) return } - if (fileWriters.exists(checkDiskFull(_) == true)) { - val (mapId, attemptId) = getMapAttempt(body) - logWarning( - s"return hard split for disk full with shuffle $shuffleKey map $mapId attempt $attemptId") - callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue))) - return + fileWriters.zipWithIndex.foreach { + case (fileWriter, index) => + if (fileWriter == null) { + if (!pushMergedDataCallback.containIndex(index)) { Review Comment: A more meaningful method name will be better. Maybe IsSplittedPartition. ########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -530,30 +538,44 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler s"While handling PushMergedData, throw $cause, fileWriter $fileWriterWithException has exception.", fileWriterWithException.getException) workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT) - callbackWithTimer.onFailure(new CelebornIOException(cause)) + pushMergedDataCallback.onFailure(new CelebornIOException(cause)) return } - if (fileWriters.exists(checkDiskFull(_) == true)) { - val (mapId, attemptId) = getMapAttempt(body) - logWarning( - s"return hard split for disk full with shuffle $shuffleKey map $mapId attempt $attemptId") - callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue))) - return + fileWriters.zipWithIndex.foreach { + case (fileWriter, index) => + if (fileWriter == null) { + if (!pushMergedDataCallback.containIndex(index)) { + pushMergedDataCallback.onFailure(new CelebornIOException(s"Partition $index's fileWriter not found, but it hasn't been identified in the previous validation step.")) + return + } + } else if (checkDiskFull(fileWriter)) { + logWarning( + s"return hard split for disk full with shuffle $shuffleKey map $mapId attempt $attemptId") + pushMergedDataCallback.addUnsuccessfulPartition(index, StatusCode.HARD_SPLIT) + } else if (fileWriter.isClosed) { + val fileInfo = fileWriter.getCurrentFileInfo + logWarning( + s"[handlePushMergedData] FileWriter is already closed! File path ${fileInfo.getFilePath} " + + s"length ${fileInfo.getFileLength}") + pushMergedDataCallback.addUnsuccessfulPartition(index, StatusCode.HARD_SPLIT) + } else { + val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary) + if (splitStatus == StatusCode.HARD_SPLIT) { + logWarning( + s"return hard split for disk full with shuffle $shuffleKey map $mapId attempt $attemptId") + workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT) + pushMergedDataCallback.addUnsuccessfulPartition(index, StatusCode.HARD_SPLIT) + } else if (splitStatus == StatusCode.SOFT_SPLIT) { + pushMergedDataCallback.addUnsuccessfulPartition(index, StatusCode.SOFT_SPLIT) + } + } + if (!pushMergedDataCallback.containIndex(index) || pushMergedDataCallback.getStatusCode( Review Comment: The method getStatusCode will return -1 by default. So this check can be simplified to `pushMergedDataCallback.getStatusCode( index) == StatusCode.SOFT_SPLIT.getValue` ########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -1277,11 +1391,11 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler |fileLength:${diskFileInfo.getFileLength}, |fileName:${diskFileInfo.getFilePath} |""".stripMargin) - return true + return StatusCode.HARD_SPLIT } } } - false + StatusCode.SUCCESS Review Comment: This change breaks the semantics of this method. I'd rather add a new status code named NO_SPLIT. ########## common/src/main/java/org/apache/celeborn/common/write/DataBatchReviveInfo.java: ########## @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import org.apache.celeborn.common.protocol.message.StatusCode; + +public class DataBatchReviveInfo { Review Comment: This class is only used in the logs, it should not exist. ########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } } + class PushMergedDataCallback(callback: RpcResponseCallback) { + private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]() + + def addUnsuccessfulPartition(index: Int, statusCode: StatusCode): Unit = { + partitionIndex2StatusCode.put(index, statusCode.getValue) + } + + def containIndex(index: Int): Boolean = { + partitionIndex2StatusCode.contains(index) + } + + def getStatusCode(index: Int): Byte = { Review Comment: ```suggestion def getPartitionFailedStatus(index:Int): Byte = { ``` ########## common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java: ########## @@ -96,4 +96,119 @@ public enum StatusCode { public final byte getValue() { return value; } + + public static StatusCode fromValue(byte value) { Review Comment: This method can be replaced by following codes. ``` private static final Map<Byte, StatusCode> lookup = Arrays.stream(StatusCode.values()).collect(Collectors.toMap(i -> i.getValue(), i -> i)); public static StatusCode fromValue(byte value) { StatusCode code = lookup.get(value); if (code != null) { return code; } throw new IllegalArgumentException("Unknown status code: " + value); } ``` ########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -758,6 +808,76 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } } + class PushMergedDataCallback(callback: RpcResponseCallback) { + private val partitionIndex2StatusCode = new mutable.HashMap[Int, Byte]() + + def addUnsuccessfulPartition(index: Int, statusCode: StatusCode): Unit = { + partitionIndex2StatusCode.put(index, statusCode.getValue) + } + + def containIndex(index: Int): Boolean = { + partitionIndex2StatusCode.contains(index) + } + + def getStatusCode(index: Int): Byte = { + partitionIndex2StatusCode.getOrElse(index, -1) + } + + def unionUnsuccessfulPartition( + replicaPartitionIndexes: Array[Int], + replicaStatusCodes: Array[Byte]): Unit = { + if (replicaPartitionIndexes.length != replicaStatusCodes.length) { + throw new IllegalArgumentException( + "replicaPartitionIndexes and replicaStatusCodes must have the same size") + } + for (i <- replicaPartitionIndexes.indices) { + val index = replicaPartitionIndexes(i) + // if primary and replica have the same index, use the primary's status code + if (!partitionIndex2StatusCode.contains(index)) { + partitionIndex2StatusCode.put(index, replicaStatusCodes(i)) + } + } + } + + def getSortedSkipIndexes(): Array[Int] = { Review Comment: ```suggestion def getFailedPartitionIndexes(): Array[Int] = { ``` ########## worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala: ########## @@ -1330,21 +1445,28 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler batchOffsets match { case Some(batchOffsets) => var index = 0 + var indexOfSkipIndexes = 0 + val skipListLength = sortedSkipIndexes.length var fileWriter: PartitionDataWriter = null while (index < fileWriters.length) { - fileWriter = fileWriters(index) - if (!writePromise.isCompleted) { - val offset = body.readerIndex() + batchOffsets(index) - val length = - if (index == fileWriters.length - 1) { - body.readableBytes() - batchOffsets(index) - } else { - batchOffsets(index + 1) - batchOffsets(index) - } - val batchBody = body.slice(offset, length) - writeData(fileWriter, batchBody, shuffleKey) + if (indexOfSkipIndexes < skipListLength && index == sortedSkipIndexes( + indexOfSkipIndexes)) { + indexOfSkipIndexes += 1 } else { - fileWriter.decrementPendingWrites() + fileWriter = fileWriters(index) + if (!writePromise.isCompleted) { + val offset = body.readerIndex() + batchOffsets(index) + val length = + if (index == fileWriters.length - 1) { + body.readableBytes() - batchOffsets(index) + } else { + batchOffsets(index + 1) - batchOffsets(index) + } + val batchBody = body.slice(offset, length) + writeData(fileWriter, batchBody, shuffleKey) Review Comment: This PR is intended to optimize process logic for pushMergedData and possibly create another PR for fileWriters with failures. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@celeborn.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org