You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/01/08 13:03:59 UTC
[05/15] flink git commit: [FLINK-7406][network] Implement Netty
receiver incoming pipeline for credit-based
[FLINK-7406][network] Implement Netty receiver incoming pipeline for credit-based
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/268867ce
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/268867ce
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/268867ce
Branch: refs/heads/master
Commit: 268867ce620a2c12879749db2ecb68bbe129cad5
Parents: 542419b
Author: Zhijiang <wa...@aliyun.com>
Authored: Thu Aug 10 13:29:13 2017 +0800
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Mon Jan 8 11:46:00 2018 +0100
----------------------------------------------------------------------
.../network/netty/CreditBasedClientHandler.java | 277 ++++++++
.../runtime/io/network/netty/NettyMessage.java | 15 +-
.../netty/PartitionRequestClientHandler.java | 8 +-
.../io/network/netty/PartitionRequestQueue.java | 3 +-
.../partition/consumer/RemoteInputChannel.java | 257 +++++--
.../netty/NettyMessageSerializationTest.java | 3 +-
.../PartitionRequestClientHandlerTest.java | 151 ++---
.../partition/InputGateConcurrentTest.java | 2 +-
.../partition/InputGateFairnessTest.java | 8 +-
.../consumer/RemoteInputChannelTest.java | 665 +++++++++++++++++--
10 files changed, 1175 insertions(+), 214 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java
new file mode 100644
index 0000000..1f18588
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException;
+import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException;
+import org.apache.flink.runtime.io.network.netty.exception.TransportException;
+import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * Channel handler to read the messages of buffer response or error response from the
+ * producer, to write and flush the unannounced credits for the producer.
+ */
+class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
+
+ private static final Logger LOG = LoggerFactory.getLogger(CreditBasedClientHandler.class);
+
+ /** Channels, which already requested partitions from the producers. */
+ private final ConcurrentMap<InputChannelID, RemoteInputChannel> inputChannels = new ConcurrentHashMap<>();
+
+ private final AtomicReference<Throwable> channelError = new AtomicReference<>();
+
+ /**
+ * Set of cancelled partition requests. A request is cancelled iff an input channel is cleared
+ * while data is still coming in for this channel.
+ */
+ private final ConcurrentMap<InputChannelID, InputChannelID> cancelled = new ConcurrentHashMap<>();
+
+ private volatile ChannelHandlerContext ctx;
+
+ // ------------------------------------------------------------------------
+ // Input channel/receiver registration
+ // ------------------------------------------------------------------------
+
+ void addInputChannel(RemoteInputChannel listener) throws IOException {
+ checkError();
+
+ if (!inputChannels.containsKey(listener.getInputChannelId())) {
+ inputChannels.put(listener.getInputChannelId(), listener);
+ }
+ }
+
+ void removeInputChannel(RemoteInputChannel listener) {
+ inputChannels.remove(listener.getInputChannelId());
+ }
+
+ void cancelRequestFor(InputChannelID inputChannelId) {
+ if (inputChannelId == null || ctx == null) {
+ return;
+ }
+
+ if (cancelled.putIfAbsent(inputChannelId, inputChannelId) == null) {
+ ctx.writeAndFlush(new NettyMessage.CancelPartitionRequest(inputChannelId));
+ }
+ }
+
+ // ------------------------------------------------------------------------
+ // Network events
+ // ------------------------------------------------------------------------
+
+ @Override
+ public void channelActive(final ChannelHandlerContext ctx) throws Exception {
+ if (this.ctx == null) {
+ this.ctx = ctx;
+ }
+
+ super.channelActive(ctx);
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ // Unexpected close. In normal operation, the client closes the connection after all input
+ // channels have been removed. This indicates a problem with the remote task manager.
+ if (!inputChannels.isEmpty()) {
+ final SocketAddress remoteAddr = ctx.channel().remoteAddress();
+
+ notifyAllChannelsOfErrorAndClose(new RemoteTransportException(
+ "Connection unexpectedly closed by remote task manager '" + remoteAddr + "'. "
+ + "This might indicate that the remote task manager was lost.", remoteAddr));
+ }
+
+ super.channelInactive(ctx);
+ }
+
+ /**
+ * Called on exceptions in the client handler pipeline.
+ *
+ * <p>Remote exceptions are received as regular payload.
+ */
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+
+ if (cause instanceof TransportException) {
+ notifyAllChannelsOfErrorAndClose(cause);
+ } else {
+ final SocketAddress remoteAddr = ctx.channel().remoteAddress();
+
+ final TransportException tex;
+
+ // Improve on the connection reset by peer error message
+ if (cause instanceof IOException && cause.getMessage().equals("Connection reset by peer")) {
+ tex = new RemoteTransportException("Lost connection to task manager '" + remoteAddr + "'. " +
+ "This indicates that the remote task manager was lost.", remoteAddr, cause);
+ } else {
+ tex = new LocalTransportException(cause.getMessage(), ctx.channel().localAddress(), cause);
+ }
+
+ notifyAllChannelsOfErrorAndClose(tex);
+ }
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ try {
+ decodeMsg(msg);
+ } catch (Throwable t) {
+ notifyAllChannelsOfErrorAndClose(t);
+ }
+ }
+
+ private void notifyAllChannelsOfErrorAndClose(Throwable cause) {
+ if (channelError.compareAndSet(null, cause)) {
+ try {
+ for (RemoteInputChannel inputChannel : inputChannels.values()) {
+ inputChannel.onError(cause);
+ }
+ } catch (Throwable t) {
+ // We can only swallow the Exception at this point. :(
+ LOG.warn("An Exception was thrown during error notification of a remote input channel.", t);
+ } finally {
+ inputChannels.clear();
+
+ if (ctx != null) {
+ ctx.close();
+ }
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------
+
+ /**
+ * Checks for an error and rethrows it if one was reported.
+ */
+ private void checkError() throws IOException {
+ final Throwable t = channelError.get();
+
+ if (t != null) {
+ if (t instanceof IOException) {
+ throw (IOException) t;
+ } else {
+ throw new IOException("There has been an error in the channel.", t);
+ }
+ }
+ }
+
+ private void decodeMsg(Object msg) throws Throwable {
+ final Class<?> msgClazz = msg.getClass();
+
+ // ---- Buffer --------------------------------------------------------
+ if (msgClazz == NettyMessage.BufferResponse.class) {
+ NettyMessage.BufferResponse bufferOrEvent = (NettyMessage.BufferResponse) msg;
+
+ RemoteInputChannel inputChannel = inputChannels.get(bufferOrEvent.receiverId);
+ if (inputChannel == null) {
+ bufferOrEvent.releaseBuffer();
+
+ cancelRequestFor(bufferOrEvent.receiverId);
+
+ return;
+ }
+
+ decodeBufferOrEvent(inputChannel, bufferOrEvent);
+
+ } else if (msgClazz == NettyMessage.ErrorResponse.class) {
+ // ---- Error ---------------------------------------------------------
+ NettyMessage.ErrorResponse error = (NettyMessage.ErrorResponse) msg;
+
+ SocketAddress remoteAddr = ctx.channel().remoteAddress();
+
+ if (error.isFatalError()) {
+ notifyAllChannelsOfErrorAndClose(new RemoteTransportException(
+ "Fatal error at remote task manager '" + remoteAddr + "'.",
+ remoteAddr,
+ error.cause));
+ } else {
+ RemoteInputChannel inputChannel = inputChannels.get(error.receiverId);
+
+ if (inputChannel != null) {
+ if (error.cause.getClass() == PartitionNotFoundException.class) {
+ inputChannel.onFailedPartitionRequest();
+ } else {
+ inputChannel.onError(new RemoteTransportException(
+ "Error at remote task manager '" + remoteAddr + "'.",
+ remoteAddr,
+ error.cause));
+ }
+ }
+ }
+ } else {
+ throw new IllegalStateException("Received unknown message from producer: " + msg.getClass());
+ }
+ }
+
+ private void decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessage.BufferResponse bufferOrEvent) throws Throwable {
+ try {
+ if (bufferOrEvent.isBuffer()) {
+ // ---- Buffer ------------------------------------------------
+
+ // Early return for empty buffers. Otherwise Netty's readBytes() throws an
+ // IndexOutOfBoundsException.
+ if (bufferOrEvent.getSize() == 0) {
+ inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+ return;
+ }
+
+ Buffer buffer = inputChannel.requestBuffer();
+ if (buffer != null) {
+ buffer.setSize(bufferOrEvent.getSize());
+ bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer());
+
+ inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+ } else if (inputChannel.isReleased()) {
+ cancelRequestFor(bufferOrEvent.receiverId);
+ } else {
+ throw new IllegalStateException("No buffer available in credit-based input channel.");
+ }
+ } else {
+ // ---- Event -------------------------------------------------
+ // TODO We can just keep the serialized data in the Netty buffer and release it later at the reader
+ byte[] byteArray = new byte[bufferOrEvent.getSize()];
+ bufferOrEvent.getNettyBuffer().readBytes(byteArray);
+
+ MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray);
+ Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false);
+
+ inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+ }
+ } finally {
+ bufferOrEvent.releaseBuffer();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
index 89fb9e8..db1b899 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
@@ -221,6 +221,8 @@ public abstract class NettyMessage {
final int sequenceNumber;
+ final int backlog;
+
// ---- Deserialization -----------------------------------------------
final boolean isBuffer;
@@ -232,7 +234,8 @@ public abstract class NettyMessage {
private BufferResponse(
ByteBuf retainedSlice, boolean isBuffer, int sequenceNumber,
- InputChannelID receiverId) {
+ InputChannelID receiverId,
+ int backlog) {
// When deserializing we first have to request a buffer from the respective buffer
// provider (at the handler) and copy the buffer from Netty's space to ours. Only
// retainedSlice is set in this case.
@@ -242,15 +245,17 @@ public abstract class NettyMessage {
this.isBuffer = isBuffer;
this.sequenceNumber = sequenceNumber;
this.receiverId = checkNotNull(receiverId);
+ this.backlog = backlog;
}
- BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId) {
+ BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId, int backlog) {
this.buffer = checkNotNull(buffer);
this.retainedSlice = null;
this.isBuffer = buffer.isBuffer();
this.size = buffer.getSize();
this.sequenceNumber = sequenceNumber;
this.receiverId = checkNotNull(receiverId);
+ this.backlog = backlog;
}
boolean isBuffer() {
@@ -280,7 +285,7 @@ public abstract class NettyMessage {
ByteBuf write(ByteBufAllocator allocator) throws IOException {
checkNotNull(buffer, "No buffer instance to serialize.");
- int length = 16 + 4 + 1 + 4 + buffer.getSize();
+ int length = 16 + 4 + 4 + 1 + 4 + buffer.getSize();
ByteBuf result = null;
try {
@@ -288,6 +293,7 @@ public abstract class NettyMessage {
receiverId.writeTo(result);
result.writeInt(sequenceNumber);
+ result.writeInt(backlog);
result.writeBoolean(buffer.isBuffer());
result.writeInt(buffer.getSize());
result.writeBytes(buffer.getNioBuffer());
@@ -309,12 +315,13 @@ public abstract class NettyMessage {
static BufferResponse readFrom(ByteBuf buffer) {
InputChannelID receiverId = InputChannelID.fromByteBuf(buffer);
int sequenceNumber = buffer.readInt();
+ int backlog = buffer.readInt();
boolean isBuffer = buffer.readBoolean();
int size = buffer.readInt();
ByteBuf retainedSlice = buffer.readSlice(size).retain();
- return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId);
+ return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId, backlog);
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
index 566b215..ab4798e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
@@ -276,7 +276,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter {
// Early return for empty buffers. Otherwise Netty's readBytes() throws an
// IndexOutOfBoundsException.
if (bufferOrEvent.getSize() == 0) {
- inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber);
+ inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, -1);
return true;
}
@@ -295,7 +295,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter {
buffer.setSize(bufferOrEvent.getSize());
bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer());
- inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber);
+ inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1);
return true;
}
@@ -318,7 +318,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter {
MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray);
Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false);
- inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber);
+ inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1);
return true;
}
@@ -450,7 +450,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter {
RemoteInputChannel inputChannel = inputChannels.get(stagedBufferResponse.receiverId);
if (inputChannel != null) {
- inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber);
+ inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber, -1);
success = true;
}
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
index ff0f130..41f87ae 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
@@ -193,7 +193,8 @@ class PartitionRequestQueue extends ChannelInboundHandlerAdapter {
BufferResponse msg = new BufferResponse(
next.buffer(),
reader.getSequenceNumber(),
- reader.getReceiverId());
+ reader.getReceiverId(),
+ 0);
if (isEndOfPartitionEvent(next.buffer())) {
reader.notifySubpartitionConsumed();
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
index cd00934..02c7b34 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.io.network.partition.consumer;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.event.TaskEvent;
import org.apache.flink.runtime.io.network.ConnectionID;
@@ -32,11 +33,13 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.util.ExceptionUtils;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import java.io.IOException;
import java.util.ArrayDeque;
+import java.util.Collections;
import java.util.List;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
@@ -82,17 +85,19 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
/** The initial number of exclusive buffers assigned to this channel. */
private int initialCredit;
- /** The current available buffers including both exclusive buffers and requested floating buffers. */
- private final ArrayDeque<Buffer> availableBuffers = new ArrayDeque<>();
+ /** The available buffer queue wraps both exclusive and requested floating buffers. */
+ private final AvailableBufferQueue bufferQueue = new AvailableBufferQueue();
/** The number of available buffers that have not been announced to the producer yet. */
private final AtomicInteger unannouncedCredit = new AtomicInteger(0);
- /** The number of unsent buffers in the producer's sub partition. */
- private final AtomicInteger senderBacklog = new AtomicInteger(0);
+ /** The number of required buffers that equals to sender's backlog plus initial credit. */
+ @GuardedBy("bufferQueue")
+ private int numRequiredBuffers;
/** The tag indicates whether this channel is waiting for additional floating buffers from the buffer pool. */
- private final AtomicBoolean isWaitingForFloatingBuffers = new AtomicBoolean(false);
+ @GuardedBy("bufferQueue")
+ private boolean isWaitingForFloatingBuffers;
public RemoteInputChannel(
SingleInputGate inputGate,
@@ -133,10 +138,11 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
checkArgument(segments.size() > 0, "The number of exclusive buffers per channel should be larger than 0.");
this.initialCredit = segments.size();
+ this.numRequiredBuffers = segments.size();
- synchronized(availableBuffers) {
+ synchronized(bufferQueue) {
for (MemorySegment segment : segments) {
- availableBuffers.add(new Buffer(segment, this));
+ bufferQueue.addExclusiveBuffer(new Buffer(segment, this), numRequiredBuffers);
}
}
}
@@ -211,7 +217,7 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
// ------------------------------------------------------------------------
@Override
- boolean isReleased() {
+ public boolean isReleased() {
return isReleased.get();
}
@@ -227,7 +233,8 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
void releaseAllResources() throws IOException {
if (isReleased.compareAndSet(false, true)) {
- // Gather all exclusive buffers and recycle them to global pool in batch
+ // Gather all exclusive buffers and recycle them to global pool in batch, because
+ // we do not want to trigger redistribution of buffers after each recycle.
final List<MemorySegment> exclusiveRecyclingSegments = new ArrayList<>();
synchronized (receivedBuffers) {
@@ -240,16 +247,8 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
}
}
}
-
- synchronized (availableBuffers) {
- Buffer buffer;
- while ((buffer = availableBuffers.poll()) != null) {
- if (buffer.getRecycler() == this) {
- exclusiveRecyclingSegments.add(buffer.getMemorySegment());
- } else {
- buffer.recycle();
- }
- }
+ synchronized (bufferQueue) {
+ bufferQueue.releaseAll(exclusiveRecyclingSegments);
}
if (exclusiveRecyclingSegments.size() > 0) {
@@ -287,81 +286,93 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
}
/**
- * Exclusive buffer is recycled to this input channel directly and it may trigger notify
- * credit to producer.
+ * Exclusive buffer is recycled to this input channel directly and it may trigger return extra
+ * floating buffer and notify increased credit to the producer.
*
* @param segment The exclusive segment of this channel.
*/
@Override
public void recycle(MemorySegment segment) {
- synchronized (availableBuffers) {
- // Important: the isReleased check should be inside the synchronized block.
- // that way the segment can also be returned to global pool after added into
- // the available queue during releasing all resources.
+ int numAddedBuffers;
+
+ synchronized (bufferQueue) {
+ // Important: check the isReleased state inside synchronized block, so there is no
+ // race condition when recycle and releaseAllResources running in parallel.
if (isReleased.get()) {
try {
- inputGate.returnExclusiveSegments(Arrays.asList(segment));
+ inputGate.returnExclusiveSegments(Collections.singletonList(segment));
return;
} catch (Throwable t) {
ExceptionUtils.rethrow(t);
}
}
- availableBuffers.add(new Buffer(segment, this));
+ numAddedBuffers = bufferQueue.addExclusiveBuffer(new Buffer(segment, this), numRequiredBuffers);
}
- if (unannouncedCredit.getAndAdd(1) == 0) {
+ if (numAddedBuffers > 0 && unannouncedCredit.getAndAdd(numAddedBuffers) == 0) {
notifyCreditAvailable();
}
}
public int getNumberOfAvailableBuffers() {
- synchronized (availableBuffers) {
- return availableBuffers.size();
+ synchronized (bufferQueue) {
+ return bufferQueue.getAvailableBufferSize();
}
}
+ @VisibleForTesting
+ public int getNumberOfRequiredBuffers() {
+ return numRequiredBuffers;
+ }
+
/**
* The Buffer pool notifies this channel of an available floating buffer. If the channel is released or
* currently does not need extra buffers, the buffer should be recycled to the buffer pool. Otherwise,
- * the buffer will be added into the <tt>availableBuffers</tt> queue and the unannounced credit is
- * increased by one.
+ * the buffer will be added into the <tt>bufferQueue</tt> and the unannounced credit is increased
+ * by one.
*
* @param buffer Buffer that becomes available in buffer pool.
* @return True when this channel is waiting for more floating buffers, otherwise false.
*/
@Override
public boolean notifyBufferAvailable(Buffer buffer) {
- checkState(isWaitingForFloatingBuffers.get(), "This channel should be waiting for floating buffers.");
+ // Check the isReleased state outside synchronized block first to avoid
+ // deadlock with releaseAllResources running in parallel.
+ if (isReleased.get()) {
+ buffer.recycle();
+ return false;
+ }
- synchronized (availableBuffers) {
- // Important: the isReleased check should be inside the synchronized block.
- if (isReleased.get() || availableBuffers.size() >= senderBacklog.get()) {
- isWaitingForFloatingBuffers.set(false);
- buffer.recycle();
+ boolean needMoreBuffers = false;
+ synchronized (bufferQueue) {
+ checkState(isWaitingForFloatingBuffers, "This channel should be waiting for floating buffers.");
+ // Important: double check the isReleased state inside synchronized block, so there is no
+ // race condition when notifyBufferAvailable and releaseAllResources running in parallel.
+ if (isReleased.get() || bufferQueue.getAvailableBufferSize() >= numRequiredBuffers) {
+ buffer.recycle();
return false;
}
- availableBuffers.add(buffer);
-
- if (unannouncedCredit.getAndAdd(1) == 0) {
- notifyCreditAvailable();
- }
+ bufferQueue.addFloatingBuffer(buffer);
- if (availableBuffers.size() >= senderBacklog.get()) {
- isWaitingForFloatingBuffers.set(false);
- return false;
+ if (bufferQueue.getAvailableBufferSize() == numRequiredBuffers) {
+ isWaitingForFloatingBuffers = false;
} else {
- return true;
+ needMoreBuffers = true;
}
}
+
+ if (unannouncedCredit.getAndAdd(1) == 0) {
+ notifyCreditAvailable();
+ }
+
+ return needMoreBuffers;
}
@Override
public void notifyBufferDestroyed() {
- if (!isWaitingForFloatingBuffers.compareAndSet(true, false)) {
- throw new IllegalStateException("This channel should be waiting for floating buffers currently.");
- }
+ // Nothing to do actually.
}
// ------------------------------------------------------------------------
@@ -394,7 +405,58 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
return inputGate.getBufferProvider();
}
- public void onBuffer(Buffer buffer, int sequenceNumber) {
+ /**
+ * Requests buffer from input channel directly for receiving network data.
+ * It should always return an available buffer in credit-based mode unless
+ * the channel has been released.
+ *
+ * @return The available buffer.
+ */
+ @Nullable
+ public Buffer requestBuffer() {
+ synchronized (bufferQueue) {
+ return bufferQueue.takeBuffer();
+ }
+ }
+
+ /**
+ * Receives the backlog from the producer's buffer response. If the number of available
+ * buffers is less than backlog + initialCredit, it will request floating buffers from the buffer
+ * pool, and then notify unannounced credits to the producer.
+ *
+ * @param backlog The number of unsent buffers in the producer's sub partition.
+ */
+ @VisibleForTesting
+ void onSenderBacklog(int backlog) throws IOException {
+ int numRequestedBuffers = 0;
+
+ synchronized (bufferQueue) {
+ // Important: check the isReleased state inside synchronized block, so there is no
+ // race condition when onSenderBacklog and releaseAllResources running in parallel.
+ if (isReleased.get()) {
+ return;
+ }
+
+ numRequiredBuffers = backlog + initialCredit;
+ while (bufferQueue.getAvailableBufferSize() < numRequiredBuffers && !isWaitingForFloatingBuffers) {
+ Buffer buffer = inputGate.getBufferPool().requestBuffer();
+ if (buffer != null) {
+ bufferQueue.addFloatingBuffer(buffer);
+ numRequestedBuffers++;
+ } else if (inputGate.getBufferProvider().addBufferListener(this)) {
+ // If the channel has not got enough buffers, register it as listener to wait for more floating buffers.
+ isWaitingForFloatingBuffers = true;
+ break;
+ }
+ }
+ }
+
+ if (numRequestedBuffers > 0 && unannouncedCredit.getAndAdd(numRequestedBuffers) == 0) {
+ notifyCreditAvailable();
+ }
+ }
+
+ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog) throws IOException {
boolean success = false;
try {
@@ -416,6 +478,10 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
}
}
}
+
+ if (success && backlog >= 0) {
+ onSenderBacklog(backlog);
+ }
} finally {
if (!success) {
buffer.recycle();
@@ -423,16 +489,23 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
}
}
- public void onEmptyBuffer(int sequenceNumber) {
+ public void onEmptyBuffer(int sequenceNumber, int backlog) throws IOException {
+ boolean success = false;
+
synchronized (receivedBuffers) {
if (!isReleased.get()) {
if (expectedSequenceNumber == sequenceNumber) {
expectedSequenceNumber++;
+ success = true;
} else {
onError(new BufferReorderingException(expectedSequenceNumber, sequenceNumber));
}
}
}
+
+ if (success && backlog >= 0) {
+ onSenderBacklog(backlog);
+ }
}
public void onFailedPartitionRequest() {
@@ -462,4 +535,82 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
expectedSequenceNumber, actualSequenceNumber);
}
}
+
+ /**
+ * Manages the exclusive and floating buffers of this channel, and handles the
+ * internal buffer related logic.
+ */
+ private static class AvailableBufferQueue {
+
+ /** The current available floating buffers from the fixed buffer pool. */
+ private final ArrayDeque<Buffer> floatingBuffers;
+
+ /** The current available exclusive buffers from the global buffer pool. */
+ private final ArrayDeque<Buffer> exclusiveBuffers;
+
+ AvailableBufferQueue() {
+ this.exclusiveBuffers = new ArrayDeque<>();
+ this.floatingBuffers = new ArrayDeque<>();
+ }
+
+ /**
+ * Adds an exclusive buffer (back) into the queue and recycles one floating buffer if the
+ * number of available buffers in queue is more than the required amount.
+ *
+ * @param buffer The exclusive buffer to add
+ * @param numRequiredBuffers The number of required buffers
+ *
+ * @return How many buffers were added to the queue
+ */
+ int addExclusiveBuffer(Buffer buffer, int numRequiredBuffers) {
+ exclusiveBuffers.add(buffer);
+ if (getAvailableBufferSize() > numRequiredBuffers) {
+ Buffer floatingBuffer = floatingBuffers.poll();
+ floatingBuffer.recycle();
+ return 0;
+ } else {
+ return 1;
+ }
+ }
+
+ void addFloatingBuffer(Buffer buffer) {
+ floatingBuffers.add(buffer);
+ }
+
+ /**
+ * Takes the floating buffer first in order to make full use of floating
+ * buffers reasonably.
+ *
+ * @return An available floating or exclusive buffer, may be null
+ * if the channel is released.
+ */
+ @Nullable
+ Buffer takeBuffer() {
+ if (floatingBuffers.size() > 0) {
+ return floatingBuffers.poll();
+ } else {
+ return exclusiveBuffers.poll();
+ }
+ }
+
+ /**
+ * The floating buffer is recycled to local buffer pool directly, and the
+ * exclusive buffer will be gathered to return to global buffer pool later.
+ *
+ * @param exclusiveSegments The list that we will add exclusive segments into.
+ */
+ void releaseAll(List<MemorySegment> exclusiveSegments) {
+ Buffer buffer;
+ while ((buffer = floatingBuffers.poll()) != null) {
+ buffer.recycle();
+ }
+ while ((buffer = exclusiveBuffers.poll()) != null) {
+ exclusiveSegments.add(buffer.getMemorySegment());
+ }
+ }
+
+ int getAvailableBufferSize() {
+ return floatingBuffers.size() + exclusiveBuffers.size();
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
index 0651f97..8c87ceb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
@@ -62,7 +62,7 @@ public class NettyMessageSerializationTest {
nioBuffer.putInt(i);
}
- NettyMessage.BufferResponse expected = new NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID());
+ NettyMessage.BufferResponse expected = new NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID(), random.nextInt());
NettyMessage.BufferResponse actual = encodeAndDecode(expected);
// Verify recycle has been called on buffer instance
@@ -85,6 +85,7 @@ public class NettyMessageSerializationTest {
assertEquals(expected.sequenceNumber, actual.sequenceNumber);
assertEquals(expected.receiverId, actual.receiverId);
+ assertEquals(expected.backlog, actual.backlog);
}
{
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
index e1e5bd3..d3ff6c2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
@@ -30,23 +30,16 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
-import org.apache.flink.runtime.testutils.DiscardingRecycler;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
-import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
import java.io.IOException;
-import java.util.concurrent.atomic.AtomicReference;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@@ -80,19 +73,19 @@ public class PartitionRequestClientHandlerTest {
when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
- final BufferResponse ReceivedBuffer = createBufferResponse(
- TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId());
+ final BufferResponse receivedBuffer = createBufferResponse(
+ TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2);
final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
client.addInputChannel(inputChannel);
- client.channelRead(mock(ChannelHandlerContext.class), ReceivedBuffer);
+ client.channelRead(mock(ChannelHandlerContext.class), receivedBuffer);
}
/**
* Tests a fix for FLINK-1761.
*
- * <p> FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0.
+ * <p>FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0.
*/
@Test
public void testReceiveEmptyBuffer() throws Exception {
@@ -108,10 +101,11 @@ public class PartitionRequestClientHandlerTest {
final Buffer emptyBuffer = TestBufferFactory.createBuffer();
emptyBuffer.setSize(0);
+ final int backlog = 2;
final BufferResponse receivedBuffer = createBufferResponse(
- emptyBuffer, 0, inputChannel.getInputChannelId());
+ emptyBuffer, 0, inputChannel.getInputChannelId(), backlog);
- final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
+ final CreditBasedClientHandler client = new CreditBasedClientHandler();
client.addInputChannel(inputChannel);
// Read the empty buffer
@@ -119,6 +113,51 @@ public class PartitionRequestClientHandlerTest {
// This should not throw an exception
verify(inputChannel, never()).onError(any(Throwable.class));
+ verify(inputChannel, times(1)).onEmptyBuffer(0, backlog);
+ }
+
+ /**
+ * Verifies that {@link RemoteInputChannel#onBuffer(Buffer, int, int)} is called when a
+ * {@link BufferResponse} is received.
+ */
+ @Test
+ public void testReceiveBuffer() throws Exception {
+ final Buffer buffer = TestBufferFactory.createBuffer();
+ final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
+ when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
+ when(inputChannel.requestBuffer()).thenReturn(buffer);
+
+ final int backlog = 2;
+ final BufferResponse bufferResponse = createBufferResponse(
+ TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), backlog);
+
+ final CreditBasedClientHandler client = new CreditBasedClientHandler();
+ client.addInputChannel(inputChannel);
+
+ client.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
+
+ verify(inputChannel, times(1)).onBuffer(buffer, 0, backlog);
+ }
+
+ /**
+ * Verifies that {@link RemoteInputChannel#onError(Throwable)} is called when a
+ * {@link BufferResponse} is received but no available buffer in input channel.
+ */
+ @Test
+ public void testThrowExceptionForNoAvailableBuffer() throws Exception {
+ final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
+ when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
+ when(inputChannel.requestBuffer()).thenReturn(null);
+
+ final BufferResponse bufferResponse = createBufferResponse(
+ TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2);
+
+ final CreditBasedClientHandler client = new CreditBasedClientHandler();
+ client.addInputChannel(inputChannel);
+
+ client.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
+
+ verify(inputChannel, times(1)).onError(any(IllegalStateException.class));
}
/**
@@ -136,8 +175,8 @@ public class PartitionRequestClientHandlerTest {
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
final ErrorResponse partitionNotFound = new ErrorResponse(
- new PartitionNotFoundException(new ResultPartitionID()),
- inputChannel.getInputChannelId());
+ new PartitionNotFoundException(new ResultPartitionID()),
+ inputChannel.getInputChannelId());
final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
client.addInputChannel(inputChannel);
@@ -169,95 +208,19 @@ public class PartitionRequestClientHandlerTest {
client.cancelRequestFor(inputChannel.getInputChannelId());
}
- /**
- * Tests that an unsuccessful message decode call for a staged message
- * does not leave the channel with auto read set to false.
- */
- @Test
- @SuppressWarnings("unchecked")
- public void testAutoReadAfterUnsuccessfulStagedMessage() throws Exception {
- PartitionRequestClientHandler handler = new PartitionRequestClientHandler();
- EmbeddedChannel channel = new EmbeddedChannel(handler);
-
- final AtomicReference<BufferListener> listener = new AtomicReference<>();
-
- BufferProvider bufferProvider = mock(BufferProvider.class);
- when(bufferProvider.addBufferListener(any(BufferListener.class))).thenAnswer(new Answer<Boolean>() {
- @Override
- @SuppressWarnings("unchecked")
- public Boolean answer(InvocationOnMock invocation) throws Throwable {
- listener.set((BufferListener) invocation.getArguments()[0]);
- return true;
- }
- });
-
- when(bufferProvider.requestBuffer()).thenReturn(null);
-
- InputChannelID channelId = new InputChannelID(0, 0);
- RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
- when(inputChannel.getInputChannelId()).thenReturn(channelId);
-
- // The 3rd staged msg has a null buffer provider
- when(inputChannel.getBufferProvider()).thenReturn(bufferProvider, bufferProvider, null);
-
- handler.addInputChannel(inputChannel);
-
- BufferResponse msg = createBufferResponse(createBuffer(true), 0, channelId);
-
- // Write 1st buffer msg. No buffer is available, therefore the buffer
- // should be staged and auto read should be set to false.
- assertTrue(channel.config().isAutoRead());
- channel.writeInbound(msg);
-
- // No buffer available, auto read false
- assertFalse(channel.config().isAutoRead());
-
- // Write more buffers... all staged.
- msg = createBufferResponse(createBuffer(true), 1, channelId);
- channel.writeInbound(msg);
-
- msg = createBufferResponse(createBuffer(true), 2, channelId);
- channel.writeInbound(msg);
-
- // Notify about buffer => handle 1st msg
- Buffer availableBuffer = createBuffer(false);
- listener.get().notifyBufferAvailable(availableBuffer);
-
- // Start processing of staged buffers (in run pending tasks). Make
- // sure that the buffer provider acts like it's destroyed.
- when(bufferProvider.addBufferListener(any(BufferListener.class))).thenReturn(false);
- when(bufferProvider.isDestroyed()).thenReturn(true);
-
- // Execute all tasks that are scheduled in the event loop. Further
- // eventLoop().execute() calls are directly executed, if they are
- // called in the scope of this call.
- channel.runPendingTasks();
-
- assertTrue(channel.config().isAutoRead());
- }
-
// ---------------------------------------------------------------------------------------------
- private static Buffer createBuffer(boolean fill) {
- MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(1024, null);
- if (fill) {
- for (int i = 0; i < 1024; i++) {
- segment.put(i, (byte) i);
- }
- }
- return new Buffer(segment, DiscardingRecycler.INSTANCE, true);
- }
-
/**
* Returns a deserialized buffer message as it would be received during runtime.
*/
private BufferResponse createBufferResponse(
Buffer buffer,
int sequenceNumber,
- InputChannelID receivingChannelId) throws IOException {
+ InputChannelID receivingChannelId,
+ int backlog) throws IOException {
// Mock buffer to serialize
- BufferResponse resp = new BufferResponse(buffer, sequenceNumber, receivingChannelId);
+ BufferResponse resp = new BufferResponse(buffer, sequenceNumber, receivingChannelId, backlog);
ByteBuf serialized = resp.write(UnpooledByteBufAllocator.DEFAULT);
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
index 6f98119..81788c9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -216,7 +216,7 @@ public class InputGateConcurrentTest {
@Override
void addBuffer(Buffer buffer) throws Exception {
- channel.onBuffer(buffer, seq++);
+ channel.onBuffer(buffer, seq++, -1);
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
index 324a060..4e90265 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -206,9 +206,9 @@ public class InputGateFairnessTest {
channels[i] = channel;
for (int p = 0; p < buffersPerChannel; p++) {
- channel.onBuffer(mockBuffer, p);
+ channel.onBuffer(mockBuffer, p, -1);
}
- channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel);
+ channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel, -1);
gate.setInputChannel(new IntermediateResultPartitionID(), channel);
}
@@ -263,7 +263,7 @@ public class InputGateFairnessTest {
gate.setInputChannel(new IntermediateResultPartitionID(), channel);
}
- channels[11].onBuffer(mockBuffer, 0);
+ channels[11].onBuffer(mockBuffer, 0, -1);
channelSequenceNums[11]++;
// read all the buffers and the EOF event
@@ -325,7 +325,7 @@ public class InputGateFairnessTest {
Collections.shuffle(poss);
for (int i : poss) {
- partitions[i].onBuffer(buffer, sequenceNumbers[i]++);
+ partitions[i].onBuffer(buffer, sequenceNumbers[i]++, -1);
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index d791ced..863f886 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -18,24 +18,28 @@
package org.apache.flink.runtime.io.network.partition.consumer;
-import org.apache.flink.core.memory.MemorySegment;
-import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.execution.CancelTaskException;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.netty.PartitionRequestClient;
import org.apache.flink.runtime.io.network.partition.ProducerFailedException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.taskmanager.TaskActions;
import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
import org.junit.Test;
-import scala.Tuple2;
import java.io.IOException;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
@@ -43,12 +47,14 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
+import scala.Tuple2;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyListOf;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
@@ -66,10 +72,10 @@ public class RemoteInputChannelTest {
final Buffer buffer = TestBufferFactory.createBuffer();
// The test
- inputChannel.onBuffer(buffer.retain(), 0);
+ inputChannel.onBuffer(buffer.retain(), 0, -1);
// This does not yet throw the exception, but sets the error at the channel.
- inputChannel.onBuffer(buffer, 29);
+ inputChannel.onBuffer(buffer, 29, -1);
try {
inputChannel.getNextBuffer();
@@ -113,7 +119,7 @@ public class RemoteInputChannelTest {
for (int j = 0; j < 128; j++) {
// this is the same buffer over and over again which will be
// recycled by the RemoteInputChannel
- inputChannel.onBuffer(buffer.retain(), j);
+ inputChannel.onBuffer(buffer.retain(), j, -1);
}
if (inputChannel.isReleased()) {
@@ -301,81 +307,562 @@ public class RemoteInputChannelTest {
}
/**
- * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is
- * recycled to available buffers directly and it triggers notify of announced credit.
+ * Tests to verify the behaviours of three different processes if the number of available
+ * buffers is less than required buffers.
+ *
+ * 1. Recycle the floating buffer
+ * 2. Recycle the exclusive buffer
+ * 3. Decrease the sender's backlog
*/
@Test
- public void testRecycleExclusiveBufferBeforeReleased() throws Exception {
- final SingleInputGate inputGate = mock(SingleInputGate.class);
- final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate));
-
- // Recycle exclusive segment
- inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel));
+ public void testAvailableBuffersLessThanRequiredBuffers() throws Exception {
+ // Setup
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32);
+ final int numExclusiveBuffers = 2;
+ final int numFloatingBuffers = 14;
- assertEquals("There should be one buffer available after recycle.",
- 1, inputChannel.getNumberOfAvailableBuffers());
- verify(inputChannel, times(1)).notifyCreditAvailable();
+ final SingleInputGate inputGate = createSingleInputGate();
+ final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+ inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel);
+ try {
+ final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
+
+ // Prepare the exclusive and floating buffers to verify recycle logic later
+ final Buffer exclusiveBuffer = inputChannel.requestBuffer();
+ assertNotNull(exclusiveBuffer);
+
+ final int numRecycleFloatingBuffers = 2;
+ final ArrayDeque<Buffer> floatingBufferQueue = new ArrayDeque<>(numRecycleFloatingBuffers);
+ for (int i = 0; i < numRecycleFloatingBuffers; i++) {
+ Buffer floatingBuffer = bufferPool.requestBuffer();
+ assertNotNull(floatingBuffer);
+ floatingBufferQueue.add(floatingBuffer);
+ }
- inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel));
+ verify(bufferPool, times(numRecycleFloatingBuffers)).requestBuffer();
+
+ // Receive the producer's backlog more than the number of available floating buffers
+ inputChannel.onSenderBacklog(14);
+
+ // The channel requests (backlog + numExclusiveBuffers) floating buffers from local pool.
+ // It does not get enough floating buffers and register as buffer listener
+ verify(bufferPool, times(15)).requestBuffer();
+ verify(bufferPool, times(1)).addBufferListener(inputChannel);
+ assertEquals("There should be 13 buffers available in the channel",
+ 13, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 16 buffers required in the channel",
+ 16, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Increase the backlog
+ inputChannel.onSenderBacklog(16);
+
+ // The channel is already in the status of waiting for buffers and will not request any more
+ verify(bufferPool, times(15)).requestBuffer();
+ verify(bufferPool, times(1)).addBufferListener(inputChannel);
+ assertEquals("There should be 13 buffers available in the channel",
+ 13, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 18 buffers required in the channel",
+ 18, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Recycle one floating buffer
+ floatingBufferQueue.poll().recycle();
+
+ // Assign the floating buffer to the listener and the channel is still waiting for more floating buffers
+ verify(bufferPool, times(15)).requestBuffer();
+ verify(bufferPool, times(1)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 18 buffers required in the channel",
+ 18, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Recycle one more floating buffer
+ floatingBufferQueue.poll().recycle();
+
+ // Assign the floating buffer to the listener and the channel is still waiting for more floating buffers
+ verify(bufferPool, times(15)).requestBuffer();
+ verify(bufferPool, times(1)).addBufferListener(inputChannel);
+ assertEquals("There should be 15 buffers available in the channel",
+ 15, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 18 buffers required in the channel",
+ 18, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Decrease the backlog
+ inputChannel.onSenderBacklog(15);
+
+ // Only the number of required buffers is changed by (backlog + numExclusiveBuffers)
+ verify(bufferPool, times(15)).requestBuffer();
+ verify(bufferPool, times(1)).addBufferListener(inputChannel);
+ assertEquals("There should be 15 buffers available in the channel",
+ 15, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 17 buffers required in the channel",
+ 17, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Recycle one exclusive buffer
+ exclusiveBuffer.recycle();
+
+ // The exclusive buffer is returned to the channel directly
+ verify(bufferPool, times(15)).requestBuffer();
+ verify(bufferPool, times(1)).addBufferListener(inputChannel);
+ assertEquals("There should be 16 buffers available in the channel",
+ 16, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 17 buffers required in the channel",
+ 17, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffers available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ } finally {
+ // Release all the buffer resources
+ inputChannel.releaseAllResources();
- assertEquals("There should be two buffers available after recycle.",
- 2, inputChannel.getNumberOfAvailableBuffers());
- // It should be called only once when increased from zero.
- verify(inputChannel, times(1)).notifyCreditAvailable();
+ networkBufferPool.destroyAllBufferPools();
+ networkBufferPool.destroy();
+ }
}
/**
- * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is
- * recycled to global pool via input gate when channel is released.
+ * Tests to verify the behaviours of recycling floating and exclusive buffers if the number of available
+ * buffers equals to required buffers.
*/
@Test
- public void testRecycleExclusiveBufferAfterReleased() throws Exception {
+ public void testAvailableBuffersEqualToRequiredBuffers() throws Exception {
// Setup
- final SingleInputGate inputGate = mock(SingleInputGate.class);
- final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate));
-
- inputChannel.releaseAllResources();
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32);
+ final int numExclusiveBuffers = 2;
+ final int numFloatingBuffers = 14;
- // Recycle exclusive segment after channel released
- inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel));
+ final SingleInputGate inputGate = createSingleInputGate();
+ final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+ inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel);
+ try {
+ final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
+
+ // Prepare the exclusive and floating buffers to verify recycle logic later
+ final Buffer exclusiveBuffer = inputChannel.requestBuffer();
+ assertNotNull(exclusiveBuffer);
+ final Buffer floatingBuffer = bufferPool.requestBuffer();
+ assertNotNull(floatingBuffer);
+ verify(bufferPool, times(1)).requestBuffer();
+
+ // Receive the producer's backlog
+ inputChannel.onSenderBacklog(12);
+
+ // The channel requests (backlog + numExclusiveBuffers) floating buffers from local pool
+ // and gets enough floating buffers
+ verify(bufferPool, times(14)).requestBuffer();
+ verify(bufferPool, times(0)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 14 buffers required in the channel",
+ 14, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Recycle one floating buffer
+ floatingBuffer.recycle();
+
+ // The floating buffer is returned to local buffer directly because the channel is not waiting
+ // for floating buffers
+ verify(bufferPool, times(14)).requestBuffer();
+ verify(bufferPool, times(0)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 14 buffers required in the channel",
+ 14, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 1 buffer available in local pool",
+ 1, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Recycle one exclusive buffer
+ exclusiveBuffer.recycle();
+
+ // Return one extra floating buffer to the local pool because the number of available buffers
+ // already equals to required buffers
+ verify(bufferPool, times(14)).requestBuffer();
+ verify(bufferPool, times(0)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 14 buffers required in the channel",
+ 14, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 2 buffers available in local pool",
+ 2, bufferPool.getNumberOfAvailableMemorySegments());
+
+ } finally {
+ // Release all the buffer resources
+ inputChannel.releaseAllResources();
- assertEquals("Resource leak during recycling buffer after channel is released.",
- 0, inputChannel.getNumberOfAvailableBuffers());
- verify(inputChannel, times(0)).notifyCreditAvailable();
- verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class));
+ networkBufferPool.destroyAllBufferPools();
+ networkBufferPool.destroy();
+ }
}
/**
- * Tests {@link RemoteInputChannel#releaseAllResources()}, verifying the exclusive segments are
- * recycled to global pool via input gate and no resource leak.
+ * Tests to verify the behaviours of recycling floating and exclusive buffers if the number of available
+ * buffers is more than required buffers by decreasing the sender's backlog.
*/
@Test
- public void testReleaseExclusiveBuffers() throws Exception {
+ public void testAvailableBuffersMoreThanRequiredBuffers() throws Exception {
// Setup
- final SingleInputGate inputGate = mock(SingleInputGate.class);
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32);
+ final int numExclusiveBuffers = 2;
+ final int numFloatingBuffers = 14;
+
+ final SingleInputGate inputGate = createSingleInputGate();
final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+ inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel);
+ try {
+ final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
+
+ // Prepare the exclusive and floating buffers to verify recycle logic later
+ final Buffer exclusiveBuffer = inputChannel.requestBuffer();
+ assertNotNull(exclusiveBuffer);
+
+ final Buffer floatingBuffer = bufferPool.requestBuffer();
+ assertNotNull(floatingBuffer);
+
+ verify(bufferPool, times(1)).requestBuffer();
+
+ // Receive the producer's backlog
+ inputChannel.onSenderBacklog(12);
+
+ // The channel gets enough floating buffers from local pool
+ verify(bufferPool, times(14)).requestBuffer();
+ verify(bufferPool, times(0)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 14 buffers required in the channel",
+ 14, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Decrease the backlog to make the number of available buffers more than required buffers
+ inputChannel.onSenderBacklog(10);
+
+ // Only the number of required buffers is changed by (backlog + numExclusiveBuffers)
+ verify(bufferPool, times(14)).requestBuffer();
+ verify(bufferPool, times(0)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 12 buffers required in the channel",
+ 12, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 0 buffer available in local pool",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Recycle one exclusive buffer
+ exclusiveBuffer.recycle();
+
+ // Return one extra floating buffer to the local pool because the number of available buffers
+ // is more than required buffers
+ verify(bufferPool, times(14)).requestBuffer();
+ verify(bufferPool, times(0)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 12 buffers required in the channel",
+ 12, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 1 buffer available in local pool",
+ 1, bufferPool.getNumberOfAvailableMemorySegments());
+
+ // Recycle one floating buffer
+ floatingBuffer.recycle();
+
+ // The floating buffer is returned to local pool directly because the channel is not waiting for
+ // floating buffers
+ verify(bufferPool, times(14)).requestBuffer();
+ verify(bufferPool, times(0)).addBufferListener(inputChannel);
+ assertEquals("There should be 14 buffers available in the channel",
+ 14, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 12 buffers required in the channel",
+ 12, inputChannel.getNumberOfRequiredBuffers());
+ assertEquals("There should be 2 buffers available in local pool",
+ 2, bufferPool.getNumberOfAvailableMemorySegments());
+
+ } finally {
+ // Release all the buffer resources
+ inputChannel.releaseAllResources();
- // Assign exclusive segments to channel
- final List<MemorySegment> exclusiveSegments = new ArrayList<>();
+ networkBufferPool.destroyAllBufferPools();
+ networkBufferPool.destroy();
+ }
+ }
+
+ /**
+ * Tests to verify that the buffer pool will distribute available floating buffers among
+ * all the channel listeners in a fair way.
+ */
+ @Test
+ public void testFairDistributionFloatingBuffers() throws Exception {
+ // Setup
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(12, 32);
final int numExclusiveBuffers = 2;
- for (int i = 0; i < numExclusiveBuffers; i++) {
- exclusiveSegments.add(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel));
+ final int numFloatingBuffers = 3;
+
+ final SingleInputGate inputGate = createSingleInputGate();
+ final RemoteInputChannel channel1 = spy(createRemoteInputChannel(inputGate));
+ final RemoteInputChannel channel2 = spy(createRemoteInputChannel(inputGate));
+ final RemoteInputChannel channel3 = spy(createRemoteInputChannel(inputGate));
+ inputGate.setInputChannel(channel1.partitionId.getPartitionId(), channel1);
+ inputGate.setInputChannel(channel2.partitionId.getPartitionId(), channel2);
+ inputGate.setInputChannel(channel3.partitionId.getPartitionId(), channel3);
+ try {
+ final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
+
+ // Exhaust all the floating buffers
+ final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers);
+ for (int i = 0; i < numFloatingBuffers; i++) {
+ Buffer buffer = bufferPool.requestBuffer();
+ assertNotNull(buffer);
+ floatingBuffers.add(buffer);
+ }
+
+ // Receive the producer's backlog to trigger request floating buffers from pool
+ // and register as listeners as a result
+ channel1.onSenderBacklog(8);
+ channel2.onSenderBacklog(8);
+ channel3.onSenderBacklog(8);
+
+ verify(bufferPool, times(1)).addBufferListener(channel1);
+ verify(bufferPool, times(1)).addBufferListener(channel2);
+ verify(bufferPool, times(1)).addBufferListener(channel3);
+ assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel",
+ numExclusiveBuffers, channel1.getNumberOfAvailableBuffers());
+ assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel",
+ numExclusiveBuffers, channel2.getNumberOfAvailableBuffers());
+ assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel",
+ numExclusiveBuffers, channel3.getNumberOfAvailableBuffers());
+
+ // Recycle three floating buffers to trigger notify buffer available
+ for (Buffer buffer : floatingBuffers) {
+ buffer.recycle();
+ }
+
+ verify(channel1, times(1)).notifyBufferAvailable(any(Buffer.class));
+ verify(channel2, times(1)).notifyBufferAvailable(any(Buffer.class));
+ verify(channel3, times(1)).notifyBufferAvailable(any(Buffer.class));
+ assertEquals("There should be 3 buffers available in the channel", 3, channel1.getNumberOfAvailableBuffers());
+ assertEquals("There should be 3 buffers available in the channel", 3, channel2.getNumberOfAvailableBuffers());
+ assertEquals("There should be 3 buffers available in the channel", 3, channel3.getNumberOfAvailableBuffers());
+
+ } finally {
+ // Release all the buffer resources
+ channel1.releaseAllResources();
+ channel2.releaseAllResources();
+ channel3.releaseAllResources();
+
+ networkBufferPool.destroyAllBufferPools();
+ networkBufferPool.destroy();
}
- inputChannel.assignExclusiveSegments(exclusiveSegments);
+ }
+
+ /**
+ * Tests to verify that there is no race condition with two things running in parallel:
+ * requesting floating buffers on sender backlog and some other thread releasing
+ * the input channel.
+ */
+ @Test
+ public void testConcurrentOnSenderBacklogAndRelease() throws Exception {
+ // Setup
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(130, 32);
+ final int numExclusiveBuffers = 2;
+ final int numFloatingBuffers = 128;
+
+ final ExecutorService executor = Executors.newFixedThreadPool(2);
+
+ final SingleInputGate inputGate = createSingleInputGate();
+ final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+ inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel);
+ try {
+ final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
+
+ final Callable<Void> requestBufferTask = new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ while (true) {
+ for (int j = 1; j <= numFloatingBuffers; j++) {
+ inputChannel.onSenderBacklog(j);
+ }
- assertEquals("The number of available buffers is not equal to the assigned amount.",
- numExclusiveBuffers, inputChannel.getNumberOfAvailableBuffers());
+ if (inputChannel.isReleased()) {
+ return null;
+ }
+ }
+ }
+ };
+
+ final Callable<Void> releaseTask = new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ inputChannel.releaseAllResources();
+
+ return null;
+ }
+ };
+
+ // Submit tasks and wait to finish
+ submitTasksAndWaitForResults(executor, new Callable[]{requestBufferTask, releaseTask});
+
+ assertEquals("There should be no buffers available in the channel.",
+ 0, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be 130 buffers available in local pool.",
+ 130, bufferPool.getNumberOfAvailableMemorySegments() + networkBufferPool.getNumberOfAvailableMemorySegments());
- // Release this channel
- inputChannel.releaseAllResources();
+ } finally {
+ // Release all the buffer resources once exception
+ if (!inputChannel.isReleased()) {
+ inputChannel.releaseAllResources();
+ }
+
+ networkBufferPool.destroyAllBufferPools();
+ networkBufferPool.destroy();
- assertEquals("Resource leak after channel is released.",
- 0, inputChannel.getNumberOfAvailableBuffers());
- verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class));
+ executor.shutdown();
+ }
+ }
+
+ /**
+ * Tests to verify that there is no race condition with two things running in parallel:
+ * requesting floating buffers on sender backlog and some other thread recycling
+ * floating or exclusive buffers.
+ */
+ @Test
+ public void testConcurrentOnSenderBacklogAndRecycle() throws Exception {
+ // Setup
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(248, 32);
+ final int numExclusiveSegments = 120;
+ final int numFloatingBuffers = 128;
+ final int backlog = 128;
+
+ final ExecutorService executor = Executors.newFixedThreadPool(3);
+
+ final SingleInputGate inputGate = createSingleInputGate();
+ final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+ inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel);
+ try {
+ final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments);
+
+ final Callable<Void> requestBufferTask = new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ for (int j = 1; j <= backlog; j++) {
+ inputChannel.onSenderBacklog(j);
+ }
+
+ return null;
+ }
+ };
+
+ // Submit tasks and wait to finish
+ submitTasksAndWaitForResults(executor, new Callable[]{
+ recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
+ recycleFloatingBufferTask(bufferPool, numFloatingBuffers),
+ requestBufferTask});
+
+ assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.",
+ inputChannel.getNumberOfRequiredBuffers(), inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be no buffers available in local pool.",
+ 0, bufferPool.getNumberOfAvailableMemorySegments());
+
+ } finally {
+ // Release all the buffer resources
+ inputChannel.releaseAllResources();
+
+ networkBufferPool.destroyAllBufferPools();
+ networkBufferPool.destroy();
+
+ executor.shutdown();
+ }
+ }
+
+ /**
+ * Tests to verify that there is no race condition with two things running in parallel:
+ * recycling the exclusive or floating buffers and some other thread releasing the
+ * input channel.
+ */
+ @Test
+ public void testConcurrentRecycleAndRelease() throws Exception {
+ // Setup
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(248, 32);
+ final int numExclusiveSegments = 120;
+ final int numFloatingBuffers = 128;
+
+ final ExecutorService executor = Executors.newFixedThreadPool(3);
+
+ final SingleInputGate inputGate = createSingleInputGate();
+ final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+ inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel);
+ try {
+ final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments);
+
+ final Callable<Void> releaseTask = new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ inputChannel.releaseAllResources();
+
+ return null;
+ }
+ };
+
+ // Submit tasks and wait to finish
+ submitTasksAndWaitForResults(executor, new Callable[]{
+ recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
+ recycleFloatingBufferTask(bufferPool, numFloatingBuffers),
+ releaseTask});
+
+ assertEquals("There should be no buffers available in the channel.",
+ 0, inputChannel.getNumberOfAvailableBuffers());
+ assertEquals("There should be " + numFloatingBuffers + " buffers available in local pool.",
+ numFloatingBuffers, bufferPool.getNumberOfAvailableMemorySegments());
+ assertEquals("There should be " + numExclusiveSegments + " buffers available in global pool.",
+ numExclusiveSegments, networkBufferPool.getNumberOfAvailableMemorySegments());
+
+ } finally {
+ // Release all the buffer resources once exception
+ if (!inputChannel.isReleased()) {
+ inputChannel.releaseAllResources();
+ }
+
+ networkBufferPool.destroyAllBufferPools();
+ networkBufferPool.destroy();
+
+ executor.shutdown();
+ }
}
// ---------------------------------------------------------------------------------------------
+ private SingleInputGate createSingleInputGate() {
+ return new SingleInputGate(
+ "InputGate",
+ new JobID(),
+ new IntermediateDataSetID(),
+ ResultPartitionType.PIPELINED_CREDIT_BASED,
+ 0,
+ 1,
+ mock(TaskActions.class),
+ UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+ }
+
private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate)
throws IOException, InterruptedException {
@@ -403,4 +890,78 @@ public class RemoteInputChannelTest {
initialAndMaxRequestBackoff._2(),
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
}
+
+ /**
+ * Requests the exclusive buffers from input channel first and then recycles them by a callable task.
+ *
+ * @param inputChannel The input channel that exclusive buffers request from.
+ * @param numExclusiveSegments The number of exclusive buffers to request.
+ * @return The callable task to recycle exclusive buffers.
+ */
+ private Callable<Void> recycleExclusiveBufferTask(RemoteInputChannel inputChannel, int numExclusiveSegments) {
+ final List<Buffer> exclusiveBuffers = new ArrayList<>(numExclusiveSegments);
+ // Exhaust all the exclusive buffers
+ for (int i = 0; i < numExclusiveSegments; i++) {
+ Buffer buffer = inputChannel.requestBuffer();
+ assertNotNull(buffer);
+ exclusiveBuffers.add(buffer);
+ }
+
+ return new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ for (Buffer buffer : exclusiveBuffers) {
+ buffer.recycle();
+ }
+
+ return null;
+ }
+ };
+ }
+
+ /**
+ * Requests the floating buffers from pool first and then recycles them by a callable task.
+ *
+ * @param bufferPool The buffer pool that floating buffers request from.
+ * @param numFloatingBuffers The number of floating buffers to request.
+ * @return The callable task to recycle floating buffers.
+ */
+ private Callable<Void> recycleFloatingBufferTask(BufferPool bufferPool, int numFloatingBuffers) throws Exception {
+ final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers);
+ // Exhaust all the floating buffers
+ for (int i = 0; i < numFloatingBuffers; i++) {
+ Buffer buffer = bufferPool.requestBuffer();
+ assertNotNull(buffer);
+ floatingBuffers.add(buffer);
+ }
+
+ return new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ for (Buffer buffer : floatingBuffers) {
+ buffer.recycle();
+ }
+
+ return null;
+ }
+ };
+ }
+
+ /**
+ * Submits all the callable tasks to the executor and waits for the results.
+ *
+ * @param executor The executor service for running tasks.
+ * @param tasks The callable tasks to be submitted and executed.
+ */
+ private void submitTasksAndWaitForResults(ExecutorService executor, Callable[] tasks) throws Exception {
+ final List<Future> results = Lists.newArrayListWithCapacity(tasks.length);
+
+ for(Callable task : tasks) {
+ results.add(executor.submit(task));
+ }
+
+ for (Future result : results) {
+ result.get();
+ }
+ }
}