This is an automated email from the ASF dual-hosted git repository.

vanzin pushed a commit to branch branch-2.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.3 by this push:
     new c45f8da  [SPARK-26604][CORE][BACKPORT-2.4] Clean up channel 
registration for StreamManager
c45f8da is described below

commit c45f8da3af6000645ee76544940a6bdc5477884b
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Thu Mar 7 19:48:20 2019 -0800

    [SPARK-26604][CORE][BACKPORT-2.4] Clean up channel registration for 
StreamManager
    
    ## What changes were proposed in this pull request?
    
    This is mostly a clean backport of 
https://github.com/apache/spark/pull/23521 to branch-2.4
    
    ## How was this patch tested?
    
    I've tested this with a hack in `TransportRequestHandler` to force 
`ChunkFetchRequest` to get dropped.
    
    Then making a number of `ExternalShuffleClient.fetchChunk` requests (which 
`OpenBlocks` then `ChunkFetchRequest`) and closing out of my test harness. A 
heap dump later reveals that the `StreamState` references are unreachable.
    
    I haven't run this through the unit test suite, but doing that now. Wanted 
to get this up as I think folks are waiting for it for 2.4.1
    
    Closes #24013 from abellina/SPARK-26604_cherry_pick_2_4.
    
    Lead-authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Co-authored-by: Alessandro Bellina <abell...@yahoo-inc.com>
    Signed-off-by: Marcelo Vanzin <van...@cloudera.com>
    (cherry picked from commit 216eeec2bd319f1d6a1228c9bc8d8a579d5e6571)
    Signed-off-by: Marcelo Vanzin <van...@cloudera.com>
---
 .../network/server/OneForOneStreamManager.java     | 25 ++++++++++++----------
 .../apache/spark/network/server/StreamManager.java | 10 ---------
 .../network/server/TransportRequestHandler.java    |  1 -
 .../network/TransportRequestHandlerSuite.java      |  9 ++++++--
 .../server/OneForOneStreamManagerSuite.java        |  5 +++--
 .../shuffle/ExternalShuffleBlockHandler.java       |  2 +-
 .../shuffle/ExternalShuffleBlockHandlerSuite.java  |  3 ++-
 .../spark/network/netty/NettyBlockRpcServer.scala  |  3 ++-
 8 files changed, 29 insertions(+), 29 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
 
b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index 0f6a882..6fafcc1 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -23,6 +23,7 @@ import java.util.Random;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import io.netty.channel.Channel;
 import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -49,7 +50,7 @@ public class OneForOneStreamManager extends StreamManager {
     final Iterator<ManagedBuffer> buffers;
 
     // The channel associated to the stream
-    Channel associatedChannel = null;
+    final Channel associatedChannel;
 
     // Used to keep track of the index of the buffer that the user has 
retrieved, just to ensure
     // that the caller only requests each chunk one at a time, in order.
@@ -58,9 +59,10 @@ public class OneForOneStreamManager extends StreamManager {
     // Used to keep track of the number of chunks being transferred and not 
finished yet.
     volatile long chunksBeingTransferred = 0L;
 
-    StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+    StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel 
channel) {
       this.appId = appId;
       this.buffers = Preconditions.checkNotNull(buffers);
+      this.associatedChannel = channel;
     }
   }
 
@@ -72,13 +74,6 @@ public class OneForOneStreamManager extends StreamManager {
   }
 
   @Override
