You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/07/21 10:13:06 UTC

[flink] branch master updated: [FLINK-28512][network] Select HashBasedDataBuffer and SortBasedDataBuffer dynamically based on the number of network buffers can be allocated for SortMergeResultPartition

This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f0c4ab91a8 [FLINK-28512][network] Select HashBasedDataBuffer and SortBasedDataBuffer dynamically based on the number of network buffers can be allocated for SortMergeResultPartition
5f0c4ab91a8 is described below

commit 5f0c4ab91a8ffc80ff04d0324934b2993fc5b533
Author: Tan Yuxin <ta...@gmail.com>
AuthorDate: Thu Jul 21 18:12:55 2022 +0800

    [FLINK-28512][network] Select HashBasedDataBuffer and SortBasedDataBuffer dynamically based on the number of network buffers can be allocated for SortMergeResultPartition
    
    This closes #20315.
---
 .../runtime/io/network/partition/DataBuffer.java   |   3 -
 .../io/network/partition/HashBasedDataBuffer.java  |  81 ++++++------
 .../io/network/partition/SortBasedDataBuffer.java  |  88 ++++---------
 .../partition/SortMergeResultPartition.java        | 136 ++++++++++++---------
 .../io/network/partition/DataBufferTest.java       |  60 +++++----
 .../partition/SortMergeResultPartitionTest.java    |  29 ++---
 6 files changed, 183 insertions(+), 214 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
