This is an automated email from the ASF dual-hosted git repository.

runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new e99f0d4  RATIS-1099. DataStreamServerRpc should connect other peers 
automatically (#225)
e99f0d4 is described below

commit e99f0d4b4fcf187ec3623fc550ecfa8d4df80cc7
Author: Tsz-Wo Nicholas Sze <[email protected]>
AuthorDate: Wed Oct 21 08:25:13 2020 +0800

    RATIS-1099. DataStreamServerRpc should connect other peers automatically 
(#225)
    
    * RATIS-1099. DataStreamServerRpc should connect other peers automatically.
    
    * Add NettyServerStreamRpc.Proxies.
    
    * Remove unused import.
---
 .../org/apache/ratis/client/DataStreamClient.java  |   7 +-
 .../java/org/apache/ratis/util/PeerProxyMap.java   |   4 +
 .../apache/ratis/netty/NettyDataStreamFactory.java |  12 +-
 .../ratis/netty/server/NettyServerStreamRpc.java   | 166 +++++++++++++--------
 .../org/apache/ratis/server/DataStreamServer.java  |   9 +-
 .../ratis/server/DataStreamServerFactory.java      |  17 +--
 .../apache/ratis/server/DataStreamServerRpc.java   |  16 +-
 .../ratis/server/impl/DataStreamServerImpl.java    |  28 +---
 .../apache/ratis/datastream/TestDataStream.java    |  60 ++++----
 9 files changed, 161 insertions(+), 158 deletions(-)

diff --git 
a/ratis-client/src/main/java/org/apache/ratis/client/DataStreamClient.java 
b/ratis-client/src/main/java/org/apache/ratis/client/DataStreamClient.java
index 201f3a8..b72e1fd 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/DataStreamClient.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/DataStreamClient.java
@@ -25,11 +25,13 @@ import org.apache.ratis.protocol.RaftPeer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.Closeable;
+
 /**
  * A client interface that sends request to the streaming pipeline.
  * Associated with it will be a Netty Client.
  */
-public interface DataStreamClient extends DataStreamApi {
+public interface DataStreamClient extends DataStreamApi, Closeable {
 
   Logger LOG = LoggerFactory.getLogger(DataStreamClient.class);
 
@@ -39,9 +41,6 @@ public interface DataStreamClient extends DataStreamApi {
   /** add information of the raft peers to communicate with */
   void addPeers(Iterable<RaftPeer> peers);
 
-  /** close the client */
-  void close();
-
   /** start the client */
   void start();
 
diff --git a/ratis-common/src/main/java/org/apache/ratis/util/PeerProxyMap.java 
b/ratis-common/src/main/java/org/apache/ratis/util/PeerProxyMap.java
index 5fb6802..80f026e 100644
--- a/ratis-common/src/main/java/org/apache/ratis/util/PeerProxyMap.java
+++ b/ratis-common/src/main/java/org/apache/ratis/util/PeerProxyMap.java
@@ -97,6 +97,10 @@ public class PeerProxyMap<PROXY extends Closeable> 
implements Closeable {
     this.createProxy = this::createProxyImpl;
   }
 
+  public String getName() {
+    return name;
+  }
+
   public PROXY getProxy(RaftPeerId id) throws IOException {
     Objects.requireNonNull(id, "id == null");
     PeerAndProxy p = peers.get(id);
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamFactory.java 
b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamFactory.java
index f9af46e..aa76b58 100644
--- 
a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamFactory.java
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyDataStreamFactory.java
@@ -17,7 +17,6 @@
  */
 package org.apache.ratis.netty;
 
-import java.util.List;
 import org.apache.ratis.client.DataStreamClientRpc;
 import org.apache.ratis.client.DataStreamClientFactory;
 import org.apache.ratis.conf.Parameters;
@@ -44,13 +43,8 @@ public class NettyDataStreamFactory implements 
DataStreamServerFactory, DataStre
   }
 
   @Override
-  public DataStreamServerRpc newDataStreamServerRpc(RaftPeer server, 
StateMachine stateMachine) {
-    return new NettyServerStreamRpc(server, stateMachine);
-  }
-
-  @Override
-  public DataStreamServerRpc newDataStreamServerRpc(
-      RaftPeer server, List<RaftPeer> peers, StateMachine stateMachine, 
RaftProperties properties) {
-    return new NettyServerStreamRpc(server, peers, stateMachine, properties);
+  public DataStreamServerRpc newDataStreamServerRpc(RaftPeer server, 
StateMachine stateMachine,
+      RaftProperties properties) {
+    return new NettyServerStreamRpc(server, stateMachine, properties);
   }
 }
diff --git 
a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
 
b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
index 18a4695..385a9f7 100644
--- 
a/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
+++ 
b/ratis-netty/src/main/java/org/apache/ratis/netty/server/NettyServerStreamRpc.java
@@ -19,11 +19,13 @@
 package org.apache.ratis.netty.server;
 
 import org.apache.ratis.client.DataStreamClient;
-import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.client.api.DataStreamOutput;
+import org.apache.ratis.client.impl.ClientProtoUtils;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
+import org.apache.ratis.io.CloseAsync;
 import org.apache.ratis.proto.RaftProtos;
+import org.apache.ratis.protocol.DataStreamReply;
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.server.DataStreamServerRpc;
@@ -38,25 +40,76 @@ import 
org.apache.ratis.thirdparty.io.netty.channel.socket.SocketChannel;
 import 
org.apache.ratis.thirdparty.io.netty.channel.socket.nio.NioServerSocketChannel;
 import org.apache.ratis.thirdparty.io.netty.handler.logging.LogLevel;
 import org.apache.ratis.thirdparty.io.netty.handler.logging.LoggingHandler;
+import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.NetUtils;
+import org.apache.ratis.util.PeerProxyMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.ArrayList;
+import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.List;
 import java.nio.channels.WritableByteChannel;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CopyOnWriteArraySet;
 import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
 
 public class NettyServerStreamRpc implements DataStreamServerRpc {
   public static final Logger LOG = 
LoggerFactory.getLogger(NettyServerStreamRpc.class);
 
-  private final RaftPeer raftServer;
+  /**
+   * Proxies to other peers.
+   *
+   * Invariant: all the {@link #peers} must exist in the {@link #map}.
+   */
+  static class Proxies {
+    private final Set<RaftPeer> peers = new CopyOnWriteArraySet<>();
+    private final PeerProxyMap<DataStreamClient> map;
+
+    Proxies(PeerProxyMap<DataStreamClient> map) {
+      this.map = map;
+    }
+
+    void addPeers(Collection<RaftPeer> newPeers) {
+      // add to the map first in order to preserve the invariant.
+      map.addPeers(newPeers);
+      // must use atomic addAll
+      peers.addAll(newPeers);
+    }
+
+    List<DataStreamOutput> getDataStreamOutput() throws IOException {
+      final List<DataStreamOutput> outs = new ArrayList<>();
+      try {
+        getDataStreamOutput(outs);
+      } catch (IOException e) {
+        outs.forEach(CloseAsync::closeAsync);
+        throw e;
+      }
+      return outs;
+    }
+
+    private void getDataStreamOutput(List<DataStreamOutput> outs) throws 
IOException {
+      for (RaftPeer peer : peers) {
+        try {
+          outs.add(map.getProxy(peer.getId()).stream());
+        } catch (IOException e) {
+          throw new IOException(map.getName() + ": Failed to 
getDataStreamOutput for " + peer, e);
+        }
+      }
+    }
+
+    void close() {
+      map.close();
+    }
+  }
+
+  private final String name;
   private final EventLoopGroup bossGroup = new NioEventLoopGroup();
   private final EventLoopGroup workerGroup = new NioEventLoopGroup();
   private final ChannelFuture channelFuture;
@@ -65,23 +118,35 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
   private final ConcurrentMap<Long, CompletableFuture<DataStream>> streams = 
new ConcurrentHashMap<>();
   private final ConcurrentMap<Long, List<DataStreamOutput>> peersStreamOutput 
= new ConcurrentHashMap<>();
 
-  private List<DataStreamClient> clients = new ArrayList<>();
+  private final Proxies proxies;
 
-  public NettyServerStreamRpc(RaftPeer server, StateMachine stateMachine) {
-    this.raftServer = server;
+  public NettyServerStreamRpc(RaftPeer server, StateMachine stateMachine, 
RaftProperties properties) {
+    this.name = server + "-" + getClass().getSimpleName();
     this.stateMachine = stateMachine;
-    this.channelFuture = buildChannel();
+    this.channelFuture = new ServerBootstrap()
+        .group(bossGroup, workerGroup)
+        .channel(NioServerSocketChannel.class)
+        .handler(new LoggingHandler(LogLevel.INFO))
+        .childHandler(getInitializer())
+        .childOption(ChannelOption.SO_KEEPALIVE, true)
+        .localAddress(NetUtils.createSocketAddr(server.getAddress()))
+        .bind();
+
+    this.proxies = new Proxies(new PeerProxyMap<>(name, peer -> 
newClient(peer, properties)));
   }
 
-  public NettyServerStreamRpc(
-      RaftPeer server, List<RaftPeer> otherPeers,
-      StateMachine stateMachine, RaftProperties properties){
-    this(server, stateMachine);
-    setupClient(otherPeers, properties);
+  static DataStreamClient newClient(RaftPeer peer, RaftProperties properties) {
+    final DataStreamClient client = DataStreamClient.newBuilder()
+        .setRaftServer(peer)
+        .setProperties(properties)
+        .build();
+    client.start();
+    return client;
   }
 
-  private List<DataStreamOutput> getDataStreamOutput() {
-    return clients.stream().map(client -> 
client.stream()).collect(Collectors.toList());
+  @Override
+  public void addPeers(Collection<RaftPeer> newPeers) {
+    proxies.addPeers(newPeers);
   }
 
   private CompletableFuture<DataStream> getDataStreamFuture(ByteBuf buf) {
@@ -121,34 +186,34 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
   private ChannelInboundHandler getServerHandler(){
     return new ChannelInboundHandlerAdapter(){
       @Override
-      public void channelRead(ChannelHandlerContext ctx, Object msg) {
+      public void channelRead(ChannelHandlerContext ctx, Object msg) throws 
IOException {
         final DataStreamRequestByteBuf request = (DataStreamRequestByteBuf)msg;
         final ByteBuf buf = request.slice();
         final boolean isHeader = request.getStreamOffset() == -1;
 
-        CompletableFuture<?>[] parallelWrites = new 
CompletableFuture<?>[clients.size() + 1];
-
-        final CompletableFuture<?> localWrites = isHeader?
-                streams.computeIfAbsent(request.getStreamId(), id -> 
getDataStreamFuture(buf))
+        final CompletableFuture<Long> localWrite = isHeader?
+                streams.computeIfAbsent(request.getStreamId(), id -> 
getDataStreamFuture(buf)).thenApply(stream -> 0L)
                 : streams.get(request.getStreamId()).thenApply(stream -> 
writeTo(buf, stream));
-        parallelWrites[0] = localWrites;
-        peersStreamOutput.putIfAbsent(request.getStreamId(), 
getDataStreamOutput());
 
-          // do not need to forward header request
+        final List<CompletableFuture<DataStreamReply>> remoteWrites = new 
ArrayList<>();
         if (isHeader) {
-          for (int i = 0; i < 
peersStreamOutput.get(request.getStreamId()).size(); i++) {
-            parallelWrites[i + 1] = 
peersStreamOutput.get(request.getStreamId()).get(i).getHeaderFuture();
+          // do not need to forward header request
+          final List<DataStreamOutput> outs = proxies.getDataStreamOutput();
+          peersStreamOutput.put(request.getStreamId(), outs);
+          for(DataStreamOutput out : outs) {
+            remoteWrites.add(out.getHeaderFuture());
           }
         } else {
           // body
-          for (int i = 0; i < clients.size(); i++) {
-            parallelWrites[i + 1]  =
-              
peersStreamOutput.get(request.getStreamId()).get(i).writeAsync(request.slice().nioBuffer());
+          for(DataStreamOutput out : 
peersStreamOutput.get(request.getStreamId())) {
+            remoteWrites.add(out.writeAsync(request.slice().nioBuffer()));
           }
         }
-        CompletableFuture.allOf(parallelWrites).whenComplete((t, r) -> {
+
+        JavaUtils.allOf(remoteWrites).thenCombine(localWrite, (v, 
bytesWritten) -> {
               buf.release();
               sendReply(request, ctx);
+              return null;
         });
       }
     };
@@ -166,45 +231,17 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
     };
   }
 
-  ChannelFuture buildChannel() {
-    return new ServerBootstrap()
-        .group(bossGroup, workerGroup)
-        .channel(NioServerSocketChannel.class)
-        .handler(new LoggingHandler(LogLevel.INFO))
-        .childHandler(getInitializer())
-        .childOption(ChannelOption.SO_KEEPALIVE, true)
-        .localAddress(NetUtils.createSocketAddr(raftServer.getAddress()))
-        .bind();
-  }
-
-  private void setupClient(List<RaftPeer> otherPeers, RaftProperties 
properties) {
-    for (RaftPeer peer : otherPeers) {
-      clients.add(DataStreamClient.newBuilder()
-              .setParameters(null)
-              .setRaftServer(peer)
-              .setProperties(properties)
-              .build());
-    }
-  }
-
   private Channel getChannel() {
     return channelFuture.awaitUninterruptibly().channel();
   }
 
   @Override
-  public void startServer() {
+  public void start() {
     channelFuture.syncUninterruptibly();
   }
 
-  // TODO: RATIS-1099 build connection with other server automatically.
-  public void startClientToPeers() {
-    for (DataStreamClient client : clients) {
-      client.start();
-    }
-  }
-
   @Override
-  public void closeServer() {
+  public void close() {
     final ChannelFuture f = getChannel().close();
     f.syncUninterruptibly();
     bossGroup.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS);
@@ -216,8 +253,11 @@ public class NettyServerStreamRpc implements 
DataStreamServerRpc {
       LOG.error("Interrupt EventLoopGroup terminate", e);
     }
 
-    for (DataStreamClient client : clients) {
-      client.close();
-    }
+    proxies.close();
+  }
+
+  @Override
+  public String toString() {
+    return name;
   }
 }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServer.java 
b/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServer.java
index 434aee6..ced98f1 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServer.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServer.java
@@ -17,17 +17,14 @@
  */
 package org.apache.ratis.server;
 
+import java.io.Closeable;
+
 /**
  * Interface for streaming server.
  */
-public interface DataStreamServer {
+public interface DataStreamServer extends Closeable {
   /**
    * Get network interface for server.
    */
   DataStreamServerRpc getServerRpc();
-
-  /**
-   * close server.
-   */
-  void close();
 }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerFactory.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerFactory.java
index b9390ee..be3db98 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerFactory.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerFactory.java
@@ -17,17 +17,17 @@
  */
 package org.apache.ratis.server;
 
-import java.util.List;
 import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.datastream.DataStreamFactory;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.server.impl.ServerFactory;
 import org.apache.ratis.statemachine.StateMachine;
 
+/** A {@link DataStreamFactory} to create server-side objects. */
 public interface DataStreamServerFactory extends DataStreamFactory {
 
   static DataStreamServerFactory cast(DataStreamFactory dataStreamFactory) {
-    if (dataStreamFactory instanceof DataStreamFactory) {
+    if (dataStreamFactory instanceof DataStreamServerFactory) {
       return (DataStreamServerFactory)dataStreamFactory;
     }
     throw new ClassCastException("Cannot cast " + dataStreamFactory.getClass()
@@ -35,15 +35,6 @@ public interface DataStreamServerFactory extends 
DataStreamFactory {
         + "; rpc type is " + dataStreamFactory.getDataStreamType());
   }
 
-  /**
-   * Server implementation for streaming in Raft group
-   */
-  DataStreamServerRpc newDataStreamServerRpc(RaftPeer server, StateMachine 
stateMachine);
-
-  /**
-   * Server implementation for streaming in Raft group. The server will 
forward requests
-   * to peers.
-   */
-  DataStreamServerRpc newDataStreamServerRpc(
-      RaftPeer server, List<RaftPeer> peers, StateMachine stateMachine, 
RaftProperties properties);
+  /** Create a new {@link DataStreamServerRpc}. */
+  DataStreamServerRpc newDataStreamServerRpc(RaftPeer server, StateMachine 
stateMachine, RaftProperties properties);
 }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerRpc.java 
b/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerRpc.java
index f8a270e..ff12b57 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerRpc.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/DataStreamServerRpc.java
@@ -17,19 +17,21 @@
  */
 package org.apache.ratis.server;
 
+import org.apache.ratis.protocol.RaftPeer;
+
+import java.io.Closeable;
+import java.util.Collection;
+
 /**
  * A server interface handling incoming streams
  * Relays those streams to other servers after persisting
  */
-public interface DataStreamServerRpc {
+public interface DataStreamServerRpc extends Closeable {
   /**
    * start server
    */
-  void startServer();
-
-  /**
-   * shutdown server
-   */
-  void closeServer();
+  void start();
 
+  /** Add the given peers */
+  void addPeers(Collection<RaftPeer> peers);
 }
diff --git 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/DataStreamServerImpl.java
 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/DataStreamServerImpl.java
index 4a3cdc7..2504f6f 100644
--- 
a/ratis-server/src/main/java/org/apache/ratis/server/impl/DataStreamServerImpl.java
+++ 
b/ratis-server/src/main/java/org/apache/ratis/server/impl/DataStreamServerImpl.java
@@ -18,7 +18,7 @@
 
 package org.apache.ratis.server.impl;
 
-import java.util.List;
+import java.io.IOException;
 import org.apache.ratis.RaftConfigKeys;
 import org.apache.ratis.conf.Parameters;
 import org.apache.ratis.conf.RaftProperties;
@@ -34,32 +34,14 @@ import org.slf4j.LoggerFactory;
 public class DataStreamServerImpl implements DataStreamServer {
   public static final Logger LOG = 
LoggerFactory.getLogger(DataStreamServerImpl.class);
 
-  private DataStreamServerRpc serverRpc;
-  private RaftPeer raftServer;
-  private final StateMachine stateMachine;
+  private final DataStreamServerRpc serverRpc;
 
   public DataStreamServerImpl(RaftPeer server, StateMachine stateMachine,
       RaftProperties properties, Parameters parameters){
-    this.raftServer = server;
-    this.stateMachine = stateMachine;
-
-    final SupportedDataStreamType type = 
RaftConfigKeys.DataStream.type(properties, LOG::info);
-
-    this.serverRpc = DataStreamServerFactory.cast(type.newFactory(parameters))
-        .newDataStreamServerRpc(raftServer, stateMachine);
-  }
-
-  public DataStreamServerImpl(RaftPeer server,
-      RaftProperties properties,
-      Parameters parameters,
-      StateMachine stateMachine,
-      List<RaftPeer> peers){
-    this.raftServer = server;
-    this.stateMachine = stateMachine;
     final SupportedDataStreamType type = 
RaftConfigKeys.DataStream.type(properties, LOG::info);
 
     this.serverRpc = DataStreamServerFactory.cast(type.newFactory(parameters))
-        .newDataStreamServerRpc(server, peers, stateMachine, properties);
+        .newDataStreamServerRpc(server, stateMachine, properties);
   }
 
   @Override
@@ -68,7 +50,7 @@ public class DataStreamServerImpl implements DataStreamServer 
{
   }
 
   @Override
-  public void close(){
-    serverRpc.closeServer();
+  public void close() throws IOException {
+    serverRpc.close();
   }
 }
diff --git 
a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStream.java 
b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStream.java
index 84bbc79..9af676d 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStream.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStream.java
@@ -18,17 +18,18 @@
 
 package org.apache.ratis.datastream;
 
+import java.io.IOException;
 import java.util.stream.Collectors;
 import org.apache.ratis.BaseTest;
 import org.apache.ratis.MiniRaftCluster;
 import org.apache.ratis.client.api.DataStreamOutput;
 import org.apache.ratis.client.impl.DataStreamClientImpl;
 import org.apache.ratis.conf.RaftProperties;
-import org.apache.ratis.netty.server.NettyServerStreamRpc;
 import org.apache.ratis.protocol.DataStreamReply;
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.protocol.RaftPeerId;
+import org.apache.ratis.server.DataStreamServerRpc;
 import org.apache.ratis.server.impl.DataStreamServerImpl;
 import org.apache.ratis.statemachine.impl.BaseStateMachine;
 import org.apache.ratis.util.JavaUtils;
@@ -51,7 +52,7 @@ public class TestDataStream extends BaseTest {
     return (byte) ('A' + pos%MODULUS);
   }
 
-  class SingleDataStreamStateMachine extends BaseStateMachine {
+  static class SingleDataStreamStateMachine extends BaseStateMachine {
     private int byteWritten = 0;
     private RaftClientRequest writeRequest;
 
@@ -127,24 +128,17 @@ public class TestDataStream extends BaseTest {
     for (int i = 0; i < peers.size(); i++) {
       SingleDataStreamStateMachine singleDataStreamStateMachine = new 
SingleDataStreamStateMachine();
       singleDataStreamStateMachines.add(singleDataStreamStateMachine);
-      DataStreamServerImpl streamServer;
+      final DataStreamServerImpl streamServer = new DataStreamServerImpl(
+          peers.get(i), singleDataStreamStateMachine, properties, null);
+      final DataStreamServerRpc rpc = streamServer.getServerRpc();
       if (i == 0) {
         // only the first server routes requests to peers.
         List<RaftPeer> otherPeers = new ArrayList<>(peers);
         otherPeers.remove(peers.get(i));
-        streamServer = new DataStreamServerImpl(
-            peers.get(i), properties, null, singleDataStreamStateMachine, 
otherPeers);
-      } else {
-        streamServer = new DataStreamServerImpl(
-            peers.get(i), singleDataStreamStateMachine, properties, null);
+        rpc.addPeers(otherPeers);
       }
+      rpc.start();
       servers.add(streamServer);
-      streamServer.getServerRpc().startServer();
-    }
-
-    // start peer clients on stream servers
-    for (DataStreamServerImpl streamServer : servers) {
-      ((NettyServerStreamRpc) 
streamServer.getServerRpc()).startClientToPeers();
     }
   }
 
@@ -153,35 +147,37 @@ public class TestDataStream extends BaseTest {
     client.start();
   }
 
-  public void shutDownSetup(){
+  public void shutdown() throws IOException {
     client.close();
-    servers.stream().forEach(s -> s.close());
+    for (DataStreamServerImpl server : servers) {
+      server.close();
+    }
   }
 
   @Test
-  public void testDataStream(){
-    properties = new RaftProperties();
-    peers = Arrays.stream(MiniRaftCluster.generateIds(1, 0))
-                       .map(RaftPeerId::valueOf)
-                       .map(id -> new RaftPeer(id, 
NetUtils.createLocalServerAddress())).collect(
-            Collectors.toList());
-
-    setupServer();
-    setupClient();
-    runTestDataStream();
+  public void testDataStreamSingleServer() throws Exception {
+    runTestDataStream(1);
   }
 
   @Test
-  public void testDataStreamMultipleServer(){
+  public void testDataStreamMultipleServer() throws Exception {
+    runTestDataStream(3);
+  }
+
+  void runTestDataStream(int numServers) throws Exception {
     properties = new RaftProperties();
-    peers = Arrays.asList(MiniRaftCluster.generateIds(3, 0)).stream()
+    peers = Arrays.stream(MiniRaftCluster.generateIds(numServers, 0))
         .map(RaftPeerId::valueOf)
-        .map(id -> new RaftPeer(id, 
NetUtils.createLocalServerAddress())).collect(
-            Collectors.toList());
+        .map(id -> new RaftPeer(id, NetUtils.createLocalServerAddress()))
+        .collect(Collectors.toList());
 
     setupServer();
     setupClient();
-    runTestDataStream();
+    try {
+      runTestDataStream();
+    } finally {
+      shutdown();
+    }
   }
 
   public void runTestDataStream(){
@@ -219,8 +215,6 @@ public class TestDataStream extends BaseTest {
       }
       Assert.assertEquals(dataSize, s.getByteWritten());
     }
-
-    shutDownSetup();
   }
 
   static ByteBuffer initBuffer(int offset, int size) {

Reply via email to