-  public void registerChannel(Channel channel, long streamId) {
-    if (streams.containsKey(streamId)) {
-      streams.get(streamId).associatedChannel = channel;
-    }
-  }
-
-  @Override
   public ManagedBuffer getChunk(long streamId, int chunkIndex) {
     StreamState state = streams.get(streamId);
     if (chunkIndex != state.curChunk) {
@@ -195,11 +190,19 @@ public class OneForOneStreamManager extends StreamManager 
{
    *
    * If an app ID is provided, only callers who've authenticated with the 
given app ID will be
    * allowed to fetch from this stream.
+   *
+   * This method also associates the stream with a single client connection, 
which is guaranteed
+   * to be the only reader of the stream. Once the connection is closed, the 
stream will never
+   * be used again, enabling cleanup by `connectionTerminated`.
    */
-  public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
+  public long registerStream(String appId, Iterator<ManagedBuffer> buffers, 
Channel channel) {
     long myStreamId = nextStreamId.getAndIncrement();
-    streams.put(myStreamId, new StreamState(appId, buffers));
+    streams.put(myStreamId, new StreamState(appId, buffers, channel));
     return myStreamId;
   }
 
+  @VisibleForTesting
+  public int numStreamStates() {
+    return streams.size();
+  }
 }
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
 
b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
index c535295..e48d27b 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -61,16 +61,6 @@ public abstract class StreamManager {
   }
 
   /**
-   * Associates a stream with a single client connection, which is guaranteed 
to be the only reader
-   * of the stream. The getChunk() method will be called serially on this 
connection and once the
-   * connection is closed, the stream will never be used again, enabling 
cleanup.
-   *
-   * This must be called before the first getChunk() on the stream, but it may 
be invoked multiple
-   * times with the same channel and stream id.
-   */
-  public void registerChannel(Channel channel, long streamId) { }
-
-  /**
    * Indicates that the given channel has been terminated. After this occurs, 
we are guaranteed not
    * to read from the associated streams again, so any state can be cleaned up.
    */
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
 
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index e944535..ed439a7 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -133,7 +133,6 @@ public class TransportRequestHandler extends 
MessageHandler<RequestMessage> {
     ManagedBuffer buf;
     try {
       streamManager.checkAuthorization(reverseClient, 
req.streamChunkId.streamId);
-      streamManager.registerChannel(channel, req.streamChunkId.streamId);
       buf = streamManager.getChunk(req.streamChunkId.streamId, 
req.streamChunkId.chunkIndex);
     } catch (Exception e) {
       logger.error(String.format("Error opening block %s for request from %s",
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
index 2656cbe..0b565f2 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
@@ -62,8 +62,10 @@ public class TransportRequestHandlerSuite {
     managedBuffers.add(new TestManagedBuffer(20));
     managedBuffers.add(new TestManagedBuffer(30));
     managedBuffers.add(new TestManagedBuffer(40));
-    long streamId = streamManager.registerStream("test-app", 
managedBuffers.iterator());
-    streamManager.registerChannel(channel, streamId);
+    long streamId = streamManager.registerStream("test-app", 
managedBuffers.iterator(), channel);
+
+    assert streamManager.numStreamStates() == 1;
+
     TransportClient reverseClient = mock(TransportClient.class);
     TransportRequestHandler requestHandler = new 
TransportRequestHandler(channel, reverseClient,
       rpcHandler, 2L);
@@ -98,6 +100,9 @@ public class TransportRequestHandlerSuite {
     requestHandler.handle(request3);
     verify(channel, times(1)).close();
     assert responseAndPromisePairs.size() == 3;
+
+    streamManager.connectionTerminated(channel);
+    assert streamManager.numStreamStates() == 0;
   }
 
   private class ExtendedChannelPromise extends DefaultChannelPromise {
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
index c647525..4248762 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
@@ -37,14 +37,15 @@ public class OneForOneStreamManagerSuite {
     TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
     buffers.add(buffer1);
     buffers.add(buffer2);
-    long streamId = manager.registerStream("appId", buffers.iterator());
 
     Channel dummyChannel = Mockito.mock(Channel.class, 
Mockito.RETURNS_SMART_NULLS);
-    manager.registerChannel(dummyChannel, streamId);
+    manager.registerStream("appId", buffers.iterator(), dummyChannel);
+    assert manager.numStreamStates() == 1;
 
     manager.connectionTerminated(dummyChannel);
 
     Mockito.verify(buffer1, Mockito.times(1)).release();
     Mockito.verify(buffer2, Mockito.times(1)).release();
+    assert manager.numStreamStates() == 0;
   }
 }
diff --git 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index fc7bba4..d6335f0 100644
--- 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -91,7 +91,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
         OpenBlocks msg = (OpenBlocks) msgObj;
         checkAuth(client, msg.appId);
         long streamId = streamManager.registerStream(client.getClientId(),
-          new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
+          new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), 
client.getChannel());
         if (logger.isTraceEnabled()) {
           logger.trace("Registered streamId {} with {} buffers for client {} 
from host {}",
                        streamId,
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 7846b71..1e4eda0 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -101,7 +101,8 @@ public class ExternalShuffleBlockHandlerSuite {
     @SuppressWarnings("unchecked")
     ArgumentCaptor<Iterator<ManagedBuffer>> stream = 
(ArgumentCaptor<Iterator<ManagedBuffer>>)
         (ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
-    verify(streamManager, times(1)).registerStream(anyString(), 
stream.capture());
+    verify(streamManager, times(1)).registerStream(anyString(), 
stream.capture(),
+      any());
     Iterator<ManagedBuffer> buffers = stream.getValue();
     assertEquals(block0Marker, buffers.next());
     assertEquals(block1Marker, buffers.next());
diff --git 
a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala 
b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index eb4cf94..0a9c5d0 100644
--- 
a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ 
b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -59,7 +59,8 @@ class NettyBlockRpcServer(
         val blocksNum = openBlocks.blockIds.length
         val blocks = for (i <- (0 until blocksNum).view)
           yield 
blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
-        val streamId = streamManager.registerStream(appId, 
blocks.iterator.asJava)
+        val streamId = streamManager.registerStream(appId, 
blocks.iterator.asJava,
+          client.getChannel)
         logTrace(s"Registered streamId $streamId with $blocksNum buffers")
         responseContext.onSuccess(new StreamHandle(streamId, 
blocksNum).toByteBuffer)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to