You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2019/05/22 10:26:33 UTC

[flink] 08/10: [hotfix][network] Move MemorySegmentProvider from SingleInputGate to RemoteInputChannel

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 8589c82b4cbc8a73417974107186f9ccc3a629d3
Author: Andrey Zagrebin <az...@gmail.com>
AuthorDate: Fri May 10 20:25:59 2019 +0200

    [hotfix][network] Move MemorySegmentProvider from SingleInputGate to RemoteInputChannel
---
 .../runtime/io/network/NetworkEnvironment.java     |  6 +-
 .../partition/consumer/RemoteInputChannel.java     | 21 +++++--
 .../partition/consumer/SingleInputGate.java        | 33 ++--------
 .../partition/consumer/SingleInputGateFactory.java | 14 ++++-
 .../partition/consumer/UnknownInputChannel.java    | 13 +++-
 .../runtime/io/network/NetworkEnvironmentTest.java | 41 +++++++------
 ...editBasedPartitionRequestClientHandlerTest.java | 21 ++++---
 .../netty/PartitionRequestClientHandlerTest.java   | 54 ++--------------
 .../network/netty/PartitionRequestClientTest.java  | 17 ++++--
 .../network/partition/InputChannelTestUtils.java   | 71 ++++++++++++++++++++++
 .../partition/consumer/InputChannelBuilder.java    | 14 ++++-
 .../partition/consumer/RemoteInputChannelTest.java | 63 +++++++++++++------
 .../partition/consumer/SingleInputGateTest.java    |  3 +-
 .../StreamNetworkBenchmarkEnvironment.java         |  3 +-
 14 files changed, 225 insertions(+), 149 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
