You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2019/05/10 13:23:35 UTC

[flink] 01/08: Revert "Merge pull request #8361 from pnowojski/f12434"

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit dd1eb082fdd291523ec931a01b295e407a2df1f5
Author: Piotr Nowojski <pi...@gmail.com>
AuthorDate: Fri May 10 15:22:23 2019 +0200

    Revert "Merge pull request #8361 from pnowojski/f12434"
    
    This reverts commit d3fd7a6794a8ffd16e00bb1867b6e62c3909a2d2.
---
 .../io/network/partition/consumer/InputGate.java   |  55 +-----
 .../partition/consumer/InputGateListener.java      |  35 ++++
 .../partition/consumer/SingleInputGate.java        |  87 ++++-----
 .../network/partition/consumer/UnionInputGate.java | 206 ++++++++++-----------
 .../partition/consumer/InputGateTestBase.java      |  91 ---------
 .../partition/consumer/SingleInputGateTest.java    |  54 ++++--
 .../partition/consumer/TestSingleInputGate.java    |  44 ++++-
 .../partition/consumer/UnionInputGateTest.java     |  22 +--
 .../consumer/StreamTestSingleInputGate.java        |  84 +--------
 .../runtime/io/BarrierBufferMassiveRandomTest.java |   7 +-
 .../flink/streaming/runtime/io/MockInputGate.java  |   9 +-
 11 files changed, 281 insertions(+), 413 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
index 03ac822..c270d37 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
@@ -22,9 +22,6 @@ import org.apache.flink.runtime.event.TaskEvent;
 
 import java.io.IOException;
 import java.util.Optional;
-import java.util.concurrent.CompletableFuture;
-
-import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
  * An input gate consumes one or more partitions of a single produced intermediate result.
@@ -68,65 +65,33 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * will have an input gate attached to it. This will provide its input, which will consist of one
  * subpartition from each partition of the intermediate result.
  */
