This is an automated email from the ASF dual-hosted git repository. chengpan pushed a commit to branch branch-0.3 in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
commit dd05933430375bd3ed56c3707b7982d567b7f5b8 Author: SteNicholas <[email protected]> AuthorDate: Fri Sep 22 11:36:45 2023 +0800 [CELEBORN-771][FLINK] Convert PushDataHandShake, RegionFinish, RegionStart to PB ### What changes were proposed in this pull request? `PushDataHandShake`, `RegionFinish`, and `RegionStart` should merge to transport messages to enhance celeborn's compatibility. ### Why are the changes needed? 1. Improves celeborn's transport flexibility to change RPC. 2. Makes Compatible with 0.2 client. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - `RemoteShuffleOutputGateSuiteJ` Closes #1910 from SteNicholas/CELEBORN-771. Authored-by: SteNicholas <[email protected]> Signed-off-by: mingji <[email protected]> (cherry picked from commit 55e85055b80111b6f3662474c8fd500bf6589f2a) --- .../flink/readclient/FlinkShuffleClientImpl.java | 70 ++++++---- .../common/network/protocol/PushDataHandShake.java | 1 + .../common/network/protocol/RegionFinish.java | 1 + .../common/network/protocol/RegionStart.java | 1 + .../common/network/protocol/TransportMessage.java | 12 ++ common/src/main/proto/TransportMessages.proto | 29 +++++ .../service/deploy/worker/PushDataHandler.scala | 145 +++++++++++++++++---- 7 files changed, 207 insertions(+), 52 deletions(-) diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java index 992716894..c7ed33b33 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java @@ -44,13 +44,16 @@ import org.apache.celeborn.common.network.client.RpcResponseCallback; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; import org.apache.celeborn.common.network.protocol.PushData; -import org.apache.celeborn.common.network.protocol.PushDataHandShake; -import org.apache.celeborn.common.network.protocol.RegionFinish; -import org.apache.celeborn.common.network.protocol.RegionStart; +import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.MessageType; import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.protocol.PbChangeLocationPartitionInfo; import org.apache.celeborn.common.protocol.PbChangeLocationResponse; +import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode; +import org.apache.celeborn.common.protocol.PbPushDataHandShake; +import org.apache.celeborn.common.protocol.PbRegionFinish; +import org.apache.celeborn.common.protocol.PbRegionStart; import org.apache.celeborn.common.protocol.ReviveRequest; import org.apache.celeborn.common.protocol.TransportModuleConstants; import org.apache.celeborn.common.protocol.message.ControlMessages; @@ -332,18 +335,23 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl { location.getUniqueId()); logger.debug("PushDataHandShake location {}", location); TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - PushDataHandShake handShake = - new PushDataHandShake( - PRIMARY_MODE, - shuffleKey, - location.getUniqueId(), - attemptId, - numPartitions, - bufferSize); ByteBuffer pushDataHandShakeResponse; try { pushDataHandShakeResponse = - client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs()); + client.sendRpcSync( + new TransportMessage( + MessageType.PUSH_DATA_HAND_SHAKE, + PbPushDataHandShake.newBuilder() + .setMode(Mode.forNumber(PRIMARY_MODE)) + .setShuffleKey(shuffleKey) + .setPartitionUniqueId(location.getUniqueId()) + .setAttemptId(attemptId) + .setNumPartitions(numPartitions) + .setBufferSize(bufferSize) + .build() + .toByteArray()) + .toByteBuffer(), + conf.pushDataTimeoutMs()); } catch (IOException e) { // ioexeption revive return revive(shuffleId, mapId, attemptId, location); @@ -378,18 +386,23 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl { location.getUniqueId()); logger.debug("RegionStart for location {}.", location.toString()); TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - RegionStart regionStart = - new RegionStart( - PRIMARY_MODE, - shuffleKey, - location.getUniqueId(), - attemptId, - currentRegionIdx, - isBroadcast); ByteBuffer regionStartResponse; try { regionStartResponse = - client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs()); + client.sendRpcSync( + new TransportMessage( + MessageType.REGION_START, + PbRegionStart.newBuilder() + .setMode(Mode.forNumber(PRIMARY_MODE)) + .setShuffleKey(shuffleKey) + .setPartitionUniqueId(location.getUniqueId()) + .setAttemptId(attemptId) + .setCurrentRegionIndex(currentRegionIdx) + .setIsBroadcast(isBroadcast) + .build() + .toByteArray()) + .toByteBuffer(), + conf.pushDataTimeoutMs()); } catch (IOException e) { // ioexeption revive return revive(shuffleId, mapId, attemptId, location); @@ -459,9 +472,18 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl { location.getUniqueId()); logger.debug("RegionFinish for location {}.", location); TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - RegionFinish regionFinish = - new RegionFinish(PRIMARY_MODE, shuffleKey, location.getUniqueId(), attemptId); - client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs()); + client.sendRpcSync( + new TransportMessage( + MessageType.REGION_FINISH, + PbRegionFinish.newBuilder() + .setMode(Mode.forNumber(PRIMARY_MODE)) + .setShuffleKey(shuffleKey) + .setPartitionUniqueId(location.getUniqueId()) + .setAttemptId(attemptId) + .build() + .toByteArray()) + .toByteBuffer(), + conf.pushDataTimeoutMs()); return null; }); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java index 163fcaeb9..dc8c04816 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java @@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +@Deprecated public final class PushDataHandShake extends RequestMessage { // 0 for primary, 1 for replica, see PartitionLocation.Mode public final byte mode; diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java index 62fe18e53..c7f804d84 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java @@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +@Deprecated public final class RegionFinish extends RequestMessage { // 0 for primary, 1 for replica, see PartitionLocation.Mode diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java index 322029d28..4081c5cf4 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java @@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +@Deprecated public final class RegionStart extends RequestMessage { // 0 for primary, 1 for replica, see PartitionLocation.Mode diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java index 769eef0d4..965c534d9 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java @@ -20,7 +20,10 @@ package org.apache.celeborn.common.network.protocol; import static org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE; import static org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE; import static org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE; +import static org.apache.celeborn.common.protocol.MessageType.PUSH_DATA_HAND_SHAKE_VALUE; import static org.apache.celeborn.common.protocol.MessageType.READ_ADD_CREDIT_VALUE; +import static org.apache.celeborn.common.protocol.MessageType.REGION_FINISH_VALUE; +import static org.apache.celeborn.common.protocol.MessageType.REGION_START_VALUE; import static org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE; import java.io.Serializable; @@ -36,7 +39,10 @@ import org.apache.celeborn.common.protocol.MessageType; import org.apache.celeborn.common.protocol.PbBacklogAnnouncement; import org.apache.celeborn.common.protocol.PbBufferStreamEnd; import org.apache.celeborn.common.protocol.PbOpenStream; +import org.apache.celeborn.common.protocol.PbPushDataHandShake; import org.apache.celeborn.common.protocol.PbReadAddCredit; +import org.apache.celeborn.common.protocol.PbRegionFinish; +import org.apache.celeborn.common.protocol.PbRegionStart; import org.apache.celeborn.common.protocol.PbStreamHandler; public class TransportMessage implements Serializable { @@ -76,6 +82,12 @@ public class TransportMessage implements Serializable { return (T) PbBufferStreamEnd.parseFrom(payload); case READ_ADD_CREDIT_VALUE: return (T) PbReadAddCredit.parseFrom(payload); + case PUSH_DATA_HAND_SHAKE_VALUE: + return (T) PbPushDataHandShake.parseFrom(payload); + case REGION_START_VALUE: + return (T) PbRegionStart.parseFrom(payload); + case REGION_FINISH_VALUE: + return (T) PbRegionFinish.parseFrom(payload); default: logger.error("Unexpected type {}", type); } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index e31725edc..fa91bd234 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -79,6 +79,9 @@ enum MessageType { BACKLOG_ANNOUNCEMENT = 59; BUFFER_STREAM_END = 60; READ_ADD_CREDIT = 61; + PUSH_DATA_HAND_SHAKE = 56; + REGION_START = 57; + REGION_FINISH = 58; } message PbStorageInfo { @@ -527,3 +530,29 @@ message PbReadAddCredit { int64 streamId = 1; int32 credit = 2; } + + +message PbPushDataHandShake { + PbPartitionLocation.Mode mode = 1; + string shuffleKey = 2; + string partitionUniqueId = 3; + int32 attemptId = 4; + int32 numPartitions = 5; + int32 bufferSize = 6; +} + +message PbRegionStart { + PbPartitionLocation.Mode mode = 1; + string shuffleKey = 2; + string partitionUniqueId = 3; + int32 attemptId = 4; + int32 currentRegionIndex = 5; + bool isBroadcast = 6; +} + +message PbRegionFinish { + PbPartitionLocation.Mode mode = 1; + string shuffleKey = 2; + string partitionUniqueId = 3; + int32 attemptId = 4; +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index 68facb601..0a201439a 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor} import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray} import com.google.common.base.Throwables +import com.google.protobuf.GeneratedMessageV3 import io.netty.buffer.ByteBuf import org.apache.celeborn.common.exception.{AlreadyClosedException, CelebornIOException} @@ -30,10 +31,11 @@ import org.apache.celeborn.common.meta.{DiskStatus, WorkerInfo, WorkerPartitionL import org.apache.celeborn.common.metrics.source.Source import org.apache.celeborn.common.network.buffer.{NettyManagedBuffer, NioManagedBuffer} import org.apache.celeborn.common.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory} -import org.apache.celeborn.common.network.protocol.{Message, PushData, PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, RpcFailure, RpcRequest, RpcResponse} +import org.apache.celeborn.common.network.protocol.{Message, PushData, PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, RpcFailure, RpcRequest, RpcResponse, TransportMessage} import org.apache.celeborn.common.network.protocol.Message.Type import org.apache.celeborn.common.network.server.BaseMessageHandler -import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType} +import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, PbPushDataHandShake, PbRegionFinish, PbRegionStart} +import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.unsafe.Platform import org.apache.celeborn.common.util.Utils @@ -798,10 +800,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { pushData.`type`(), shuffleKey, pushData.partitionUniqueId, - null, location, - callback, - wrappedCallback)) return + callback)) return val fileWriter = getFileWriterAndCheck(pushData.`type`(), location, isPrimary, callback) match { @@ -849,45 +849,120 @@ class PushDataHandler extends BaseMessageHandler with Logging { } private def handleRpcRequest(client: TransportClient, rpcRequest: RpcRequest): Unit = { - val msg = Message.decode(rpcRequest.body().nioByteBuffer()) val requestId = rpcRequest.requestId - val (mode, shuffleKey, partitionUniqueId, checkSplit) = msg match { - case p: PushDataHandShake => (p.mode, p.shuffleKey, p.partitionUniqueId, true) - case rs: RegionStart => (rs.mode, rs.shuffleKey, rs.partitionUniqueId, true) - case rf: RegionFinish => (rf.mode, rf.shuffleKey, rf.partitionUniqueId, false) - } + val (pbMsg, msg, isLegacy, messageType, mode, shuffleKey, partitionUniqueId, checkSplit) = + mapPartitionRpcRequest(rpcRequest) handleCore( client, rpcRequest, requestId, () => handleMapPartitionRpcRequestCore( - mode, + requestId, + pbMsg, msg, + isLegacy, + messageType, + mode, shuffleKey, partitionUniqueId, - requestId, checkSplit, new SimpleRpcResponseCallback( client, requestId, shuffleKey))) + } + private def mapPartitionRpcRequest(rpcRequest: RpcRequest) + : Tuple8[GeneratedMessageV3, Message, Boolean, Type, Mode, String, String, Boolean] = { + try { + val msg = TransportMessage.fromByteBuffer( + rpcRequest.body().nioByteBuffer()).getParsedPayload.asInstanceOf[GeneratedMessageV3] + msg match { + case p: PbPushDataHandShake => + ( + msg, + null, + false, + Type.PUSH_DATA_HAND_SHAKE, + p.getMode, + p.getShuffleKey, + p.getPartitionUniqueId, + true) + case rs: PbRegionStart => + ( + msg, + null, + false, + Type.REGION_START, + rs.getMode, + rs.getShuffleKey, + rs.getPartitionUniqueId, + true) + case rf: PbRegionFinish => + ( + msg, + null, + false, + Type.REGION_FINISH, + rf.getMode, + rf.getShuffleKey, + rf.getPartitionUniqueId, + false) + } + } catch { + case _: Exception => + val msg = Message.decode(rpcRequest.body().nioByteBuffer()) + msg match { + case p: PushDataHandShake => + ( + null, + msg, + true, + Type.PUSH_DATA_HAND_SHAKE, + Mode.forNumber(p.mode), + p.shuffleKey, + p.partitionUniqueId, + true) + case rs: RegionStart => + ( + null, + msg, + true, + Type.REGION_START, + Mode.forNumber(rs.mode), + rs.shuffleKey, + rs.partitionUniqueId, + true) + case rf: RegionFinish => + ( + null, + msg, + true, + Type.REGION_FINISH, + Mode.forNumber(rf.mode), + rf.shuffleKey, + rf.partitionUniqueId, + false) + } + } } private def handleMapPartitionRpcRequestCore( - mode: Byte, - message: Message, + requestId: Long, + pbMsg: GeneratedMessageV3, + msg: Message, + isLegacy: Boolean, + messageType: Message.Type, + mode: Mode, shuffleKey: String, partitionUniqueId: String, - requestId: Long, checkSplit: Boolean, callback: RpcResponseCallback): Unit = { - val isPrimary = PartitionLocation.getMode(mode) == PartitionLocation.Mode.PRIMARY - val messageType = message.`type`() log.debug( s"requestId:$requestId, pushdata rpc:$messageType, mode:$mode, shuffleKey:$shuffleKey, " + s"partitionUniqueId:$partitionUniqueId") + val isPrimary = mode == Mode.Primary val (workerSourcePrimary, workerSourceReplica) = messageType match { case Type.PUSH_DATA_HAND_SHAKE => @@ -924,10 +999,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { messageType, shuffleKey, partitionUniqueId, - null, location, - callback, - wrappedCallback)) return + callback)) return val fileWriter = getFileWriterAndCheck(messageType, location, isPrimary, callback) match { @@ -957,13 +1030,31 @@ class PushDataHandler extends BaseMessageHandler with Logging { try { messageType match { case Type.PUSH_DATA_HAND_SHAKE => + val (numPartitions, bufferSize) = + if (isLegacy) + ( + msg.asInstanceOf[PushDataHandShake].numPartitions, + msg.asInstanceOf[PushDataHandShake].bufferSize) + else + ( + pbMsg.asInstanceOf[PbPushDataHandShake].getNumPartitions, + pbMsg.asInstanceOf[PbPushDataHandShake].getBufferSize) fileWriter.asInstanceOf[MapPartitionFileWriter].pushDataHandShake( - message.asInstanceOf[PushDataHandShake].numPartitions, - message.asInstanceOf[PushDataHandShake].bufferSize) + numPartitions, + bufferSize) case Type.REGION_START => + val (currentRegionIndex, isBroadcast) = + if (isLegacy) + ( + msg.asInstanceOf[RegionStart].currentRegionIndex, + msg.asInstanceOf[RegionStart].isBroadcast) + else + ( + pbMsg.asInstanceOf[PbRegionStart].getCurrentRegionIndex, + Boolean.box(pbMsg.asInstanceOf[PbRegionStart].getIsBroadcast)) fileWriter.asInstanceOf[MapPartitionFileWriter].regionStart( - message.asInstanceOf[RegionStart].currentRegionIndex, - message.asInstanceOf[RegionStart].isBroadcast) + currentRegionIndex, + isBroadcast) case Type.REGION_FINISH => fileWriter.asInstanceOf[MapPartitionFileWriter].regionFinish() case _ => throw new IllegalArgumentException(s"Not support $messageType yet") @@ -1039,10 +1130,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { messageType: Message.Type, shuffleKey: String, partitionUniqueId: String, - body: ByteBuf, location: PartitionLocation, - callback: RpcResponseCallback, - wrappedCallback: RpcResponseCallback): Boolean = { + callback: RpcResponseCallback): Boolean = { if (location == null) { val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, uniqueId $partitionUniqueId)."
