This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 5595f2f4 [CELEBORN-124]Add buffer stream. (#1069)
5595f2f4 is described below

commit 5595f2f4b3de2c1dbbbaba39d7b62d3efee65c5d
Author: Ethan Feng <[email protected]>
AuthorDate: Fri Jan 6 15:54:52 2023 +0800

    [CELEBORN-124]Add buffer stream. (#1069)
---
 .../org/apache/celeborn/common/meta/FileInfo.java  |   4 +
 .../celeborn/common/network/protocol/Message.java  |  15 ++-
 .../common/network/protocol/ReadAddCredit.java     |  79 +++++++++++++++
 .../celeborn/common/network/protocol/ReadData.java | 111 +++++++++++++++++++++
 .../common/network/server/BufferStreamManager.java |  72 +++++++++++++
 .../service/deploy/worker/FetchHandler.scala       |  21 +++-
 6 files changed, 300 insertions(+), 2 deletions(-)

diff --git a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java 
b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
index 50d3d8bb..6eefb83c 100644
--- a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
+++ b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
@@ -182,4 +182,8 @@ public class FileInfo {
   public void setBufferSize(int bufferSize) {
     this.bufferSize = bufferSize;
   }
+
+  public int getBufferSize() {
+    return bufferSize;
+  }
 }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
index 02061426..7ba54cbd 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
@@ -84,7 +84,9 @@ public abstract class Message implements Encodable {
     PUSH_MERGED_DATA(12),
     REGION_START(13),
     REGION_FINISH(14),
-    PUSH_DATA_HAND_SHAKE(15);
+    PUSH_DATA_HAND_SHAKE(15),
+    READ_ADD_CREDIT(16),
+    READ_DATA(17);
 
     private final byte id;
 
@@ -138,6 +140,10 @@ public abstract class Message implements Encodable {
           return REGION_FINISH;
         case 15:
           return PUSH_DATA_HAND_SHAKE;
+        case 16:
+          return READ_ADD_CREDIT;
+        case 17:
+          return READ_DATA;
         case -1:
           throw new IllegalArgumentException("User type messages cannot be 
decoded.");
         default:
@@ -193,6 +199,13 @@ public abstract class Message implements Encodable {
 
       case PUSH_DATA_HAND_SHAKE:
         return PushDataHandShake.decode(in);
+
+      case READ_ADD_CREDIT:
+        return ReadAddCredit.decode(in);
+
+      case READ_DATA:
+        return ReadData.decode(in);
+
       default:
         throw new IllegalArgumentException("Unexpected message type: " + 
msgType);
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
new file mode 100644
index 00000000..ca34a5c1
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.common.network.protocol;
+
+import java.util.Objects;
+
+import io.netty.buffer.ByteBuf;
+
+public class ReadAddCredit extends RequestMessage {
+  private long streamId;
+  private int credit;
+
+  public ReadAddCredit(long streamId, int credit) {
+    this.streamId = streamId;
+    this.credit = credit;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + 4;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(streamId);
+    buf.writeInt(credit);
+  }
+
+  public static ReadAddCredit decode(ByteBuf buf) {
+    long streamId = buf.readLong();
+    int credit = buf.readInt();
+    return new ReadAddCredit(streamId, credit);
+  }
+
+  public long getStreamId() {
+    return streamId;
+  }
+
+  public int getCredit() {
+    return credit;
+  }
+
+  @Override
+  public Type type() {
+    return Type.READ_ADD_CREDIT;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+    ReadAddCredit that = (ReadAddCredit) o;
+    return streamId == that.streamId && credit == that.credit;
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(streamId, credit);
+  }
+
+  @Override
+  public String toString() {
+    return "ReadAddCredit{" + "streamId=" + streamId + ", credit=" + credit + 
'}';
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
new file mode 100644
index 00000000..92e107cc
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.common.network.protocol;
+
+import java.util.Objects;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+public class ReadData extends RequestMessage {
+  private long streamId;
+  private int backlog;
+  private long offset;
+  private ByteBuf buf;
+
+  public ReadData(long streamId, int backlog, long offset, ByteBuf buf) {
+    this.streamId = streamId;
+    this.backlog = backlog;
+    this.offset = offset;
+    this.buf = buf;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + 4 + 4 + 8 + 4 + buf.readableBytes();
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(streamId);
+    buf.writeInt(backlog);
+    buf.writeLong(offset);
+    buf.writeInt(this.buf.readableBytes());
+    buf.writeBytes(this.buf);
+  }
+
+  public static ReadData decode(ByteBuf buf) {
+    long streamId = buf.readLong();
+    int backlog = buf.readInt();
+    long offset = buf.readLong();
+    int tmpBufSize = buf.readInt();
+    ByteBuf tmpBuf = Unpooled.buffer(tmpBufSize, tmpBufSize);
+    buf.readBytes(tmpBuf);
+    return new ReadData(streamId, backlog, offset, tmpBuf);
+  }
+
+  public long getStreamId() {
+    return streamId;
+  }
+
+  public int getBacklog() {
+    return backlog;
+  }
+
+  public long getOffset() {
+    return offset;
+  }
+
+  public ByteBuf getBuf() {
+    return buf;
+  }
+
+  @Override
+  public Type type() {
+    return Type.READ_DATA;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+    ReadData readData = (ReadData) o;
+    return streamId == readData.streamId
+        && backlog == readData.backlog
+        && offset == readData.offset
+        && Objects.equals(buf, readData.buf);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(streamId, backlog, offset, buf);
+  }
+
+  @Override
+  public String toString() {
+    return "ReadData{"
+        + "streamId="
+        + streamId
+        + ", backlog="
+        + backlog
+        + ", offset="
+        + offset
+        + ", buf="
+        + buf
+        + '}';
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
new file mode 100644
index 00000000..bba500b8
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class BufferStreamManager {
+  private static final Logger logger = 
LoggerFactory.getLogger(BufferStreamManager.class);
+  private final AtomicLong nextStreamId;
+  protected final ConcurrentHashMap<Long, StreamState> streams;
+
+  protected class StreamState {
+    private Channel associatedChannel;
+    private int bufferSize;
+
+    public StreamState(Channel associatedChannel, int bufferSize) {
+      this.associatedChannel = associatedChannel;
+      this.bufferSize = bufferSize;
+    }
+
+    public Channel getAssociatedChannel() {
+      return associatedChannel;
+    }
+
+    public int getBufferSize() {
+      return bufferSize;
+    }
+  }
+
+  public BufferStreamManager() {
+    nextStreamId = new AtomicLong((long) new 
Random().nextInt(Integer.MAX_VALUE) * 1000);
+    streams = new ConcurrentHashMap<>();
+  }
+
+  public long registerStream(Channel channel, int bufferSize) {
+    long streamId = nextStreamId.getAndIncrement();
+    streams.put(streamId, new StreamState(channel, bufferSize));
+    return streamId;
+  }
+
+  public void addCredit(int numCredit, long streamId) {}
+
+  public void connectionTerminated(Channel channel) {
+    for (Map.Entry<Long, StreamState> entry : streams.entrySet()) {
+      if (entry.getValue().getAssociatedChannel() == channel) {
+        streams.remove(entry.getKey());
+      }
+    }
+  }
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 62b846de..26eeafbc 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -18,10 +18,12 @@
 package org.apache.celeborn.service.deploy.worker
 
 import java.io.{FileNotFoundException, IOException}
+import java.nio.ByteBuffer
 import java.nio.charset.StandardCharsets
 import java.util.concurrent.atomic.AtomicBoolean
 
 import com.google.common.base.Throwables
+import io.netty.buffer.{ByteBuf, Unpooled}
 import io.netty.util.concurrent.{Future, GenericFutureListener}
 
 import org.apache.celeborn.common.exception.CelebornException
@@ -31,13 +33,14 @@ import org.apache.celeborn.common.metrics.source.RPCSource
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer
 import org.apache.celeborn.common.network.client.TransportClient
 import org.apache.celeborn.common.network.protocol._
-import org.apache.celeborn.common.network.server.{BaseMessageHandler, 
ChunkStreamManager}
+import org.apache.celeborn.common.network.server.{BaseMessageHandler, 
BufferStreamManager, ChunkStreamManager}
 import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
 import org.apache.celeborn.common.protocol.PartitionType
 import 
org.apache.celeborn.service.deploy.worker.storage.{PartitionFilesSorter, 
StorageManager}
 
 class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with 
Logging {
   var chunkStreamManager = new ChunkStreamManager()
+  val bufferStreamManager = new BufferStreamManager()
   var workerSource: WorkerSource = _
   var rpcSource: RPCSource = _
   var storageManager: StorageManager = _
@@ -67,6 +70,9 @@ class FetchHandler(val conf: TransportConf) extends 
BaseMessageHandler with Logg
 
   override def receive(client: TransportClient, msg: RequestMessage): Unit = {
     msg match {
+      case r: ReadAddCredit =>
+        rpcSource.updateMessageMetrics(r, 0)
+        handleReadAddCredit(client, r)
       case r: ChunkFetchRequest =>
         rpcSource.updateMessageMetrics(r, 0)
         handleChunkFetchRequest(client, r)
@@ -120,6 +126,14 @@ class FetchHandler(val conf: TransportConf) extends 
BaseMessageHandler with Logg
               new NioManagedBuffer(streamHandle.toByteBuffer)))
           }
         case PartitionType.MAP =>
+          // return stream id
+          val streamId =
+            bufferStreamManager.registerStream(client.getChannel, 
fileInfo.getBufferSize)
+          val res = ByteBuffer.allocate(8)
+          res.putLong(streamId)
+          client.getChannel.writeAndFlush(new RpcResponse(
+            request.requestId,
+            new NioManagedBuffer(res)))
         case PartitionType.MAPGROUP =>
       } catch {
         case e: IOException =>
@@ -141,6 +155,10 @@ class FetchHandler(val conf: TransportConf) extends 
BaseMessageHandler with Logg
     }
   }
 
+  def handleReadAddCredit(client: TransportClient, req: ReadAddCredit): Unit = 
{
+    bufferStreamManager.addCredit(req.getCredit, req.getStreamId)
+  }
+
   def handleChunkFetchRequest(client: TransportClient, req: 
ChunkFetchRequest): Unit = {
     workerSource.startTimer(WorkerSource.FetchChunkTime, req.toString)
     logTrace(s"Received req from 
${NettyUtils.getRemoteAddress(client.getChannel)}" +
@@ -188,6 +206,7 @@ class FetchHandler(val conf: TransportConf) extends 
BaseMessageHandler with Logg
 
   override def channelInactive(client: TransportClient): Unit = {
     chunkStreamManager.connectionTerminated(client.getChannel)
+    bufferStreamManager.connectionTerminated(client.getChannel)
     logDebug(s"channel inactive ${client.getSocketAddress}")
   }
 

Reply via email to