This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/branch-0.3 by this push:
new 581923951 [CELEBORN-771][FLINK] Convert PushDataHandShake,
RegionFinish, RegionStart to PB
581923951 is described below
commit 581923951d669013651e925a1774fdb7699ec6fc
Author: SteNicholas <[email protected]>
AuthorDate: Fri Sep 22 11:36:45 2023 +0800
[CELEBORN-771][FLINK] Convert PushDataHandShake, RegionFinish, RegionStart
to PB
### What changes were proposed in this pull request?
`PushDataHandShake`, `RegionFinish`, and `RegionStart` should merge to
transport messages to enhance celeborn's compatibility.
### Why are the changes needed?
1. Improves celeborn's transport flexibility to change RPC.
2. Makes Compatible with 0.2 client.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- `RemoteShuffleOutputGateSuiteJ`
Closes #1910 from SteNicholas/CELEBORN-771.
Authored-by: SteNicholas <[email protected]>
Signed-off-by: mingji <[email protected]>
(cherry picked from commit 55e85055b80111b6f3662474c8fd500bf6589f2a)
---
.../flink/readclient/FlinkShuffleClientImpl.java | 70 ++++++----
.../common/network/protocol/PushDataHandShake.java | 1 +
.../common/network/protocol/RegionFinish.java | 1 +
.../common/network/protocol/RegionStart.java | 1 +
.../common/network/protocol/TransportMessage.java | 12 ++
common/src/main/proto/TransportMessages.proto | 29 +++++
.../service/deploy/worker/PushDataHandler.scala | 145 +++++++++++++++++----
7 files changed, 207 insertions(+), 52 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index 992716894..c7ed33b33 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -44,13 +44,16 @@ import
org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
-import org.apache.celeborn.common.network.protocol.PushDataHandShake;
-import org.apache.celeborn.common.network.protocol.RegionFinish;
-import org.apache.celeborn.common.network.protocol.RegionStart;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbChangeLocationPartitionInfo;
import org.apache.celeborn.common.protocol.PbChangeLocationResponse;
+import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode;
+import org.apache.celeborn.common.protocol.PbPushDataHandShake;
+import org.apache.celeborn.common.protocol.PbRegionFinish;
+import org.apache.celeborn.common.protocol.PbRegionStart;
import org.apache.celeborn.common.protocol.ReviveRequest;
import org.apache.celeborn.common.protocol.TransportModuleConstants;
import org.apache.celeborn.common.protocol.message.ControlMessages;
@@ -332,18 +335,23 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
location.getUniqueId());
logger.debug("PushDataHandShake location {}", location);
TransportClient client =
createClientWaitingInFlightRequest(location, mapKey, pushState);
- PushDataHandShake handShake =
- new PushDataHandShake(
- PRIMARY_MODE,
- shuffleKey,
- location.getUniqueId(),
- attemptId,
- numPartitions,
- bufferSize);
ByteBuffer pushDataHandShakeResponse;
try {
pushDataHandShakeResponse =
- client.sendRpcSync(handShake.toByteBuffer(),
conf.pushDataTimeoutMs());
+ client.sendRpcSync(
+ new TransportMessage(
+ MessageType.PUSH_DATA_HAND_SHAKE,
+ PbPushDataHandShake.newBuilder()
+ .setMode(Mode.forNumber(PRIMARY_MODE))
+ .setShuffleKey(shuffleKey)
+ .setPartitionUniqueId(location.getUniqueId())
+ .setAttemptId(attemptId)
+ .setNumPartitions(numPartitions)
+ .setBufferSize(bufferSize)
+ .build()
+ .toByteArray())
+ .toByteBuffer(),
+ conf.pushDataTimeoutMs());
} catch (IOException e) {
// ioexeption revive
return revive(shuffleId, mapId, attemptId, location);
@@ -378,18 +386,23 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
location.getUniqueId());
logger.debug("RegionStart for location {}.", location.toString());
TransportClient client =
createClientWaitingInFlightRequest(location, mapKey, pushState);
- RegionStart regionStart =
- new RegionStart(
- PRIMARY_MODE,
- shuffleKey,
- location.getUniqueId(),
- attemptId,
- currentRegionIdx,
- isBroadcast);
ByteBuffer regionStartResponse;
try {
regionStartResponse =
- client.sendRpcSync(regionStart.toByteBuffer(),
conf.pushDataTimeoutMs());
+ client.sendRpcSync(
+ new TransportMessage(
+ MessageType.REGION_START,
+ PbRegionStart.newBuilder()
+ .setMode(Mode.forNumber(PRIMARY_MODE))
+ .setShuffleKey(shuffleKey)
+ .setPartitionUniqueId(location.getUniqueId())
+ .setAttemptId(attemptId)
+ .setCurrentRegionIndex(currentRegionIdx)
+ .setIsBroadcast(isBroadcast)
+ .build()
+ .toByteArray())
+ .toByteBuffer(),
+ conf.pushDataTimeoutMs());
} catch (IOException e) {
// ioexeption revive
return revive(shuffleId, mapId, attemptId, location);
@@ -459,9 +472,18 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
location.getUniqueId());
logger.debug("RegionFinish for location {}.", location);
TransportClient client =
createClientWaitingInFlightRequest(location, mapKey, pushState);
- RegionFinish regionFinish =
- new RegionFinish(PRIMARY_MODE, shuffleKey,
location.getUniqueId(), attemptId);
- client.sendRpcSync(regionFinish.toByteBuffer(),
conf.pushDataTimeoutMs());
+ client.sendRpcSync(
+ new TransportMessage(
+ MessageType.REGION_FINISH,
+ PbRegionFinish.newBuilder()
+ .setMode(Mode.forNumber(PRIMARY_MODE))
+ .setShuffleKey(shuffleKey)
+ .setPartitionUniqueId(location.getUniqueId())
+ .setAttemptId(attemptId)
+ .build()
+ .toByteArray())
+ .toByteBuffer(),
+ conf.pushDataTimeoutMs());
return null;
});
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
index 163fcaeb9..dc8c04816 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+@Deprecated
public final class PushDataHandShake extends RequestMessage {
// 0 for primary, 1 for replica, see PartitionLocation.Mode
public final byte mode;
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
index 62fe18e53..c7f804d84 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+@Deprecated
public final class RegionFinish extends RequestMessage {
// 0 for primary, 1 for replica, see PartitionLocation.Mode
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
index 322029d28..4081c5cf4 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java
@@ -20,6 +20,7 @@ package org.apache.celeborn.common.network.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
+@Deprecated
public final class RegionStart extends RequestMessage {
// 0 for primary, 1 for replica, see PartitionLocation.Mode
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 769eef0d4..fd7d16350 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -20,6 +20,9 @@ package org.apache.celeborn.common.network.protocol;
import static
org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
import static
org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
import static
org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
+import static
org.apache.celeborn.common.protocol.MessageType.PUSH_DATA_HAND_SHAKE_VALUE;
+import static
org.apache.celeborn.common.protocol.MessageType.REGION_FINISH_VALUE;
+import static
org.apache.celeborn.common.protocol.MessageType.REGION_START_VALUE;
import static
org.apache.celeborn.common.protocol.MessageType.READ_ADD_CREDIT_VALUE;
import static
org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
@@ -37,6 +40,9 @@ import
org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
+import org.apache.celeborn.common.protocol.PbPushDataHandShake;
+import org.apache.celeborn.common.protocol.PbRegionFinish;
+import org.apache.celeborn.common.protocol.PbRegionStart;
import org.apache.celeborn.common.protocol.PbStreamHandler;
public class TransportMessage implements Serializable {
@@ -76,6 +82,12 @@ public class TransportMessage implements Serializable {
return (T) PbBufferStreamEnd.parseFrom(payload);
case READ_ADD_CREDIT_VALUE:
return (T) PbReadAddCredit.parseFrom(payload);
+ case PUSH_DATA_HAND_SHAKE_VALUE:
+ return (T) PbPushDataHandShake.parseFrom(payload);
+ case REGION_START_VALUE:
+ return (T) PbRegionStart.parseFrom(payload);
+ case REGION_FINISH_VALUE:
+ return (T) PbRegionFinish.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index e31725edc..fa91bd234 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -79,6 +79,9 @@ enum MessageType {
BACKLOG_ANNOUNCEMENT = 59;
BUFFER_STREAM_END = 60;
READ_ADD_CREDIT = 61;
+ PUSH_DATA_HAND_SHAKE = 56;
+ REGION_START = 57;
+ REGION_FINISH = 58;
}
message PbStorageInfo {
@@ -527,3 +530,29 @@ message PbReadAddCredit {
int64 streamId = 1;
int32 credit = 2;
}
+
+
+message PbPushDataHandShake {
+ PbPartitionLocation.Mode mode = 1;
+ string shuffleKey = 2;
+ string partitionUniqueId = 3;
+ int32 attemptId = 4;
+ int32 numPartitions = 5;
+ int32 bufferSize = 6;
+}
+
+message PbRegionStart {
+ PbPartitionLocation.Mode mode = 1;
+ string shuffleKey = 2;
+ string partitionUniqueId = 3;
+ int32 attemptId = 4;
+ int32 currentRegionIndex = 5;
+ bool isBroadcast = 6;
+}
+
+message PbRegionFinish {
+ PbPartitionLocation.Mode mode = 1;
+ string shuffleKey = 2;
+ string partitionUniqueId = 3;
+ int32 attemptId = 4;
+}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 68facb601..0a201439a 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap,
ThreadPoolExecutor}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray}
import com.google.common.base.Throwables
+import com.google.protobuf.GeneratedMessageV3
import io.netty.buffer.ByteBuf
import org.apache.celeborn.common.exception.{AlreadyClosedException,
CelebornIOException}
@@ -30,10 +31,11 @@ import org.apache.celeborn.common.meta.{DiskStatus,
WorkerInfo, WorkerPartitionL
import org.apache.celeborn.common.metrics.source.Source
import org.apache.celeborn.common.network.buffer.{NettyManagedBuffer,
NioManagedBuffer}
import org.apache.celeborn.common.network.client.{RpcResponseCallback,
TransportClient, TransportClientFactory}
-import org.apache.celeborn.common.network.protocol.{Message, PushData,
PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage,
RpcFailure, RpcRequest, RpcResponse}
+import org.apache.celeborn.common.network.protocol.{Message, PushData,
PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage,
RpcFailure, RpcRequest, RpcResponse, TransportMessage}
import org.apache.celeborn.common.network.protocol.Message.Type
import org.apache.celeborn.common.network.server.BaseMessageHandler
-import org.apache.celeborn.common.protocol.{PartitionLocation,
PartitionSplitMode, PartitionType}
+import org.apache.celeborn.common.protocol.{PartitionLocation,
PartitionSplitMode, PartitionType, PbPushDataHandShake, PbRegionFinish,
PbRegionStart}
+import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.unsafe.Platform
import org.apache.celeborn.common.util.Utils
@@ -798,10 +800,8 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
pushData.`type`(),
shuffleKey,
pushData.partitionUniqueId,
- null,
location,
- callback,
- wrappedCallback)) return
+ callback)) return
val fileWriter =
getFileWriterAndCheck(pushData.`type`(), location, isPrimary, callback)
match {
@@ -849,45 +849,120 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
}
private def handleRpcRequest(client: TransportClient, rpcRequest:
RpcRequest): Unit = {
- val msg = Message.decode(rpcRequest.body().nioByteBuffer())
val requestId = rpcRequest.requestId
- val (mode, shuffleKey, partitionUniqueId, checkSplit) = msg match {
- case p: PushDataHandShake => (p.mode, p.shuffleKey, p.partitionUniqueId,
true)
- case rs: RegionStart => (rs.mode, rs.shuffleKey, rs.partitionUniqueId,
true)
- case rf: RegionFinish => (rf.mode, rf.shuffleKey, rf.partitionUniqueId,
false)
- }
+ val (pbMsg, msg, isLegacy, messageType, mode, shuffleKey,
partitionUniqueId, checkSplit) =
+ mapPartitionRpcRequest(rpcRequest)
handleCore(
client,
rpcRequest,
requestId,
() =>
handleMapPartitionRpcRequestCore(
- mode,
+ requestId,
+ pbMsg,
msg,
+ isLegacy,
+ messageType,
+ mode,
shuffleKey,
partitionUniqueId,
- requestId,
checkSplit,
new SimpleRpcResponseCallback(
client,
requestId,
shuffleKey)))
+ }
+ private def mapPartitionRpcRequest(rpcRequest: RpcRequest)
+ : Tuple8[GeneratedMessageV3, Message, Boolean, Type, Mode, String,
String, Boolean] = {
+ try {
+ val msg = TransportMessage.fromByteBuffer(
+
rpcRequest.body().nioByteBuffer()).getParsedPayload.asInstanceOf[GeneratedMessageV3]
+ msg match {
+ case p: PbPushDataHandShake =>
+ (
+ msg,
+ null,
+ false,
+ Type.PUSH_DATA_HAND_SHAKE,
+ p.getMode,
+ p.getShuffleKey,
+ p.getPartitionUniqueId,
+ true)
+ case rs: PbRegionStart =>
+ (
+ msg,
+ null,
+ false,
+ Type.REGION_START,
+ rs.getMode,
+ rs.getShuffleKey,
+ rs.getPartitionUniqueId,
+ true)
+ case rf: PbRegionFinish =>
+ (
+ msg,
+ null,
+ false,
+ Type.REGION_FINISH,
+ rf.getMode,
+ rf.getShuffleKey,
+ rf.getPartitionUniqueId,
+ false)
+ }
+ } catch {
+ case _: Exception =>
+ val msg = Message.decode(rpcRequest.body().nioByteBuffer())
+ msg match {
+ case p: PushDataHandShake =>
+ (
+ null,
+ msg,
+ true,
+ Type.PUSH_DATA_HAND_SHAKE,
+ Mode.forNumber(p.mode),
+ p.shuffleKey,
+ p.partitionUniqueId,
+ true)
+ case rs: RegionStart =>
+ (
+ null,
+ msg,
+ true,
+ Type.REGION_START,
+ Mode.forNumber(rs.mode),
+ rs.shuffleKey,
+ rs.partitionUniqueId,
+ true)
+ case rf: RegionFinish =>
+ (
+ null,
+ msg,
+ true,
+ Type.REGION_FINISH,
+ Mode.forNumber(rf.mode),
+ rf.shuffleKey,
+ rf.partitionUniqueId,
+ false)
+ }
+ }
}
private def handleMapPartitionRpcRequestCore(
- mode: Byte,
- message: Message,
+ requestId: Long,
+ pbMsg: GeneratedMessageV3,
+ msg: Message,
+ isLegacy: Boolean,
+ messageType: Message.Type,
+ mode: Mode,
shuffleKey: String,
partitionUniqueId: String,
- requestId: Long,
checkSplit: Boolean,
callback: RpcResponseCallback): Unit = {
- val isPrimary = PartitionLocation.getMode(mode) ==
PartitionLocation.Mode.PRIMARY
- val messageType = message.`type`()
log.debug(
s"requestId:$requestId, pushdata rpc:$messageType, mode:$mode,
shuffleKey:$shuffleKey, " +
s"partitionUniqueId:$partitionUniqueId")
+ val isPrimary = mode == Mode.Primary
val (workerSourcePrimary, workerSourceReplica) =
messageType match {
case Type.PUSH_DATA_HAND_SHAKE =>
@@ -924,10 +999,8 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
messageType,
shuffleKey,
partitionUniqueId,
- null,
location,
- callback,
- wrappedCallback)) return
+ callback)) return
val fileWriter =
getFileWriterAndCheck(messageType, location, isPrimary, callback) match {
@@ -957,13 +1030,31 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
try {
messageType match {
case Type.PUSH_DATA_HAND_SHAKE =>
+ val (numPartitions, bufferSize) =
+ if (isLegacy)
+ (
+ msg.asInstanceOf[PushDataHandShake].numPartitions,
+ msg.asInstanceOf[PushDataHandShake].bufferSize)
+ else
+ (
+ pbMsg.asInstanceOf[PbPushDataHandShake].getNumPartitions,
+ pbMsg.asInstanceOf[PbPushDataHandShake].getBufferSize)
fileWriter.asInstanceOf[MapPartitionFileWriter].pushDataHandShake(
- message.asInstanceOf[PushDataHandShake].numPartitions,
- message.asInstanceOf[PushDataHandShake].bufferSize)
+ numPartitions,
+ bufferSize)
case Type.REGION_START =>
+ val (currentRegionIndex, isBroadcast) =
+ if (isLegacy)
+ (
+ msg.asInstanceOf[RegionStart].currentRegionIndex,
+ msg.asInstanceOf[RegionStart].isBroadcast)
+ else
+ (
+ pbMsg.asInstanceOf[PbRegionStart].getCurrentRegionIndex,
+ Boolean.box(pbMsg.asInstanceOf[PbRegionStart].getIsBroadcast))
fileWriter.asInstanceOf[MapPartitionFileWriter].regionStart(
- message.asInstanceOf[RegionStart].currentRegionIndex,
- message.asInstanceOf[RegionStart].isBroadcast)
+ currentRegionIndex,
+ isBroadcast)
case Type.REGION_FINISH =>
fileWriter.asInstanceOf[MapPartitionFileWriter].regionFinish()
case _ => throw new IllegalArgumentException(s"Not support
$messageType yet")
@@ -1039,10 +1130,8 @@ class PushDataHandler extends BaseMessageHandler with
Logging {
messageType: Message.Type,
shuffleKey: String,
partitionUniqueId: String,
- body: ByteBuf,
location: PartitionLocation,
- callback: RpcResponseCallback,
- wrappedCallback: RpcResponseCallback): Boolean = {
+ callback: RpcResponseCallback): Boolean = {
if (location == null) {
val msg =
s"Partition location wasn't found for task(shuffle $shuffleKey,
uniqueId $partitionUniqueId)."