You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2019/05/10 13:21:44 UTC

[flink] branch master updated: Merge pull request #8361 from pnowojski/f12434

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new d3fd7a6  Merge pull request #8361 from pnowojski/f12434
d3fd7a6 is described below

commit d3fd7a6794a8ffd16e00bb1867b6e62c3909a2d2
Author: Piotr Nowojski <pn...@users.noreply.github.com>
AuthorDate: Fri May 10 15:21:35 2019 +0200

    Merge pull request #8361 from pnowojski/f12434
    
    [FLINK-12434][network] Replace listeners with CompletableFuture in InputGates
---
 .../io/network/partition/consumer/InputGate.java   |  55 +++++-
 .../partition/consumer/InputGateListener.java      |  35 ----
 .../partition/consumer/SingleInputGate.java        |  87 +++++----
 .../network/partition/consumer/UnionInputGate.java | 206 +++++++++++----------
 .../partition/consumer/InputGateTestBase.java      |  91 +++++++++
 .../partition/consumer/SingleInputGateTest.java    |  54 ++----
 .../partition/consumer/TestSingleInputGate.java    |  44 +----
 .../partition/consumer/UnionInputGateTest.java     |  22 ++-
 .../consumer/StreamTestSingleInputGate.java        |  84 ++++++++-
 .../runtime/io/BarrierBufferMassiveRandomTest.java |   7 +-
 .../flink/streaming/runtime/io/MockInputGate.java  |   9 +-
 11 files changed, 413 insertions(+), 281 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
index c270d37..03ac822 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
@@ -22,6 +22,9 @@ import org.apache.flink.runtime.event.TaskEvent;
 
 import java.io.IOException;
 import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
  * An input gate consumes one or more partitions of a single produced intermediate result.
@@ -65,33 +68,65 @@ import java.util.Optional;
  * will have an input gate attached to it. This will provide its input, which will consist of one
  * subpartition from each partition of the intermediate result.
  */
