You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2020/06/18 09:25:38 UTC

[flink] 05/05: [FLINK-18094][network] Buffers are only addressed through InputChannelInfo.

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1d211b93c4d7a180f3bb5ab131929201af29b8dc
Author: Arvid Heise <ar...@ververica.com>
AuthorDate: Mon Jun 15 21:43:32 2020 +0200

    [FLINK-18094][network] Buffers are only addressed through InputChannelInfo.
    
    This removes the need to translate the InputChannelInfo back and forth to flattened indexes across all InputGates.
    All index-based data structures are replaced by maps that associate a certain state to a given InputChannelInfo. For performance reasons, these maps are fully initialized upon construction, such that no nodes need to be added/removed during runtime and only values are updated.
    Additionally, this commit unifies the creation of BarrierHandlers (similar signature) and removes the error-prone offset handling from CheckpointedInputGate.
---
 .../network/api/reader/AbstractRecordReader.java   | 22 +++---
 .../network/partition/consumer/BufferOrEvent.java  | 33 ++++----
 .../partition/consumer/SingleInputGate.java        |  4 +-
 .../network/partition/consumer/UnionInputGate.java | 22 ++----
 .../io/network/api/writer/RecordWriterTest.java    |  5 +-
 .../partition/consumer/LocalInputChannelTest.java  |  4 +-
 .../partition/consumer/SingleInputGateBuilder.java | 24 +++++-
 .../partition/consumer/SingleInputGateTest.java    | 19 +++--
 .../io/AlternatingCheckpointBarrierHandler.java    |  8 +-
 .../runtime/io/CheckpointBarrierAligner.java       | 81 ++++++++++---------
 .../runtime/io/CheckpointBarrierHandler.java       |  6 +-
 .../runtime/io/CheckpointBarrierTracker.java       |  5 +-
 .../runtime/io/CheckpointBarrierUnaligner.java     | 80 +++++++------------
 .../runtime/io/CheckpointedInputGate.java          | 43 +---------
 .../streaming/runtime/io/InputProcessorUtil.java   | 91 +++++-----------------
 .../runtime/io/StreamTaskNetworkInput.java         | 25 +++++-
 .../AlternatingCheckpointBarrierHandlerTest.java   | 45 ++++++-----
 .../CheckpointBarrierAlignerMassiveRandomTest.java | 18 ++++-
 .../io/CheckpointBarrierAlignerTestBase.java       | 12 +--
 .../runtime/io/CheckpointBarrierTrackerTest.java   |  7 +-
 ...CheckpointBarrierUnalignerCancellationTest.java |  5 +-
 .../runtime/io/CheckpointBarrierUnalignerTest.java | 50 ++++++------
 .../CreditBasedCheckpointBarrierAlignerTest.java   |  2 +-
 .../runtime/io/InputProcessorUtilTest.java         | 31 --------
 .../flink/streaming/runtime/io/MockInputGate.java  |  2 +-
 .../runtime/io/StreamTaskNetworkInputTest.java     | 27 ++++---
 26 files changed, 287 insertions(+), 384 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java
index 1c98d0c..5632370 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.api.reader;
 
 import org.apache.flink.core.io.IOReadableWritable;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
 import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
 import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
