This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 30850a658 [CELEBORN-1932][CIP-14] Adapt java's serialization to
support cpp serialization for GetReducerFileGroup/Response
30850a658 is described below
commit 30850a6586abe319d96869f1da028defd8e9546f
Author: HolyLow <[email protected]>
AuthorDate: Thu Apr 3 14:43:47 2025 +0800
[CELEBORN-1932][CIP-14] Adapt java's serialization to support cpp
serialization for GetReducerFileGroup/Response
### What changes were proposed in this pull request?
The java's existing serialization is adapted to support multi-language
serialization. Besides, the GetReducerFileGroup/Response is adapted to java/cpp
modes.
### Why are the changes needed?
To support CppClient communicates with JavaServer.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation.
Closes #3177 from HolyLow/issue/celeborn-1932-adapt-java-serialization.
Authored-by: HolyLow <[email protected]>
Signed-off-by: mingji <[email protected]>
---
.../apache/celeborn/client/ShuffleClientImpl.java | 8 ++---
.../org/apache/celeborn/client/CommitManager.scala | 8 +++--
.../apache/celeborn/client/LifecycleManager.scala | 17 +++++----
.../celeborn/client/commit/CommitHandler.scala | 6 +++-
.../client/commit/MapPartitionCommitHandler.scala | 9 +++--
.../commit/ReducePartitionCommitHandler.scala | 42 ++++++++++++++++------
.../celeborn/client/ShuffleClientSuiteJ.java | 25 ++++++++-----
.../common/network/protocol/SerdeVersion.java | 41 +++++++++++++++++++++
.../common/network/protocol/TransportMessage.java | 17 ++++++++-
.../common/protocol/message/ControlMessages.scala | 22 +++++++-----
.../celeborn/common/rpc/netty/NettyRpcEnv.scala | 16 ++++++++-
.../common/serializer/JavaSerializer.scala | 31 ++++++++++++++++
12 files changed, 198 insertions(+), 44 deletions(-)
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index e7f4a9083..12ef7d330 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -53,10 +53,8 @@ import
org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientBootstrap;
import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.PushData;
-import org.apache.celeborn.common.network.protocol.PushMergedData;
-import org.apache.celeborn.common.network.protocol.TransportMessage;
-import org.apache.celeborn.common.network.protocol.TransportMessagesHelper;
+import org.apache.celeborn.common.network.protocol.*;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
import org.apache.celeborn.common.network.sasl.SaslClientBootstrap;
import org.apache.celeborn.common.network.sasl.SaslCredentials;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
@@ -1803,7 +1801,7 @@ public class ShuffleClientImpl extends ShuffleClient {
}
try {
GetReducerFileGroup getReducerFileGroup =
- new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible);
+ new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible,
SerdeVersion.V1);
GetReducerFileGroupResponse response =
lifecycleManagerRef.askSync(
diff --git
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index ad0d66e3a..bffd05430 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -34,6 +34,7 @@ import org.apache.celeborn.client.listener.{WorkersStatus,
WorkerStatusListener}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.WorkerInfo
+import org.apache.celeborn.common.network.protocol.SerdeVersion
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType,
StorageInfo}
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.RpcCallContext
@@ -275,8 +276,11 @@ class CommitManager(appUniqueId: String, val conf:
CelebornConf, lifecycleManage
getCommitHandler(shuffleId).waitStageEnd(shuffleId)
}
- def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit
= {
- getCommitHandler(shuffleId).handleGetReducerFileGroup(context, shuffleId)
+ def handleGetReducerFileGroup(
+ context: RpcCallContext,
+ shuffleId: Int,
+ serdeVersion: SerdeVersion): Unit = {
+ getCommitHandler(shuffleId).handleGetReducerFileGroup(context, shuffleId,
serdeVersion)
}
// exposed for test
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 35d1945e1..9eac9afb8 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -45,7 +45,7 @@ import org.apache.celeborn.common.identity.{IdentityProvider,
UserIdentifier}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ApplicationMeta,
ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.metrics.source.Role
-import org.apache.celeborn.common.network.protocol.TransportMessagesHelper
+import org.apache.celeborn.common.network.protocol.{SerdeVersion,
TransportMessagesHelper}
import org.apache.celeborn.common.network.sasl.registration.RegistrationInfo
import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP
@@ -432,10 +432,13 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
throw new UnsupportedOperationException(s"Not support $partitionType
yet")
}
- case GetReducerFileGroup(shuffleId: Int, isSegmentGranularityVisible:
Boolean) =>
+ case GetReducerFileGroup(
+ shuffleId: Int,
+ isSegmentGranularityVisible: Boolean,
+ serdeVersion: SerdeVersion) =>
logDebug(
s"Received GetShuffleFileGroup request for shuffleId $shuffleId,
isSegmentGranularityVisible $isSegmentGranularityVisible")
- handleGetReducerFileGroup(context, shuffleId,
isSegmentGranularityVisible)
+ handleGetReducerFileGroup(context, shuffleId,
isSegmentGranularityVisible, serdeVersion)
case pb: PbGetShuffleId =>
val appShuffleId = pb.getAppShuffleId
@@ -845,7 +848,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
private def handleGetReducerFileGroup(
context: RpcCallContext,
shuffleId: Int,
- isSegmentGranularityVisible: Boolean): Unit = {
+ isSegmentGranularityVisible: Boolean,
+ serdeVersion: SerdeVersion): Unit = {
// If isSegmentGranularityVisible is set to true, the downstream reduce
task may start early than upstream map task, e.g. flink hybrid shuffle.
// Under these circumstances, there's a possibility that the shuffle might
not yet be registered when the downstream reduce task send GetReduceFileGroup
request,
// so we shouldn't send a SHUFFLE_NOT_REGISTERED response directly, should
enqueue this request to pending list, and response to the downstream reduce
task the ReduceFileGroup when the upstream map task register shuffle done
@@ -854,10 +858,11 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
context.reply(GetReducerFileGroupResponse(
StatusCode.SHUFFLE_NOT_REGISTERED,
JavaUtils.newConcurrentHashMap(),
- Array.empty))
+ Array.empty,
+ serdeVersion = serdeVersion))
return
}
- commitManager.handleGetReducerFileGroup(context, shuffleId)
+ commitManager.handleGetReducerFileGroup(context, shuffleId, serdeVersion)
}
private def handleGetShuffleIdForApp(
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index ea86b828d..658eb47a7 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -34,6 +34,7 @@ import
org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, Shuffl
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo,
WorkerInfo}
+import org.apache.celeborn.common.network.protocol.SerdeVersion
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
import
org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles,
CommitFilesResponse}
import org.apache.celeborn.common.protocol.message.StatusCode
@@ -178,7 +179,10 @@ abstract class CommitHandler(
* partitions are complete by the time the method is called, as downstream
tasks may start early before all tasks
* are completed.So map partition may need refresh reducer file group if
needed.
*/
- def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit
+ def handleGetReducerFileGroup(
+ context: RpcCallContext,
+ shuffleId: Int,
+ serdeVersion: SerdeVersion): Unit
def removeExpiredShuffle(shuffleId: Int): Unit = {
reducerFileGroupsMap.remove(shuffleId)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index a08f1e0d5..4f31018e5 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -31,6 +31,7 @@ import
org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, Shu
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo,
WorkerInfo}
+import org.apache.celeborn.common.network.protocol.SerdeVersion
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.protocol.message.StatusCode
@@ -230,7 +231,10 @@ class MapPartitionCommitHandler(
shuffleIsSegmentGranularityVisible.get(shuffleId)
}
- override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId:
Int): Unit = {
+ override def handleGetReducerFileGroup(
+ context: RpcCallContext,
+ shuffleId: Int,
+ serdeVersion: SerdeVersion): Unit = {
// TODO: if support the downstream map task start early before the
upstream reduce task, it should
// waiting the upstream task register shuffle, then reply these
GetReducerFileGroup.
// Note that flink hybrid shuffle should support it in the future.
@@ -244,7 +248,8 @@ class MapPartitionCommitHandler(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId,
JavaUtils.newConcurrentHashMap()),
getMapperAttempts(shuffleId),
- succeedPartitionIds))
+ succeedPartitionIds,
+ serdeVersion = serdeVersion))
}
override def releasePartitionResource(shuffleId: Int, partitionId: Int):
Unit = {
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 98fe624fb..5bdd1c550 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -34,6 +34,7 @@ import
org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, Shu
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.ShufflePartitionLocationInfo
+import org.apache.celeborn.common.network.protocol.SerdeVersion
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.protocol.message.StatusCode
@@ -65,8 +66,10 @@ class ReducePartitionCommitHandler(
sharedRpcPool)
with Logging {
+ class MultiSerdeVersionRpcContext(val ctx: RpcCallContext, val serdeVersion:
SerdeVersion) {}
+
private val getReducerFileGroupRequest =
- JavaUtils.newConcurrentHashMap[Int, util.Set[RpcCallContext]]()
+ JavaUtils.newConcurrentHashMap[Int,
util.Set[MultiSerdeVersionRpcContext]]()
private val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
@@ -300,7 +303,7 @@ class ReducePartitionCommitHandler(
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
- getReducerFileGroupRequest.put(shuffleId, new
util.HashSet[RpcCallContext]())
+ getReducerFileGroupRequest.put(shuffleId, new
util.HashSet[MultiSerdeVersionRpcContext]())
initMapperAttempts(shuffleId, numMappers)
}
@@ -314,7 +317,16 @@ class ReducePartitionCommitHandler(
}
}
- private def replyGetReducerFileGroup(context: RpcCallContext, shuffleId:
Int): Unit = {
+ private def replyGetReducerFileGroup(
+ context: MultiSerdeVersionRpcContext,
+ shuffleId: Int): Unit = {
+ replyGetReducerFileGroup(context.ctx, shuffleId, context.serdeVersion)
+ }
+
+ private def replyGetReducerFileGroup(
+ context: RpcCallContext,
+ shuffleId: Int,
+ serdeVersion: SerdeVersion): Unit = {
if (isStageDataLost(shuffleId)) {
context.reply(
GetReducerFileGroupResponse(
@@ -328,7 +340,8 @@ class ReducePartitionCommitHandler(
var response = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId,
JavaUtils.newConcurrentHashMap()),
- getMapperAttempts(shuffleId))
+ getMapperAttempts(shuffleId),
+ serdeVersion = serdeVersion)
// only check whether broadcast enabled for the UTs
if (getReducerFileGroupResponseBroadcastEnabled) {
@@ -348,7 +361,8 @@ class ReducePartitionCommitHandler(
pushFailedBatches =
shufflePushFailedBatches.getOrDefault(
shuffleId,
- new util.HashMap[String, util.Set[PushFailedBatch]]()))
+ new util.HashMap[String, util.Set[PushFailedBatch]]()),
+ serdeVersion = serdeVersion)
val serializedMsg =
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
@@ -382,22 +396,30 @@ class ReducePartitionCommitHandler(
response: GetReducerFileGroupResponse): GetReducerFileGroupResponse = {
lifecycleManager.broadcastGetReducerFileGroupResponse(shuffleId, response)
match {
case Some(broadcastBytes) if broadcastBytes.nonEmpty =>
- GetReducerFileGroupResponse(response.status, broadcast =
broadcastBytes)
+ GetReducerFileGroupResponse(
+ response.status,
+ broadcast = broadcastBytes,
+ serdeVersion = response.serdeVersion)
case _ => response
}
}
- override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId:
Int): Unit = {
+ override def handleGetReducerFileGroup(
+ context: RpcCallContext,
+ shuffleId: Int,
+ serdeVersion: SerdeVersion): Unit = {
// Quick return for ended stage, avoid occupy sync lock.
if (isStageEnd(shuffleId)) {
- replyGetReducerFileGroup(context, shuffleId)
+ replyGetReducerFileGroup(context, shuffleId, serdeVersion)
} else {
getReducerFileGroupRequest.synchronized {
// If setStageEnd() called after isStageEnd and before got lock,
should reply here.
if (isStageEnd(shuffleId)) {
- replyGetReducerFileGroup(context, shuffleId)
+ replyGetReducerFileGroup(context, shuffleId, serdeVersion)
} else {
- getReducerFileGroupRequest.get(shuffleId).add(context)
+ getReducerFileGroupRequest.get(shuffleId).add(new
MultiSerdeVersionRpcContext(
+ context,
+ serdeVersion))
}
}
}
diff --git
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index a5076a59f..85cf0ba10 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -48,6 +48,7 @@ import
org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.SerdeVersion;
import org.apache.celeborn.common.protocol.CompressionCodec;
import org.apache.celeborn.common.protocol.PartitionLocation;
import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$;
@@ -428,7 +429,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -441,7 +443,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
shuffleClient =
@@ -485,7 +488,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -497,7 +501,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
shuffleClient =
@@ -519,7 +524,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -531,7 +537,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
shuffleClient =
@@ -553,7 +560,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -565,7 +573,8 @@ public class ShuffleClientSuiteJ {
new int[0],
Collections.emptySet(),
Collections.emptyMap(),
- new byte[0]);
+ new byte[0],
+ SerdeVersion.V1);
});
shuffleClient =
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SerdeVersion.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SerdeVersion.java
new file mode 100644
index 000000000..177b9bb7d
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SerdeVersion.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+/**
+ * SerdeVersion represents which ser/de version the message is deserialized
from / will be
+ * serialized into. For V1 (used by legacy java engine), the ser/de is
dependent on java's
+ * serialization stack, and the leading byte would be 0xAC according to Java's
serialization stack.
+ * For V2 (used by cpp client), the ser/de is language-agnostic, the leading
byte would be 0xFF as
+ * defined in CelebornCpp module. In this way, messages from/for different
version could be
+ * distinguished and ser/de accordingly.
+ */
+public enum SerdeVersion {
+ V1((byte) 0xAC),
+ V2((byte) 0xFF);
+
+ private final byte marker;
+
+ SerdeVersion(byte marker) {
+ this.marker = marker;
+ }
+
+ public byte getMarker() {
+ return marker;
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 01a9a37f9..137c2e710 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -36,11 +36,17 @@ public class TransportMessage implements Serializable {
@Deprecated private final transient MessageType type;
private final int messageTypeValue;
private final byte[] payload;
+ private final SerdeVersion serdeVersion;
public TransportMessage(MessageType type, byte[] payload) {
+ this(type, payload, SerdeVersion.V1);
+ }
+
+ public TransportMessage(MessageType type, byte[] payload, SerdeVersion
serdeVersion) {
this.type = type;
this.messageTypeValue = type.getNumber();
this.payload = payload;
+ this.serdeVersion = serdeVersion;
}
public MessageType getType() {
@@ -55,6 +61,10 @@ public class TransportMessage implements Serializable {
return payload;
}
+ public SerdeVersion getSerdeVersion() {
+ return serdeVersion;
+ }
+
public <T extends GeneratedMessageV3> T getParsedPayload() throws
InvalidProtocolBufferException {
switch (messageTypeValue) {
case OPEN_STREAM_VALUE:
@@ -132,6 +142,11 @@ public class TransportMessage implements Serializable {
}
public static TransportMessage fromByteBuffer(ByteBuffer buffer) throws
CelebornIOException {
+ return fromByteBuffer(buffer, SerdeVersion.V1);
+ }
+
+ public static TransportMessage fromByteBuffer(ByteBuffer buffer,
SerdeVersion serdeVersion)
+ throws CelebornIOException {
int messageTypeValue = buffer.getInt();
if (MessageType.forNumber(messageTypeValue) == null) {
throw new CelebornIOException("Decode failed, fallback to legacy
messages.");
@@ -140,6 +155,6 @@ public class TransportMessage implements Serializable {
byte[] payload = new byte[payloadLen];
buffer.get(payload);
MessageType msgType = MessageType.forNumber(messageTypeValue);
- return new TransportMessage(msgType, payload);
+ return new TransportMessage(msgType, payload, serdeVersion);
}
}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 949b13322..8719dea7a 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -28,7 +28,7 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{DiskInfo, WorkerInfo, WorkerStatus}
-import org.apache.celeborn.common.network.protocol.TransportMessage
+import org.apache.celeborn.common.network.protocol.{SerdeVersion,
TransportMessage}
import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.MessageType._
import org.apache.celeborn.common.quota.ResourceConsumption
@@ -279,7 +279,10 @@ object ControlMessages extends Logging {
case class MapperEndResponse(status: StatusCode) extends MasterMessage
- case class GetReducerFileGroup(shuffleId: Int, isSegmentGranularityVisible:
Boolean)
+ case class GetReducerFileGroup(
+ shuffleId: Int,
+ isSegmentGranularityVisible: Boolean,
+ serdeVersion: SerdeVersion)
extends MasterMessage
// util.Set[String] -> util.Set[Path.toString]
@@ -290,7 +293,8 @@ object ControlMessages extends Logging {
attempts: Array[Int] = Array.emptyIntArray,
partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] =
Collections.emptyMap(),
- broadcast: Array[Byte] = Array.emptyByteArray)
+ broadcast: Array[Byte] = Array.emptyByteArray,
+ serdeVersion: SerdeVersion = SerdeVersion.V1)
extends MasterMessage
object WorkerExclude {
@@ -747,12 +751,12 @@ object ControlMessages extends Logging {
.build().toByteArray
new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload)
- case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible) =>
+ case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible,
serdeVersion) =>
val payload = PbGetReducerFileGroup.newBuilder()
.setShuffleId(shuffleId)
.setIsSegmentGranularityVisible(isSegmentGranularityVisible)
.build().toByteArray
- new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload)
+ new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload,
serdeVersion)
case GetReducerFileGroupResponse(
status,
@@ -760,7 +764,8 @@ object ControlMessages extends Logging {
attempts,
partitionIds,
failedBatches,
- broadcast) =>
+ broadcast,
+ serdeVersion) =>
val builder = PbGetReducerFileGroupResponse
.newBuilder()
.setStatus(status.getValue)
@@ -780,7 +785,7 @@ object ControlMessages extends Logging {
}.asJava)
builder.setBroadcast(ByteString.copyFrom(broadcast))
val payload = builder.build().toByteArray
- new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE,
payload)
+ new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE,
payload, serdeVersion)
case pb: PbWorkerExclude =>
new TransportMessage(MessageType.WORKER_EXCLUDE, pb.toByteArray)
@@ -1177,7 +1182,8 @@ object ControlMessages extends Logging {
val pbGetReducerFileGroup =
PbGetReducerFileGroup.parseFrom(message.getPayload)
GetReducerFileGroup(
pbGetReducerFileGroup.getShuffleId,
- pbGetReducerFileGroup.getIsSegmentGranularityVisible)
+ pbGetReducerFileGroup.getIsSegmentGranularityVisible,
+ message.getSerdeVersion)
case GET_REDUCER_FILE_GROUP_RESPONSE_VALUE =>
val pbGetReducerFileGroupResponse = PbGetReducerFileGroupResponse
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
index 1ce989134..b2de46559 100644
---
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
@@ -34,7 +34,7 @@ import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.network.TransportContext
import org.apache.celeborn.common.network.client._
-import org.apache.celeborn.common.network.protocol.{RequestMessage =>
NRequestMessage, RpcRequest}
+import org.apache.celeborn.common.network.protocol.{RequestMessage =>
NRequestMessage, RpcRequest, SerdeVersion, TransportMessage}
import org.apache.celeborn.common.network.sasl.{SaslClientBootstrap,
SaslServerBootstrap}
import
org.apache.celeborn.common.network.sasl.registration.{RegistrationClientBootstrap,
RegistrationServerBootstrap}
import org.apache.celeborn.common.network.server._
@@ -504,6 +504,20 @@ private[celeborn] class RequestMessage(
writeRpcAddress(out, senderAddress)
writeRpcAddress(out, receiver.address)
out.writeUTF(receiver.name)
+ val msg = Utils.toTransportMessage(content)
+ msg match {
+ case transMsg: TransportMessage =>
+ // Check if the msg is a TransportMessage with language-agnostic V2
serdeVersion.
+ // If so, write the marker and the body explicitly.
+ if (transMsg.getSerdeVersion == SerdeVersion.V2) {
+ val out = new DataOutputStream(bos)
+ out.writeByte(SerdeVersion.V2.getMarker)
+ out.write(transMsg.toByteBuffer.array)
+ out.close()
+ return bos.toByteBuffer
+ }
+ case _ =>
+ }
val s = nettyEnv.serializeStream(out)
try {
s.writeObject(Utils.toTransportMessage(content))
diff --git
a/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
b/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
index d38161583..3a813ef9c 100644
---
a/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag
import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.network.protocol.{SerdeVersion,
TransportMessage}
import org.apache.celeborn.common.util.{ByteBufferInputStream,
ByteBufferOutputStream, Utils}
private[celeborn] class JavaSerializationStream(
@@ -98,6 +99,20 @@ private[celeborn] class JavaSerializerInstance(
override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteBufferOutputStream()
+ val msg = Utils.toTransportMessage(t)
+ msg match {
+ case transMsg: TransportMessage =>
+ // Check if the msg is a TransportMessage with language-agnostic V2
serdeVersion.
+ // If so, write the marker and the body explicitly.
+ if (transMsg.getSerdeVersion == SerdeVersion.V2) {
+ val out = new DataOutputStream(bos)
+ out.writeByte(SerdeVersion.V2.getMarker)
+ out.write(transMsg.toByteBuffer.array)
+ out.close()
+ return bos.toByteBuffer
+ }
+ case _ =>
+ }
val out = serializeStream(bos)
out.writeObject(Utils.toTransportMessage(t))
out.close()
@@ -105,12 +120,28 @@ private[celeborn] class JavaSerializerInstance(
}
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
+ bytes.mark
+ val serdeVersion = bytes.get
+ // If the serdeVersion byte is V2, deserialize directly.
+ if (serdeVersion == SerdeVersion.V2.getMarker) {
+ return Utils.fromTransportMessage(
+ TransportMessage.fromByteBuffer(bytes,
SerdeVersion.V2)).asInstanceOf[T]
+ }
+ bytes.reset
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis)
Utils.fromTransportMessage(in.readObject()).asInstanceOf[T]
}
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader:
ClassLoader): T = {
+ bytes.mark
+ val serdeVersion = bytes.get
+ // If the serdeVersion byte is V2, deserialize directly.
+ if (serdeVersion == SerdeVersion.V2.getMarker) {
+ return Utils.fromTransportMessage(
+ TransportMessage.fromByteBuffer(bytes,
SerdeVersion.V2)).asInstanceOf[T]
+ }
+ bytes.reset
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis, loader)
Utils.fromTransportMessage(in.readObject()).asInstanceOf[T]