-public interface InputGate extends AutoCloseable {
+public abstract class InputGate implements AutoCloseable {
+
+	public static final CompletableFuture<?> AVAILABLE = CompletableFuture.completedFuture(null);
+
+	protected CompletableFuture<?> isAvailable = new CompletableFuture<>();
 
-	int getNumberOfInputChannels();
+	public abstract int getNumberOfInputChannels();
 
-	String getOwningTaskName();
+	public abstract String getOwningTaskName();
 
-	boolean isFinished();
+	public abstract boolean isFinished();
 
-	void requestPartitions() throws IOException, InterruptedException;
+	public abstract void requestPartitions() throws IOException, InterruptedException;
 
 	/**
 	 * Blocking call waiting for next {@link BufferOrEvent}.
 	 *
 	 * @return {@code Optional.empty()} if {@link #isFinished()} returns true.
 	 */
-	Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException;
+	public abstract Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException;
 
 	/**
 	 * Poll the {@link BufferOrEvent}.
 	 *
 	 * @return {@code Optional.empty()} if there is no data to return or if {@link #isFinished()} returns true.
 	 */
-	Optional<BufferOrEvent> pollNextBufferOrEvent() throws IOException, InterruptedException;
+	public abstract Optional<BufferOrEvent> pollNextBufferOrEvent() throws IOException, InterruptedException;
+
+	public abstract void sendTaskEvent(TaskEvent event) throws IOException;
+
+	public abstract int getPageSize();
+
+	/**
+	 * @return a future that is completed if there are more records available. If there more records
+	 * available immediately, {@link #AVAILABLE} should be returned.
+	 */
+	public CompletableFuture<?> isAvailable() {
+		return isAvailable;
+	}
 
-	void sendTaskEvent(TaskEvent event) throws IOException;
+	protected void resetIsAvailable() {
+		// try to avoid volatile access in isDone()}
+		if (isAvailable == AVAILABLE || isAvailable.isDone()) {
+			isAvailable = new CompletableFuture<>();
+		}
+	}
 
-	void registerListener(InputGateListener listener);
+	/**
+	 * Simple pojo for INPUT, DATA and moreAvailable.
+	 */
+	protected static class InputWithData<INPUT, DATA> {
+		protected final INPUT input;
+		protected final DATA data;
+		protected final boolean moreAvailable;
 
-	int getPageSize();
+		InputWithData(INPUT input, DATA data, boolean moreAvailable) {
+			this.input = checkNotNull(input);
+			this.data = checkNotNull(data);
+			this.moreAvailable = moreAvailable;
+		}
+	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateListener.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateListener.java
deleted file mode 100644
index 00fa782..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateListener.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.network.partition.consumer;
-
-/**
- * Listener interface implemented by consumers of {@link InputGate} instances
- * that want to be notified of availability of buffer or event instances.
- */
-public interface InputGateListener {
-
-	/**
-	 * Notification callback if the input gate moves from zero to non-zero
-	 * available input channels with data.
-	 *
-	 * @param inputGate Input Gate that became available.
-	 */
-	void notifyInputGateNonEmpty(InputGate inputGate);
-
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index d40af83..19912b2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -56,6 +56,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Timer;
+import java.util.concurrent.CompletableFuture;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -101,7 +102,7 @@ import static org.apache.flink.util.Preconditions.checkState;
  * in two partitions (Partition 1 and 2). Each of these partitions is further partitioned into two
  * subpartitions -- one for each parallel reduce subtask.
  */
-public class SingleInputGate implements InputGate {
+public class SingleInputGate extends InputGate {
 
 	private static final Logger LOG = LoggerFactory.getLogger(SingleInputGate.class);
 
@@ -172,9 +173,6 @@ public class SingleInputGate implements InputGate {
 	/** Flag indicating whether all resources have been released. */
 	private volatile boolean isReleased;
 
-	/** Registered listener to forward buffer notifications to. */
-	private volatile InputGateListener inputGateListener;
-
 	private final List<TaskEvent> pendingEvents = new ArrayList<>();
 
 	private int numberOfUninitializedChannels;
@@ -531,12 +529,21 @@ public class SingleInputGate implements InputGate {
 		}
 
 		requestPartitions();
+		Optional<InputWithData<InputChannel, BufferAndAvailability>> next = waitAndGetNextData(blocking);
+		if (!next.isPresent()) {
+			return Optional.empty();
+		}
 
-		InputChannel currentChannel;
-		boolean moreAvailable;
-		Optional<BufferAndAvailability> result = Optional.empty();
+		InputWithData<InputChannel, BufferAndAvailability> inputWithData = next.get();
+		return Optional.of(transformToBufferOrEvent(
+			inputWithData.data.buffer(),
+			inputWithData.moreAvailable,
+			inputWithData.input));
+	}
 
-		do {
+	private Optional<InputWithData<InputChannel, BufferAndAvailability>> waitAndGetNextData(boolean blocking)
+			throws IOException, InterruptedException {
+		while (true) {
 			synchronized (inputChannelsWithData) {
 				while (inputChannelsWithData.size() == 0) {
 					if (isReleased) {
@@ -547,30 +554,43 @@ public class SingleInputGate implements InputGate {
 						inputChannelsWithData.wait();
 					}
 					else {
+						resetIsAvailable();
 						return Optional.empty();
 					}
 				}
 
-				currentChannel = inputChannelsWithData.remove();
-				enqueuedInputChannelsWithData.clear(currentChannel.getChannelIndex());
-				moreAvailable = !inputChannelsWithData.isEmpty();
-			}
+				InputChannel inputChannel = inputChannelsWithData.remove();
 
-			result = currentChannel.getNextBuffer();
-		} while (!result.isPresent());
+				Optional<BufferAndAvailability> result = inputChannel.getNextBuffer();
 
-		// this channel was now removed from the non-empty channels queue
-		// we re-add it in case it has more data, because in that case no "non-empty" notification
-		// will come for that channel
-		if (result.get().moreAvailable()) {
-			queueChannel(currentChannel);
-			moreAvailable = true;
+				if (result.isPresent() && result.get().moreAvailable()) {
+					// enqueue the inputChannel at the end to avoid starvation
+					inputChannelsWithData.add(inputChannel);
+				} else {
+					enqueuedInputChannelsWithData.clear(inputChannel.getChannelIndex());
+				}
+
+				if (inputChannelsWithData.isEmpty()) {
+					resetIsAvailable();
+				}
+
+				if (result.isPresent()) {
+					return Optional.of(new InputWithData<>(
+						inputChannel,
+						result.get(),
+						!inputChannelsWithData.isEmpty()));
+				}
+			}
 		}
+	}
 
-		final Buffer buffer = result.get().buffer();
+	private BufferOrEvent transformToBufferOrEvent(
+			Buffer buffer,
+			boolean moreAvailable,
+			InputChannel currentChannel) throws IOException, InterruptedException {
 		numBytesIn.inc(buffer.getSizeUnsafe());
 		if (buffer.isBuffer()) {
-			return Optional.of(new BufferOrEvent(buffer, currentChannel.getChannelIndex(), moreAvailable));
+			return new BufferOrEvent(buffer, currentChannel.getChannelIndex(), moreAvailable);
 		}
 		else {
 			final AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader());
@@ -589,11 +609,10 @@ public class SingleInputGate implements InputGate {
 				}
 
 				currentChannel.notifySubpartitionConsumed();
-
 				currentChannel.releaseAllResources();
 			}
 
-			return Optional.of(new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable));
+			return new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable);
 		}
 	}
 
@@ -614,15 +633,6 @@ public class SingleInputGate implements InputGate {
 	// Channel notifications
 	// ------------------------------------------------------------------------
 
-	@Override
-	public void registerListener(InputGateListener inputGateListener) {
-		if (this.inputGateListener == null) {
-			this.inputGateListener = inputGateListener;
-		} else {
-			throw new IllegalStateException("Multiple listeners");
-		}
-	}
-
 	void notifyChannelNonEmpty(InputChannel channel) {
 		queueChannel(checkNotNull(channel));
 	}
@@ -634,6 +644,8 @@ public class SingleInputGate implements InputGate {
 	private void queueChannel(InputChannel channel) {
 		int availableChannels;
 
+		CompletableFuture<?> toNotify = null;
+
 		synchronized (inputChannelsWithData) {
 			if (enqueuedInputChannelsWithData.get(channel.getChannelIndex())) {
 				return;
@@ -645,14 +657,13 @@ public class SingleInputGate implements InputGate {
 
 			if (availableChannels == 0) {
 				inputChannelsWithData.notifyAll();
+				toNotify = isAvailable;
+				isAvailable = AVAILABLE;
 			}
 		}
 
-		if (availableChannels == 0) {
-			InputGateListener listener = inputGateListener;
-			if (listener != null) {
-				listener.notifyInputGateNonEmpty(this);
-			}
+		if (toNotify != null) {
+			toNotify.complete(null);
 		}
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index ea83004..5019cfc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -25,11 +25,12 @@ import org.apache.flink.shaded.guava18.com.google.common.collect.Maps;
 import org.apache.flink.shaded.guava18.com.google.common.collect.Sets;
 
 import java.io.IOException;
-import java.util.ArrayDeque;
-import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -63,27 +64,22 @@ import static org.apache.flink.util.Preconditions.checkState;
  *
  * <strong>It is NOT possible to recursively union union input gates.</strong>
  */
-public class UnionInputGate implements InputGate, InputGateListener {
+public class UnionInputGate extends InputGate {
 
 	/** The input gates to union. */
 	private final InputGate[] inputGates;
 
 	private final Set<InputGate> inputGatesWithRemainingData;
 
-	/** Gates, which notified this input gate about available data. */
-	private final ArrayDeque<InputGate> inputGatesWithData = new ArrayDeque<>();
-
 	/**
-	 * Guardian against enqueuing an {@link InputGate} multiple times on {@code inputGatesWithData}.
+	 * Gates, which notified this input gate about available data. We are using it as a FIFO
+	 * queue of {@link InputGate}s to avoid starvation and provide some basic fairness.
 	 */
-	private final Set<InputGate> enqueuedInputGatesWithData = new HashSet<>();
+	private final LinkedHashSet<InputGate> inputGatesWithData = new LinkedHashSet<>();
 
 	/** The total number of input channels across all unioned input gates. */
 	private final int totalNumberOfInputChannels;
 
-	/** Registered listener to forward input gate notifications to. */
-	private volatile InputGateListener inputGateListener;
-
 	/**
 	 * A mapping from input gate to (logical) channel index offset. Valid channel indexes go from 0
 	 * (inclusive) to the total number of input channels (exclusive).
@@ -102,20 +98,31 @@ public class UnionInputGate implements InputGate, InputGateListener {
 
 		int currentNumberOfInputChannels = 0;
 
-		for (InputGate inputGate : inputGates) {
-			if (inputGate instanceof UnionInputGate) {
-				// if we want to add support for this, we need to implement pollNextBufferOrEvent()
-				throw new UnsupportedOperationException("Cannot union a union of input gates.");
-			}
+		synchronized (inputGatesWithData) {
+			for (InputGate inputGate : inputGates) {
+				if (inputGate instanceof UnionInputGate) {
+					// if we want to add support for this, we need to implement pollNextBufferOrEvent()
+					throw new UnsupportedOperationException("Cannot union a union of input gates.");
+				}
+
+				// The offset to use for buffer or event instances received from this input gate.
+				inputGateToIndexOffsetMap.put(checkNotNull(inputGate), currentNumberOfInputChannels);
+				inputGatesWithRemainingData.add(inputGate);
+
+				currentNumberOfInputChannels += inputGate.getNumberOfInputChannels();
 
-			// The offset to use for buffer or event instances received from this input gate.
-			inputGateToIndexOffsetMap.put(checkNotNull(inputGate), currentNumberOfInputChannels);
-			inputGatesWithRemainingData.add(inputGate);
+				CompletableFuture<?> available = inputGate.isAvailable();
 
-			currentNumberOfInputChannels += inputGate.getNumberOfInputChannels();
+				if (available.isDone()) {
+					inputGatesWithData.add(inputGate);
+				} else {
+					available.thenRun(() -> queueInputGate(inputGate));
+				}
+			}
 
-			// Register the union gate as a listener for all input gates
-			inputGate.registerListener(this);
+			if (!inputGatesWithData.isEmpty()) {
+				isAvailable = AVAILABLE;
+			}
 		}
 
 		this.totalNumberOfInputChannels = currentNumberOfInputChannels;
@@ -159,6 +166,15 @@ public class UnionInputGate implements InputGate, InputGateListener {
 
 	@Override
 	public Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException {
+		return getNextBufferOrEvent(true);
+	}
+
+	@Override
+	public Optional<BufferOrEvent> pollNextBufferOrEvent() throws IOException, InterruptedException {
+		return getNextBufferOrEvent(false);
+	}
+
+	private Optional<BufferOrEvent> getNextBufferOrEvent(boolean blocking) throws IOException, InterruptedException {
 		if (inputGatesWithRemainingData.isEmpty()) {
 			return Optional.empty();
 		}
@@ -166,72 +182,84 @@ public class UnionInputGate implements InputGate, InputGateListener {
 		// Make sure to request the partitions, if they have not been requested before.
 		requestPartitions();
 
-		InputGateWithData inputGateWithData = waitAndGetNextInputGate();
-		InputGate inputGate = inputGateWithData.inputGate;
-		BufferOrEvent bufferOrEvent = inputGateWithData.bufferOrEvent;
-
-		if (bufferOrEvent.moreAvailable()) {
-			// this buffer or event was now removed from the non-empty gates queue
-			// we re-add it in case it has more data, because in that case no "non-empty" notification
-			// will come for that gate
-			queueInputGate(inputGate);
-		}
-
-		if (bufferOrEvent.isEvent()
-			&& bufferOrEvent.getEvent().getClass() == EndOfPartitionEvent.class
-			&& inputGate.isFinished()) {
-
-			checkState(!bufferOrEvent.moreAvailable());
-			if (!inputGatesWithRemainingData.remove(inputGate)) {
-				throw new IllegalStateException("Couldn't find input gate in set of remaining " +
-					"input gates.");
-			}
+		Optional<InputWithData<InputGate, BufferOrEvent>> next = waitAndGetNextData(blocking);
+		if (!next.isPresent()) {
+			return Optional.empty();
 		}
 
-		// Set the channel index to identify the input channel (across all unioned input gates)
-		final int channelIndexOffset = inputGateToIndexOffsetMap.get(inputGate);
-
-		bufferOrEvent.setChannelIndex(channelIndexOffset + bufferOrEvent.getChannelIndex());
-		bufferOrEvent.setMoreAvailable(bufferOrEvent.moreAvailable() || inputGateWithData.moreInputGatesAvailable);
+		InputWithData<InputGate, BufferOrEvent> inputWithData = next.get();
 
-		return Optional.of(bufferOrEvent);
+		handleEndOfPartitionEvent(inputWithData.data, inputWithData.input);
+		return Optional.of(adjustForUnionInputGate(
+			inputWithData.data,
+			inputWithData.input,
+			inputWithData.moreAvailable));
 	}
 
-	@Override
-	public Optional<BufferOrEvent> pollNextBufferOrEvent() throws UnsupportedOperationException {
-		throw new UnsupportedOperationException();
-	}
-
-	private InputGateWithData waitAndGetNextInputGate() throws IOException, InterruptedException {
+	private Optional<InputWithData<InputGate, BufferOrEvent>> waitAndGetNextData(boolean blocking)
+			throws IOException, InterruptedException {
 		while (true) {
-			InputGate inputGate;
-			boolean moreInputGatesAvailable;
 			synchronized (inputGatesWithData) {
 				while (inputGatesWithData.size() == 0) {
-					inputGatesWithData.wait();
+					if (blocking) {
+						inputGatesWithData.wait();
+					} else {
+						resetIsAvailable();
+						return Optional.empty();
+					}
+				}
+
+				Iterator<InputGate> inputGateIterator = inputGatesWithData.iterator();
+				final InputGate inputGate = inputGateIterator.next();
+				inputGateIterator.remove();
+
+				// In case of inputGatesWithData being inaccurate do not block on an empty inputGate, but just poll the data.
+				Optional<BufferOrEvent> bufferOrEvent = inputGate.pollNextBufferOrEvent();
+
+				if (bufferOrEvent.isPresent() && bufferOrEvent.get().moreAvailable()) {
+					// enqueue the inputGate at the end to avoid starvation
+					inputGatesWithData.add(inputGate);
+				} else {
+					inputGate.isAvailable().thenRun(() -> queueInputGate(inputGate));
+				}
+
+				if (inputGatesWithData.isEmpty()) {
+					resetIsAvailable();
 				}
-				inputGate = inputGatesWithData.remove();
-				enqueuedInputGatesWithData.remove(inputGate);
-				moreInputGatesAvailable = enqueuedInputGatesWithData.size() > 0;
-			}
 
-			// In case of inputGatesWithData being inaccurate do not block on an empty inputGate, but just poll the data.
-			Optional<BufferOrEvent> bufferOrEvent = inputGate.pollNextBufferOrEvent();
-			if (bufferOrEvent.isPresent()) {
-				return new InputGateWithData(inputGate, bufferOrEvent.get(), moreInputGatesAvailable);
+				if (bufferOrEvent.isPresent()) {
+					return Optional.of(new InputWithData<>(
+						inputGate,
+						bufferOrEvent.get(),
+						!inputGatesWithData.isEmpty()));
+				}
 			}
 		}
 	}
 
-	private static class InputGateWithData {
-		private final InputGate inputGate;
-		private final BufferOrEvent bufferOrEvent;
-		private final boolean moreInputGatesAvailable;
+	private BufferOrEvent adjustForUnionInputGate(
+		BufferOrEvent bufferOrEvent,
+		InputGate inputGate,
+		boolean moreInputGatesAvailable) {
+		// Set the channel index to identify the input channel (across all unioned input gates)
+		final int channelIndexOffset = inputGateToIndexOffsetMap.get(inputGate);
+
+		bufferOrEvent.setChannelIndex(channelIndexOffset + bufferOrEvent.getChannelIndex());
+		bufferOrEvent.setMoreAvailable(bufferOrEvent.moreAvailable() || moreInputGatesAvailable);
+
+		return bufferOrEvent;
+	}
+
+	private void handleEndOfPartitionEvent(BufferOrEvent bufferOrEvent, InputGate inputGate) {
+		if (bufferOrEvent.isEvent()
+			&& bufferOrEvent.getEvent().getClass() == EndOfPartitionEvent.class
+			&& inputGate.isFinished()) {
 
-		InputGateWithData(InputGate inputGate, BufferOrEvent bufferOrEvent, boolean moreInputGatesAvailable) {
-			this.inputGate = checkNotNull(inputGate);
-			this.bufferOrEvent = checkNotNull(bufferOrEvent);
-			this.moreInputGatesAvailable = moreInputGatesAvailable;
+			checkState(!bufferOrEvent.moreAvailable());
+			if (!inputGatesWithRemainingData.remove(inputGate)) {
+				throw new IllegalStateException("Couldn't find input gate in set of remaining " +
+					"input gates.");
+			}
 		}
 	}
 
@@ -243,15 +271,6 @@ public class UnionInputGate implements InputGate, InputGateListener {
 	}
 
 	@Override
-	public void registerListener(InputGateListener listener) {
-		if (this.inputGateListener == null) {
-			this.inputGateListener = listener;
-		} else {
-			throw new IllegalStateException("Multiple listeners");
-		}
-	}
-
-	@Override
 	public int getPageSize() {
 		int pageSize = -1;
 		for (InputGate gate : inputGates) {
@@ -268,34 +287,29 @@ public class UnionInputGate implements InputGate, InputGateListener {
 	public void close() throws IOException {
 	}
 
-	@Override
-	public void notifyInputGateNonEmpty(InputGate inputGate) {
-		queueInputGate(checkNotNull(inputGate));
-	}
-
 	private void queueInputGate(InputGate inputGate) {
-		int availableInputGates;
+		checkNotNull(inputGate);
+
+		CompletableFuture<?> toNotify = null;
 
 		synchronized (inputGatesWithData) {
-			if (enqueuedInputGatesWithData.contains(inputGate)) {
+			if (inputGatesWithData.contains(inputGate)) {
 				return;
 			}
 
-			availableInputGates = inputGatesWithData.size();
+			int availableInputGates = inputGatesWithData.size();
 
 			inputGatesWithData.add(inputGate);
-			enqueuedInputGatesWithData.add(inputGate);
 
 			if (availableInputGates == 0) {
 				inputGatesWithData.notifyAll();
+				toNotify = isAvailable;
+				isAvailable = AVAILABLE;
 			}
 		}
 
-		if (availableInputGates == 0) {
-			InputGateListener listener = inputGateListener;
-			if (listener != null) {
-				listener.notifyInputGateNonEmpty(this);
-			}
+		if (toNotify != null) {
+			toNotify.complete(null);
 		}
 	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
new file mode 100644
index 0000000..ddf027c
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.consumer;
+
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Test base for {@link InputGate}.
+ */
+@RunWith(Parameterized.class)
+public abstract class InputGateTestBase {
+
+	@Parameter
+	public boolean enableCreditBasedFlowControl;
+
+	@Parameters(name = "Credit-based = {0}")
+	public static List<Boolean> parameters() {
+		return Arrays.asList(Boolean.TRUE, Boolean.FALSE);
+	}
+
+	protected void testIsAvailable(
+			InputGate inputGateToTest,
+			SingleInputGate inputGateToNotify,
+			TestInputChannel inputChannelWithNewData) throws Exception {
+
+		assertFalse(inputGateToTest.isAvailable().isDone());
+		assertFalse(inputGateToTest.pollNextBufferOrEvent().isPresent());
+
+		CompletableFuture<?> isAvailable = inputGateToTest.isAvailable();
+
+		assertFalse(inputGateToTest.isAvailable().isDone());
+		assertFalse(inputGateToTest.pollNextBufferOrEvent().isPresent());
+
+		assertEquals(isAvailable, inputGateToTest.isAvailable());
+
+		inputChannelWithNewData.readBuffer();
+		inputGateToNotify.notifyChannelNonEmpty(inputChannelWithNewData);
+
+		assertTrue(isAvailable.isDone());
+		assertTrue(inputGateToTest.isAvailable().isDone());
+	}
+
+	protected SingleInputGate createInputGate() {
+		return createInputGate(2);
+	}
+
+	protected SingleInputGate createInputGate(int numberOfInputChannels) {
+		return createInputGate(numberOfInputChannels, ResultPartitionType.PIPELINED);
+	}
+
+	protected SingleInputGate createInputGate(
+			int numberOfInputChannels, ResultPartitionType partitionType) {
+		SingleInputGate inputGate = createSingleInputGate(
+			numberOfInputChannels,
+			partitionType,
+			enableCreditBasedFlowControl);
+
+		assertEquals(partitionType, inputGate.getConsumedPartitionType());
+		return inputGate;
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 71e4f5a..a6f824d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -50,18 +50,13 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.taskmanager.NoOpTaskActions;
 
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
-import java.util.Arrays;
-import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicReference;
 
-import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
@@ -80,17 +75,7 @@ import static org.mockito.Mockito.when;
 /**
  * Tests for {@link SingleInputGate}.
  */
-@RunWith(Parameterized.class)
-public class SingleInputGateTest {
-
-	@Parameterized.Parameter
-	public boolean enableCreditBasedFlowControl;
-
-	@Parameterized.Parameters(name = "Credit-based = {0}")
-	public static List<Boolean> parameters() {
-		return Arrays.asList(Boolean.TRUE, Boolean.FALSE);
-	}
-
+public class SingleInputGateTest extends InputGateTestBase {
 	/**
 	 * Tests basic correctness of buffer-or-event interleaving and correct <code>null</code> return
 	 * value after receiving all end-of-partition events.
@@ -131,6 +116,15 @@ public class SingleInputGateTest {
 		assertTrue(inputGate.isFinished());
 	}
 
+	@Test
+	public void testIsAvailable() throws Exception {
+		final SingleInputGate inputGate = createInputGate(1);
+		TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
+		inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannel);
+
+		testIsAvailable(inputGate, inputGate, inputChannel);
+	}
+
 	@Test(timeout = 120 * 1000)
 	public void testIsMoreAvailableReadingFromSingleInputChannel() throws Exception {
 		// Setup
@@ -546,22 +540,6 @@ public class SingleInputGateTest {
 
 	// ---------------------------------------------------------------------------------------------
 
-	private SingleInputGate createInputGate() {
-		return createInputGate(2);
-	}
-
-	private SingleInputGate createInputGate(int numberOfInputChannels) {
-		return createInputGate(numberOfInputChannels, ResultPartitionType.PIPELINED);
-	}
-
-	private SingleInputGate createInputGate(int numberOfInputChannels, ResultPartitionType partitionType) {
-		SingleInputGate inputGate = createSingleInputGate(numberOfInputChannels, partitionType, enableCreditBasedFlowControl);
-
-		assertEquals(partitionType, inputGate.getConsumedPartitionType());
-
-		return inputGate;
-	}
-
 	private void addUnknownInputChannel(
 			NetworkEnvironment network,
 			SingleInputGate inputGate,
@@ -619,17 +597,7 @@ public class SingleInputGateTest {
 		assertEquals(expectedChannelIndex, bufferOrEvent.get().getChannelIndex());
 		assertEquals(expectedMoreAvailable, bufferOrEvent.get().moreAvailable());
 		if (!expectedMoreAvailable) {
-			try {
-				assertFalse(inputGate.pollNextBufferOrEvent().isPresent());
-			}
-			catch (UnsupportedOperationException ex) {
-				/**
-				 * {@link UnionInputGate#pollNextBufferOrEvent()} is unsupported at the moment.
-				 */
-				if (!(inputGate instanceof UnionInputGate)) {
-					throw ex;
-				}
-			}
+			assertFalse(inputGate.pollNextBufferOrEvent().isPresent());
 		}
 	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
index 464ab7c..f60cdb9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
@@ -20,17 +20,8 @@ package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import java.lang.reflect.Field;
-import java.util.ArrayDeque;
-
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.apache.flink.util.Preconditions.checkArgument;
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.spy;
 
 /**
  * A test input gate to mock reading data.
@@ -44,39 +35,8 @@ public class TestSingleInputGate {
 	public TestSingleInputGate(int numberOfInputChannels, boolean initialize) {
 		checkArgument(numberOfInputChannels >= 1);
 
-		SingleInputGate realGate = createSingleInputGate(numberOfInputChannels);
-
-		this.inputGate = spy(realGate);
-
-		// Notify about late registrations (added for DataSinkTaskTest#testUnionDataSinkTask).
-		// After merging registerInputOutput and invoke, we have to make sure that the test
-		// notifications happen at the expected time. In real programs, this is guaranteed by
-		// the instantiation and request partition life cycle.
-		try {
-			Field f = realGate.getClass().getDeclaredField("inputChannelsWithData");
-			f.setAccessible(true);
-			final ArrayDeque<InputChannel> notifications = (ArrayDeque<InputChannel>) f.get(realGate);
-
-			doAnswer(new Answer<Void>() {
-				@Override
-				public Void answer(InvocationOnMock invocation) throws Throwable {
-					invocation.callRealMethod();
-
-					synchronized (notifications) {
-						if (!notifications.isEmpty()) {
-							InputGateListener listener = (InputGateListener) invocation.getArguments()[0];
-							listener.notifyInputGateNonEmpty(inputGate);
-						}
-					}
-
-					return null;
-				}
-			}).when(inputGate).registerListener(any(InputGateListener.class));
-		} catch (Exception e) {
-			throw new RuntimeException(e);
-		}
-
-		this.inputChannels = new TestInputChannel[numberOfInputChannels];
+		inputGate = createSingleInputGate(numberOfInputChannels);
+		inputChannels = new TestInputChannel[numberOfInputChannels];
 
 		if (initialize) {
 			for (int i = 0; i < numberOfInputChannels; i++) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index 082ccec..9dca613 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -18,9 +18,10 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+
 import org.junit.Test;
 
-import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.verifyBufferOrEvent;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -29,7 +30,7 @@ import static org.junit.Assert.assertTrue;
 /**
  * Tests for {@link UnionInputGate}.
  */
-public class UnionInputGateTest {
+public class UnionInputGateTest extends InputGateTestBase {
 
 	/**
 	 * Tests basic correctness of buffer-or-event interleaving and correct <code>null</code> return
@@ -41,8 +42,8 @@ public class UnionInputGateTest {
 	@Test(timeout = 120 * 1000)
 	public void testBasicGetNextLogic() throws Exception {
 		// Setup
-		final SingleInputGate ig1 = createSingleInputGate(3);
-		final SingleInputGate ig2 = createSingleInputGate(5);
+		final SingleInputGate ig1 = createInputGate(3);
+		final SingleInputGate ig2 = createInputGate(5);
 
 		final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2});
 
@@ -101,4 +102,17 @@ public class UnionInputGateTest {
 		assertTrue(union.isFinished());
 		assertFalse(union.getNextBufferOrEvent().isPresent());
 	}
+
+	@Test
+	public void testIsAvailable() throws Exception {
+		final SingleInputGate inputGate1 = createInputGate(1);
+		TestInputChannel inputChannel1 = new TestInputChannel(inputGate1, 0);
+		inputGate1.setInputChannel(new IntermediateResultPartitionID(), inputChannel1);
+
+		final SingleInputGate inputGate2 = createInputGate(1);
+		TestInputChannel inputChannel2 = new TestInputChannel(inputGate2, 0);
+		inputGate2.setInputChannel(new IntermediateResultPartitionID(), inputChannel2);
+
+		testIsAvailable(new UnionInputGate(inputGate1, inputGate2), inputGate1, inputChannel1);
+	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
index dbb81ab..b9fe84f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
@@ -21,12 +21,16 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer;
 import org.apache.flink.runtime.io.network.api.serialization.SpanningRecordSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferListener;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability;
 import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel.BufferAndAvailabilityProvider;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
@@ -40,7 +44,6 @@ import java.util.concurrent.ConcurrentLinkedQueue;
 
 import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.buildSingleBuffer;
 import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.createBufferBuilder;
-import static org.mockito.Mockito.doReturn;
 
 /**
  * Test {@link InputGate} that allows setting multiple channels. Use
@@ -76,7 +79,7 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
 		inputQueues = new ConcurrentLinkedQueue[numInputChannels];
 
 		setupInputChannels();
-		doReturn(bufferSize).when(inputGate).getPageSize();
+		inputGate.setBufferPool(new NoOpBufferPool(bufferSize));
 	}
 
 	@SuppressWarnings("unchecked")
@@ -219,4 +222,81 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
 			return isEvent;
 		}
 	}
+
+	private static class NoOpBufferPool implements BufferPool {
+		private int bufferSize;
+
+		public NoOpBufferPool(int bufferSize) {
+			this.bufferSize = bufferSize;
+		}
+
+		@Override
+		public void lazyDestroy() {
+		}
+
+		@Override
+		public int getMemorySegmentSize() {
+			return bufferSize;
+		}
+
+		@Override
+		public Buffer requestBuffer() throws IOException {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public Buffer requestBufferBlocking() throws IOException, InterruptedException {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public BufferBuilder requestBufferBuilderBlocking() throws IOException, InterruptedException {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public boolean addBufferListener(BufferListener listener) {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public boolean isDestroyed() {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public int getNumberOfRequiredMemorySegments() {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public int getMaxNumberOfMemorySegments() {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public int getNumBuffers() {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public void setNumBuffers(int numBuffers) throws IOException {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public int getNumberOfAvailableMemorySegments() {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public int bestEffortGetNumOfUsedBuffers() {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public void recycle(MemorySegment memorySegment) {
+			throw new UnsupportedOperationException();
+		}
+	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java
index ded52ff..b1a3ad5 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java
@@ -27,7 +27,6 @@ import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
-import org.apache.flink.runtime.io.network.partition.consumer.InputGateListener;
 
 import org.junit.Test;
 
@@ -130,7 +129,7 @@ public class BarrierBufferMassiveRandomTest {
 		}
 	}
 
-	private static class RandomGeneratingInputGate implements InputGate {
+	private static class RandomGeneratingInputGate extends InputGate {
 
 		private final int numberOfChannels;
 		private final BufferPool[] bufferPools;
@@ -151,6 +150,7 @@ public class BarrierBufferMassiveRandomTest {
 			this.bufferPools = bufferPools;
 			this.barrierGens = barrierGens;
 			this.owningTaskName = owningTaskName;
+			this.isAvailable = AVAILABLE;
 		}
 
 		@Override
@@ -199,9 +199,6 @@ public class BarrierBufferMassiveRandomTest {
 		public void sendTaskEvent(TaskEvent event) {}
 
 		@Override
-		public void registerListener(InputGateListener listener) {}
-
-		@Override
 		public int getPageSize() {
 			return PAGE_SIZE;
 		}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
index a29cbf5..30941ab 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
@@ -22,7 +22,6 @@ import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
-import org.apache.flink.runtime.io.network.partition.consumer.InputGateListener;
 
 import java.util.ArrayDeque;
 import java.util.List;
@@ -32,7 +31,7 @@ import java.util.Queue;
 /**
  * Mock {@link InputGate}.
  */
-public class MockInputGate implements InputGate {
+public class MockInputGate extends InputGate {
 
 	private final int pageSize;
 
@@ -56,6 +55,8 @@ public class MockInputGate implements InputGate {
 		this.bufferOrEvents = new ArrayDeque<BufferOrEvent>(bufferOrEvents);
 		this.closed = new boolean[numberOfChannels];
 		this.owningTaskName = owningTaskName;
+
+		isAvailable = AVAILABLE;
 	}
 
 	@Override
@@ -111,10 +112,6 @@ public class MockInputGate implements InputGate {
 	}
 
 	@Override
-	public void registerListener(InputGateListener listener) {
-	}
-
-	@Override
 	public void close() {
 	}
 }