This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d6145896 [CELEBORN-1490][CIP-6] Extends message to support hybrid 
shuffle
5d6145896 is described below

commit 5d61458964dbae6ad0b1857e4ddf0e82070bcb6a
Author: Weijie Guo <[email protected]>
AuthorDate: Fri Aug 30 09:43:09 2024 +0800

    [CELEBORN-1490][CIP-6] Extends message to support hybrid shuffle
    
    ### What changes were proposed in this pull request?
    
    This is the first PR to support Hybrid Shuffle.
    
    Extends message to support hybrid shuffle.
    
    ### Why are the changes needed?
    hybrid shuffle is a tiered storage architecture, which introduces the 
concept of `segment`. One segment's data selects a tier to send. Data is split 
into segments and sent to multiple tiers.
    
    This PR introduces segment-related message. In addition, hybrid shuffle 
needs to distinguish which subpartition it comes from when consuming data, so 
we need to extend the `SubpartitionId` field to `ReadData` (new class 
introduced for compatibility).
    
    ### Does this PR introduce _any_ user-facing change?
    no.
    
    ### How was this patch tested?
    no need.
    
    Closes #2714 from reswqa/cip6-1-extend-message.
    
    Authored-by: Weijie Guo <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../plugin/flink/network/MessageDecoderExt.java    |   6 ++
 .../plugin/flink/network/ReadClientHandler.java    |   7 ++
 .../flink/protocol/SubPartitionReadData.java       | 106 +++++++++++++++++++++
 .../celeborn/common/network/protocol/Message.java  |  12 ++-
 .../network/protocol/SubPartitionReadData.java     |  90 +++++++++++++++++
 .../common/network/protocol/TransportMessage.java  |   4 +
 .../common/protocol/message/StatusCode.java        |   4 +-
 common/src/main/proto/TransportMessages.proto      |  28 ++++++
 8 files changed, 255 insertions(+), 2 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
index 12f1293ec..6eb06970e 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
@@ -25,6 +25,7 @@ import 
org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
 import org.apache.celeborn.common.network.protocol.*;
 import org.apache.celeborn.plugin.flink.buffer.FlinkNettyManagedBuffer;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