index 97016ef..459669c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
@@ -138,8 +138,8 @@ public class NetworkEnvironment {
 		ResultPartitionFactory resultPartitionFactory =
 			new ResultPartitionFactory(resultPartitionManager, ioManager);
 
-		SingleInputGateFactory singleInputGateFactory =
-			new SingleInputGateFactory(config, connectionManager, resultPartitionManager, taskEventPublisher);
+		SingleInputGateFactory singleInputGateFactory = new SingleInputGateFactory(
+			config, connectionManager, resultPartitionManager, taskEventPublisher, networkBufferPool);
 
 		return new NetworkEnvironment(
 			config,
@@ -247,7 +247,7 @@ public class NetworkEnvironment {
 					config.floatingNetworkBuffersPerGate() : Integer.MAX_VALUE;
 
 				// assign exclusive buffers to input channels directly and use the rest for floating buffers
-				gate.assignExclusiveSegments(networkBufferPool);
+				gate.assignExclusiveSegments();
 				bufferPool = networkBufferPool.createBufferPool(0, maxNumberOfMemorySegments);
 			} else {
 				maxNumberOfMemorySegments = gate.getConsumedPartitionType().isBounded() ?
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
index 98182c6..397c4fe 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
@@ -34,6 +35,7 @@ import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.util.ExceptionUtils;
 
+import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy;
 
@@ -103,6 +105,10 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
 	@GuardedBy("bufferQueue")
 	private boolean isWaitingForFloatingBuffers;
 
+	/** Global memory segment provider to request and recycle exclusive buffers (only for credit-based). */
+	@Nonnull
+	private final MemorySegmentProvider memorySegmentProvider;
+
 	public RemoteInputChannel(
 		SingleInputGate inputGate,
 		int channelIndex,
@@ -111,23 +117,26 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
 		ConnectionManager connectionManager,
 		int initialBackOff,
 		int maxBackoff,
-		InputChannelMetrics metrics) {
+		InputChannelMetrics metrics,
+		@Nonnull MemorySegmentProvider memorySegmentProvider) {
 
-		super(inputGate, channelIndex, partitionId, initialBackOff, maxBackoff, metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter());
+		super(inputGate, channelIndex, partitionId, initialBackOff, maxBackoff,
+			metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter());
 
 		this.connectionId = checkNotNull(connectionId);
 		this.connectionManager = checkNotNull(connectionManager);
+		this.memorySegmentProvider = memorySegmentProvider;
 	}
 
 	/**
 	 * Assigns exclusive buffers to this input channel, and this method should be called only once
 	 * after this input channel is created.
 	 */
-	void assignExclusiveSegments(Collection<MemorySegment> segments) {
+	void assignExclusiveSegments() throws IOException {
 		checkState(this.initialCredit == 0, "Bug in input channel setup logic: exclusive buffers have " +
 			"already been set for this input channel.");
 
-		checkNotNull(segments);
+		Collection<MemorySegment> segments = checkNotNull(memorySegmentProvider.requestMemorySegments());
 		checkArgument(!segments.isEmpty(), "The number of exclusive buffers per channel should be larger than 0.");
 
 		this.initialCredit = segments.size();
@@ -247,7 +256,7 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
 			}
 
 			if (exclusiveRecyclingSegments.size() > 0) {
-				inputGate.returnExclusiveSegments(exclusiveRecyclingSegments);
+				memorySegmentProvider.recycleMemorySegments(exclusiveRecyclingSegments);
 			}
 
 			// The released flag has to be set before closing the connection to ensure that
@@ -297,7 +306,7 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
 			// after releaseAllResources() released all buffers (see below for details).
 			if (isReleased.get()) {
 				try {
-					inputGate.returnExclusiveSegments(Collections.singletonList(segment));
+					memorySegmentProvider.recycleMemorySegments(Collections.singletonList(segment));
 					return;
 				} catch (Throwable t) {
 					ExceptionUtils.rethrow(t);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index 6b8cec1..087d912 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -18,9 +18,8 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.core.memory.MemorySegment;
-import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionLocation;
@@ -155,9 +154,6 @@ public class SingleInputGate extends InputGate {
 	 */
 	private BufferPool bufferPool;
 
-	/** Global memory segment provider to request and recycle exclusive buffers (only for credit-based). */
-	private MemorySegmentProvider memorySegmentProvider;
-
 	private final boolean isCreditBased;
 
 	private boolean hasReceivedAllEndOfPartitionEvents;
@@ -289,35 +285,19 @@ public class SingleInputGate extends InputGate {
 
 	/**
 	 * Assign the exclusive buffers to all remote input channels directly for credit-based mode.
-	 *
-	 * @param memorySegmentProvider The global memory segment provider to request and recycle exclusive buffers
 	 */
-	public void assignExclusiveSegments(MemorySegmentProvider memorySegmentProvider) throws IOException {
+	@VisibleForTesting
+	public void assignExclusiveSegments() throws IOException {
 		checkState(this.isCreditBased, "Bug in input gate setup logic: exclusive buffers only exist with credit-based flow control.");
-		checkState(this.memorySegmentProvider == null,
-			"Bug in input gate setup logic: global memory segment provider has already been set for this input gate.");
-
-		this.memorySegmentProvider = checkNotNull(memorySegmentProvider);
-
 		synchronized (requestLock) {
 			for (InputChannel inputChannel : inputChannels.values()) {
 				if (inputChannel instanceof RemoteInputChannel) {
-					((RemoteInputChannel) inputChannel).assignExclusiveSegments(
-						memorySegmentProvider.requestMemorySegments());
+					((RemoteInputChannel) inputChannel).assignExclusiveSegments();
 				}
 			}
 		}
 	}
 
-	/**
-	 * The exclusive segments are recycled to network buffer pool directly when input channel is released.
-	 *
-	 * @param segments The exclusive segments need to be recycled
-	 */
-	public void returnExclusiveSegments(List<MemorySegment> segments) throws IOException {
-		memorySegmentProvider.recycleMemorySegments(segments);
-	}
-
 	public void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) {
 		synchronized (requestLock) {
 			if (inputChannels.put(checkNotNull(partitionId), checkNotNull(inputChannel)) == null
@@ -354,10 +334,7 @@ public class SingleInputGate extends InputGate {
 					newChannel = unknownChannel.toRemoteInputChannel(partitionLocation.getConnectionId());
 
 					if (this.isCreditBased) {
-						checkState(this.memorySegmentProvider != null, "Bug in input gate setup logic: " +
-							"global buffer pool has not been set for this input gate.");
-						((RemoteInputChannel) newChannel).assignExclusiveSegments(
-							memorySegmentProvider.requestMemorySegments());
+						((RemoteInputChannel) newChannel).assignExclusiveSegments();
 					}
 				}
 				else {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java
index 60f511f..e3ed90e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
@@ -62,17 +63,22 @@ public class SingleInputGateFactory {
 	@Nonnull
 	private final TaskEventPublisher taskEventPublisher;
 
+	@Nonnull
+	private final MemorySegmentProvider memorySegmentProvider;
+
 	public SingleInputGateFactory(
 			@Nonnull NetworkEnvironmentConfiguration networkConfig,
 			@Nonnull ConnectionManager connectionManager,
 			@Nonnull ResultPartitionManager partitionManager,
-			@Nonnull TaskEventPublisher taskEventPublisher) {
+			@Nonnull TaskEventPublisher taskEventPublisher,
+			@Nonnull MemorySegmentProvider memorySegmentProvider) {
 		this.isCreditBased = networkConfig.isCreditBased();
 		this.partitionRequestInitialBackoff = networkConfig.partitionRequestInitialBackoff();
 		this.partitionRequestMaxBackoff = networkConfig.partitionRequestMaxBackoff();
 		this.connectionManager = connectionManager;
 		this.partitionManager = partitionManager;
 		this.taskEventPublisher = taskEventPublisher;
+		this.memorySegmentProvider = memorySegmentProvider;
 	}
 
 	/**
@@ -124,7 +130,8 @@ public class SingleInputGateFactory {
 					connectionManager,
 					partitionRequestInitialBackoff,
 					partitionRequestMaxBackoff,
-					metrics);
+					metrics,
+					memorySegmentProvider);
 
 				numRemoteChannels++;
 			}
@@ -135,7 +142,8 @@ public class SingleInputGateFactory {
 					connectionManager,
 					partitionRequestInitialBackoff,
 					partitionRequestMaxBackoff,
-					metrics);
+					metrics,
+					memorySegmentProvider);
 
 				numUnknownChannels++;
 			}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java
index 0ed01e2..826595f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
+import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
@@ -26,6 +27,8 @@ import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 
+import javax.annotation.Nonnull;
+
 import java.io.IOException;
 import java.util.Optional;
 
@@ -50,6 +53,9 @@ class UnknownInputChannel extends InputChannel {
 
 	private final InputChannelMetrics metrics;
 
+	@Nonnull
+	private final MemorySegmentProvider memorySegmentProvider;
+
 	public UnknownInputChannel(
 			SingleInputGate gate,
 			int channelIndex,
@@ -59,7 +65,8 @@ class UnknownInputChannel extends InputChannel {
 			ConnectionManager connectionManager,
 			int initialBackoff,
 			int maxBackoff,
-			InputChannelMetrics metrics) {
+			InputChannelMetrics metrics,
+			@Nonnull MemorySegmentProvider memorySegmentProvider) {
 
 		super(gate, channelIndex, partitionId, initialBackoff, maxBackoff, null, null);
 
@@ -69,6 +76,7 @@ class UnknownInputChannel extends InputChannel {
 		this.metrics = checkNotNull(metrics);
 		this.initialBackoff = initialBackoff;
 		this.maxBackoff = maxBackoff;
+		this.memorySegmentProvider = memorySegmentProvider;
 	}
 
 	@Override
@@ -118,7 +126,8 @@ class UnknownInputChannel extends InputChannel {
 	// ------------------------------------------------------------------------
 
 	public RemoteInputChannel toRemoteInputChannel(ConnectionID producerAddress) {
-		return new RemoteInputChannel(inputGate, channelIndex, partitionId, checkNotNull(producerAddress), connectionManager, initialBackoff, maxBackoff, metrics);
+		return new RemoteInputChannel(inputGate, channelIndex, partitionId, checkNotNull(producerAddress),
+			connectionManager, initialBackoff, maxBackoff, metrics, memorySegmentProvider);
 	}
 
 	public LocalInputChannel toLocalInputChannel() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
index c06bfa5..42c5a96 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network;
 
 import org.apache.flink.configuration.TaskManagerOptions;
+import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.runtime.io.network.partition.ResultPartition;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
@@ -117,10 +118,10 @@ public class NetworkEnvironmentTest {
 		assertEquals(enableCreditBasedFlowControl ? 8 : 8 * 2 + 8, ig4.getBufferPool().getMaxNumberOfMemorySegments());
 
 		int invokations = enableCreditBasedFlowControl ? 1 : 0;
-		verify(ig1, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
-		verify(ig2, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
-		verify(ig3, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
-		verify(ig4, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
+		verify(ig1, times(invokations)).assignExclusiveSegments();
+		verify(ig2, times(invokations)).assignExclusiveSegments();
+		verify(ig3, times(invokations)).assignExclusiveSegments();
+		verify(ig4, times(invokations)).assignExclusiveSegments();
 
 		for (ResultPartition rp : resultPartitions) {
 			rp.release();
@@ -197,19 +198,19 @@ public class NetworkEnvironmentTest {
 		// set up remote input channels for the exclusive buffers of the credit-based flow control
 		// (note that this does not obey the partition types which is ok for the scope of the test)
 		if (enableCreditBasedFlowControl) {
-			createRemoteInputChannel(ig4, 0, rp1, connManager);
-			createRemoteInputChannel(ig4, 0, rp2, connManager);
-			createRemoteInputChannel(ig4, 0, rp3, connManager);
-			createRemoteInputChannel(ig4, 0, rp4, connManager);
+			createRemoteInputChannel(ig4, 0, rp1, connManager, network.getNetworkBufferPool());
+			createRemoteInputChannel(ig4, 0, rp2, connManager, network.getNetworkBufferPool());
+			createRemoteInputChannel(ig4, 0, rp3, connManager, network.getNetworkBufferPool());
+			createRemoteInputChannel(ig4, 0, rp4, connManager, network.getNetworkBufferPool());
 
-			createRemoteInputChannel(ig1, 1, rp1, connManager);
-			createRemoteInputChannel(ig1, 1, rp4, connManager);
+			createRemoteInputChannel(ig1, 1, rp1, connManager, network.getNetworkBufferPool());
+			createRemoteInputChannel(ig1, 1, rp4, connManager, network.getNetworkBufferPool());
 
-			createRemoteInputChannel(ig2, 1, rp2, connManager);
-			createRemoteInputChannel(ig2, 2, rp4, connManager);
+			createRemoteInputChannel(ig2, 1, rp2, connManager, network.getNetworkBufferPool());
+			createRemoteInputChannel(ig2, 2, rp4, connManager, network.getNetworkBufferPool());
 
-			createRemoteInputChannel(ig3, 1, rp3, connManager);
-			createRemoteInputChannel(ig3, 3, rp4, connManager);
+			createRemoteInputChannel(ig3, 1, rp3, connManager, network.getNetworkBufferPool());
+			createRemoteInputChannel(ig3, 3, rp4, connManager, network.getNetworkBufferPool());
 		}
 
 		// overall task to register
@@ -243,10 +244,10 @@ public class NetworkEnvironmentTest {
 		assertEquals(enableCreditBasedFlowControl ? 8 : 4 * 2 + 8, ig4.getBufferPool().getMaxNumberOfMemorySegments());
 
 		int invokations = enableCreditBasedFlowControl ? 1 : 0;
-		verify(ig1, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
-		verify(ig2, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
-		verify(ig3, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
-		verify(ig4, times(invokations)).assignExclusiveSegments(network.getNetworkBufferPool());
+		verify(ig1, times(invokations)).assignExclusiveSegments();
+		verify(ig2, times(invokations)).assignExclusiveSegments();
+		verify(ig3, times(invokations)).assignExclusiveSegments();
+		verify(ig4, times(invokations)).assignExclusiveSegments();
 
 		for (ResultPartition rp : resultPartitions) {
 			rp.release();
@@ -301,11 +302,13 @@ public class NetworkEnvironmentTest {
 			SingleInputGate inputGate,
 			int channelIndex,
 			ResultPartition resultPartition,
-			ConnectionManager connManager) {
+			ConnectionManager connManager,
+			MemorySegmentProvider memorySegmentProvider) {
 		InputChannelBuilder.newBuilder()
 			.setChannelIndex(channelIndex)
 			.setPartitionId(resultPartition.getPartitionId())
 			.setConnectionManager(connManager)
+			.setMemorySegmentProvider(memorySegmentProvider)
 			.buildRemoteAndSetToGate(inputGate);
 	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
index f2808fc..92ae98d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse;
 import org.apache.flink.runtime.io.network.netty.NettyMessage.PartitionRequest;
 import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
 import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
@@ -44,8 +45,8 @@ import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
 import org.junit.Test;
 
 import static org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createBufferResponse;
-import static org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createRemoteInputChannel;
 import static org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel;
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.junit.Assert.assertEquals;
@@ -138,11 +139,13 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 	public void testReceiveBuffer() throws Exception {
 		final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32, 2);
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = InputChannelBuilder.newBuilder()
+			.setMemorySegmentProvider(networkBufferPool)
+			.buildRemoteAndSetToGate(inputGate);
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(8, 8);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 
 			final CreditBasedPartitionRequestClientHandler handler = new CreditBasedPartitionRequestClientHandler();
 			handler.addInputChannel(inputChannel);
@@ -170,7 +173,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 	@Test
 	public void testThrowExceptionForNoAvailableBuffer() throws Exception {
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate));
+		final RemoteInputChannel inputChannel = spy(InputChannelBuilder.newBuilder().buildRemoteAndSetToGate(inputGate));
 
 		final CreditBasedPartitionRequestClientHandler handler = new CreditBasedPartitionRequestClientHandler();
 		handler.addInputChannel(inputChannel);
@@ -246,12 +249,12 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
 		final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32, 2);
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel1 = createRemoteInputChannel(inputGate, client);
-		final RemoteInputChannel inputChannel2 = createRemoteInputChannel(inputGate, client);
+		final RemoteInputChannel inputChannel1 = createRemoteInputChannel(inputGate, client, networkBufferPool);
+		final RemoteInputChannel inputChannel2 = createRemoteInputChannel(inputGate, client, networkBufferPool);
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 
 			inputChannel1.requestSubpartition(0);
 			inputChannel2.requestSubpartition(0);
@@ -346,11 +349,11 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
 		final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32, 2);
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, client);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, client, networkBufferPool);
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 
 			inputChannel.requestSubpartition(0);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
index 9017bf8..9cd7347 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.runtime.io.network.netty;
 
-import org.apache.flink.runtime.io.network.ConnectionID;
-import org.apache.flink.runtime.io.network.ConnectionManager;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferListener;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
@@ -129,11 +127,13 @@ public class PartitionRequestClientHandlerTest {
 	public void testReceiveBuffer() throws Exception {
 		final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32, 2);
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = InputChannelBuilder.newBuilder()
+			.setMemorySegmentProvider(networkBufferPool)
+			.buildRemoteAndSetToGate(inputGate);
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(8, 8);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 
 			final PartitionRequestClientHandler handler = new PartitionRequestClientHandler();
 			handler.addInputChannel(inputChannel);
@@ -204,52 +204,6 @@ public class PartitionRequestClientHandlerTest {
 	// ---------------------------------------------------------------------------------------------
 
 	/**
-	 * Creates and returns a remote input channel for the specific input gate.
-	 *
-	 * @param inputGate The input gate owns the created input channel.
-	 * @return The new created remote input channel.
-	 */
-	static RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) throws Exception {
-		return createRemoteInputChannel(inputGate, null);
-	}
-
-	/**
-	 * Creates and returns a remote input channel for the specific input gate with specific partition request client.
-	 *
-	 * @param inputGate The input gate owns the created input channel.
-	 * @param client The client is used to send partition request.
-	 * @return The new created remote input channel.
-	 */
-	static RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate, PartitionRequestClient client) throws Exception {
-		return createRemoteInputChannel(inputGate, client, 0, 0);
-	}
-
-	/**
-	 * Creates and returns a remote input channel for the specific input gate with specific partition request client.
-	 *
-	 * @param inputGate The input gate owns the created input channel.
-	 * @param client The client is used to send partition request.
-	 * @param initialBackoff initial back off (in ms) for retriggering subpartition requests (must be <tt>&gt; 0</tt> to activate)
-	 * @param maxBackoff after which delay (in ms) to stop retriggering subpartition requests
-	 * @return The new created remote input channel.
-	 */
-	static RemoteInputChannel createRemoteInputChannel(
-			SingleInputGate inputGate,
-			PartitionRequestClient client,
-			int initialBackoff,
-			int maxBackoff) throws Exception {
-		final ConnectionManager connectionManager = mock(ConnectionManager.class);
-		when(connectionManager.createPartitionRequestClient(any(ConnectionID.class)))
-			.thenReturn(client);
-
-		return InputChannelBuilder.newBuilder()
-			.setConnectionManager(connectionManager)
-			.setInitialBackoff(initialBackoff)
-			.setMaxBackoff(maxBackoff)
-			.buildRemoteAndSetToGate(inputGate);
-	}
-
-	/**
 	 * Returns a deserialized buffer message as it would be received during runtime.
 	 */
 	static BufferResponse createBufferResponse(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientTest.java
index 6b5560e..dcc3ad4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.netty.NettyMessage.PartitionRequest;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
 import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
 
@@ -29,8 +30,9 @@ import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
 
 import org.junit.Test;
 
-import static org.apache.flink.runtime.io.network.netty.PartitionRequestClientHandlerTest.createRemoteInputChannel;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel;
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.mockConnectionManagerWithPartitionRequestClient;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
@@ -55,12 +57,17 @@ public class PartitionRequestClientTest {
 		final int numExclusiveBuffers = 2;
 		final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32, numExclusiveBuffers);
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, client, 1, 2);
+		final RemoteInputChannel inputChannel = InputChannelBuilder.newBuilder()
+			.setConnectionManager(mockConnectionManagerWithPartitionRequestClient(client))
+			.setInitialBackoff(1)
+			.setMaxBackoff(2)
+			.setMemorySegmentProvider(networkBufferPool)
+			.buildRemoteAndSetToGate(inputGate);
 
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 
 			// first subpartition request
 			inputChannel.requestSubpartition(0);
@@ -109,12 +116,12 @@ public class PartitionRequestClientTest {
 		final int numExclusiveBuffers = 2;
 		final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32, numExclusiveBuffers);
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, client);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, client, networkBufferPool);
 
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			// The input channel should only send one partition request
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
index 806c556..49981ce 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
@@ -18,8 +18,11 @@
 
 package org.apache.flink.runtime.io.network.partition;
 
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.TaskEventPublisher;
 import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
 import org.apache.flink.runtime.io.network.netty.PartitionRequestClient;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
@@ -32,6 +35,9 @@ import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
+import java.util.Collection;
+import java.util.Collections;
+
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
 import static org.mockito.Mockito.mock;
@@ -121,6 +127,48 @@ public class InputChannelTestUtils {
 			.buildRemoteAndSetToGate(inputGate);
 	}
 
+	public static RemoteInputChannel createRemoteInputChannel(
+		SingleInputGate inputGate,
+		PartitionRequestClient client,
+		MemorySegmentProvider memorySegmentProvider) {
+
+		return InputChannelBuilder.newBuilder()
+			.setConnectionManager(mockConnectionManagerWithPartitionRequestClient(client))
+			.setMemorySegmentProvider(memorySegmentProvider)
+			.buildRemoteAndSetToGate(inputGate);
+	}
+
+	public static ConnectionManager mockConnectionManagerWithPartitionRequestClient(PartitionRequestClient client) {
+		return new ConnectionManager() {
+			@Override
+			public void start(ResultPartitionProvider partitionProvider, TaskEventPublisher taskEventDispatcher) {
+			}
+
+			@Override
+			public PartitionRequestClient createPartitionRequestClient(ConnectionID connectionId) {
+				return client;
+			}
+
+			@Override
+			public void closeOpenChannelConnections(ConnectionID connectionId) {
+			}
+
+			@Override
+			public int getNumberOfActiveConnections() {
+				return 0;
+			}
+
+			@Override
+			public int getDataPort() {
+				return 0;
+			}
+
+			@Override
+			public void shutdown() {
+			}
+		};
+	}
+
 	public static InputChannelMetrics newUnregisteredInputChannelMetrics() {
 		return new InputChannelMetrics(UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
 	}
@@ -129,4 +177,27 @@ public class InputChannelTestUtils {
 
 	/** This class is not meant to be instantiated. */
 	private InputChannelTestUtils() {}
+
+	/**
+	 * Test stub for {@link MemorySegmentProvider}.
+	 */
+	public static class StubMemorySegmentProvider implements MemorySegmentProvider {
+		private static final MemorySegmentProvider INSTANCE = new StubMemorySegmentProvider();
+
+		public static MemorySegmentProvider getInstance() {
+			return INSTANCE;
+		}
+
+		private StubMemorySegmentProvider() {
+		}
+
+		@Override
+		public Collection<MemorySegment> requestMemorySegments() {
+			return Collections.emptyList();
+		}
+
+		@Override
+		public void recycleMemorySegments(Collection<MemorySegment> segments) {
+		}
+	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java
index 549dcc5..4153e8c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
+import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
 import org.apache.flink.runtime.io.network.LocalConnectionManager;
@@ -47,6 +48,7 @@ public class InputChannelBuilder {
 	private int initialBackoff = 0;
 	private int maxBackoff = 0;
 	private InputChannelMetrics metrics = InputChannelTestUtils.newUnregisteredInputChannelMetrics();
+	private MemorySegmentProvider memorySegmentProvider = InputChannelTestUtils.StubMemorySegmentProvider.getInstance();
 
 	public static InputChannelBuilder newBuilder() {
 		return new InputChannelBuilder();
@@ -92,11 +94,17 @@ public class InputChannelBuilder {
 		return this;
 	}
 
+	public InputChannelBuilder setMemorySegmentProvider(MemorySegmentProvider memorySegmentProvider) {
+		this.memorySegmentProvider = memorySegmentProvider;
+		return this;
+	}
+
 	InputChannelBuilder setupFromNetworkEnvironment(NetworkEnvironment network) {
 		this.partitionManager = network.getResultPartitionManager();
 		this.connectionManager = network.getConnectionManager();
 		this.initialBackoff = network.getConfiguration().partitionRequestInitialBackoff();
 		this.maxBackoff = network.getConfiguration().partitionRequestMaxBackoff();
+		this.memorySegmentProvider = network.getNetworkBufferPool();
 		return this;
 	}
 
@@ -110,7 +118,8 @@ public class InputChannelBuilder {
 			connectionManager,
 			initialBackoff,
 			maxBackoff,
-			metrics);
+			metrics,
+			memorySegmentProvider);
 		inputGate.setInputChannel(partitionId.getPartitionId(), channel);
 		return channel;
 	}
@@ -138,7 +147,8 @@ public class InputChannelBuilder {
 			connectionManager,
 			initialBackoff,
 			maxBackoff,
-			metrics);
+			metrics,
+			memorySegmentProvider);
 		inputGate.setInputChannel(partitionId.getPartitionId(), channel);
 		return channel;
 	}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index 935ddb6..e7f0648 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
+import org.apache.flink.core.memory.MemorySegmentProvider;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
@@ -329,12 +330,12 @@ public class RemoteInputChannelTest {
 		final int numFloatingBuffers = 14;
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, networkBufferPool);
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			// Prepare the exclusive and floating buffers to verify recycle logic later
@@ -467,12 +468,12 @@ public class RemoteInputChannelTest {
 		final int numFloatingBuffers = 14;
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, networkBufferPool);
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			// Prepare the exclusive and floating buffers to verify recycle logic later
@@ -541,12 +542,12 @@ public class RemoteInputChannelTest {
 		final int numFloatingBuffers = 14;
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, networkBufferPool);
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			// Prepare the exclusive and floating buffers to verify recycle logic later
@@ -630,14 +631,14 @@ public class RemoteInputChannelTest {
 		final int numFloatingBuffers = 3;
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel channel1 = spy(createRemoteInputChannel(inputGate));
-		final RemoteInputChannel channel2 = spy(createRemoteInputChannel(inputGate));
-		final RemoteInputChannel channel3 = spy(createRemoteInputChannel(inputGate));
+		final RemoteInputChannel channel1 = spy(createRemoteInputChannel(inputGate, networkBufferPool));
+		final RemoteInputChannel channel2 = spy(createRemoteInputChannel(inputGate, networkBufferPool));
+		final RemoteInputChannel channel3 = spy(createRemoteInputChannel(inputGate, networkBufferPool));
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			channel1.requestSubpartition(0);
 			channel2.requestSubpartition(0);
 			channel3.requestSubpartition(0);
@@ -760,12 +761,12 @@ public class RemoteInputChannelTest {
 		final ExecutorService executor = Executors.newFixedThreadPool(2);
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel  = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, networkBufferPool);
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			final Callable<Void> requestBufferTask = new Callable<Void>() {
@@ -822,12 +823,12 @@ public class RemoteInputChannelTest {
 		final ExecutorService executor = Executors.newFixedThreadPool(3);
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel  = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, networkBufferPool);
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			final Callable<Void> requestBufferTask = new Callable<Void>() {
@@ -873,12 +874,12 @@ public class RemoteInputChannelTest {
 		final ExecutorService executor = Executors.newFixedThreadPool(3);
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel  = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, networkBufferPool);
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			final Callable<Void> releaseTask = new Callable<Void>() {
@@ -927,12 +928,12 @@ public class RemoteInputChannelTest {
 		final ExecutorService executor = Executors.newFixedThreadPool(2);
 
 		final SingleInputGate inputGate = createSingleInputGate(1);
-		final RemoteInputChannel inputChannel  = createRemoteInputChannel(inputGate);
+		final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate, networkBufferPool);
 		Throwable thrown = null;
 		try {
 			final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
 			inputGate.setBufferPool(bufferPool);
-			inputGate.assignExclusiveSegments(networkBufferPool);
+			inputGate.assignExclusiveSegments();
 			inputChannel.requestSubpartition(0);
 
 			final Callable<Void> bufferPoolInteractionsTask = () -> {
@@ -993,9 +994,18 @@ public class RemoteInputChannelTest {
 	// ---------------------------------------------------------------------------------------------
 
 	private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate)
-			throws IOException, InterruptedException {
+		throws IOException, InterruptedException {
+
+		return createRemoteInputChannel(inputGate, InputChannelTestUtils.StubMemorySegmentProvider.getInstance());
+	}
+
+	private RemoteInputChannel createRemoteInputChannel(
+		SingleInputGate inputGate,
+		MemorySegmentProvider memorySegmentProvider)
+		throws IOException, InterruptedException {
 
-		return createRemoteInputChannel(inputGate, mock(PartitionRequestClient.class), 0, 0);
+		return createRemoteInputChannel(
+				inputGate, mock(PartitionRequestClient.class), 0, 0, memorySegmentProvider);
 	}
 
 	private RemoteInputChannel createRemoteInputChannel(
@@ -1005,6 +1015,18 @@ public class RemoteInputChannelTest {
 			int maxBackoff)
 			throws IOException, InterruptedException {
 
+		return createRemoteInputChannel(inputGate, partitionRequestClient, initialBackoff, maxBackoff,
+			InputChannelTestUtils.StubMemorySegmentProvider.getInstance());
+	}
+
+	private RemoteInputChannel createRemoteInputChannel(
+			SingleInputGate inputGate,
+			PartitionRequestClient partitionRequestClient,
+			int initialBackoff,
+			int maxBackoff,
+			MemorySegmentProvider memorySegmentProvider)
+			throws IOException, InterruptedException {
+
 		final ConnectionManager connectionManager = mock(ConnectionManager.class);
 		when(connectionManager.createPartitionRequestClient(any(ConnectionID.class)))
 				.thenReturn(partitionRequestClient);
@@ -1013,6 +1035,7 @@ public class RemoteInputChannelTest {
 			.setConnectionManager(connectionManager)
 			.setInitialBackoff(initialBackoff)
 			.setMaxBackoff(maxBackoff)
+			.setMemorySegmentProvider(memorySegmentProvider)
 			.buildRemoteAndSetToGate(inputGate);
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 702ad34..dd031f8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -336,7 +336,8 @@ public class SingleInputGateTest extends InputGateTestBase {
 			netEnv.getConfiguration(),
 			netEnv.getConnectionManager(),
 			netEnv.getResultPartitionManager(),
-			new TaskEventDispatcher())
+			new TaskEventDispatcher(),
+			netEnv.getNetworkBufferPool())
 			.create(
 				"TestTask",
 				new JobID(),
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java
index 2b1b976..2bed27e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java
@@ -253,7 +253,8 @@ public class StreamNetworkBenchmarkEnvironment<T extends IOReadableWritable> {
 				environment.getConfiguration(),
 				environment.getConnectionManager(),
 				environment.getResultPartitionManager(),
-				new TaskEventDispatcher())
+				new TaskEventDispatcher(),
+				environment.getNetworkBufferPool())
 				.create(
 					"receiving task[" + channel + "]",
 					jobId,