This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 5595f2f4 [CELEBORN-124]Add buffer stream. (#1069)
5595f2f4 is described below
commit 5595f2f4b3de2c1dbbbaba39d7b62d3efee65c5d
Author: Ethan Feng <[email protected]>
AuthorDate: Fri Jan 6 15:54:52 2023 +0800
[CELEBORN-124]Add buffer stream. (#1069)
---
.../org/apache/celeborn/common/meta/FileInfo.java | 4 +
.../celeborn/common/network/protocol/Message.java | 15 ++-
.../common/network/protocol/ReadAddCredit.java | 79 +++++++++++++++
.../celeborn/common/network/protocol/ReadData.java | 111 +++++++++++++++++++++
.../common/network/server/BufferStreamManager.java | 72 +++++++++++++
.../service/deploy/worker/FetchHandler.scala | 21 +++-
6 files changed, 300 insertions(+), 2 deletions(-)
diff --git a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
index 50d3d8bb..6eefb83c 100644
--- a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
+++ b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
@@ -182,4 +182,8 @@ public class FileInfo {
public void setBufferSize(int bufferSize) {
this.bufferSize = bufferSize;
}
+
+ public int getBufferSize() {
+ return bufferSize;
+ }
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
index 02061426..7ba54cbd 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
@@ -84,7 +84,9 @@ public abstract class Message implements Encodable {
PUSH_MERGED_DATA(12),
REGION_START(13),
REGION_FINISH(14),
- PUSH_DATA_HAND_SHAKE(15);
+ PUSH_DATA_HAND_SHAKE(15),
+ READ_ADD_CREDIT(16),
+ READ_DATA(17);
private final byte id;
@@ -138,6 +140,10 @@ public abstract class Message implements Encodable {
return REGION_FINISH;
case 15:
return PUSH_DATA_HAND_SHAKE;
+ case 16:
+ return READ_ADD_CREDIT;
+ case 17:
+ return READ_DATA;
case -1:
throw new IllegalArgumentException("User type messages cannot be
decoded.");
default:
@@ -193,6 +199,13 @@ public abstract class Message implements Encodable {
case PUSH_DATA_HAND_SHAKE:
return PushDataHandShake.decode(in);
+
+ case READ_ADD_CREDIT:
+ return ReadAddCredit.decode(in);
+
+ case READ_DATA:
+ return ReadData.decode(in);
+
default:
throw new IllegalArgumentException("Unexpected message type: " +
msgType);
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
new file mode 100644
index 00000000..ca34a5c1
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.common.network.protocol;
+
+import java.util.Objects;
+
+import io.netty.buffer.ByteBuf;
+
+public class ReadAddCredit extends RequestMessage {
+ private long streamId;
+ private int credit;
+
+ public ReadAddCredit(long streamId, int credit) {
+ this.streamId = streamId;
+ this.credit = credit;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(streamId);
+ buf.writeInt(credit);
+ }
+
+ public static ReadAddCredit decode(ByteBuf buf) {
+ long streamId = buf.readLong();
+ int credit = buf.readInt();
+ return new ReadAddCredit(streamId, credit);
+ }
+
+ public long getStreamId() {
+ return streamId;
+ }
+
+ public int getCredit() {
+ return credit;
+ }
+
+ @Override
+ public Type type() {
+ return Type.READ_ADD_CREDIT;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ReadAddCredit that = (ReadAddCredit) o;
+ return streamId == that.streamId && credit == that.credit;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(streamId, credit);
+ }
+
+ @Override
+ public String toString() {
+ return "ReadAddCredit{" + "streamId=" + streamId + ", credit=" + credit +
'}';
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
new file mode 100644
index 00000000..92e107cc
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.common.network.protocol;
+
+import java.util.Objects;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+public class ReadData extends RequestMessage {
+ private long streamId;
+ private int backlog;
+ private long offset;
+ private ByteBuf buf;
+
+ public ReadData(long streamId, int backlog, long offset, ByteBuf buf) {
+ this.streamId = streamId;
+ this.backlog = backlog;
+ this.offset = offset;
+ this.buf = buf;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + 4 + 8 + 4 + buf.readableBytes();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(streamId);
+ buf.writeInt(backlog);
+ buf.writeLong(offset);
+ buf.writeInt(this.buf.readableBytes());
+ buf.writeBytes(this.buf);
+ }
+
+ public static ReadData decode(ByteBuf buf) {
+ long streamId = buf.readLong();
+ int backlog = buf.readInt();
+ long offset = buf.readLong();
+ int tmpBufSize = buf.readInt();
+ ByteBuf tmpBuf = Unpooled.buffer(tmpBufSize, tmpBufSize);
+ buf.readBytes(tmpBuf);
+ return new ReadData(streamId, backlog, offset, tmpBuf);
+ }
+
+ public long getStreamId() {
+ return streamId;
+ }
+
+ public int getBacklog() {
+ return backlog;
+ }
+
+ public long getOffset() {
+ return offset;
+ }
+
+ public ByteBuf getBuf() {
+ return buf;
+ }
+
+ @Override
+ public Type type() {
+ return Type.READ_DATA;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ReadData readData = (ReadData) o;
+ return streamId == readData.streamId
+ && backlog == readData.backlog
+ && offset == readData.offset
+ && Objects.equals(buf, readData.buf);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(streamId, backlog, offset, buf);
+ }
+
+ @Override
+ public String toString() {
+ return "ReadData{"
+ + "streamId="
+ + streamId
+ + ", backlog="
+ + backlog
+ + ", offset="
+ + offset
+ + ", buf="
+ + buf
+ + '}';
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
new file mode 100644
index 00000000..bba500b8
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class BufferStreamManager {
+ private static final Logger logger =
LoggerFactory.getLogger(BufferStreamManager.class);
+ private final AtomicLong nextStreamId;
+ protected final ConcurrentHashMap<Long, StreamState> streams;
+
+ protected class StreamState {
+ private Channel associatedChannel;
+ private int bufferSize;
+
+ public StreamState(Channel associatedChannel, int bufferSize) {
+ this.associatedChannel = associatedChannel;
+ this.bufferSize = bufferSize;
+ }
+
+ public Channel getAssociatedChannel() {
+ return associatedChannel;
+ }
+
+ public int getBufferSize() {
+ return bufferSize;
+ }
+ }
+
+ public BufferStreamManager() {
+ nextStreamId = new AtomicLong((long) new
Random().nextInt(Integer.MAX_VALUE) * 1000);
+ streams = new ConcurrentHashMap<>();
+ }
+
+ public long registerStream(Channel channel, int bufferSize) {
+ long streamId = nextStreamId.getAndIncrement();
+ streams.put(streamId, new StreamState(channel, bufferSize));
+ return streamId;
+ }
+
+ public void addCredit(int numCredit, long streamId) {}
+
+ public void connectionTerminated(Channel channel) {
+ for (Map.Entry<Long, StreamState> entry : streams.entrySet()) {
+ if (entry.getValue().getAssociatedChannel() == channel) {
+ streams.remove(entry.getKey());
+ }
+ }
+ }
+}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 62b846de..26eeafbc 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -18,10 +18,12 @@
package org.apache.celeborn.service.deploy.worker
import java.io.{FileNotFoundException, IOException}
+import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.concurrent.atomic.AtomicBoolean
import com.google.common.base.Throwables
+import io.netty.buffer.{ByteBuf, Unpooled}
import io.netty.util.concurrent.{Future, GenericFutureListener}
import org.apache.celeborn.common.exception.CelebornException
@@ -31,13 +33,14 @@ import org.apache.celeborn.common.metrics.source.RPCSource
import org.apache.celeborn.common.network.buffer.NioManagedBuffer
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol._
-import org.apache.celeborn.common.network.server.{BaseMessageHandler,
ChunkStreamManager}
+import org.apache.celeborn.common.network.server.{BaseMessageHandler,
BufferStreamManager, ChunkStreamManager}
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
import org.apache.celeborn.common.protocol.PartitionType
import
org.apache.celeborn.service.deploy.worker.storage.{PartitionFilesSorter,
StorageManager}
class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with
Logging {
var chunkStreamManager = new ChunkStreamManager()
+ val bufferStreamManager = new BufferStreamManager()
var workerSource: WorkerSource = _
var rpcSource: RPCSource = _
var storageManager: StorageManager = _
@@ -67,6 +70,9 @@ class FetchHandler(val conf: TransportConf) extends
BaseMessageHandler with Logg
override def receive(client: TransportClient, msg: RequestMessage): Unit = {
msg match {
+ case r: ReadAddCredit =>
+ rpcSource.updateMessageMetrics(r, 0)
+ handleReadAddCredit(client, r)
case r: ChunkFetchRequest =>
rpcSource.updateMessageMetrics(r, 0)
handleChunkFetchRequest(client, r)
@@ -120,6 +126,14 @@ class FetchHandler(val conf: TransportConf) extends
BaseMessageHandler with Logg
new NioManagedBuffer(streamHandle.toByteBuffer)))
}
case PartitionType.MAP =>
+ // return stream id
+ val streamId =
+ bufferStreamManager.registerStream(client.getChannel,
fileInfo.getBufferSize)
+ val res = ByteBuffer.allocate(8)
+ res.putLong(streamId)
+ client.getChannel.writeAndFlush(new RpcResponse(
+ request.requestId,
+ new NioManagedBuffer(res)))
case PartitionType.MAPGROUP =>
} catch {
case e: IOException =>
@@ -141,6 +155,10 @@ class FetchHandler(val conf: TransportConf) extends
BaseMessageHandler with Logg
}
}
+ def handleReadAddCredit(client: TransportClient, req: ReadAddCredit): Unit =
{
+ bufferStreamManager.addCredit(req.getCredit, req.getStreamId)
+ }
+
def handleChunkFetchRequest(client: TransportClient, req:
ChunkFetchRequest): Unit = {
workerSource.startTimer(WorkerSource.FetchChunkTime, req.toString)
logTrace(s"Received req from
${NettyUtils.getRemoteAddress(client.getChannel)}" +
@@ -188,6 +206,7 @@ class FetchHandler(val conf: TransportConf) extends
BaseMessageHandler with Logg
override def channelInactive(client: TransportClient): Unit = {
chunkStreamManager.connectionTerminated(client.getChannel)
+ bufferStreamManager.connectionTerminated(client.getChannel)
logDebug(s"channel inactive ${client.getSocketAddress}")
}