You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/07/21 10:13:06 UTC
[flink] branch master updated: [FLINK-28512][network] Select HashBasedDataBuffer and SortBasedDataBuffer dynamically based on the number of network buffers can be allocated for SortMergeResultPartition
This is an automated email from the ASF dual-hosted git repository.
yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 5f0c4ab91a8 [FLINK-28512][network] Select HashBasedDataBuffer and SortBasedDataBuffer dynamically based on the number of network buffers can be allocated for SortMergeResultPartition
5f0c4ab91a8 is described below
commit 5f0c4ab91a8ffc80ff04d0324934b2993fc5b533
Author: Tan Yuxin <ta...@gmail.com>
AuthorDate: Thu Jul 21 18:12:55 2022 +0800
[FLINK-28512][network] Select HashBasedDataBuffer and SortBasedDataBuffer dynamically based on the number of network buffers can be allocated for SortMergeResultPartition
This closes #20315.
---
.../runtime/io/network/partition/DataBuffer.java | 3 -
.../io/network/partition/HashBasedDataBuffer.java | 81 ++++++------
.../io/network/partition/SortBasedDataBuffer.java | 88 ++++---------
.../partition/SortMergeResultPartition.java | 136 ++++++++++++---------
.../io/network/partition/DataBufferTest.java | 60 +++++----
.../partition/SortMergeResultPartitionTest.java | 29 ++---
6 files changed, 183 insertions(+), 214 deletions(-)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
index 4d36bfe35ea..ed5e5a42b43 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
@@ -58,9 +58,6 @@ public interface DataBuffer {
/** Returns true if not all data appended to this {@link DataBuffer} is consumed. */
boolean hasRemaining();
- /** Resets this {@link DataBuffer} to be reused for data appending. */
- void reset();
-
/** Finishes this {@link DataBuffer} which means no record can be appended any more. */
void finish();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
index 03410278236..7114b1f41a5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
@@ -23,7 +23,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
-import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
@@ -32,6 +32,7 @@ import javax.annotation.Nullable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
+import java.util.LinkedList;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -47,8 +48,11 @@ import static org.apache.flink.util.Preconditions.checkState;
*/
public class HashBasedDataBuffer implements DataBuffer {
- /** A buffer pool to request memory segments from. */
- private final BufferPool bufferPool;
+ /** A list of {@link MemorySegment}s used to store data in memory. */
+ private final LinkedList<MemorySegment> freeSegments;
+
+ /** {@link BufferRecycler} used to recycle {@link #freeSegments}. */
+ private final BufferRecycler bufferRecycler;
/** Number of guaranteed buffers can be allocated from the buffer pool for data sort. */
private final int numGuaranteedBuffers;
@@ -56,6 +60,9 @@ public class HashBasedDataBuffer implements DataBuffer {
/** Buffers containing data for all subpartitions. */
private final ArrayDeque<BufferConsumer>[] buffers;
+ /** Size of buffers requested from buffer pool. All buffers must be of the same size. */
+ private final int bufferSize;
+
// ---------------------------------------------------------------------------------------------
// Statistics and states
// ---------------------------------------------------------------------------------------------
@@ -66,9 +73,6 @@ public class HashBasedDataBuffer implements DataBuffer {
/** Total number of records already appended to this sort buffer. */
private long numTotalRecords;
- /** Whether this sort buffer is full and ready to read data from. */
- private boolean isFull;
-
/** Whether this sort buffer is finished. One can only read a finished sort buffer. */
private boolean isFinished;
@@ -99,14 +103,19 @@ public class HashBasedDataBuffer implements DataBuffer {
private long numTotalBytesRead;
public HashBasedDataBuffer(
- BufferPool bufferPool,
+ LinkedList<MemorySegment> freeSegments,
+ BufferRecycler bufferRecycler,
int numSubpartitions,
+ int bufferSize,
int numGuaranteedBuffers,
@Nullable int[] customReadOrder) {
checkArgument(numGuaranteedBuffers > 0, "No guaranteed buffers for sort.");
- this.bufferPool = checkNotNull(bufferPool);
+ this.freeSegments = checkNotNull(freeSegments);
+ this.bufferRecycler = checkNotNull(bufferRecycler);
+ this.bufferSize = bufferSize;
this.numGuaranteedBuffers = numGuaranteedBuffers;
+ checkState(numGuaranteedBuffers <= freeSegments.size(), "Wrong number of free segments.");
this.builders = new BufferBuilder[numSubpartitions];
this.buffers = new ArrayDeque[numSubpartitions];
@@ -134,7 +143,6 @@ public class HashBasedDataBuffer implements DataBuffer {
public boolean append(ByteBuffer source, int targetChannel, Buffer.DataType dataType)
throws IOException {
checkArgument(source.hasRemaining(), "Cannot append empty data.");
- checkState(!isFull, "Sort buffer is already full.");
checkState(!isFinished, "Sort buffer is already finished.");
checkState(!isReleased, "Sort buffer is already released.");
@@ -145,19 +153,18 @@ public class HashBasedDataBuffer implements DataBuffer {
writeEvent(source, targetChannel, dataType);
}
- isFull = source.hasRemaining();
- if (!isFull) {
- ++numTotalRecords;
+ if (source.hasRemaining()) {
+ return true;
}
+ ++numTotalRecords;
numTotalBytes += totalBytes - source.remaining();
- return isFull;
+ return false;
}
private void writeEvent(ByteBuffer source, int targetChannel, Buffer.DataType dataType) {
BufferBuilder builder = builders[targetChannel];
if (builder != null) {
builder.finish();
- buffers[targetChannel].add(builder.createBufferConsumerFromBeginning());
builder.close();
builders[targetChannel] = null;
}
@@ -172,14 +179,19 @@ public class HashBasedDataBuffer implements DataBuffer {
buffers[targetChannel].add(consumer);
}
- private void writeRecord(ByteBuffer source, int targetChannel) throws IOException {
+ private void writeRecord(ByteBuffer source, int targetChannel) {
+ BufferBuilder builder = builders[targetChannel];
+ int availableBytes = builder != null ? builder.getWritableBytes() : 0;
+ if (source.remaining()
+ > availableBytes
+ + (numGuaranteedBuffers - numBuffersOccupied) * (long) bufferSize) {
+ return;
+ }
+
do {
- BufferBuilder builder = builders[targetChannel];
if (builder == null) {
- builder = requestBufferFromPool();
- if (builder == null) {
- break;
- }
+ builder = new BufferBuilder(freeSegments.poll(), bufferRecycler);
+ buffers[targetChannel].add(builder.createBufferConsumer());
++numBuffersOccupied;
builders[targetChannel] = builder;
}
@@ -187,29 +199,16 @@ public class HashBasedDataBuffer implements DataBuffer {
builder.append(source);
if (builder.isFull()) {
builder.finish();
- buffers[targetChannel].add(builder.createBufferConsumerFromBeginning());
builder.close();
builders[targetChannel] = null;
+ builder = null;
}
} while (source.hasRemaining());
}
- private BufferBuilder requestBufferFromPool() throws IOException {
- try {
- // blocking request buffers if there is still guaranteed memory
- if (numBuffersOccupied < numGuaranteedBuffers) {
- return bufferPool.requestBufferBuilderBlocking();
- }
- } catch (InterruptedException e) {
- throw new IOException("Interrupted while requesting buffer.", e);
- }
-
- return bufferPool.requestBufferBuilder();
- }
-
@Override
public BufferWithChannel getNextBuffer(MemorySegment transitBuffer) {
- checkState(isFull, "Sort buffer is not ready to be read.");
+ checkState(isFinished, "Sort buffer is not ready to be read.");
checkState(!isReleased, "Sort buffer is already released.");
BufferWithChannel buffer = null;
@@ -250,27 +249,15 @@ public class HashBasedDataBuffer implements DataBuffer {
return numTotalBytesRead < numTotalBytes;
}
- @Override
- public void reset() {
- checkState(!isFinished, "Sort buffer has been finished.");
- checkState(!isReleased, "Sort buffer has been released.");
-
- isFull = false;
- readOrderIndex = 0;
- }
-
@Override
public void finish() {
- checkState(!isFull, "DataBuffer must not be full.");
checkState(!isFinished, "DataBuffer is already finished.");
- isFull = true;
isFinished = true;
for (int channel = 0; channel < builders.length; ++channel) {
BufferBuilder builder = builders[channel];
if (builder != null) {
builder.finish();
- buffers[channel].add(builder.createBufferConsumerFromBeginning());
builder.close();
builders[channel] = null;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
index cd2acaa6d66..993518295c5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.io.network.partition;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.network.buffer.Buffer;
-import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import javax.annotation.Nullable;
@@ -31,6 +31,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.LinkedList;
import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
import static org.apache.flink.util.Preconditions.checkArgument;
@@ -58,8 +59,11 @@ public class SortBasedDataBuffer implements DataBuffer {
*/
private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8;
- /** A buffer pool to request memory segments from. */
- private final BufferPool bufferPool;
+ /** A list of {@link MemorySegment}s used to store data in memory. */
+ private final LinkedList<MemorySegment> freeSegments;
+
+ /** {@link BufferRecycler} used to recycle {@link #freeSegments}. */
+ private final BufferRecycler bufferRecycler;
/** A segment list as a joint buffer which stores all records and index entries. */
private final ArrayList<MemorySegment> segments = new ArrayList<>();
@@ -89,9 +93,6 @@ public class SortBasedDataBuffer implements DataBuffer {
/** Total number of bytes already read from this sort buffer. */
private long numTotalBytesRead;
- /** Whether this sort buffer is full and ready to read data from. */
- private boolean isFull;
-
/** Whether this sort buffer is finished. One can only read a finished sort buffer. */
private boolean isFinished;
@@ -125,7 +126,8 @@ public class SortBasedDataBuffer implements DataBuffer {
private int readOrderIndex = -1;
public SortBasedDataBuffer(
- BufferPool bufferPool,
+ LinkedList<MemorySegment> freeSegments,
+ BufferRecycler bufferRecycler,
int numSubpartitions,
int bufferSize,
int numGuaranteedBuffers,
@@ -133,9 +135,11 @@ public class SortBasedDataBuffer implements DataBuffer {
checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is too small.");
checkArgument(numGuaranteedBuffers > 0, "No guaranteed buffers for sort.");
- this.bufferPool = checkNotNull(bufferPool);
+ this.freeSegments = checkNotNull(freeSegments);
+ this.bufferRecycler = checkNotNull(bufferRecycler);
this.bufferSize = bufferSize;
this.numGuaranteedBuffers = numGuaranteedBuffers;
+ checkState(numGuaranteedBuffers <= freeSegments.size(), "Wrong number of free segments.");
this.firstIndexEntryAddresses = new long[numSubpartitions];
this.lastIndexEntryAddresses = new long[numSubpartitions];
@@ -162,7 +166,6 @@ public class SortBasedDataBuffer implements DataBuffer {
public boolean append(ByteBuffer source, int targetChannel, DataType dataType)
throws IOException {
checkArgument(source.hasRemaining(), "Cannot append empty data.");
- checkState(!isFull, "Sort buffer is already full.");
checkState(!isFinished, "Sort buffer is already finished.");
checkState(!isReleased, "Sort buffer is already released.");
@@ -170,11 +173,6 @@ public class SortBasedDataBuffer implements DataBuffer {
// return true directly if it can not allocate enough buffers for the given record
if (!allocateBuffersForRecord(totalBytes)) {
- isFull = true;
- if (hasRemaining()) {
- // prepare for reading
- updateReadChannelAndIndexEntryAddress();
- }
return true;
}
@@ -224,7 +222,7 @@ public class SortBasedDataBuffer implements DataBuffer {
}
}
- private boolean allocateBuffersForRecord(int numRecordBytes) throws IOException {
+ private boolean allocateBuffersForRecord(int numRecordBytes) {
int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes;
int availableBytes =
writeSegmentIndex == segments.size() ? 0 : bufferSize - writeSegmentOffset;
@@ -240,16 +238,16 @@ public class SortBasedDataBuffer implements DataBuffer {
availableBytes = 0;
}
+ if (availableBytes + (numGuaranteedBuffers - segments.size()) * (long) bufferSize
+ < numBytesRequired) {
+ return false;
+ }
+
// allocate exactly enough buffers for the appended record
do {
- MemorySegment segment = requestBufferFromPool();
- if (segment == null) {
- // return false if we can not allocate enough buffers for the appended record
- return false;
- }
-
+ MemorySegment segment = freeSegments.poll();
availableBytes += bufferSize;
- addBuffer(segment);
+ addBuffer(checkNotNull(segment));
} while (availableBytes < numBytesRequired);
return true;
@@ -257,31 +255,18 @@ public class SortBasedDataBuffer implements DataBuffer {
private void addBuffer(MemorySegment segment) {
if (segment.size() != bufferSize) {
- bufferPool.recycle(segment);
+ bufferRecycler.recycle(segment);
throw new IllegalStateException("Illegal memory segment size.");
}
if (isReleased) {
- bufferPool.recycle(segment);
+ bufferRecycler.recycle(segment);
throw new IllegalStateException("Sort buffer is already released.");
}
segments.add(segment);
}
- private MemorySegment requestBufferFromPool() throws IOException {
- try {
- // blocking request buffers if there is still guaranteed memory
- if (segments.size() < numGuaranteedBuffers) {
- return bufferPool.requestMemorySegmentBlocking();
- }
- } catch (InterruptedException e) {
- throw new IOException("Interrupted while requesting buffer.");
- }
-
- return bufferPool.requestMemorySegment();
- }
-
private void updateWriteSegmentIndexAndOffset(int numBytes) {
writeSegmentOffset += numBytes;
@@ -294,7 +279,7 @@ public class SortBasedDataBuffer implements DataBuffer {
@Override
public BufferWithChannel getNextBuffer(MemorySegment transitBuffer) {
- checkState(isFull, "Sort buffer is not ready to be read.");
+ checkState(isFinished, "Sort buffer is not ready to be read.");
checkState(!isReleased, "Sort buffer is already released.");
if (!hasRemaining()) {
@@ -426,38 +411,13 @@ public class SortBasedDataBuffer implements DataBuffer {
return numTotalBytesRead < numTotalBytes;
}
- @Override
- public void reset() {
- checkState(!isFinished, "Sort buffer has been finished.");
- checkState(!isReleased, "Sort buffer has been released.");
- checkState(!hasRemaining(), "Still has remaining data.");
-
- for (MemorySegment segment : segments) {
- bufferPool.recycle(segment);
- }
- segments.clear();
-
- // initialized with -1 means the corresponding channel has no data
- Arrays.fill(firstIndexEntryAddresses, -1L);
- Arrays.fill(lastIndexEntryAddresses, -1L);
-
- isFull = false;
- writeSegmentIndex = 0;
- writeSegmentOffset = 0;
- readIndexEntryAddress = 0;
- recordRemainingBytes = 0;
- readOrderIndex = -1;
- }
-
@Override
public void finish() {
- checkState(!isFull, "DataBuffer must not be full.");
checkState(!isFinished, "DataBuffer is already finished.");
isFinished = true;
// prepare for reading
- isFull = true;
updateReadChannelAndIndexEntryAddress();
}
@@ -474,7 +434,7 @@ public class SortBasedDataBuffer implements DataBuffer {
isReleased = true;
for (MemorySegment segment : segments) {
- bufferPool.recycle(segment);
+ bufferRecycler.recycle(segment);
}
segments.clear();
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index 26e648e3379..1620c36beea 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -43,6 +43,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
+import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Random;
@@ -84,9 +85,6 @@ public class SortMergeResultPartition extends ResultPartition {
@GuardedBy("lock")
private PartitionedFile resultFile;
- /** Buffers cut from the network buffer pool for data writing. */
- private final List<MemorySegment> writeSegments = new ArrayList<>();
-
private boolean hasNotifiedEndOfUserRecords;
/** Size of network buffer and write buffer. */
@@ -115,6 +113,9 @@ public class SortMergeResultPartition extends ResultPartition {
*/
private final SortMergeResultPartitionReadScheduler readScheduler;
+ /** All available network buffers can be used by this result partition for a data region. */
+ private final LinkedList<MemorySegment> freeSegments = new LinkedList<>();
+
/**
* Number of guaranteed network buffers can be used by {@link #unicastDataBuffer} and {@link
* #broadcastDataBuffer}.
@@ -190,42 +191,7 @@ public class SortMergeResultPartition extends ResultPartition {
readBufferPool.initialize();
super.setup();
- int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments();
- if (numRequiredBuffer < 2) {
- throw new IOException(
- String.format(
- "Too few sort buffers, please increase %s.",
- NettyShuffleEnvironmentOptions.NETWORK_SORT_SHUFFLE_MIN_BUFFERS));
- }
-
- int expectedWriteBuffers = 0;
- if (numRequiredBuffer >= 2 * numSubpartitions) {
- useHashBuffer = true;
- } else if (networkBufferSize >= NUM_WRITE_BUFFER_BYTES) {
- expectedWriteBuffers = 1;
- } else {
- expectedWriteBuffers =
- Math.min(EXPECTED_WRITE_BATCH_SIZE, NUM_WRITE_BUFFER_BYTES / networkBufferSize);
- }
-
- int numBuffersForWrite = Math.min(numRequiredBuffer / 2, expectedWriteBuffers);
- numBuffersForSort = numRequiredBuffer - numBuffersForWrite;
-
- try {
- for (int i = 0; i < numBuffersForWrite; ++i) {
- MemorySegment segment = bufferPool.requestMemorySegmentBlocking();
- writeSegments.add(segment);
- }
- } catch (InterruptedException exception) {
- // the setup method does not allow InterruptedException
- throw new IOException(exception);
- }
-
- LOG.info(
- "Sort-merge partition {} initialized, num sort buffers: {}, num write buffers: {}.",
- getPartitionId(),
- numBuffersForSort,
- numBuffersForWrite);
+ LOG.info("Sort-merge partition {} initialized.", getPartitionId());
}
@Override
@@ -296,13 +262,13 @@ public class SortMergeResultPartition extends ResultPartition {
}
if (!dataBuffer.hasRemaining()) {
- dataBuffer.reset();
+ dataBuffer.release();
writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
return;
}
flushDataBuffer(dataBuffer, isBroadcast);
- dataBuffer.reset();
+ dataBuffer.release();
if (record.hasRemaining()) {
emit(record, targetSubpartition, dataType, isBroadcast);
}
@@ -317,7 +283,9 @@ public class SortMergeResultPartition extends ResultPartition {
private DataBuffer getUnicastDataBuffer() throws IOException {
flushBroadcastDataBuffer();
- if (unicastDataBuffer != null && !unicastDataBuffer.isFinished()) {
+ if (unicastDataBuffer != null
+ && !unicastDataBuffer.isFinished()
+ && !unicastDataBuffer.isReleased()) {
return unicastDataBuffer;
}
@@ -328,7 +296,9 @@ public class SortMergeResultPartition extends ResultPartition {
private DataBuffer getBroadcastDataBuffer() throws IOException {
flushUnicastDataBuffer();
- if (broadcastDataBuffer != null && !broadcastDataBuffer.isFinished()) {
+ if (broadcastDataBuffer != null
+ && !broadcastDataBuffer.isFinished()
+ && !broadcastDataBuffer.isReleased()) {
return broadcastDataBuffer;
}
@@ -336,12 +306,20 @@ public class SortMergeResultPartition extends ResultPartition {
return broadcastDataBuffer;
}
- private DataBuffer createNewDataBuffer() {
+ private DataBuffer createNewDataBuffer() throws IOException {
+ requestNetworkBuffers();
+
if (useHashBuffer) {
return new HashBasedDataBuffer(
- bufferPool, numSubpartitions, numBuffersForSort, subpartitionOrder);
+ freeSegments,
+ bufferPool,
+ numSubpartitions,
+ networkBufferSize,
+ numBuffersForSort,
+ subpartitionOrder);
} else {
return new SortBasedDataBuffer(
+ freeSegments,
bufferPool,
numSubpartitions,
networkBufferSize,
@@ -350,12 +328,53 @@ public class SortMergeResultPartition extends ResultPartition {
}
}
+ private void requestNetworkBuffers() throws IOException {
+ int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments();
+ if (numRequiredBuffer < 2) {
+ throw new IOException(
+ String.format(
+ "Too few sort buffers, please increase %s.",
+ NettyShuffleEnvironmentOptions.NETWORK_SORT_SHUFFLE_MIN_BUFFERS));
+ }
+
+ try {
+ while (freeSegments.size() < numRequiredBuffer) {
+ freeSegments.add(checkNotNull(bufferPool.requestMemorySegmentBlocking()));
+ }
+ } catch (InterruptedException exception) {
+ throw new IOException("Failed to allocate buffers for result partition.", exception);
+ }
+
+ // avoid taking too many buffers in one result partition
+ while (freeSegments.size() < bufferPool.getMaxNumberOfMemorySegments()) {
+ MemorySegment segment = bufferPool.requestMemorySegment();
+ if (segment == null) {
+ break;
+ }
+ freeSegments.add(segment);
+ }
+
+ useHashBuffer = false;
+ int numWriteBuffers = 0;
+ if (freeSegments.size() >= 2 * numSubpartitions) {
+ useHashBuffer = true;
+ } else if (networkBufferSize >= NUM_WRITE_BUFFER_BYTES) {
+ numWriteBuffers = 1;
+ } else {
+ numWriteBuffers =
+ Math.min(EXPECTED_WRITE_BATCH_SIZE, NUM_WRITE_BUFFER_BYTES / networkBufferSize);
+ }
+ numWriteBuffers = Math.min(freeSegments.size() / 2, numWriteBuffers);
+ numBuffersForSort = freeSegments.size() - numWriteBuffers;
+ }
+
private void flushDataBuffer(DataBuffer dataBuffer, boolean isBroadcast) throws IOException {
if (dataBuffer == null || dataBuffer.isReleased() || !dataBuffer.hasRemaining()) {
return;
}
+ dataBuffer.finish();
- Queue<MemorySegment> segments = new ArrayDeque<>(writeSegments);
+ Queue<MemorySegment> segments = new ArrayDeque<>(freeSegments);
int numBuffersToWrite =
useHashBuffer
? EXPECTED_WRITE_BATCH_SIZE
@@ -366,7 +385,7 @@ public class SortMergeResultPartition extends ResultPartition {
do {
if (toWrite.size() >= numBuffersToWrite) {
writeBuffers(toWrite);
- segments = new ArrayDeque<>(writeSegments);
+ segments = new ArrayDeque<>(freeSegments);
}
BufferWithChannel bufferWithChannel = dataBuffer.getNextBuffer(segments.poll());
@@ -378,11 +397,12 @@ public class SortMergeResultPartition extends ResultPartition {
updateStatistics(bufferWithChannel.getBuffer(), isBroadcast);
toWrite.add(compressBufferIfPossible(bufferWithChannel));
} while (true);
+
+ releaseFreeBuffers();
}
private void flushBroadcastDataBuffer() throws IOException {
if (broadcastDataBuffer != null) {
- broadcastDataBuffer.finish();
flushDataBuffer(broadcastDataBuffer, true);
broadcastDataBuffer.release();
broadcastDataBuffer = null;
@@ -391,7 +411,6 @@ public class SortMergeResultPartition extends ResultPartition {
private void flushUnicastDataBuffer() throws IOException {
if (unicastDataBuffer != null) {
- unicastDataBuffer.finish();
flushDataBuffer(unicastDataBuffer, false);
unicastDataBuffer.release();
unicastDataBuffer = null;
@@ -421,19 +440,17 @@ public class SortMergeResultPartition extends ResultPartition {
private void writeLargeRecord(
ByteBuffer record, int targetSubpartition, DataType dataType, boolean isBroadcast)
throws IOException {
- // for the hash-based data buffer implementation, a large record will be appended to the
- // data buffer directly and spilled to multiple data regions
- checkState(!useHashBuffer, "No buffers available for writing.");
+ // a large record will be spilled to a separated data region
fileWriter.startNewRegion(isBroadcast);
List<BufferWithChannel> toWrite = new ArrayList<>();
- Queue<MemorySegment> segments = new ArrayDeque<>(writeSegments);
+ Queue<MemorySegment> segments = new ArrayDeque<>(freeSegments);
while (record.hasRemaining()) {
if (segments.isEmpty()) {
fileWriter.writeBuffers(toWrite);
toWrite.clear();
- segments = new ArrayDeque<>(writeSegments);
+ segments = new ArrayDeque<>(freeSegments);
}
int toCopy = Math.min(record.remaining(), networkBufferSize);
@@ -447,6 +464,7 @@ public class SortMergeResultPartition extends ResultPartition {
}
fileWriter.writeBuffers(toWrite);
+ releaseFreeBuffers();
}
private void writeBuffers(List<BufferWithChannel> buffers) throws IOException {
@@ -480,18 +498,16 @@ public class SortMergeResultPartition extends ResultPartition {
}
}
- private void releaseWriteBuffers() {
+ private void releaseFreeBuffers() {
if (bufferPool != null) {
- for (MemorySegment segment : writeSegments) {
- bufferPool.recycle(segment);
- }
- writeSegments.clear();
+ freeSegments.forEach(buffer -> bufferPool.recycle(buffer));
+ freeSegments.clear();
}
}
@Override
public void close() {
- releaseWriteBuffers();
+ releaseFreeBuffers();
// the close method will be always called by the task thread, so there is need to make
// the sort buffer fields volatile and visible to the cancel thread intermediately
releaseDataBuffer(unicastDataBuffer);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
index a6b8b6cdfb2..bd2583e8f32 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
@@ -34,6 +34,7 @@ import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Random;
@@ -81,12 +82,10 @@ public class DataBufferTest {
// fill the sort buffer with randomly generated data
int totalBytesWritten = 0;
+ int[] subpartitionReadOrder = getRandomSubpartitionOrder(numSubpartitions);
DataBuffer dataBuffer =
createDataBuffer(
- bufferPoolSize,
- bufferSize,
- numSubpartitions,
- getRandomSubpartitionOrder(numSubpartitions));
+ bufferPoolSize, bufferSize, numSubpartitions, subpartitionReadOrder);
int numDataBuffers = 5;
while (numDataBuffers > 0) {
// record size may be larger than buffer size so a record may span multiple segments
@@ -112,31 +111,34 @@ public class DataBufferTest {
totalBytesWritten += record.remaining();
}
- while (isFull && dataBuffer.hasRemaining()) {
+ if (!isFull) {
+ continue;
+ }
+ dataBuffer.finish();
+ --numDataBuffers;
+
+ while (dataBuffer.hasRemaining()) {
BufferWithChannel buffer = copyIntoSegment(bufferSize, dataBuffer);
if (buffer == null) {
break;
}
addBufferRead(buffer, buffersRead, numBytesRead);
}
-
- if (isFull) {
- --numDataBuffers;
- dataBuffer.reset();
- }
+ dataBuffer =
+ createDataBuffer(
+ bufferPoolSize, bufferSize, numSubpartitions, subpartitionReadOrder);
}
// read all data from the sort buffer
if (dataBuffer.hasRemaining()) {
assertTrue(dataBuffer instanceof HashBasedDataBuffer);
- dataBuffer.reset();
dataBuffer.finish();
while (dataBuffer.hasRemaining()) {
addBufferRead(copyIntoSegment(bufferSize, dataBuffer), buffersRead, numBytesRead);
}
}
- assertEquals(totalBytesWritten, dataBuffer.numTotalBytes());
+ assertEquals(0, dataBuffer.numTotalBytes());
checkWriteReadResult(
numSubpartitions, numBytesWritten, numBytesRead, dataWritten, buffersRead);
}
@@ -300,7 +302,7 @@ public class DataBufferTest {
// append should fail for insufficient capacity
int numRecords = bufferPoolSize - 1;
- long numBytes = useHashBuffer ? bufferSize * bufferPoolSize : bufferSize * numRecords;
+ long numBytes = bufferSize * numRecords;
appendAndCheckResult(dataBuffer, bufferSize + 1, true, numBytes, numRecords, true);
}
@@ -310,9 +312,7 @@ public class DataBufferTest {
int bufferSize = 1024;
DataBuffer dataBuffer = createDataBuffer(bufferPoolSize, bufferSize, 1);
- long numBytes = useHashBuffer ? bufferPoolSize * bufferSize : 0;
- appendAndCheckResult(
- dataBuffer, bufferPoolSize * bufferSize + 1, true, numBytes, 0, useHashBuffer);
+ appendAndCheckResult(dataBuffer, bufferPoolSize * bufferSize + 1, true, 0, 0, false);
}
private void appendAndCheckResult(
@@ -378,8 +378,12 @@ public class DataBufferTest {
NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
+ LinkedList<MemorySegment> segments = new LinkedList<>();
+ for (int i = 0; i < bufferPoolSize; ++i) {
+ segments.add(bufferPool.requestMemorySegmentBlocking());
+ }
DataBuffer dataBuffer =
- new SortBasedDataBuffer(bufferPool, 1, bufferSize, bufferPoolSize, null);
+ new SortBasedDataBuffer(segments, bufferPool, 1, bufferSize, bufferPoolSize, null);
dataBuffer.append(ByteBuffer.allocate(recordSize), 0, Buffer.DataType.DATA_BUFFER);
assertEquals(bufferPoolSize, bufferPool.bestEffortGetNumOfUsedBuffers());
@@ -396,22 +400,36 @@ public class DataBufferTest {
}
private DataBuffer createDataBuffer(int bufferPoolSize, int bufferSize, int numSubpartitions)
- throws IOException {
+ throws Exception {
return createDataBuffer(bufferPoolSize, bufferSize, numSubpartitions, null);
}
private DataBuffer createDataBuffer(
int bufferPoolSize, int bufferSize, int numSubpartitions, int[] customReadOrder)
- throws IOException {
+ throws Exception {
NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
+ LinkedList<MemorySegment> segments = new LinkedList<>();
+ for (int i = 0; i < bufferPoolSize; ++i) {
+ segments.add(bufferPool.requestMemorySegmentBlocking());
+ }
if (useHashBuffer) {
return new HashBasedDataBuffer(
- bufferPool, numSubpartitions, bufferPoolSize, customReadOrder);
+ segments,
+ bufferPool,
+ numSubpartitions,
+ bufferSize,
+ bufferPoolSize,
+ customReadOrder);
} else {
return new SortBasedDataBuffer(
- bufferPool, numSubpartitions, bufferSize, bufferPoolSize, customReadOrder);
+ segments,
+ bufferPool,
+ numSubpartitions,
+ bufferSize,
+ bufferPoolSize,
+ customReadOrder);
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
index 53cb37007a8..646b66de791 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
@@ -246,15 +246,13 @@ public class SortMergeResultPartitionTest extends TestLogger {
@Test
public void testWriteLargeRecord() throws Exception {
int numBuffers = useHashDataBuffer ? 100 : 15;
- int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
ByteBuffer recordWritten = generateRandomData(bufferSize * numBuffers, new Random());
partition.emitRecord(recordWritten, 0);
assertEquals(
- useHashDataBuffer ? numBuffers : numWriteBuffers,
- bufferPool.bestEffortGetNumOfUsedBuffers());
+ useHashDataBuffer ? numBuffers : 0, bufferPool.bestEffortGetNumOfUsedBuffers());
partition.finish();
partition.close();
@@ -309,15 +307,13 @@ public class SortMergeResultPartitionTest extends TestLogger {
@Test(expected = IllegalStateException.class)
public void testReleaseWhileWriting() throws Exception {
int numBuffers = useHashDataBuffer ? 100 : 15;
- int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
- int numBuffersForSort = numBuffers - numWriteBuffers;
BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
- assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+ assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
- partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 0);
- partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 1);
+ partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0);
+ partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1);
partition.emitRecord(ByteBuffer.allocate(bufferSize), 2);
assertNull(partition.getResultFile());
@@ -336,15 +332,13 @@ public class SortMergeResultPartitionTest extends TestLogger {
@Test
public void testRelease() throws Exception {
int numBuffers = useHashDataBuffer ? 100 : 15;
- int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
- int numBuffersForSort = numBuffers - numWriteBuffers;
BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
- assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+ assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
- partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 0);
- partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 1);
+ partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0);
+ partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1);
partition.finish();
partition.close();
@@ -371,17 +365,14 @@ public class SortMergeResultPartitionTest extends TestLogger {
@Test
public void testCloseReleasesAllBuffers() throws Exception {
int numBuffers = useHashDataBuffer ? 100 : 15;
- int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
- int numBuffersForSort = numBuffers - numWriteBuffers;
BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
- assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+ assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
- partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 5);
+ partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 5);
assertEquals(
- useHashDataBuffer ? numBuffers - 1 : numBuffers,
- bufferPool.bestEffortGetNumOfUsedBuffers());
+ useHashDataBuffer ? numBuffers : 0, bufferPool.bestEffortGetNumOfUsedBuffers());
partition.close();
assertTrue(bufferPool.isDestroyed());