-public abstract class InputGate implements AutoCloseable {
-
-	public static final CompletableFuture<?> AVAILABLE = CompletableFuture.completedFuture(null);
-
-	protected CompletableFuture<?> isAvailable = new CompletableFuture<>();
+public interface InputGate extends AutoCloseable {
 
-	public abstract int getNumberOfInputChannels();
+	int getNumberOfInputChannels();
 
-	public abstract String getOwningTaskName();
+	String getOwningTaskName();
 
-	public abstract boolean isFinished();
+	boolean isFinished();
 
-	public abstract void requestPartitions() throws IOException, InterruptedException;
+	void requestPartitions() throws IOException, InterruptedException;
 
 	/**
 	 * Blocking call waiting for next {@link BufferOrEvent}.
 	 *
 	 * @return {@code Optional.empty()} if {@link #isFinished()} returns true.
 	 */
-	public abstract Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException;
+	Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException;
 
 	/**
 	 * Poll the {@link BufferOrEvent}.
 	 *
 	 * @return {@code Optional.empty()} if there is no data to return or if {@link #isFinished()} returns true.
 	 */
-	public abstract Optional<BufferOrEvent> pollNextBufferOrEvent() throws IOException, InterruptedException;
-
-	public abstract void sendTaskEvent(TaskEvent event) throws IOException;
-
-	public abstract int getPageSize();
-
-	/**
-	 * @return a future that is completed if there are more records available. If there more records
-	 * available immediately, {@link #AVAILABLE} should be returned.
-	 */
-	public CompletableFuture<?> isAvailable() {
-		return isAvailable;
-	}
+	Optional<BufferOrEvent> pollNextBufferOrEvent() throws IOException, InterruptedException;
 
-	protected void resetIsAvailable() {
-		// try to avoid volatile access in isDone()}
-		if (isAvailable == AVAILABLE || isAvailable.isDone()) {
-			isAvailable = new CompletableFuture<>();
-		}
-	}
+	void sendTaskEvent(TaskEvent event) throws IOException;
 
-	/**
-	 * Simple pojo for INPUT, DATA and moreAvailable.
-	 */
-	protected static class InputWithData<INPUT, DATA> {
-		protected final INPUT input;
-		protected final DATA data;
-		protected final boolean moreAvailable;
+	void registerListener(InputGateListener listener);
 
-		InputWithData(INPUT input, DATA data, boolean moreAvailable) {
-			this.input = checkNotNull(input);
-			this.data = checkNotNull(data);
-			this.moreAvailable = moreAvailable;
-		}
-	}
+	int getPageSize();
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateListener.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateListener.java
new file mode 100644
index 0000000..00fa782
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateListener.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.consumer;
+
+/**
+ * Listener interface implemented by consumers of {@link InputGate} instances
+ * that want to be notified of availability of buffer or event instances.
+ */
+public interface InputGateListener {
+
+	/**
+	 * Notification callback if the input gate moves from zero to non-zero
+	 * available input channels with data.
+	 *
+	 * @param inputGate Input Gate that became available.
+	 */
+	void notifyInputGateNonEmpty(InputGate inputGate);
+
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index 19912b2..d40af83 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -56,7 +56,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Timer;
-import java.util.concurrent.CompletableFuture;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -102,7 +101,7 @@ import static org.apache.flink.util.Preconditions.checkState;
  * in two partitions (Partition 1 and 2). Each of these partitions is further partitioned into two
  * subpartitions -- one for each parallel reduce subtask.
  */
-public class SingleInputGate extends InputGate {
+public class SingleInputGate implements InputGate {
 
 	private static final Logger LOG = LoggerFactory.getLogger(SingleInputGate.class);
 
@@ -173,6 +172,9 @@ public class SingleInputGate extends InputGate {
 	/** Flag indicating whether all resources have been released. */
 	private volatile boolean isReleased;
 
+	/** Registered listener to forward buffer notifications to. */
+	private volatile InputGateListener inputGateListener;
+
 	private final List<TaskEvent> pendingEvents = new ArrayList<>();
 
 	private int numberOfUninitializedChannels;
@@ -529,21 +531,12 @@ public class SingleInputGate extends InputGate {
 		}
 
 		requestPartitions();
-		Optional<InputWithData<InputChannel, BufferAndAvailability>> next = waitAndGetNextData(blocking);
-		if (!next.isPresent()) {
-			return Optional.empty();
-		}
 
-		InputWithData<InputChannel, BufferAndAvailability> inputWithData = next.get();
-		return Optional.of(transformToBufferOrEvent(
-			inputWithData.data.buffer(),
-			inputWithData.moreAvailable,
-			inputWithData.input));
-	}
+		InputChannel currentChannel;
+		boolean moreAvailable;
+		Optional<BufferAndAvailability> result = Optional.empty();
 
-	private Optional<InputWithData<InputChannel, BufferAndAvailability>> waitAndGetNextData(boolean blocking)
-			throws IOException, InterruptedException {
-		while (true) {
+		do {
 			synchronized (inputChannelsWithData) {
 				while (inputChannelsWithData.size() == 0) {
 					if (isReleased) {
@@ -554,43 +547,30 @@ public class SingleInputGate extends InputGate {
 						inputChannelsWithData.wait();
 					}
 					else {
-						resetIsAvailable();
 						return Optional.empty();
 					}
 				}
 
-				InputChannel inputChannel = inputChannelsWithData.remove();
-
-				Optional<BufferAndAvailability> result = inputChannel.getNextBuffer();
+				currentChannel = inputChannelsWithData.remove();
+				enqueuedInputChannelsWithData.clear(currentChannel.getChannelIndex());
+				moreAvailable = !inputChannelsWithData.isEmpty();
+			}
 
-				if (result.isPresent() && result.get().moreAvailable()) {
-					// enqueue the inputChannel at the end to avoid starvation
-					inputChannelsWithData.add(inputChannel);
-				} else {
-					enqueuedInputChannelsWithData.clear(inputChannel.getChannelIndex());
-				}
+			result = currentChannel.getNextBuffer();
+		} while (!result.isPresent());
 
-				if (inputChannelsWithData.isEmpty()) {
-					resetIsAvailable();
-				}
-
-				if (result.isPresent()) {
-					return Optional.of(new InputWithData<>(
-						inputChannel,
-						result.get(),
-						!inputChannelsWithData.isEmpty()));
-				}
-			}
+		// this channel was now removed from the non-empty channels queue
+		// we re-add it in case it has more data, because in that case no "non-empty" notification
+		// will come for that channel
+		if (result.get().moreAvailable()) {
+			queueChannel(currentChannel);
+			moreAvailable = true;
 		}
-	}
 
-	private BufferOrEvent transformToBufferOrEvent(
-			Buffer buffer,
-			boolean moreAvailable,
-			InputChannel currentChannel) throws IOException, InterruptedException {
+		final Buffer buffer = result.get().buffer();
 		numBytesIn.inc(buffer.getSizeUnsafe());
 		if (buffer.isBuffer()) {
-			return new BufferOrEvent(buffer, currentChannel.getChannelIndex(), moreAvailable);
+			return Optional.of(new BufferOrEvent(buffer, currentChannel.getChannelIndex(), moreAvailable));
 		}
 		else {
 			final AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader());
@@ -609,10 +589,11 @@ public class SingleInputGate extends InputGate {
 				}
 
 				currentChannel.notifySubpartitionConsumed();
+
 				currentChannel.releaseAllResources();
 			}
 
-			return new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable);
+			return Optional.of(new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable));
 		}
 	}
 
@@ -633,6 +614,15 @@ public class SingleInputGate extends InputGate {
 	// Channel notifications
 	// ------------------------------------------------------------------------
 
+	@Override
+	public void registerListener(InputGateListener inputGateListener) {
+		if (this.inputGateListener == null) {
+			this.inputGateListener = inputGateListener;
+		} else {
+			throw new IllegalStateException("Multiple listeners");
+		}
+	}
+
 	void notifyChannelNonEmpty(InputChannel channel) {
 		queueChannel(checkNotNull(channel));
 	}
@@ -644,8 +634,6 @@ public class SingleInputGate extends InputGate {
 	private void queueChannel(InputChannel channel) {
 		int availableChannels;
 
-		CompletableFuture<?> toNotify = null;
-
 		synchronized (inputChannelsWithData) {
 			if (enqueuedInputChannelsWithData.get(channel.getChannelIndex())) {
 				return;
@@ -657,13 +645,14 @@ public class SingleInputGate extends InputGate {
 
 			if (availableChannels == 0) {
 				inputChannelsWithData.notifyAll();
-				toNotify = isAvailable;
-				isAvailable = AVAILABLE;
 			}
 		}
 
-		if (toNotify != null) {
-			toNotify.complete(null);
+		if (availableChannels == 0) {
+			InputGateListener listener = inputGateListener;
+			if (listener != null) {
+				listener.notifyInputGateNonEmpty(this);
+			}
 		}
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index 5019cfc..ea83004 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -25,12 +25,11 @@ import org.apache.flink.shaded.guava18.com.google.common.collect.Maps;
 import org.apache.flink.shaded.guava18.com.google.common.collect.Sets;
 
 import java.io.IOException;
-import java.util.Iterator;
-import java.util.LinkedHashSet;
+import java.util.ArrayDeque;
+import java.util.HashSet;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
-import java.util.concurrent.CompletableFuture;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -64,22 +63,27 @@ import static org.apache.flink.util.Preconditions.checkState;
  *
  * <strong>It is NOT possible to recursively union union input gates.</strong>
  */
-public class UnionInputGate extends InputGate {
+public class UnionInputGate implements InputGate, InputGateListener {
 
 	/** The input gates to union. */
 	private final InputGate[] inputGates;
 
 	private final Set<InputGate> inputGatesWithRemainingData;
 
+	/** Gates, which notified this input gate about available data. */
+	private final ArrayDeque<InputGate> inputGatesWithData = new ArrayDeque<>();
+
 	/**
-	 * Gates, which notified this input gate about available data. We are using it as a FIFO
-	 * queue of {@link InputGate}s to avoid starvation and provide some basic fairness.
+	 * Guardian against enqueuing an {@link InputGate} multiple times on {@code inputGatesWithData}.
 	 */
-	private final LinkedHashSet<InputGate> inputGatesWithData = new LinkedHashSet<>();
+	private final Set<InputGate> enqueuedInputGatesWithData = new HashSet<>();
 
 	/** The total number of input channels across all unioned input gates. */
 	private final int totalNumberOfInputChannels;
 
+	/** Registered listener to forward input gate notifications to. */
+	private volatile InputGateListener inputGateListener;
+
 	/**
 	 * A mapping from input gate to (logical) channel index offset. Valid channel indexes go from 0
 	 * (inclusive) to the total number of input channels (exclusive).
@@ -98,31 +102,20 @@ public class UnionInputGate extends InputGate {
 
 		int currentNumberOfInputChannels = 0;
 
-		synchronized (inputGatesWithData) {
-			for (InputGate inputGate : inputGates) {
-				if (inputGate instanceof UnionInputGate) {
-					// if we want to add support for this, we need to implement pollNextBufferOrEvent()
-					throw new UnsupportedOperationException("Cannot union a union of input gates.");
-				}
-
-				// The offset to use for buffer or event instances received from this input gate.
-				inputGateToIndexOffsetMap.put(checkNotNull(inputGate), currentNumberOfInputChannels);
-				inputGatesWithRemainingData.add(inputGate);
-
-				currentNumberOfInputChannels += inputGate.getNumberOfInputChannels();
+		for (InputGate inputGate : inputGates) {
+			if (inputGate instanceof UnionInputGate) {
+				// if we want to add support for this, we need to implement pollNextBufferOrEvent()
+				throw new UnsupportedOperationException("Cannot union a union of input gates.");
+			}
 
-				CompletableFuture<?> available = inputGate.isAvailable();
+			// The offset to use for buffer or event instances received from this input gate.
+			inputGateToIndexOffsetMap.put(checkNotNull(inputGate), currentNumberOfInputChannels);
+			inputGatesWithRemainingData.add(inputGate);
 
-				if (available.isDone()) {
-					inputGatesWithData.add(inputGate);
-				} else {
-					available.thenRun(() -> queueInputGate(inputGate));
-				}
-			}
+			currentNumberOfInputChannels += inputGate.getNumberOfInputChannels();
 
-			if (!inputGatesWithData.isEmpty()) {
-				isAvailable = AVAILABLE;
-			}
+			// Register the union gate as a listener for all input gates
+			inputGate.registerListener(this);
 		}
 
 		this.totalNumberOfInputChannels = currentNumberOfInputChannels;
@@ -166,15 +159,6 @@ public class UnionInputGate extends InputGate {
 
 	@Override
 	public Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException {
-		return getNextBufferOrEvent(true);
-	}
-
-	@Override
-	public Optional<BufferOrEvent> pollNextBufferOrEvent() throws IOException, InterruptedException {
-		return getNextBufferOrEvent(false);
-	}
-
-	private Optional<BufferOrEvent> getNextBufferOrEvent(boolean blocking) throws IOException, InterruptedException {
 		if (inputGatesWithRemainingData.isEmpty()) {
 			return Optional.empty();
 		}
@@ -182,87 +166,75 @@ public class UnionInputGate extends InputGate {
 		// Make sure to request the partitions, if they have not been requested before.
 		requestPartitions();
 
-		Optional<InputWithData<InputGate, BufferOrEvent>> next = waitAndGetNextData(blocking);
-		if (!next.isPresent()) {
-			return Optional.empty();
-		}
-
-		InputWithData<InputGate, BufferOrEvent> inputWithData = next.get();
-
-		handleEndOfPartitionEvent(inputWithData.data, inputWithData.input);
-		return Optional.of(adjustForUnionInputGate(
-			inputWithData.data,
-			inputWithData.input,
-			inputWithData.moreAvailable));
-	}
-
-	private Optional<InputWithData<InputGate, BufferOrEvent>> waitAndGetNextData(boolean blocking)
-			throws IOException, InterruptedException {
-		while (true) {
-			synchronized (inputGatesWithData) {
-				while (inputGatesWithData.size() == 0) {
-					if (blocking) {
-						inputGatesWithData.wait();
-					} else {
-						resetIsAvailable();
-						return Optional.empty();
-					}
-				}
+		InputGateWithData inputGateWithData = waitAndGetNextInputGate();
+		InputGate inputGate = inputGateWithData.inputGate;
+		BufferOrEvent bufferOrEvent = inputGateWithData.bufferOrEvent;
 
-				Iterator<InputGate> inputGateIterator = inputGatesWithData.iterator();
-				final InputGate inputGate = inputGateIterator.next();
-				inputGateIterator.remove();
-
-				// In case of inputGatesWithData being inaccurate do not block on an empty inputGate, but just poll the data.
-				Optional<BufferOrEvent> bufferOrEvent = inputGate.pollNextBufferOrEvent();
-
-				if (bufferOrEvent.isPresent() && bufferOrEvent.get().moreAvailable()) {
-					// enqueue the inputGate at the end to avoid starvation
-					inputGatesWithData.add(inputGate);
-				} else {
-					inputGate.isAvailable().thenRun(() -> queueInputGate(inputGate));
-				}
+		if (bufferOrEvent.moreAvailable()) {
+			// this buffer or event was now removed from the non-empty gates queue
+			// we re-add it in case it has more data, because in that case no "non-empty" notification
+			// will come for that gate
+			queueInputGate(inputGate);
+		}
 
-				if (inputGatesWithData.isEmpty()) {
-					resetIsAvailable();
-				}
+		if (bufferOrEvent.isEvent()
+			&& bufferOrEvent.getEvent().getClass() == EndOfPartitionEvent.class
+			&& inputGate.isFinished()) {
 
-				if (bufferOrEvent.isPresent()) {
-					return Optional.of(new InputWithData<>(
-						inputGate,
-						bufferOrEvent.get(),
-						!inputGatesWithData.isEmpty()));
-				}
+			checkState(!bufferOrEvent.moreAvailable());
+			if (!inputGatesWithRemainingData.remove(inputGate)) {
+				throw new IllegalStateException("Couldn't find input gate in set of remaining " +
+					"input gates.");
 			}
 		}
-	}
 
-	private BufferOrEvent adjustForUnionInputGate(
-		BufferOrEvent bufferOrEvent,
-		InputGate inputGate,
-		boolean moreInputGatesAvailable) {
 		// Set the channel index to identify the input channel (across all unioned input gates)
 		final int channelIndexOffset = inputGateToIndexOffsetMap.get(inputGate);
 
 		bufferOrEvent.setChannelIndex(channelIndexOffset + bufferOrEvent.getChannelIndex());
-		bufferOrEvent.setMoreAvailable(bufferOrEvent.moreAvailable() || moreInputGatesAvailable);
+		bufferOrEvent.setMoreAvailable(bufferOrEvent.moreAvailable() || inputGateWithData.moreInputGatesAvailable);
 
-		return bufferOrEvent;
+		return Optional.of(bufferOrEvent);
 	}
 
-	private void handleEndOfPartitionEvent(BufferOrEvent bufferOrEvent, InputGate inputGate) {
-		if (bufferOrEvent.isEvent()
-			&& bufferOrEvent.getEvent().getClass() == EndOfPartitionEvent.class
-			&& inputGate.isFinished()) {
+	@Override
+	public Optional<BufferOrEvent> pollNextBufferOrEvent() throws UnsupportedOperationException {
+		throw new UnsupportedOperationException();
+	}
 
-			checkState(!bufferOrEvent.moreAvailable());
-			if (!inputGatesWithRemainingData.remove(inputGate)) {
-				throw new IllegalStateException("Couldn't find input gate in set of remaining " +
-					"input gates.");
+	private InputGateWithData waitAndGetNextInputGate() throws IOException, InterruptedException {
+		while (true) {
+			InputGate inputGate;
+			boolean moreInputGatesAvailable;
+			synchronized (inputGatesWithData) {
+				while (inputGatesWithData.size() == 0) {
+					inputGatesWithData.wait();
+				}
+				inputGate = inputGatesWithData.remove();
+				enqueuedInputGatesWithData.remove(inputGate);
+				moreInputGatesAvailable = enqueuedInputGatesWithData.size() > 0;
+			}
+
+			// In case of inputGatesWithData being inaccurate do not block on an empty inputGate, but just poll the data.
+			Optional<BufferOrEvent> bufferOrEvent = inputGate.pollNextBufferOrEvent();
+			if (bufferOrEvent.isPresent()) {
+				return new InputGateWithData(inputGate, bufferOrEvent.get(), moreInputGatesAvailable);
 			}
 		}
 	}
 
+	private static class InputGateWithData {
+		private final InputGate inputGate;
+		private final BufferOrEvent bufferOrEvent;
+		private final boolean moreInputGatesAvailable;
+
+		InputGateWithData(InputGate inputGate, BufferOrEvent bufferOrEvent, boolean moreInputGatesAvailable) {
+			this.inputGate = checkNotNull(inputGate);
+			this.bufferOrEvent = checkNotNull(bufferOrEvent);
+			this.moreInputGatesAvailable = moreInputGatesAvailable;
+		}
+	}
+
 	@Override
 	public void sendTaskEvent(TaskEvent event) throws IOException {
 		for (InputGate inputGate : inputGates) {
@@ -271,6 +243,15 @@ public class UnionInputGate extends InputGate {
 	}
 
 	@Override
+	public void registerListener(InputGateListener listener) {
+		if (this.inputGateListener == null) {
+			this.inputGateListener = listener;
+		} else {
+			throw new IllegalStateException("Multiple listeners");
+		}
+	}
+
+	@Override
 	public int getPageSize() {
 		int pageSize = -1;
 		for (InputGate gate : inputGates) {
@@ -287,29 +268,34 @@ public class UnionInputGate extends InputGate {
 	public void close() throws IOException {
 	}
 
-	private void queueInputGate(InputGate inputGate) {
-		checkNotNull(inputGate);
+	@Override
+	public void notifyInputGateNonEmpty(InputGate inputGate) {
+		queueInputGate(checkNotNull(inputGate));
+	}
 
-		CompletableFuture<?> toNotify = null;
+	private void queueInputGate(InputGate inputGate) {
+		int availableInputGates;
 
 		synchronized (inputGatesWithData) {
-			if (inputGatesWithData.contains(inputGate)) {
+			if (enqueuedInputGatesWithData.contains(inputGate)) {
 				return;
 			}
 
-			int availableInputGates = inputGatesWithData.size();
+			availableInputGates = inputGatesWithData.size();
 
 			inputGatesWithData.add(inputGate);
+			enqueuedInputGatesWithData.add(inputGate);
 
 			if (availableInputGates == 0) {
 				inputGatesWithData.notifyAll();
-				toNotify = isAvailable;
-				isAvailable = AVAILABLE;
 			}
 		}
 
-		if (toNotify != null) {
-			toNotify.complete(null);
+		if (availableInputGates == 0) {
+			InputGateListener listener = inputGateListener;
+			if (listener != null) {
+				listener.notifyInputGateNonEmpty(this);
+			}
 		}
 	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
deleted file mode 100644
index ddf027c..0000000
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.network.partition.consumer;
-
-import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
-
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-import org.junit.runners.Parameterized.Parameter;
-import org.junit.runners.Parameterized.Parameters;
-
-import java.util.Arrays;
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-
-import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-
-/**
- * Test base for {@link InputGate}.
- */
-@RunWith(Parameterized.class)
-public abstract class InputGateTestBase {
-
-	@Parameter
-	public boolean enableCreditBasedFlowControl;
-
-	@Parameters(name = "Credit-based = {0}")
-	public static List<Boolean> parameters() {
-		return Arrays.asList(Boolean.TRUE, Boolean.FALSE);
-	}
-
-	protected void testIsAvailable(
-			InputGate inputGateToTest,
-			SingleInputGate inputGateToNotify,
-			TestInputChannel inputChannelWithNewData) throws Exception {
-
-		assertFalse(inputGateToTest.isAvailable().isDone());
-		assertFalse(inputGateToTest.pollNextBufferOrEvent().isPresent());
-
-		CompletableFuture<?> isAvailable = inputGateToTest.isAvailable();
-
-		assertFalse(inputGateToTest.isAvailable().isDone());
-		assertFalse(inputGateToTest.pollNextBufferOrEvent().isPresent());
-
-		assertEquals(isAvailable, inputGateToTest.isAvailable());
-
-		inputChannelWithNewData.readBuffer();
-		inputGateToNotify.notifyChannelNonEmpty(inputChannelWithNewData);
-
-		assertTrue(isAvailable.isDone());
-		assertTrue(inputGateToTest.isAvailable().isDone());
-	}
-
-	protected SingleInputGate createInputGate() {
-		return createInputGate(2);
-	}
-
-	protected SingleInputGate createInputGate(int numberOfInputChannels) {
-		return createInputGate(numberOfInputChannels, ResultPartitionType.PIPELINED);
-	}
-
-	protected SingleInputGate createInputGate(
-			int numberOfInputChannels, ResultPartitionType partitionType) {
-		SingleInputGate inputGate = createSingleInputGate(
-			numberOfInputChannels,
-			partitionType,
-			enableCreditBasedFlowControl);
-
-		assertEquals(partitionType, inputGate.getConsumedPartitionType());
-		return inputGate;
-	}
-}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index a6f824d..71e4f5a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -50,13 +50,18 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.taskmanager.NoOpTaskActions;
 
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
@@ -75,7 +80,17 @@ import static org.mockito.Mockito.when;
 /**
  * Tests for {@link SingleInputGate}.
  */
-public class SingleInputGateTest extends InputGateTestBase {
+@RunWith(Parameterized.class)
+public class SingleInputGateTest {
+
+	@Parameterized.Parameter
+	public boolean enableCreditBasedFlowControl;
+
+	@Parameterized.Parameters(name = "Credit-based = {0}")
+	public static List<Boolean> parameters() {
+		return Arrays.asList(Boolean.TRUE, Boolean.FALSE);
+	}
+
 	/**
 	 * Tests basic correctness of buffer-or-event interleaving and correct <code>null</code> return
 	 * value after receiving all end-of-partition events.
@@ -116,15 +131,6 @@ public class SingleInputGateTest extends InputGateTestBase {
 		assertTrue(inputGate.isFinished());
 	}
 
-	@Test
-	public void testIsAvailable() throws Exception {
-		final SingleInputGate inputGate = createInputGate(1);
-		TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
-		inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannel);
-
-		testIsAvailable(inputGate, inputGate, inputChannel);
-	}
-
 	@Test(timeout = 120 * 1000)
 	public void testIsMoreAvailableReadingFromSingleInputChannel() throws Exception {
 		// Setup
@@ -540,6 +546,22 @@ public class SingleInputGateTest extends InputGateTestBase {
 
 	// ---------------------------------------------------------------------------------------------
 
+	private SingleInputGate createInputGate() {
+		return createInputGate(2);
+	}
+
+	private SingleInputGate createInputGate(int numberOfInputChannels) {
+		return createInputGate(numberOfInputChannels, ResultPartitionType.PIPELINED);
+	}
+
+	private SingleInputGate createInputGate(int numberOfInputChannels, ResultPartitionType partitionType) {
+		SingleInputGate inputGate = createSingleInputGate(numberOfInputChannels, partitionType, enableCreditBasedFlowControl);
+
+		assertEquals(partitionType, inputGate.getConsumedPartitionType());
+
+		return inputGate;
+	}
+
 	private void addUnknownInputChannel(
 			NetworkEnvironment network,
 			SingleInputGate inputGate,
@@ -597,7 +619,17 @@ public class SingleInputGateTest extends InputGateTestBase {
 		assertEquals(expectedChannelIndex, bufferOrEvent.get().getChannelIndex());
 		assertEquals(expectedMoreAvailable, bufferOrEvent.get().moreAvailable());
 		if (!expectedMoreAvailable) {
-			assertFalse(inputGate.pollNextBufferOrEvent().isPresent());
+			try {
+				assertFalse(inputGate.pollNextBufferOrEvent().isPresent());
+			}
+			catch (UnsupportedOperationException ex) {
+				/**
+				 * {@link UnionInputGate#pollNextBufferOrEvent()} is unsupported at the moment.
+				 */
+				if (!(inputGate instanceof UnionInputGate)) {
+					throw ex;
+				}
+			}
 		}
 	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
index f60cdb9..464ab7c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
@@ -20,8 +20,17 @@ package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.lang.reflect.Field;
+import java.util.ArrayDeque;
+
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.spy;
 
 /**
  * A test input gate to mock reading data.
@@ -35,8 +44,39 @@ public class TestSingleInputGate {
 	public TestSingleInputGate(int numberOfInputChannels, boolean initialize) {
 		checkArgument(numberOfInputChannels >= 1);
 
-		inputGate = createSingleInputGate(numberOfInputChannels);
-		inputChannels = new TestInputChannel[numberOfInputChannels];
+		SingleInputGate realGate = createSingleInputGate(numberOfInputChannels);
+
+		this.inputGate = spy(realGate);
+
+		// Notify about late registrations (added for DataSinkTaskTest#testUnionDataSinkTask).
+		// After merging registerInputOutput and invoke, we have to make sure that the test
+		// notifications happen at the expected time. In real programs, this is guaranteed by
+		// the instantiation and request partition life cycle.
+		try {
+			Field f = realGate.getClass().getDeclaredField("inputChannelsWithData");
+			f.setAccessible(true);
+			final ArrayDeque<InputChannel> notifications = (ArrayDeque<InputChannel>) f.get(realGate);
+
+			doAnswer(new Answer<Void>() {
+				@Override
+				public Void answer(InvocationOnMock invocation) throws Throwable {
+					invocation.callRealMethod();
+
+					synchronized (notifications) {
+						if (!notifications.isEmpty()) {
+							InputGateListener listener = (InputGateListener) invocation.getArguments()[0];
+							listener.notifyInputGateNonEmpty(inputGate);
+						}
+					}
+
+					return null;
+				}
+			}).when(inputGate).registerListener(any(InputGateListener.class));
+		} catch (Exception e) {
+			throw new RuntimeException(e);
+		}
+
+		this.inputChannels = new TestInputChannel[numberOfInputChannels];
 
 		if (initialize) {
 			for (int i = 0; i < numberOfInputChannels; i++) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index 9dca613..082ccec 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -18,10 +18,9 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
-import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
-
 import org.junit.Test;
 
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.verifyBufferOrEvent;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -30,7 +29,7 @@ import static org.junit.Assert.assertTrue;
 /**
  * Tests for {@link UnionInputGate}.
  */
-public class UnionInputGateTest extends InputGateTestBase {
+public class UnionInputGateTest {
 
 	/**
 	 * Tests basic correctness of buffer-or-event interleaving and correct <code>null</code> return
@@ -42,8 +41,8 @@ public class UnionInputGateTest extends InputGateTestBase {
 	@Test(timeout = 120 * 1000)
 	public void testBasicGetNextLogic() throws Exception {
 		// Setup
-		final SingleInputGate ig1 = createInputGate(3);
-		final SingleInputGate ig2 = createInputGate(5);
+		final SingleInputGate ig1 = createSingleInputGate(3);
+		final SingleInputGate ig2 = createSingleInputGate(5);
 
 		final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2});
 
@@ -102,17 +101,4 @@ public class UnionInputGateTest extends InputGateTestBase {
 		assertTrue(union.isFinished());
 		assertFalse(union.getNextBufferOrEvent().isPresent());
 	}
-
-	@Test
-	public void testIsAvailable() throws Exception {
-		final SingleInputGate inputGate1 = createInputGate(1);
-		TestInputChannel inputChannel1 = new TestInputChannel(inputGate1, 0);
-		inputGate1.setInputChannel(new IntermediateResultPartitionID(), inputChannel1);
-
-		final SingleInputGate inputGate2 = createInputGate(1);
-		TestInputChannel inputChannel2 = new TestInputChannel(inputGate2, 0);
-		inputGate2.setInputChannel(new IntermediateResultPartitionID(), inputChannel2);
-
-		testIsAvailable(new UnionInputGate(inputGate1, inputGate2), inputGate1, inputChannel1);
-	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
index b9fe84f..dbb81ab 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
@@ -21,16 +21,12 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer;
 import org.apache.flink.runtime.io.network.api.serialization.SpanningRecordSerializer;
-import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
-import org.apache.flink.runtime.io.network.buffer.BufferListener;
-import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability;
 import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel.BufferAndAvailabilityProvider;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
@@ -44,6 +40,7 @@ import java.util.concurrent.ConcurrentLinkedQueue;
 
 import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.buildSingleBuffer;
 import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.createBufferBuilder;
+import static org.mockito.Mockito.doReturn;
 
 /**
  * Test {@link InputGate} that allows setting multiple channels. Use
@@ -79,7 +76,7 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
 		inputQueues = new ConcurrentLinkedQueue[numInputChannels];
 
 		setupInputChannels();
-		inputGate.setBufferPool(new NoOpBufferPool(bufferSize));
+		doReturn(bufferSize).when(inputGate).getPageSize();
 	}
 
 	@SuppressWarnings("unchecked")
@@ -222,81 +219,4 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
 			return isEvent;
 		}
 	}
-
-	private static class NoOpBufferPool implements BufferPool {
-		private int bufferSize;
-
-		public NoOpBufferPool(int bufferSize) {
-			this.bufferSize = bufferSize;
-		}
-
-		@Override
-		public void lazyDestroy() {
-		}
-
-		@Override
-		public int getMemorySegmentSize() {
-			return bufferSize;
-		}
-
-		@Override
-		public Buffer requestBuffer() throws IOException {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public Buffer requestBufferBlocking() throws IOException, InterruptedException {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public BufferBuilder requestBufferBuilderBlocking() throws IOException, InterruptedException {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public boolean addBufferListener(BufferListener listener) {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public boolean isDestroyed() {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public int getNumberOfRequiredMemorySegments() {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public int getMaxNumberOfMemorySegments() {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public int getNumBuffers() {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public void setNumBuffers(int numBuffers) throws IOException {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public int getNumberOfAvailableMemorySegments() {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public int bestEffortGetNumOfUsedBuffers() {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public void recycle(MemorySegment memorySegment) {
-			throw new UnsupportedOperationException();
-		}
-	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java
index b1a3ad5..ded52ff 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGateListener;
 
 import org.junit.Test;
 
@@ -129,7 +130,7 @@ public class BarrierBufferMassiveRandomTest {
 		}
 	}
 
-	private static class RandomGeneratingInputGate extends InputGate {
+	private static class RandomGeneratingInputGate implements InputGate {
 
 		private final int numberOfChannels;
 		private final BufferPool[] bufferPools;
@@ -150,7 +151,6 @@ public class BarrierBufferMassiveRandomTest {
 			this.bufferPools = bufferPools;
 			this.barrierGens = barrierGens;
 			this.owningTaskName = owningTaskName;
-			this.isAvailable = AVAILABLE;
 		}
 
 		@Override
@@ -199,6 +199,9 @@ public class BarrierBufferMassiveRandomTest {
 		public void sendTaskEvent(TaskEvent event) {}
 
 		@Override
+		public void registerListener(InputGateListener listener) {}
+
+		@Override
 		public int getPageSize() {
 			return PAGE_SIZE;
 		}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
index 30941ab..a29cbf5 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGateListener;
 
 import java.util.ArrayDeque;
 import java.util.List;
@@ -31,7 +32,7 @@ import java.util.Queue;
 /**
  * Mock {@link InputGate}.
  */
-public class MockInputGate extends InputGate {
+public class MockInputGate implements InputGate {
 
 	private final int pageSize;
 
@@ -55,8 +56,6 @@ public class MockInputGate extends InputGate {
 		this.bufferOrEvents = new ArrayDeque<BufferOrEvent>(bufferOrEvents);
 		this.closed = new boolean[numberOfChannels];
 		this.owningTaskName = owningTaskName;
-
-		isAvailable = AVAILABLE;
 	}
 
 	@Override
@@ -112,6 +111,10 @@ public class MockInputGate extends InputGate {
 	}
 
 	@Override
+	public void registerListener(InputGateListener listener) {
+	}
+
+	@Override
 	public void close() {
 	}
 }