+import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;
 
 public class MessageDecoderExt {
   public static Message decode(Message.Type type, ByteBuf in, boolean 
decodeBody) {
@@ -74,6 +75,11 @@ public class MessageDecoderExt {
         streamId = in.readLong();
         return new ReadData(streamId);
 
+      case SUBPARTITION_READ_DATA:
+        streamId = in.readLong();
+        int subPartitionId = in.readInt();
+        return new SubPartitionReadData(streamId, subPartitionId);
+
       case BACKLOG_ANNOUNCEMENT:
         streamId = in.readLong();
         int backlog = in.readInt();
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index d773edb88..c1d0982fb 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -39,6 +39,7 @@ import 
org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
 import org.apache.celeborn.common.util.JavaUtils;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
+import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;
 
 public class ReadClientHandler extends BaseMessageHandler {
   private static Logger logger = 
LoggerFactory.getLogger(ReadClientHandler.class);
@@ -65,6 +66,8 @@ public class ReadClientHandler extends BaseMessageHandler {
     } else {
       if (msg != null && msg instanceof ReadData) {
         ((ReadData) msg).getFlinkBuffer().release();
+      } else if (msg != null && msg instanceof SubPartitionReadData) {
+        ((SubPartitionReadData) msg).getFlinkBuffer().release();
       }
 
       logger.warn("Unexpected streamId received: {}", streamId);
@@ -83,6 +86,10 @@ public class ReadClientHandler extends BaseMessageHandler {
         ReadData readData = (ReadData) msg;
         processMessageInternal(readData.getStreamId(), readData);
         break;
+      case SUBPARTITION_READ_DATA:
+        SubPartitionReadData subPartitionReadData = (SubPartitionReadData) msg;
+        processMessageInternal(subPartitionReadData.getStreamId(), 
subPartitionReadData);
+        break;
       case BACKLOG_ANNOUNCEMENT:
         BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
         processMessageInternal(backlogAnnouncement.getStreamId(), 
backlogAnnouncement);
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
new file mode 100644
index 000000000..b12f24d9b
--- /dev/null
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.protocol;
+
+import java.util.Objects;
+
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+
+import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+
+/**
+ * Comparing {@link ReadData}, this class has an additional field of 
subpartitionId. This class is
+ * added to keep the backward compatibility.
+ */
+public class SubPartitionReadData extends RequestMessage {
+  private final long streamId;
+  private final int subPartitionId;
+  private ByteBuf flinkBuffer;
+
+  @Override
+  public boolean needCopyOut() {
+    return true;
+  }
+
+  public SubPartitionReadData(long streamId, int subPartitionId) {
+    this.subPartitionId = subPartitionId;
+    this.streamId = streamId;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + 4;
+  }
+
+  // This method will not be called because ReadData won't be created at flink 
client.
+  @Override
+  public void encode(io.netty.buffer.ByteBuf buf) {
+    buf.writeLong(streamId);
+    buf.writeInt(subPartitionId);
+  }
+
+  public long getStreamId() {
+    return streamId;
+  }
+
+  public int getSubPartitionId() {
+    return subPartitionId;
+  }
+
+  @Override
+  public Type type() {
+    return Type.SUBPARTITION_READ_DATA;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+    SubPartitionReadData readData = (SubPartitionReadData) o;
+    return streamId == readData.streamId
+        && subPartitionId == readData.subPartitionId
+        && Objects.equals(flinkBuffer, readData.flinkBuffer);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(streamId, subPartitionId, flinkBuffer);
+  }
+
+  @Override
+  public String toString() {
+    return "SubpartitionReadData{"
+        + "streamId="
+        + streamId
+        + ", subPartitionId="
+        + subPartitionId
+        + ", flinkBuffer="
+        + flinkBuffer
+        + '}';
+  }
+
+  public ByteBuf getFlinkBuffer() {
+    return flinkBuffer;
+  }
+
+  public void setFlinkBuffer(ByteBuf flinkBuffer) {
+    this.flinkBuffer = flinkBuffer;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
index b9423683b..ceb894312 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
@@ -92,7 +92,11 @@ public abstract class Message implements Encodable {
     BACKLOG_ANNOUNCEMENT(19),
     TRANSPORTABLE_ERROR(20),
     BUFFER_STREAM_END(21),
-    HEARTBEAT(22);
+    HEARTBEAT(22),
+    SEGMENT_START(23),
+    NOTIFY_REQUIRED_SEGMENT(24),
+    SUBPARTITION_READ_DATA(25);
+
     private final byte id;
 
     Type(int id) {
@@ -164,6 +168,12 @@ public abstract class Message implements Encodable {
           return BUFFER_STREAM_END;
         case 22:
           return HEARTBEAT;
+        case 23:
+          return SEGMENT_START;
+        case 24:
+          return NOTIFY_REQUIRED_SEGMENT;
+        case 25:
+          return SUBPARTITION_READ_DATA;
         case -1:
           throw new IllegalArgumentException("User type messages cannot be 
decoded.");
         default:
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
new file mode 100644
index 000000000..bdea84989
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import java.util.Objects;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+
+/**
+ * Comparing {@link ReadData}, this class has an additional field of 
subpartitionId. This class is
+ * added to keep the backward compatibility.
+ */
+public class SubPartitionReadData extends RequestMessage {
+  private long streamId;
+
+  private int subPartitionId;
+
+  public SubPartitionReadData(long streamId, int subPartitionId, ByteBuf buf) {
+    super(new NettyManagedBuffer(buf));
+    this.streamId = streamId;
+    this.subPartitionId = subPartitionId;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + 4;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(streamId);
+    buf.writeInt(subPartitionId);
+  }
+
+  public long getStreamId() {
+    return streamId;
+  }
+
+  public int getSubPartitionId() {
+    return subPartitionId;
+  }
+
+  @Override
+  public Type type() {
+    return Type.SUBPARTITION_READ_DATA;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+    SubPartitionReadData readData = (SubPartitionReadData) o;
+    return streamId == readData.streamId
+        && subPartitionId == readData.subPartitionId
+        && super.equals(o);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(streamId, subPartitionId, super.hashCode());
+  }
+
+  @Override
+  public String toString() {
+    return "SubpartitionReadData{"
+        + "streamId="
+        + streamId
+        + ", subPartitionId="
+        + subPartitionId
+        + '}';
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index aa5d7ad51..239874e8c 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -109,6 +109,10 @@ public class TransportMessage implements Serializable {
         return (T) PbOpenStreamList.parseFrom(payload);
       case BATCH_OPEN_STREAM_RESPONSE_VALUE:
         return (T) PbOpenStreamListResponse.parseFrom(payload);
+      case SEGMENT_START_VALUE:
+        return (T) PbSegmentStart.parseFrom(payload);
+      case NOTIFY_REQUIRED_SEGMENT_VALUE:
+        return (T) PbNotifyRequiredSegment.parseFrom(payload);
       default:
         logger.error("Unexpected type {}", type);
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
index 0ebfad65a..ca8655bab 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java
@@ -82,7 +82,9 @@ public enum StatusCode {
   DESTROY_SLOTS_MOCK_FAILURE(48),
   COMMIT_FILES_MOCK_FAILURE(49),
   PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50),
-  OPEN_STREAM_FAILED(51);
+  OPEN_STREAM_FAILED(51),
+  SEGMENT_START_FAIL_REPLICA(52),
+  SEGMENT_START_FAIL_PRIMARY(53);
 
   private final byte value;
 
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 154da825c..9bde109f0 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -105,6 +105,8 @@ enum MessageType {
   REPORT_WORKER_DECOMMISSION = 82;
   REPORT_BARRIER_STAGE_ATTEMPT_FAILURE = 83;
   REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE = 84;
+  SEGMENT_START = 85;
+  NOTIFY_REQUIRED_SEGMENT = 86;
 }
 
 enum StreamType {
@@ -258,6 +260,7 @@ message PbRegisterMapPartitionTask {
   int32 mapId = 3;
   int32 attemptId = 4;
   int32 partitionId = 5;
+  bool isSegmentGranularityVisible = 6;
 }
 
 message PbRegisterShuffleResponse {
@@ -353,6 +356,7 @@ message PbMapperEndResponse {
 
 message PbGetReducerFileGroup {
   int32 shuffleId = 1;
+  bool isSegmentGranularityVisible = 2;
 }
 
 message PbGetReducerFileGroupResponse {
@@ -471,6 +475,7 @@ message PbReserveSlots {
   bool partitionSplitEnabled = 11;
   int32 availableStorageTypes = 12;
   PbPackedPartitionLocationsPair partitionLocationsPair = 13;
+  bool isSegmentGranularityVisible = 14;
 }
 
 message PbReserveSlotsResponse {
@@ -576,6 +581,13 @@ message PbFileInfo {
   int32 numSubpartitions = 6;
   int64 bytesFlushed = 7;
   bool partitionSplitEnabled = 8;
+  bool isSegmentGranularityVisible = 9;
+  map<int32, int32> partitionWritingSegment = 10;
+  repeated PbSegmentIndex segmentIndex = 11;
+}
+
+message PbSegmentIndex {
+  map<int32, int32> firstBufferIndexToSegment = 1;
 }
 
 message PbMapFileMeta {
@@ -653,6 +665,7 @@ message PbOpenStream {
   int32 endIndex = 4;
   int32 initialCredit = 5;
   bool readLocalShuffle = 6;
+  bool requireSubpartitionId = 7;
 }
 
 message PbStreamHandler {
@@ -760,6 +773,21 @@ message PbAuthenticationInitiationRequest {
   repeated PbSaslMechanism saslMechanisms = 3;
 }
 
+message PbSegmentStart {
+  PbPartitionLocation.Mode mode = 1;
+  string shuffleKey = 2;
+  string partitionUniqueId = 3;
+  int32 attemptId = 4;
+  int32 subPartitionId = 5;
+  int32 segmentId = 6;
+}
+
+message PbNotifyRequiredSegment {
+  int64 streamId = 1;
+  int32 requiredSegmentId = 2;
+  int32 subPartitionId = 3;
+}
+
 message PbAuthenticationInitiationResponse {
   string version = 1;
   bool authEnabled = 2;

Reply via email to