index 4d36bfe35ea..ed5e5a42b43 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
@@ -58,9 +58,6 @@ public interface DataBuffer {
     /** Returns true if not all data appended to this {@link DataBuffer} is consumed. */
     boolean hasRemaining();
 
-    /** Resets this {@link DataBuffer} to be reused for data appending. */
-    void reset();
-
     /** Finishes this {@link DataBuffer} which means no record can be appended any more. */
     void finish();
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
index 03410278236..7114b1f41a5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
@@ -23,7 +23,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
-import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 
@@ -32,6 +32,7 @@ import javax.annotation.Nullable;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayDeque;
+import java.util.LinkedList;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -47,8 +48,11 @@ import static org.apache.flink.util.Preconditions.checkState;
  */
 public class HashBasedDataBuffer implements DataBuffer {
 
-    /** A buffer pool to request memory segments from. */
-    private final BufferPool bufferPool;
+    /** A list of {@link MemorySegment}s used to store data in memory. */
+    private final LinkedList<MemorySegment> freeSegments;
+
+    /** {@link BufferRecycler} used to recycle {@link #freeSegments}. */
+    private final BufferRecycler bufferRecycler;
 
     /** Number of guaranteed buffers can be allocated from the buffer pool for data sort. */
     private final int numGuaranteedBuffers;
@@ -56,6 +60,9 @@ public class HashBasedDataBuffer implements DataBuffer {
     /** Buffers containing data for all subpartitions. */
     private final ArrayDeque<BufferConsumer>[] buffers;
 
+    /** Size of buffers requested from buffer pool. All buffers must be of the same size. */
+    private final int bufferSize;
+
     // ---------------------------------------------------------------------------------------------
     // Statistics and states
     // ---------------------------------------------------------------------------------------------
@@ -66,9 +73,6 @@ public class HashBasedDataBuffer implements DataBuffer {
     /** Total number of records already appended to this sort buffer. */
     private long numTotalRecords;
 
-    /** Whether this sort buffer is full and ready to read data from. */
-    private boolean isFull;
-
     /** Whether this sort buffer is finished. One can only read a finished sort buffer. */
     private boolean isFinished;
 
@@ -99,14 +103,19 @@ public class HashBasedDataBuffer implements DataBuffer {
     private long numTotalBytesRead;
 
     public HashBasedDataBuffer(
-            BufferPool bufferPool,
+            LinkedList<MemorySegment> freeSegments,
+            BufferRecycler bufferRecycler,
             int numSubpartitions,
+            int bufferSize,
             int numGuaranteedBuffers,
             @Nullable int[] customReadOrder) {
         checkArgument(numGuaranteedBuffers > 0, "No guaranteed buffers for sort.");
 
-        this.bufferPool = checkNotNull(bufferPool);
+        this.freeSegments = checkNotNull(freeSegments);
+        this.bufferRecycler = checkNotNull(bufferRecycler);
+        this.bufferSize = bufferSize;
         this.numGuaranteedBuffers = numGuaranteedBuffers;
+        checkState(numGuaranteedBuffers <= freeSegments.size(), "Wrong number of free segments.");
 
         this.builders = new BufferBuilder[numSubpartitions];
         this.buffers = new ArrayDeque[numSubpartitions];
@@ -134,7 +143,6 @@ public class HashBasedDataBuffer implements DataBuffer {
     public boolean append(ByteBuffer source, int targetChannel, Buffer.DataType dataType)
             throws IOException {
         checkArgument(source.hasRemaining(), "Cannot append empty data.");
-        checkState(!isFull, "Sort buffer is already full.");
         checkState(!isFinished, "Sort buffer is already finished.");
         checkState(!isReleased, "Sort buffer is already released.");
 
@@ -145,19 +153,18 @@ public class HashBasedDataBuffer implements DataBuffer {
             writeEvent(source, targetChannel, dataType);
         }
 
-        isFull = source.hasRemaining();
-        if (!isFull) {
-            ++numTotalRecords;
+        if (source.hasRemaining()) {
+            return true;
         }
+        ++numTotalRecords;
         numTotalBytes += totalBytes - source.remaining();
-        return isFull;
+        return false;
     }
 
     private void writeEvent(ByteBuffer source, int targetChannel, Buffer.DataType dataType) {
         BufferBuilder builder = builders[targetChannel];
         if (builder != null) {
             builder.finish();
-            buffers[targetChannel].add(builder.createBufferConsumerFromBeginning());
             builder.close();
             builders[targetChannel] = null;
         }
@@ -172,14 +179,19 @@ public class HashBasedDataBuffer implements DataBuffer {
         buffers[targetChannel].add(consumer);
     }
 
-    private void writeRecord(ByteBuffer source, int targetChannel) throws IOException {
+    private void writeRecord(ByteBuffer source, int targetChannel) {
+        BufferBuilder builder = builders[targetChannel];
+        int availableBytes = builder != null ? builder.getWritableBytes() : 0;
+        if (source.remaining()
+                > availableBytes
+                        + (numGuaranteedBuffers - numBuffersOccupied) * (long) bufferSize) {
+            return;
+        }
+
         do {
-            BufferBuilder builder = builders[targetChannel];
             if (builder == null) {
-                builder = requestBufferFromPool();
-                if (builder == null) {
-                    break;
-                }
+                builder = new BufferBuilder(freeSegments.poll(), bufferRecycler);
+                buffers[targetChannel].add(builder.createBufferConsumer());
                 ++numBuffersOccupied;
                 builders[targetChannel] = builder;
             }
@@ -187,29 +199,16 @@ public class HashBasedDataBuffer implements DataBuffer {
             builder.append(source);
             if (builder.isFull()) {
                 builder.finish();
-                buffers[targetChannel].add(builder.createBufferConsumerFromBeginning());
                 builder.close();
                 builders[targetChannel] = null;
+                builder = null;
             }
         } while (source.hasRemaining());
     }
 
-    private BufferBuilder requestBufferFromPool() throws IOException {
-        try {
-            // blocking request buffers if there is still guaranteed memory
-            if (numBuffersOccupied < numGuaranteedBuffers) {
-                return bufferPool.requestBufferBuilderBlocking();
-            }
-        } catch (InterruptedException e) {
-            throw new IOException("Interrupted while requesting buffer.", e);
-        }
-
-        return bufferPool.requestBufferBuilder();
-    }
-
     @Override
     public BufferWithChannel getNextBuffer(MemorySegment transitBuffer) {
-        checkState(isFull, "Sort buffer is not ready to be read.");
+        checkState(isFinished, "Sort buffer is not ready to be read.");
         checkState(!isReleased, "Sort buffer is already released.");
 
         BufferWithChannel buffer = null;
@@ -250,27 +249,15 @@ public class HashBasedDataBuffer implements DataBuffer {
         return numTotalBytesRead < numTotalBytes;
     }
 
-    @Override
-    public void reset() {
-        checkState(!isFinished, "Sort buffer has been finished.");
-        checkState(!isReleased, "Sort buffer has been released.");
-
-        isFull = false;
-        readOrderIndex = 0;
-    }
-
     @Override
     public void finish() {
-        checkState(!isFull, "DataBuffer must not be full.");
         checkState(!isFinished, "DataBuffer is already finished.");
 
-        isFull = true;
         isFinished = true;
         for (int channel = 0; channel < builders.length; ++channel) {
             BufferBuilder builder = builders[channel];
             if (builder != null) {
                 builder.finish();
-                buffers[channel].add(builder.createBufferConsumerFromBeginning());
                 builder.close();
                 builders[channel] = null;
             }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
index cd2acaa6d66..993518295c5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.io.network.partition;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
-import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 
 import javax.annotation.Nullable;
@@ -31,6 +31,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.LinkedList;
 
 import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
 import static org.apache.flink.util.Preconditions.checkArgument;
@@ -58,8 +59,11 @@ public class SortBasedDataBuffer implements DataBuffer {
      */
     private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8;
 
-    /** A buffer pool to request memory segments from. */
-    private final BufferPool bufferPool;
+    /** A list of {@link MemorySegment}s used to store data in memory. */
+    private final LinkedList<MemorySegment> freeSegments;
+
+    /** {@link BufferRecycler} used to recycle {@link #freeSegments}. */
+    private final BufferRecycler bufferRecycler;
 
     /** A segment list as a joint buffer which stores all records and index entries. */
     private final ArrayList<MemorySegment> segments = new ArrayList<>();
@@ -89,9 +93,6 @@ public class SortBasedDataBuffer implements DataBuffer {
     /** Total number of bytes already read from this sort buffer. */
     private long numTotalBytesRead;
 
-    /** Whether this sort buffer is full and ready to read data from. */
-    private boolean isFull;
-
     /** Whether this sort buffer is finished. One can only read a finished sort buffer. */
     private boolean isFinished;
 
@@ -125,7 +126,8 @@ public class SortBasedDataBuffer implements DataBuffer {
     private int readOrderIndex = -1;
 
     public SortBasedDataBuffer(
-            BufferPool bufferPool,
+            LinkedList<MemorySegment> freeSegments,
+            BufferRecycler bufferRecycler,
             int numSubpartitions,
             int bufferSize,
             int numGuaranteedBuffers,
@@ -133,9 +135,11 @@ public class SortBasedDataBuffer implements DataBuffer {
         checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is too small.");
         checkArgument(numGuaranteedBuffers > 0, "No guaranteed buffers for sort.");
 
-        this.bufferPool = checkNotNull(bufferPool);
+        this.freeSegments = checkNotNull(freeSegments);
+        this.bufferRecycler = checkNotNull(bufferRecycler);
         this.bufferSize = bufferSize;
         this.numGuaranteedBuffers = numGuaranteedBuffers;
+        checkState(numGuaranteedBuffers <= freeSegments.size(), "Wrong number of free segments.");
         this.firstIndexEntryAddresses = new long[numSubpartitions];
         this.lastIndexEntryAddresses = new long[numSubpartitions];
 
@@ -162,7 +166,6 @@ public class SortBasedDataBuffer implements DataBuffer {
     public boolean append(ByteBuffer source, int targetChannel, DataType dataType)
             throws IOException {
         checkArgument(source.hasRemaining(), "Cannot append empty data.");
-        checkState(!isFull, "Sort buffer is already full.");
         checkState(!isFinished, "Sort buffer is already finished.");
         checkState(!isReleased, "Sort buffer is already released.");
 
@@ -170,11 +173,6 @@ public class SortBasedDataBuffer implements DataBuffer {
 
         // return true directly if it can not allocate enough buffers for the given record
         if (!allocateBuffersForRecord(totalBytes)) {
-            isFull = true;
-            if (hasRemaining()) {
-                // prepare for reading
-                updateReadChannelAndIndexEntryAddress();
-            }
             return true;
         }
 
@@ -224,7 +222,7 @@ public class SortBasedDataBuffer implements DataBuffer {
         }
     }
 
-    private boolean allocateBuffersForRecord(int numRecordBytes) throws IOException {
+    private boolean allocateBuffersForRecord(int numRecordBytes) {
         int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes;
         int availableBytes =
                 writeSegmentIndex == segments.size() ? 0 : bufferSize - writeSegmentOffset;
@@ -240,16 +238,16 @@ public class SortBasedDataBuffer implements DataBuffer {
             availableBytes = 0;
         }
 
+        if (availableBytes + (numGuaranteedBuffers - segments.size()) * (long) bufferSize
+                < numBytesRequired) {
+            return false;
+        }
+
         // allocate exactly enough buffers for the appended record
         do {
-            MemorySegment segment = requestBufferFromPool();
-            if (segment == null) {
-                // return false if we can not allocate enough buffers for the appended record
-                return false;
-            }
-
+            MemorySegment segment = freeSegments.poll();
             availableBytes += bufferSize;
-            addBuffer(segment);
+            addBuffer(checkNotNull(segment));
         } while (availableBytes < numBytesRequired);
 
         return true;
@@ -257,31 +255,18 @@ public class SortBasedDataBuffer implements DataBuffer {
 
     private void addBuffer(MemorySegment segment) {
         if (segment.size() != bufferSize) {
-            bufferPool.recycle(segment);
+            bufferRecycler.recycle(segment);
             throw new IllegalStateException("Illegal memory segment size.");
         }
 
         if (isReleased) {
-            bufferPool.recycle(segment);
+            bufferRecycler.recycle(segment);
             throw new IllegalStateException("Sort buffer is already released.");
         }
 
         segments.add(segment);
     }
 
-    private MemorySegment requestBufferFromPool() throws IOException {
-        try {
-            // blocking request buffers if there is still guaranteed memory
-            if (segments.size() < numGuaranteedBuffers) {
-                return bufferPool.requestMemorySegmentBlocking();
-            }
-        } catch (InterruptedException e) {
-            throw new IOException("Interrupted while requesting buffer.");
-        }
-
-        return bufferPool.requestMemorySegment();
-    }
-
     private void updateWriteSegmentIndexAndOffset(int numBytes) {
         writeSegmentOffset += numBytes;
 
@@ -294,7 +279,7 @@ public class SortBasedDataBuffer implements DataBuffer {
 
     @Override
     public BufferWithChannel getNextBuffer(MemorySegment transitBuffer) {
-        checkState(isFull, "Sort buffer is not ready to be read.");
+        checkState(isFinished, "Sort buffer is not ready to be read.");
         checkState(!isReleased, "Sort buffer is already released.");
 
         if (!hasRemaining()) {
@@ -426,38 +411,13 @@ public class SortBasedDataBuffer implements DataBuffer {
         return numTotalBytesRead < numTotalBytes;
     }
 
-    @Override
-    public void reset() {
-        checkState(!isFinished, "Sort buffer has been finished.");
-        checkState(!isReleased, "Sort buffer has been released.");
-        checkState(!hasRemaining(), "Still has remaining data.");
-
-        for (MemorySegment segment : segments) {
-            bufferPool.recycle(segment);
-        }
-        segments.clear();
-
-        // initialized with -1 means the corresponding channel has no data
-        Arrays.fill(firstIndexEntryAddresses, -1L);
-        Arrays.fill(lastIndexEntryAddresses, -1L);
-
-        isFull = false;
-        writeSegmentIndex = 0;
-        writeSegmentOffset = 0;
-        readIndexEntryAddress = 0;
-        recordRemainingBytes = 0;
-        readOrderIndex = -1;
-    }
-
     @Override
     public void finish() {
-        checkState(!isFull, "DataBuffer must not be full.");
         checkState(!isFinished, "DataBuffer is already finished.");
 
         isFinished = true;
 
         // prepare for reading
-        isFull = true;
         updateReadChannelAndIndexEntryAddress();
     }
 
@@ -474,7 +434,7 @@ public class SortBasedDataBuffer implements DataBuffer {
         isReleased = true;
 
         for (MemorySegment segment : segments) {
-            bufferPool.recycle(segment);
+            bufferRecycler.recycle(segment);
         }
         segments.clear();
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index 26e648e3379..1620c36beea 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -43,6 +43,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Queue;
 import java.util.Random;
@@ -84,9 +85,6 @@ public class SortMergeResultPartition extends ResultPartition {
     @GuardedBy("lock")
     private PartitionedFile resultFile;
 
-    /** Buffers cut from the network buffer pool for data writing. */
-    private final List<MemorySegment> writeSegments = new ArrayList<>();
-
     private boolean hasNotifiedEndOfUserRecords;
 
     /** Size of network buffer and write buffer. */
@@ -115,6 +113,9 @@ public class SortMergeResultPartition extends ResultPartition {
      */
     private final SortMergeResultPartitionReadScheduler readScheduler;
 
+    /** All available network buffers can be used by this result partition for a data region. */
+    private final LinkedList<MemorySegment> freeSegments = new LinkedList<>();
+
     /**
      * Number of guaranteed network buffers can be used by {@link #unicastDataBuffer} and {@link
      * #broadcastDataBuffer}.
@@ -190,42 +191,7 @@ public class SortMergeResultPartition extends ResultPartition {
         readBufferPool.initialize();
         super.setup();
 
-        int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments();
-        if (numRequiredBuffer < 2) {
-            throw new IOException(
-                    String.format(
-                            "Too few sort buffers, please increase %s.",
-                            NettyShuffleEnvironmentOptions.NETWORK_SORT_SHUFFLE_MIN_BUFFERS));
-        }
-
-        int expectedWriteBuffers = 0;
-        if (numRequiredBuffer >= 2 * numSubpartitions) {
-            useHashBuffer = true;
-        } else if (networkBufferSize >= NUM_WRITE_BUFFER_BYTES) {
-            expectedWriteBuffers = 1;
-        } else {
-            expectedWriteBuffers =
-                    Math.min(EXPECTED_WRITE_BATCH_SIZE, NUM_WRITE_BUFFER_BYTES / networkBufferSize);
-        }
-
-        int numBuffersForWrite = Math.min(numRequiredBuffer / 2, expectedWriteBuffers);
-        numBuffersForSort = numRequiredBuffer - numBuffersForWrite;
-
-        try {
-            for (int i = 0; i < numBuffersForWrite; ++i) {
-                MemorySegment segment = bufferPool.requestMemorySegmentBlocking();
-                writeSegments.add(segment);
-            }
-        } catch (InterruptedException exception) {
-            // the setup method does not allow InterruptedException
-            throw new IOException(exception);
-        }
-
-        LOG.info(
-                "Sort-merge partition {} initialized, num sort buffers: {}, num write buffers: {}.",
-                getPartitionId(),
-                numBuffersForSort,
-                numBuffersForWrite);
+        LOG.info("Sort-merge partition {} initialized.", getPartitionId());
     }
 
     @Override
@@ -296,13 +262,13 @@ public class SortMergeResultPartition extends ResultPartition {
         }
 
         if (!dataBuffer.hasRemaining()) {
-            dataBuffer.reset();
+            dataBuffer.release();
             writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
             return;
         }
 
         flushDataBuffer(dataBuffer, isBroadcast);
-        dataBuffer.reset();
+        dataBuffer.release();
         if (record.hasRemaining()) {
             emit(record, targetSubpartition, dataType, isBroadcast);
         }
@@ -317,7 +283,9 @@ public class SortMergeResultPartition extends ResultPartition {
     private DataBuffer getUnicastDataBuffer() throws IOException {
         flushBroadcastDataBuffer();
 
-        if (unicastDataBuffer != null && !unicastDataBuffer.isFinished()) {
+        if (unicastDataBuffer != null
+                && !unicastDataBuffer.isFinished()
+                && !unicastDataBuffer.isReleased()) {
             return unicastDataBuffer;
         }
 
@@ -328,7 +296,9 @@ public class SortMergeResultPartition extends ResultPartition {
     private DataBuffer getBroadcastDataBuffer() throws IOException {
         flushUnicastDataBuffer();
 
-        if (broadcastDataBuffer != null && !broadcastDataBuffer.isFinished()) {
+        if (broadcastDataBuffer != null
+                && !broadcastDataBuffer.isFinished()
+                && !broadcastDataBuffer.isReleased()) {
             return broadcastDataBuffer;
         }
 
@@ -336,12 +306,20 @@ public class SortMergeResultPartition extends ResultPartition {
         return broadcastDataBuffer;
     }
 
-    private DataBuffer createNewDataBuffer() {
+    private DataBuffer createNewDataBuffer() throws IOException {
+        requestNetworkBuffers();
+
         if (useHashBuffer) {
             return new HashBasedDataBuffer(
-                    bufferPool, numSubpartitions, numBuffersForSort, subpartitionOrder);
+                    freeSegments,
+                    bufferPool,
+                    numSubpartitions,
+                    networkBufferSize,
+                    numBuffersForSort,
+                    subpartitionOrder);
         } else {
             return new SortBasedDataBuffer(
+                    freeSegments,
                     bufferPool,
                     numSubpartitions,
                     networkBufferSize,
@@ -350,12 +328,53 @@ public class SortMergeResultPartition extends ResultPartition {
         }
     }
 
+    private void requestNetworkBuffers() throws IOException {
+        int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments();
+        if (numRequiredBuffer < 2) {
+            throw new IOException(
+                    String.format(
+                            "Too few sort buffers, please increase %s.",
+                            NettyShuffleEnvironmentOptions.NETWORK_SORT_SHUFFLE_MIN_BUFFERS));
+        }
+
+        try {
+            while (freeSegments.size() < numRequiredBuffer) {
+                freeSegments.add(checkNotNull(bufferPool.requestMemorySegmentBlocking()));
+            }
+        } catch (InterruptedException exception) {
+            throw new IOException("Failed to allocate buffers for result partition.", exception);
+        }
+
+        // avoid taking too many buffers in one result partition
+        while (freeSegments.size() < bufferPool.getMaxNumberOfMemorySegments()) {
+            MemorySegment segment = bufferPool.requestMemorySegment();
+            if (segment == null) {
+                break;
+            }
+            freeSegments.add(segment);
+        }
+
+        useHashBuffer = false;
+        int numWriteBuffers = 0;
+        if (freeSegments.size() >= 2 * numSubpartitions) {
+            useHashBuffer = true;
+        } else if (networkBufferSize >= NUM_WRITE_BUFFER_BYTES) {
+            numWriteBuffers = 1;
+        } else {
+            numWriteBuffers =
+                    Math.min(EXPECTED_WRITE_BATCH_SIZE, NUM_WRITE_BUFFER_BYTES / networkBufferSize);
+        }
+        numWriteBuffers = Math.min(freeSegments.size() / 2, numWriteBuffers);
+        numBuffersForSort = freeSegments.size() - numWriteBuffers;
+    }
+
     private void flushDataBuffer(DataBuffer dataBuffer, boolean isBroadcast) throws IOException {
         if (dataBuffer == null || dataBuffer.isReleased() || !dataBuffer.hasRemaining()) {
             return;
         }
+        dataBuffer.finish();
 
-        Queue<MemorySegment> segments = new ArrayDeque<>(writeSegments);
+        Queue<MemorySegment> segments = new ArrayDeque<>(freeSegments);
         int numBuffersToWrite =
                 useHashBuffer
                         ? EXPECTED_WRITE_BATCH_SIZE
@@ -366,7 +385,7 @@ public class SortMergeResultPartition extends ResultPartition {
         do {
             if (toWrite.size() >= numBuffersToWrite) {
                 writeBuffers(toWrite);
-                segments = new ArrayDeque<>(writeSegments);
+                segments = new ArrayDeque<>(freeSegments);
             }
 
             BufferWithChannel bufferWithChannel = dataBuffer.getNextBuffer(segments.poll());
@@ -378,11 +397,12 @@ public class SortMergeResultPartition extends ResultPartition {
             updateStatistics(bufferWithChannel.getBuffer(), isBroadcast);
             toWrite.add(compressBufferIfPossible(bufferWithChannel));
         } while (true);
+
+        releaseFreeBuffers();
     }
 
     private void flushBroadcastDataBuffer() throws IOException {
         if (broadcastDataBuffer != null) {
-            broadcastDataBuffer.finish();
             flushDataBuffer(broadcastDataBuffer, true);
             broadcastDataBuffer.release();
             broadcastDataBuffer = null;
@@ -391,7 +411,6 @@ public class SortMergeResultPartition extends ResultPartition {
 
     private void flushUnicastDataBuffer() throws IOException {
         if (unicastDataBuffer != null) {
-            unicastDataBuffer.finish();
             flushDataBuffer(unicastDataBuffer, false);
             unicastDataBuffer.release();
             unicastDataBuffer = null;
@@ -421,19 +440,17 @@ public class SortMergeResultPartition extends ResultPartition {
     private void writeLargeRecord(
             ByteBuffer record, int targetSubpartition, DataType dataType, boolean isBroadcast)
             throws IOException {
-        // for the hash-based data buffer implementation, a large record will be appended to the
-        // data buffer directly and spilled to multiple data regions
-        checkState(!useHashBuffer, "No buffers available for writing.");
+        // a large record will be spilled to a separated data region
         fileWriter.startNewRegion(isBroadcast);
 
         List<BufferWithChannel> toWrite = new ArrayList<>();
-        Queue<MemorySegment> segments = new ArrayDeque<>(writeSegments);
+        Queue<MemorySegment> segments = new ArrayDeque<>(freeSegments);
 
         while (record.hasRemaining()) {
             if (segments.isEmpty()) {
                 fileWriter.writeBuffers(toWrite);
                 toWrite.clear();
-                segments = new ArrayDeque<>(writeSegments);
+                segments = new ArrayDeque<>(freeSegments);
             }
 
             int toCopy = Math.min(record.remaining(), networkBufferSize);
@@ -447,6 +464,7 @@ public class SortMergeResultPartition extends ResultPartition {
         }
 
         fileWriter.writeBuffers(toWrite);
+        releaseFreeBuffers();
     }
 
     private void writeBuffers(List<BufferWithChannel> buffers) throws IOException {
@@ -480,18 +498,16 @@ public class SortMergeResultPartition extends ResultPartition {
         }
     }
 
-    private void releaseWriteBuffers() {
+    private void releaseFreeBuffers() {
         if (bufferPool != null) {
-            for (MemorySegment segment : writeSegments) {
-                bufferPool.recycle(segment);
-            }
-            writeSegments.clear();
+            freeSegments.forEach(buffer -> bufferPool.recycle(buffer));
+            freeSegments.clear();
         }
     }
 
     @Override
     public void close() {
-        releaseWriteBuffers();
+        releaseFreeBuffers();
         // the close method will be always called by the task thread, so there is need to make
         // the sort buffer fields volatile and visible to the cancel thread intermediately
         releaseDataBuffer(unicastDataBuffer);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
index a6b8b6cdfb2..bd2583e8f32 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
@@ -34,6 +34,7 @@ import java.nio.ByteBuffer;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Queue;
 import java.util.Random;
@@ -81,12 +82,10 @@ public class DataBufferTest {
 
         // fill the sort buffer with randomly generated data
         int totalBytesWritten = 0;
+        int[] subpartitionReadOrder = getRandomSubpartitionOrder(numSubpartitions);
         DataBuffer dataBuffer =
                 createDataBuffer(
-                        bufferPoolSize,
-                        bufferSize,
-                        numSubpartitions,
-                        getRandomSubpartitionOrder(numSubpartitions));
+                        bufferPoolSize, bufferSize, numSubpartitions, subpartitionReadOrder);
         int numDataBuffers = 5;
         while (numDataBuffers > 0) {
             // record size may be larger than buffer size so a record may span multiple segments
@@ -112,31 +111,34 @@ public class DataBufferTest {
                 totalBytesWritten += record.remaining();
             }
 
-            while (isFull && dataBuffer.hasRemaining()) {
+            if (!isFull) {
+                continue;
+            }
+            dataBuffer.finish();
+            --numDataBuffers;
+
+            while (dataBuffer.hasRemaining()) {
                 BufferWithChannel buffer = copyIntoSegment(bufferSize, dataBuffer);
                 if (buffer == null) {
                     break;
                 }
                 addBufferRead(buffer, buffersRead, numBytesRead);
             }
-
-            if (isFull) {
-                --numDataBuffers;
-                dataBuffer.reset();
-            }
+            dataBuffer =
+                    createDataBuffer(
+                            bufferPoolSize, bufferSize, numSubpartitions, subpartitionReadOrder);
         }
 
         // read all data from the sort buffer
         if (dataBuffer.hasRemaining()) {
             assertTrue(dataBuffer instanceof HashBasedDataBuffer);
-            dataBuffer.reset();
             dataBuffer.finish();
             while (dataBuffer.hasRemaining()) {
                 addBufferRead(copyIntoSegment(bufferSize, dataBuffer), buffersRead, numBytesRead);
             }
         }
 
-        assertEquals(totalBytesWritten, dataBuffer.numTotalBytes());
+        assertEquals(0, dataBuffer.numTotalBytes());
         checkWriteReadResult(
                 numSubpartitions, numBytesWritten, numBytesRead, dataWritten, buffersRead);
     }
@@ -300,7 +302,7 @@ public class DataBufferTest {
 
         // append should fail for insufficient capacity
         int numRecords = bufferPoolSize - 1;
-        long numBytes = useHashBuffer ? bufferSize * bufferPoolSize : bufferSize * numRecords;
+        long numBytes = bufferSize * numRecords;
         appendAndCheckResult(dataBuffer, bufferSize + 1, true, numBytes, numRecords, true);
     }
 
@@ -310,9 +312,7 @@ public class DataBufferTest {
         int bufferSize = 1024;
 
         DataBuffer dataBuffer = createDataBuffer(bufferPoolSize, bufferSize, 1);
-        long numBytes = useHashBuffer ? bufferPoolSize * bufferSize : 0;
-        appendAndCheckResult(
-                dataBuffer, bufferPoolSize * bufferSize + 1, true, numBytes, 0, useHashBuffer);
+        appendAndCheckResult(dataBuffer, bufferPoolSize * bufferSize + 1, true, 0, 0, false);
     }
 
     private void appendAndCheckResult(
@@ -378,8 +378,12 @@ public class DataBufferTest {
         NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
         BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
 
+        LinkedList<MemorySegment> segments = new LinkedList<>();
+        for (int i = 0; i < bufferPoolSize; ++i) {
+            segments.add(bufferPool.requestMemorySegmentBlocking());
+        }
         DataBuffer dataBuffer =
-                new SortBasedDataBuffer(bufferPool, 1, bufferSize, bufferPoolSize, null);
+                new SortBasedDataBuffer(segments, bufferPool, 1, bufferSize, bufferPoolSize, null);
         dataBuffer.append(ByteBuffer.allocate(recordSize), 0, Buffer.DataType.DATA_BUFFER);
 
         assertEquals(bufferPoolSize, bufferPool.bestEffortGetNumOfUsedBuffers());
@@ -396,22 +400,36 @@ public class DataBufferTest {
     }
 
     private DataBuffer createDataBuffer(int bufferPoolSize, int bufferSize, int numSubpartitions)
-            throws IOException {
+            throws Exception {
         return createDataBuffer(bufferPoolSize, bufferSize, numSubpartitions, null);
     }
 
     private DataBuffer createDataBuffer(
             int bufferPoolSize, int bufferSize, int numSubpartitions, int[] customReadOrder)
-            throws IOException {
+            throws Exception {
         NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
         BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
 
+        LinkedList<MemorySegment> segments = new LinkedList<>();
+        for (int i = 0; i < bufferPoolSize; ++i) {
+            segments.add(bufferPool.requestMemorySegmentBlocking());
+        }
         if (useHashBuffer) {
             return new HashBasedDataBuffer(
-                    bufferPool, numSubpartitions, bufferPoolSize, customReadOrder);
+                    segments,
+                    bufferPool,
+                    numSubpartitions,
+                    bufferSize,
+                    bufferPoolSize,
+                    customReadOrder);
         } else {
             return new SortBasedDataBuffer(
-                    bufferPool, numSubpartitions, bufferSize, bufferPoolSize, customReadOrder);
+                    segments,
+                    bufferPool,
+                    numSubpartitions,
+                    bufferSize,
+                    bufferPoolSize,
+                    customReadOrder);
         }
     }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
index 53cb37007a8..646b66de791 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
@@ -246,15 +246,13 @@ public class SortMergeResultPartitionTest extends TestLogger {
     @Test
     public void testWriteLargeRecord() throws Exception {
         int numBuffers = useHashDataBuffer ? 100 : 15;
-        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
 
         ByteBuffer recordWritten = generateRandomData(bufferSize * numBuffers, new Random());
         partition.emitRecord(recordWritten, 0);
         assertEquals(
-                useHashDataBuffer ? numBuffers : numWriteBuffers,
-                bufferPool.bestEffortGetNumOfUsedBuffers());
+                useHashDataBuffer ? numBuffers : 0, bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.finish();
         partition.close();
@@ -309,15 +307,13 @@ public class SortMergeResultPartitionTest extends TestLogger {
     @Test(expected = IllegalStateException.class)
     public void testReleaseWhileWriting() throws Exception {
         int numBuffers = useHashDataBuffer ? 100 : 15;
-        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
-        int numBuffersForSort = numBuffers - numWriteBuffers;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
-        assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+        assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
 
-        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 0);
-        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 1);
+        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0);
+        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1);
 
         partition.emitRecord(ByteBuffer.allocate(bufferSize), 2);
         assertNull(partition.getResultFile());
@@ -336,15 +332,13 @@ public class SortMergeResultPartitionTest extends TestLogger {
     @Test
     public void testRelease() throws Exception {
         int numBuffers = useHashDataBuffer ? 100 : 15;
-        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
-        int numBuffersForSort = numBuffers - numWriteBuffers;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
-        assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+        assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
 
-        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 0);
-        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 1);
+        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0);
+        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1);
         partition.finish();
         partition.close();
 
@@ -371,17 +365,14 @@ public class SortMergeResultPartitionTest extends TestLogger {
     @Test
     public void testCloseReleasesAllBuffers() throws Exception {
         int numBuffers = useHashDataBuffer ? 100 : 15;
-        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
-        int numBuffersForSort = numBuffers - numWriteBuffers;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
-        assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+        assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
 
-        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 5);
+        partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 5);
         assertEquals(
-                useHashDataBuffer ? numBuffers - 1 : numBuffers,
-                bufferPool.bestEffortGetNumOfUsedBuffers());
+                useHashDataBuffer ? numBuffers : 0, bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.close();
         assertTrue(bufferPool.isDestroyed());