This is an automated email from the ASF dual-hosted git repository.

runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new cca0e6c  RATIS-1121. Support multiple streams. (#243)
cca0e6c is described below

commit cca0e6cab13c900219996782472e7086326df120
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Sat Oct 31 09:13:49 2020 +0800

    RATIS-1121. Support multiple streams. (#243)
    
    * RATIS-1121. Support multiple streams.
    
    * Change back byteWritten to int and writeRequest to non-volatile.
---
 .../ratis/netty/client/NettyClientStreamRpc.java   | 20 ++++---
 .../ratis/netty/server/NettyServerStreamRpc.java   |  2 +
 .../ratis/datastream/TestDataStreamBase.java       | 70 +++++++++++++++-------
 .../ratis/datastream/TestDataStreamNetty.java      |  8 +--
 4 files changed, 67 insertions(+), 33 deletions(-)

diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
index e769e2c..98ddc8f 100644
--- 
a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
@@ -38,11 +38,13 @@ import org.apache.ratis.util.NetUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Optional;
 import java.util.Queue;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.ConcurrentMap;
 import java.util.function.Supplier;
 
 public class NettyClientStreamRpc implements DataStreamClientRpc {
@@ -51,7 +53,7 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
   private final RaftPeer server;
   private final EventLoopGroup workerGroup = new NioEventLoopGroup();
   private final Supplier<Channel> channel;
-  private final Queue<CompletableFuture<DataStreamReply>> replies = new 
LinkedList<>();
+  private final ConcurrentMap<Long, Queue<CompletableFuture<DataStreamReply>>> 
replies = new ConcurrentHashMap<>();
 
   public NettyClientStreamRpc(RaftPeer server, RaftProperties properties){
     this.server = server;
@@ -69,16 +71,15 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
     return channel.get();
   }
 
-  synchronized CompletableFuture<DataStreamReply> pollReply() {
-    return replies.poll();
-  }
-
   private ChannelInboundHandler getClientHandler(){
     return new ChannelInboundHandlerAdapter(){
       @Override
       public void channelRead(ChannelHandlerContext ctx, Object msg) {
         final DataStreamReply reply = (DataStreamReply) msg;
-        pollReply().complete(reply);
+        LOG.debug("{}: read {}", this, reply);
+        Optional.ofNullable(replies.get(reply.getStreamId()))
+            .map(Queue::poll)
+            .ifPresent(f -> f.complete(reply));
       }
     };
   }
@@ -120,7 +121,10 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
   @Override
   public synchronized CompletableFuture<DataStreamReply> 
streamAsync(DataStreamRequest request) {
     CompletableFuture<DataStreamReply> f = new CompletableFuture<>();
-    replies.offer(f);
+    final Queue<CompletableFuture<DataStreamReply>> q = 
replies.computeIfAbsent(
+        request.getStreamId(), key -> new ConcurrentLinkedQueue<>());
+    q.offer(f);
+    LOG.debug("{}: write {}", this, request);
     getChannel().writeAndFlush(request);
     return f;
   }
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
 
b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
index 4743fcd..9c0226e 100644
--- 
a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
@@ -224,6 +224,7 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
   private void sendReplySuccess(DataStreamRequestByteBuf request, long 
bytesWritten, ChannelHandlerContext ctx) {
     final DataStreamReplyByteBuffer reply = new DataStreamReplyByteBuffer(
         request.getStreamId(), request.getStreamOffset(), null, bytesWritten, 
true, request.getType());
+    LOG.debug("{}: write {}", this, reply);
     ctx.writeAndFlush(reply);
   }
 
@@ -246,6 +247,7 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
   }
 
   private void read(ChannelHandlerContext ctx, DataStreamRequestByteBuf 
request) {
+    LOG.debug("{}: read {}", this, request);
     final ByteBuf buf = request.slice();
     final boolean isHeader = request.getType() == Type.STREAM_HEADER;
 
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java
index 01b9fae..2c10861 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java
@@ -19,8 +19,8 @@ package org.apache.ratis.datastream;
 
 import org.apache.ratis.BaseTest;
 import org.apache.ratis.MiniRaftCluster;
-import org.apache.ratis.client.api.DataStreamOutput;
 import org.apache.ratis.client.impl.DataStreamClientImpl;
+import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.protocol.DataStreamReply;
 import org.apache.ratis.protocol.RaftClientRequest;
@@ -30,6 +30,7 @@ import 
org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
 import org.apache.ratis.server.DataStreamServerRpc;
 import org.apache.ratis.server.impl.DataStreamServerImpl;
 import org.apache.ratis.statemachine.impl.BaseStateMachine;
+import org.apache.ratis.statemachine.StateMachine.DataStream;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.NetUtils;
 import org.junit.Assert;
@@ -41,6 +42,8 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.stream.Collectors;
 
@@ -51,7 +54,22 @@ class TestDataStreamBase extends BaseTest {
     return (byte) ('A' + pos%MODULUS);
   }
 
-  static class SingleDataStreamStateMachine extends BaseStateMachine {
+  static class MultiDataStreamStateMachine extends BaseStateMachine {
+    final ConcurrentMap<Long, SingleDataStream> streams = new 
ConcurrentHashMap<>();
+
+    @Override
+    public CompletableFuture<DataStream> stream(RaftClientRequest request) {
+      final SingleDataStream s = new SingleDataStream();
+      streams.put(request.getCallId(), s);
+      return s.stream(request);
+    }
+
+    SingleDataStream getSingleDataStream(long callId) {
+      return streams.get(callId);
+    }
+  }
+
+  static class SingleDataStream {
     private int byteWritten = 0;
     private RaftClientRequest writeRequest;
 
@@ -99,7 +117,6 @@ class TestDataStreamBase extends BaseTest {
       }
     };
 
-    @Override
     public CompletableFuture<DataStream> stream(RaftClientRequest request) {
       writeRequest = request;
       return CompletableFuture.completedFuture(stream);
@@ -119,17 +136,17 @@ class TestDataStreamBase extends BaseTest {
 
   private List<DataStreamServerImpl> servers;
   private List<RaftPeer> peers;
-  private List<SingleDataStreamStateMachine> singleDataStreamStateMachines;
+  private List<MultiDataStreamStateMachine> stateMachines;
 
   protected void setupServer(){
     servers = new ArrayList<>(peers.size());
-    singleDataStreamStateMachines = new ArrayList<>(peers.size());
+    stateMachines = new ArrayList<>(peers.size());
     // start stream servers on raft peers.
     for (int i = 0; i < peers.size(); i++) {
-      SingleDataStreamStateMachine singleDataStreamStateMachine = new 
SingleDataStreamStateMachine();
-      singleDataStreamStateMachines.add(singleDataStreamStateMachine);
+      final MultiDataStreamStateMachine stateMachine = new 
MultiDataStreamStateMachine();
+      stateMachines.add(stateMachine);
       final DataStreamServerImpl streamServer = new DataStreamServerImpl(
-          peers.get(i), singleDataStreamStateMachine, properties, null);
+          peers.get(i), stateMachine, properties, null);
       final DataStreamServerRpc rpc = streamServer.getServerRpc();
       if (i == 0) {
         // only the first server routes requests to peers.
@@ -160,21 +177,27 @@ class TestDataStreamBase extends BaseTest {
         .collect(Collectors.toList());
   }
 
-  protected void runTestDataStream(int numServers, int bufferSize, int 
bufferNum) throws Exception {
+  protected void runTestDataStream(int numServers, int numStreams, int 
bufferSize, int bufferNum) throws Exception {
     setupRaftPeers(numServers);
     try {
       setupServer();
       setupClient();
-      runTestDataStream(bufferSize, bufferNum);
+      runTestDataStream(numStreams, bufferSize, bufferNum);
     } finally {
       shutdown();
     }
   }
 
-  private void runTestDataStream(int bufferSize, int bufferNum) {
-    final DataStreamOutput out = client.stream();
-    DataStreamClientImpl.DataStreamOutputImpl impl = 
(DataStreamClientImpl.DataStreamOutputImpl) out;
+  private void runTestDataStream(int numStreams, int bufferSize, int 
bufferNum) {
+    final List<CompletableFuture<Void>> futures = new ArrayList<>();
+    for (int i = 0; i < numStreams; i++) {
+      futures.add(CompletableFuture.runAsync(
+          () -> runTestDataStream((DataStreamOutputImpl) client.stream(), 
bufferSize, bufferNum)));
+    }
+    futures.forEach(CompletableFuture::join);
+  }
 
+  private void runTestDataStream(DataStreamOutputImpl out, int bufferSize, int 
bufferNum) {
     final List<CompletableFuture<DataStreamReply>> futures = new ArrayList<>();
     final List<Integer> sizes = new ArrayList<>();
 
@@ -191,7 +214,7 @@ class TestDataStreamBase extends BaseTest {
     }
 
     { // check header
-      final DataStreamReply reply = impl.getHeaderFuture().join();
+      final DataStreamReply reply = out.getHeaderFuture().join();
       Assert.assertTrue(reply.isSuccess());
       Assert.assertEquals(0, reply.getBytesWritten());
       Assert.assertEquals(reply.getType(), Type.STREAM_HEADER);
@@ -205,14 +228,19 @@ class TestDataStreamBase extends BaseTest {
       Assert.assertEquals(reply.getType(), Type.STREAM_DATA);
     }
 
-    for (SingleDataStreamStateMachine s : singleDataStreamStateMachines) {
-      RaftClientRequest writeRequest = s.getWriteRequest();
-      if (writeRequest.getClientId().equals(impl.getHeader().getClientId())) {
-        Assert.assertEquals(writeRequest.getCallId(), 
impl.getHeader().getCallId());
-        Assert.assertEquals(writeRequest.getRaftGroupId(), 
impl.getHeader().getRaftGroupId());
-        Assert.assertEquals(writeRequest.getServerId(), 
impl.getHeader().getServerId());
+    final RaftClientRequest header = out.getHeader();
+    for (MultiDataStreamStateMachine s : stateMachines) {
+      final SingleDataStream stream = 
s.getSingleDataStream(header.getCallId());
+      if (stream == null) {
+        continue;
+      }
+      final RaftClientRequest writeRequest = stream.getWriteRequest();
+      if (writeRequest.getClientId().equals(header.getClientId())) {
+        Assert.assertEquals(writeRequest.getCallId(), header.getCallId());
+        Assert.assertEquals(writeRequest.getRaftGroupId(), 
header.getRaftGroupId());
+        Assert.assertEquals(writeRequest.getServerId(), header.getServerId());
       }
-      Assert.assertEquals(dataSize, s.getByteWritten());
+      Assert.assertEquals(dataSize, stream.getByteWritten());
     }
   }
 
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
index 2725b01..7549189 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
@@ -33,13 +33,13 @@ public class TestDataStreamNetty extends TestDataStreamBase 
{
 
   @Test
   public void testDataStreamSingleServer() throws Exception {
-    runTestDataStream(1, 1_000_000, 100);
-    runTestDataStream(1,1_000, 10_000);
+    runTestDataStream(1, 5, 1_000_000, 100);
+    runTestDataStream(1,5, 1_000, 10_000);
   }
 
   @Test
   public void testDataStreamMultipleServer() throws Exception {
-    runTestDataStream(3, 1_000_000, 100);
-    runTestDataStream(3, 1_000, 10_000);
+    runTestDataStream(3, 5, 1_000_000, 100);
+    runTestDataStream(3, 5, 1_000, 10_000);
   }
 }

Reply via email to