You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2018/05/02 09:29:53 UTC
[15/16] flink git commit: [FLINK-9256] [network] Fix NPE in
SingleInputGate#updateInputChannel() for non-credit based flow control
[FLINK-9256] [network] Fix NPE in SingleInputGate#updateInputChannel() for non-credit based flow control
This closes #5914
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/56e2b0b5
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/56e2b0b5
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/56e2b0b5
Branch: refs/heads/release-1.5
Commit: 56e2b0b5d600935eae590a985643a5879f224d04
Parents: f1fa517
Author: Nico Kruber <ni...@data-artisans.com>
Authored: Wed Apr 25 18:28:48 2018 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Mon Apr 30 23:25:38 2018 +0200
----------------------------------------------------------------------
.../runtime/io/network/NetworkEnvironment.java | 4 +
.../partition/consumer/SingleInputGate.java | 20 +-
.../io/network/NetworkEnvironmentTest.java | 5 +-
.../PartitionRequestClientHandlerTest.java | 3 +-
.../partition/InputGateConcurrentTest.java | 9 +-
.../partition/InputGateFairnessTest.java | 17 +-
.../consumer/LocalInputChannelTest.java | 6 +-
.../consumer/RemoteInputChannelTest.java | 3 +-
.../partition/consumer/SingleInputGateTest.java | 318 +++++++++++++++----
.../partition/consumer/TestSingleInputGate.java | 3 +-
.../partition/consumer/UnionInputGateTest.java | 6 +-
11 files changed, 301 insertions(+), 93 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
index 0a9dc0f..f254756 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
@@ -157,6 +157,10 @@ public class NetworkEnvironment {
return partitionRequestMaxBackoff;
}
+ public boolean isCreditBased() {
+ return enableCreditBased;
+ }
+
public KvStateRegistry getKvStateRegistry() {
return kvStateRegistry;
}
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index b9091b2..06e80ff 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -157,9 +157,11 @@ public class SingleInputGate implements InputGate {
*/
private BufferPool bufferPool;
- /** Global network buffer pool to request and recycle exclusive buffers. */
+ /** Global network buffer pool to request and recycle exclusive buffers (only for credit-based). */
private NetworkBufferPool networkBufferPool;
+ private final boolean isCreditBased;
+
private boolean hasReceivedAllEndOfPartitionEvents;
/** Flag indicating whether partitions have been requested. */
@@ -189,7 +191,8 @@ public class SingleInputGate implements InputGate {
int consumedSubpartitionIndex,
int numberOfInputChannels,
TaskActions taskActions,
- TaskIOMetricGroup metrics) {
+ TaskIOMetricGroup metrics,
+ boolean isCreditBased) {
this.owningTaskName = checkNotNull(owningTaskName);
this.jobId = checkNotNull(jobId);
@@ -208,6 +211,7 @@ public class SingleInputGate implements InputGate {
this.enqueuedInputChannelsWithData = new BitSet(numberOfInputChannels);
this.taskActions = checkNotNull(taskActions);
+ this.isCreditBased = isCreditBased;
}
// ------------------------------------------------------------------------
@@ -288,6 +292,7 @@ public class SingleInputGate implements InputGate {
* @param networkBuffersPerChannel The number of exclusive buffers for each channel
*/
public void assignExclusiveSegments(NetworkBufferPool networkBufferPool, int networkBuffersPerChannel) throws IOException {
+ checkState(this.isCreditBased, "Bug in input gate setup logic: exclusive buffers only exist with credit-based flow control.");
checkState(this.networkBufferPool == null, "Bug in input gate setup logic: global buffer pool has" +
"already been set for this input gate.");
@@ -347,8 +352,13 @@ public class SingleInputGate implements InputGate {
}
else if (partitionLocation.isRemote()) {
newChannel = unknownChannel.toRemoteInputChannel(partitionLocation.getConnectionId());
- ((RemoteInputChannel)newChannel).assignExclusiveSegments(
- networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+
+ if (this.isCreditBased) {
+ checkState(this.networkBufferPool != null, "Bug in input gate setup logic: " +
+ "global buffer pool has not been set for this input gate.");
+ ((RemoteInputChannel) newChannel).assignExclusiveSegments(
+ networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+ }
}
else {
throw new IllegalStateException("Tried to update unknown channel with unknown channel.");
@@ -661,7 +671,7 @@ public class SingleInputGate implements InputGate {
final SingleInputGate inputGate = new SingleInputGate(
owningTaskName, jobId, consumedResultId, consumedPartitionType, consumedSubpartitionIndex,
- icdd.length, taskActions, metrics);
+ icdd.length, taskActions, metrics, networkEnvironment.isCreditBased());
// Create the input channels. There is one input channel for each consumed partition.
final InputChannel[] inputChannels = new InputChannel[icdd.length];
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
index 317a214..f790b5f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
@@ -329,7 +329,7 @@ public class NetworkEnvironmentTest {
*
* @return input gate with some fake settings
*/
- private static SingleInputGate createSingleInputGate(
+ private SingleInputGate createSingleInputGate(
final ResultPartitionType partitionType, final int channels) {
return spy(new SingleInputGate(
"Test Task Name",
@@ -339,7 +339,8 @@ public class NetworkEnvironmentTest {
0,
channels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()));
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ enableCreditBasedFlowControl));
}
private static void createRemoteInputChannel(
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
index 13f7510..842aed8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
@@ -221,7 +221,8 @@ public class PartitionRequestClientHandlerTest {
0,
1,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
}
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
index 289a398..73f3cfb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -66,7 +66,8 @@ public class InputGateConcurrentTest {
new IntermediateDataSetID(), ResultPartitionType.PIPELINED,
0, numChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
for (int i = 0; i < numChannels; i++) {
LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
@@ -102,7 +103,8 @@ public class InputGateConcurrentTest {
0,
numChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
for (int i = 0; i < numChannels; i++) {
RemoteInputChannel channel = new RemoteInputChannel(
@@ -151,7 +153,8 @@ public class InputGateConcurrentTest {
0,
numChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
for (int i = 0, local = 0; i < numChannels; i++) {
if (localOrRemote.get(i)) {
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
index 45df56f..82a27cc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -93,7 +93,8 @@ public class InputGateFairnessTest {
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
for (int i = 0; i < numChannels; i++) {
LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
@@ -146,7 +147,8 @@ public class InputGateFairnessTest {
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
for (int i = 0; i < numChannels; i++) {
LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
@@ -196,7 +198,8 @@ public class InputGateFairnessTest {
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
final ConnectionManager connManager = createDummyConnectionManager();
@@ -251,7 +254,8 @@ public class InputGateFairnessTest {
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
final ConnectionManager connManager = createDummyConnectionManager();
@@ -349,11 +353,12 @@ public class InputGateFairnessTest {
int consumedSubpartitionIndex,
int numberOfInputChannels,
TaskActions taskActions,
- TaskIOMetricGroup metrics) {
+ TaskIOMetricGroup metrics,
+ boolean isCreditBased) {
super(owningTaskName, jobId, consumedResultId, ResultPartitionType.PIPELINED,
consumedSubpartitionIndex,
- numberOfInputChannels, taskActions, metrics);
+ numberOfInputChannels, taskActions, metrics, isCreditBased);
try {
Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData");
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
index c78b7b9..1ecb67f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
@@ -293,7 +293,8 @@ public class LocalInputChannelTest {
0,
1,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true
);
ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
@@ -490,7 +491,8 @@ public class LocalInputChannelTest {
subpartitionIndex,
numberOfInputChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
// Set buffer pool
inputGate.setBufferPool(bufferPool);
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index 97a5688..802cb93 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -889,7 +889,8 @@ public class RemoteInputChannelTest {
0,
1,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
}
private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate)
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 8c54c1f..c244668 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -19,13 +19,13 @@
package org.apache.flink.runtime.io.network.partition.consumer;
import org.apache.flink.api.common.JobID;
-import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionLocation;
import org.apache.flink.runtime.event.TaskEvent;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.LocalConnectionManager;
@@ -45,31 +45,51 @@ import org.apache.flink.runtime.io.network.util.TestTaskEvent;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
import java.io.IOException;
import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
-import static org.mockito.Matchers.anyListOf;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+/**
+ * Tests for {@link SingleInputGate}.
+ */
+@RunWith(Parameterized.class)
public class SingleInputGateTest {
+ @Parameterized.Parameter
+ public boolean enableCreditBasedFlowControl;
+
+ @Parameterized.Parameters(name = "Credit-based = {0}")
+ public static List<Boolean> parameters() {
+ return Arrays.asList(Boolean.TRUE, Boolean.FALSE);
+ }
+
/**
* Tests basic correctness of buffer-or-event interleaving and correct <code>null</code> return
* value after receiving all end-of-partition events.
@@ -324,12 +344,7 @@ public class SingleInputGateTest {
int initialBackoff = 137;
int maxBackoff = 1001;
- NetworkEnvironment netEnv = mock(NetworkEnvironment.class);
- when(netEnv.getResultPartitionManager()).thenReturn(new ResultPartitionManager());
- when(netEnv.getTaskEventDispatcher()).thenReturn(new TaskEventDispatcher());
- when(netEnv.getPartitionRequestInitialBackoff()).thenReturn(initialBackoff);
- when(netEnv.getPartitionRequestMaxBackoff()).thenReturn(maxBackoff);
- when(netEnv.getConnectionManager()).thenReturn(new LocalConnectionManager());
+ final NetworkEnvironment netEnv = createNetworkEnvironment(2, 8, initialBackoff, maxBackoff);
SingleInputGate gate = SingleInputGate.create(
"TestTask",
@@ -340,37 +355,43 @@ public class SingleInputGateTest {
mock(TaskActions.class),
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
- assertEquals(gateDesc.getConsumedPartitionType(), gate.getConsumedPartitionType());
+ try {
+ assertEquals(gateDesc.getConsumedPartitionType(), gate.getConsumedPartitionType());
- Map<IntermediateResultPartitionID, InputChannel> channelMap = gate.getInputChannels();
+ Map<IntermediateResultPartitionID, InputChannel> channelMap = gate.getInputChannels();
- assertEquals(3, channelMap.size());
- InputChannel localChannel = channelMap.get(partitionIds[0].getPartitionId());
- assertEquals(LocalInputChannel.class, localChannel.getClass());
+ assertEquals(3, channelMap.size());
+ InputChannel localChannel = channelMap.get(partitionIds[0].getPartitionId());
+ assertEquals(LocalInputChannel.class, localChannel.getClass());
- InputChannel remoteChannel = channelMap.get(partitionIds[1].getPartitionId());
- assertEquals(RemoteInputChannel.class, remoteChannel.getClass());
+ InputChannel remoteChannel = channelMap.get(partitionIds[1].getPartitionId());
+ assertEquals(RemoteInputChannel.class, remoteChannel.getClass());
- InputChannel unknownChannel = channelMap.get(partitionIds[2].getPartitionId());
- assertEquals(UnknownInputChannel.class, unknownChannel.getClass());
+ InputChannel unknownChannel = channelMap.get(partitionIds[2].getPartitionId());
+ assertEquals(UnknownInputChannel.class, unknownChannel.getClass());
- InputChannel[] channels = new InputChannel[]{localChannel, remoteChannel, unknownChannel};
- for (InputChannel ch : channels) {
- assertEquals(0, ch.getCurrentBackoff());
+ InputChannel[] channels =
+ new InputChannel[] {localChannel, remoteChannel, unknownChannel};
+ for (InputChannel ch : channels) {
+ assertEquals(0, ch.getCurrentBackoff());
- assertTrue(ch.increaseBackoff());
- assertEquals(initialBackoff, ch.getCurrentBackoff());
+ assertTrue(ch.increaseBackoff());
+ assertEquals(initialBackoff, ch.getCurrentBackoff());
- assertTrue(ch.increaseBackoff());
- assertEquals(initialBackoff * 2, ch.getCurrentBackoff());
+ assertTrue(ch.increaseBackoff());
+ assertEquals(initialBackoff * 2, ch.getCurrentBackoff());
- assertTrue(ch.increaseBackoff());
- assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff());
+ assertTrue(ch.increaseBackoff());
+ assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff());
- assertTrue(ch.increaseBackoff());
- assertEquals(maxBackoff, ch.getCurrentBackoff());
+ assertTrue(ch.increaseBackoff());
+ assertEquals(maxBackoff, ch.getCurrentBackoff());
- assertFalse(ch.increaseBackoff());
+ assertFalse(ch.increaseBackoff());
+ }
+ } finally {
+ gate.releaseAllResources();
+ netEnv.shutdown();
}
}
@@ -379,26 +400,39 @@ public class SingleInputGateTest {
*/
@Test
public void testRequestBuffersWithRemoteInputChannel() throws Exception {
- final SingleInputGate inputGate = new SingleInputGate(
- "t1",
- new JobID(),
- new IntermediateDataSetID(),
- ResultPartitionType.PIPELINED_BOUNDED,
- 0,
- 1,
- mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
-
- RemoteInputChannel remote = mock(RemoteInputChannel.class);
- inputGate.setInputChannel(new IntermediateResultPartitionID(), remote);
-
- final int buffersPerChannel = 2;
- NetworkBufferPool network = mock(NetworkBufferPool.class);
- // Trigger requests of segments from global pool and assign buffers to remote input channel
- inputGate.assignExclusiveSegments(network, buffersPerChannel);
-
- verify(network, times(1)).requestMemorySegments(buffersPerChannel);
- verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+ final SingleInputGate inputGate = createInputGate(1, ResultPartitionType.PIPELINED_BOUNDED);
+ int buffersPerChannel = 2;
+ int extraNetworkBuffersPerGate = 8;
+ final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel,
+ extraNetworkBuffersPerGate, 0, 0);
+
+ try {
+ final ResultPartitionID resultPartitionId = new ResultPartitionID();
+ final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
+ addRemoteInputChannel(network, inputGate, connectionId, resultPartitionId, 0);
+
+ network.setupInputGate(inputGate);
+
+ NetworkBufferPool bufferPool = network.getNetworkBufferPool();
+ if (enableCreditBasedFlowControl) {
+ verify(bufferPool,
+ times(1)).requestMemorySegments(buffersPerChannel);
+ RemoteInputChannel remote = (RemoteInputChannel) inputGate.getInputChannels()
+ .get(resultPartitionId.getPartitionId());
+ // only the exclusive buffers should be assigned/available now
+ assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers());
+
+ assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel,
+ bufferPool.getNumberOfAvailableMemorySegments());
+ // note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted
+ assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+ } else {
+ assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers());
+ }
+ } finally {
+ inputGate.releaseAllResources();
+ network.shutdown();
+ }
}
/**
@@ -407,51 +441,195 @@ public class SingleInputGateTest {
*/
@Test
public void testRequestBuffersWithUnknownInputChannel() throws Exception {
- final SingleInputGate inputGate = createInputGate(1);
+ final SingleInputGate inputGate = createInputGate(1, ResultPartitionType.PIPELINED_BOUNDED);
+ int buffersPerChannel = 2;
+ int extraNetworkBuffersPerGate = 8;
+ final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel, extraNetworkBuffersPerGate, 0, 0);
- UnknownInputChannel unknown = mock(UnknownInputChannel.class);
- final ResultPartitionID resultPartitionId = new ResultPartitionID();
- inputGate.setInputChannel(resultPartitionId.getPartitionId(), unknown);
+ try {
+ final ResultPartitionID resultPartitionId = new ResultPartitionID();
+ addUnknownInputChannel(network, inputGate, resultPartitionId, 0);
- RemoteInputChannel remote = mock(RemoteInputChannel.class);
- final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
- when(unknown.toRemoteInputChannel(connectionId)).thenReturn(remote);
+ network.setupInputGate(inputGate);
+ NetworkBufferPool bufferPool = network.getNetworkBufferPool();
- final int buffersPerChannel = 2;
- NetworkBufferPool network = mock(NetworkBufferPool.class);
- inputGate.assignExclusiveSegments(network, buffersPerChannel);
+ if (enableCreditBasedFlowControl) {
+ verify(bufferPool, times(0)).requestMemorySegments(buffersPerChannel);
- // Trigger updates to remote input channel from unknown input channel
- inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
- resultPartitionId,
- ResultPartitionLocation.createRemote(connectionId)));
+ assertEquals(bufferPool.getTotalNumberOfMemorySegments(),
+ bufferPool.getNumberOfAvailableMemorySegments());
+ // note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted
+ assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+ } else {
+ assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers());
+ }
+
+ // Trigger updates to remote input channel from unknown input channel
+ final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
+ inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
+ resultPartitionId,
+ ResultPartitionLocation.createRemote(connectionId)));
+
+ if (enableCreditBasedFlowControl) {
+ verify(bufferPool,
+ times(1)).requestMemorySegments(buffersPerChannel);
+ RemoteInputChannel remote = (RemoteInputChannel) inputGate.getInputChannels()
+ .get(resultPartitionId.getPartitionId());
+ // only the exclusive buffers should be assigned/available now
+ assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers());
+
+ assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel,
+ bufferPool.getNumberOfAvailableMemorySegments());
+ // note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted
+ assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+ } else {
+ assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers());
+ }
+ } finally {
+ inputGate.releaseAllResources();
+ network.shutdown();
+ }
+ }
- verify(network, times(1)).requestMemorySegments(buffersPerChannel);
- verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+ /**
+ * Tests that input gate can successfully convert unknown input channels into local and remote
+ * channels.
+ */
+ @Test
+ public void testUpdateUnknownInputChannel() throws Exception {
+ final SingleInputGate inputGate = createInputGate(2);
+ int buffersPerChannel = 2;
+ final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel, 8, 0, 0);
+
+ try {
+ final ResultPartitionID localResultPartitionId = new ResultPartitionID();
+ addUnknownInputChannel(network, inputGate, localResultPartitionId, 0);
+
+ final ResultPartitionID remoteResultPartitionId = new ResultPartitionID();
+ addUnknownInputChannel(network, inputGate, remoteResultPartitionId, 1);
+
+ network.setupInputGate(inputGate);
+
+ assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+ is(instanceOf((UnknownInputChannel.class))));
+ assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+ is(instanceOf((UnknownInputChannel.class))));
+
+ // Trigger updates to remote input channel from unknown input channel
+ final ConnectionID remoteConnectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0);
+ inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
+ remoteResultPartitionId,
+ ResultPartitionLocation.createRemote(remoteConnectionId)));
+
+ assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+ is(instanceOf((RemoteInputChannel.class))));
+ assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+ is(instanceOf((UnknownInputChannel.class))));
+
+ // Trigger updates to local input channel from unknown input channel
+ inputGate.updateInputChannel(new InputChannelDeploymentDescriptor(
+ localResultPartitionId,
+ ResultPartitionLocation.createLocal()));
+
+ assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+ is(instanceOf((RemoteInputChannel.class))));
+ assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+ is(instanceOf((LocalInputChannel.class))));
+ } finally {
+ inputGate.releaseAllResources();
+ network.shutdown();
+ }
}
// ---------------------------------------------------------------------------------------------
- private static SingleInputGate createInputGate() {
+ private NetworkEnvironment createNetworkEnvironment(
+ int buffersPerChannel,
+ int extraNetworkBuffersPerGate,
+ int initialBackoff,
+ int maxBackoff) {
+ return new NetworkEnvironment(
+ spy(new NetworkBufferPool(100, 32)),
+ new LocalConnectionManager(),
+ new ResultPartitionManager(),
+ new TaskEventDispatcher(),
+ new KvStateRegistry(),
+ null,
+ null,
+ IOManager.IOMode.SYNC,
+ initialBackoff,
+ maxBackoff,
+ buffersPerChannel,
+ extraNetworkBuffersPerGate,
+ enableCreditBasedFlowControl);
+ }
+
+ private SingleInputGate createInputGate() {
return createInputGate(2);
}
- private static SingleInputGate createInputGate(int numberOfInputChannels) {
+ private SingleInputGate createInputGate(int numberOfInputChannels) {
+ return createInputGate(numberOfInputChannels, ResultPartitionType.PIPELINED);
+ }
+
+ private SingleInputGate createInputGate(
+ int numberOfInputChannels, ResultPartitionType partitionType) {
SingleInputGate inputGate = new SingleInputGate(
"Test Task Name",
new JobID(),
new IntermediateDataSetID(),
- ResultPartitionType.PIPELINED,
+ partitionType,
0,
numberOfInputChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ enableCreditBasedFlowControl);
- assertEquals(ResultPartitionType.PIPELINED, inputGate.getConsumedPartitionType());
+ assertEquals(partitionType, inputGate.getConsumedPartitionType());
return inputGate;
}
+ private void addUnknownInputChannel(
+ NetworkEnvironment network,
+ SingleInputGate inputGate,
+ ResultPartitionID partitionId,
+ int channelIndex) {
+ UnknownInputChannel unknown =
+ createUnknownInputChannel(network, inputGate, partitionId, channelIndex);
+ inputGate.setInputChannel(partitionId.getPartitionId(), unknown);
+ }
+
+ private UnknownInputChannel createUnknownInputChannel(
+ NetworkEnvironment network,
+ SingleInputGate inputGate,
+ ResultPartitionID partitionId,
+ int channelIndex) {
+ return new UnknownInputChannel(
+ inputGate,
+ channelIndex,
+ partitionId,
+ network.getResultPartitionManager(),
+ network.getTaskEventDispatcher(),
+ network.getConnectionManager(),
+ network.getPartitionRequestInitialBackoff(),
+ network.getPartitionRequestMaxBackoff(),
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()
+ );
+ }
+
+ private void addRemoteInputChannel(
+ NetworkEnvironment network,
+ SingleInputGate inputGate,
+ ConnectionID connectionId,
+ ResultPartitionID partitionId,
+ int channelIndex) {
+ RemoteInputChannel remote =
+ createUnknownInputChannel(network, inputGate, partitionId, channelIndex)
+ .toRemoteInputChannel(connectionId);
+ inputGate.setInputChannel(partitionId.getPartitionId(), remote);
+ }
+
static void verifyBufferOrEvent(
InputGate inputGate,
boolean expectedIsBuffer,
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
index 0ae6e74..33dc1ca 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
@@ -60,7 +60,8 @@ public class TestSingleInputGate {
0,
numberOfInputChannels,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
this.inputGate = spy(realGate);
http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index 912cd5b..081d97d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -50,13 +50,15 @@ public class UnionInputGateTest {
new IntermediateDataSetID(), ResultPartitionType.PIPELINED,
0, 3,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
final SingleInputGate ig2 = new SingleInputGate(
testTaskName, new JobID(),
new IntermediateDataSetID(), ResultPartitionType.PIPELINED,
0, 5,
mock(TaskActions.class),
- UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+ true);
final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2});