This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new d30c02e36 [CELEBORN-2235][CIP-14] Adapt Java end's serialization to
CppWriterClient
d30c02e36 is described below
commit d30c02e3690a5ecfe3cbb32f2d21ba101d37c679
Author: HolyLow <[email protected]>
AuthorDate: Mon Jan 5 22:24:22 2026 +0800
[CELEBORN-2235][CIP-14] Adapt Java end's serialization to CppWriterClient
### What changes were proposed in this pull request?
This PR adapts Java end's serialization to CppWriterClient, including
RegisterShuffle/Response, Revive/Response, MapperEnd/Response. Joint test for
cpp-write java-read procedure is included as well.
### Why are the changes needed?
Support writing to Celeborn server with CppWriterClient.
### Does this PR resolve a correctness bug?
No.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and integration tests.
Closes #3561 from
HolyLow/issue/celeborn-2235-adapt-java-to-cpp-writer-serialization.
Authored-by: HolyLow <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.github/workflows/cpp_integration.yml | 13 +-
.../flink/client/FlinkShuffleClientImpl.java | 5 +-
.../apache/celeborn/client/ShuffleClientImpl.java | 56 +++---
.../apache/celeborn/client/LifecycleManager.scala | 114 ++++++-----
.../client/RequestLocationCallContext.scala | 17 +-
.../celeborn/client/ShuffleClientBaseSuiteJ.java | 12 +-
.../celeborn/client/ShuffleClientSuiteJ.java | 4 +-
.../common/protocol/message/ControlMessages.scala | 217 +++++++++++++--------
.../apache/celeborn/common/util/UtilsSuite.scala | 13 +-
cpp/celeborn/tests/CMakeLists.txt | 27 ++-
cpp/celeborn/tests/DataSumWithWriterClient.cpp | 96 +++++++++
...Z4.scala => CppWriteJavaReadTestWithNONE.scala} | 4 +-
....scala => JavaCppHybridReadWriteTestBase.scala} | 98 +++++++++-
.../cluster/JavaWriteCppReadTestWithLZ4.scala | 2 +-
.../cluster/JavaWriteCppReadTestWithNONE.scala | 2 +-
.../cluster/JavaWriteCppReadTestWithZSTD.scala | 2 +-
16 files changed, 492 insertions(+), 190 deletions(-)
diff --git a/.github/workflows/cpp_integration.yml
b/.github/workflows/cpp_integration.yml
index 1c521f5d0..4f04c4bad 100644
--- a/.github/workflows/cpp_integration.yml
+++ b/.github/workflows/cpp_integration.yml
@@ -85,24 +85,31 @@ jobs:
check-latest: false
- name: Compile & Install Celeborn Java
run: build/mvn clean install -DskipTests
- - name: Run Java-Cpp Hybrid Integration Test
+ - name: Run Java-Write Cpp-Read Hybrid Integration Test (NONE
Decompression)
run: |
build/mvn -pl worker \
test-compile exec:java \
-Dexec.classpathScope="test" \
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithNONE"
\
-Dexec.args="-XX:MaxDirectMemorySize=2G"
- - name: Run Java-Cpp Hybrid Integration Test (LZ4 Decompression)
+ - name: Run Java-Write Cpp-Read Hybrid Integration Test (LZ4
Decompression)
run: |
build/mvn -pl worker \
test-compile exec:java \
-Dexec.classpathScope="test" \
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithLZ4"
\
-Dexec.args="-XX:MaxDirectMemorySize=2G"
- - name: Run Java-Cpp Hybrid Integration Test (ZSTD Decompression)
+ - name: Run Java-Write Cpp-Read Hybrid Integration Test (ZSTD
Decompression)
run: |
build/mvn -pl worker \
test-compile exec:java \
-Dexec.classpathScope="test" \
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithZSTD"
\
-Dexec.args="-XX:MaxDirectMemorySize=2G"
+ - name: Run Cpp-Write Java-Read Hybrid Integration Test (NONE
Compression)
+ run: |
+ build/mvn -pl worker \
+ test-compile exec:java \
+ -Dexec.classpathScope="test" \
+
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.CppWriteJavaReadTestWithNONE"
\
+ -Dexec.args="-XX:MaxDirectMemorySize=2G"
\ No newline at end of file
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
index 87a80006a..6f52b4aab 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java
@@ -46,6 +46,7 @@ import
org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.MessageType;
@@ -528,7 +529,7 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
public Optional<PartitionLocation> revive(
int shuffleId, int mapId, int attemptId, PartitionLocation location)
throws CelebornIOException {
- Set<Integer> mapIds = new HashSet<>();
+ List<Integer> mapIds = new ArrayList<>();
mapIds.add(mapId);
List<ReviveRequest> requests = new ArrayList<>();
ReviveRequest req =
@@ -543,7 +544,7 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
requests.add(req);
PbChangeLocationResponse response =
lifecycleManagerRef.askSync(
- ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests),
+ ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests,
SerdeVersion.V1),
conf.clientRpcRequestPartitionLocationAskTimeout(),
ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
// per partitionKey only serve single PartitionLocation in Client Cache.
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index cfe40a296..f6a7d9750 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -550,11 +550,11 @@ public class ShuffleClientImpl extends ShuffleClient {
numPartitions,
() ->
lifecycleManagerRef.askSync(
- RegisterShuffle$.MODULE$.apply(shuffleId, numMappers,
numPartitions),
+ new RegisterShuffle(shuffleId, numMappers, numPartitions,
SerdeVersion.V1),
conf.clientRpcRegisterShuffleAskTimeout(),
rpcMaxRetries,
rpcRetryWait,
- ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+ ClassTag$.MODULE$.apply(RegisterShuffleResponse.class)));
}
@Override
@@ -593,7 +593,7 @@ public class ShuffleClientImpl extends ShuffleClient {
partitionId,
isSegmentGranularityVisible),
conf.clientRpcRegisterShuffleAskTimeout(),
- ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
+ ClassTag$.MODULE$.apply(RegisterShuffleResponse.class)));
return partitionLocationMap.get(partitionId);
}
@@ -709,23 +709,18 @@ public class ShuffleClientImpl extends ShuffleClient {
}
private ConcurrentHashMap<Integer, PartitionLocation>
registerShuffleInternal(
- int shuffleId,
- int numMappers,
- int numPartitions,
- Callable<PbRegisterShuffleResponse> callable)
+ int shuffleId, int numMappers, int numPartitions,
Callable<RegisterShuffleResponse> callable)
throws CelebornIOException {
int numRetries = registerShuffleMaxRetries;
StatusCode lastFailedStatusCode = null;
while (numRetries > 0) {
try {
- PbRegisterShuffleResponse response = callable.call();
- StatusCode respStatus = StatusCode.fromValue(response.getStatus());
+ RegisterShuffleResponse response = callable.call();
+ StatusCode respStatus = response.status();
if (StatusCode.SUCCESS.equals(respStatus)) {
ConcurrentHashMap<Integer, PartitionLocation> result =
JavaUtils.newConcurrentHashMap();
- Tuple2<List<PartitionLocation>, List<PartitionLocation>> locations =
- PbSerDeUtils.fromPbPackedPartitionLocationsPair(
- response.getPackedPartitionLocationsPair());
- for (PartitionLocation location : locations._1) {
+ PartitionLocation[] locations = response.partitionLocations();
+ for (PartitionLocation location : locations) {
pushExcludedWorkers.remove(location.hostAndPushPort());
if (location.hasPeer()) {
pushExcludedWorkers.remove(location.getPeer().hostAndPushPort());
@@ -900,43 +895,43 @@ public class ShuffleClientImpl extends ShuffleClient {
oldLocMap.put(req.partitionId, req.loc);
}
try {
- PbChangeLocationResponse response =
+ ChangeLocationResponse response =
lifecycleManagerRef.askSync(
- Revive$.MODULE$.apply(shuffleId, mapIds, requests),
+ Revive$.MODULE$.apply(
+ shuffleId, new ArrayList<>(mapIds), new
ArrayList<>(requests), SerdeVersion.V1),
conf.clientRpcRequestPartitionLocationAskTimeout(),
- ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
+ ClassTag$.MODULE$.apply(ChangeLocationResponse.class));
- for (int i = 0; i < response.getEndedMapIdCount(); i++) {
- int mapId = response.getEndedMapId(i);
+ for (int i = 0; i < response.endedMapIds().size(); i++) {
+ int mapId = response.endedMapIds().get(i);
mapperEndMap.computeIfAbsent(shuffleId, (id) ->
ConcurrentHashMap.newKeySet()).add(mapId);
}
- for (int i = 0; i < response.getPartitionInfoCount(); i++) {
- PbChangeLocationPartitionInfo partitionInfo =
response.getPartitionInfo(i);
- int partitionId = partitionInfo.getPartitionId();
- int statusCode = partitionInfo.getStatus();
- if (partitionInfo.getOldAvailable()) {
+ for (Map.Entry<Integer, Tuple3<StatusCode, Boolean, PartitionLocation>>
entry :
+ response.newLocs().entrySet()) {
+ int partitionId = entry.getKey();
+ StatusCode statusCode = entry.getValue()._1();
+ if (entry.getValue()._2() != null) {
PartitionLocation oldLoc = oldLocMap.get(partitionId);
// Currently, revive only check if main location available, here
won't remove peer loc.
pushExcludedWorkers.remove(oldLoc.hostAndPushPort());
}
- if (StatusCode.SUCCESS.getValue() == statusCode) {
- PartitionLocation loc =
-
PbSerDeUtils.fromPbPartitionLocation(partitionInfo.getPartition());
+ if (StatusCode.SUCCESS == statusCode) {
+ PartitionLocation loc = entry.getValue()._3();
partitionLocationMap.put(partitionId, loc);
pushExcludedWorkers.remove(loc.hostAndPushPort());
if (loc.hasPeer()) {
pushExcludedWorkers.remove(loc.getPeer().hostAndPushPort());
}
- } else if (StatusCode.STAGE_ENDED.getValue() == statusCode) {
+ } else if (StatusCode.STAGE_ENDED == statusCode) {
stageEndShuffleSet.add(shuffleId);
return results;
- } else if (StatusCode.SHUFFLE_UNREGISTERED.getValue() == statusCode) {
+ } else if (StatusCode.SHUFFLE_UNREGISTERED == statusCode) {
logger.error("SHUFFLE_NOT_REGISTERED!");
return null;
}
- results.put(partitionId, statusCode);
+ results.put(partitionId, (int) (statusCode.getValue()));
}
return results;
@@ -1806,7 +1801,8 @@ public class ShuffleClientImpl extends ShuffleClient {
pushState.getFailedBatches(),
numPartitions,
crc32PerPartition,
- bytesPerPartition),
+ bytesPerPartition,
+ SerdeVersion.V1),
rpcMaxRetries,
rpcRetryWait,
ClassTag$.MODULE$.apply(MapperEndResponse.class));
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 231898535..f48f3cd72 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -156,7 +156,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
case class RegisterCallContext(context: RpcCallContext, partitionId: Int =
-1) {
- def reply(response: PbRegisterShuffleResponse) = {
+ def reply(response: RegisterShuffleResponse) = {
context.reply(response)
}
}
@@ -360,14 +360,12 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any,
Unit] = {
- case pb: PbRegisterShuffle =>
- val shuffleId = pb.getShuffleId
- val numMappers = pb.getNumMappers
- val numPartitions = pb.getNumPartitions
+ case RegisterShuffle(shuffleId, numMappers, numPartitions, serdeVersion) =>
logDebug(s"Received RegisterShuffle request, " +
s"$shuffleId, $numMappers, $numPartitions.")
offerAndReserveSlots(
RegisterCallContext(context),
+ serdeVersion,
shuffleId,
numMappers,
numPartitions)
@@ -384,31 +382,25 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
offerAndReserveSlots(
RegisterCallContext(context, partitionId),
+ // Use V1 as this is only supported in java
+ SerdeVersion.V1,
shuffleId,
numMappers,
numMappers,
partitionId,
isSegmentGranularityVisible)
- case pb: PbRevive =>
- val shuffleId = pb.getShuffleId
- val mapIds = pb.getMapIdList
- val partitionInfos = pb.getPartitionInfoList
-
+ case Revive(shuffleId, mapIds, reviveRequests, serdeVersion) =>
val partitionIds = new util.ArrayList[Integer]()
val epochs = new util.ArrayList[Integer]()
val oldPartitions = new util.ArrayList[PartitionLocation]()
val causes = new util.ArrayList[StatusCode]()
- (0 until partitionInfos.size()).foreach { idx =>
- val info = partitionInfos.get(idx)
- partitionIds.add(info.getPartitionId)
- epochs.add(info.getEpoch)
- if (info.hasPartition) {
-
oldPartitions.add(PbSerDeUtils.fromPbPartitionLocation(info.getPartition))
- } else {
- oldPartitions.add(null)
- }
- causes.add(StatusCode.fromValue(info.getStatus))
+ (0 until reviveRequests.size()).foreach { idx =>
+ val request = reviveRequests.get(idx)
+ partitionIds.add(request.partitionId)
+ epochs.add(request.epoch)
+ oldPartitions.add(request.loc)
+ causes.add(request.cause)
}
logDebug(s"Received Revive request, number of partitions
${partitionIds.size()}")
handleRevive(
@@ -418,7 +410,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
partitionIds,
epochs,
oldPartitions,
- causes)
+ causes,
+ serdeVersion)
case pb: PbPartitionSplit =>
val shuffleId = pb.getShuffleId
@@ -428,7 +421,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
logTrace(s"Received split request, " +
s"$shuffleId, $partitionId, $epoch, $oldPartition")
changePartitionManager.handleRequestPartitionLocation(
- ChangeLocationsCallContext(context, 1),
+ // TODO: this message is not supported in cppClient yet.
+ ChangeLocationsCallContext(context, 1, SerdeVersion.V1),
shuffleId,
partitionId,
epoch,
@@ -444,7 +438,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
pushFailedBatch,
numPartitions,
crc32PerPartition,
- bytesWrittenPerPartition) =>
+ bytesWrittenPerPartition,
+ serdeVersion) =>
logTrace(s"Received MapperEnd TaskEnd request, " +
s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}")
val partitionType = getPartitionType(shuffleId)
@@ -459,7 +454,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
pushFailedBatch,
numPartitions,
crc32PerPartition,
- bytesWrittenPerPartition)
+ bytesWrittenPerPartition,
+ serdeVersion)
case PartitionType.MAP =>
handleMapPartitionEnd(
context,
@@ -467,7 +463,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
mapId,
attemptId,
partitionId,
- numMappers)
+ numMappers,
+ serdeVersion)
case _ =>
throw new UnsupportedOperationException(s"Not support $partitionType
yet")
}
@@ -618,6 +615,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
private def offerAndReserveSlots(
context: RegisterCallContext,
+ serdeVersion: SerdeVersion,
shuffleId: Int,
numMappers: Int,
numPartitions: Int,
@@ -641,13 +639,15 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
processMapTaskReply(
shuffleId,
rpcContext,
+ serdeVersion,
partitionId,
getLatestLocs(shuffleId, p => p.getId == partitionId))
case PartitionType.REDUCE =>
if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) {
context.reply(RegisterShuffleResponse(
StatusCode.SUCCESS,
- getLatestLocs(shuffleId, _ => true)))
+ getLatestLocs(shuffleId, _ => true),
+ serdeVersion))
} else {
val cachedMsg = registerShuffleResponseRpcCache.get(
shuffleId,
@@ -656,7 +656,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(
RegisterShuffleResponse(
StatusCode.SUCCESS,
- getLatestLocs(shuffleId, _ => true)))
+ getLatestLocs(shuffleId, _ => true),
+ serdeVersion))
}
})
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
@@ -699,15 +700,16 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
def processMapTaskReply(
shuffleId: Int,
context: RpcCallContext,
+ serdeVersion: SerdeVersion,
partitionId: Int,
partitionLocations: Array[PartitionLocation]): Unit = {
// if any partition location resource exist just reply
if (partitionLocations.size > 0) {
- context.reply(RegisterShuffleResponse(StatusCode.SUCCESS,
partitionLocations))
+ context.reply(RegisterShuffleResponse(StatusCode.SUCCESS,
partitionLocations, serdeVersion))
} else {
// request new resource for this task
changePartitionManager.handleRequestPartitionLocation(
- ApplyNewLocationCallContext(context),
+ ApplyNewLocationCallContext(context, serdeVersion),
shuffleId,
partitionId,
-1,
@@ -717,13 +719,13 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
// Reply to all RegisterShuffle request for current shuffle id.
- def replyRegisterShuffle(response: PbRegisterShuffleResponse): Unit = {
+ def replyRegisterShuffle(response: RegisterShuffleResponse): Unit = {
registeringShuffleRequest.synchronized {
val serializedMsg: Option[ByteBuffer] = partitionType match {
case PartitionType.REDUCE =>
context.context match {
case remoteContext: RemoteNettyRpcCallContext =>
- if (response.getStatus == StatusCode.SUCCESS.getValue) {
+ if (response.status == StatusCode.SUCCESS) {
Option(remoteContext.nettyEnv.serialize(
response))
} else {
@@ -735,19 +737,19 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
case _ => Option.empty
}
- val locations = PbSerDeUtils.fromPbPackedPartitionLocationsPair(
- response.getPackedPartitionLocationsPair)._1.asScala
+ val locations = response.partitionLocations
registeringShuffleRequest.asScala
.get(shuffleId)
.foreach(_.asScala.foreach(context => {
partitionType match {
case PartitionType.MAP =>
- if (response.getStatus == StatusCode.SUCCESS.getValue) {
- val partitionLocations = locations.filter(_.getId ==
context.partitionId).toArray
+ if (response.status == StatusCode.SUCCESS) {
+ val partitionLocations = locations.filter(_.getId ==
context.partitionId)
processMapTaskReply(
shuffleId,
context.context,
+ serdeVersion,
context.partitionId,
partitionLocations)
} else {
@@ -757,7 +759,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
case PartitionType.REDUCE =>
if (context.context.isInstanceOf[
- LocalNettyRpcCallContext] || response.getStatus !=
StatusCode.SUCCESS.getValue) {
+ LocalNettyRpcCallContext] || response.status !=
StatusCode.SUCCESS) {
context.reply(response)
} else {
registerShuffleResponseRpcCache.put(shuffleId,
serializedMsg.get)
@@ -780,17 +782,26 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
res.status match {
case StatusCode.REQUEST_FAILED =>
logInfo(s"OfferSlots RPC request failed for $shuffleId!")
-
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.REQUEST_FAILED,
Array.empty))
+ replyRegisterShuffle(RegisterShuffleResponse(
+ StatusCode.REQUEST_FAILED,
+ Array.empty,
+ serdeVersion))
return
case StatusCode.SLOT_NOT_AVAILABLE =>
logInfo(s"OfferSlots for $shuffleId failed!")
-
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE,
Array.empty))
+ replyRegisterShuffle(RegisterShuffleResponse(
+ StatusCode.SLOT_NOT_AVAILABLE,
+ Array.empty,
+ serdeVersion))
return
case StatusCode.SUCCESS =>
logDebug(s"OfferSlots for $shuffleId Success!Slots Info:
${res.workerResource}")
case StatusCode.WORKER_EXCLUDED =>
logInfo(s"OfferSlots for $shuffleId failed due to all workers be
excluded!")
-
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.WORKER_EXCLUDED,
Array.empty))
+ replyRegisterShuffle(RegisterShuffleResponse(
+ StatusCode.WORKER_EXCLUDED,
+ Array.empty,
+ serdeVersion))
return
case _ => // won't happen
throw new UnsupportedOperationException()
@@ -823,7 +834,10 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
// If reserve slots failed, clear allocated resources, reply
ReserveSlotFailed and return.
if (!reserveSlotsSuccess) {
logError(s"reserve buffer for $shuffleId failed, reply to all.")
-
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED,
Array.empty))
+ replyRegisterShuffle(RegisterShuffleResponse(
+ StatusCode.RESERVE_SLOTS_FAILED,
+ Array.empty,
+ serdeVersion))
} else {
if (log.isDebugEnabled()) {
logDebug(s"ReserveSlots for $shuffleId success with details:$slots!")
@@ -851,7 +865,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
val allPrimaryPartitionLocations =
slots.asScala.flatMap(_._2._1.asScala).toArray
replyRegisterShuffle(RegisterShuffleResponse(
StatusCode.SUCCESS,
- allPrimaryPartitionLocations))
+ allPrimaryPartitionLocations,
+ serdeVersion))
}
}
@@ -862,9 +877,10 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
partitionIds: util.List[Integer],
oldEpochs: util.List[Integer],
oldPartitions: util.List[PartitionLocation],
- causes: util.List[StatusCode]): Unit = {
+ causes: util.List[StatusCode],
+ serdeVersion: SerdeVersion): Unit = {
val contextWrapper =
- ChangeLocationsCallContext(context, partitionIds.size())
+ ChangeLocationsCallContext(context, partitionIds.size(), serdeVersion)
// If shuffle not registered, reply ShuffleNotRegistered and return
if (!registeredShuffle.contains(shuffleId)) {
logError(s"[handleRevive] shuffle $shuffleId not registered!")
@@ -916,7 +932,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
numPartitions: Int,
crc32PerPartition: Array[Int],
- bytesWrittenPerPartition: Array[Long]): Unit = {
+ bytesWrittenPerPartition: Array[Long],
+ serdeVersion: SerdeVersion): Unit = {
val (mapperAttemptFinishedSuccess, allMapperFinished) =
commitManager.finishMapperAttempt(
@@ -936,7 +953,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
// reply success
- context.reply(MapperEndResponse(StatusCode.SUCCESS))
+ context.reply(MapperEndResponse(StatusCode.SUCCESS, serdeVersion))
}
private def handleGetReducerFileGroup(
@@ -1205,7 +1222,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
mapId: Int,
attemptId: Int,
partitionId: Int,
- numMappers: Int): Unit = {
+ numMappers: Int,
+ serdeVersion: SerdeVersion): Unit = {
def reply(result: Boolean): Unit = {
val message =
s"to handle MapPartitionEnd for ${Utils.makeMapKey(shuffleId, mapId,
attemptId)}, " +
@@ -1213,10 +1231,10 @@ class LifecycleManager(val appUniqueId: String, val
conf: CelebornConf) extends
result match {
case true => // if already committed by another try
logDebug(s"Succeed $message")
- context.reply(MapperEndResponse(StatusCode.SUCCESS))
+ context.reply(MapperEndResponse(StatusCode.SUCCESS, serdeVersion))
case false =>
logError(s"Failed $message, reply ${StatusCode.SHUFFLE_DATA_LOST}.")
- context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST))
+ context.reply(MapperEndResponse(StatusCode.SHUFFLE_DATA_LOST,
serdeVersion))
}
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
index 091960a4c..9de71dd46 100644
---
a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala
@@ -20,6 +20,7 @@ package org.apache.celeborn.client
import java.util
import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.network.protocol.SerdeVersion
import org.apache.celeborn.common.protocol.PartitionLocation
import
org.apache.celeborn.common.protocol.message.ControlMessages.{ChangeLocationResponse,
RegisterShuffleResponse}
import org.apache.celeborn.common.protocol.message.StatusCode
@@ -36,11 +37,12 @@ trait RequestLocationCallContext {
case class ChangeLocationsCallContext(
context: RpcCallContext,
- partitionCount: Int)
+ partitionCount: Int,
+ serdeVersion: SerdeVersion)
extends RequestLocationCallContext with Logging {
- val endedMapIds = new util.HashSet[Integer]()
+ val endedMapIds = new util.ArrayList[Integer]()
val newLocs =
- JavaUtils.newConcurrentHashMap[Integer, (StatusCode, Boolean,
PartitionLocation)](
+ JavaUtils.newConcurrentHashMap[Integer, (StatusCode, java.lang.Boolean,
PartitionLocation)](
partitionCount)
def markMapperEnd(mapId: Int): Unit = this.synchronized {
@@ -59,12 +61,13 @@ case class ChangeLocationsCallContext(
if (newLocs.size() == partitionCount || StatusCode.SHUFFLE_UNREGISTERED ==
status
|| StatusCode.STAGE_ENDED == status) {
- context.reply(ChangeLocationResponse(endedMapIds, newLocs))
+ context.reply(ChangeLocationResponse(endedMapIds, newLocs, serdeVersion))
}
}
}
-case class ApplyNewLocationCallContext(context: RpcCallContext) extends
RequestLocationCallContext {
+case class ApplyNewLocationCallContext(context: RpcCallContext, serdeVersion:
SerdeVersion)
+ extends RequestLocationCallContext {
override def reply(
partitionId: Int,
status: StatusCode,
@@ -72,8 +75,8 @@ case class ApplyNewLocationCallContext(context:
RpcCallContext) extends RequestL
available: Boolean): Unit = {
partitionLocationOpt match {
case Some(partitionLocation) =>
- context.reply(RegisterShuffleResponse(status,
Array(partitionLocation)))
- case None => context.reply(RegisterShuffleResponse(status, Array.empty))
+ context.reply(RegisterShuffleResponse(status,
Array(partitionLocation), serdeVersion))
+ case None => context.reply(RegisterShuffleResponse(status, Array.empty,
serdeVersion))
}
}
}
diff --git
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
index 7a7706973..2cc2cd1fd 100644
---
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
+++
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
@@ -30,9 +30,9 @@ import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
import org.apache.celeborn.common.protocol.CompressionCodec;
import org.apache.celeborn.common.protocol.PartitionLocation;
-import org.apache.celeborn.common.protocol.PbRegisterShuffleResponse;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
@@ -90,12 +90,14 @@ public abstract class ShuffleClientBaseSuiteJ {
primaryLocation.setPeer(replicaLocation);
when(endpointRef.askSync(
- ControlMessages.RegisterShuffle$.MODULE$.apply(TEST_SHUFFLE_ID, 1,
1),
- ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)))
+ new ControlMessages.RegisterShuffle(TEST_SHUFFLE_ID, 1, 1,
SerdeVersion.V1),
+
ClassTag$.MODULE$.apply(ControlMessages.RegisterShuffleResponse.class)))
.thenAnswer(
t ->
- ControlMessages.RegisterShuffleResponse$.MODULE$.apply(
- StatusCode.SUCCESS, new PartitionLocation[]
{primaryLocation}));
+ new ControlMessages.RegisterShuffleResponse(
+ StatusCode.SUCCESS,
+ new PartitionLocation[] {primaryLocation},
+ SerdeVersion.V1));
shuffleClient.setupLifecycleManagerRef(endpointRef);
when(clientFactory.createClient(
diff --git
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index 0f4b5c30f..e6d450d87 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -263,13 +263,13 @@ public class ShuffleClientSuiteJ {
.thenAnswer(
t ->
RegisterShuffleResponse$.MODULE$.apply(
- statusCode, new PartitionLocation[] {primaryLocation}));
+ statusCode, new PartitionLocation[] {primaryLocation},
SerdeVersion.V1));
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
.thenAnswer(
t ->
RegisterShuffleResponse$.MODULE$.apply(
- statusCode, new PartitionLocation[] {primaryLocation}));
+ statusCode, new PartitionLocation[] {primaryLocation},
SerdeVersion.V1));
shuffleClient.setupLifecycleManagerRef(endpointRef);
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index eb9274632..36f164d69 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -131,17 +131,12 @@ object ControlMessages extends Logging {
workerEvent: WorkerEventType = WorkerEventType.None)
extends MasterMessage
- object RegisterShuffle {
- def apply(
- shuffleId: Int,
- numMappers: Int,
- numPartitions: Int): PbRegisterShuffle =
- PbRegisterShuffle.newBuilder()
- .setShuffleId(shuffleId)
- .setNumMappers(numMappers)
- .setNumPartitions(numPartitions)
- .build()
- }
+ case class RegisterShuffle(
+ shuffleId: Int,
+ numMappers: Int,
+ numPartitions: Int,
+ serdeVersion: SerdeVersion)
+ extends MasterMessage
object RegisterMapPartitionTask {
def apply(
@@ -161,17 +156,10 @@ object ControlMessages extends Logging {
.build()
}
- object RegisterShuffleResponse {
- def apply(
- status: StatusCode,
- partitionLocations: Array[PartitionLocation]):
PbRegisterShuffleResponse = {
- val builder = PbRegisterShuffleResponse.newBuilder()
- .setStatus(status.getValue)
- builder.setPackedPartitionLocationsPair(
-
PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList))
- builder.build()
- }
- }
+ case class RegisterShuffleResponse(
+ status: StatusCode,
+ partitionLocations: Array[PartitionLocation],
+ serdeVersion: SerdeVersion) extends MasterMessage
case class RequestSlots(
applicationId: String,
@@ -195,29 +183,11 @@ object ControlMessages extends Logging {
packed: Boolean = false)
extends MasterMessage
- object Revive {
- def apply(
- shuffleId: Int,
- mapIds: util.Set[Integer],
- reviveRequests: util.Collection[ReviveRequest]): PbRevive = {
- val builder = PbRevive.newBuilder()
- .setShuffleId(shuffleId)
- .addAllMapId(mapIds)
-
- reviveRequests.asScala.foreach { req =>
- val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder()
- .setPartitionId(req.partitionId)
- .setEpoch(req.epoch)
- .setStatus(req.cause.getValue)
- if (req.loc != null) {
-
partitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc))
- }
- builder.addPartitionInfo(partitionInfoBuilder.build())
- }
-
- builder.build()
- }
- }
+ case class Revive(
+ shuffleId: Int,
+ mapIds: util.List[Integer],
+ reviveRequests: util.List[ReviveRequest],
+ serdeVersion: SerdeVersion) extends MasterMessage
object PartitionSplit {
def apply(
@@ -233,26 +203,10 @@ object ControlMessages extends Logging {
.build()
}
- object ChangeLocationResponse {
- def apply(
- mapIds: util.Set[Integer],
- newLocs: util.Map[Integer, (StatusCode, Boolean, PartitionLocation)])
- : PbChangeLocationResponse = {
- val builder = PbChangeLocationResponse.newBuilder()
- builder.addAllEndedMapId(mapIds)
- newLocs.asScala.foreach { case (partitionId, (status, available, loc)) =>
- val pbChangeLocationPartitionInfoBuilder =
PbChangeLocationPartitionInfo.newBuilder()
- .setPartitionId(partitionId)
- .setStatus(status.getValue)
- .setOldAvailable(available)
- if (loc != null) {
-
pbChangeLocationPartitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(loc))
- }
- builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build())
- }
- builder.build()
- }
- }
+ case class ChangeLocationResponse(
+ endedMapIds: util.List[Integer],
+ newLocs: util.Map[Integer, (StatusCode, java.lang.Boolean,
PartitionLocation)],
+ serdeVersion: SerdeVersion) extends MasterMessage
case class MapperEnd(
shuffleId: Int,
@@ -263,7 +217,8 @@ object ControlMessages extends Logging {
failedBatchSet: util.Map[String, LocationPushFailedBatches],
numPartitions: Int,
crc32PerPartition: Array[Int],
- bytesWrittenPerPartition: Array[Long])
+ bytesWrittenPerPartition: Array[Long],
+ serdeVersion: SerdeVersion)
extends MasterMessage
case class ReadReducerPartitionEnd(
@@ -275,7 +230,7 @@ object ControlMessages extends Logging {
bytesWritten: Long)
extends MasterMessage
- case class MapperEndResponse(status: StatusCode) extends MasterMessage
+ case class MapperEndResponse(status: StatusCode, serdeVersion: SerdeVersion)
extends MasterMessage
case class ReadReducerPartitionEndResponse(status: StatusCode) extends
MasterMessage
@@ -674,14 +629,23 @@ object ControlMessages extends Logging {
.build().toByteArray
new TransportMessage(MessageType.HEARTBEAT_FROM_WORKER_RESPONSE, payload)
- case pb: PbRegisterShuffle =>
- new TransportMessage(MessageType.REGISTER_SHUFFLE, pb.toByteArray)
+ case RegisterShuffle(shuffleId, numMappers, numPartitions, serdeVersion) =>
+ val payload = PbRegisterShuffle.newBuilder()
+ .setShuffleId(shuffleId)
+ .setNumMappers(numMappers)
+ .setNumPartitions(numPartitions)
+ .build().toByteArray
+ new TransportMessage(MessageType.REGISTER_SHUFFLE, payload, serdeVersion)
case pb: PbRegisterMapPartitionTask =>
new TransportMessage(MessageType.REGISTER_MAP_PARTITION_TASK,
pb.toByteArray)
- case pb: PbRegisterShuffleResponse =>
- new TransportMessage(MessageType.REGISTER_SHUFFLE_RESPONSE,
pb.toByteArray)
+ case RegisterShuffleResponse(status, partitionLocations, serdeVersion) =>
+ val payload = PbRegisterShuffleResponse.newBuilder()
+ .setStatus(status.getValue).setPackedPartitionLocationsPair(
+
PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList))
+ .build().toByteArray
+ new TransportMessage(MessageType.REGISTER_SHUFFLE_RESPONSE, payload,
serdeVersion)
case RequestSlots(
applicationId,
@@ -729,11 +693,39 @@ object ControlMessages extends Logging {
val payload = builder.build().toByteArray
new TransportMessage(MessageType.REQUEST_SLOTS_RESPONSE, payload)
- case pb: PbRevive =>
- new TransportMessage(MessageType.CHANGE_LOCATION, pb.toByteArray)
+ case Revive(shuffleId, mapIds, reviveRequests, serdeVersion) =>
+ val builder = PbRevive.newBuilder()
+ .setShuffleId(shuffleId)
+ .addAllMapId(mapIds)
- case pb: PbChangeLocationResponse =>
- new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE,
pb.toByteArray)
+ reviveRequests.asScala.foreach { req =>
+ val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder()
+ .setPartitionId(req.partitionId)
+ .setEpoch(req.epoch)
+ .setStatus(req.cause.getValue)
+ if (req.loc != null) {
+
partitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc))
+ }
+ builder.addPartitionInfo(partitionInfoBuilder.build())
+ }
+ val payload = builder.build().toByteArray
+ new TransportMessage(MessageType.CHANGE_LOCATION, payload, serdeVersion)
+
+ case ChangeLocationResponse(mapIds, newLocs, serdeVersion) =>
+ val builder = PbChangeLocationResponse.newBuilder()
+ builder.addAllEndedMapId(mapIds)
+ newLocs.asScala.foreach { case (partitionId, (status, available, loc)) =>
+ val pbChangeLocationPartitionInfoBuilder =
PbChangeLocationPartitionInfo.newBuilder()
+ .setPartitionId(partitionId)
+ .setStatus(status.getValue)
+ .setOldAvailable(available)
+ if (loc != null) {
+
pbChangeLocationPartitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(loc))
+ }
+ builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build())
+ }
+ val payload = builder.build().toByteArray
+ new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, payload,
serdeVersion)
case MapperEnd(
shuffleId,
@@ -744,7 +736,8 @@ object ControlMessages extends Logging {
pushFailedBatch,
numPartitions,
crc32PerPartition,
- bytesWrittenPerPartition) =>
+ bytesWrittenPerPartition,
+ serdeVersion) =>
val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v)
(k, resultValue)
@@ -761,13 +754,13 @@ object ControlMessages extends Logging {
.addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map(
java.lang.Long.valueOf).toSeq.asJava)
.build().toByteArray
- new TransportMessage(MessageType.MAPPER_END, payload)
+ new TransportMessage(MessageType.MAPPER_END, payload, serdeVersion)
- case MapperEndResponse(status) =>
+ case MapperEndResponse(status, serdeVersion) =>
val payload = PbMapperEndResponse.newBuilder()
.setStatus(status.getValue)
.build().toByteArray
- new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload)
+ new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload,
serdeVersion)
case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible,
serdeVersion) =>
val payload = PbGetReducerFileGroup.newBuilder()
@@ -1132,13 +1125,23 @@ object ControlMessages extends Logging {
pbHeartbeatFromWorkerResponse.getWorkerEventType)
case REGISTER_SHUFFLE_VALUE =>
- PbRegisterShuffle.parseFrom(message.getPayload)
+ val pbRegisterShuffle = PbRegisterShuffle.parseFrom(message.getPayload)
+ RegisterShuffle(
+ pbRegisterShuffle.getShuffleId,
+ pbRegisterShuffle.getNumMappers,
+ pbRegisterShuffle.getNumPartitions,
+ message.getSerdeVersion)
case REGISTER_MAP_PARTITION_TASK_VALUE =>
PbRegisterMapPartitionTask.parseFrom(message.getPayload)
case REGISTER_SHUFFLE_RESPONSE_VALUE =>
- PbRegisterShuffleResponse.parseFrom(message.getPayload)
+ val pbRegisterShuffleResponse =
PbRegisterShuffleResponse.parseFrom(message.getPayload)
+ RegisterShuffleResponse(
+ StatusCode.fromValue(pbRegisterShuffleResponse.getStatus),
+ PbSerDeUtils.fromPbPackedPartitionLocationsPair(
+
pbRegisterShuffleResponse.getPackedPartitionLocationsPair)._1.asScala.toArray,
+ message.getSerdeVersion)
case REQUEST_SLOTS_VALUE =>
val pbRequestSlots = PbRequestSlots.parseFrom(message.getPayload)
@@ -1175,10 +1178,51 @@ object ControlMessages extends Logging {
workerResource)
case CHANGE_LOCATION_VALUE =>
- PbRevive.parseFrom(message.getPayload)
+ val pbRevive = PbRevive.parseFrom(message.getPayload)
+ val shuffleId = pbRevive.getShuffleId
+ val partitionInfos = pbRevive.getPartitionInfoList
+ val reviveRequests = new util.ArrayList[ReviveRequest]()
+ (0 until partitionInfos.size).foreach { idx =>
+ val info = partitionInfos.get(idx)
+ var partition: PartitionLocation = null
+ if (info.hasPartition) {
+ partition = PbSerDeUtils.fromPbPartitionLocation(info.getPartition)
+ }
+ val reviveRequest = new ReviveRequest(
+ shuffleId,
+ -1,
+ -1,
+ info.getPartitionId,
+ info.getEpoch,
+ partition,
+ StatusCode.fromValue(info.getStatus))
+ reviveRequests.add(reviveRequest)
+ }
+ Revive(
+ pbRevive.getShuffleId,
+ pbRevive.getMapIdList,
+ reviveRequests,
+ message.getSerdeVersion)
case CHANGE_LOCATION_RESPONSE_VALUE =>
- PbChangeLocationResponse.parseFrom(message.getPayload)
+ val pbChangeLocationResponse =
PbChangeLocationResponse.parseFrom(message.getPayload)
+ val newLocs =
+ new util.HashMap[Integer, (StatusCode, java.lang.Boolean,
PartitionLocation)]()
+ val partitionInfos = pbChangeLocationResponse.getPartitionInfoList
+ (0 until partitionInfos.size).foreach { idx =>
+ val info = partitionInfos.get(idx)
+ var partition: PartitionLocation = null
+ if (info.hasPartition) {
+ partition = PbSerDeUtils.fromPbPartitionLocation(info.getPartition)
+ }
+ newLocs.put(
+ info.getPartitionId,
+ (StatusCode.fromValue(info.getStatus), info.getOldAvailable,
partition))
+ }
+ ChangeLocationResponse(
+ pbChangeLocationResponse.getEndedMapIdList,
+ newLocs,
+ message.getSerdeVersion)
case MAPPER_END_VALUE =>
val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload)
@@ -1203,7 +1247,8 @@ object ControlMessages extends Logging {
}.toMap.asJava,
pbMapperEnd.getNumPartitions,
crc32Array,
- bytesWrittenPerPartitionArray)
+ bytesWrittenPerPartitionArray,
+ message.getSerdeVersion)
case READ_REDUCER_PARTITION_END_VALUE =>
val pbReadReducerPartitionEnd =
PbReadReducerPartitionEnd.parseFrom(message.getPayload)
@@ -1220,7 +1265,9 @@ object ControlMessages extends Logging {
case MAPPER_END_RESPONSE_VALUE =>
val pbMapperEndResponse =
PbMapperEndResponse.parseFrom(message.getPayload)
- MapperEndResponse(StatusCode.fromValue(pbMapperEndResponse.getStatus))
+ MapperEndResponse(
+ StatusCode.fromValue(pbMapperEndResponse.getStatus),
+ message.getSerdeVersion)
case GET_REDUCER_FILE_GROUP_VALUE =>
val pbGetReducerFileGroup =
PbGetReducerFileGroup.parseFrom(message.getPayload)
diff --git
a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
index 7cadaf07e..8be472b64 100644
--- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
@@ -28,6 +28,7 @@ import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.client.{MasterEndpointResolver,
StaticMasterEndpointResolver}
import org.apache.celeborn.common.exception.CelebornException
import org.apache.celeborn.common.identity.DefaultIdentityProvider
+import org.apache.celeborn.common.network.protocol.SerdeVersion
import org.apache.celeborn.common.protocol.{PartitionLocation,
TransportModuleConstants}
import
org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse,
MapperEnd}
import org.apache.celeborn.common.protocol.message.StatusCode
@@ -149,7 +150,17 @@ class UtilsSuite extends CelebornFunSuite {
test("MapperEnd class convert with pb") {
val mapperEnd =
- MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap(), 1, Array.emptyIntArray,
Array.emptyLongArray)
+ MapperEnd(
+ 1,
+ 1,
+ 1,
+ 2,
+ 1,
+ Collections.emptyMap(),
+ 1,
+ Array.emptyIntArray,
+ Array.emptyLongArray,
+ SerdeVersion.V1)
val mapperEndTrans =
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId)
diff --git a/cpp/celeborn/tests/CMakeLists.txt
b/cpp/celeborn/tests/CMakeLists.txt
index 0bc5e41c9..104607ce2 100644
--- a/cpp/celeborn/tests/CMakeLists.txt
+++ b/cpp/celeborn/tests/CMakeLists.txt
@@ -35,4 +35,29 @@ target_link_libraries(
add_executable(cppDataSumWithReaderClient DataSumWithReaderClient.cpp)
-target_link_libraries(cppDataSumWithReaderClient dataSumWithReaderClient)
\ No newline at end of file
+target_link_libraries(cppDataSumWithReaderClient dataSumWithReaderClient)
+
+add_library(
+ dataSumWithWriterClient
+ DataSumWithWriterClient.cpp)
+
+target_link_libraries(
+ dataSumWithWriterClient
+ memory
+ utils
+ conf
+ proto
+ network
+ protocol
+ client
+ ${WANGLE}
+ ${FIZZ}
+ ${LIBSODIUM_LIBRARY}
+ ${FOLLY_WITH_DEPENDENCIES}
+ ${GLOG}
+ ${GFLAGS_LIBRARIES}
+)
+
+add_executable(cppDataSumWithWriterClient DataSumWithWriterClient.cpp)
+
+target_link_libraries(cppDataSumWithWriterClient dataSumWithWriterClient)
diff --git a/cpp/celeborn/tests/DataSumWithWriterClient.cpp
b/cpp/celeborn/tests/DataSumWithWriterClient.cpp
new file mode 100644
index 000000000..ceaa01beb
--- /dev/null
+++ b/cpp/celeborn/tests/DataSumWithWriterClient.cpp
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/init/Init.h>
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+
+#include <celeborn/client/ShuffleClient.h>
+
+int main(int argc, char** argv) {
+ folly::init(&argc, &argv, false);
+ // Read the configs.
+ assert(argc == 9);
+ std::string lifecycleManagerHost = argv[1];
+ int lifecycleManagerPort = std::atoi(argv[2]);
+ std::string appUniqueId = argv[3];
+ int shuffleId = std::atoi(argv[4]);
+ int attemptId = std::atoi(argv[5]);
+ int numMappers = std::atoi(argv[6]);
+ int numPartitions = std::atoi(argv[7]);
+ std::string resultFile = argv[8];
+ std::cout << "lifecycleManagerHost = " << lifecycleManagerHost
+ << ", lifecycleManagerPort = " << lifecycleManagerPort
+ << ", appUniqueId = " << appUniqueId
+ << ", shuffleId = " << shuffleId << ", attemptId = " << attemptId
+ << ", numMappers = " << numMappers
+ << ", numPartitions = " << numPartitions
+ << ", resultFile = " << resultFile << std::endl;
+
+ // Create shuffleClient and setup.
+ auto conf = std::make_shared<celeborn::conf::CelebornConf>();
+ auto clientEndpoint =
+ std::make_shared<celeborn::client::ShuffleClientEndpoint>(conf);
+ auto shuffleClient = celeborn::client::ShuffleClientImpl::create(
+ appUniqueId, conf, *clientEndpoint);
+ shuffleClient->setupLifecycleManagerRef(
+ lifecycleManagerHost, lifecycleManagerPort);
+
+ long maxData = 1000000;
+ size_t numData = 1000;
+ // Generate data, sum up and pushData.
+ std::vector<long> result(numPartitions, 0);
+ std::vector<size_t> dataCnt(numPartitions, 0);
+ for (int mapId = 0; mapId < numMappers; mapId++) {
+ for (int partitionId = 0; partitionId < numPartitions; partitionId++) {
+ std::string partitionData;
+ for (size_t i = 0; i < numData; i++) {
+ int data = std::rand() % maxData;
+ result[partitionId] += data;
+ dataCnt[partitionId]++;
+ partitionData += "-" + std::to_string(data);
+ }
+ shuffleClient->pushData(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ reinterpret_cast<const uint8_t*>(partitionData.c_str()),
+ 0,
+ partitionData.size(),
+ numMappers,
+ numPartitions);
+ }
+ shuffleClient->mapperEnd(shuffleId, mapId, attemptId, numMappers);
+ }
+ for (int partitionId = 0; partitionId < numPartitions; partitionId++) {
+ std::cout << "partition " << partitionId
+ << " sum result = " << result[partitionId]
+ << ", dataCnt = " << dataCnt[partitionId] << std::endl;
+ }
+
+ // Write result to resultFile.
+ remove(resultFile.c_str());
+ std::ofstream of(resultFile);
+ for (int partitionId = 0; partitionId < numPartitions; partitionId++) {
+ of << result[partitionId] << std::endl;
+ }
+ of.close();
+
+ return 0;
+}
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala
similarity index 88%
copy from
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
copy to
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala
index bc1961384..b7fb62a4a 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/CppWriteJavaReadTestWithNONE.scala
@@ -19,9 +19,9 @@ package org.apache.celeborn.service.deploy.cluster
import org.apache.celeborn.common.protocol.CompressionCodec
-object JavaWriteCppReadTestWithLZ4 extends JavaWriteCppReadTestBase {
+object CppWriteJavaReadTestWithNONE extends JavaCppHybridReadWriteTestBase {
def main(args: Array[String]) = {
- testJavaWriteCppRead(CompressionCodec.LZ4)
+ testCppWriteJavaRead(CompressionCodec.NONE)
}
}
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala
similarity index 60%
rename from
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala
rename to
worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala
index e059754e3..325f9c8b7 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestBase.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaCppHybridReadWriteTestBase.scala
@@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.internal.Logging
@@ -35,7 +36,7 @@ import org.apache.celeborn.common.protocol.CompressionCodec
import org.apache.celeborn.common.util.Utils.runCommand
import org.apache.celeborn.service.deploy.MiniClusterFeature
-trait JavaWriteCppReadTestBase extends AnyFunSuite
+trait JavaCppHybridReadWriteTestBase extends AnyFunSuite
with Logging with MiniClusterFeature with BeforeAndAfterAll {
var masterPort = 0
@@ -147,4 +148,99 @@ trait JavaWriteCppReadTestBase extends AnyFunSuite
shuffleClient.shutdown()
}
+ def testCppWriteJavaRead(codec: CompressionCodec): Unit = {
+ beforeAll()
+ try {
+ runCppWriteJavaRead(codec)
+ } finally {
+ afterAll()
+ }
+ }
+
+ def runCppWriteJavaRead(codec: CompressionCodec): Unit = {
+ val appUniqueId = "test-app"
+ val shuffleId = 0
+ val attemptId = 0
+
+ // Create lifecycleManager.
+ val clientConf = new CelebornConf()
+ .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort")
+ .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key, codec.name)
+ .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true")
+ .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
+ .set(CelebornConf.READ_LOCAL_SHUFFLE_FILE, false)
+ .set("celeborn.data.io.numConnectionsPerPeer", "1")
+ val lifecycleManager = new LifecycleManager(appUniqueId, clientConf)
+
+ // Create writer shuffleClient.
+ val shuffleClient =
+ new ShuffleClientImpl(appUniqueId, clientConf, UserIdentifier("mock",
"mock"))
+ shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+ val numMappers = 2
+ val numPartitions = 2
+
+ // Launch cpp writer to write data, calculate result and write to specific
result file.
+ val cppResultFile = "/tmp/celeborn-cpp-writer-result.txt"
+ val lifecycleManagerHost = lifecycleManager.getHost
+ val lifecycleManagerPort = lifecycleManager.getPort
+ val projectDirectory = new File(new File(".").getAbsolutePath)
+ val cppBinRelativeDirectory = "cpp/build/celeborn/tests/"
+ val cppBinFileName = "cppDataSumWithWriterClient"
+ val cppBinFilePath =
s"$projectDirectory/$cppBinRelativeDirectory/$cppBinFileName"
+ // Execution command: $exec lifecycleManagerHost lifecycleManagerPort
appUniqueId shuffleId attemptId numMappers numPartitions cppResultFile
+ val command = {
+ s"$cppBinFilePath $lifecycleManagerHost $lifecycleManagerPort
$appUniqueId $shuffleId $attemptId $numMappers $numPartitions $cppResultFile"
+ }
+ println(s"run command: $command")
+ val commandOutput = runCommand(command)
+ println(s"command output: $commandOutput")
+
+ val metricsCallback = new MetricsCallback {
+ override def incBytesRead(bytesWritten: Long): Unit = {}
+ override def incReadTime(time: Long): Unit = {}
+ }
+
+ var sums = new util.ArrayList[Long](numPartitions)
+ for (partitionId <- 0 until numPartitions) {
+ sums.add(0)
+ val inputStream = shuffleClient.readPartition(
+ shuffleId,
+ partitionId,
+ attemptId,
+ 0,
+ 0,
+ Integer.MAX_VALUE,
+ metricsCallback)
+ var c = inputStream.read()
+ var data: Long = 0
+ var dataCnt = 0
+ while (c != -1) {
+ if (c == '-') {
+ sums.set(partitionId, sums.get(partitionId) + data)
+ data = 0
+ dataCnt += 1
+ } else {
+ assert(c >= '0' && c <= '9')
+ data *= 10
+ data += c - '0'
+ }
+ c = inputStream.read()
+ }
+ sums.set(partitionId, sums.get(partitionId) + data)
+ println(s"partition $partitionId sum result = ${sums.get(partitionId)},
dataCnt = $dataCnt")
+ }
+
+ // Verify the sum result.
+ var lineCount = 0
+ for (line <- Source.fromFile(cppResultFile, "utf-8").getLines.toList) {
+ val data = line.toLong
+ Assert.assertEquals(data, sums.get(lineCount))
+ lineCount += 1
+ }
+ Assert.assertEquals(lineCount, numPartitions)
+ lifecycleManager.stop()
+ shuffleClient.shutdown()
+ }
+
}
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
index bc1961384..327754ed9 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithLZ4.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster
import org.apache.celeborn.common.protocol.CompressionCodec
-object JavaWriteCppReadTestWithLZ4 extends JavaWriteCppReadTestBase {
+object JavaWriteCppReadTestWithLZ4 extends JavaCppHybridReadWriteTestBase {
def main(args: Array[String]) = {
testJavaWriteCppRead(CompressionCodec.LZ4)
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
index a649f8350..18bb8a418 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithNONE.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster
import org.apache.celeborn.common.protocol.CompressionCodec
-object JavaWriteCppReadTestWithNONE extends JavaWriteCppReadTestBase {
+object JavaWriteCppReadTestWithNONE extends JavaCppHybridReadWriteTestBase {
def main(args: Array[String]) = {
testJavaWriteCppRead(CompressionCodec.NONE)
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
index f2ba2e769..de7cdf102 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaWriteCppReadTestWithZSTD.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.service.deploy.cluster
import org.apache.celeborn.common.protocol.CompressionCodec
-object JavaWriteCppReadTestWithZSTD extends JavaWriteCppReadTestBase {
+object JavaWriteCppReadTestWithZSTD extends JavaCppHybridReadWriteTestBase {
def main(args: Array[String]) = {
testJavaWriteCppRead(CompressionCodec.ZSTD)