This is an automated email from the ASF dual-hosted git repository.

runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new f939700  RATIS-1123. Support multiple DataStream clients. (#244)
f939700 is described below

commit f939700549fa67d3e0cdea7103941822db9d0d55
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Sat Oct 31 14:45:23 2020 +0800

    RATIS-1123. Support multiple DataStream clients. (#244)
---
 .../ratis/client/impl/OrderedStreamAsync.java      |  2 +-
 .../apache/ratis/netty/NettyDataStreamUtils.java   | 14 ++---
 .../ratis/netty/client/NettyClientStreamRpc.java   | 15 +++--
 .../ratis/netty/server/NettyServerStreamRpc.java   | 69 ++++++++++++++++++++--
 .../ratis/datastream/TestDataStreamBase.java       | 50 +++++++++-------
 .../ratis/datastream/TestDataStreamDisabled.java   |  6 +-
 .../ratis/datastream/TestDataStreamNetty.java      |  8 +--
 7 files changed, 117 insertions(+), 47 deletions(-)

diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
index 6bfca51..12b2bb7 100644
--- 
a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
+++ 
b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java
@@ -84,7 +84,7 @@ public class OrderedStreamAsync {
   OrderedStreamAsync(ClientId clientId, DataStreamClientRpc 
dataStreamClientRpc, RaftProperties properties){
     this.dataStreamClientRpc = dataStreamClientRpc;
     this.slidingWindow = new SlidingWindow.Client<>(clientId);
-    this.requestSemaphore = new 
Semaphore(RaftClientConfigKeys.DataStream.outstandingRequestsMax(properties)*2);
+    this.requestSemaphore = new 
Semaphore(RaftClientConfigKeys.DataStream.outstandingRequestsMax(properties));
   }
 
   CompletableFuture<DataStreamReply> sendRequest(long streamId, long offset, 
ByteBuffer data, Type type){
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java 
b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
index 0aadfea..d3b8834 100644
--- a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
+++ b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamUtils.java
@@ -27,7 +27,7 @@ import org.apache.ratis.protocol.DataStreamPacketHeader;
 import org.apache.ratis.protocol.DataStreamReplyHeader;
 import org.apache.ratis.protocol.DataStreamRequestHeader;
 import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
-import org.apache.ratis.thirdparty.io.netty.buffer.PooledByteBufAllocator;
+import org.apache.ratis.thirdparty.io.netty.buffer.ByteBufAllocator;
 import org.apache.ratis.thirdparty.io.netty.buffer.Unpooled;
 
 import java.nio.ByteBuffer;
@@ -69,20 +69,20 @@ public interface NettyDataStreamUtils {
         .asReadOnlyByteBuffer();
   }
 
-  static void encodeDataStreamRequestByteBuffer(DataStreamRequestByteBuffer 
request, Consumer<ByteBuf> out) {
+  static void encodeDataStreamRequestByteBuffer(DataStreamRequestByteBuffer 
request, Consumer<ByteBuf> out,
+      ByteBufAllocator allocator) {
     ByteBuffer headerBuf = getDataStreamRequestHeaderProtoByteBuf(request);
-    final ByteBuf headerLenBuf =
-        
PooledByteBufAllocator.DEFAULT.directBuffer(DataStreamPacketHeader.getSizeOfHeaderLen());
+    final ByteBuf headerLenBuf = 
allocator.directBuffer(DataStreamPacketHeader.getSizeOfHeaderLen());
     headerLenBuf.writeInt(headerBuf.remaining());
     out.accept(headerLenBuf);
     out.accept(Unpooled.wrappedBuffer(headerBuf));
     out.accept(Unpooled.wrappedBuffer(request.slice()));
   }
 
-  static void encodeDataStreamReplyByteBuffer(DataStreamReplyByteBuffer reply, 
Consumer<ByteBuf> out) {
+  static void encodeDataStreamReplyByteBuffer(DataStreamReplyByteBuffer reply, 
Consumer<ByteBuf> out,
+      ByteBufAllocator allocator) {
     ByteBuffer headerBuf = getDataStreamReplyHeaderProtoByteBuf(reply);
-    final ByteBuf headerLenBuf =
-        
PooledByteBufAllocator.DEFAULT.directBuffer(DataStreamPacketHeader.getSizeOfHeaderLen());
+    final ByteBuf headerLenBuf = 
allocator.directBuffer(DataStreamPacketHeader.getSizeOfHeaderLen());
     headerLenBuf.writeInt(headerBuf.remaining());
     out.accept(headerLenBuf);
     out.accept(Unpooled.wrappedBuffer(headerBuf));
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
index 98ddc8f..3f4b791 100644
--- 
a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java
@@ -75,6 +75,10 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
     return new ChannelInboundHandlerAdapter(){
       @Override
       public void channelRead(ChannelHandlerContext ctx, Object msg) {
+        if (!(msg instanceof DataStreamReply)) {
+          LOG.error("{}: unexpected message {}", this, msg.getClass());
+          return;
+        }
         final DataStreamReply reply = (DataStreamReply) msg;
         LOG.debug("{}: read {}", this, reply);
         Optional.ofNullable(replies.get(reply.getStreamId()))
@@ -100,7 +104,7 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
     return new MessageToMessageEncoder<DataStreamRequestByteBuffer>() {
       @Override
       protected void encode(ChannelHandlerContext context, 
DataStreamRequestByteBuffer request, List<Object> out) {
-        NettyDataStreamUtils.encodeDataStreamRequestByteBuffer(request, 
out::add);
+        NettyDataStreamUtils.encodeDataStreamRequestByteBuffer(request, 
out::add, context.alloc());
       }
     };
   }
@@ -119,11 +123,14 @@ public class NettyClientStreamRpc implements 
DataStreamClientRpc {
   }
 
   @Override
-  public synchronized CompletableFuture<DataStreamReply> 
streamAsync(DataStreamRequest request) {
-    CompletableFuture<DataStreamReply> f = new CompletableFuture<>();
+  public CompletableFuture<DataStreamReply> streamAsync(DataStreamRequest 
request) {
+    final CompletableFuture<DataStreamReply> f = new CompletableFuture<>();
     final Queue<CompletableFuture<DataStreamReply>> q = 
replies.computeIfAbsent(
         request.getStreamId(), key -> new ConcurrentLinkedQueue<>());
-    q.offer(f);
+    if (!q.offer(f)) {
+      f.completeExceptionally(new IllegalStateException(this + ": Failed to 
offer a future for " + request));
+      return f;
+    }
     LOG.debug("{}: write {}", this, request);
     getChannel().writeAndFlush(request);
     return f;
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
 
b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
index 9c0226e..440ccda 100644
--- 
a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
@@ -58,6 +58,7 @@ import java.nio.channels.WritableByteChannel;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
@@ -67,6 +68,7 @@ import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.CopyOnWriteArraySet;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 
 public class NettyServerStreamRpc implements DataStreamServerRpc {
   public static final Logger LOG = 
LoggerFactory.getLogger(NettyServerStreamRpc.class);
@@ -118,12 +120,14 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
   }
 
   static class StreamInfo {
+    private final RaftClientRequest request;
     private final CompletableFuture<DataStream> stream;
     private final List<DataStreamOutput> outs;
     private final AtomicReference<CompletableFuture<?>> previous
         = new AtomicReference<>(CompletableFuture.completedFuture(null));
 
-    StreamInfo(CompletableFuture<DataStream> stream, List<DataStreamOutput> 
outs) {
+    StreamInfo(RaftClientRequest request, CompletableFuture<DataStream> 
stream, List<DataStreamOutput> outs) {
+      this.request = request;
       this.stream = stream;
       this.outs = outs;
     }
@@ -139,6 +143,58 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
     AtomicReference<CompletableFuture<?>> getPrevious() {
       return previous;
     }
+
+    @Override
+    public String toString() {
+      return getClass().getSimpleName() + ":" + request;
+    }
+  }
+
+  static class StreamMap {
+    static class Key {
+      private final ChannelId channelId;
+      private final long streamId;
+
+      Key(ChannelId channelId, long streamId) {
+        this.channelId = channelId;
+        this.streamId = streamId;
+      }
+
+      @Override
+      public boolean equals(Object obj) {
+        if (this == obj) {
+          return true;
+        } else if (obj == null || getClass() != obj.getClass()) {
+          return false;
+        }
+        final Key that = (Key) obj;
+        return this.streamId == that.streamId && 
Objects.equals(this.channelId, that.channelId);
+      }
+
+      @Override
+      public int hashCode() {
+        return Objects.hash(channelId, streamId);
+      }
+
+      @Override
+      public String toString() {
+        return channelId + "-" + streamId;
+      }
+    }
+
+    private final ConcurrentMap<Key, StreamInfo> map = new 
ConcurrentHashMap<>();
+
+    StreamInfo computeIfAbsent(Key key, Function<Key, StreamInfo> function) {
+      final StreamInfo info = map.computeIfAbsent(key, function);
+      LOG.debug("computeIfAbsent({}) returns {}", key, info);
+      return info;
+    }
+
+    StreamInfo get(Key key) {
+      final StreamInfo info = map.get(key);
+      LOG.debug("get({}) returns {}", key, info);
+      return info;
+    }
   }
 
   private final RaftServer server;
@@ -148,7 +204,7 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
   private final ChannelFuture channelFuture;
 
   private final StateMachine stateMachine;
-  private final ConcurrentMap<Long, StreamInfo> streams = new 
ConcurrentHashMap<>();
+  private final StreamMap streams = new StreamMap();
 
   private final Proxies proxies;
 
@@ -192,7 +248,7 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
     try {
       final RaftClientRequest request = ClientProtoUtils.toRaftClientRequest(
           RaftClientRequestProto.parseFrom(buf.nioBuffer()));
-      return new StreamInfo(stateMachine.data().stream(request), 
proxies.getDataStreamOutput());
+      return new StreamInfo(request, stateMachine.data().stream(request), 
proxies.getDataStreamOutput());
     } catch (Throwable e) {
       throw new CompletionException(e);
     }
@@ -254,14 +310,15 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
     final StreamInfo info;
     final CompletableFuture<Long> localWrite;
     final List<CompletableFuture<DataStreamReply>> remoteWrites = new 
ArrayList<>();
+    final StreamMap.Key key = new StreamMap.Key(ctx.channel().id(), 
request.getStreamId());
     if (isHeader) {
-      info = streams.computeIfAbsent(request.getStreamId(), id -> 
newStreamInfo(buf));
+      info = streams.computeIfAbsent(key, id -> newStreamInfo(buf));
       localWrite = CompletableFuture.completedFuture(0L);
       for (DataStreamOutput out : info.getDataStreamOutputs()) {
         remoteWrites.add(out.getHeaderFuture());
       }
     } else {
-      info = streams.get(request.getStreamId());
+      info = streams.get(key);
       localWrite = info.getStream().thenApply(stream -> writeTo(buf, stream));
       for (DataStreamOutput out : info.getDataStreamOutputs()) {
         remoteWrites.add(out.writeAsync(request.slice().nioBuffer()));
@@ -319,7 +376,7 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
     return new MessageToMessageEncoder<DataStreamReplyByteBuffer>() {
       @Override
       protected void encode(ChannelHandlerContext context, 
DataStreamReplyByteBuffer reply, List<Object> out) {
-        NettyDataStreamUtils.encodeDataStreamReplyByteBuffer(reply, out::add);
+        NettyDataStreamUtils.encodeDataStreamReplyByteBuffer(reply, out::add, 
context.alloc());
       }
     };
   }
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java
index 2c10861..80c5a80 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamBase.java
@@ -132,13 +132,16 @@ class TestDataStreamBase extends BaseTest {
   }
 
   protected RaftProperties properties;
-  protected DataStreamClientImpl client;
 
   private List<DataStreamServerImpl> servers;
   private List<RaftPeer> peers;
   private List<MultiDataStreamStateMachine> stateMachines;
 
-  protected void setupServer(){
+  protected void setup(int numServers){
+    peers = Arrays.stream(MiniRaftCluster.generateIds(numServers, 0))
+        .map(RaftPeerId::valueOf)
+        .map(id -> new RaftPeer(id, NetUtils.createLocalServerAddress()))
+        .collect(Collectors.toList());
     servers = new ArrayList<>(peers.size());
     stateMachines = new ArrayList<>(peers.size());
     // start stream servers on raft peers.
@@ -159,42 +162,45 @@ class TestDataStreamBase extends BaseTest {
     }
   }
 
-  protected void setupClient(){
-    client = new DataStreamClientImpl(peers.get(0), properties, null);
+  DataStreamClientImpl newDataStreamClientImpl() {
+    return new DataStreamClientImpl(peers.get(0), properties, null);
   }
 
   protected void shutdown() throws IOException {
-    client.close();
     for (DataStreamServerImpl server : servers) {
       server.close();
     }
   }
 
-  protected void setupRaftPeers(int numServers) {
-    peers = Arrays.stream(MiniRaftCluster.generateIds(numServers, 0))
-        .map(RaftPeerId::valueOf)
-        .map(id -> new RaftPeer(id, NetUtils.createLocalServerAddress()))
-        .collect(Collectors.toList());
-  }
-
-  protected void runTestDataStream(int numServers, int numStreams, int 
bufferSize, int bufferNum) throws Exception {
-    setupRaftPeers(numServers);
+  protected void runTestDataStream(int numServers, int numClients, int 
numStreams, int bufferSize, int bufferNum)
+      throws Exception {
     try {
-      setupServer();
-      setupClient();
-      runTestDataStream(numStreams, bufferSize, bufferNum);
+      setup(numServers);
+      runTestDataStream(numClients, numStreams, bufferSize, bufferNum);
     } finally {
       shutdown();
     }
   }
 
-  private void runTestDataStream(int numStreams, int bufferSize, int 
bufferNum) {
+  private void runTestDataStream(int numClients, int numStreams, int 
bufferSize, int bufferNum) throws Exception {
     final List<CompletableFuture<Void>> futures = new ArrayList<>();
-    for (int i = 0; i < numStreams; i++) {
-      futures.add(CompletableFuture.runAsync(
-          () -> runTestDataStream((DataStreamOutputImpl) client.stream(), 
bufferSize, bufferNum)));
+    final List<DataStreamClientImpl> clients = new ArrayList<>();
+    try {
+      for (int j = 0; j < numClients; j++) {
+        final DataStreamClientImpl client = newDataStreamClientImpl();
+        clients.add(client);
+        for (int i = 0; i < numStreams; i++) {
+          futures.add(CompletableFuture.runAsync(
+              () -> runTestDataStream((DataStreamOutputImpl) client.stream(), 
bufferSize, bufferNum)));
+        }
+      }
+      Assert.assertEquals(numClients*numStreams, futures.size());
+      futures.forEach(CompletableFuture::join);
+    } finally {
+      for (int j = 0; j < numClients; j++) {
+        clients.get(j).close();
+      }
     }
-    futures.forEach(CompletableFuture::join);
   }
 
   private void runTestDataStream(DataStreamOutputImpl out, int bufferSize, int 
bufferNum) {
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java
 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java
index e22efb3..0414704 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java
@@ -19,6 +19,7 @@ package org.apache.ratis.datastream;
 
 import org.apache.ratis.RaftConfigKeys;
 import org.apache.ratis.client.DisabledDataStreamClientFactory;
+import org.apache.ratis.client.impl.DataStreamClientImpl;
 import org.apache.ratis.conf.RaftProperties;
 import org.junit.Before;
 import org.junit.Rule;
@@ -37,10 +38,9 @@ public class TestDataStreamDisabled extends 
TestDataStreamBase {
 
   @Test
   public void testDataStreamDisabled() throws Exception {
-    setupRaftPeers(1);
     try {
-      setupServer();
-      setupClient();
+      setup(1);
+      final DataStreamClientImpl client = newDataStreamClientImpl();
       exception.expect(UnsupportedOperationException.class);
       exception.expectMessage(DisabledDataStreamClientFactory.class.getName()
           + "$1 does not support streamAsync");
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
index 7549189..b980dbd 100644
--- 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
+++ 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
@@ -33,13 +33,13 @@ public class TestDataStreamNetty extends TestDataStreamBase 
{
 
   @Test
   public void testDataStreamSingleServer() throws Exception {
-    runTestDataStream(1, 5, 1_000_000, 100);
-    runTestDataStream(1,5, 1_000, 10_000);
+    runTestDataStream(1, 2, 3, 1_000_000, 10);
+    runTestDataStream(1, 2, 3, 1_000, 10_000);
   }
 
   @Test
   public void testDataStreamMultipleServer() throws Exception {
-    runTestDataStream(3, 5, 1_000_000, 100);
-    runTestDataStream(3, 5, 1_000, 10_000);
+    runTestDataStream(3, 2, 3, 1_000_000, 100);
+    runTestDataStream(3, 2, 3, 1_000, 10_000);
   }
 }

Reply via email to