You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by uc...@apache.org on 2016/12/02 08:42:38 UTC
[4/6] flink git commit: [FLINK-5169] [network] Add tests for channel
consumption
[FLINK-5169] [network] Add tests for channel consumption
This closes #2882.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/c0cdc5c4
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/c0cdc5c4
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/c0cdc5c4
Branch: refs/heads/master
Commit: c0cdc5c4ec08e35a8ea319d1bbf2b24e03e24fd3
Parents: d3ac0ad
Author: Stephan Ewen <se...@apache.org>
Authored: Sun Nov 27 18:15:40 2016 +0100
Committer: Ufuk Celebi <uc...@apache.org>
Committed: Thu Dec 1 21:42:49 2016 +0100
----------------------------------------------------------------------
.../partition/PipelinedSubpartition.java | 8 +
.../partition/consumer/LocalInputChannel.java | 4 +-
.../partition/consumer/SingleInputGate.java | 4 +-
.../partition/consumer/UnionInputGate.java | 2 +-
.../partition/InputChannelTestUtils.java | 89 +++++
.../partition/InputGateConcurrentTest.java | 323 +++++++++++++++
.../partition/InputGateFairnessTest.java | 395 +++++++++++++++++++
7 files changed, 820 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
index e9400f0..9e2f5ba 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
@@ -183,6 +183,14 @@ class PipelinedSubpartition extends ResultSubpartition {
return readView;
}
+ // ------------------------------------------------------------------------
+
+ int getCurrentNumberOfBuffers() {
+ return buffers.size();
+ }
+
+ // ------------------------------------------------------------------------
+
@Override
public String toString() {
final long numBuffers;
http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
index d5308a8..1936da2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
@@ -64,7 +64,7 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit
private volatile boolean isReleased;
- LocalInputChannel(
+ public LocalInputChannel(
SingleInputGate inputGate,
int channelIndex,
ResultPartitionID partitionId,
@@ -76,7 +76,7 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit
0, 0, metrics);
}
- LocalInputChannel(
+ public LocalInputChannel(
SingleInputGate inputGate,
int channelIndex,
ResultPartitionID partitionId,
http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index bcbb2c4..b4d8d2c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -261,7 +261,7 @@ public class SingleInputGate implements InputGate {
this.bufferPool = checkNotNull(bufferPool);
}
- void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) {
+ public void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) {
synchronized (requestLock) {
if (inputChannels.put(checkNotNull(partitionId), checkNotNull(inputChannel)) == null
&& inputChannel.getClass() == UnknownInputChannel.class) {
@@ -546,7 +546,7 @@ public class SingleInputGate implements InputGate {
inputChannelsWithData.add(channel);
if (availableChannels == 0) {
- inputChannelsWithData.notify();
+ inputChannelsWithData.notifyAll();
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index e8ccbb4..55c78af 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -225,7 +225,7 @@ public class UnionInputGate implements InputGate, InputGateListener {
inputGatesWithData.add(inputGate);
if (availableInputGates == 0) {
- inputGatesWithData.notify();
+ inputGatesWithData.notifyAll();
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
new file mode 100644
index 0000000..e292576
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferProvider;
+import org.apache.flink.runtime.io.network.netty.PartitionRequestClient;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Some utility methods used for testing InputChannels and InputGates.
+ */
+class InputChannelTestUtils {
+
+ /**
+ * Creates a simple Buffer that is not recycled (never will be) of the given size.
+ */
+ public static Buffer createMockBuffer(int size) {
+ final Buffer mockBuffer = mock(Buffer.class);
+ when(mockBuffer.isBuffer()).thenReturn(true);
+ when(mockBuffer.getSize()).thenReturn(size);
+ when(mockBuffer.isRecycled()).thenReturn(false);
+
+ return mockBuffer;
+ }
+
+ /**
+ * Creates a result partition manager that ignores all IDs, and simply returns the given
+ * subpartitions in sequence.
+ */
+ public static ResultPartitionManager createResultPartitionManager(final ResultSubpartition[] sources) throws Exception {
+
+ final Answer<ResultSubpartitionView> viewCreator = new Answer<ResultSubpartitionView>() {
+
+ private int num = 0;
+
+ @Override
+ public ResultSubpartitionView answer(InvocationOnMock invocation) throws Throwable {
+ BufferAvailabilityListener channel = (BufferAvailabilityListener) invocation.getArguments()[3];
+ return sources[num++].createReadView(null, channel);
+ }
+ };
+
+ ResultPartitionManager manager = mock(ResultPartitionManager.class);
+ when(manager.createSubpartitionView(
+ any(ResultPartitionID.class), anyInt(), any(BufferProvider.class), any(BufferAvailabilityListener.class)))
+ .thenAnswer(viewCreator);
+
+ return manager;
+ }
+
+ public static ConnectionManager createDummyConnectionManager() throws Exception {
+ final PartitionRequestClient mockClient = mock(PartitionRequestClient.class);
+
+ final ConnectionManager connManager = mock(ConnectionManager.class);
+ when(connManager.createPartitionRequestClient(any(ConnectionID.class))).thenReturn(mockClient);
+
+ return connManager;
+ }
+
+ // ------------------------------------------------------------------------
+
+ /** This class is not meant to be instantiated */
+ private InputChannelTestUtils() {}
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
new file mode 100644
index 0000000..6570679
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.TaskEventDispatcher;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import org.apache.flink.runtime.taskmanager.TaskActions;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.Mockito.mock;
+
+public class InputGateConcurrentTest {
+
+ @Test
+ public void testConsumptionWithLocalChannels() throws Exception {
+ final int numChannels = 11;
+ final int buffersPerChannel = 1000;
+
+ final ResultPartition resultPartition = mock(ResultPartition.class);
+
+ final PipelinedSubpartition[] partitions = new PipelinedSubpartition[numChannels];
+ final Source[] sources = new Source[numChannels];
+
+ final ResultPartitionManager resultPartitionManager = createResultPartitionManager(partitions);
+
+ final SingleInputGate gate = new SingleInputGate(
+ "Test Task Name",
+ new JobID(),
+ new ExecutionAttemptID(),
+ new IntermediateDataSetID(),
+ 0, numChannels,
+ mock(TaskActions.class),
+ new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ for (int i = 0; i < numChannels; i++) {
+ LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+ resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+
+ partitions[i] = new PipelinedSubpartition(0, resultPartition);
+ sources[i] = new PipelinedSubpartitionSource(partitions[i]);
+ }
+
+ ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10);
+ ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel);
+ producer.start();
+ consumer.start();
+
+ // the 'sync()' call checks for exceptions and failed assertions
+ producer.sync();
+ consumer.sync();
+ }
+
+ @Test
+ public void testConsumptionWithRemoteChannels() throws Exception {
+ final int numChannels = 11;
+ final int buffersPerChannel = 1000;
+
+ final ConnectionManager connManager = createDummyConnectionManager();
+ final Source[] sources = new Source[numChannels];
+
+ final SingleInputGate gate = new SingleInputGate(
+ "Test Task Name",
+ new JobID(),
+ new ExecutionAttemptID(),
+ new IntermediateDataSetID(),
+ 0,
+ numChannels,
+ mock(TaskActions.class),
+ new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ for (int i = 0; i < numChannels; i++) {
+ RemoteInputChannel channel = new RemoteInputChannel(
+ gate, i, new ResultPartitionID(), mock(ConnectionID.class),
+ connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+
+ sources[i] = new RemoteChannelSource(channel);
+ }
+
+ ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10);
+ ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel);
+ producer.start();
+ consumer.start();
+
+ // the 'sync()' call checks for exceptions and failed assertions
+ producer.sync();
+ consumer.sync();
+ }
+
+ @Test
+ public void testConsumptionWithMixedChannels() throws Exception {
+ final int numChannels = 61;
+ final int numLocalChannels = 20;
+ final int buffersPerChannel = 1000;
+
+ // fill the local/remote decision
+ List<Boolean> localOrRemote = new ArrayList<>(numChannels);
+ for (int i = 0; i < numChannels; i++) {
+ localOrRemote.add(i < numLocalChannels);
+ }
+ Collections.shuffle(localOrRemote);
+
+ final ConnectionManager connManager = createDummyConnectionManager();
+ final ResultPartition resultPartition = mock(ResultPartition.class);
+
+ final PipelinedSubpartition[] localPartitions = new PipelinedSubpartition[numLocalChannels];
+ final ResultPartitionManager resultPartitionManager = createResultPartitionManager(localPartitions);
+
+ final Source[] sources = new Source[numChannels];
+
+ final SingleInputGate gate = new SingleInputGate(
+ "Test Task Name",
+ new JobID(),
+ new ExecutionAttemptID(),
+ new IntermediateDataSetID(),
+ 0,
+ numChannels,
+ mock(TaskActions.class),
+ new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ for (int i = 0, local = 0; i < numChannels; i++) {
+ if (localOrRemote.get(i)) {
+ // local channel
+ PipelinedSubpartition psp = new PipelinedSubpartition(0, resultPartition);
+ localPartitions[local++] = psp;
+ sources[i] = new PipelinedSubpartitionSource(psp);
+
+ LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+ resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+ }
+ else {
+ //remote channel
+ RemoteInputChannel channel = new RemoteInputChannel(
+ gate, i, new ResultPartitionID(), mock(ConnectionID.class),
+ connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+
+ sources[i] = new RemoteChannelSource(channel);
+ }
+ }
+
+ ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10);
+ ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel);
+ producer.start();
+ consumer.start();
+
+ // the 'sync()' call checks for exceptions and failed assertions
+ producer.sync();
+ consumer.sync();
+ }
+
+ // ------------------------------------------------------------------------
+ // testing threads
+ // ------------------------------------------------------------------------
+
+ private static abstract class Source {
+
+ abstract void addBuffer(Buffer buffer) throws Exception;
+ }
+
+ private static class PipelinedSubpartitionSource extends Source {
+
+ final PipelinedSubpartition partition;
+
+ PipelinedSubpartitionSource(PipelinedSubpartition partition) {
+ this.partition = partition;
+ }
+
+ @Override
+ void addBuffer(Buffer buffer) throws Exception {
+ partition.add(buffer);
+ }
+ }
+
+ private static class RemoteChannelSource extends Source {
+
+ final RemoteInputChannel channel;
+ private int seq = 0;
+
+ RemoteChannelSource(RemoteInputChannel channel) {
+ this.channel = channel;
+ }
+
+ @Override
+ void addBuffer(Buffer buffer) throws Exception {
+ channel.onBuffer(buffer, seq++);
+ }
+ }
+
+ // ------------------------------------------------------------------------
+ // testing threads
+ // ------------------------------------------------------------------------
+
+ private static abstract class CheckedThread extends Thread {
+
+ private volatile Throwable error;
+
+ public abstract void go() throws Exception;
+
+ @Override
+ public void run() {
+ try {
+ go();
+ }
+ catch (Throwable t) {
+ error = t;
+ }
+ }
+
+ public void sync() throws Exception {
+ join();
+
+ // propagate the error
+ if (error != null) {
+ if (error instanceof Error) {
+ throw (Error) error;
+ }
+ else if (error instanceof Exception) {
+ throw (Exception) error;
+ }
+ else {
+ throw new Exception(error.getMessage(), error);
+ }
+ }
+ }
+ }
+
+ private static class ProducerThread extends CheckedThread {
+
+ private final Random rnd = new Random();
+ private final Source[] sources;
+ private final int numTotal;
+ private final int maxChunk;
+ private final int yieldAfter;
+
+ ProducerThread(Source[] sources, int numTotal, int maxChunk, int yieldAfter) {
+ this.sources = sources;
+ this.numTotal = numTotal;
+ this.maxChunk = maxChunk;
+ this.yieldAfter = yieldAfter;
+ }
+
+ @Override
+ public void go() throws Exception {
+ final Buffer buffer = InputChannelTestUtils.createMockBuffer(100);
+ int nextYield = numTotal - yieldAfter;
+
+ for (int i = numTotal; i > 0;) {
+ final int nextChannel = rnd.nextInt(sources.length);
+ final int chunk = Math.min(i, rnd.nextInt(maxChunk) + 1);
+
+ final Source next = sources[nextChannel];
+
+ for (int k = chunk; k > 0; --k) {
+ next.addBuffer(buffer);
+ }
+
+ i -= chunk;
+
+ if (i <= nextYield) {
+ nextYield -= yieldAfter;
+ //noinspection CallToThreadYield
+ Thread.yield();
+ }
+
+ }
+ }
+ }
+
+ private static class ConsumerThread extends CheckedThread {
+
+ private final SingleInputGate gate;
+ private final int numBuffers;
+
+ ConsumerThread(SingleInputGate gate, int numBuffers) {
+ this.gate = gate;
+ this.numBuffers = numBuffers;
+ }
+
+ @Override
+ public void go() throws Exception {
+ for (int i = numBuffers; i > 0; --i) {
+ assertNotNull(gate.getNextBufferOrEvent());
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
new file mode 100644
index 0000000..b35612a
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.TaskEventDispatcher;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
+import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import org.apache.flink.runtime.taskmanager.TaskActions;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createMockBuffer;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+public class InputGateFairnessTest {
+
+ @Test
+ public void testFairConsumptionLocalChannelsPreFilled() throws Exception {
+ final int numChannels = 37;
+ final int buffersPerChannel = 27;
+
+ final ResultPartition resultPartition = mock(ResultPartition.class);
+ final Buffer mockBuffer = createMockBuffer(42);
+
+ // ----- create some source channels and fill them with buffers -----
+
+ final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels];
+
+ for (int i = 0; i < numChannels; i++) {
+ PipelinedSubpartition partition = new PipelinedSubpartition(0, resultPartition);
+
+ for (int p = 0; p < buffersPerChannel; p++) {
+ partition.add(mockBuffer);
+ }
+
+ partition.finish();
+ sources[i] = partition;
+ }
+
+ // ----- create reading side -----
+
+ ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources);
+
+ SingleInputGate gate = new FairnessVerifyingInputGate(
+ "Test Task Name",
+ new JobID(),
+ new ExecutionAttemptID(),
+ new IntermediateDataSetID(),
+ 0, numChannels,
+ mock(TaskActions.class),
+ new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ for (int i = 0; i < numChannels; i++) {
+ LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+ resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+ }
+
+ // read all the buffers and the EOF event
+ for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) {
+ assertNotNull(gate.getNextBufferOrEvent());
+
+ int min = Integer.MAX_VALUE;
+ int max = 0;
+
+ for (PipelinedSubpartition source : sources) {
+ int size = source.getCurrentNumberOfBuffers();
+ min = Math.min(min, size);
+ max = Math.max(max, size);
+ }
+
+ assertTrue(max == min || max == min+1);
+ }
+
+ assertNull(gate.getNextBufferOrEvent());
+ }
+
+ @Test
+ public void testFairConsumptionLocalChannels() throws Exception {
+ final int numChannels = 37;
+ final int buffersPerChannel = 27;
+
+ final ResultPartition resultPartition = mock(ResultPartition.class);
+ final Buffer mockBuffer = createMockBuffer(42);
+
+ // ----- create some source channels and fill them with one buffer each -----
+
+ final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels];
+
+ for (int i = 0; i < numChannels; i++) {
+ sources[i] = new PipelinedSubpartition(0, resultPartition);
+ }
+
+ // ----- create reading side -----
+
+ ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources);
+
+ SingleInputGate gate = new FairnessVerifyingInputGate(
+ "Test Task Name",
+ new JobID(),
+ new ExecutionAttemptID(),
+ new IntermediateDataSetID(),
+ 0, numChannels,
+ mock(TaskActions.class),
+ new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ for (int i = 0; i < numChannels; i++) {
+ LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+ resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+ }
+
+ // seed one initial buffer
+ sources[12].add(mockBuffer);
+
+ // read all the buffers and the EOF event
+ for (int i = 0; i < numChannels * buffersPerChannel; i++) {
+ assertNotNull(gate.getNextBufferOrEvent());
+
+ int min = Integer.MAX_VALUE;
+ int max = 0;
+
+ for (PipelinedSubpartition source : sources) {
+ int size = source.getCurrentNumberOfBuffers();
+ min = Math.min(min, size);
+ max = Math.max(max, size);
+ }
+
+ assertTrue(max == min || max == min+1);
+
+ if (i % (2 * numChannels) == 0) {
+ // add three buffers to each channel, in random order
+ fillRandom(sources, 3, mockBuffer);
+ }
+ }
+
+ // there is still more in the queues
+ }
+
+ @Test
+ public void testFairConsumptionRemoteChannelsPreFilled() throws Exception {
+ final int numChannels = 37;
+ final int buffersPerChannel = 27;
+
+ final Buffer mockBuffer = createMockBuffer(42);
+
+ // ----- create some source channels and fill them with buffers -----
+
+ SingleInputGate gate = new FairnessVerifyingInputGate(
+ "Test Task Name",
+ new JobID(),
+ new ExecutionAttemptID(),
+ new IntermediateDataSetID(),
+ 0, numChannels,
+ mock(TaskActions.class),
+ new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ final ConnectionManager connManager = createDummyConnectionManager();
+
+ final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels];
+
+ for (int i = 0; i < numChannels; i++) {
+ RemoteInputChannel channel = new RemoteInputChannel(
+ gate, i, new ResultPartitionID(), mock(ConnectionID.class),
+ connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ channels[i] = channel;
+
+ for (int p = 0; p < buffersPerChannel; p++) {
+ channel.onBuffer(mockBuffer, p);
+ }
+ channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel);
+
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+ }
+
+ // read all the buffers and the EOF event
+ for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) {
+ assertNotNull(gate.getNextBufferOrEvent());
+
+ int min = Integer.MAX_VALUE;
+ int max = 0;
+
+ for (RemoteInputChannel channel : channels) {
+ int size = channel.getNumberOfQueuedBuffers();
+ min = Math.min(min, size);
+ max = Math.max(max, size);
+ }
+
+ assertTrue(max == min || max == min+1);
+ }
+
+ assertNull(gate.getNextBufferOrEvent());
+ }
+
+ @Test
+ public void testFairConsumptionRemoteChannels() throws Exception {
+ final int numChannels = 37;
+ final int buffersPerChannel = 27;
+
+ final Buffer mockBuffer = createMockBuffer(42);
+
+ // ----- create some source channels and fill them with buffers -----
+
+ SingleInputGate gate = new FairnessVerifyingInputGate(
+ "Test Task Name",
+ new JobID(),
+ new ExecutionAttemptID(),
+ new IntermediateDataSetID(),
+ 0, numChannels,
+ mock(TaskActions.class),
+ new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ final ConnectionManager connManager = createDummyConnectionManager();
+
+ final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels];
+ final int[] channelSequenceNums = new int[numChannels];
+
+ for (int i = 0; i < numChannels; i++) {
+ RemoteInputChannel channel = new RemoteInputChannel(
+ gate, i, new ResultPartitionID(), mock(ConnectionID.class),
+ connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+ channels[i] = channel;
+ gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+ }
+
+ channels[11].onBuffer(mockBuffer, 0);
+ channelSequenceNums[11]++;
+
+ // read all the buffers and the EOF event
+ for (int i = 0; i < numChannels * buffersPerChannel; i++) {
+ assertNotNull(gate.getNextBufferOrEvent());
+
+ int min = Integer.MAX_VALUE;
+ int max = 0;
+
+ for (RemoteInputChannel channel : channels) {
+ int size = channel.getNumberOfQueuedBuffers();
+ min = Math.min(min, size);
+ max = Math.max(max, size);
+ }
+
+ assertTrue(max == min || max == min+1);
+
+ if (i % (2 * numChannels) == 0) {
+ // add three buffers to each channel, in random order
+ fillRandom(channels, channelSequenceNums, 3, mockBuffer);
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------
+ // Utilities
+ // ------------------------------------------------------------------------
+
+ private void fillRandom(PipelinedSubpartition[] partitions, int numPerPartition, Buffer buffer) throws Exception {
+ ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition);
+
+ for (int i = 0; i < partitions.length; i++) {
+ for (int k = 0; k < numPerPartition; k++) {
+ poss.add(i);
+ }
+ }
+
+ Collections.shuffle(poss);
+
+ for (Integer i : poss) {
+ partitions[i].add(buffer);
+ }
+ }
+
+ private void fillRandom(
+ RemoteInputChannel[] partitions,
+ int[] sequenceNumbers,
+ int numPerPartition,
+ Buffer buffer) throws Exception {
+
+ ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition);
+
+ for (int i = 0; i < partitions.length; i++) {
+ for (int k = 0; k < numPerPartition; k++) {
+ poss.add(i);
+ }
+ }
+
+ Collections.shuffle(poss);
+
+ for (int i : poss) {
+ partitions[i].onBuffer(buffer, sequenceNumbers[i]++);
+ }
+ }
+
+ // ------------------------------------------------------------------------
+
+ private static class FairnessVerifyingInputGate extends SingleInputGate {
+
+ private final ArrayDeque<InputChannel> channelsWithData;
+
+ private final HashSet<InputChannel> uniquenessChecker;
+
+ @SuppressWarnings("unchecked")
+ public FairnessVerifyingInputGate(
+ String owningTaskName,
+ JobID jobId,
+ ExecutionAttemptID executionId,
+ IntermediateDataSetID consumedResultId,
+ int consumedSubpartitionIndex,
+ int numberOfInputChannels,
+ TaskActions taskActions,
+ TaskIOMetricGroup metrics) {
+
+ super(owningTaskName, jobId, executionId, consumedResultId, consumedSubpartitionIndex,
+ numberOfInputChannels, taskActions, metrics);
+
+ try {
+ Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData");
+ f.setAccessible(true);
+ channelsWithData = (ArrayDeque<InputChannel>) f.get(this);
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ this.uniquenessChecker = new HashSet<>();
+ }
+
+
+ @Override
+ public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException {
+ synchronized (channelsWithData) {
+ assertTrue("too many input channels", channelsWithData.size() <= getNumberOfInputChannels());
+ ensureUnique(channelsWithData);
+ }
+
+ return super.getNextBufferOrEvent();
+ }
+
+ private void ensureUnique(Collection<InputChannel> channels) {
+ HashSet<InputChannel> uniquenessChecker = this.uniquenessChecker;
+
+ for (InputChannel channel : channels) {
+ if (!uniquenessChecker.add(channel)) {
+ fail("Duplicate channel in input gate: " + channel);
+ }
+ }
+
+ assertTrue("found duplicate input channels", uniquenessChecker.size() == channels.size());
+ uniquenessChecker.clear();
+ }
+ }
+}