@@ -27,6 +28,9 @@ import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 
 import java.io.IOException;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 /**
  * A record-oriented reader.
@@ -37,7 +41,7 @@ import java.io.IOException;
  */
 abstract class AbstractRecordReader<T extends IOReadableWritable> extends AbstractReader implements ReaderBase {
 
-	private final RecordDeserializer<T>[] recordDeserializers;
+	private final Map<InputChannelInfo, RecordDeserializer<T>> recordDeserializers;
 
 	private RecordDeserializer<T> currentRecordDeserializer;
 
@@ -58,10 +62,10 @@ abstract class AbstractRecordReader<T extends IOReadableWritable> extends Abstra
 		super(inputGate);
 
 		// Initialize one deserializer per input channel
-		this.recordDeserializers = new SpillingAdaptiveSpanningRecordDeserializer[inputGate.getNumberOfInputChannels()];
-		for (int i = 0; i < recordDeserializers.length; i++) {
-			recordDeserializers[i] = new SpillingAdaptiveSpanningRecordDeserializer<T>(tmpDirectories);
-		}
+		recordDeserializers = inputGate.getChannelInfos().stream()
+			.collect(Collectors.toMap(
+				Function.identity(),
+				channelInfo -> new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories)));
 	}
 
 	protected boolean getNextRecord(T target) throws IOException, InterruptedException {
@@ -96,15 +100,15 @@ abstract class AbstractRecordReader<T extends IOReadableWritable> extends Abstra
 			final BufferOrEvent bufferOrEvent = inputGate.getNext().orElseThrow(IllegalStateException::new);
 
 			if (bufferOrEvent.isBuffer()) {
-				currentRecordDeserializer = recordDeserializers[bufferOrEvent.getChannelIndex()];
+				currentRecordDeserializer = recordDeserializers.get(bufferOrEvent.getChannelInfo());
 				currentRecordDeserializer.setNextBuffer(bufferOrEvent.getBuffer());
 			}
 			else {
 				// sanity check for leftover data in deserializers. events should only come between
 				// records, not in the middle of a fragment
-				if (recordDeserializers[bufferOrEvent.getChannelIndex()].hasUnfinishedData()) {
+				if (recordDeserializers.get(bufferOrEvent.getChannelInfo()).hasUnfinishedData()) {
 					throw new IOException(
-							"Received an event in channel " + bufferOrEvent.getChannelIndex() + " while still having "
+							"Received an event in channel " + bufferOrEvent.getChannelInfo() + " while still having "
 							+ "data from a record. This indicates broken serialization logic. "
 							+ "If you are using custom serialization code (Writable or Value types), check their "
 							+ "serialization routines. In the case of Kryo, check the respective Kryo serializer.");
@@ -125,7 +129,7 @@ abstract class AbstractRecordReader<T extends IOReadableWritable> extends Abstra
 	}
 
 	public void clearBuffers() {
-		for (RecordDeserializer<?> deserializer : recordDeserializers) {
+		for (RecordDeserializer<?> deserializer : recordDeserializers.values()) {
 			Buffer buffer = deserializer.getCurrentBuffer();
 			if (buffer != null && !buffer.isRecycled()) {
 				buffer.recycleBuffer();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java
index 1ec864d..498e338 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java
@@ -19,10 +19,10 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 
-import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -42,34 +42,34 @@ public class BufferOrEvent {
 	 */
 	private boolean moreAvailable;
 
-	private int channelIndex;
+	private InputChannelInfo channelInfo;
 
 	private final int size;
 
-	public BufferOrEvent(Buffer buffer, int channelIndex, boolean moreAvailable) {
+	public BufferOrEvent(Buffer buffer, InputChannelInfo channelInfo, boolean moreAvailable) {
 		this.buffer = checkNotNull(buffer);
 		this.event = null;
-		this.channelIndex = channelIndex;
+		this.channelInfo = channelInfo;
 		this.moreAvailable = moreAvailable;
 		this.size = buffer.getSize();
 	}
 
-	public BufferOrEvent(AbstractEvent event, int channelIndex, boolean moreAvailable, int size) {
+	public BufferOrEvent(AbstractEvent event, InputChannelInfo channelInfo, boolean moreAvailable, int size) {
 		this.buffer = null;
 		this.event = checkNotNull(event);
-		this.channelIndex = channelIndex;
+		this.channelInfo = channelInfo;
 		this.moreAvailable = moreAvailable;
 		this.size = size;
 	}
 
 	@VisibleForTesting
-	public BufferOrEvent(Buffer buffer, int channelIndex) {
-		this(buffer, channelIndex, true);
+	public BufferOrEvent(Buffer buffer, InputChannelInfo channelInfo) {
+		this(buffer, channelInfo, true);
 	}
 
 	@VisibleForTesting
-	public BufferOrEvent(AbstractEvent event, int channelIndex) {
-		this(event, channelIndex, true, 0);
+	public BufferOrEvent(AbstractEvent event, InputChannelInfo channelInfo) {
+		this(event, channelInfo, true, 0);
 	}
 
 	public boolean isBuffer() {
@@ -88,13 +88,12 @@ public class BufferOrEvent {
 		return event;
 	}
 
-	public int getChannelIndex() {
-		return channelIndex;
+	public InputChannelInfo getChannelInfo() {
+		return channelInfo;
 	}
 
-	public void setChannelIndex(int channelIndex) {
-		checkArgument(channelIndex >= 0);
-		this.channelIndex = channelIndex;
+	public void setChannelInfo(InputChannelInfo channelInfo) {
+		this.channelInfo = channelInfo;
 	}
 
 	public boolean moreAvailable() {
@@ -103,8 +102,8 @@ public class BufferOrEvent {
 
 	@Override
 	public String toString() {
-		return String.format("BufferOrEvent [%s, channelIndex = %d, size = %d]",
-				isBuffer() ? buffer : event, channelIndex, size);
+		return String.format("BufferOrEvent [%s, channelInfo = %d, size = %d]",
+				isBuffer() ? buffer : event, channelInfo, size);
 	}
 
 	public void setMoreAvailable(boolean moreAvailable) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index 0f227f9..0bd06c0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -669,7 +669,7 @@ public class SingleInputGate extends IndexedInputGate {
 	}
 
 	private BufferOrEvent transformBuffer(Buffer buffer, boolean moreAvailable, InputChannel currentChannel) {
-		return new BufferOrEvent(decompressBufferIfNeeded(buffer), currentChannel.getChannelIndex(), moreAvailable);
+		return new BufferOrEvent(decompressBufferIfNeeded(buffer), currentChannel.getChannelInfo(), moreAvailable);
 	}
 
 	private BufferOrEvent transformEvent(
@@ -700,7 +700,7 @@ public class SingleInputGate extends IndexedInputGate {
 			currentChannel.releaseAllResources();
 		}
 
-		return new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable, buffer.getSize());
+		return new BufferOrEvent(event, currentChannel.getChannelInfo(), moreAvailable, buffer.getSize());
 	}
 
 	private Buffer decompressBufferIfNeeded(Buffer buffer) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index ad8361c..c05eef7 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -176,10 +176,11 @@ public class UnionInputGate extends InputGate {
 		InputWithData<IndexedInputGate, BufferOrEvent> inputWithData = next.get();
 
 		handleEndOfPartitionEvent(inputWithData.data, inputWithData.input);
-		return Optional.of(adjustForUnionInputGate(
-			inputWithData.data,
-			inputWithData.input,
-			inputWithData.moreAvailable));
+		if (!inputWithData.data.moreAvailable()) {
+			inputWithData.data.setMoreAvailable(inputWithData.moreAvailable);
+		}
+
+		return Optional.of(inputWithData.data);
 	}
 
 	private Optional<InputWithData<IndexedInputGate, BufferOrEvent>> waitAndGetNextData(boolean blocking)
@@ -217,19 +218,6 @@ public class UnionInputGate extends InputGate {
 		}
 	}
 
-	private BufferOrEvent adjustForUnionInputGate(
-			BufferOrEvent bufferOrEvent,
-			IndexedInputGate inputGate,
-			boolean moreInputGatesAvailable) {
-		// Set the channel index to identify the input channel (across all unioned input gates)
-		final int channelIndexOffset = inputGateChannelIndexOffsets[inputGate.getGateIndex()];
-
-		bufferOrEvent.setChannelIndex(channelIndexOffset + bufferOrEvent.getChannelIndex());
-		bufferOrEvent.setMoreAvailable(bufferOrEvent.moreAvailable() || moreInputGatesAvailable);
-
-		return bufferOrEvent;
-	}
-
 	private void handleEndOfPartitionEvent(BufferOrEvent bufferOrEvent, InputGate inputGate) {
 		if (bufferOrEvent.isEvent()
 			&& bufferOrEvent.getEvent().getClass() == EndOfPartitionEvent.class
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
index 6e75dff..caa934a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
@@ -657,12 +658,12 @@ public class RecordWriterTest {
 	static BufferOrEvent parseBuffer(BufferConsumer bufferConsumer, int targetChannel) throws IOException {
 		Buffer buffer = buildSingleBuffer(bufferConsumer);
 		if (buffer.isBuffer()) {
-			return new BufferOrEvent(buffer, targetChannel);
+			return new BufferOrEvent(buffer, new InputChannelInfo(0, targetChannel));
 		} else {
 			// is event:
 			AbstractEvent event = EventSerializer.fromBuffer(buffer, RecordWriterTest.class.getClassLoader());
 			buffer.recycleBuffer(); // the buffer is not needed anymore
-			return new BufferOrEvent(event, targetChannel);
+			return new BufferOrEvent(event, new InputChannelInfo(0, targetChannel));
 		}
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
index e9e16db..b6c0f48 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
@@ -629,11 +629,11 @@ public class LocalInputChannelTest {
 						boe.get().getBuffer().recycleBuffer();
 
 						// Check that we don't receive too many buffers
-						if (++numberOfBuffersPerChannel[boe.get().getChannelIndex()]
+						if (++numberOfBuffersPerChannel[boe.get().getChannelInfo().getInputChannelIdx()]
 								> numberOfExpectedBuffersPerChannel) {
 
 							throw new IllegalStateException("Received more buffers than expected " +
-									"on channel " + boe.get().getChannelIndex() + ".");
+									"on channel " + boe.get().getChannelInfo() + ".");
 						}
 					}
 				}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java
index b279998..8ae12cfc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java
@@ -29,7 +29,11 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.taskmanager.NettyShuffleEnvironmentConfiguration;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
+import java.util.function.BiFunction;
+import java.util.stream.IntStream;
 
 /**
  * Utility class to encapsulate the logic of building a {@link SingleInputGate} instance.
@@ -54,6 +58,9 @@ public class SingleInputGateBuilder {
 
 	private MemorySegmentProvider segmentProvider = InputChannelTestUtils.StubMemorySegmentProvider.getInstance();
 
+	@Nullable
+	private BiFunction<InputChannelBuilder, SingleInputGate, InputChannel> channelFactory = null;
+
 	private SupplierWithException<BufferPool, IOException> bufferPoolFactory = () -> {
 		throw new UnsupportedOperationException();
 	};
@@ -112,8 +119,17 @@ public class SingleInputGateBuilder {
 		return this;
 	}
 
+	/**
+	 * Adds automatic initialization of all channels with the given factory.
+	 */
+	public SingleInputGateBuilder setChannelFactory(
+			BiFunction<InputChannelBuilder, SingleInputGate, InputChannel> channelFactory) {
+		this.channelFactory = channelFactory;
+		return this;
+	}
+
 	public SingleInputGate build() {
-		return new SingleInputGate(
+		SingleInputGate gate = new SingleInputGate(
 			"Single Input Gate",
 			gateIndex,
 			intermediateDataSetID,
@@ -124,5 +140,11 @@ public class SingleInputGateBuilder {
 			bufferPoolFactory,
 			bufferDecompressor,
 			segmentProvider);
+		if (channelFactory != null) {
+			gate.setInputChannels(IntStream.range(0, numberOfChannels)
+				.mapToObj(index -> channelFactory.apply(InputChannelBuilder.newBuilder().setChannelIndex(index), gate))
+				.toArray(InputChannel[]::new));
+		}
+		return gate;
 	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 0bd9a41..ea3f668 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -821,14 +821,13 @@ public class SingleInputGateTest extends InputGateTestBase {
 		// Setup
 		final SingleInputGate inputGate = createInputGate(network, 2, ResultPartitionType.PIPELINED);
 
-		final int channelIndex1 = 0, channelIndex2 = 1;
 		final RemoteInputChannel remoteInputChannel1 = InputChannelBuilder.newBuilder()
-			.setChannelIndex(channelIndex1)
+			.setChannelIndex(0)
 			.setupFromNettyShuffleEnvironment(network)
 			.setConnectionManager(new TestingConnectionManager())
 			.buildRemoteChannel(inputGate);
 		final RemoteInputChannel remoteInputChannel2 = InputChannelBuilder.newBuilder()
-			.setChannelIndex(channelIndex2)
+			.setChannelIndex(1)
 			.setupFromNettyShuffleEnvironment(network)
 			.setConnectionManager(new TestingConnectionManager())
 			.buildRemoteChannel(inputGate);
@@ -838,12 +837,12 @@ public class SingleInputGateTest extends InputGateTestBase {
 		inputGate.registerBufferReceivedListener(new BufferReceivedListener() {
 			@Override
 			public void notifyBufferReceived(Buffer buffer, InputChannelInfo channelInfo) {
-				notifications.add(new BufferOrEvent(buffer, channelInfo.getInputChannelIdx()));
+				notifications.add(new BufferOrEvent(buffer, channelInfo));
 			}
 
 			@Override
 			public void notifyBarrierReceived(CheckpointBarrier barrier, InputChannelInfo channelInfo) {
-				notifications.add(new BufferOrEvent(barrier, channelInfo.getInputChannelIdx()));
+				notifications.add(new BufferOrEvent(barrier, channelInfo));
 			}
 		});
 		setupInputGate(inputGate, remoteInputChannel1, remoteInputChannel2);
@@ -873,10 +872,10 @@ public class SingleInputGateTest extends InputGateTestBase {
 		}
 
 		assertEquals(getIds(asList(
-			new BufferOrEvent(new CheckpointBarrier(0, 0, options), channelIndex2),
-			new BufferOrEvent(createBuffer(11), channelIndex1),
-			new BufferOrEvent(new CheckpointBarrier(1, 0, options), channelIndex1),
-			new BufferOrEvent(createBuffer(22), channelIndex2)
+			new BufferOrEvent(new CheckpointBarrier(0, 0, options), remoteInputChannel2.getChannelInfo()),
+			new BufferOrEvent(createBuffer(11), remoteInputChannel1.getChannelInfo()),
+			new BufferOrEvent(new CheckpointBarrier(1, 0, options), remoteInputChannel1.getChannelInfo()),
+			new BufferOrEvent(createBuffer(22), remoteInputChannel2.getChannelInfo())
 		)), getIds(notifications));
 	}
 
@@ -1071,7 +1070,7 @@ public class SingleInputGateTest extends InputGateTestBase {
 		final Optional<BufferOrEvent> bufferOrEvent = inputGate.getNext();
 		assertTrue(bufferOrEvent.isPresent());
 		assertEquals(expectedIsBuffer, bufferOrEvent.get().isBuffer());
-		assertEquals(expectedChannelIndex, bufferOrEvent.get().getChannelIndex());
+		assertEquals(inputGate.getChannel(expectedChannelIndex).getChannelInfo(), bufferOrEvent.get().getChannelInfo());
 		assertEquals(expectedMoreAvailable, bufferOrEvent.get().moreAvailable());
 		if (!expectedMoreAvailable) {
 			assertFalse(inputGate.pollNext().isPresent());
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java
index 3b27d95..2fb6e72 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java
@@ -51,12 +51,12 @@ class AlternatingCheckpointBarrierHandler extends CheckpointBarrierHandler {
 	}
 
 	@Override
-	public boolean isBlocked(int channelIndex) {
-		return activeHandler.isBlocked(channelIndex);
+	public boolean isBlocked(InputChannelInfo channelInfo) {
+		return activeHandler.isBlocked(channelInfo);
 	}
 
 	@Override
-	public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception {
+	public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception {
 		if (receivedBarrier.getId() < lastSeenBarrierId) {
 			return;
 		}
@@ -70,7 +70,7 @@ class AlternatingCheckpointBarrierHandler extends CheckpointBarrierHandler {
 				new CheckpointException(format("checkpoint subsumed by %d", lastSeenBarrierId), CHECKPOINT_DECLINED_SUBSUMED));
 		}
 
-		activeHandler.processBarrier(receivedBarrier, channelIndex);
+		activeHandler.processBarrier(receivedBarrier, channelInfo);
 	}
 
 	@Override
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java
index 01892bc..a052a70 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
@@ -31,9 +32,11 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
-import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -46,14 +49,8 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 
 	private static final Logger LOG = LoggerFactory.getLogger(CheckpointBarrierAligner.class);
 
-	/** Used to get InputGate by channel index. */
-	private final InputGate[] channelIndexToInputGate;
-
-	/** Used to get channel index offset by InputGate. */
-	private final Map<InputGate, Integer> inputGateToChannelIndexOffset;
-
 	/** Flags that indicate whether a channel is currently blocked/buffered. */
-	private final boolean[] blockedChannels;
+	private final Map<InputChannelInfo, Boolean> blockedChannels;
 
 	/** The total number of channels that this buffer handles data from. */
 	private final int totalNumberOfInputChannels;
@@ -78,18 +75,20 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 	/** The time (in nanoseconds) that the latest alignment took. */
 	private long latestAlignmentDurationNanos;
 
+	private final InputGate[] inputGates;
+
 	CheckpointBarrierAligner(
 			String taskName,
-			InputGate[] channelIndexToInputGate,
-			Map<InputGate, Integer> inputGateToChannelIndexOffset,
-			AbstractInvokable toNotifyOnCheckpoint) {
+			AbstractInvokable toNotifyOnCheckpoint,
+			InputGate... inputGates) {
 		super(toNotifyOnCheckpoint);
-		this.taskName = taskName;
-		this.channelIndexToInputGate = checkNotNull(channelIndexToInputGate);
-		this.inputGateToChannelIndexOffset = checkNotNull(inputGateToChannelIndexOffset);
-		this.totalNumberOfInputChannels = channelIndexToInputGate.length;
 
-		this.blockedChannels = new boolean[totalNumberOfInputChannels];
+		this.taskName = taskName;
+		this.inputGates = inputGates;
+		blockedChannels = Arrays.stream(inputGates)
+			.flatMap(gate -> gate.getChannelInfos().stream())
+			.collect(Collectors.toMap(Function.identity(), info -> false));
+		totalNumberOfInputChannels = blockedChannels.size();
 	}
 
 	@Override
@@ -104,12 +103,12 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 	public void releaseBlocksAndResetBarriers() {
 		LOG.debug("{}: End of stream alignment, feeding buffered data back.", taskName);
 
-		for (int i = 0; i < blockedChannels.length; i++) {
-			if (blockedChannels[i]) {
-				resumeConsumption(i);
+		blockedChannels.entrySet().forEach(blockedChannel -> {
+			if (blockedChannel.getValue()) {
+				resumeConsumption(blockedChannel.getKey());
 			}
-			blockedChannels[i] = false;
-		}
+			blockedChannel.setValue(false);
+		});
 
 		// the next barrier that comes must assume it is the first
 		numBarriersReceived = 0;
@@ -121,17 +120,17 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 	}
 
 	@Override
-	public boolean isBlocked(int channelIndex) {
-		return blockedChannels[channelIndex];
+	public boolean isBlocked(InputChannelInfo channelInfo) {
+		return blockedChannels.get(channelInfo);
 	}
 
 	@Override
-	public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception {
+	public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception {
 		final long barrierId = receivedBarrier.getId();
 
 		// fast path for single channel cases
 		if (totalNumberOfInputChannels == 1) {
-			resumeConsumption(channelIndex);
+			resumeConsumption(channelInfo);
 			if (barrierId > currentCheckpointId) {
 				// new checkpoint
 				currentCheckpointId = barrierId;
@@ -147,7 +146,7 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 
 			if (barrierId == currentCheckpointId) {
 				// regular case
-				onBarrier(channelIndex);
+				onBarrier(channelInfo);
 			}
 			else if (barrierId > currentCheckpointId) {
 				// we did not complete the current checkpoint, another started before
@@ -167,21 +166,21 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 				releaseBlocksAndResetBarriers();
 
 				// begin a new checkpoint
-				beginNewAlignment(barrierId, channelIndex, receivedBarrier.getTimestamp());
+				beginNewAlignment(barrierId, channelInfo, receivedBarrier.getTimestamp());
 			}
 			else {
 				// ignore trailing barrier from an earlier checkpoint (obsolete now)
-				resumeConsumption(channelIndex);
+				resumeConsumption(channelInfo);
 			}
 		}
 		else if (barrierId > currentCheckpointId) {
 			// first barrier of a new checkpoint
-			beginNewAlignment(barrierId, channelIndex, receivedBarrier.getTimestamp());
+			beginNewAlignment(barrierId, channelInfo, receivedBarrier.getTimestamp());
 		}
 		else {
 			// either the current checkpoint was canceled (numBarriers == 0) or
 			// this barrier is from an old subsumed checkpoint
-			resumeConsumption(channelIndex);
+			resumeConsumption(channelInfo);
 		}
 
 		// check if we have all barriers - since canceled checkpoints always have zero barriers
@@ -202,11 +201,11 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 
 	protected void beginNewAlignment(
 			long checkpointId,
-			int channelIndex,
+			InputChannelInfo channelInfo,
 			long checkpointTimestamp) throws IOException {
 		markCheckpointStart(checkpointTimestamp);
 		currentCheckpointId = checkpointId;
-		onBarrier(channelIndex);
+		onBarrier(channelInfo);
 
 		startOfAlignmentTimestamp = System.nanoTime();
 
@@ -218,20 +217,20 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 	/**
 	 * Blocks the given channel index, from which a barrier has been received.
 	 *
-	 * @param channelIndex The channel index to block.
+	 * @param channelInfo The channel to block.
 	 */
-	protected void onBarrier(int channelIndex) throws IOException {
-		if (!blockedChannels[channelIndex]) {
-			blockedChannels[channelIndex] = true;
+	protected void onBarrier(InputChannelInfo channelInfo) throws IOException {
+		if (!blockedChannels.get(channelInfo)) {
+			blockedChannels.put(channelInfo, true);
 
 			numBarriersReceived++;
 
 			if (LOG.isDebugEnabled()) {
-				LOG.debug("{}: Received barrier from channel {}.", taskName, channelIndex);
+				LOG.debug("{}: Received barrier from channel {}.", taskName, channelInfo);
 			}
 		}
 		else {
-			throw new IOException("Stream corrupt: Repeated barrier for same checkpoint on input " + channelIndex);
+			throw new IOException("Stream corrupt: Repeated barrier for same checkpoint on input " + channelInfo);
 		}
 	}
 
@@ -339,11 +338,11 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler {
 		return numBarriersReceived > 0;
 	}
 
-	private void resumeConsumption(int channelIndex) {
-		InputGate inputGate = channelIndexToInputGate[channelIndex];
+	private void resumeConsumption(InputChannelInfo channelInfo) {
+		InputGate inputGate = inputGates[channelInfo.getGateIdx()];
 		checkState(!inputGate.isFinished(), "InputGate already finished.");
 
-		inputGate.resumeConsumption(channelIndex - inputGateToChannelIndexOffset.get(inputGate));
+		inputGate.resumeConsumption(channelInfo.getInputChannelIdx());
 	}
 
 	@VisibleForTesting
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java
index 15382ee..fb0a319 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java
@@ -58,10 +58,10 @@ public abstract class CheckpointBarrierHandler implements Closeable {
 	/**
 	 * Checks whether the channel with the given index is blocked.
 	 *
-	 * @param channelIndex The channel index to check.
+	 * @param channelInfo The channel index to check.
 	 * @return True if the channel is blocked, false if not.
 	 */
-	public boolean isBlocked(int channelIndex) {
+	public boolean isBlocked(InputChannelInfo channelInfo) {
 		return false;
 	}
 
@@ -69,7 +69,7 @@ public abstract class CheckpointBarrierHandler implements Closeable {
 	public void close() throws IOException {
 	}
 
-	public abstract void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception;
+	public abstract void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception;
 
 	public abstract void processCancellationBarrier(CancelCheckpointMarker cancelBarrier) throws Exception;
 
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java
index 7dbfbaa..6b89854 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java
@@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.io;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
@@ -75,7 +76,7 @@ public class CheckpointBarrierTracker extends CheckpointBarrierHandler {
 		this.pendingCheckpoints = new ArrayDeque<>();
 	}
 
-	public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception {
+	public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception {
 		final long barrierId = receivedBarrier.getId();
 
 		// fast path for single channel trackers
@@ -86,7 +87,7 @@ public class CheckpointBarrierTracker extends CheckpointBarrierHandler {
 
 		// general path for multiple input channels
 		if (LOG.isDebugEnabled()) {
-			LOG.debug("Received barrier for checkpoint {} from channel {}", barrierId, channelIndex);
+			LOG.debug("Received barrier for checkpoint {} from channel {}", barrierId, channelInfo);
 		}
 
 		// find the checkpoint barrier in the queue of pending barriers
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
index 114c03e..8d7cdbf 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferReceivedListener;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.streaming.runtime.tasks.SubtaskCheckpointCoordinator;
 
@@ -41,10 +42,11 @@ import javax.annotation.concurrent.ThreadSafe;
 import java.io.Closeable;
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.function.Function;
-import java.util.stream.IntStream;
+import java.util.stream.Collectors;
 
 import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED;
 import static org.apache.flink.util.CloseableIterator.ofElement;
@@ -66,20 +68,11 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 	 * Tag the state of which input channel has pending in-flight buffers; that is, already received buffers that
 	 * predate the checkpoint barrier of the current checkpoint.
 	 */
-	private final boolean[] hasInflightBuffers;
+	private final Map<InputChannelInfo, Boolean> hasInflightBuffers;
 
 	private int numBarrierConsumed;
 
 	/**
-	 * Contains the offsets of the channel indices for each gate when flattening the channels of all gates.
-	 *
-	 * <p>For example, consider 3 gates with 4 channels, {@code gateChannelOffsets = [0, 4, 8]}.
-	 */
-	private final int[] gateChannelOffsets;
-
-	private final InputChannelInfo[] channelInfos;
-
-	/**
 	 * The checkpoint id to guarantee that we would trigger only one checkpoint when reading the same barrier from
 	 * different channels.
 	 *
@@ -92,31 +85,17 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 	private final ThreadSafeUnaligner threadSafeUnaligner;
 
 	CheckpointBarrierUnaligner(
-			int[] numberOfInputChannelsPerGate,
 			SubtaskCheckpointCoordinator checkpointCoordinator,
 			String taskName,
-			AbstractInvokable toNotifyOnCheckpoint) {
+			AbstractInvokable toNotifyOnCheckpoint,
+			InputGate... inputGates) {
 		super(toNotifyOnCheckpoint);
 
 		this.taskName = taskName;
-
-		final int numGates = numberOfInputChannelsPerGate.length;
-
-		gateChannelOffsets = new int[numGates];
-		for (int index = 1; index < numGates; index++) {
-			gateChannelOffsets[index] = gateChannelOffsets[index - 1] + numberOfInputChannelsPerGate[index - 1];
-		}
-
-		final int totalNumChannels = gateChannelOffsets[numGates - 1] + numberOfInputChannelsPerGate[numGates - 1];
-		hasInflightBuffers = new boolean[totalNumChannels];
-
-		channelInfos = IntStream.range(0, numGates)
-			.mapToObj(gateIndex -> IntStream.range(0, numberOfInputChannelsPerGate[gateIndex])
-				.mapToObj(channelIndex -> new InputChannelInfo(gateIndex, channelIndex)))
-			.flatMap(Function.identity())
-			.toArray(InputChannelInfo[]::new);
-
-		threadSafeUnaligner = new ThreadSafeUnaligner(totalNumChannels,	checkNotNull(checkpointCoordinator), this);
+		hasInflightBuffers = Arrays.stream(inputGates)
+			.flatMap(gate -> gate.getChannelInfos().stream())
+			.collect(Collectors.toMap(Function.identity(), info -> false));
+		threadSafeUnaligner = new ThreadSafeUnaligner(checkNotNull(checkpointCoordinator), this, inputGates);
 	}
 
 	/**
@@ -127,7 +106,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 	 * <p>Note this is also suitable for the trigger case of local input channel.
 	 */
 	@Override
-	public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception {
+	public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception {
 		long barrierId = receivedBarrier.getId();
 		if (currentConsumedCheckpointId > barrierId || (currentConsumedCheckpointId == barrierId && !isCheckpointPending())) {
 			// ignore old and cancelled barriers
@@ -136,13 +115,13 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 		if (currentConsumedCheckpointId < barrierId) {
 			currentConsumedCheckpointId = barrierId;
 			numBarrierConsumed = 0;
-			Arrays.fill(hasInflightBuffers, true);
+			hasInflightBuffers.entrySet().forEach(hasInflightBuffer -> hasInflightBuffer.setValue(true));
 		}
 		if (currentConsumedCheckpointId == barrierId) {
-			hasInflightBuffers[channelIndex] = false;
+			hasInflightBuffers.put(channelInfo, false);
 			numBarrierConsumed++;
 		}
-		threadSafeUnaligner.notifyBarrierReceived(receivedBarrier, channelInfos[channelIndex]);
+		threadSafeUnaligner.notifyBarrierReceived(receivedBarrier, channelInfo);
 	}
 
 	@Override
@@ -184,7 +163,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 				checkpointId,
 				currentConsumedCheckpointId);
 
-			Arrays.fill(hasInflightBuffers, false);
+			hasInflightBuffers.entrySet().forEach(hasInflightBuffer -> hasInflightBuffer.setValue(false));
 			numBarrierConsumed = 0;
 		}
 	}
@@ -213,7 +192,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 		if (checkpointId > currentConsumedCheckpointId) {
 			return true;
 		}
-		return hasInflightBuffers[getFlattenedChannelIndex(channelInfo)];
+		return hasInflightBuffers.get(channelInfo);
 	}
 
 	@Override
@@ -231,10 +210,6 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 		return numBarrierConsumed > 0;
 	}
 
-	private int getFlattenedChannelIndex(InputChannelInfo channelInfo) {
-		return gateChannelOffsets[channelInfo.getGateIdx()] + channelInfo.getInputChannelIdx();
-	}
-
 	@VisibleForTesting
 	int getNumOpenChannels() {
 		return threadSafeUnaligner.getNumOpenChannels();
@@ -259,7 +234,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 		 * Tag the state of which input channel has not received the barrier, such that newly arriving buffers need
 		 * to be written in the unaligned checkpoint.
 		 */
-		private final boolean[] storeNewBuffers;
+		private final Map<InputChannelInfo, Boolean> storeNewBuffers;
 
 		/** The number of input channels which has received or processed the barrier. */
 		private int numBarriersReceived;
@@ -282,9 +257,11 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 
 		private final CheckpointBarrierUnaligner handler;
 
-		ThreadSafeUnaligner(int totalNumChannels, SubtaskCheckpointCoordinator checkpointCoordinator, CheckpointBarrierUnaligner handler) {
-			this.numOpenChannels = totalNumChannels;
-			this.storeNewBuffers = new boolean[totalNumChannels];
+		ThreadSafeUnaligner(SubtaskCheckpointCoordinator checkpointCoordinator, CheckpointBarrierUnaligner handler, InputGate... inputGates) {
+			storeNewBuffers = Arrays.stream(inputGates)
+				.flatMap(gate -> gate.getChannelInfos().stream())
+				.collect(Collectors.toMap(Function.identity(), info -> false));
+			numOpenChannels = storeNewBuffers.size();
 			this.checkpointCoordinator = checkpointCoordinator;
 			this.handler = handler;
 		}
@@ -298,13 +275,12 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 				handler.executeInTaskThread(() -> handler.notifyCheckpoint(barrier), "notifyCheckpoint");
 			}
 
-			int channelIndex = handler.getFlattenedChannelIndex(channelInfo);
-			if (barrierId == currentReceivedCheckpointId && storeNewBuffers[channelIndex]) {
+			if (barrierId == currentReceivedCheckpointId && storeNewBuffers.get(channelInfo)) {
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{}: Received barrier from channel {} @ {}.", handler.taskName, channelIndex, barrierId);
+					LOG.debug("{}: Received barrier from channel {} @ {}.", handler.taskName, channelInfo, barrierId);
 				}
 
-				storeNewBuffers[channelIndex] = false;
+				storeNewBuffers.put(channelInfo, false);
 
 				if (++numBarriersReceived == numOpenChannels) {
 					allBarriersReceivedFuture.complete(null);
@@ -314,7 +290,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 
 		@Override
 		public synchronized void notifyBufferReceived(Buffer buffer, InputChannelInfo channelInfo) {
-			if (storeNewBuffers[handler.getFlattenedChannelIndex(channelInfo)]) {
+			if (storeNewBuffers.get(channelInfo)) {
 				checkpointCoordinator.getChannelStateWriter().addInputData(
 					currentReceivedCheckpointId,
 					channelInfo,
@@ -350,7 +326,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 			}
 
 			currentReceivedCheckpointId = barrierId;
-			Arrays.fill(storeNewBuffers, true);
+			storeNewBuffers.entrySet().forEach(storeNewBuffer -> storeNewBuffer.setValue(true));
 			numBarriersReceived = 0;
 			allBarriersReceivedFuture = new CompletableFuture<>();
 			checkpointCoordinator.initCheckpoint(barrierId, barrier.getCheckpointOptions());
@@ -397,7 +373,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 				return false;
 			}
 
-			Arrays.fill(storeNewBuffers, false);
+			storeNewBuffers.entrySet().forEach(storeNewBuffer -> storeNewBuffer.setValue(false));
 			numBarriersReceived = 0;
 			return true;
 		}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java
index 9428515..cb503b1 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java
@@ -28,10 +28,6 @@ import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
-import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.io.IOException;
@@ -47,39 +43,14 @@ import static org.apache.flink.util.Preconditions.checkState;
  */
 @Internal
 public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEvent>, Closeable {
-
-	private static final Logger LOG = LoggerFactory.getLogger(CheckpointedInputGate.class);
-
 	private final CheckpointBarrierHandler barrierHandler;
 
 	/** The gate that the buffer draws its input from. */
 	private final InputGate inputGate;
 
-	private final int channelIndexOffset;
-
 	/** Indicate end of the input. */
 	private boolean isFinished;
 
-	public CheckpointedInputGate(
-			InputGate inputGate,
-			String taskName,
-			AbstractInvokable toNotifyOnCheckpoint) {
-		this(
-			inputGate,
-			new CheckpointBarrierAligner(
-				taskName,
-				InputProcessorUtil.generateChannelIndexToInputGateMap(inputGate),
-				InputProcessorUtil.generateInputGateToChannelIndexOffsetMap(inputGate),
-				toNotifyOnCheckpoint)
-		);
-	}
-
-	public CheckpointedInputGate(
-			InputGate inputGate,
-			CheckpointBarrierHandler barrierHandler) {
-		this(inputGate, barrierHandler, 0);
-	}
-
 	/**
 	 * Creates a new checkpoint stream aligner.
 	 *
@@ -89,15 +60,11 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven
 	 *
 	 * @param inputGate The input gate to draw the buffers and events from.
 	 * @param barrierHandler Handler that controls which channels are blocked.
-	 * @param channelIndexOffset Optional offset added to channelIndex returned from the inputGate
-	 *                           before passing it to the barrierHandler.
 	 */
 	public CheckpointedInputGate(
 			InputGate inputGate,
-			CheckpointBarrierHandler barrierHandler,
-			int channelIndexOffset) {
+			CheckpointBarrierHandler barrierHandler) {
 		this.inputGate = inputGate;
-		this.channelIndexOffset = channelIndexOffset;
 		this.barrierHandler = barrierHandler;
 	}
 
@@ -116,14 +83,14 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven
 			}
 
 			BufferOrEvent bufferOrEvent = next.get();
-			checkState(!barrierHandler.isBlocked(offsetChannelIndex(bufferOrEvent.getChannelIndex())));
+			checkState(!barrierHandler.isBlocked(bufferOrEvent.getChannelInfo()));
 
 			if (bufferOrEvent.isBuffer()) {
 				return next;
 			}
 			else if (bufferOrEvent.getEvent().getClass() == CheckpointBarrier.class) {
 				CheckpointBarrier checkpointBarrier = (CheckpointBarrier) bufferOrEvent.getEvent();
-				barrierHandler.processBarrier(checkpointBarrier, offsetChannelIndex(bufferOrEvent.getChannelIndex()));
+				barrierHandler.processBarrier(checkpointBarrier, bufferOrEvent.getChannelInfo());
 				return next;
 			}
 			else if (bufferOrEvent.getEvent().getClass() == CancelCheckpointMarker.class) {
@@ -152,10 +119,6 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven
 		return barrierHandler.getAllBarriersReceivedFuture(checkpointId);
 	}
 
-	private int offsetChannelIndex(int channelIndex) {
-		return channelIndex + channelIndexOffset;
-	}
-
 	private Optional<BufferOrEvent> handleEmptyBuffer() {
 		if (inputGate.isFinished()) {
 			isFinished = true;
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java
index 762761b..3ed8584 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java
@@ -31,9 +31,7 @@ import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.stream.IntStream;
+import java.util.List;
 
 /**
  * Utility for creating {@link CheckpointedInputGate} based on checkpoint mode
@@ -71,108 +69,55 @@ public class InputProcessorUtil {
 			String taskName,
 			List<IndexedInputGate>... inputGates) {
 
-		IntStream numberOfInputChannelsPerGate =
-			Arrays
-				.stream(inputGates)
-				.flatMap(collection -> collection.stream())
-				.sorted(Comparator.comparingInt(IndexedInputGate::getGateIndex))
-				.mapToInt(InputGate::getNumberOfInputChannels);
-
-		Map<InputGate, Integer> inputGateToChannelIndexOffset = generateInputGateToChannelIndexOffsetMap(unionedInputGates);
-		// Note that numberOfInputChannelsPerGate and inputGateToChannelIndexOffset have a bit different
-		// indexing and purposes.
-		//
-		// The numberOfInputChannelsPerGate is indexed based on flattened input gates, and sorted based on GateIndex,
-		// so that it can be used in combination with InputChannelInfo class.
-		//
-		// The inputGateToChannelIndexOffset is based upon unioned input gates and it's use for translating channel
-		// indexes from perspective of UnionInputGate to perspective of SingleInputGate.
-
+		IndexedInputGate[] sortedInputGates = Arrays.stream(inputGates)
+			.flatMap(Collection::stream)
+			.sorted(Comparator.comparing(IndexedInputGate::getGateIndex))
+			.toArray(IndexedInputGate[]::new);
 		CheckpointBarrierHandler barrierHandler = createCheckpointBarrierHandler(
 			config,
-			numberOfInputChannelsPerGate,
+			sortedInputGates,
 			checkpointCoordinator,
 			taskName,
-			generateChannelIndexToInputGateMap(unionedInputGates),
-			inputGateToChannelIndexOffset,
 			toNotifyOnCheckpoint);
 		registerCheckpointMetrics(taskIOMetricGroup, barrierHandler);
 
+		InputGate[] unionedInputGates = Arrays.stream(inputGates)
+			.map(InputGateUtil::createInputGate)
+			.toArray(InputGate[]::new);
 		barrierHandler.getBufferReceivedListener().ifPresent(listener -> {
 			for (final InputGate inputGate : unionedInputGates) {
 				inputGate.registerBufferReceivedListener(listener);
 			}
 		});
 
-		CheckpointedInputGate[] checkpointedInputGates = new CheckpointedInputGate[unionedInputGates.length];
-
-		for (int i = 0; i < unionedInputGates.length; i++) {
-			checkpointedInputGates[i] = new CheckpointedInputGate(
-				unionedInputGates[i], barrierHandler, inputGateToChannelIndexOffset.get(unionedInputGates[i]));
-		}
-
-		return checkpointedInputGates;
+		return Arrays.stream(unionedInputGates)
+			.map(unionedInputGate -> new CheckpointedInputGate(unionedInputGate, barrierHandler))
+			.toArray(CheckpointedInputGate[]::new);
 	}
 
 	private static CheckpointBarrierHandler createCheckpointBarrierHandler(
 			StreamConfig config,
-			IntStream numberOfInputChannelsPerGate,
+			InputGate[] inputGates,
 			SubtaskCheckpointCoordinator checkpointCoordinator,
 			String taskName,
-			InputGate[] channelIndexToInputGate,
-			Map<InputGate, Integer> inputGateToChannelIndexOffset,
 			AbstractInvokable toNotifyOnCheckpoint) {
 		switch (config.getCheckpointMode()) {
 			case EXACTLY_ONCE:
 				if (config.isUnalignedCheckpointsEnabled()) {
 					return new AlternatingCheckpointBarrierHandler(
-						new CheckpointBarrierAligner(
-							taskName,
-							channelIndexToInputGate,
-							inputGateToChannelIndexOffset,
-							toNotifyOnCheckpoint),
-						new CheckpointBarrierUnaligner(
-							numberOfInputChannelsPerGate.toArray(),
-							checkpointCoordinator,
-							taskName,
-							toNotifyOnCheckpoint),
+						new CheckpointBarrierAligner(taskName, toNotifyOnCheckpoint, inputGates),
+						new CheckpointBarrierUnaligner(checkpointCoordinator, taskName, toNotifyOnCheckpoint, inputGates),
 						toNotifyOnCheckpoint);
 				}
-				return new CheckpointBarrierAligner(
-					taskName,
-					channelIndexToInputGate,
-					inputGateToChannelIndexOffset,
-					toNotifyOnCheckpoint);
+				return new CheckpointBarrierAligner(taskName, toNotifyOnCheckpoint, inputGates);
 			case AT_LEAST_ONCE:
-				return new CheckpointBarrierTracker(numberOfInputChannelsPerGate.sum(), toNotifyOnCheckpoint);
+				int numInputChannels = Arrays.stream(inputGates).mapToInt(InputGate::getNumberOfInputChannels).sum();
+				return new CheckpointBarrierTracker(numInputChannels, toNotifyOnCheckpoint);
 			default:
 				throw new UnsupportedOperationException("Unrecognized Checkpointing Mode: " + config.getCheckpointMode());
 		}
 	}
 
-	static InputGate[] generateChannelIndexToInputGateMap(InputGate ...inputGates) {
-		int numberOfInputChannels = Arrays.stream(inputGates).mapToInt(InputGate::getNumberOfInputChannels).sum();
-		InputGate[] channelIndexToInputGate = new InputGate[numberOfInputChannels];
-		int channelIndexOffset = 0;
-		for (InputGate inputGate: inputGates) {
-			for (int i = 0; i < inputGate.getNumberOfInputChannels(); ++i) {
-				channelIndexToInputGate[channelIndexOffset + i] = inputGate;
-			}
-			channelIndexOffset += inputGate.getNumberOfInputChannels();
-		}
-		return channelIndexToInputGate;
-	}
-
-	static Map<InputGate, Integer> generateInputGateToChannelIndexOffsetMap(InputGate ...inputGates) {
-		Map<InputGate, Integer> inputGateToChannelIndexOffset = new HashMap<>();
-		int channelIndexOffset = 0;
-		for (InputGate inputGate: inputGates) {
-			inputGateToChannelIndexOffset.put(inputGate, channelIndexOffset);
-			channelIndexOffset += inputGate.getNumberOfInputChannels();
-		}
-		return inputGateToChannelIndexOffset;
-	}
-
 	private static void registerCheckpointMetrics(TaskIOMetricGroup taskIOMetricGroup, CheckpointBarrierHandler barrierHandler) {
 		taskIOMetricGroup.gauge(MetricNames.CHECKPOINT_ALIGNMENT_TIME, barrierHandler::getAlignmentDurationNanos);
 		taskIOMetricGroup.gauge(MetricNames.CHECKPOINT_START_DELAY_TIME, barrierHandler::getCheckpointStartDelayNanos);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java
index 7134b44..8585dab 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java
@@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.io.InputStatus;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
@@ -41,7 +42,12 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
 import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve;
 import org.apache.flink.streaming.runtime.streamstatus.StreamStatus;
 
+import javax.annotation.Nonnull;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 
@@ -74,6 +80,8 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> {
 
 	private final int inputIndex;
 
+	private final Map<InputChannelInfo, Integer> channelIndexes;
+
 	private int lastChannel = UNSPECIFIED;
 
 	private RecordDeserializer<DeserializationDelegate<StreamElement>> currentRecordDeserializer = null;
@@ -98,6 +106,18 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> {
 
 		this.statusWatermarkValve = checkNotNull(statusWatermarkValve);
 		this.inputIndex = inputIndex;
+		this.channelIndexes = getChannelIndexes(checkpointedInputGate);
+	}
+
+	@Nonnull
+	private static Map<InputChannelInfo, Integer> getChannelIndexes(CheckpointedInputGate checkpointedInputGate) {
+		int index = 0;
+		List<InputChannelInfo> channelInfos = checkpointedInputGate.getChannelInfos();
+		Map<InputChannelInfo, Integer> channelIndexes = new HashMap<>(channelInfos.size());
+		for (InputChannelInfo channelInfo : channelInfos) {
+			channelIndexes.put(channelInfo, index++);
+		}
+		return channelIndexes;
 	}
 
 	@VisibleForTesting
@@ -114,6 +134,7 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> {
 		this.recordDeserializers = recordDeserializers;
 		this.statusWatermarkValve = statusWatermarkValve;
 		this.inputIndex = inputIndex;
+		this.channelIndexes = getChannelIndexes(checkpointedInputGate);
 	}
 
 	@Override
@@ -168,7 +189,7 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> {
 
 	private void processBufferOrEvent(BufferOrEvent bufferOrEvent) throws IOException {
 		if (bufferOrEvent.isBuffer()) {
-			lastChannel = bufferOrEvent.getChannelIndex();
+			lastChannel = channelIndexes.get(bufferOrEvent.getChannelInfo());
 			checkState(lastChannel != StreamTaskInput.UNSPECIFIED);
 			currentRecordDeserializer = recordDeserializers[lastChannel];
 			checkState(currentRecordDeserializer != null,
@@ -186,7 +207,7 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> {
 
 			// release the record deserializer immediately,
 			// which is very valuable in case of bounded stream
-			releaseDeserializer(bufferOrEvent.getChannelIndex());
+			releaseDeserializer(channelIndexes.get(bufferOrEvent.getChannelInfo()));
 		}
 	}
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java
index 16a6bf2..ea147df 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
@@ -43,7 +44,6 @@ import java.util.Arrays;
 import java.util.List;
 
 import static java.util.Collections.singletonList;
-import static java.util.Collections.singletonMap;
 import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
 import static org.apache.flink.runtime.checkpoint.CheckpointType.SAVEPOINT;
 import static org.apache.flink.runtime.io.network.api.serialization.EventSerializer.toBuffer;
@@ -88,14 +88,14 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(2).build();
 		inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1));
 		TestInvokable target = new TestInvokable();
-		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
+		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 
 		for (int i = 0; i < 4; i++) {
 			int channel = i % 2;
 			CheckpointType type = channel == 0 ? CHECKPOINT : SAVEPOINT;
-			barrierHandler.processBarrier(new CheckpointBarrier(i, 0, new CheckpointOptions(type, CheckpointStorageLocationReference.getDefault())), channel);
+			barrierHandler.processBarrier(new CheckpointBarrier(i, 0, new CheckpointOptions(type, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, channel));
 			assertEquals(type.isSavepoint(), alignedHandler.isCheckpointPending());
 			assertNotEquals(alignedHandler.isCheckpointPending(), unalignedHandler.isCheckpointPending());
 
@@ -118,12 +118,12 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(2).build();
 		inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1));
 		TestInvokable target = new TestInvokable();
-		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
+		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 
 		final long id = 1;
-		unalignedHandler.processBarrier(new CheckpointBarrier(id, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), 0);
+		unalignedHandler.processBarrier(new CheckpointBarrier(id, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, 0));
 
 		assertInflightDataEquals(unalignedHandler, barrierHandler, id, inputGate.getNumberOfInputChannels());
 		assertFalse(barrierHandler.getAllBarriersReceivedFuture(id).isDone());
@@ -134,16 +134,16 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(2).build();
 		inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1));
 		TestInvokable target = new TestInvokable();
-		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
+		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 
 		long checkpointId = 10;
 		long outOfOrderSavepointId = 5;
 		long initialAlignedCheckpointId = alignedHandler.getLatestCheckpointId();
 
-		barrierHandler.processBarrier(new CheckpointBarrier(checkpointId, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), 0);
-		barrierHandler.processBarrier(new CheckpointBarrier(outOfOrderSavepointId, 0, new CheckpointOptions(SAVEPOINT, CheckpointStorageLocationReference.getDefault())), 1);
+		barrierHandler.processBarrier(new CheckpointBarrier(checkpointId, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, 0));
+		barrierHandler.processBarrier(new CheckpointBarrier(outOfOrderSavepointId, 0, new CheckpointOptions(SAVEPOINT, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, 1));
 
 		assertEquals(checkpointId, barrierHandler.getLatestCheckpointId());
 		assertInflightDataEquals(unalignedHandler, barrierHandler, checkpointId, inputGate.getNumberOfInputChannels());
@@ -154,10 +154,13 @@ public class AlternatingCheckpointBarrierHandlerTest {
 	public void testEndOfPartition() throws Exception {
 		int totalChannels = 5;
 		int closedChannels = 2;
-		SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(totalChannels).build();
+		SingleInputGate inputGate = new SingleInputGateBuilder()
+			.setNumberOfChannels(totalChannels)
+			.setChannelFactory(InputChannelBuilder::buildLocalChannel)
+			.build();
 		TestInvokable target = new TestInvokable();
-		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
+		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 		for (int i = 0; i < closedChannels; i++) {
 			barrierHandler.processEndOfPartition();
@@ -174,19 +177,19 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		TestInputChannel slow = new TestInputChannel(gate, 1, false, true);
 		gate.setInputChannels(fast, slow);
 		AlternatingCheckpointBarrierHandler barrierHandler = barrierHandler(gate, target);
-		CheckpointedInputGate checkpointedGate = new CheckpointedInputGate(gate, barrierHandler, 0 /* offset */);
+		CheckpointedInputGate checkpointedGate = new CheckpointedInputGate(gate, barrierHandler  /* offset */);
 
 		sendBarrier(barrierId, checkpointType, fast, checkpointedGate);
 
 		assertEquals(checkpointType.isSavepoint(), target.triggeredCheckpoints.isEmpty());
-		assertEquals(checkpointType.isSavepoint(), barrierHandler.isBlocked(fast.getChannelIndex()));
-		assertFalse(barrierHandler.isBlocked(slow.getChannelIndex()));
+		assertEquals(checkpointType.isSavepoint(), barrierHandler.isBlocked(fast.getChannelInfo()));
+		assertFalse(barrierHandler.isBlocked(slow.getChannelInfo()));
 
 		sendBarrier(barrierId, checkpointType, slow, checkpointedGate);
 
 		assertEquals(singletonList(barrierId), target.triggeredCheckpoints);
 		for (InputChannel channel : gate.getInputChannels().values()) {
-			assertFalse(barrierHandler.isBlocked(channel.getChannelIndex()));
+			assertFalse(barrierHandler.isBlocked(channel.getChannelInfo()));
 			assertEquals(
 				String.format("channel %d should be resumed", channel.getChannelIndex()),
 				checkpointType.isSavepoint(),
@@ -205,8 +208,8 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		InputGate[] channelIndexToInputGate = new InputGate[inputGate.getNumberOfInputChannels()];
 		Arrays.fill(channelIndexToInputGate, inputGate);
 		return new AlternatingCheckpointBarrierHandler(
-			new CheckpointBarrierAligner(taskName, channelIndexToInputGate, singletonMap(inputGate, 0), target),
-			new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, taskName, target),
+			new CheckpointBarrierAligner(taskName, target, inputGate),
+			new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, taskName, target, inputGate),
 			target);
 	}
 
@@ -250,7 +253,7 @@ public class AlternatingCheckpointBarrierHandlerTest {
 			channels[i] = new TestInputChannel(gate, i, false, true);
 		}
 		gate.setInputChannels(channels);
-		return new CheckpointedInputGate(gate, barrierHandler(gate, target), 0);
+		return new CheckpointedInputGate(gate, barrierHandler(gate, target));
 	}
 
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java
index a653916..f415e8f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java
@@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.io;
 
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
@@ -34,10 +35,13 @@ import org.junit.Test;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.List;
 import java.util.Optional;
 import java.util.Random;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 /**
  * The test generates two random streams (input channels) which independently
@@ -65,8 +69,7 @@ public class CheckpointBarrierAlignerMassiveRandomTest {
 			CheckpointedInputGate checkpointedInputGate =
 				new CheckpointedInputGate(
 					myIG,
-					"Testing: No task associated",
-					new DummyCheckpointInvokable());
+					new CheckpointBarrierAligner("Testing: No task associated", new DummyCheckpointInvokable(), myIG));
 
 			for (int i = 0; i < 2000000; i++) {
 				BufferOrEvent boe = checkpointedInputGate.pollNext().get();
@@ -161,6 +164,13 @@ public class CheckpointBarrierAlignerMassiveRandomTest {
 		}
 
 		@Override
+		public List<InputChannelInfo> getChannelInfos() {
+			return IntStream.range(0, numberOfChannels)
+					.mapToObj(channelIndex -> new InputChannelInfo(0, channelIndex))
+					.collect(Collectors.toList());
+		}
+
+		@Override
 		public Optional<BufferOrEvent> getNext() throws IOException {
 			currentChannel = (currentChannel + 1) % numberOfChannels;
 			if (channelBlocked[currentChannel]) {
@@ -179,7 +189,7 @@ public class CheckpointBarrierAlignerMassiveRandomTest {
 							++currentBarriers[currentChannel],
 							System.currentTimeMillis(),
 							CheckpointOptions.forCheckpointWithDefaultLocation()),
-						currentChannel));
+						new InputChannelInfo(0, currentChannel)));
 			} else {
 				Buffer buffer = bufferPools[currentChannel].requestBuffer();
 				if (buffer == null) {
@@ -188,7 +198,7 @@ public class CheckpointBarrierAlignerMassiveRandomTest {
 					return getNext();
 				}
 				buffer.getMemorySegment().putLong(0, c++);
-				return Optional.of(new BufferOrEvent(buffer, currentChannel));
+				return Optional.of(new BufferOrEvent(buffer, new InputChannelInfo(0, currentChannel)));
 			}
 		}
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java
index ea4e004..0c478a0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
@@ -835,12 +836,13 @@ public abstract class CheckpointBarrierAlignerTestBase {
 	// ------------------------------------------------------------------------
 
 	private static BufferOrEvent createBarrier(long checkpointId, int channel) {
-		return new BufferOrEvent(new CheckpointBarrier(
-			checkpointId, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), channel);
+		return new BufferOrEvent(
+			new CheckpointBarrier(checkpointId, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()),
+			new InputChannelInfo(0, channel));
 	}
 
 	private static BufferOrEvent createCancellationBarrier(long checkpointId, int channel) {
-		return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), channel);
+		return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), new InputChannelInfo(0, channel));
 	}
 
 	private static BufferOrEvent createBuffer(int channel) {
@@ -857,11 +859,11 @@ public abstract class CheckpointBarrierAlignerTestBase {
 		// retain an additional time so it does not get disposed after being read by the input gate
 		buf.retainBuffer();
 
-		return new BufferOrEvent(buf, channel);
+		return new BufferOrEvent(buf, new InputChannelInfo(0, channel));
 	}
 
 	private static BufferOrEvent createEndOfPartition(int channel) {
-		return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, channel);
+		return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, new InputChannelInfo(0, channel));
 	}
 
 	private static void check(BufferOrEvent expected, BufferOrEvent present, int pageSize) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java
index 1deca6a..5a3eafc 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.io;
 
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
@@ -369,16 +370,16 @@ public class CheckpointBarrierTrackerTest {
 	}
 
 	private static BufferOrEvent createBarrier(long id, int channel) {
-		return new BufferOrEvent(new CheckpointBarrier(id, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), channel);
+		return new BufferOrEvent(new CheckpointBarrier(id, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), new InputChannelInfo(0, channel));
 	}
 
 	private static BufferOrEvent createCancellationBarrier(long id, int channel) {
-		return new BufferOrEvent(new CancelCheckpointMarker(id), channel);
+		return new BufferOrEvent(new CancelCheckpointMarker(id), new InputChannelInfo(0, channel));
 	}
 
 	private static BufferOrEvent createBuffer(int channel) {
 		return new BufferOrEvent(
-				new NetworkBuffer(MemorySegmentFactory.wrap(new byte[]{1, 2}), FreeingBufferRecycler.INSTANCE), channel);
+				new NetworkBuffer(MemorySegmentFactory.wrap(new byte[]{1, 2}), FreeingBufferRecycler.INSTANCE), new InputChannelInfo(0, channel));
 	}
 
 	// ------------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java
index 1bb4972..37d4865 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.io;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.RuntimeEvent;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
@@ -77,13 +78,13 @@ public class CheckpointBarrierUnalignerCancellationTest {
 	@Test
 	public void test() throws Exception {
 		TestInvokable invokable = new TestInvokable();
-		CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(new int[]{numChannels}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
+		CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate(0, numChannels));
 
 		for (RuntimeEvent e : events) {
 			if (e instanceof CancelCheckpointMarker) {
 				unaligner.processCancellationBarrier((CancelCheckpointMarker) e);
 			} else if (e instanceof CheckpointBarrier) {
-				unaligner.processBarrier((CheckpointBarrier) e, channel);
+				unaligner.processBarrier((CheckpointBarrier) e, new InputChannelInfo(0, channel));
 			} else {
 				throw new IllegalArgumentException("unexpected event type: " + e);
 			}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java
index 0ab8ee2..35e5b08 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java
@@ -30,8 +30,8 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
-import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
@@ -479,7 +479,7 @@ public class CheckpointBarrierUnalignerTest {
 	}
 
 	/**
-	 * Tests the race condition between {@link CheckpointBarrierUnaligner#processBarrier(CheckpointBarrier, int)}
+	 * Tests the race condition between {@link CheckpointBarrierHandler#processBarrier(CheckpointBarrier, InputChannelInfo)}
 	 * and {@link ThreadSafeUnaligner#notifyBarrierReceived(CheckpointBarrier, InputChannelInfo)}. The barrier
 	 * notification will trigger an async checkpoint (ch1) via mailbox, and meanwhile the barrier processing will
 	 * execute the next checkpoint (ch2) directly in advance. When the ch1 action is taken from mailbox to execute,
@@ -488,7 +488,7 @@ public class CheckpointBarrierUnalignerTest {
 	@Test
 	public void testConcurrentProcessBarrierAndNotifyBarrierReceived() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
-		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
+		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate());
 		final InputChannelInfo channelInfo = new InputChannelInfo(0, 0);
 		final ExecutorService executor = Executors.newFixedThreadPool(1);
 
@@ -502,7 +502,7 @@ public class CheckpointBarrierUnalignerTest {
 			result.get();
 
 			// Execute the checkpoint (ch1) directly because it is triggered by main thread.
-			handler.processBarrier(buildCheckpointBarrier(1), 0);
+			handler.processBarrier(buildCheckpointBarrier(1), new InputChannelInfo(0, 0));
 
 			// Run the previous queued mailbox action to execute ch0.
 			invokable.runMailboxStep();
@@ -523,8 +523,7 @@ public class CheckpointBarrierUnalignerTest {
 	@Test
 	public void testProcessCancellationBarrierAfterNotifyBarrierReceived() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
-		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
+		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate());
 
 		ThreadSafeUnaligner unaligner = handler.getThreadSafeUnaligner();
 		// should trigger respective checkpoint
@@ -541,16 +540,15 @@ public class CheckpointBarrierUnalignerTest {
 	/**
 	 * Tests {@link CheckpointBarrierUnaligner#processCancellationBarrier(CancelCheckpointMarker)}
 	 * abort the current pending checkpoint triggered by
-	 * {@link CheckpointBarrierUnaligner#processBarrier(CheckpointBarrier, int)}.
+	 * {@link CheckpointBarrierHandler#processBarrier(CheckpointBarrier, InputChannelInfo)}.
 	 */
 	@Test
 	public void testProcessCancellationBarrierAfterProcessBarrier() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
-		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
+		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate());
 
 		// should trigger respective checkpoint
-		handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), 0);
+		handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), new InputChannelInfo(0, 0));
 
 		assertTrue(handler.isCheckpointPending());
 		assertTrue(handler.getThreadSafeUnaligner().isCheckpointPending());
@@ -563,15 +561,14 @@ public class CheckpointBarrierUnalignerTest {
 	@Test
 	public void testProcessCancellationBarrierBeforeProcessAndReceiveBarrier() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
-		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
+		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate());
 
 		handler.processCancellationBarrier(new CancelCheckpointMarker(DEFAULT_CHECKPOINT_ID));
 
 		verifyTriggeredCheckpoint(handler, invokable, DEFAULT_CHECKPOINT_ID);
 
 		// it would not trigger checkpoint since the respective cancellation barrier already happened before
-		handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), 0);
+		handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), new InputChannelInfo(0, 0));
 		handler.getThreadSafeUnaligner().notifyBarrierReceived(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), new InputChannelInfo(0, 0));
 
 		verifyTriggeredCheckpoint(handler, invokable, DEFAULT_CHECKPOINT_ID);
@@ -608,8 +605,7 @@ public class CheckpointBarrierUnalignerTest {
 	public void testEndOfStreamWithPendingCheckpoint() throws Exception {
 		final int numberOfChannels = 2;
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
-		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { numberOfChannels }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
+		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate(0, numberOfChannels));
 
 		ThreadSafeUnaligner unaligner = handler.getThreadSafeUnaligner();
 		// should trigger respective checkpoint
@@ -639,26 +635,26 @@ public class CheckpointBarrierUnalignerTest {
 				checkpointId,
 				System.currentTimeMillis(),
 				CheckpointOptions.forCheckpointWithDefaultLocation()),
-			channel);
+			new InputChannelInfo(0, channel));
 	}
 
 	private BufferOrEvent createCancellationBarrier(long checkpointId, int channel) {
 		sizeCounter++;
-		return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), channel);
+		return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), new InputChannelInfo(0, channel));
 	}
 
 	private BufferOrEvent createBuffer(int channel) {
 		final int size = sizeCounter++;
-		return new BufferOrEvent(TestBufferFactory.createBuffer(size), channel);
+		return new BufferOrEvent(TestBufferFactory.createBuffer(size), new InputChannelInfo(0, channel));
 	}
 
 	private static BufferOrEvent createEndOfPartition(int channel) {
-		return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, channel);
+		return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, new InputChannelInfo(0, channel));
 	}
 
 	private CheckpointedInputGate createInputGate(
 			int numberOfChannels,
-			AbstractInvokable toNotify) throws IOException, InterruptedException {
+			AbstractInvokable toNotify) throws IOException {
 		final NettyShuffleEnvironment environment = new NettyShuffleEnvironmentBuilder().build();
 		SingleInputGate gate = new SingleInputGateBuilder()
 			.setNumberOfChannels(numberOfChannels)
@@ -687,12 +683,12 @@ public class CheckpointBarrierUnalignerTest {
 			if (bufferOrEvent.isEvent()) {
 				bufferOrEvent = new BufferOrEvent(
 					EventSerializer.toBuffer(bufferOrEvent.getEvent()),
-					bufferOrEvent.getChannelIndex(),
+					bufferOrEvent.getChannelInfo(),
 					bufferOrEvent.moreAvailable());
 			}
-			((RemoteInputChannel) inputGate.getChannel(bufferOrEvent.getChannelIndex())).onBuffer(
+			((RemoteInputChannel) inputGate.getChannel(bufferOrEvent.getChannelInfo().getInputChannelIdx())).onBuffer(
 				bufferOrEvent.getBuffer(),
-				sequenceNumbers[bufferOrEvent.getChannelIndex()]++,
+				sequenceNumbers[bufferOrEvent.getChannelInfo().getInputChannelIdx()]++,
 				0);
 
 			while (inputGate.pollNext().map(output::add).isPresent()) {
@@ -702,12 +698,12 @@ public class CheckpointBarrierUnalignerTest {
 		return sequence;
 	}
 
-	private CheckpointedInputGate createCheckpointedInputGate(InputGate gate, AbstractInvokable toNotify) {
+	private CheckpointedInputGate createCheckpointedInputGate(IndexedInputGate gate, AbstractInvokable toNotify) {
 		final CheckpointBarrierUnaligner barrierHandler = new CheckpointBarrierUnaligner(
-			new int[]{ gate.getNumberOfInputChannels() },
 			new TestSubtaskCheckpointCoordinator(channelStateWriter),
 			"Test",
-			toNotify);
+			toNotify,
+			gate);
 		barrierHandler.getBufferReceivedListener().ifPresent(gate::registerBufferReceivedListener);
 		return new CheckpointedInputGate(gate, barrierHandler);
 	}
@@ -719,7 +715,7 @@ public class CheckpointBarrierUnalignerTest {
 
 	private Collection<BufferOrEvent> getAndResetInflightData() {
 		final List<BufferOrEvent> inflightData = channelStateWriter.getAddedInput().entries().stream()
-			.map(entry -> new BufferOrEvent(entry.getValue(), entry.getKey().getInputChannelIdx()))
+			.map(entry -> new BufferOrEvent(entry.getValue(), entry.getKey()))
 			.collect(Collectors.toList());
 		channelStateWriter.reset();
 		return inflightData;
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java
index 2837890..eb1b403 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java
@@ -28,6 +28,6 @@ public class CreditBasedCheckpointBarrierAlignerTest extends CheckpointBarrierAl
 
 	@Override
 	CheckpointedInputGate createBarrierBuffer(InputGate gate, AbstractInvokable toNotify) {
-		return new CheckpointedInputGate(gate, "Testing", toNotify);
+		return new CheckpointedInputGate(gate, new CheckpointBarrierAligner("Testing", toNotify, gate));
 	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java
index a17ead0..9112bf0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java
@@ -25,9 +25,6 @@ import org.apache.flink.runtime.checkpoint.channel.MockChannelStateWriter;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.BufferReceivedListener;
 import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
-import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
-import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
-import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
 import org.apache.flink.streaming.api.CheckpointingMode;
@@ -41,10 +38,8 @@ import org.junit.Test;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.Map;
 import java.util.stream.Collectors;
 
-import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
 /**
@@ -53,32 +48,6 @@ import static org.junit.Assert.assertTrue;
 public class InputProcessorUtilTest {
 
 	@Test
-	public void testGenerateChannelIndexToInputGateMap() {
-		SingleInputGate ig1 = new SingleInputGateBuilder().setNumberOfChannels(2).build();
-		SingleInputGate ig2 = new SingleInputGateBuilder().setNumberOfChannels(3).build();
-
-		InputGate[] channelIndexToInputGateMap = InputProcessorUtil.generateChannelIndexToInputGateMap(ig1, ig2);
-		assertEquals(5, channelIndexToInputGateMap.length);
-		assertEquals(ig1, channelIndexToInputGateMap[0]);
-		assertEquals(ig1, channelIndexToInputGateMap[1]);
-		assertEquals(ig2, channelIndexToInputGateMap[2]);
-		assertEquals(ig2, channelIndexToInputGateMap[3]);
-		assertEquals(ig2, channelIndexToInputGateMap[4]);
-	}
-
-	@Test
-	public void testGenerateInputGateToChannelIndexOffsetMap() {
-		SingleInputGate ig1 = new SingleInputGateBuilder().setNumberOfChannels(3).build();
-		SingleInputGate ig2 = new SingleInputGateBuilder().setNumberOfChannels(2).build();
-
-		Map<InputGate, Integer> inputGateToChannelIndexOffsetMap =
-			InputProcessorUtil.generateInputGateToChannelIndexOffsetMap(ig1, ig2);
-		assertEquals(2, inputGateToChannelIndexOffsetMap.size());
-		assertEquals(0, inputGateToChannelIndexOffsetMap.get(ig1).intValue());
-		assertEquals(3, inputGateToChannelIndexOffsetMap.get(ig2).intValue());
-	}
-
-	@Test
 	public void testCreateCheckpointedMultipleInputGate() throws Exception {
 		try (CloseableRegistry registry = new CloseableRegistry()) {
 			MockEnvironment environment = new MockEnvironmentBuilder().build();
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
index a536cbd..8779200 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
@@ -113,7 +113,7 @@ public class MockInputGate extends InputGate {
 			return Optional.empty();
 		}
 
-		int channelIdx = next.getChannelIndex();
+		int channelIdx = next.getChannelInfo().getInputChannelIdx();
 		if (closed[channelIdx]) {
 			throw new RuntimeException("Inconsistent: Channel " + channelIdx
 				+ " has data even though it is already closed.");
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java
index b61da52..5fd2aef 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.core.io.InputStatus;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
@@ -109,7 +110,7 @@ public class StreamTaskNetworkInputTest {
 		CheckpointBarrier barrier = new CheckpointBarrier(0, 0, CheckpointOptions.forCheckpointWithDefaultLocation());
 
 		List<BufferOrEvent> buffers = new ArrayList<>(2);
-		buffers.add(new BufferOrEvent(barrier, 0));
+		buffers.add(new BufferOrEvent(barrier, new InputChannelInfo(0, 0)));
 		buffers.add(createDataBuffer());
 
 		VerifyRecordsDataOutput output = new VerifyRecordsDataOutput<>();
@@ -121,22 +122,24 @@ public class StreamTaskNetworkInputTest {
 
 	@Test
 	public void testSnapshotWithTwoInputGates() throws Exception {
-		CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(
-				new int[]{ 1, 1 },
-				TestSubtaskCheckpointCoordinator.INSTANCE,
-				"test",
-				new DummyCheckpointInvokable());
-
 		SingleInputGate inputGate1 = new SingleInputGateBuilder().setSingleInputGateIndex(0).build();
 		RemoteInputChannel channel1 = InputChannelBuilder.newBuilder().buildRemoteChannel(inputGate1);
 		inputGate1.setInputChannels(channel1);
-		inputGate1.registerBufferReceivedListener(unaligner.getBufferReceivedListener().get());
-		StreamTaskNetworkInput<Long> input1 = createInput(unaligner, inputGate1);
 
 		SingleInputGate inputGate2 = new SingleInputGateBuilder().setSingleInputGateIndex(1).build();
 		RemoteInputChannel channel2 = InputChannelBuilder.newBuilder().buildRemoteChannel(inputGate2);
 		inputGate2.setInputChannels(channel2);
+
+		CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(
+			TestSubtaskCheckpointCoordinator.INSTANCE,
+			"test",
+			new DummyCheckpointInvokable(),
+			inputGate1,
+			inputGate2);
+		inputGate1.registerBufferReceivedListener(unaligner.getBufferReceivedListener().get());
 		inputGate2.registerBufferReceivedListener(unaligner.getBufferReceivedListener().get());
+
+		StreamTaskNetworkInput<Long> input1 = createInput(unaligner, inputGate1);
 		StreamTaskNetworkInput<Long> input2 = createInput(unaligner, inputGate2);
 
 		CheckpointBarrier barrier = new CheckpointBarrier(0, 0L, CheckpointOptions.forCheckpointWithDefaultLocation());
@@ -194,10 +197,10 @@ public class StreamTaskNetworkInputTest {
 			new CheckpointedInputGate(
 				inputGate.getInputGate(),
 				new CheckpointBarrierUnaligner(
-					new int[] { numInputChannels },
 					TestSubtaskCheckpointCoordinator.INSTANCE,
 					"test",
-					new DummyCheckpointInvokable())),
+					new DummyCheckpointInvokable(),
+					inputGate.getInputGate())),
 			inSerializer,
 			new StatusWatermarkValve(numInputChannels, output),
 			0,
@@ -261,7 +264,7 @@ public class StreamTaskNetworkInputTest {
 		serializeRecord(42L, bufferBuilder);
 		serializeRecord(44L, bufferBuilder);
 
-		return new BufferOrEvent(bufferConsumer.build(), 0, false);
+		return new BufferOrEvent(bufferConsumer.build(), new InputChannelInfo(0, 0), false);
 	}
 
 	private StreamTaskNetworkInput createStreamTaskNetworkInput(List<BufferOrEvent> buffers, DataOutput output) {