You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2020/06/18 09:25:34 UTC

[flink] 01/05: [FLINK-18094][network] Fixed UnionInputGate#getChannel.

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 75c0b5b1f4e0545ee4c55349dd6633bbd13cf128
Author: Arvid Heise <ar...@ververica.com>
AuthorDate: Tue Jun 16 09:21:41 2020 +0200

    [FLINK-18094][network] Fixed UnionInputGate#getChannel.
    
    The method assumed that the gates have consecutive indexes starting at 0.
---
 .../io/network/partition/consumer/UnionInputGate.java | 19 +++++++++++--------
 .../network/partition/consumer/InputGateTestBase.java |  2 +-
 .../partition/consumer/UnionInputGateTest.java        | 18 ++++++++++++++++++
 3 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index e863e10..ad8361c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -29,6 +29,7 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
+import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
@@ -70,7 +71,7 @@ import static org.apache.flink.util.Preconditions.checkState;
 public class UnionInputGate extends InputGate {
 
 	/** The input gates to union. */
-	private final InputGate[] inputGates;
+	private final Map<Integer, InputGate> inputGatesByGateIndex;
 
 	private final Set<IndexedInputGate> inputGatesWithRemainingData;
 
@@ -89,7 +90,7 @@ public class UnionInputGate extends InputGate {
 	private final int[] inputGateChannelIndexOffsets;
 
 	public UnionInputGate(IndexedInputGate... inputGates) {
-		this.inputGates = checkNotNull(inputGates);
+		inputGatesByGateIndex = Arrays.stream(inputGates).collect(Collectors.toMap(IndexedInputGate::getGateIndex, ig -> ig));
 		checkArgument(inputGates.length > 1, "Union input gate should union at least two input gates.");
 
 		if (Arrays.stream(inputGates).map(IndexedInputGate::getGateIndex).distinct().count() != inputGates.length) {
@@ -100,8 +101,9 @@ public class UnionInputGate extends InputGate {
 		this.inputGatesWithRemainingData = Sets.newHashSetWithExpectedSize(inputGates.length);
 
 		final int maxGateIndex = Arrays.stream(inputGates).mapToInt(IndexedInputGate::getGateIndex).max().orElse(0);
-		inputGateChannelIndexOffsets = new int[maxGateIndex + 1];
 		int totalNumberOfInputChannels = Arrays.stream(inputGates).mapToInt(IndexedInputGate::getNumberOfInputChannels).sum();
+
+		inputGateChannelIndexOffsets = new int[maxGateIndex + 1];
 		inputChannelToInputGateIndex = new int[totalNumberOfInputChannels];
 
 		int currentNumberOfInputChannels = 0;
@@ -141,8 +143,9 @@ public class UnionInputGate extends InputGate {
 
 	@Override
 	public InputChannel getChannel(int channelIndex) {
-		int gateIndex = this.inputChannelToInputGateIndex[channelIndex];
-		return inputGates[gateIndex].getChannel(channelIndex - inputGateChannelIndexOffsets[gateIndex]);
+		int gateIndex = inputChannelToInputGateIndex[channelIndex];
+		return inputGatesByGateIndex.get(gateIndex)
+			.getChannel(channelIndex - inputGateChannelIndexOffsets[gateIndex]);
 	}
 
 	@Override
@@ -253,7 +256,7 @@ public class UnionInputGate extends InputGate {
 
 	@Override
 	public void sendTaskEvent(TaskEvent event) throws IOException {
-		for (InputGate inputGate : inputGates) {
+		for (InputGate inputGate : inputGatesByGateIndex.values()) {
 			inputGate.sendTaskEvent(event);
 		}
 	}
@@ -277,7 +280,7 @@ public class UnionInputGate extends InputGate {
 
 	@Override
 	public void requestPartitions() throws IOException {
-		for (InputGate inputGate : inputGates) {
+		for (InputGate inputGate : inputGatesByGateIndex.values()) {
 			inputGate.requestPartitions();
 		}
 	}
@@ -332,7 +335,7 @@ public class UnionInputGate extends InputGate {
 
 	@Override
 	public void registerBufferReceivedListener(BufferReceivedListener listener) {
-		for (InputGate inputGate : inputGates) {
+		for (InputGate inputGate : inputGatesByGateIndex.values()) {
 			inputGate.registerBufferReceivedListener(listener);
 		}
 	}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
index 14f42f4..ae52a3c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
@@ -35,7 +35,7 @@ import static org.junit.Assert.assertTrue;
  */
 public abstract class InputGateTestBase {
 
-	private int gateIndex;
+	int gateIndex;
 
 	@Before
 	public void resetGateIndex() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index 8b4a7df..93b131e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -164,4 +164,22 @@ public class UnionInputGateTest extends InputGateTestBase {
 		// Check that updated input channel is visible via UnionInputGate
 		assertThat(unionInputGate.getChannel(1), Matchers.is(inputGate2.getChannel(0)));
 	}
+
+	@Test
+	public void testGetChannelWithShiftedGateIndexes() {
+		gateIndex = 2;
+		final SingleInputGate inputGate1 = createInputGate(1);
+		TestInputChannel inputChannel1 = new TestInputChannel(inputGate1, 0);
+		inputGate1.setInputChannels(inputChannel1);
+
+		final SingleInputGate inputGate2 = createInputGate(1);
+		TestInputChannel inputChannel2 = new TestInputChannel(inputGate2, 0);
+		inputGate2.setInputChannels(inputChannel2);
+
+		UnionInputGate unionInputGate = new UnionInputGate(inputGate1, inputGate2);
+
+		assertThat(unionInputGate.getChannel(0), Matchers.is(inputChannel1));
+		// Check that updated input channel is visible via UnionInputGate
+		assertThat(unionInputGate.getChannel(1), Matchers.is(inputChannel2));
+	}
 }