Repository: spark
Updated Branches:
  refs/heads/master d8f84f26e -> 2b9b72682


[SPARK-4740] Create multiple concurrent connections between two peer nodes in 
Netty.

It's been reported that when the number of disks is large and the number of 
nodes is small, Netty network throughput is low compared with NIO. We suspect 
the problem is that only a small number of disks are utilized to serve shuffle 
files at any given point, due to connection reuse. This patch adds a new config 
parameter to specify the number of concurrent connections between two peer 
nodes, default to 2.

Author: Reynold Xin <[email protected]>

Closes #3625 from rxin/SPARK-4740 and squashes the following commits:

ad4241a [Reynold Xin] Updated javadoc.
f33c72b [Reynold Xin] Code review feedback.
0fefabb [Reynold Xin] Use double check in synchronization.
41dfcb2 [Reynold Xin] Added test case.
9076b4a [Reynold Xin] Fixed two NPEs.
3e1306c [Reynold Xin] Minor style fix.
4f21673 [Reynold Xin] [SPARK-4740] Create multiple concurrent connections 
between two peer nodes in Netty.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b9b7268
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b9b7268
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b9b7268

Branch: refs/heads/master
Commit: 2b9b72682e587909a84d3ace214c22cec830eeaf
Parents: d8f84f2
Author: Reynold Xin <[email protected]>
Authored: Tue Dec 9 17:49:59 2014 -0800
Committer: Reynold Xin <[email protected]>
Committed: Tue Dec 9 17:49:59 2014 -0800

----------------------------------------------------------------------
 .../network/client/TransportClientFactory.java  | 124 +++++++++++++------
 .../spark/network/util/TransportConf.java       |   5 +
 .../network/TransportClientFactorySuite.java    |  97 +++++++++++++--
 3 files changed, 180 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2b9b7268/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
 
b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 9afd5de..d26b9b4 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -22,6 +22,7 @@ import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
 import java.util.List;
+import java.util.Random;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -42,6 +43,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.network.TransportContext;
 import org.apache.spark.network.server.TransportChannelHandler;
 import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.NettyUtils;
 import org.apache.spark.network.util.TransportConf;
 
