You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2018/05/02 09:29:53 UTC

[15/16] flink git commit: [FLINK-9256] [network] Fix NPE in SingleInputGate#updateInputChannel() for non-credit based flow control

[FLINK-9256] [network] Fix NPE in SingleInputGate#updateInputChannel() for non-credit based flow control

This closes #5914


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/56e2b0b5
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/56e2b0b5
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/56e2b0b5

Branch: refs/heads/release-1.5
Commit: 56e2b0b5d600935eae590a985643a5879f224d04
Parents: f1fa517
Author: Nico Kruber <ni...@data-artisans.com>
Authored: Wed Apr 25 18:28:48 2018 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Mon Apr 30 23:25:38 2018 +0200

----------------------------------------------------------------------
 .../runtime/io/network/NetworkEnvironment.java  |   4 +
 .../partition/consumer/SingleInputGate.java     |  20 +-
 .../io/network/NetworkEnvironmentTest.java      |   5 +-
 .../PartitionRequestClientHandlerTest.java      |   3 +-
 .../partition/InputGateConcurrentTest.java      |   9 +-
 .../partition/InputGateFairnessTest.java        |  17 +-
 .../consumer/LocalInputChannelTest.java         |   6 +-
 .../consumer/RemoteInputChannelTest.java        |   3 +-
 .../partition/consumer/SingleInputGateTest.java | 318 +++++++++++++++----
 .../partition/consumer/TestSingleInputGate.java |   3 +-
 .../partition/consumer/UnionInputGateTest.java  |   6 +-
 11 files changed, 301 insertions(+), 93 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
index 0a9dc0f..f254756 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
@@ -157,6 +157,10 @@ public class NetworkEnvironment {
 		return partitionRequestMaxBackoff;
 	}
 
