HolyLow commented on code in PR #3177:
URL: https://github.com/apache/celeborn/pull/3177#discussion_r2022446128
##########
common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala:
##########
@@ -98,19 +99,51 @@ private[celeborn] class JavaSerializerInstance(
override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteBufferOutputStream()
+ val msg = Utils.toTransportMessage(t)
+ msg match {
+ case transMsg: TransportMessage =>
+ // Check if the msg is a TransportMessage with CPP languageType.
+ // If so, write the marker and the body explicitly.
+ if (transMsg.getLanguageType == LanguageType.CPP) {
+ val out = new DataOutputStream(bos)
+ out.writeByte(LanguageType.CPP.getMarker)
+ out.write(transMsg.toByteBuffer.array)
+ out.close()
+ return bos.toByteBuffer
+ }
+ case _ =>
+ }
val out = serializeStream(bos)
out.writeObject(Utils.toTransportMessage(t))
out.close()
bos.toByteBuffer
}
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
+ bytes.mark
+ val languageMarker = bytes.get
+ bytes.reset
+ // If the languageMarker byte is CPP, deserialize directly.
+ if (languageMarker == LanguageType.CPP.getMarker) {
+ bytes.get
+ return Utils.fromTransportMessage(
+ TransportMessage.fromByteBuffer(bytes,
LanguageType.CPP)).asInstanceOf[T]
+ }
val bis = new ByteBufferInputStream(bytes)
Review Comment:
You are right, this would be better. Refactored as suggested.
##########
common/src/main/scala/org/apache/celeborn/common/serializer/JavaSerializer.scala:
##########
@@ -98,19 +99,51 @@ private[celeborn] class JavaSerializerInstance(
override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteBufferOutputStream()
+ val msg = Utils.toTransportMessage(t)
+ msg match {
+ case transMsg: TransportMessage =>
+ // Check if the msg is a TransportMessage with CPP languageType.
+ // If so, write the marker and the body explicitly.
+ if (transMsg.getLanguageType == LanguageType.CPP) {
+ val out = new DataOutputStream(bos)
+ out.writeByte(LanguageType.CPP.getMarker)
+ out.write(transMsg.toByteBuffer.array)
+ out.close()
+ return bos.toByteBuffer
+ }
+ case _ =>
+ }
val out = serializeStream(bos)
out.writeObject(Utils.toTransportMessage(t))
out.close()
bos.toByteBuffer
}
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
+ bytes.mark
+ val languageMarker = bytes.get
+ bytes.reset
+ // If the languageMarker byte is CPP, deserialize directly.
+ if (languageMarker == LanguageType.CPP.getMarker) {
+ bytes.get
+ return Utils.fromTransportMessage(
+ TransportMessage.fromByteBuffer(bytes,
LanguageType.CPP)).asInstanceOf[T]
+ }
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis)
Utils.fromTransportMessage(in.readObject()).asInstanceOf[T]
}
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader:
ClassLoader): T = {
+ bytes.mark
+ val languageMarker = bytes.get
+ bytes.reset
+ // If the languageMarker byte is CPP, deserialize directly.
+ if (languageMarker == LanguageType.CPP.getMarker) {
+ bytes.get
+ return Utils.fromTransportMessage(
+ TransportMessage.fromByteBuffer(bytes,
LanguageType.CPP)).asInstanceOf[T]
+ }
val bis = new ByteBufferInputStream(bytes)
Review Comment:
Refactored as suggested.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]