You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by GitBox <gi...@apache.org> on 2019/01/16 03:00:51 UTC
[spark] Diff for: [GitHub] cloud-fan closed pull request #23521:
[SPARK-26604][CORE] Clean up channel registration for StreamManager
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
index f08d8b0f984cf..43c3d23b6304d 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
@@ -90,7 +90,6 @@ protected void channelRead0(
ManagedBuffer buf;
try {
streamManager.checkAuthorization(client, msg.streamChunkId.streamId);
- streamManager.registerChannel(channel, msg.streamChunkId.streamId);
buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format("Error opening block %s for request from %s",
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index 0f6a8824d95e5..6fafcc131fa24 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -23,6 +23,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -49,7 +50,7 @@
final Iterator<ManagedBuffer> buffers;
// The channel associated to the stream
- Channel associatedChannel = null;
+ final Channel associatedChannel;
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
@@ -58,9 +59,10 @@
// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;
- StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+ StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
+ this.associatedChannel = channel;
}
}
@@ -71,13 +73,6 @@ public OneForOneStreamManager() {
streams = new ConcurrentHashMap<>();
}
- @Override
- public void registerChannel(Channel channel, long streamId) {
- if (streams.containsKey(streamId)) {
- streams.get(streamId).associatedChannel = channel;
- }
- }
-
@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId);
@@ -195,11 +190,19 @@ public long chunksBeingTransferred() {
*
* If an app ID is provided, only callers who've authenticated with the given app ID will be
* allowed to fetch from this stream.
+ *
+ * This method also associates the stream with a single client connection, which is guaranteed
+ * to be the only reader of the stream. Once the connection is closed, the stream will never
+ * be used again, enabling cleanup by `connectionTerminated`.
*/
- public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
+ public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
long myStreamId = nextStreamId.getAndIncrement();
- streams.put(myStreamId, new StreamState(appId, buffers));
+ streams.put(myStreamId, new StreamState(appId, buffers, channel));
return myStreamId;
}
+ @VisibleForTesting
+ public int numStreamStates() {
+ return streams.size();
+ }
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
index c535295831606..e48d27be1126a 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -60,16 +60,6 @@ public ManagedBuffer openStream(String streamId) {
throw new UnsupportedOperationException();
}
- /**
- * Associates a stream with a single client connection, which is guaranteed to be the only reader
- * of the stream. The getChunk() method will be called serially on this connection and once the
- * connection is closed, the stream will never be used again, enabling cleanup.
- *
- * This must be called before the first getChunk() on the stream, but it may be invoked multiple
- * times with the same channel and stream id.
- */
- public void registerChannel(Channel channel, long streamId) { }
-
/**
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
* to read from the associated streams again, so any state can be cleaned up.
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java
index 2c72c53a33ae8..6c9239606bb8c 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java
@@ -64,8 +64,7 @@ public void handleChunkFetchRequest() throws Exception {
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
- long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
- streamManager.registerChannel(channel, streamId);
+ long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
TransportClient reverseClient = mock(TransportClient.class);
ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient,
rpcHandler.getStreamManager(), 2L);
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
index ad640415a8e6d..a87f6c11a2bfd 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java
@@ -58,8 +58,10 @@ public void handleStreamRequest() throws Exception {
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
- long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
- streamManager.registerChannel(channel, streamId);
+ long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
+
+ assert streamManager.numStreamStates() == 1;
+
TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L);
@@ -94,5 +96,8 @@ public void handleStreamRequest() throws Exception {
requestHandler.handle(request3);
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;
+
+ streamManager.connectionTerminated(channel);
+ assert streamManager.numStreamStates() == 0;
}
}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
index c647525d8f1bd..4248762c32389 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
@@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
buffers.add(buffer1);
buffers.add(buffer2);
- long streamId = manager.registerStream("appId", buffers.iterator());
Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
- manager.registerChannel(dummyChannel, streamId);
+ manager.registerStream("appId", buffers.iterator(), dummyChannel);
+ assert manager.numStreamStates() == 1;
manager.connectionTerminated(dummyChannel);
Mockito.verify(buffer1, Mockito.times(1)).release();
Mockito.verify(buffer2, Mockito.times(1)).release();
+ assert manager.numStreamStates() == 0;
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 788a845c57755..b25e48a164e6b 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -92,7 +92,7 @@ protected void handleMessage(
OpenBlocks msg = (OpenBlocks) msgObj;
checkAuth(client, msg.appId);
long streamId = streamManager.registerStream(client.getClientId(),
- new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
+ new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel());
if (logger.isTraceEnabled()) {
logger.trace("Registered streamId {} with {} buffers for client {} from host {}",
streamId,
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 4cc9a16e1449f..537c277cd26b5 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -103,7 +103,8 @@ public void testOpenShuffleBlocks() {
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
- verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
+ verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
+ any());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 7076701421e2e..27f4f94ea55f8 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -59,7 +59,8 @@ class NettyBlockRpcServer(
val blocksNum = openBlocks.blockIds.length
val blocks = for (i <- (0 until blocksNum).view)
yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
- val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
+ val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
+ client.getChannel)
logTrace(s"Registered streamId $streamId with $blocksNum buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org