+	public boolean isCreditBased() {
+		return enableCreditBased;
+	}
+
 	public KvStateRegistry getKvStateRegistry() {
 		return kvStateRegistry;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index b9091b2..06e80ff 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -157,9 +157,11 @@ public class SingleInputGate implements InputGate {
 	 */
 	private BufferPool bufferPool;
 
-	/** Global network buffer pool to request and recycle exclusive buffers. */
+	/** Global network buffer pool to request and recycle exclusive buffers (only for credit-based). */
 	private NetworkBufferPool networkBufferPool;
 
+	private final boolean isCreditBased;
+
 	private boolean hasReceivedAllEndOfPartitionEvents;
 
 	/** Flag indicating whether partitions have been requested. */
@@ -189,7 +191,8 @@ public class SingleInputGate implements InputGate {
 		int consumedSubpartitionIndex,
 		int numberOfInputChannels,
 		TaskActions taskActions,
-		TaskIOMetricGroup metrics) {
+		TaskIOMetricGroup metrics,
+		boolean isCreditBased) {
 
 		this.owningTaskName = checkNotNull(owningTaskName);
 		this.jobId = checkNotNull(jobId);
@@ -208,6 +211,7 @@ public class SingleInputGate implements InputGate {
 		this.enqueuedInputChannelsWithData = new BitSet(numberOfInputChannels);
 
 		this.taskActions = checkNotNull(taskActions);
+		this.isCreditBased = isCreditBased;
 	}
 
 	// ------------------------------------------------------------------------
@@ -288,6 +292,7 @@ public class SingleInputGate implements InputGate {
 	 * @param networkBuffersPerChannel The number of exclusive buffers for each channel
 	 */
 	public void assignExclusiveSegments(NetworkBufferPool networkBufferPool, int networkBuffersPerChannel) throws IOException {
+		checkState(this.isCreditBased, "Bug in input gate setup logic: exclusive buffers only exist with credit-based flow control.");
 		checkState(this.networkBufferPool == null, "Bug in input gate setup logic: global buffer pool has" +
 			"already been set for this input gate.");
 
@@ -347,8 +352,13 @@ public class SingleInputGate implements InputGate {
 				}
 				else if (partitionLocation.isRemote()) {
 					newChannel = unknownChannel.toRemoteInputChannel(partitionLocation.getConnectionId());
-					((RemoteInputChannel)newChannel).assignExclusiveSegments(
-						networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+
+					if (this.isCreditBased) {
+						checkState(this.networkBufferPool != null, "Bug in input gate setup logic: " +
+							"global buffer pool has not been set for this input gate.");
+						((RemoteInputChannel) newChannel).assignExclusiveSegments(
+							networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+					}
 				}
 				else {
 					throw new IllegalStateException("Tried to update unknown channel with unknown channel.");
@@ -661,7 +671,7 @@ public class SingleInputGate implements InputGate {
 
 		final SingleInputGate inputGate = new SingleInputGate(
 			owningTaskName, jobId, consumedResultId, consumedPartitionType, consumedSubpartitionIndex,
-			icdd.length, taskActions, metrics);
+			icdd.length, taskActions, metrics, networkEnvironment.isCreditBased());
 
 		// Create the input channels. There is one input channel for each consumed partition.
 		final InputChannel[] inputChannels = new InputChannel[icdd.length];

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
index 317a214..f790b5f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
@@ -329,7 +329,7 @@ public class NetworkEnvironmentTest {
 	 *
 	 * @return input gate with some fake settings
 	 */
-	private static SingleInputGate createSingleInputGate(
+	private SingleInputGate createSingleInputGate(
 			final ResultPartitionType partitionType, final int channels) {
 		return spy(new SingleInputGate(
 			"Test Task Name",
@@ -339,7 +339,8 @@ public class NetworkEnvironmentTest {
 			0,
 			channels,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()));
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			enableCreditBasedFlowControl));
 	}
 
 	private static void createRemoteInputChannel(

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
index 13f7510..842aed8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
@@ -221,7 +221,8 @@ public class PartitionRequestClientHandlerTest {
 			0,
 			1,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			true);
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
index 289a398..73f3cfb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -66,7 +66,8 @@ public class InputGateConcurrentTest {
 				new IntermediateDataSetID(), ResultPartitionType.PIPELINED,
 				0, numChannels,
 				mock(TaskActions.class),
-				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+				true);
 
 		for (int i = 0; i < numChannels; i++) {
 			LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
@@ -102,7 +103,8 @@ public class InputGateConcurrentTest {
 				0,
 				numChannels,
 				mock(TaskActions.class),
-				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+				true);
 
 		for (int i = 0; i < numChannels; i++) {
 			RemoteInputChannel channel = new RemoteInputChannel(
@@ -151,7 +153,8 @@ public class InputGateConcurrentTest {
 				0,
 				numChannels,
 				mock(TaskActions.class),
-				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+				true);
 
 		for (int i = 0, local = 0; i < numChannels; i++) {
 			if (localOrRemote.get(i)) {

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
index 45df56f..82a27cc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -93,7 +93,8 @@ public class InputGateFairnessTest {
 				new IntermediateDataSetID(),
 				0, numChannels,
 				mock(TaskActions.class),
-				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+				true);
 
 		for (int i = 0; i < numChannels; i++) {
 			LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
@@ -146,7 +147,8 @@ public class InputGateFairnessTest {
 				new IntermediateDataSetID(),
 				0, numChannels,
 				mock(TaskActions.class),
-				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+				true);
 
 			for (int i = 0; i < numChannels; i++) {
 				LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
@@ -196,7 +198,8 @@ public class InputGateFairnessTest {
 				new IntermediateDataSetID(),
 				0, numChannels,
 				mock(TaskActions.class),
-				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+				true);
 
 		final ConnectionManager connManager = createDummyConnectionManager();
 
@@ -251,7 +254,8 @@ public class InputGateFairnessTest {
 				new IntermediateDataSetID(),
 				0, numChannels,
 				mock(TaskActions.class),
-				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+				UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+				true);
 
 		final ConnectionManager connManager = createDummyConnectionManager();
 
@@ -349,11 +353,12 @@ public class InputGateFairnessTest {
 				int consumedSubpartitionIndex,
 				int numberOfInputChannels,
 				TaskActions taskActions,
-				TaskIOMetricGroup metrics) {
+				TaskIOMetricGroup metrics,
+				boolean isCreditBased) {
 
 			super(owningTaskName, jobId, consumedResultId, ResultPartitionType.PIPELINED,
 				consumedSubpartitionIndex,
-					numberOfInputChannels, taskActions, metrics);
+					numberOfInputChannels, taskActions, metrics, isCreditBased);
 
 			try {
 				Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData");

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
index c78b7b9..1ecb67f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
@@ -293,7 +293,8 @@ public class LocalInputChannelTest {
 			0,
 			1,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			true
 		);
 
 		ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
@@ -490,7 +491,8 @@ public class LocalInputChannelTest {
 					subpartitionIndex,
 					numberOfInputChannels,
 					mock(TaskActions.class),
-					UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+					UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+					true);
 
 			// Set buffer pool
 			inputGate.setBufferPool(bufferPool);

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index 97a5688..802cb93 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -889,7 +889,8 @@ public class RemoteInputChannelTest {
 			0,
 			1,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			true);
 	}
 
 	private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate)

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 8c54c1f..c244668 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -19,13 +19,13 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionLocation;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
 import org.apache.flink.runtime.io.network.LocalConnectionManager;
@@ -45,31 +45,51 @@ import org.apache.flink.runtime.io.network.util.TestTaskEvent;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.taskmanager.TaskActions;
 
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
-import static org.mockito.Matchers.anyListOf;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+/**
+ * Tests for {@link SingleInputGate}.
+ */
+@RunWith(Parameterized.class)
 public class SingleInputGateTest {
 
+	@Parameterized.Parameter
+	public boolean enableCreditBasedFlowControl;
+
+	@Parameterized.Parameters(name = "Credit-based = {0}")
+	public static List<Boolean> parameters() {
+		return Arrays.asList(Boolean.TRUE, Boolean.FALSE);
+	}
+
 	/**
 	 * Tests basic correctness of buffer-or-event interleaving and correct <code>null</code> return
 	 * value after receiving all end-of-partition events.
@@ -324,12 +344,7 @@ public class SingleInputGateTest {
 		int initialBackoff = 137;
 		int maxBackoff = 1001;
 
-		NetworkEnvironment netEnv = mock(NetworkEnvironment.class);
-		when(netEnv.getResultPartitionManager()).thenReturn(new ResultPartitionManager());
-		when(netEnv.getTaskEventDispatcher()).thenReturn(new TaskEventDispatcher());
-		when(netEnv.getPartitionRequestInitialBackoff()).thenReturn(initialBackoff);
-		when(netEnv.getPartitionRequestMaxBackoff()).thenReturn(maxBackoff);
-		when(netEnv.getConnectionManager()).thenReturn(new LocalConnectionManager());
+		final NetworkEnvironment netEnv = createNetworkEnvironment(2, 8, initialBackoff, maxBackoff);
 
 		SingleInputGate gate = SingleInputGate.create(
 			"TestTask",
@@ -340,37 +355,43 @@ public class SingleInputGateTest {
 			mock(TaskActions.class),
 			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
 
-		assertEquals(gateDesc.getConsumedPartitionType(), gate.getConsumedPartitionType());
+		try {
+			assertEquals(gateDesc.getConsumedPartitionType(), gate.getConsumedPartitionType());
 
-		Map<IntermediateResultPartitionID, InputChannel> channelMap = gate.getInputChannels();
+			Map<IntermediateResultPartitionID, InputChannel> channelMap = gate.getInputChannels();
 
-		assertEquals(3, channelMap.size());
-		InputChannel localChannel = channelMap.get(partitionIds[0].getPartitionId());
-		assertEquals(LocalInputChannel.class, localChannel.getClass());
+			assertEquals(3, channelMap.size());
+			InputChannel localChannel = channelMap.get(partitionIds[0].getPartitionId());
+			assertEquals(LocalInputChannel.class, localChannel.getClass());
 
-		InputChannel remoteChannel = channelMap.get(partitionIds[1].getPartitionId());
-		assertEquals(RemoteInputChannel.class, remoteChannel.getClass());
+			InputChannel remoteChannel = channelMap.get(partitionIds[1].getPartitionId());
+			assertEquals(RemoteInputChannel.class, remoteChannel.getClass());
 
-		InputChannel unknownChannel = channelMap.get(partitionIds[2].getPartitionId());
-		assertEquals(UnknownInputChannel.class, unknownChannel.getClass());
+			InputChannel unknownChannel = channelMap.get(partitionIds[2].getPartitionId());
+			assertEquals(UnknownInputChannel.class, unknownChannel.getClass());
 
-		InputChannel[] channels = new InputChannel[]{localChannel, remoteChannel, unknownChannel};
-		for (InputChannel ch : channels) {
-			assertEquals(0, ch.getCurrentBackoff());
+			InputChannel[] channels =
+				new InputChannel[] {localChannel, remoteChannel, unknownChannel};
+			for (InputChannel ch : channels) {
+				assertEquals(0, ch.getCurrentBackoff());
 
-			assertTrue(ch.increaseBackoff());
-			assertEquals(initialBackoff, ch.getCurrentBackoff());
+				assertTrue(ch.increaseBackoff());
+				assertEquals(initialBackoff, ch.getCurrentBackoff());
 
-			assertTrue(ch.increaseBackoff());
-			assertEquals(initialBackoff * 2, ch.getCurrentBackoff());
+				assertTrue(ch.increaseBackoff());
+				assertEquals(initialBackoff * 2, ch.getCurrentBackoff());
 
-			assertTrue(ch.increaseBackoff());
-			assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff());
+				assertTrue(ch.increaseBackoff());
+				assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff());
 
-			assertTrue(ch.increaseBackoff());
-			assertEquals(maxBackoff, ch.getCurrentBackoff());
+				assertTrue(ch.increaseBackoff());
+				assertEquals(maxBackoff, ch.getCurrentBackoff());
 
-			assertFalse(ch.increaseBackoff());
+				assertFalse(ch.increaseBackoff());
+			}
+		} finally {
+			gate.releaseAllResources();
+			netEnv.shutdown();
 		}
 	}
 
@@ -379,26 +400,39 @@ public class SingleInputGateTest {
 	 */
 	@Test
 	public void testRequestBuffersWithRemoteInputChannel() throws Exception {
-		final SingleInputGate inputGate = new SingleInputGate(
-			"t1",
-			new JobID(),
-			new IntermediateDataSetID(),
-			ResultPartitionType.PIPELINED_BOUNDED,
-			0,
-			1,
-			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
-
-		RemoteInputChannel remote = mock(RemoteInputChannel.class);
-		inputGate.setInputChannel(new IntermediateResultPartitionID(), remote);
-
-		final int buffersPerChannel = 2;
-		NetworkBufferPool network = mock(NetworkBufferPool.class);
-		// Trigger requests of segments from global pool and assign buffers to remote input channel
-		inputGate.assignExclusiveSegments(network, buffersPerChannel);
-
-		verify(network, times(1)).requestMemorySegments(buffersPerChannel);
-		verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+		final SingleInputGate inputGate = createInputGate(1, ResultPartitionType.PIPELINED_BOUNDED);
+		int buffersPerChannel = 2;
+		int extraNetworkBuffersPerGate = 8;
+		final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel,
+			extraNetworkBuffersPerGate, 0, 0);
+
+		try {
+			final ResultPartitionID resultPartitionId = new ResultPartitionID();
+			final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
+			addRemoteInputChannel(network, inputGate, connectionId, resultPartitionId, 0);
+
+			network.setupInputGate(inputGate);
+
+			NetworkBufferPool bufferPool = network.getNetworkBufferPool();
+			if (enableCreditBasedFlowControl) {
+				verify(bufferPool,
+					times(1)).requestMemorySegments(buffersPerChannel);
+				RemoteInputChannel remote = (RemoteInputChannel) inputGate.getInputChannels()
+					.get(resultPartitionId.getPartitionId());
+				// only the exclusive buffers should be assigned/available now
+				assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers());
+
+				assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel,
+					bufferPool.getNumberOfAvailableMemorySegments());
+				// note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted
+				assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+			} else {
+				assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers());
+			}
+		} finally {
+			inputGate.releaseAllResources();
+			network.shutdown();
+		}
 	}
 
 	/**
@@ -407,51 +441,195 @@ public class SingleInputGateTest {
 	 */
 	@Test
 	public void testRequestBuffersWithUnknownInputChannel() throws Exception {
-		final SingleInputGate inputGate = createInputGate(1);
+		final SingleInputGate inputGate = createInputGate(1, ResultPartitionType.PIPELINED_BOUNDED);
+		int buffersPerChannel = 2;
+		int extraNetworkBuffersPerGate = 8;
+		final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel, extraNetworkBuffersPerGate, 0, 0);
 
-		UnknownInputChannel unknown = mock(UnknownInputChannel.class);
-		final ResultPartitionID resultPartitionId = new ResultPartitionID();
-		inputGate.setInputChannel(resultPartitionId.getPartitionId(), unknown);
+		try {
+			final ResultPartitionID resultPartitionId = new ResultPartitionID();
+			addUnknownInputChannel(network, inputGate, resultPartitionId, 0);
 
-		RemoteInputChannel remote = mock(RemoteInputChannel.class);
-		final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
-		when(unknown.toRemoteInputChannel(connectionId)).thenReturn(remote);
+			network.setupInputGate(inputGate);
+			NetworkBufferPool bufferPool = network.getNetworkBufferPool();
 
-		final int buffersPerChannel = 2;
-		NetworkBufferPool network = mock(NetworkBufferPool.class);
-		inputGate.assignExclusiveSegments(network, buffersPerChannel);
+			if (enableCreditBasedFlowControl) {
+				verify(bufferPool, times(0)).requestMemorySegments(buffersPerChannel);
 
-		// Trigger updates to remote input channel from unknown input channel
-		inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
-			resultPartitionId,
-			ResultPartitionLocation.createRemote(connectionId)));
+				assertEquals(bufferPool.getTotalNumberOfMemorySegments(),
+					bufferPool.getNumberOfAvailableMemorySegments());
+				// note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted
+				assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+			} else {
+				assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers());
+			}
+
+			// Trigger updates to remote input channel from unknown input channel
+			final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
+			inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
+				resultPartitionId,
+				ResultPartitionLocation.createRemote(connectionId)));
+
+			if (enableCreditBasedFlowControl) {
+				verify(bufferPool,
+					times(1)).requestMemorySegments(buffersPerChannel);
+				RemoteInputChannel remote = (RemoteInputChannel) inputGate.getInputChannels()
+					.get(resultPartitionId.getPartitionId());
+				// only the exclusive buffers should be assigned/available now
+				assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers());
+
+				assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel,
+					bufferPool.getNumberOfAvailableMemorySegments());
+				// note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted
+				assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+			} else {
+				assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers());
+			}
+		} finally {
+			inputGate.releaseAllResources();
+			network.shutdown();
+		}
+	}
 
-		verify(network, times(1)).requestMemorySegments(buffersPerChannel);
-		verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+	/**
+	 * Tests that input gate can successfully convert unknown input channels into local and remote
+	 * channels.
+	 */
+	@Test
+	public void testUpdateUnknownInputChannel() throws Exception {
+		final SingleInputGate inputGate = createInputGate(2);
+		int buffersPerChannel = 2;
+		final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel, 8, 0, 0);
+
+		try {
+			final ResultPartitionID localResultPartitionId = new ResultPartitionID();
+			addUnknownInputChannel(network, inputGate, localResultPartitionId, 0);
+
+			final ResultPartitionID remoteResultPartitionId = new ResultPartitionID();
+			addUnknownInputChannel(network, inputGate, remoteResultPartitionId, 1);
+
+			network.setupInputGate(inputGate);
+
+			assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+				is(instanceOf((UnknownInputChannel.class))));
+			assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+				is(instanceOf((UnknownInputChannel.class))));
+
+			// Trigger updates to remote input channel from unknown input channel
+			final ConnectionID remoteConnectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
+			inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
+				remoteResultPartitionId,
+				ResultPartitionLocation.createRemote(remoteConnectionId)));
+
+			assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+				is(instanceOf((RemoteInputChannel.class))));
+			assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+				is(instanceOf((UnknownInputChannel.class))));
+
+			// Trigger updates to local input channel from unknown input channel
+			inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
+				localResultPartitionId,
+				ResultPartitionLocation.createLocal()));
+
+			assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+				is(instanceOf((RemoteInputChannel.class))));
+			assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+				is(instanceOf((LocalInputChannel.class))));
+		} finally {
+			inputGate.releaseAllResources();
+			network.shutdown();
+		}
 	}
 
 	// ---------------------------------------------------------------------------------------------
 