@@ -56,12 +58,31 @@ import org.apache.spark.network.util.TransportConf;
  * TransportClient, all given {@link TransportClientBootstrap}s will be run.
  */
 public class TransportClientFactory implements Closeable {
+
+  /** A simple data structure to track the pool of clients between two peer 
nodes. */
+  private static class ClientPool {
+    TransportClient[] clients;
+    Object[] locks;
+
+    public ClientPool(int size) {
+      clients = new TransportClient[size];
+      locks = new Object[size];
+      for (int i = 0; i < size; i++) {
+        locks[i] = new Object();
+      }
+    }
+  }
+
   private final Logger logger = 
LoggerFactory.getLogger(TransportClientFactory.class);
 
   private final TransportContext context;
   private final TransportConf conf;
   private final List<TransportClientBootstrap> clientBootstraps;
-  private final ConcurrentHashMap<SocketAddress, TransportClient> 
connectionPool;
+  private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
+
+  /** Random number generator for picking connections between peers. */
+  private final Random rand;
+  private final int numConnectionsPerPeer;
 
   private final Class<? extends Channel> socketChannelClass;
   private EventLoopGroup workerGroup;
@@ -73,7 +94,9 @@ public class TransportClientFactory implements Closeable {
     this.context = Preconditions.checkNotNull(context);
     this.conf = context.getConf();
     this.clientBootstraps = 
Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
-    this.connectionPool = new ConcurrentHashMap<SocketAddress, 
TransportClient>();
+    this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
+    this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
+    this.rand = new Random();
 
     IOMode ioMode = IOMode.valueOf(conf.ioMode());
     this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
@@ -84,10 +107,14 @@ public class TransportClientFactory implements Closeable {
   }
 
   /**
-   * Create a new {@link TransportClient} connecting to the given remote host 
/ port. This will
-   * reuse TransportClients if they are still active and are for the same 
remote address. Prior
-   * to the creation of a new TransportClient, we will execute all {@link 
TransportClientBootstrap}s
-   * that are registered with this factory.
+   * Create a {@link TransportClient} connecting to the given remote host / 
port.
+   *
+   * We maintains an array of clients (size determined by 
spark.shuffle.io.numConnectionsPerPeer)
+   * and randomly picks one to use. If no client was previously created in the 
randomly selected
+   * spot, this function creates a new client and places it there.
+   *
+   * Prior to the creation of a new TransportClient, we will execute all
+   * {@link TransportClientBootstrap}s that are registered with this factory.
    *
    * This blocks until a connection is successfully established and fully 
bootstrapped.
    *
@@ -97,23 +124,48 @@ public class TransportClientFactory implements Closeable {
     // Get connection from the connection pool first.
     // If it is not found or not active, create a new one.
     final InetSocketAddress address = new InetSocketAddress(remoteHost, 
remotePort);
-    TransportClient cachedClient = connectionPool.get(address);
-    if (cachedClient != null) {
-      if (cachedClient.isActive()) {
-        logger.trace("Returning cached connection to {}: {}", address, 
cachedClient);
-        return cachedClient;
-      } else {
-        logger.info("Found inactive connection to {}, closing it.", address);
-        connectionPool.remove(address, cachedClient); // Remove inactive 
clients.
+
+    // Create the ClientPool if we don't have it yet.
+    ClientPool clientPool = connectionPool.get(address);
+    if (clientPool == null) {
+      connectionPool.putIfAbsent(address, new 
ClientPool(numConnectionsPerPeer));
+      clientPool = connectionPool.get(address);
+    }
+
+    int clientIndex = rand.nextInt(numConnectionsPerPeer);
+    TransportClient cachedClient = clientPool.clients[clientIndex];
+
+    if (cachedClient != null && cachedClient.isActive()) {
+      logger.trace("Returning cached connection to {}: {}", address, 
cachedClient);
+      return cachedClient;
+    }
+
+    // If we reach here, we don't have an existing connection open. Let's 
create a new one.
+    // Multiple threads might race here to create new connections. Keep only 
one of them active.
+    synchronized (clientPool.locks[clientIndex]) {
+      cachedClient = clientPool.clients[clientIndex];
+
+      if (cachedClient != null) {
+        if (cachedClient.isActive()) {
+          logger.trace("Returning cached connection to {}: {}", address, 
cachedClient);
+          return cachedClient;
+        } else {
+          logger.info("Found inactive connection to {}, creating a new one.", 
address);
+        }
       }
+      clientPool.clients[clientIndex] = createClient(address);
+      return clientPool.clients[clientIndex];
     }
+  }
 
+  /** Create a completely new {@link TransportClient} to the remote address. */
+  private TransportClient createClient(InetSocketAddress address) throws 
IOException {
     logger.debug("Creating new connection to " + address);
 
     Bootstrap bootstrap = new Bootstrap();
     bootstrap.group(workerGroup)
       .channel(socketChannelClass)
-       // Disable Nagle's Algorithm since we don't want packets to wait
+      // Disable Nagle's Algorithm since we don't want packets to wait
       .option(ChannelOption.TCP_NODELAY, true)
       .option(ChannelOption.SO_KEEPALIVE, true)
       .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
@@ -130,7 +182,7 @@ public class TransportClientFactory implements Closeable {
     });
 
     // Connect to the remote server
-    long preConnect = System.currentTimeMillis();
+    long preConnect = System.nanoTime();
     ChannelFuture cf = bootstrap.connect(address);
     if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
       throw new IOException(
@@ -143,43 +195,37 @@ public class TransportClientFactory implements Closeable {
     assert client != null : "Channel future completed successfully with null 
client";
 
     // Execute any client bootstraps synchronously before marking the Client 
as successful.
-    long preBootstrap = System.currentTimeMillis();
+    long preBootstrap = System.nanoTime();
     logger.debug("Connection to {} successful, running bootstraps...", 
address);
     try {
       for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
         clientBootstrap.doBootstrap(client);
       }
     } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap 
may be written in Scala
-      long bootstrapTime = System.currentTimeMillis() - preBootstrap;
-      logger.error("Exception while bootstrapping client after " + 
bootstrapTime + " ms", e);
+      long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
+      logger.error("Exception while bootstrapping client after " + 
bootstrapTimeMs + " ms", e);
       client.close();
       throw Throwables.propagate(e);
     }
-    long postBootstrap = System.currentTimeMillis();
-
-    // Successful connection & bootstrap -- in the event that two threads 
raced to create a client,
-    // use the first one that was put into the connectionPool and close the 
one we made here.
-    TransportClient oldClient = connectionPool.putIfAbsent(address, client);
-    if (oldClient == null) {
-      logger.debug("Successfully created connection to {} after {} ms ({} ms 
spent in bootstraps)",
-        address, postBootstrap - preConnect, postBootstrap - preBootstrap);
-      return client;
-    } else {
-      logger.debug("Two clients were created concurrently after {} ms, second 
will be disposed.",
-        postBootstrap - preConnect);
-      client.close();
-      return oldClient;
-    }
+    long postBootstrap = System.nanoTime();
+
+    logger.debug("Successfully created connection to {} after {} ms ({} ms 
spent in bootstraps)",
+      address, (postBootstrap - preConnect) / 1000000, (postBootstrap - 
preBootstrap) / 1000000);
+
+    return client;
   }
 
   /** Close all connections in the connection pool, and shutdown the worker 
thread pool. */
   @Override
   public void close() {
-    for (TransportClient client : connectionPool.values()) {
-      try {
-        client.close();
-      } catch (RuntimeException e) {
-        logger.warn("Ignoring exception during close", e);
+    // Go through all clients and close them if they are active.
+    for (ClientPool clientPool : connectionPool.values()) {
+      for (int i = 0; i < clientPool.clients.length; i++) {
+        TransportClient client = clientPool.clients[i];
+        if (client != null) {
+          clientPool.clients[i] = null;
+          JavaUtils.closeQuietly(client);
+        }
       }
     }
     connectionPool.clear();

http://git-wip-us.apache.org/repos/asf/spark/blob/2b9b7268/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java 
b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 1af40ac..f605739 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -40,6 +40,11 @@ public class TransportConf {
     return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
   }
 
+  /** Number of concurrent connections between two nodes for fetching data. **/
+  public int numConnectionsPerPeer() {
+    return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2);
+  }
+
   /** Requested maximum length of the queue of incoming connections. Default 
-1 for no backlog. */
   public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2b9b7268/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
 
b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 822bef1..416dc1b 100644
--- 
a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ 
b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -18,7 +18,11 @@
 package org.apache.spark.network;
 
 import java.io.IOException;
-import java.util.concurrent.TimeoutException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.NoSuchElementException;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import org.junit.After;
 import org.junit.Before;
@@ -32,6 +36,7 @@ import org.apache.spark.network.client.TransportClientFactory;
 import org.apache.spark.network.server.NoOpRpcHandler;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.ConfigProvider;
 import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
@@ -57,16 +62,94 @@ public class TransportClientFactorySuite {
     JavaUtils.closeQuietly(server2);
   }
 
+  /**
+   * Request a bunch of clients to a single server to test
+   * we create up to maxConnections of clients.
+   *
+   * If concurrent is true, create multiple threads to create clients in 
parallel.
+   */
+  private void testClientReuse(final int maxConnections, boolean concurrent)
+    throws IOException, InterruptedException {
+    TransportConf conf = new TransportConf(new ConfigProvider() {
+      @Override
+      public String get(String name) {
+        if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
+          return Integer.toString(maxConnections);
+        } else {
+          throw new NoSuchElementException();
+        }
+      }
+    });
+
+    RpcHandler rpcHandler = new NoOpRpcHandler();
+    TransportContext context = new TransportContext(conf, rpcHandler);
+    final TransportClientFactory factory = context.createClientFactory();
+    final Set<TransportClient> clients = Collections.synchronizedSet(
+      new HashSet<TransportClient>());
+
+    final AtomicInteger failed = new AtomicInteger();
+    Thread[] attempts = new Thread[maxConnections * 10];
+
+    // Launch a bunch of threads to create new clients.
+    for (int i = 0; i < attempts.length; i++) {
+      attempts[i] = new Thread() {
+        @Override
+        public void run() {
+          try {
+            TransportClient client =
+              factory.createClient(TestUtils.getLocalHost(), 
server1.getPort());
+            assert (client.isActive());
+            clients.add(client);
+          } catch (IOException e) {
+            failed.incrementAndGet();
+          }
+        }
+      };
+
+      if (concurrent) {
+        attempts[i].start();
+      } else {
+        attempts[i].run();
+      }
+    }
+
+    // Wait until all the threads complete.
+    for (int i = 0; i < attempts.length; i++) {
+      attempts[i].join();
+    }
+
+    assert(failed.get() == 0);
+    assert(clients.size() == maxConnections);
+
+    for (TransportClient client : clients) {
+      client.close();
+    }
+  }
+
+  @Test
+  public void reuseClientsUpToConfigVariable() throws Exception {
+    testClientReuse(1, false);
+    testClientReuse(2, false);
+    testClientReuse(3, false);
+    testClientReuse(4, false);
+  }
+
   @Test
-  public void createAndReuseBlockClients() throws IOException {
+  public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
+    testClientReuse(1, true);
+    testClientReuse(2, true);
+    testClientReuse(3, true);
+    testClientReuse(4, true);
+  }
+
+  @Test
+  public void returnDifferentClientsForDifferentServers() throws IOException {
     TransportClientFactory factory = context.createClientFactory();
     TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), 
server1.getPort());
-    TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), 
server1.getPort());
-    TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), 
server2.getPort());
+    TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), 
server2.getPort());
     assertTrue(c1.isActive());
-    assertTrue(c3.isActive());
-    assertTrue(c1 == c2);
-    assertTrue(c1 != c3);
+    assertTrue(c2.isActive());
+    assertTrue(c1 != c2);
     factory.close();
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to