This is an automated email from the ASF dual-hosted git repository. vanzin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5668c42 [SPARK-27021][CORE] Cleanup of Netty event loop group for shuffle chunk fetch requests 5668c42 is described below commit 5668c42edf20bc577305437622272bf803b6019e Author: “attilapiros” <piros.attila.zs...@gmail.com> AuthorDate: Tue Mar 5 12:31:06 2019 -0800 [SPARK-27021][CORE] Cleanup of Netty event loop group for shuffle chunk fetch requests ## What changes were proposed in this pull request? Creating an Netty `EventLoopGroup` leads to creating a new Thread pool for handling the events. For stopping the threads of the pool the event loop group should be shut down which is properly done for transport servers and clients by calling for example the `shutdownGracefully()` method (for details see the `close()` method of `TransportClientFactory` and `TransportServer`). But there is a separate event loop group for shuffle chunk fetch requests which is in pipeline for handling fet [...] ## How was this patch tested? With existing unittest. This leak is in the production system too but its effect is spiking in the unittest. Checking the core unittest logs before the PR: ``` $ grep "LEAK IN SUITE" unit-tests.log | grep -o shuffle-chunk-fetch-handler | wc -l 381 ``` And after the PR without whitelisting in thread audit and with an extra `await` after the ` chunkFetchWorkers.shutdownGracefully()`: ``` $ grep "LEAK IN SUITE" unit-tests.log | grep -o shuffle-chunk-fetch-handler | wc -l 0 ``` Closes #23930 from attilapiros/SPARK-27021. Authored-by: “attilapiros” <piros.attila.zs...@gmail.com> Signed-off-by: Marcelo Vanzin <van...@cloudera.com> --- .../org/apache/spark/network/TransportContext.java | 15 +-- .../spark/network/ChunkFetchIntegrationSuite.java | 4 +- .../network/RequestTimeoutIntegrationSuite.java | 10 +- .../apache/spark/network/RpcIntegrationSuite.java | 4 +- .../java/org/apache/spark/network/StreamSuite.java | 4 +- .../spark/network/TransportClientFactorySuite.java | 78 +++++++------- .../spark/network/crypto/AuthIntegrationSuite.java | 3 + .../apache/spark/network/sasl/SparkSaslSuite.java | 6 +- .../network/util/NettyMemoryMetricsSuite.java | 5 +- .../spark/network/sasl/SaslIntegrationSuite.java | 91 ++++++++-------- .../shuffle/ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/ExternalShuffleSecuritySuite.java | 10 +- .../spark/network/yarn/YarnShuffleService.java | 7 +- .../spark/deploy/ExternalShuffleService.scala | 8 +- .../network/netty/NettyBlockTransferService.scala | 3 + .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 3 + .../apache/spark/ExternalShuffleServiceSuite.scala | 15 ++- .../test/scala/org/apache/spark/ThreadAudit.scala | 16 ++- .../apache/spark/storage/BlockManagerSuite.scala | 115 +++++++++++---------- 19 files changed, 228 insertions(+), 173 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 0bc5dd5..d99b9bd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.Closeable; import java.util.ArrayList; import java.util.List; @@ -60,13 +61,12 @@ import org.apache.spark.network.util.TransportFrameDecoder; * channel. As each TransportChannelHandler contains a TransportClient, this enables server * processes to send messages back to the client on an existing channel. */ -public class TransportContext { +public class TransportContext implements Closeable { private static final Logger logger = LoggerFactory.getLogger(TransportContext.class); private final TransportConf conf; private final RpcHandler rpcHandler; private final boolean closeIdleConnections; - private final boolean isClientOnly; // Number of registered connections to the shuffle service private Counter registeredConnections = new Counter(); @@ -120,7 +120,6 @@ public class TransportContext { this.conf = conf; this.rpcHandler = rpcHandler; this.closeIdleConnections = closeIdleConnections; - this.isClientOnly = isClientOnly; if (conf.getModuleName() != null && conf.getModuleName().equalsIgnoreCase("shuffle") && @@ -200,9 +199,7 @@ public class TransportContext { // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); // Use a separate EventLoopGroup to handle ChunkFetchRequest messages for shuffle rpcs. - if (conf.getModuleName() != null && - conf.getModuleName().equalsIgnoreCase("shuffle") - && !isClientOnly) { + if (chunkFetchWorkers != null) { pipeline.addLast(chunkFetchWorkers, "chunkFetchHandler", chunkFetchHandler); } return channelHandler; @@ -240,4 +237,10 @@ public class TransportContext { public Counter getRegisteredConnections() { return registeredConnections; } + + public void close() { + if (chunkFetchWorkers != null) { + chunkFetchWorkers.shutdownGracefully(); + } + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index ab4dd04..5999b62 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -56,6 +56,7 @@ public class ChunkFetchIntegrationSuite { static final int BUFFER_CHUNK_INDEX = 0; static final int FILE_CHUNK_INDEX = 1; + static TransportContext context; static TransportServer server; static TransportClientFactory clientFactory; static StreamManager streamManager; @@ -117,7 +118,7 @@ public class ChunkFetchIntegrationSuite { return streamManager; } }; - TransportContext context = new TransportContext(conf, handler); + context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } @@ -127,6 +128,7 @@ public class ChunkFetchIntegrationSuite { bufferChunk.release(); server.close(); clientFactory.close(); + context.close(); testFile.delete(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index c0724e0..15a28ba 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -48,6 +48,7 @@ import java.util.concurrent.TimeUnit; */ public class RequestTimeoutIntegrationSuite { + private TransportContext context; private TransportServer server; private TransportClientFactory clientFactory; @@ -79,6 +80,9 @@ public class RequestTimeoutIntegrationSuite { if (clientFactory != null) { clientFactory.close(); } + if (context != null) { + context.close(); + } } // Basic suite: First request completes quickly, and second waits for longer than network timeout. @@ -106,7 +110,7 @@ public class RequestTimeoutIntegrationSuite { } }; - TransportContext context = new TransportContext(conf, handler); + context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -153,7 +157,7 @@ public class RequestTimeoutIntegrationSuite { } }; - TransportContext context = new TransportContext(conf, handler); + context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); @@ -204,7 +208,7 @@ public class RequestTimeoutIntegrationSuite { } }; - TransportContext context = new TransportContext(conf, handler); + context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 1c0aa4d..117f1e4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -44,6 +44,7 @@ import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { static TransportConf conf; + static TransportContext context; static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; @@ -90,7 +91,7 @@ public class RpcIntegrationSuite { @Override public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; - TransportContext context = new TransportContext(conf, rpcHandler); + context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); oneWayMsgs = new ArrayList<>(); @@ -160,6 +161,7 @@ public class RpcIntegrationSuite { public static void tearDown() { server.close(); clientFactory.close(); + context.close(); testData.cleanup(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f3050cb..485d8ad 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -51,6 +51,7 @@ public class StreamSuite { private static final String[] STREAMS = StreamTestHelper.STREAMS; private static StreamTestHelper testData; + private static TransportContext context; private static TransportServer server; private static TransportClientFactory clientFactory; @@ -93,7 +94,7 @@ public class StreamSuite { return streamManager; } }; - TransportContext context = new TransportContext(conf, handler); + context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } @@ -103,6 +104,7 @@ public class StreamSuite { server.close(); clientFactory.close(); testData.cleanup(); + context.close(); } @Test diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index e95d25f..2c62114 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -64,6 +64,7 @@ public class TransportClientFactorySuite { public void tearDown() { JavaUtils.closeQuietly(server1); JavaUtils.closeQuietly(server2); + JavaUtils.closeQuietly(context); } /** @@ -80,49 +81,50 @@ public class TransportClientFactorySuite { TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); - TransportContext context = new TransportContext(conf, rpcHandler); - TransportClientFactory factory = context.createClientFactory(); - Set<TransportClient> clients = Collections.synchronizedSet( - new HashSet<TransportClient>()); - - AtomicInteger failed = new AtomicInteger(); - Thread[] attempts = new Thread[maxConnections * 10]; - - // Launch a bunch of threads to create new clients. - for (int i = 0; i < attempts.length; i++) { - attempts[i] = new Thread(() -> { - try { - TransportClient client = - factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assertTrue(client.isActive()); - clients.add(client); - } catch (IOException e) { - failed.incrementAndGet(); - } catch (InterruptedException e) { - throw new RuntimeException(e); + try (TransportContext context = new TransportContext(conf, rpcHandler)) { + TransportClientFactory factory = context.createClientFactory(); + Set<TransportClient> clients = Collections.synchronizedSet( + new HashSet<TransportClient>()); + + AtomicInteger failed = new AtomicInteger(); + Thread[] attempts = new Thread[maxConnections * 10]; + + // Launch a bunch of threads to create new clients. + for (int i = 0; i < attempts.length; i++) { + attempts[i] = new Thread(() -> { + try { + TransportClient client = + factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(client.isActive()); + clients.add(client); + } catch (IOException e) { + failed.incrementAndGet(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + if (concurrent) { + attempts[i].start(); + } else { + attempts[i].run(); } - }); + } - if (concurrent) { - attempts[i].start(); - } else { - attempts[i].run(); + // Wait until all the threads complete. + for (Thread attempt : attempts) { + attempt.join(); } - } - // Wait until all the threads complete. - for (Thread attempt : attempts) { - attempt.join(); - } + Assert.assertEquals(0, failed.get()); + Assert.assertEquals(clients.size(), maxConnections); - Assert.assertEquals(0, failed.get()); - Assert.assertEquals(clients.size(), maxConnections); + for (TransportClient client : clients) { + client.close(); + } - for (TransportClient client : clients) { - client.close(); + factory.close(); } - - factory.close(); } @Test @@ -204,8 +206,8 @@ public class TransportClientFactorySuite { throw new UnsupportedOperationException(); } }); - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); - try (TransportClientFactory factory = context.createClientFactory()) { + try (TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportClientFactory factory = context.createClientFactory()) { TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); assertTrue(c1.isActive()); long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 8751944..8a0ff54 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -196,6 +196,9 @@ public class AuthIntegrationSuite { if (server != null) { server.close(); } + if (ctx != null) { + ctx.close(); + } } private SecretKeyHolder createKeyHolder(String secret) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 59adf97..cf2d72f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -365,6 +365,7 @@ public class SparkSaslSuite { final TransportClient client; final TransportServer server; + final TransportContext ctx; private final boolean encrypt; private final boolean disableClientEncryption; @@ -396,7 +397,7 @@ public class SparkSaslSuite { when(keyHolder.getSaslUser(anyString())).thenReturn("user"); when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); - TransportContext ctx = new TransportContext(conf, rpcHandler); + this.ctx = new TransportContext(conf, rpcHandler); this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME); @@ -431,6 +432,9 @@ public class SparkSaslSuite { if (server != null) { server.close(); } + if (ctx != null) { + ctx.close(); + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java index 400b385..f049cad 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java @@ -60,11 +60,14 @@ public class NettyMemoryMetricsSuite { JavaUtils.closeQuietly(clientFactory); clientFactory = null; } - if (server != null) { JavaUtils.closeQuietly(server); server = null; } + if (context != null) { + JavaUtils.closeQuietly(context); + context = null; + } } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 02e6eb3..57c1c5e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -91,6 +91,7 @@ public class SaslIntegrationSuite { @AfterClass public static void afterAll() { server.close(); + context.close(); } @After @@ -153,13 +154,14 @@ public class SaslIntegrationSuite { @Test public void testNoSaslServer() { RpcHandler handler = new TestRpcHandler(); - TransportContext context = new TransportContext(conf, handler); - clientFactory = context.createClientFactory( - Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); - try (TransportServer server = context.createServer()) { - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + try (TransportContext context = new TransportContext(conf, handler)) { + clientFactory = context.createClientFactory( + Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + try (TransportServer server = context.createServer()) { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } } } @@ -174,18 +176,15 @@ public class SaslIntegrationSuite { ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler( new OneForOneStreamManager(), blockResolver); TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); - TransportContext blockServerContext = new TransportContext(conf, blockHandler); - TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap)); - TransportClient client1 = null; - TransportClient client2 = null; - TransportClientFactory clientFactory2 = null; - try { + try ( + TransportContext blockServerContext = new TransportContext(conf, blockHandler); + TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap)); // Create a client, and make a request to fetch blocks from a different app. - clientFactory = blockServerContext.createClientFactory( + TransportClientFactory clientFactory1 = blockServerContext.createClientFactory( Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); - client1 = clientFactory.createClient(TestUtils.getLocalHost(), - blockServer.getPort()); + TransportClient client1 = clientFactory1.createClient( + TestUtils.getLocalHost(), blockServer.getPort())) { AtomicReference<Throwable> exception = new AtomicReference<>(); @@ -223,41 +222,33 @@ public class SaslIntegrationSuite { StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; - // Create a second client, authenticated with a different app ID, and try to read from - // the stream created for the previous app. - clientFactory2 = blockServerContext.createClientFactory( - Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); - client2 = clientFactory2.createClient(TestUtils.getLocalHost(), - blockServer.getPort()); - - CountDownLatch chunkReceivedLatch = new CountDownLatch(1); - ChunkReceivedCallback callback = new ChunkReceivedCallback() { - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - chunkReceivedLatch.countDown(); - } - @Override - public void onFailure(int chunkIndex, Throwable t) { - exception.set(t); - chunkReceivedLatch.countDown(); - } - }; - - exception.set(null); - client2.fetchChunk(streamId, 0, callback); - chunkReceivedLatch.await(); - checkSecurityException(exception.get()); - } finally { - if (client1 != null) { - client1.close(); - } - if (client2 != null) { - client2.close(); - } - if (clientFactory2 != null) { - clientFactory2.close(); + try ( + // Create a second client, authenticated with a different app ID, and try to read from + // the stream created for the previous app. + TransportClientFactory clientFactory2 = blockServerContext.createClientFactory( + Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); + TransportClient client2 = clientFactory2.createClient( + TestUtils.getLocalHost(), blockServer.getPort()) + ) { + CountDownLatch chunkReceivedLatch = new CountDownLatch(1); + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + chunkReceivedLatch.countDown(); + } + + @Override + public void onFailure(int chunkIndex, Throwable t) { + exception.set(t); + chunkReceivedLatch.countDown(); + } + }; + + exception.set(null); + client2.fetchChunk(streamId, 0, callback); + chunkReceivedLatch.await(); + checkSecurityException(exception.get()); } - blockServer.close(); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 526b96b..f5b1ec9 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -58,6 +58,7 @@ public class ExternalShuffleIntegrationSuite { static ExternalShuffleBlockHandler handler; static TransportServer server; static TransportConf conf; + static TransportContext transportContext; static byte[][] exec0Blocks = new byte[][] { new byte[123], @@ -87,7 +88,7 @@ public class ExternalShuffleIntegrationSuite { conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); handler = new ExternalShuffleBlockHandler(conf, null); - TransportContext transportContext = new TransportContext(conf, handler); + transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } @@ -95,6 +96,7 @@ public class ExternalShuffleIntegrationSuite { public static void afterAll() { dataContext0.cleanup(); server.close(); + transportContext.close(); } @After diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 82caf39..67f79021 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -41,14 +41,14 @@ public class ExternalShuffleSecuritySuite { TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); TransportServer server; + TransportContext transportContext; @Before public void beforeEach() throws IOException { - TransportContext context = - new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); + transportContext = new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, new TestSecretKeyHolder("my-app-id", "secret")); - this.server = context.createServer(Arrays.asList(bootstrap)); + this.server = transportContext.createServer(Arrays.asList(bootstrap)); } @After @@ -57,6 +57,10 @@ public class ExternalShuffleSecuritySuite { server.close(); server = null; } + if (transportContext != null) { + transportContext.close(); + transportContext = null; + } } @Test diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 7e8d3b2..25592e9 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -113,6 +113,8 @@ public class YarnShuffleService extends AuxiliaryService { // The actual server that serves shuffle files private TransportServer shuffleServer = null; + private TransportContext transportContext = null; + private Configuration _conf = null; // The recovery path used to shuffle service recovery @@ -184,7 +186,7 @@ public class YarnShuffleService extends AuxiliaryService { int port = conf.getInt( SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportContext transportContext = new TransportContext(transportConf, blockHandler); + transportContext = new TransportContext(transportConf, blockHandler); shuffleServer = transportContext.createServer(port, bootstraps); // the port should normally be fixed, but for tests its useful to find an open port port = shuffleServer.getPort(); @@ -318,6 +320,9 @@ public class YarnShuffleService extends AuxiliaryService { if (shuffleServer != null) { shuffleServer.close(); } + if (transportContext != null) { + transportContext.close(); + } if (blockHandler != null) { blockHandler.close(); } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index edfd2ea..12ed189 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -52,8 +52,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = - new TransportContext(transportConf, blockHandler, true) + private var transportContext: TransportContext = _ private var server: TransportServer = _ @@ -82,6 +81,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana } else { Nil } + transportContext = new TransportContext(transportConf, blockHandler, true) server = transportContext.createServer(port, bootstraps.asJava) shuffleServiceSource.registerMetricSet(server.getAllMetrics) @@ -107,6 +107,10 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana server.close() server = null } + if (transportContext != null) { + transportContext.close() + transportContext = null + } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index dc55685..864e8ad 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -182,5 +182,8 @@ private[spark] class NettyBlockTransferService( if (clientFactory != null) { clientFactory.close() } + if (transportContext != null) { + transportContext.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 2540196..472db45 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -315,6 +315,9 @@ private[netty] class NettyRpcEnv( if (fileDownloadFactory != null) { fileDownloadFactory.close() } + if (transportContext != null) { + transportContext.close() + } } override def deserialize[T](deserializationAction: () => T): T = { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 262e2a7..8b737cd 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient} +import org.apache.spark.util.Utils /** * This suite creates an external shuffle server and routes all shuffle fetches through it. @@ -33,13 +34,14 @@ import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh */ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var server: TransportServer = _ + var transportContext: TransportContext = _ var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { super.beforeAll() val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) - val transportContext = new TransportContext(transportConf, rpcHandler) + transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() conf.set(config.SHUFFLE_MANAGER, "sort") @@ -48,11 +50,16 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { } override def afterAll() { - try { + Utils.tryLogNonFatalError{ server.close() - } finally { - super.afterAll() } + Utils.tryLogNonFatalError{ + rpcHandler.close() + } + Utils.tryLogNonFatalError{ + transportContext.close() + } + super.afterAll() } // This test ensures that the external shuffle service is actually in use for the other tests. diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala b/core/src/test/scala/org/apache/spark/ThreadAudit.scala index b3cea9d..6b91162 100644 --- a/core/src/test/scala/org/apache/spark/ThreadAudit.scala +++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala @@ -55,18 +55,26 @@ trait ThreadAudit extends Logging { * creates event loops. One is wrapped inside * [[org.apache.spark.network.server.TransportServer]] * the other one is inside [[org.apache.spark.network.client.TransportClient]]. - * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. - * Manually checked and all of them stopped properly. + * Calling [[SparkContext#stop]] will shut down the thread pool of this event group + * asynchronously. In each case proper stopping is checked manually. */ "rpc-client.*", "rpc-server.*", /** + * During [[org.apache.spark.network.TransportContext]] construction a separate event loop could + * be created for handling ChunkFetchRequest. + * Calling [[org.apache.spark.network.TransportContext#close]] will shut down the thread pool + * of this event group asynchronously. In each case proper stopping is checked manually. + */ + "shuffle-chunk-fetch-handler.*", + + /** * During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside * [[org.apache.spark.network.server.TransportServer]] * the other one is inside [[org.apache.spark.network.client.TransportClient]]. - * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. - * Manually checked and all of them stopped properly. + * Calling [[SparkContext#stop]] will shut down the thread pool of this event group + * asynchronously. In each case proper stopping is checked manually. */ "shuffle-client.*", "shuffle-server.*" diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 5dec4f5..115103f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -895,6 +895,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + allStores += store store.initialize("app-id") // The put should fail since a1 is not serializable. @@ -1360,74 +1361,76 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val tryAgainExecutor = "tryAgainExecutor" val succeedingExecutor = "succeedingExecutor" - // a server which delays response 50ms and must try twice for success. - def newShuffleServer(port: Int): (TransportServer, Int) = { - val failure = new Exception(tryAgainMsg) - val success = ByteBuffer.wrap(new Array[Byte](0)) + val failure = new Exception(tryAgainMsg) + val success = ByteBuffer.wrap(new Array[Byte](0)) - var secondExecutorFailedOnce = false - var thirdExecutorFailedOnce = false + var secondExecutorFailedOnce = false + var thirdExecutorFailedOnce = false - val handler = new NoOpRpcHandler { - override def receive( - client: TransportClient, - message: ByteBuffer, - callback: RpcResponseCallback): Unit = { - val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) - msgObj match { + val handler = new NoOpRpcHandler { + override def receive( + client: TransportClient, + message: ByteBuffer, + callback: RpcResponseCallback): Unit = { + val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) + msgObj match { - case exec: RegisterExecutor if exec.execId == timingoutExecutor => - () // No reply to generate client-side timeout + case exec: RegisterExecutor if exec.execId == timingoutExecutor => + () // No reply to generate client-side timeout - case exec: RegisterExecutor - if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce => - secondExecutorFailedOnce = true - callback.onFailure(failure) + case exec: RegisterExecutor + if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce => + secondExecutorFailedOnce = true + callback.onFailure(failure) - case exec: RegisterExecutor if exec.execId == tryAgainExecutor => - callback.onSuccess(success) + case exec: RegisterExecutor if exec.execId == tryAgainExecutor => + callback.onSuccess(success) - case exec: RegisterExecutor - if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce => - thirdExecutorFailedOnce = true - callback.onFailure(failure) + case exec: RegisterExecutor + if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce => + thirdExecutorFailedOnce = true + callback.onFailure(failure) - case exec: RegisterExecutor if exec.execId == succeedingExecutor => - callback.onSuccess(success) + case exec: RegisterExecutor if exec.execId == succeedingExecutor => + callback.onSuccess(success) - } } } - - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 0) - val transCtx = new TransportContext(transConf, handler, true) - (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } - val candidatePort = RandomUtils.nextInt(1024, 65536) - val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, - newShuffleServer, conf, "ShuffleServer") - - conf.set(SHUFFLE_SERVICE_ENABLED.key, "true") - conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString) - conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") - conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - var e = intercept[SparkException] { - makeBlockManager(8000, timingoutExecutor) - }.getMessage - assert(e.contains("TimeoutException")) - - conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") - conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - e = intercept[SparkException] { - makeBlockManager(8000, tryAgainExecutor) - }.getMessage - assert(e.contains(tryAgainMsg)) - - conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") - conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") - makeBlockManager(8000, succeedingExecutor) - server.close() + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 0) + + Utils.tryWithResource(new TransportContext(transConf, handler, true)) { transCtx => + // a server which delays response 50ms and must try twice for success. + def newShuffleServer(port: Int): (TransportServer, Int) = { + (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) + } + + val candidatePort = RandomUtils.nextInt(1024, 65536) + val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, + newShuffleServer, conf, "ShuffleServer") + + conf.set(SHUFFLE_SERVICE_ENABLED.key, "true") + conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString) + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + var e = intercept[SparkException] { + makeBlockManager(8000, timingoutExecutor) + }.getMessage + assert(e.contains("TimeoutException")) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + e = intercept[SparkException] { + makeBlockManager(8000, tryAgainExecutor) + }.getMessage + assert(e.contains(tryAgainMsg)) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") + makeBlockManager(8000, succeedingExecutor) + server.close() + } } test("fetch remote block to local disk if block size is larger than threshold") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org