-	private static SingleInputGate createInputGate() {
+	private NetworkEnvironment createNetworkEnvironment(
+			int buffersPerChannel,
+			int extraNetworkBuffersPerGate,
+			int initialBackoff,
+			int maxBackoff) {
+		return new NetworkEnvironment(
+			spy(new NetworkBufferPool(100, 32)),
+			new LocalConnectionManager(),
+			new ResultPartitionManager(),
+			new TaskEventDispatcher(),
+			new KvStateRegistry(),
+			null,
+			null,
+			IOManager.IOMode.SYNC,
+			initialBackoff,
+			maxBackoff,
+			buffersPerChannel,
+			extraNetworkBuffersPerGate,
+			enableCreditBasedFlowControl);
+	}
+
+	private SingleInputGate createInputGate() {
 		return createInputGate(2);
 	}
 
-	private static SingleInputGate createInputGate(int numberOfInputChannels) {
+	private SingleInputGate createInputGate(int numberOfInputChannels) {
+		return createInputGate(numberOfInputChannels, ResultPartitionType.PIPELINED);
+	}
+
+	private SingleInputGate createInputGate(
+			int numberOfInputChannels, ResultPartitionType partitionType) {
 		SingleInputGate inputGate = new SingleInputGate(
 			"Test Task Name",
 			new JobID(),
 			new IntermediateDataSetID(),
-			ResultPartitionType.PIPELINED,
+			partitionType,
 			0,
 			numberOfInputChannels,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			enableCreditBasedFlowControl);
 
-		assertEquals(ResultPartitionType.PIPELINED, inputGate.getConsumedPartitionType());
+		assertEquals(partitionType, inputGate.getConsumedPartitionType());
 
 		return inputGate;
 	}
 
+	private void addUnknownInputChannel(
+			NetworkEnvironment network,
+			SingleInputGate inputGate,
+			ResultPartitionID partitionId,
+			int channelIndex) {
+		UnknownInputChannel unknown =
+			createUnknownInputChannel(network, inputGate, partitionId, channelIndex);
+		inputGate.setInputChannel(partitionId.getPartitionId(), unknown);
+	}
+
+	private UnknownInputChannel createUnknownInputChannel(
+			NetworkEnvironment network,
+			SingleInputGate inputGate,
+			ResultPartitionID partitionId,
+			int channelIndex) {
+		return new UnknownInputChannel(
+			inputGate,
+			channelIndex,
+			partitionId,
+			network.getResultPartitionManager(),
+			network.getTaskEventDispatcher(),
+			network.getConnectionManager(),
+			network.getPartitionRequestInitialBackoff(),
+			network.getPartitionRequestMaxBackoff(),
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()
+		);
+	}
+
+	private void addRemoteInputChannel(
+			NetworkEnvironment network,
+			SingleInputGate inputGate,
+			ConnectionID connectionId,
+			ResultPartitionID partitionId,
+			int channelIndex) {
+		RemoteInputChannel remote =
+			createUnknownInputChannel(network, inputGate, partitionId, channelIndex)
+				.toRemoteInputChannel(connectionId);
+		inputGate.setInputChannel(partitionId.getPartitionId(), remote);
+	}
+
 	static void verifyBufferOrEvent(
 			InputGate inputGate,
 			boolean expectedIsBuffer,

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
index 0ae6e74..33dc1ca 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
@@ -60,7 +60,8 @@ public class TestSingleInputGate {
 			0,
 			numberOfInputChannels,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			true);
 
 		this.inputGate = spy(realGate);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index 912cd5b..081d97d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -50,13 +50,15 @@ public class UnionInputGateTest {
 			new IntermediateDataSetID(), ResultPartitionType.PIPELINED,
 			0, 3,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			true);
 		final SingleInputGate ig2 = new SingleInputGate(
 			testTaskName, new JobID(),
 			new IntermediateDataSetID(), ResultPartitionType.PIPELINED,
 			0, 5,
 			mock(TaskActions.class),
-			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+			UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+			true);
 
 		final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2});