You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by GitBox <gi...@apache.org> on 2019/01/16 03:00:51 UTC

[spark] Diff for: [GitHub] cloud-fan closed pull request #23521: [SPARK-26604][CORE] Clean up channel registration for StreamManager

diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
index f08d8b0f984cf..43c3d23b6304d 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
@@ -90,7 +90,6 @@ protected void channelRead0(
     ManagedBuffer buf;
     try {
       streamManager.checkAuthorization(client, msg.streamChunkId.streamId);
-      streamManager.registerChannel(channel, msg.streamChunkId.streamId);
       buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex);
     } catch (Exception e) {
       logger.error(String.format("Error opening block %s for request from %s",
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index 0f6a8824d95e5..6fafcc131fa24 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -23,6 +23,7 @@
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import io.netty.channel.Channel;
 import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -49,7 +50,7 @@
     final Iterator<ManagedBuffer> buffers;
 
     // The channel associated to the stream
-    Channel associatedChannel = null;
+    final Channel associatedChannel;
 
     // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
     // that the caller only requests each chunk one at a time, in order.
@@ -58,9 +59,10 @@
     // Used to keep track of the number of chunks being transferred and not finished yet.
     volatile long chunksBeingTransferred = 0L;
 
-    StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+    StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
       this.appId = appId;
       this.buffers = Preconditions.checkNotNull(buffers);
+      this.associatedChannel = channel;
     }
   }
 
@@ -71,13 +73,6 @@ public OneForOneStreamManager() {
     streams = new ConcurrentHashMap<>();
   }
 
-  @Override
-  public void registerChannel(Channel channel, long streamId) {
-    if (streams.containsKey(streamId)) {
-      streams.get(streamId).associatedChannel = channel;
-    }
-  }
-
   @Override
   public ManagedBuffer getChunk(long streamId, int chunkIndex) {
     StreamState state = streams.get(streamId);
@@ -195,11 +190,19 @@ public long chunksBeingTransferred() {
    *
    * If an app ID is provided, only callers who've authenticated with the given app ID will be
    * allowed to fetch from this stream.
+   *
+   * This method also associates the stream with a single client connection, which is guaranteed
+   * to be the only reader of the stream. Once the connection is closed, the stream will never
+   * be used again, enabling cleanup by `connectionTerminated`.
    */
-  public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
+  public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
     long myStreamId = nextStreamId.getAndIncrement();
-    streams.put(myStreamId, new StreamState(appId, buffers));
+    streams.put(myStreamId, new StreamState(appId, buffers, channel));
     return myStreamId;
   }
 
+  @VisibleForTesting
+  public int numStreamStates() {
+    return streams.size();
+  }
 }
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
index c535295831606..e48d27be1126a 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -60,16 +60,6 @@ public ManagedBuffer openStream(String streamId) {
     throw new UnsupportedOperationException();
   }
 
-  /**
-   * Associates a stream with a single client connection, which is guaranteed to be the only reader
-   * of the stream. The getChunk() method will be called serially on this connection and once the
-   * connection is closed, the stream will never be used again, enabling cleanup.
-   *
-   * This must be called before the first getChunk() on the stream, but it may be invoked multiple
-   * times with the same channel and stream id.
-   */
-  public void registerChannel(Channel channel, long streamId) { }
-
   /**
    * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
    * to read from the associated streams again, so any state can be cleaned up.
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java
index 2c72c53a33ae8..6c9239606bb8c 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java
@@ -64,8 +64,7 @@ public void handleChunkFetchRequest() throws Exception {
     managedBuffers.add(new TestManagedBuffer(20));
     managedBuffers.add(new TestManagedBuffer(30));
     managedBuffers.add(new TestManagedBuffer(40));
-    long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
-    streamManager.registerChannel(channel, streamId);
+    long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
     TransportClient reverseClient = mock(TransportClient.class);
     ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient,
       rpcHandler.getStreamManager(), 2L);
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
index ad640415a8e6d..a87f6c11a2bfd 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
@@ -58,8 +58,10 @@ public void handleStreamRequest() throws Exception {
     managedBuffers.add(new TestManagedBuffer(20));
     managedBuffers.add(new TestManagedBuffer(30));
     managedBuffers.add(new TestManagedBuffer(40));
-    long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
-    streamManager.registerChannel(channel, streamId);
+    long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
+
+    assert streamManager.numStreamStates() == 1;
+
     TransportClient reverseClient = mock(TransportClient.class);
     TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
       rpcHandler, 2L);
@@ -94,5 +96,8 @@ public void handleStreamRequest() throws Exception {
     requestHandler.handle(request3);
     verify(channel, times(1)).close();
     assert responseAndPromisePairs.size() == 3;
+
+    streamManager.connectionTerminated(channel);
+    assert streamManager.numStreamStates() == 0;
   }
 }
diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
index c647525d8f1bd..4248762c32389 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
@@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
     TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
     buffers.add(buffer1);
     buffers.add(buffer2);
-    long streamId = manager.registerStream("appId", buffers.iterator());
 
     Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
-    manager.registerChannel(dummyChannel, streamId);
+    manager.registerStream("appId", buffers.iterator(), dummyChannel);
+    assert manager.numStreamStates() == 1;
 
     manager.connectionTerminated(dummyChannel);
 
     Mockito.verify(buffer1, Mockito.times(1)).release();
     Mockito.verify(buffer2, Mockito.times(1)).release();
+    assert manager.numStreamStates() == 0;
   }
 }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 788a845c57755..b25e48a164e6b 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -92,7 +92,7 @@ protected void handleMessage(
         OpenBlocks msg = (OpenBlocks) msgObj;
         checkAuth(client, msg.appId);
         long streamId = streamManager.registerStream(client.getClientId(),
-          new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
+          new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel());
         if (logger.isTraceEnabled()) {
           logger.trace("Registered streamId {} with {} buffers for client {} from host {}",
                        streamId,
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 4cc9a16e1449f..537c277cd26b5 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -103,7 +103,8 @@ public void testOpenShuffleBlocks() {
     @SuppressWarnings("unchecked")
     ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
         (ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
-    verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
+    verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
+      any());
     Iterator<ManagedBuffer> buffers = stream.getValue();
     assertEquals(block0Marker, buffers.next());
     assertEquals(block1Marker, buffers.next());
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 7076701421e2e..27f4f94ea55f8 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -59,7 +59,8 @@ class NettyBlockRpcServer(
         val blocksNum = openBlocks.blockIds.length
         val blocks = for (i <- (0 until blocksNum).view)
           yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
-        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
+        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
+          client.getChannel)
         logTrace(s"Registered streamId $streamId with $blocksNum buffers")
         responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
 


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org