This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new e41ee2dc9 [CELEBORN-1721][CIP-12] Support HARD_SPLIT in PushMergedData
e41ee2dc9 is described below
commit e41ee2dc9b643f91ae8dca3512d6a5f2305a2a80
Author: jiang13021 <[email protected]>
AuthorDate: Fri Dec 6 09:20:36 2024 +0800
[CELEBORN-1721][CIP-12] Support HARD_SPLIT in PushMergedData
### What changes were proposed in this pull request?
As title.
### Why are the changes needed?
https://docs.google.com/document/d/1Jaix22vME0m1Q-JtTHF9WYsrsxBWBwzwmifcPxNQZHk/edit?tab=t.0#heading=h.iadpu3t4rywi
(Thanks to cfmcgrady littlexyw ErikFang waitinfuture RexXiong FMX for
their efforts on the proposal)
### Does this PR introduce _any_ user-facing change?
The response of pushMergedData has been modified, however, the changes are
backward compatible.
### How was this patch tested?
UT: org.apache.celeborn.service.deploy.cluster.PushMergedDataHardSplitSuite
Closes #2924 from jiang13021/cip-12.
Authored-by: jiang13021 <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../apache/celeborn/client/ShuffleClientImpl.java | 133 +++--
.../common/network/protocol/TransportMessage.java | 2 +
.../common/protocol/message/StatusCode.java | 18 +-
common/src/main/proto/TransportMessages.proto | 8 +-
.../common/protocol/message/ControlMessages.scala | 6 +
.../service/deploy/worker/PushDataHandler.scala | 533 ++++++++++++++-------
.../deploy/cluster/PushMergedDataSplitSuite.scala | 162 +++++++
7 files changed, 655 insertions(+), 207 deletions(-)
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 224ac57aa..f1e6a57e3 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -30,6 +30,7 @@ import scala.reflect.ClassTag$;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
+import com.google.protobuf.InvalidProtocolBufferException;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import org.apache.commons.lang3.StringUtils;
@@ -51,6 +52,7 @@ import
org.apache.celeborn.common.network.client.TransportClientBootstrap;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
import org.apache.celeborn.common.network.protocol.PushMergedData;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.sasl.SaslClientBootstrap;
import org.apache.celeborn.common.network.sasl.SaslCredentials;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
@@ -1438,9 +1440,64 @@ public class ShuffleClientImpl extends ShuffleClient {
new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
- if (response.remaining() > 0) {
- byte reason = response.get();
- if (reason == StatusCode.HARD_SPLIT.getValue()) {
+ byte reason = response.get();
+ if (reason == StatusCode.HARD_SPLIT.getValue()) {
+ ArrayList<DataBatches.DataBatch> batchesNeedResubmit;
+ if (response.remaining() > 0) {
+ batchesNeedResubmit = new ArrayList<>();
+ PbPushMergedDataSplitPartitionInfo partitionInfo;
+ try {
+ partitionInfo =
TransportMessage.fromByteBuffer(response).getParsedPayload();
+ } catch (CelebornIOException | InvalidProtocolBufferException
e) {
+ callback.onFailure(
+ new CelebornIOException("parse pushMergedData response
failed", e));
+ return;
+ }
+ List<Integer> splitPartitionIndexes =
partitionInfo.getSplitPartitionIndexesList();
+ List<Integer> statusCodeList =
partitionInfo.getStatusCodesList();
+ StringBuilder dataBatchReviveInfos = new StringBuilder();
+ for (int i = 0; i < splitPartitionIndexes.size(); i++) {
+ int partitionIndex = splitPartitionIndexes.get(i);
+ int batchId = batches.get(partitionIndex).batchId;
+ dataBatchReviveInfos.append(
+ String.format(
+ "(batchId=%d, partitionId=%d, cause=%s)",
+ batchId,
+ partitionIds[partitionIndex],
+
StatusCode.fromValue(statusCodeList.get(i).byteValue())));
+ if (statusCodeList.get(i) ==
StatusCode.SOFT_SPLIT.getValue()) {
+ PartitionLocation loc = batches.get(i).loc;
+ if (!newerPartitionLocationExists(
+ reducePartitionMap.get(shuffleId), loc.getId(),
loc.getEpoch(), false)) {
+ ReviveRequest reviveRequest =
+ new ReviveRequest(
+ shuffleId,
+ mapId,
+ attemptId,
+ loc.getId(),
+ loc.getEpoch(),
+ loc,
+ StatusCode.SOFT_SPLIT);
+ reviveManager.addRequest(reviveRequest);
+ }
+ } else {
+ batchesNeedResubmit.add(batches.get(partitionIndex));
+ }
+ }
+ logger.info(
+ "Push merged data to {} partial success required for
shuffle {} map {} attempt {} groupedBatch {}. split batches {}.",
+ addressPair,
+ shuffleId,
+ mapId,
+ attemptId,
+ groupedBatchId,
+ dataBatchReviveInfos);
+ } else {
+ // Workers that do not incorporate changes from [CELEBORN-1721]
+ // will respond with a status of HARD_SPLIT,
+ // but will not include a PbPushMergedDataSplitPartitionInfo.
+ // For backward compatibility, all batches must be resubmitted.
+ batchesNeedResubmit = batches;
logger.info(
"Push merged data to {} hard split required for shuffle {}
map {} attempt {} partition {} groupedBatch {} batch {}.",
addressPair,
@@ -1450,10 +1507,14 @@ public class ShuffleClientImpl extends ShuffleClient {
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds));
-
+ }
+ if (batchesNeedResubmit.isEmpty()) {
+ pushState.onSuccess(hostPort);
+ callback.onSuccess(ByteBuffer.wrap(new byte[]
{StatusCode.SOFT_SPLIT.getValue()}));
+ } else {
ReviveRequest[] requests =
addAndGetReviveRequests(
- shuffleId, mapId, attemptId, batches,
StatusCode.HARD_SPLIT);
+ shuffleId, mapId, attemptId, batchesNeedResubmit,
StatusCode.HARD_SPLIT);
pushDataRetryPool.submit(
() ->
submitRetryPushMergedData(
@@ -1461,7 +1522,7 @@ public class ShuffleClientImpl extends ShuffleClient {
shuffleId,
mapId,
attemptId,
- batches,
+ batchesNeedResubmit,
StatusCode.HARD_SPLIT,
groupedBatchId,
requests,
@@ -1470,39 +1531,37 @@ public class ShuffleClientImpl extends ShuffleClient {
+
conf.clientRpcRequestPartitionLocationAskTimeout()
.duration()
.toMillis()));
- } else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
- logger.debug(
- "Push merged data to {} primary congestion required for
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
- addressPair,
- shuffleId,
- mapId,
- attemptId,
- Arrays.toString(partitionIds),
- groupedBatchId,
- Arrays.toString(batchIds));
- pushState.onCongestControl(hostPort);
- callback.onSuccess(response);
- } else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
- logger.debug(
- "Push merged data to {} replica congestion required for
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
- addressPair,
- shuffleId,
- mapId,
- attemptId,
- Arrays.toString(partitionIds),
- groupedBatchId,
- Arrays.toString(batchIds));
- pushState.onCongestControl(hostPort);
- callback.onSuccess(response);
- } else {
- // StageEnd.
- response.rewind();
- pushState.onSuccess(hostPort);
- callback.onSuccess(response);
}
- } else {
- pushState.onSuccess(hostPort);
+ } else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
+ logger.debug(
+ "Push merged data to {} primary congestion required for
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
+ addressPair,
+ shuffleId,
+ mapId,
+ attemptId,
+ Arrays.toString(partitionIds),
+ groupedBatchId,
+ Arrays.toString(batchIds));
+ pushState.onCongestControl(hostPort);
callback.onSuccess(response);
+ } else if (reason ==
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
+ logger.debug(
+ "Push merged data to {} replica congestion required for
shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
+ addressPair,
+ shuffleId,
+ mapId,
+ attemptId,
+ Arrays.toString(partitionIds),
+ groupedBatchId,
+ Arrays.toString(batchIds));
+ pushState.onCongestControl(hostPort);
+ callback.onSuccess(response);
+ } else if (reason == StatusCode.MAP_ENDED.getValue()) {
+ pushState.onSuccess(hostPort);
+ callback.onSuccess(ByteBuffer.wrap(new byte[]
{StatusCode.MAP_ENDED.getValue()}));
+ } else { // success
+ pushState.onSuccess(hostPort);
+ callback.onSuccess(ByteBuffer.wrap(new byte[]
{StatusCode.SUCCESS.getValue()}));
}
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 239874e8c..01a9a37f9 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -113,6 +113,8 @@ public class TransportMessage implements Serializable {
return (T) PbSegmentStart.parseFrom(payload);
case NOTIFY_REQUIRED_SEGMENT_VALUE:
return (T) PbNotifyRequiredSegment.parseFrom(payload);
+ case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE:
+ return (T) PbPushMergedDataSplitPartitionInfo.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index ca8655bab..086368ca1 100644
---
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -17,6 +17,10 @@
package org.apache.celeborn.common.protocol.message;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.stream.Collectors;
+
public enum StatusCode {
// 1/0 Status
SUCCESS(0),
@@ -84,7 +88,8 @@ public enum StatusCode {
PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50),
OPEN_STREAM_FAILED(51),
SEGMENT_START_FAIL_REPLICA(52),
- SEGMENT_START_FAIL_PRIMARY(53);
+ SEGMENT_START_FAIL_PRIMARY(53),
+ NO_SPLIT(54);
private final byte value;
@@ -96,4 +101,15 @@ public enum StatusCode {
public final byte getValue() {
return value;
}
+
+ private static final Map<Byte, StatusCode> lookup =
+ Arrays.stream(StatusCode.values()).collect(Collectors.toMap(i ->
i.getValue(), i -> i));
+
+ public static StatusCode fromValue(byte value) {
+ StatusCode code = lookup.get(value);
+ if (code != null) {
+ return code;
+ }
+ throw new IllegalArgumentException("Unknown status code: " + value);
+ }
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 5c03ba291..d7b439a68 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -111,6 +111,7 @@ enum MessageType {
BATCH_UNREGISTER_SHUFFLE_RESPONSE= 88;
REVISE_LOST_SHUFFLES = 89;
REVISE_LOST_SHUFFLES_RESPONSE = 90;
+ PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91;
}
enum StreamType {
@@ -875,4 +876,9 @@ message PbReviseLostShuffles{
message PbReviseLostShufflesResponse{
bool success = 1;
string message = 2;
-}
\ No newline at end of file
+}
+
+message PbPushMergedDataSplitPartitionInfo {
+ repeated int32 splitPartitionIndexes = 1;
+ repeated int32 statusCodes = 2;
+}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index bd1a6a11e..3c9b22c54 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -602,6 +602,9 @@ object ControlMessages extends Logging {
MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE,
pb.toByteArray)
+ case pb: PbPushMergedDataSplitPartitionInfo =>
+ new TransportMessage(MessageType.PUSH_MERGED_DATA_SPLIT_PARTITION_INFO,
pb.toByteArray)
+
case HeartbeatFromWorker(
host,
rpcPort,
@@ -1400,6 +1403,9 @@ object ControlMessages extends Logging {
case REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE_VALUE =>
PbReportBarrierStageAttemptFailureResponse.parseFrom(message.getPayload)
+
+ case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE =>
+ PbPushMergedDataSplitPartitionInfo.parseFrom(message.getPayload)
}
}
}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index af6d4cc66..0890f8355 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -18,9 +18,11 @@
package org.apache.celeborn.service.deploy.worker
import java.nio.ByteBuffer
+import java.util
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray}
+import scala.collection.mutable
import scala.concurrent.{Await, Promise}
import scala.concurrent.duration.Duration
import scala.util.{Failure, Success, Try}
@@ -37,7 +39,7 @@ import
org.apache.celeborn.common.network.client.{RpcResponseCallback, Transport
import org.apache.celeborn.common.network.protocol.{Message, PushData,
PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage,
RpcFailure, RpcRequest, RpcResponse, TransportMessage}
import org.apache.celeborn.common.network.protocol.Message.Type
import org.apache.celeborn.common.network.server.BaseMessageHandler
-import org.apache.celeborn.common.protocol.{PartitionLocation,
PartitionSplitMode, PartitionType, PbPushDataHandShake, PbRegionFinish,
PbRegionStart, PbSegmentStart}
+import org.apache.celeborn.common.protocol.{PartitionLocation,
PartitionSplitMode, PartitionType, PbPushDataHandShake,
PbPushMergedDataSplitPartitionInfo, PbRegionFinish, PbRegionStart,
PbSegmentStart}
import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.unsafe.Platform
@@ -183,7 +185,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
// Fetch real batchId from body will add more cost and no meaning for
replicate.
val doReplicate = location != null && location.hasPeer && isPrimary
- val softSplit = new AtomicBoolean(false)
+ var softSplit = false
if (location == null) {
val (mapId, attemptId) = getMapAttempt(body)
@@ -248,7 +250,14 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
return
}
- if (checkDiskFullAndSplit(fileWriter, isPrimary, softSplit,
callbackWithTimer)) return
+ val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary)
+ if (splitStatus == StatusCode.HARD_SPLIT) {
+ workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+ return
+ } else if (splitStatus == StatusCode.SOFT_SPLIT) {
+ softSplit = true
+ }
fileWriter.incrementPendingWrites()
@@ -261,7 +270,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
fileWriter.decrementPendingWrites()
return
}
- val writePromise = Promise[Unit]()
+ val writePromise = Promise[Array[StatusCode]]()
// for primary, send data to replica
if (doReplicate) {
pushData.body().retain()
@@ -288,34 +297,39 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
val wrappedCallback = new RpcResponseCallback() {
override def onSuccess(response: ByteBuffer): Unit = {
Try(Await.result(writePromise.future, Duration.Inf)) match {
- case Success(_) =>
- if (response.remaining() > 0) {
- val resp = ByteBuffer.allocate(response.remaining())
- resp.put(response)
- resp.flip()
- callbackWithTimer.onSuccess(resp)
- } else if (softSplit.get()) {
- // TODO Currently if the worker is in soft split status,
given the guess that the client
- // will fast stop pushing data to the worker, we won't
return congest status. But
- // in the long term, especially if this issue could
frequently happen, we may need to return
- // congest&softSplit status together
- callbackWithTimer.onSuccess(
-
ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+ case Success(result) =>
+ if (result(0) != StatusCode.SUCCESS) {
+
callback.onSuccess(ByteBuffer.wrap(Array[Byte](result(0).getValue)))
} else {
- Option(CongestionController.instance()) match {
- case Some(congestionController) =>
- if (congestionController.isUserCongested(
- fileWriter.getUserCongestionControlContext)) {
- // Check whether primary congest the data though the
replicas doesn't congest
- // it(the response is empty)
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(
-
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
- } else {
+ if (response.remaining() > 0) {
+ val resp = ByteBuffer.allocate(response.remaining())
+ resp.put(response)
+ resp.flip()
+ callbackWithTimer.onSuccess(resp)
+ } else if (softSplit) {
+ // TODO Currently if the worker is in soft split status,
given the guess that the client
+ // will fast stop pushing data to the worker, we won't
return congest status. But
+ // in the long term, especially if this issue could
frequently happen, we may need to return
+ // congest&softSplit status together
+ callbackWithTimer.onSuccess(
+
ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+ } else {
+ Option(CongestionController.instance()) match {
+ case Some(congestionController) =>
+ if (congestionController.isUserCongested(
+ fileWriter.getUserCongestionControlContext)) {
+ // Check whether primary congest the data though
the replicas doesn't congest
+ // it(the response is empty)
+ callbackWithTimer.onSuccess(
+ ByteBuffer.wrap(
+ Array[Byte](
+
StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+ } else {
+
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ }
+ case None =>
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
- }
- case None =>
-
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ }
}
}
case Failure(e) => callbackWithTimer.onFailure(e)
@@ -376,29 +390,33 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
// congest&softSplit status together
writeLocalData(Seq(fileWriter), body, shuffleKey, isPrimary, None,
writePromise)
Try(Await.result(writePromise.future, Duration.Inf)) match {
- case Success(_) =>
- if (softSplit.get()) {
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+ case Success(result) =>
+ if (result(0) != StatusCode.SUCCESS) {
+
callback.onSuccess(ByteBuffer.wrap(Array[Byte](result(0).getValue)))
} else {
- Option(CongestionController.instance()) match {
- case Some(congestionController) =>
- if (congestionController.isUserCongested(
- fileWriter.getUserCongestionControlContext)) {
- if (isPrimary) {
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(
-
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+ if (softSplit) {
+ callbackWithTimer.onSuccess(
+ ByteBuffer.wrap(Array[Byte](StatusCode.SOFT_SPLIT.getValue)))
+ } else {
+ Option(CongestionController.instance()) match {
+ case Some(congestionController) =>
+ if (congestionController.isUserCongested(
+ fileWriter.getUserCongestionControlContext)) {
+ if (isPrimary) {
+ callbackWithTimer.onSuccess(
+ ByteBuffer.wrap(
+
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+ } else {
+ callbackWithTimer.onSuccess(
+ ByteBuffer.wrap(
+
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+ }
} else {
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(
-
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+ callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
}
- } else {
+ case None =>
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
- }
- case None =>
- callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ }
}
}
case Failure(e) => callbackWithTimer.onFailure(e)
@@ -414,6 +432,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
val batchOffsets = pushMergedData.batchOffsets
val body = pushMergedData.body.asInstanceOf[NettyManagedBuffer].getBuf
val isPrimary = mode == PartitionLocation.Mode.PRIMARY
+ val (mapId, attemptId) = getMapAttempt(body)
val key = s"${pushMergedData.requestId}"
val callbackWithTimer =
@@ -430,6 +449,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
key,
callback)
}
+ val pushMergedDataCallback = new PushMergedDataCallback(callbackWithTimer)
// For test
if (isPrimary && testPushPrimaryDataTimeout &&
@@ -458,7 +478,6 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
while (index < partitionIdToLocations.length) {
val (id, loc) = partitionIdToLocations(index)
if (loc == null) {
- val (mapId, attemptId) = getMapAttempt(body)
// MapperAttempts for a shuffle exists after any CommitFiles request
succeeds.
// A shuffle can trigger multiple CommitFiles requests, for reasons
like: HARD_SPLIT happens, StageEnd.
// If MapperAttempts but the value is -1 for the mapId(-1 means the
map has not yet finished),
@@ -468,14 +487,13 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
logDebug(s"Receive push merged data from speculative " +
s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId), " +
s"but this mapper has already been ended.")
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(Array[Byte](StatusCode.MAP_ENDED.getValue)))
+ pushMergedDataCallback.onSuccess(StatusCode.MAP_ENDED)
+ return
} else {
logDebug(s"[Case1] Receive push merged data for committed hard
split partition of " +
s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+ pushMergedDataCallback.addSplitPartition(index,
StatusCode.HARD_SPLIT)
}
} else {
if (storageManager.shuffleKeySet().contains(shuffleKey)) {
@@ -485,16 +503,15 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
logDebug(s"[Case2] Receive push merged data for committed hard
split partition of " +
s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+ pushMergedDataCallback.addSplitPartition(index,
StatusCode.HARD_SPLIT)
} else {
logWarning(s"While handling PushMergedData, Partition location
wasn't found for " +
s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId,
uniqueId $id).")
- callbackWithTimer.onFailure(
+ pushMergedDataCallback.onFailure(
new
CelebornIOException(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND))
+ return
}
}
- return
}
index += 1
}
@@ -502,7 +519,9 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
// During worker shutdown, worker will return HARD_SPLIT for all existed
partition.
// This should before return exception to make current push data can
revive and retry.
if (shutdown.get()) {
-
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+ partitionIdToLocations.indices.foreach(index =>
+ pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT))
+ pushMergedDataCallback.onSuccess(StatusCode.HARD_SPLIT)
return
}
@@ -519,31 +538,48 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
s"While handling PushMergedData, throw $cause, fileWriter
$fileWriterWithException has exception.",
fileWriterWithException.getException)
workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
- callbackWithTimer.onFailure(new CelebornIOException(cause))
+ pushMergedDataCallback.onFailure(new CelebornIOException(cause))
return
}
- if (fileWriters.exists(checkDiskFull(_) == true)) {
- val (mapId, attemptId) = getMapAttempt(body)
- logWarning(
- s"return hard split for disk full with shuffle $shuffleKey map $mapId
attempt $attemptId")
-
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
- return
+ var fileWriterIndex = 0
+ val totalFileWriters = fileWriters.length
+ while (fileWriterIndex < totalFileWriters) {
+ val fileWriter = fileWriters(fileWriterIndex)
+ if (fileWriter == null) {
+ if (!pushMergedDataCallback.isHardSplitPartition(fileWriterIndex)) {
+ pushMergedDataCallback.onFailure(
+ new CelebornIOException(s"Partition $fileWriterIndex's fileWriter
not found," +
+ s" but it hasn't been identified in the previous validation
step."))
+ return
+ }
+ } else {
+ if (fileWriter.isClosed) {
+ val fileInfo = fileWriter.getCurrentFileInfo
+ logWarning(
+ s"[handlePushMergedData] FileWriter is already closed! File path
${fileInfo.getFilePath} " +
+ s"length ${fileInfo.getFileLength}")
+ pushMergedDataCallback.addSplitPartition(fileWriterIndex,
StatusCode.HARD_SPLIT)
+ } else {
+ val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary)
+ if (splitStatus == StatusCode.HARD_SPLIT) {
+ logWarning(
+ s"return hard split for disk full with shuffle $shuffleKey map
$mapId attempt $attemptId")
+ workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+ pushMergedDataCallback.addSplitPartition(fileWriterIndex,
StatusCode.HARD_SPLIT)
+ } else if (splitStatus == StatusCode.SOFT_SPLIT) {
+ pushMergedDataCallback.addSplitPartition(fileWriterIndex,
StatusCode.SOFT_SPLIT)
+ }
+ }
+ if (!pushMergedDataCallback.isHardSplitPartition(fileWriterIndex)) {
+ fileWriter.incrementPendingWrites()
+ }
+ }
+ fileWriterIndex += 1
}
- fileWriters.foreach(_.incrementPendingWrites())
-
- val closedFileWriter = fileWriters.find(_.isClosed)
- if (closedFileWriter.isDefined) {
- val fileInfo = closedFileWriter.get.getCurrentFileInfo
- logWarning(
- s"[handlePushMergedData] FileWriter is already closed! File path
${fileInfo.getFilePath} " +
- s"length ${fileInfo.getFileLength}")
- callbackWithTimer.onFailure(new CelebornIOException("File already
closed!"))
- fileWriters.foreach(_.decrementPendingWrites())
- return
- }
- val writePromise = Promise[Unit]()
+ val hardSplitIndexes = pushMergedDataCallback.getHardSplitIndexes
+ val writePromise = Promise[Array[StatusCode]]()
// for primary, send data to replica
if (doReplicate) {
pushMergedData.body().retain()
@@ -562,7 +598,7 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
workerSource.incCounter(WorkerSource.REPLICATE_DATA_CREATE_CONNECTION_FAIL_COUNT)
logError(
s"PushMergedData replication failed caused by unavailable peer
for partitionLocation: $location")
- callbackWithTimer.onFailure(
+ pushMergedDataCallback.onFailure(
new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA))
return
}
@@ -570,30 +606,71 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
// Handle the response from replica
val wrappedCallback = new RpcResponseCallback() {
override def onSuccess(response: ByteBuffer): Unit = {
+ val replicaReason = response.get()
+ if (replicaReason == StatusCode.HARD_SPLIT.getValue) {
+ if (response.remaining() > 0) {
+ try {
+ val pushMergedDataResponse:
PbPushMergedDataSplitPartitionInfo =
+ TransportMessage.fromByteBuffer(
+
response).getParsedPayload[PbPushMergedDataSplitPartitionInfo]()
+ pushMergedDataCallback.unionReplicaSplitPartitions(
+ pushMergedDataResponse.getSplitPartitionIndexesList,
+ pushMergedDataResponse.getStatusCodesList)
+ } catch {
+ case e: CelebornIOException =>
+ pushMergedDataCallback.onFailure(e)
+ return
+ case e: IllegalArgumentException =>
+ pushMergedDataCallback.onFailure(new
CelebornIOException(e))
+ return
+ }
+ } else {
+ // During the rolling upgrade of the worker cluster, it is
possible for the primary worker
+ // to be upgraded to a new version that includes the changes
from [CELEBORN-1721], while
+ // the replica worker is still running on an older version
that does not have these changes.
+ // In this scenario, the replica may return a response with
a status of HARD_SPLIT, but
+ // will not provide a PbPushMergedDataSplitPartitionInfo.
+ logWarning(
+ s"The response status from the replica (shuffle
$shuffleKey map $mapId attempt $attemptId) is HARD_SPLIT, but no
PbPushMergedDataSplitPartitionInfo is present.")
+ partitionIdToLocations.indices.foreach(index =>
+ pushMergedDataCallback.addSplitPartition(index,
StatusCode.HARD_SPLIT))
+ pushMergedDataCallback.onSuccess(StatusCode.HARD_SPLIT)
+ return
+ }
+ }
Try(Await.result(writePromise.future, Duration.Inf)) match {
- case Success(_) =>
+ case Success(result) =>
+ var index = 0
+ while (index < result.length) {
+ if (result(index) == StatusCode.HARD_SPLIT) {
+ pushMergedDataCallback.addSplitPartition(index,
result(index))
+ }
+ index += 1
+ }
// Only primary data enable replication will push data to
replica
- if (response.remaining() > 0) {
- val resp = ByteBuffer.allocate(response.remaining())
- resp.put(response)
- resp.flip()
- callbackWithTimer.onSuccess(resp)
- } else {
- Option(CongestionController.instance()) match {
- case Some(congestionController) if fileWriters.nonEmpty
=>
- if (congestionController.isUserCongested(
- fileWriters.head.getUserCongestionControlContext))
{
- // Check whether primary congest the data though the
replicas doesn't congest
- // it(the response is empty)
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(
-
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+ Option(CongestionController.instance()) match {
+ case Some(congestionController) if fileWriters.nonEmpty =>
+ if (congestionController.isUserCongested(
+ fileWriters.head.getUserCongestionControlContext)) {
+ // Check whether primary congest the data though the
replicas doesn't congest
+ // it(the response is empty)
+ pushMergedDataCallback.onSuccess(
+ StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED)
+ } else {
+ if (replicaReason ==
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) {
+ pushMergedDataCallback.onSuccess(
+ StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
} else {
-
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
}
- case None =>
-
callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
- }
+ }
+ case None =>
+ if (replicaReason ==
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) {
+ pushMergedDataCallback.onSuccess(
+ StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
+ } else {
+ pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
+ }
}
case Failure(e) => callbackWithTimer.onFailure(e)
}
@@ -606,17 +683,17 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
// 3. Throw IOException by channel, convert to
PUSH_DATA_CONNECTION_EXCEPTION_REPLICA
if
(e.getMessage.startsWith(StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA.name())) {
workerSource.incCounter(WorkerSource.REPLICATE_DATA_WRITE_FAIL_COUNT)
- callbackWithTimer.onFailure(e)
+ pushMergedDataCallback.onFailure(e)
} else if
(e.getMessage.startsWith(StatusCode.PUSH_DATA_TIMEOUT_REPLICA.name())) {
workerSource.incCounter(WorkerSource.REPLICATE_DATA_TIMEOUT_COUNT)
- callbackWithTimer.onFailure(e)
+ pushMergedDataCallback.onFailure(e)
} else if (ExceptionUtils.connectFail(e.getMessage)) {
workerSource.incCounter(WorkerSource.REPLICATE_DATA_CONNECTION_EXCEPTION_COUNT)
- callbackWithTimer.onFailure(
+ pushMergedDataCallback.onFailure(
new
CelebornIOException(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA))
} else {
workerSource.incCounter(WorkerSource.REPLICATE_DATA_FAIL_NON_CRITICAL_CAUSE_COUNT)
- callbackWithTimer.onFailure(
+ pushMergedDataCallback.onFailure(
new
CelebornIOException(StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA))
}
}
@@ -642,39 +719,56 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
logError(
s"PushMergedData replication failed during connecting peer for
partitionLocation: $location",
e)
- callbackWithTimer.onFailure(
+ pushMergedDataCallback.onFailure(
new
CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA))
}
}
})
- writeLocalData(fileWriters, body, shuffleKey, isPrimary,
Some(batchOffsets), writePromise)
+ writeLocalData(
+ fileWriters,
+ body,
+ shuffleKey,
+ isPrimary,
+ Some(batchOffsets),
+ writePromise,
+ hardSplitIndexes)
} else {
// The codes here could be executed if
// 1. the client doesn't enable push data to the replica, the primary
worker could hit here
// 2. the client enables push data to the replica, and the replica
worker could hit here
- writeLocalData(fileWriters, body, shuffleKey, isPrimary,
Some(batchOffsets), writePromise)
+ writeLocalData(
+ fileWriters,
+ body,
+ shuffleKey,
+ isPrimary,
+ Some(batchOffsets),
+ writePromise,
+ hardSplitIndexes)
Try(Await.result(writePromise.future, Duration.Inf)) match {
- case Success(_) =>
+ case Success(result) =>
+ var index = 0
+ while (index < result.length) {
+ if (result(index) == StatusCode.HARD_SPLIT) {
+ pushMergedDataCallback.addSplitPartition(index, result(index))
+ }
+ index += 1
+ }
Option(CongestionController.instance()) match {
case Some(congestionController) if fileWriters.nonEmpty =>
if (congestionController.isUserCongested(
fileWriters.head.getUserCongestionControlContext)) {
if (isPrimary) {
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(
-
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue)))
+
pushMergedDataCallback.onSuccess(StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED)
} else {
- callbackWithTimer.onSuccess(
- ByteBuffer.wrap(
-
Array[Byte](StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue)))
+
pushMergedDataCallback.onSuccess(StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
}
} else {
- callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
}
case None =>
- callbackWithTimer.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
}
- case Failure(e) => callbackWithTimer.onFailure(e)
+ case Failure(e) => pushMergedDataCallback.onFailure(e)
}
}
}
@@ -691,11 +785,15 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
var exceptionFileWriterIndex: Option[Int] = None
while (i < partitionIdToLocations.length) {
val (_, workingPartition) = partitionIdToLocations(i)
- val fileWriter =
workingPartition.asInstanceOf[WorkingPartition].getFileWriter
- if (fileWriter.getException != null) {
- exceptionFileWriterIndex = Some(i)
+ if (workingPartition != null) {
+ val fileWriter =
workingPartition.asInstanceOf[WorkingPartition].getFileWriter
+ if (fileWriter.getException != null) {
+ exceptionFileWriterIndex = Some(i)
+ }
+ fileWriters(i) = fileWriter
+ } else {
+ fileWriters(i) = null
}
- fileWriters(i) = fileWriter
i += 1
}
(fileWriters, exceptionFileWriterIndex)
@@ -747,6 +845,76 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
}
}
+ class PushMergedDataCallback(callback: RpcResponseCallback) {
+ private val splitPartitionStatuses = new mutable.HashMap[Int, Byte]()
+
+ def addSplitPartition(index: Int, statusCode: StatusCode): Unit = {
+ splitPartitionStatuses.put(index, statusCode.getValue)
+ }
+
+ def isHardSplitPartition(index: Int): Boolean = {
+ splitPartitionStatuses.getOrElse(index, -1) ==
StatusCode.HARD_SPLIT.getValue
+ }
+
+ def unionReplicaSplitPartitions(
+ replicaPartitionIndexes: util.List[Integer],
+ replicaStatusCodes: util.List[Integer]): Unit = {
+ if (replicaPartitionIndexes.size() != replicaStatusCodes.size()) {
+ throw new IllegalArgumentException(
+ "replicaPartitionIndexes and replicaStatusCodes must have the same
size")
+ }
+ for (i <- 0 until replicaPartitionIndexes.size()) {
+ val index = replicaPartitionIndexes.get(i)
+ // The priority of HARD_SPLIT is higher than that of SOFT_SPLIT.
+ if (!isHardSplitPartition(index)) {
+ splitPartitionStatuses.put(index,
replicaStatusCodes.get(i).byteValue())
+ }
+ }
+ }
+
+ /**
+ * Returns the ordered indexes of partitions that are not writable.
+ * A partition is considered not writable if it is marked as HARD_SPLIT or
failed.
+ */
+ def getHardSplitIndexes: Array[Int] = {
+ splitPartitionStatuses.collect {
+ case (partitionIndex, statusCode) if statusCode ==
StatusCode.HARD_SPLIT.getValue =>
+ partitionIndex
+ }.toSeq.sorted.toArray
+ }
+
+ def onSuccess(status: StatusCode): Unit = {
+ val splitPartitionIndexes = new util.ArrayList[Integer]()
+ val statusCodes = new util.ArrayList[Integer]()
+ splitPartitionStatuses.foreach {
+ case (partitionIndex, statusCode) =>
+ splitPartitionIndexes.add(partitionIndex)
+ statusCodes.add(statusCode)
+ }
+ if (splitPartitionStatuses.isEmpty || status == StatusCode.MAP_ENDED) {
+ callback.onSuccess(
+ ByteBuffer.wrap(Array[Byte](status.getValue)))
+ } else {
+ val pushMergedDataInfo =
PbPushMergedDataSplitPartitionInfo.newBuilder()
+ .addAllSplitPartitionIndexes(splitPartitionIndexes)
+ .addAllStatusCodes(statusCodes)
+ .build()
+ val pushMergedDataInfoByteBuffer =
Utils.toTransportMessage(pushMergedDataInfo)
+ .asInstanceOf[TransportMessage]
+ .toByteBuffer
+ val response = ByteBuffer.allocate(1 +
pushMergedDataInfoByteBuffer.remaining())
+ response.put(StatusCode.HARD_SPLIT.getValue)
+ response.put(pushMergedDataInfoByteBuffer)
+ response.flip()
+ callback.onSuccess(response)
+ }
+ }
+
+ def onFailure(exception: Throwable): Unit = {
+ callback.onFailure(exception)
+ }
+ }
+
private def handleCore(
client: TransportClient,
message: RequestMessage,
@@ -821,18 +989,28 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
fileWriter.decrementPendingWrites()
return
}
- val writePromise = Promise[Unit]()
+ val writePromise = Promise[Array[StatusCode]]()
writeLocalData(Seq(fileWriter), body, shuffleKey, isPrimary, None,
writePromise)
// for primary, send data to replica
if (location.hasPeer && isPrimary) {
// to do
Try(Await.result(writePromise.future, Duration.Inf)) match {
- case Success(_) =>
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ case Success(result) =>
+ if (result(0) != StatusCode.SUCCESS) {
+ wrappedCallback.onFailure(new CelebornIOException("Write data
failed!"))
+ } else {
+ wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ }
case Failure(e) => wrappedCallback.onFailure(e)
}
} else {
Try(Await.result(writePromise.future, Duration.Inf)) match {
- case Success(_) =>
wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ case Success(result) =>
+ if (result(0) != StatusCode.SUCCESS) {
+ wrappedCallback.onFailure(new CelebornIOException("Write data
failed!"))
+ } else {
+ wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]()))
+ }
case Failure(e) => wrappedCallback.onFailure(e)
}
}
@@ -1030,9 +1208,11 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
if (checkSplit && (messageType == Type.REGION_START || messageType ==
Type.PUSH_DATA_HAND_SHAKE) && isPartitionSplitEnabled &&
checkDiskFullAndSplit(
fileWriter,
- isPrimary,
- null,
- callback)) return
+ isPrimary) == StatusCode.HARD_SPLIT) {
+ workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
+ return
+ }
try {
messageType match {
@@ -1226,9 +1406,12 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
private def checkDiskFullAndSplit(
fileWriter: PartitionDataWriter,
- isPrimary: Boolean,
- softSplit: AtomicBoolean,
- callback: RpcResponseCallback): Boolean = {
+ isPrimary: Boolean): StatusCode = {
+ if (fileWriter.needHardSplitForMemoryShuffleStorage()) {
+ logInfo(
+ s"Do hardSplit for memory shuffle file
fileLength:${fileWriter.getMemoryFileInfo.getFileLength}")
+ return StatusCode.HARD_SPLIT
+ }
val diskFull = checkDiskFull(fileWriter)
logTrace(
s"""
@@ -1239,24 +1422,14 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
|fileLength:${fileWriter.getCurrentFileInfo.getFileLength}
|fileName:${fileWriter.getCurrentFileInfo.getFilePath}
|""".stripMargin)
- if (fileWriter.needHardSplitForMemoryShuffleStorage()) {
- workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
-
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
- logInfo(
- s"Do hardSplit for memory shuffle file
fileLength:${fileWriter.getMemoryFileInfo.getFileLength}")
- return true
- }
-
val diskFileInfo = fileWriter.getDiskFileInfo
if (diskFileInfo != null) {
if (workerPartitionSplitEnabled && ((diskFull &&
diskFileInfo.getFileLength > partitionSplitMinimumSize) ||
(isPrimary && diskFileInfo.getFileLength >
fileWriter.getSplitThreshold))) {
- if (softSplit != null && fileWriter.getSplitMode ==
PartitionSplitMode.SOFT &&
+ if (fileWriter.getSplitMode == PartitionSplitMode.SOFT &&
(fileWriter.getDiskFileInfo.getFileLength <
partitionSplitMaximumSize)) {
- softSplit.set(true)
+ return StatusCode.SOFT_SPLIT
} else {
- workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
-
callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
logInfo(
s"""
|CheckDiskFullAndSplit hardSplit
@@ -1266,11 +1439,11 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
|fileLength:${diskFileInfo.getFileLength},
|fileName:${diskFileInfo.getFilePath}
|""".stripMargin)
- return true
+ return StatusCode.HARD_SPLIT
}
}
}
- false
+ StatusCode.NO_SPLIT
}
private def getReplicateClient(host: String, port: Int, partitionId: Int):
TransportClient = {
@@ -1287,10 +1460,18 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
shuffleKey: String,
isPrimary: Boolean,
batchOffsets: Option[Array[Int]],
- writePromise: Promise[Unit]): Unit = {
- def writeData(fileWriter: PartitionDataWriter, body: ByteBuf, shuffleKey:
String): Unit = {
+ writePromise: Promise[Array[StatusCode]],
+ hardSplitIndexes: Array[Int] = Array.empty[Int]): Unit = {
+ val length = fileWriters.length
+ val result = new Array[StatusCode](length)
+ def writeData(
+ fileWriter: PartitionDataWriter,
+ body: ByteBuf,
+ shuffleKey: String,
+ index: Int): Unit = {
try {
fileWriter.write(body)
+ result(index) = StatusCode.SUCCESS
} catch {
case e: Exception =>
if (e.isInstanceOf[AlreadyClosedException]) {
@@ -1302,47 +1483,63 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
// TODO just info log for ended attempt
logWarning(s"Append data failed for task(shuffle $shuffleKey, map
$mapId, attempt" +
s" $attemptId), caused by AlreadyClosedException, endedAttempt
$endedAttempt, error message: ${e.getMessage}")
+ workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
+ result(index) = StatusCode.HARD_SPLIT
} else {
logError("Exception encountered when write.", e)
+ workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
+ val cause =
+ if (isPrimary) {
+ StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
+ } else {
+ StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA
+ }
+ writePromise.failure(new CelebornIOException(cause))
}
- val cause =
- if (isPrimary) {
- StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
- } else {
- StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA
- }
- workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
- writePromise.failure(new CelebornIOException(cause))
fileWriter.decrementPendingWrites()
}
}
batchOffsets match {
case Some(batchOffsets) =>
var index = 0
+ val hardSplitIterator = hardSplitIndexes.iterator
+ var currentHardSplitIndex = nextValueOrElse(hardSplitIterator, -1)
var fileWriter: PartitionDataWriter = null
while (index < fileWriters.length) {
- fileWriter = fileWriters(index)
- if (!writePromise.isCompleted) {
- val offset = body.readerIndex() + batchOffsets(index)
- val length =
- if (index == fileWriters.length - 1) {
- body.readableBytes() - batchOffsets(index)
- } else {
- batchOffsets(index + 1) - batchOffsets(index)
- }
- val batchBody = body.slice(offset, length)
- writeData(fileWriter, batchBody, shuffleKey)
+ if (index == currentHardSplitIndex) {
+ currentHardSplitIndex = nextValueOrElse(hardSplitIterator, -1)
} else {
- fileWriter.decrementPendingWrites()
+ fileWriter = fileWriters(index)
+ if (!writePromise.isCompleted) {
+ val offset = body.readerIndex() + batchOffsets(index)
+ val length =
+ if (index == fileWriters.length - 1) {
+ body.readableBytes() - batchOffsets(index)
+ } else {
+ batchOffsets(index + 1) - batchOffsets(index)
+ }
+ val batchBody = body.slice(offset, length)
+ writeData(fileWriter, batchBody, shuffleKey, index)
+ } else {
+ fileWriter.decrementPendingWrites()
+ }
}
index += 1
}
case _ =>
- writeData(fileWriters.head, body, shuffleKey)
+ writeData(fileWriters.head, body, shuffleKey, 0)
}
if (!writePromise.isCompleted) {
workerSource.incCounter(WorkerSource.WRITE_DATA_SUCCESS_COUNT)
- writePromise.success()
+ writePromise.success(result)
+ }
+ }
+
+ private def nextValueOrElse(iterator: Iterator[Int], defaultValue: Int): Int
= {
+ if (iterator.hasNext) {
+ iterator.next()
+ } else {
+ defaultValue
}
}
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala
new file mode 100644
index 000000000..cdd0e758e
--- /dev/null
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.cluster
+
+import java.nio.charset.StandardCharsets
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.commons.lang3.RandomStringUtils
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.meta.WorkerInfo
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class PushMergedDataSplitSuite extends AnyFunSuite
+ with Logging with MiniClusterFeature with BeforeAndAfterAll {
+
+ var masterEndpoint = ""
+ override def beforeAll(): Unit = {
+ val conf = Map("celeborn.worker.flusher.buffer.size" -> "0")
+
+ logInfo("test initialized , setup Celeborn mini cluster")
+ val (master, _) = setupMiniClusterWithRandomPorts(conf, conf, 2)
+ masterEndpoint = master.conf.get(CelebornConf.MASTER_ENDPOINTS.key)
+ }
+
+ override def afterAll(): Unit = {
+ logInfo("all test complete , stop Celeborn mini cluster")
+ super.shutdownMiniCluster()
+ }
+
+ test("push merged data and partial partition are split") {
+ val SHUFFLE_ID = 0
+ val MAP_ID = 0
+ val ATTEMPT_ID = 0
+ val MAP_NUM = 1
+ val PARTITION_NUM = 3
+
+ Array("SOFT", "HARD").foreach {
+ splitMode =>
+ val APP = s"app-${System.currentTimeMillis()}"
+ val clientConf = new CelebornConf()
+ .set(CelebornConf.MASTER_ENDPOINTS.key, masterEndpoint)
+ .set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "1")
+ .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "5K")
+ .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_MODE.key, splitMode)
+ val lifecycleManager = new LifecycleManager(APP, clientConf)
+ val shuffleClient = new ShuffleClientImpl(APP, clientConf,
UserIdentifier("mock", "mock"))
+ shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+ // ping and reserveSlots
+ val DATA0 =
RandomStringUtils.secure().next(10).getBytes(StandardCharsets.UTF_8)
+ shuffleClient.pushData(
+ SHUFFLE_ID,
+ MAP_ID,
+ ATTEMPT_ID,
+ 0,
+ DATA0,
+ 0,
+ DATA0.length,
+ MAP_NUM,
+ PARTITION_NUM)
+
+ // find the worker that has at least 2 partitions
+ val partitionLocationMap =
+ shuffleClient.getPartitionLocation(SHUFFLE_ID, MAP_NUM,
PARTITION_NUM)
+ val worker2PartitionIds = mutable.Map.empty[WorkerInfo,
ArrayBuffer[Int]]
+ for (partitionId <- 0 until PARTITION_NUM) {
+ val partitionLocation = partitionLocationMap.get(partitionId)
+ worker2PartitionIds
+ .getOrElseUpdate(partitionLocation.getWorker, ArrayBuffer.empty)
+ .append(partitionId)
+ }
+ val partitions = worker2PartitionIds.values.filter(_.size >= 2).head
+ assert(partitions.length >= 2)
+
+ // prepare merged data
+ val PARTITION0_DATA =
RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
+ shuffleClient.mergeData(
+ SHUFFLE_ID,
+ MAP_ID,
+ ATTEMPT_ID,
+ partitions(0),
+ PARTITION0_DATA,
+ 0,
+ PARTITION0_DATA.length,
+ MAP_NUM,
+ PARTITION_NUM)
+
+ val PARTITION1_DATA =
RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
+ shuffleClient.mergeData(
+ SHUFFLE_ID,
+ MAP_ID,
+ ATTEMPT_ID,
+ partitions(1),
+ PARTITION1_DATA,
+ 0,
+ PARTITION1_DATA.length,
+ MAP_NUM,
+ PARTITION_NUM)
+
+ // pushData until partition(0) is split
+ val GIANT_DATA =
+ RandomStringUtils.secure().next(1024 *
100).getBytes(StandardCharsets.UTF_8)
+ shuffleClient.pushData(
+ SHUFFLE_ID,
+ MAP_ID,
+ ATTEMPT_ID,
+ partitions(0),
+ GIANT_DATA,
+ 0,
+ GIANT_DATA.length,
+ MAP_NUM,
+ PARTITION_NUM)
+ for (_ <- 0 until 5) {
+ val TRIGGER_DATA =
RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
+ shuffleClient.pushData(
+ SHUFFLE_ID,
+ MAP_ID,
+ ATTEMPT_ID,
+ partitions(0),
+ TRIGGER_DATA,
+ 0,
+ TRIGGER_DATA.length,
+ MAP_NUM,
+ PARTITION_NUM)
+ Thread.sleep(5 * 1000) // wait for flush
+ }
+ assert(
+ partitionLocationMap.get(partitions(0)).getEpoch > 0
+ ) // means partition(0) will be split
+
+ // push merged data, we expect that partition(0) will be split, while
partition(1) will not be split
+ shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID)
+ shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM)
+ assert(
+ partitionLocationMap.get(partitions(1)).getEpoch == 0
+ ) // means partition(1) will not be split
+ }
+ }
+}