You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/02/12 03:49:36 UTC

[flink] branch master updated: [FLINK-25796][network] Avoid record copy for result partition of sort-shuffle if there are enough buffers for better performance

This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 3be35d9  [FLINK-25796][network] Avoid record copy for result partition of sort-shuffle if there are enough buffers for better performance
3be35d9 is described below

commit 3be35d9c64e4b28cc73e157325c935d005286d99
Author: kevin.cyj <ke...@alibaba-inc.com>
AuthorDate: Tue Jan 11 11:17:12 2022 +0800

    [FLINK-25796][network] Avoid record copy for result partition of sort-shuffle if there are enough buffers for better performance
    
    Currently, for result partition of sort-shuffle, there is extra record copy overhead introduced by clustering records by subpartition index. For small records, this overhead can cause even 20% performance regression. This patch aims to solve the problem.
    
    In fact, the hash-based implementation is a nature way to achieve the goal of sorting records by partition index. However, it incurs some serious weaknesses. For example, when there is no enough buffers or there is data skew, it can waste buffers and influence compression efficiency which can cause performance regression.
    
    This patch solves the issue by dynamically switching between the two implementations, that is, if there are enough buffers, the hash-based implementation will be used and if there is no enough buffers, the sort-based implementation will be used.
    
    This closes #18505.
---
 .../partition/{SortBuffer.java => DataBuffer.java} |  41 +--
 .../io/network/partition/HashBasedDataBuffer.java  | 313 +++++++++++++++++++++
 ...nSortedBuffer.java => SortBasedDataBuffer.java} |  83 ++++--
 .../partition/SortMergeResultPartition.java        | 228 ++++++++-------
 ...onSortedBufferTest.java => DataBufferTest.java} | 241 ++++++++++------
 .../partition/PartitionedFileWriteReadTest.java    |   3 +-
 .../partition/SortMergeResultPartitionTest.java    |  66 +++--
 7 files changed, 727 insertions(+), 248 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
similarity index 53%
rename from flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBuffer.java
rename to flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
index eb0d590..4d36bfe 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/DataBuffer.java
@@ -21,48 +21,55 @@ package org.apache.flink.runtime.io.network.partition;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.nio.ByteBuffer;
 
 /**
- * Data of different channels can be appended to a {@link SortBuffer} and after the {@link
- * SortBuffer} is finished, the appended data can be copied from it in channel index order.
+ * Data of different channels can be appended to a {@link DataBuffer} and after the {@link
+ * DataBuffer} is full or finished, the appended data can be copied from it in channel index order.
+ *
+ * <p>The lifecycle of a {@link DataBuffer} can be: new, write, [read, reset, write], finish, read,
+ * release. There can be multiple [read, reset, write] operations before finish.
  */
-public interface SortBuffer {
+public interface DataBuffer {
 
     /**
-     * Appends data of the specified channel to this {@link SortBuffer} and returns true if all
-     * bytes of the source buffer is copied to this {@link SortBuffer} successfully, otherwise if
-     * returns false, nothing will be copied.
+     * Appends data of the specified channel to this {@link DataBuffer} and returns true if this
+     * {@link DataBuffer} is full.
      */
     boolean append(ByteBuffer source, int targetChannel, Buffer.DataType dataType)
             throws IOException;
 
     /**
-     * Copies data in this {@link SortBuffer} to the target {@link MemorySegment} in channel index
+     * Copies data in this {@link DataBuffer} to the target {@link MemorySegment} in channel index
      * order and returns {@link BufferWithChannel} which contains the copied data and the
      * corresponding channel index.
      */
-    BufferWithChannel copyIntoSegment(MemorySegment target);
+    BufferWithChannel getNextBuffer(@Nullable MemorySegment transitBuffer);
 
-    /** Returns the number of records written to this {@link SortBuffer}. */
-    long numRecords();
+    /** Returns the total number of records written to this {@link DataBuffer}. */
+    long numTotalRecords();
 
-    /** Returns the number of bytes written to this {@link SortBuffer}. */
-    long numBytes();
+    /** Returns the total number of bytes written to this {@link DataBuffer}. */
+    long numTotalBytes();
 
-    /** Returns true if there is still data can be consumed in this {@link SortBuffer}. */
+    /** Returns true if not all data appended to this {@link DataBuffer} is consumed. */
     boolean hasRemaining();
 
-    /** Finishes this {@link SortBuffer} which means no record can be appended any more. */
+    /** Resets this {@link DataBuffer} to be reused for data appending. */
+    void reset();
+
+    /** Finishes this {@link DataBuffer} which means no record can be appended any more. */
     void finish();
 
-    /** Whether this {@link SortBuffer} is finished or not. */
+    /** Whether this {@link DataBuffer} is finished or not. */
     boolean isFinished();
 
-    /** Releases this {@link SortBuffer} which releases all resources. */
+    /** Releases this {@link DataBuffer} which releases all resources. */
     void release();
 
-    /** Whether this {@link SortBuffer} is released or not. */
+    /** Whether this {@link DataBuffer} is released or not. */
     boolean isReleased();
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
new file mode 100644
index 0000000..0341027
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/HashBasedDataBuffer.java
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayDeque;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * * A {@link DataBuffer} implementation which sorts all appended records only by subpartition
+ * index. Records of the same subpartition keep the appended order.
+ *
+ * <p>Different from the {@link SortBasedDataBuffer}, in this {@link DataBuffer} implementation,
+ * memory segment boundary serves as the nature data boundary of different subpartitions, which
+ * means that one memory segment can never contain data from different subpartitions.
+ */
+public class HashBasedDataBuffer implements DataBuffer {
+
+    /** A buffer pool to request memory segments from. */
+    private final BufferPool bufferPool;
+
+    /** Number of guaranteed buffers can be allocated from the buffer pool for data sort. */
+    private final int numGuaranteedBuffers;
+
+    /** Buffers containing data for all subpartitions. */
+    private final ArrayDeque<BufferConsumer>[] buffers;
+
+    // ---------------------------------------------------------------------------------------------
+    // Statistics and states
+    // ---------------------------------------------------------------------------------------------
+
+    /** Total number of bytes already appended to this sort buffer. */
+    private long numTotalBytes;
+
+    /** Total number of records already appended to this sort buffer. */
+    private long numTotalRecords;
+
+    /** Whether this sort buffer is full and ready to read data from. */
+    private boolean isFull;
+
+    /** Whether this sort buffer is finished. One can only read a finished sort buffer. */
+    private boolean isFinished;
+
+    /** Whether this sort buffer is released. A released sort buffer can not be used. */
+    private boolean isReleased;
+
+    // ---------------------------------------------------------------------------------------------
+    // For writing
+    // ---------------------------------------------------------------------------------------------
+
+    /** Partial buffers to be appended data for each channel. */
+    private final BufferBuilder[] builders;
+
+    /** Total number of network buffers already occupied currently by this sort buffer. */
+    private int numBuffersOccupied;
+
+    // ---------------------------------------------------------------------------------------------
+    // For reading
+    // ---------------------------------------------------------------------------------------------
+
+    /** Used to index the current available channel to read data from. */
+    private int readOrderIndex;
+
+    /** Data of different subpartitions in this sort buffer will be read in this order. */
+    private final int[] subpartitionReadOrder;
+
+    /** Total number of bytes already read from this sort buffer. */
+    private long numTotalBytesRead;
+
+    public HashBasedDataBuffer(
+            BufferPool bufferPool,
+            int numSubpartitions,
+            int numGuaranteedBuffers,
+            @Nullable int[] customReadOrder) {
+        checkArgument(numGuaranteedBuffers > 0, "No guaranteed buffers for sort.");
+
+        this.bufferPool = checkNotNull(bufferPool);
+        this.numGuaranteedBuffers = numGuaranteedBuffers;
+
+        this.builders = new BufferBuilder[numSubpartitions];
+        this.buffers = new ArrayDeque[numSubpartitions];
+        for (int channel = 0; channel < numSubpartitions; ++channel) {
+            this.buffers[channel] = new ArrayDeque<>();
+        }
+
+        this.subpartitionReadOrder = new int[numSubpartitions];
+        if (customReadOrder != null) {
+            checkArgument(customReadOrder.length == numSubpartitions, "Illegal data read order.");
+            System.arraycopy(customReadOrder, 0, this.subpartitionReadOrder, 0, numSubpartitions);
+        } else {
+            for (int channel = 0; channel < numSubpartitions; ++channel) {
+                this.subpartitionReadOrder[channel] = channel;
+            }
+        }
+    }
+
+    /**
+     * Partial data of the target record can be written if this {@link HashBasedDataBuffer} is full.
+     * The remaining data of the target record will be written to the next data region (a new data
+     * buffer or this data buffer after reset).
+     */
+    @Override
+    public boolean append(ByteBuffer source, int targetChannel, Buffer.DataType dataType)
+            throws IOException {
+        checkArgument(source.hasRemaining(), "Cannot append empty data.");
+        checkState(!isFull, "Sort buffer is already full.");
+        checkState(!isFinished, "Sort buffer is already finished.");
+        checkState(!isReleased, "Sort buffer is already released.");
+
+        int totalBytes = source.remaining();
+        if (dataType.isBuffer()) {
+            writeRecord(source, targetChannel);
+        } else {
+            writeEvent(source, targetChannel, dataType);
+        }
+
+        isFull = source.hasRemaining();
+        if (!isFull) {
+            ++numTotalRecords;
+        }
+        numTotalBytes += totalBytes - source.remaining();
+        return isFull;
+    }
+
+    private void writeEvent(ByteBuffer source, int targetChannel, Buffer.DataType dataType) {
+        BufferBuilder builder = builders[targetChannel];
+        if (builder != null) {
+            builder.finish();
+            buffers[targetChannel].add(builder.createBufferConsumerFromBeginning());
+            builder.close();
+            builders[targetChannel] = null;
+        }
+
+        MemorySegment segment =
+                MemorySegmentFactory.allocateUnpooledOffHeapMemory(source.remaining());
+        segment.put(0, source, segment.size());
+        BufferConsumer consumer =
+                new BufferConsumer(
+                        new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE, dataType),
+                        segment.size());
+        buffers[targetChannel].add(consumer);
+    }
+
+    private void writeRecord(ByteBuffer source, int targetChannel) throws IOException {
+        do {
+            BufferBuilder builder = builders[targetChannel];
+            if (builder == null) {
+                builder = requestBufferFromPool();
+                if (builder == null) {
+                    break;
+                }
+                ++numBuffersOccupied;
+                builders[targetChannel] = builder;
+            }
+
+            builder.append(source);
+            if (builder.isFull()) {
+                builder.finish();
+                buffers[targetChannel].add(builder.createBufferConsumerFromBeginning());
+                builder.close();
+                builders[targetChannel] = null;
+            }
+        } while (source.hasRemaining());
+    }
+
+    private BufferBuilder requestBufferFromPool() throws IOException {
+        try {
+            // blocking request buffers if there is still guaranteed memory
+            if (numBuffersOccupied < numGuaranteedBuffers) {
+                return bufferPool.requestBufferBuilderBlocking();
+            }
+        } catch (InterruptedException e) {
+            throw new IOException("Interrupted while requesting buffer.", e);
+        }
+
+        return bufferPool.requestBufferBuilder();
+    }
+
+    @Override
+    public BufferWithChannel getNextBuffer(MemorySegment transitBuffer) {
+        checkState(isFull, "Sort buffer is not ready to be read.");
+        checkState(!isReleased, "Sort buffer is already released.");
+
+        BufferWithChannel buffer = null;
+        if (!hasRemaining() || readOrderIndex >= subpartitionReadOrder.length) {
+            return null;
+        }
+
+        int targetChannel = subpartitionReadOrder[readOrderIndex];
+        while (buffer == null) {
+            BufferConsumer consumer = buffers[targetChannel].poll();
+            if (consumer != null) {
+                buffer = new BufferWithChannel(consumer.build(), targetChannel);
+                numBuffersOccupied -= buffer.getBuffer().isBuffer() ? 1 : 0;
+                numTotalBytesRead += buffer.getBuffer().readableBytes();
+                consumer.close();
+            } else {
+                if (++readOrderIndex >= subpartitionReadOrder.length) {
+                    break;
+                }
+                targetChannel = subpartitionReadOrder[readOrderIndex];
+            }
+        }
+        return buffer;
+    }
+
+    @Override
+    public long numTotalRecords() {
+        return numTotalRecords;
+    }
+
+    @Override
+    public long numTotalBytes() {
+        return numTotalBytes;
+    }
+
+    @Override
+    public boolean hasRemaining() {
+        return numTotalBytesRead < numTotalBytes;
+    }
+
+    @Override
+    public void reset() {
+        checkState(!isFinished, "Sort buffer has been finished.");
+        checkState(!isReleased, "Sort buffer has been released.");
+
+        isFull = false;
+        readOrderIndex = 0;
+    }
+
+    @Override
+    public void finish() {
+        checkState(!isFull, "DataBuffer must not be full.");
+        checkState(!isFinished, "DataBuffer is already finished.");
+
+        isFull = true;
+        isFinished = true;
+        for (int channel = 0; channel < builders.length; ++channel) {
+            BufferBuilder builder = builders[channel];
+            if (builder != null) {
+                builder.finish();
+                buffers[channel].add(builder.createBufferConsumerFromBeginning());
+                builder.close();
+                builders[channel] = null;
+            }
+        }
+    }
+
+    @Override
+    public boolean isFinished() {
+        return isFinished;
+    }
+
+    @Override
+    public void release() {
+        if (isReleased) {
+            return;
+        }
+        isReleased = true;
+
+        for (int channel = 0; channel < builders.length; ++channel) {
+            BufferBuilder builder = builders[channel];
+            if (builder != null) {
+                builder.close();
+                builders[channel] = null;
+            }
+        }
+
+        for (ArrayDeque<BufferConsumer> buffer : buffers) {
+            BufferConsumer consumer = buffer.poll();
+            while (consumer != null) {
+                consumer.close();
+                consumer = buffer.poll();
+            }
+        }
+    }
+
+    @Override
+    public boolean isReleased() {
+        return isReleased;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
similarity index 87%
rename from flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBuffer.java
rename to flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
index 9f0ca0d..cd2acaa 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBuffer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortBasedDataBuffer.java
@@ -38,19 +38,19 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
- * A {@link SortBuffer} implementation which sorts all appended records only by subpartition index.
+ * A {@link DataBuffer} implementation which sorts all appended records only by subpartition index.
  * Records of the same subpartition keep the appended order.
  *
  * <p>It maintains a list of {@link MemorySegment}s as a joint buffer. Data will be appended to the
  * joint buffer sequentially. When writing a record, an index entry will be appended first. An index
  * entry consists of 4 fields: 4 bytes for record length, 4 bytes for {@link DataType} and 8 bytes
  * for address pointing to the next index entry of the same channel which will be used to index the
- * next record to read when coping data from this {@link SortBuffer}. For simplicity, no index entry
+ * next record to read when coping data from this {@link DataBuffer}. For simplicity, no index entry
  * can span multiple segments. The corresponding record data is seated right after its index entry
  * and different from the index entry, records have variable length thus may span multiple segments.
  */
 @NotThreadSafe
-public class PartitionSortedBuffer implements SortBuffer {
+public class SortBasedDataBuffer implements DataBuffer {
 
     /**
      * Size of an index entry: 4 bytes for record length, 4 bytes for data type and 8 bytes for
@@ -89,6 +89,9 @@ public class PartitionSortedBuffer implements SortBuffer {
     /** Total number of bytes already read from this sort buffer. */
     private long numTotalBytesRead;
 
+    /** Whether this sort buffer is full and ready to read data from. */
+    private boolean isFull;
+
     /** Whether this sort buffer is finished. One can only read a finished sort buffer. */
     private boolean isFinished;
 
@@ -121,7 +124,7 @@ public class PartitionSortedBuffer implements SortBuffer {
     /** Used to index the current available channel to read data from. */
     private int readOrderIndex = -1;
 
-    public PartitionSortedBuffer(
+    public SortBasedDataBuffer(
             BufferPool bufferPool,
             int numSubpartitions,
             int bufferSize,
@@ -151,18 +154,28 @@ public class PartitionSortedBuffer implements SortBuffer {
         }
     }
 
+    /**
+     * No partial record will be written to this {@link SortBasedDataBuffer}, which means that
+     * either all data of target record will be written or nothing will be written.
+     */
     @Override
     public boolean append(ByteBuffer source, int targetChannel, DataType dataType)
             throws IOException {
         checkArgument(source.hasRemaining(), "Cannot append empty data.");
+        checkState(!isFull, "Sort buffer is already full.");
         checkState(!isFinished, "Sort buffer is already finished.");
         checkState(!isReleased, "Sort buffer is already released.");
 
         int totalBytes = source.remaining();
 
-        // return false directly if it can not allocate enough buffers for the given record
+        // return true directly if it can not allocate enough buffers for the given record
         if (!allocateBuffersForRecord(totalBytes)) {
-            return false;
+            isFull = true;
+            if (hasRemaining()) {
+                // prepare for reading
+                updateReadChannelAndIndexEntryAddress();
+            }
+            return true;
         }
 
         // write the index entry and record or event data
@@ -172,7 +185,7 @@ public class PartitionSortedBuffer implements SortBuffer {
         ++numTotalRecords;
         numTotalBytes += totalBytes;
 
-        return true;
+        return false;
     }
 
     private void writeIndex(int channelIndex, int numRecordBytes, Buffer.DataType dataType) {
@@ -280,11 +293,14 @@ public class PartitionSortedBuffer implements SortBuffer {
     }
 
     @Override
-    public BufferWithChannel copyIntoSegment(MemorySegment target) {
-        checkState(hasRemaining(), "No data remaining.");
-        checkState(isFinished, "Should finish the sort buffer first before coping any data.");
+    public BufferWithChannel getNextBuffer(MemorySegment transitBuffer) {
+        checkState(isFull, "Sort buffer is not ready to be read.");
         checkState(!isReleased, "Sort buffer is already released.");
 
+        if (!hasRemaining()) {
+            return null;
+        }
+
         int numBytesCopied = 0;
         DataType bufferDataType = DataType.DATA_BUFFER;
         int channelIndex = subpartitionReadOrder[readOrderIndex];
@@ -309,13 +325,13 @@ public class PartitionSortedBuffer implements SortBuffer {
             sourceSegmentOffset += INDEX_ENTRY_SIZE;
 
             // allocate a temp buffer for the event if the target buffer is not big enough
-            if (bufferDataType.isEvent() && target.size() < length) {
-                target = MemorySegmentFactory.allocateUnpooledSegment(length);
+            if (bufferDataType.isEvent() && transitBuffer.size() < length) {
+                transitBuffer = MemorySegmentFactory.allocateUnpooledSegment(length);
             }
 
             numBytesCopied +=
                     copyRecordOrEvent(
-                            target,
+                            transitBuffer,
                             numBytesCopied,
                             sourceSegmentIndex,
                             sourceSegmentOffset,
@@ -329,10 +345,11 @@ public class PartitionSortedBuffer implements SortBuffer {
                 }
                 readIndexEntryAddress = nextReadIndexEntryAddress;
             }
-        } while (numBytesCopied < target.size() && bufferDataType.isBuffer());
+        } while (numBytesCopied < transitBuffer.size() && bufferDataType.isBuffer());
 
         numTotalBytesRead += numBytesCopied;
-        Buffer buffer = new NetworkBuffer(target, (buf) -> {}, bufferDataType, numBytesCopied);
+        Buffer buffer =
+                new NetworkBuffer(transitBuffer, (buf) -> {}, bufferDataType, numBytesCopied);
         return new BufferWithChannel(buffer, channelIndex);
     }
 
@@ -395,12 +412,12 @@ public class PartitionSortedBuffer implements SortBuffer {
     }
 
     @Override
-    public long numRecords() {
+    public long numTotalRecords() {
         return numTotalRecords;
     }
 
     @Override
-    public long numBytes() {
+    public long numTotalBytes() {
         return numTotalBytes;
     }
 
@@ -410,12 +427,37 @@ public class PartitionSortedBuffer implements SortBuffer {
     }
 
     @Override
+    public void reset() {
+        checkState(!isFinished, "Sort buffer has been finished.");
+        checkState(!isReleased, "Sort buffer has been released.");
+        checkState(!hasRemaining(), "Still has remaining data.");
+
+        for (MemorySegment segment : segments) {
+            bufferPool.recycle(segment);
+        }
+        segments.clear();
+
+        // initialized with -1 means the corresponding channel has no data
+        Arrays.fill(firstIndexEntryAddresses, -1L);
+        Arrays.fill(lastIndexEntryAddresses, -1L);
+
+        isFull = false;
+        writeSegmentIndex = 0;
+        writeSegmentOffset = 0;
+        readIndexEntryAddress = 0;
+        recordRemainingBytes = 0;
+        readOrderIndex = -1;
+    }
+
+    @Override
     public void finish() {
-        checkState(!isFinished, "SortBuffer is already finished.");
+        checkState(!isFull, "DataBuffer must not be full.");
+        checkState(!isFinished, "DataBuffer is already finished.");
 
         isFinished = true;
 
         // prepare for reading
+        isFull = true;
         updateReadChannelAndIndexEntryAddress();
     }
 
@@ -426,20 +468,15 @@ public class PartitionSortedBuffer implements SortBuffer {
 
     @Override
     public void release() {
-        // the sort buffer can be released by other threads
         if (isReleased) {
             return;
         }
-
         isReleased = true;
 
         for (MemorySegment segment : segments) {
             bufferPool.recycle(segment);
         }
         segments.clear();
-
-        numTotalBytes = 0;
-        numTotalRecords = 0;
     }
 
     @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index eabf8db..190ca45 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -54,10 +54,10 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
- * {@link SortMergeResultPartition} appends records and events to {@link SortBuffer} and after the
- * {@link SortBuffer} is full, all data in the {@link SortBuffer} will be copied and spilled to a
+ * {@link SortMergeResultPartition} appends records and events to {@link DataBuffer} and after the
+ * {@link DataBuffer} is full, all data in the {@link DataBuffer} will be copied and spilled to a
  * {@link PartitionedFile} in subpartition index order sequentially. Large records that can not be
- * appended to an empty {@link SortBuffer} will be spilled to the result {@link PartitionedFile}
+ * appended to an empty {@link DataBuffer} will be spilled to the result {@link PartitionedFile}
  * separately.
  */
 @NotThreadSafe
@@ -69,6 +69,14 @@ public class SortMergeResultPartition extends ResultPartition {
      */
     private static final int NUM_WRITE_BUFFER_BYTES = 8 * 1024 * 1024;
 
+    /**
+     * Expected number of buffers for data batch writing. 512 mean that at most 1024 buffers
+     * (including the headers) will be written in one request. This value is selected because that
+     * the writev system call has a limit on the maximum number of buffers can be written in one
+     * invoke whose advertised value is 1024 (please see writev man page for more information).
+     */
+    private static final int EXPECTED_WRITE_BATCH_SIZE = 512;
+
     private final Object lock = new Object();
 
     /** {@link PartitionedFile} produced by this result partition. */
@@ -93,7 +101,7 @@ public class SortMergeResultPartition extends ResultPartition {
      */
     private final String resultFileBasePath;
 
-    /** Subpartition orders of coping data from {@link SortBuffer} and writing to file. */
+    /** Subpartition orders of coping data from {@link DataBuffer} and writing to file. */
     private final int[] subpartitionOrder;
 
     /**
@@ -107,16 +115,22 @@ public class SortMergeResultPartition extends ResultPartition {
     private final SortMergeResultPartitionReadScheduler readScheduler;
 
     /**
-     * Number of guaranteed network buffers can be used by {@link #unicastSortBuffer} and {@link
-     * #broadcastSortBuffer}.
+     * Number of guaranteed network buffers can be used by {@link #unicastDataBuffer} and {@link
+     * #broadcastDataBuffer}.
      */
     private int numBuffersForSort;
 
-    /** {@link SortBuffer} for records sent by {@link #broadcastRecord(ByteBuffer)}. */
-    private SortBuffer broadcastSortBuffer;
+    /**
+     * If true, {@link HashBasedDataBuffer} will be used, otherwise, {@link SortBasedDataBuffer}
+     * will be used.
+     */
+    private boolean useHashBuffer;
+
+    /** {@link DataBuffer} for records sent by {@link #broadcastRecord(ByteBuffer)}. */
+    private DataBuffer broadcastDataBuffer;
 
-    /** {@link SortBuffer} for records sent by {@link #emitRecord(ByteBuffer, int)}. */
-    private SortBuffer unicastSortBuffer;
+    /** {@link DataBuffer} for records sent by {@link #emitRecord(ByteBuffer, int)}. */
+    private DataBuffer unicastDataBuffer;
 
     public SortMergeResultPartition(
             String owningTaskName,
@@ -175,23 +189,29 @@ public class SortMergeResultPartition extends ResultPartition {
         readBufferPool.initialize();
         super.setup();
 
-        int expectedWriteBuffers = NUM_WRITE_BUFFER_BYTES / networkBufferSize;
-        if (networkBufferSize > NUM_WRITE_BUFFER_BYTES) {
-            expectedWriteBuffers = 1;
-        }
-
         int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments();
-        int numWriteBuffers = Math.min(numRequiredBuffer / 2, expectedWriteBuffers);
-        if (numWriteBuffers < 1) {
+        if (numRequiredBuffer < 2) {
             throw new IOException(
                     String.format(
                             "Too few sort buffers, please increase %s.",
                             NettyShuffleEnvironmentOptions.NETWORK_SORT_SHUFFLE_MIN_BUFFERS));
         }
-        numBuffersForSort = numRequiredBuffer - numWriteBuffers;
+
+        int expectedWriteBuffers = 0;
+        if (numRequiredBuffer >= 2 * numSubpartitions) {
+            useHashBuffer = true;
+        } else if (networkBufferSize >= NUM_WRITE_BUFFER_BYTES) {
+            expectedWriteBuffers = 1;
+        } else {
+            expectedWriteBuffers =
+                    Math.min(EXPECTED_WRITE_BATCH_SIZE, NUM_WRITE_BUFFER_BYTES / networkBufferSize);
+        }
+
+        int numBuffersForWrite = Math.min(numRequiredBuffer / 2, expectedWriteBuffers);
+        numBuffersForSort = numRequiredBuffer - numBuffersForWrite;
 
         try {
-            for (int i = 0; i < numWriteBuffers; ++i) {
+            for (int i = 0; i < numBuffersForWrite; ++i) {
                 MemorySegment segment = bufferPool.requestMemorySegmentBlocking();
                 writeSegments.add(segment);
             }
@@ -204,7 +224,7 @@ public class SortMergeResultPartition extends ResultPartition {
                 "Sort-merge partition {} initialized, num sort buffers: {}, num write buffers: {}.",
                 getPartitionId(),
                 numBuffersForSort,
-                numWriteBuffers);
+                numBuffersForWrite);
     }
 
     @Override
@@ -259,105 +279,112 @@ public class SortMergeResultPartition extends ResultPartition {
             throws IOException {
         checkInProduceState();
 
-        SortBuffer sortBuffer = isBroadcast ? getBroadcastSortBuffer() : getUnicastSortBuffer();
-        if (sortBuffer.append(record, targetSubpartition, dataType)) {
+        DataBuffer dataBuffer = isBroadcast ? getBroadcastDataBuffer() : getUnicastDataBuffer();
+        if (!dataBuffer.append(record, targetSubpartition, dataType)) {
             return;
         }
 
-        if (!sortBuffer.hasRemaining()) {
-            // the record can not be appended to the free sort buffer because it is too large
-            sortBuffer.finish();
-            sortBuffer.release();
+        if (!dataBuffer.hasRemaining()) {
+            dataBuffer.reset();
             writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
             return;
         }
 
-        flushSortBuffer(sortBuffer, isBroadcast);
-        emit(record, targetSubpartition, dataType, isBroadcast);
+        flushDataBuffer(dataBuffer, isBroadcast);
+        dataBuffer.reset();
+        if (record.hasRemaining()) {
+            emit(record, targetSubpartition, dataType, isBroadcast);
+        }
     }
 
-    private void releaseSortBuffer(SortBuffer sortBuffer) {
-        if (sortBuffer != null) {
-            sortBuffer.release();
+    private void releaseDataBuffer(DataBuffer dataBuffer) {
+        if (dataBuffer != null) {
+            dataBuffer.release();
         }
     }
 
-    private SortBuffer getUnicastSortBuffer() throws IOException {
-        flushBroadcastSortBuffer();
+    private DataBuffer getUnicastDataBuffer() throws IOException {
+        flushBroadcastDataBuffer();
 
-        if (unicastSortBuffer != null && !unicastSortBuffer.isFinished()) {
-            return unicastSortBuffer;
+        if (unicastDataBuffer != null && !unicastDataBuffer.isFinished()) {
+            return unicastDataBuffer;
         }
 
-        unicastSortBuffer =
-                new PartitionSortedBuffer(
-                        bufferPool,
-                        numSubpartitions,
-                        networkBufferSize,
-                        numBuffersForSort,
-                        subpartitionOrder);
-        return unicastSortBuffer;
+        unicastDataBuffer = createNewDataBuffer();
+        return unicastDataBuffer;
     }
 
-    private SortBuffer getBroadcastSortBuffer() throws IOException {
-        flushUnicastSortBuffer();
+    private DataBuffer getBroadcastDataBuffer() throws IOException {
+        flushUnicastDataBuffer();
 
-        if (broadcastSortBuffer != null && !broadcastSortBuffer.isFinished()) {
-            return broadcastSortBuffer;
+        if (broadcastDataBuffer != null && !broadcastDataBuffer.isFinished()) {
+            return broadcastDataBuffer;
         }
 
-        broadcastSortBuffer =
-                new PartitionSortedBuffer(
-                        bufferPool,
-                        numSubpartitions,
-                        networkBufferSize,
-                        numBuffersForSort,
-                        subpartitionOrder);
-        return broadcastSortBuffer;
+        broadcastDataBuffer = createNewDataBuffer();
+        return broadcastDataBuffer;
     }
 
-    private void flushSortBuffer(SortBuffer sortBuffer, boolean isBroadcast) throws IOException {
-        if (sortBuffer == null || sortBuffer.isReleased()) {
-            return;
+    private DataBuffer createNewDataBuffer() {
+        if (useHashBuffer) {
+            return new HashBasedDataBuffer(
+                    bufferPool, numSubpartitions, numBuffersForSort, subpartitionOrder);
+        } else {
+            return new SortBasedDataBuffer(
+                    bufferPool,
+                    numSubpartitions,
+                    networkBufferSize,
+                    numBuffersForSort,
+                    subpartitionOrder);
         }
-        sortBuffer.finish();
-
-        if (sortBuffer.hasRemaining()) {
-            fileWriter.startNewRegion(isBroadcast);
+    }
 
-            List<BufferWithChannel> toWrite = new ArrayList<>();
-            Queue<MemorySegment> segments = getWriteSegments();
+    private void flushDataBuffer(DataBuffer dataBuffer, boolean isBroadcast) throws IOException {
+        if (dataBuffer == null || dataBuffer.isReleased() || !dataBuffer.hasRemaining()) {
+            return;
+        }
 
-            while (sortBuffer.hasRemaining()) {
-                if (segments.isEmpty()) {
-                    fileWriter.writeBuffers(toWrite);
-                    toWrite.clear();
-                    segments = getWriteSegments();
-                }
+        Queue<MemorySegment> segments = new ArrayDeque<>(writeSegments);
+        int numBuffersToWrite =
+                useHashBuffer
+                        ? EXPECTED_WRITE_BATCH_SIZE
+                        : Math.min(EXPECTED_WRITE_BATCH_SIZE, segments.size());
+        List<BufferWithChannel> toWrite = new ArrayList<>(numBuffersToWrite);
 
-                BufferWithChannel bufferWithChannel =
-                        sortBuffer.copyIntoSegment(checkNotNull(segments.poll()));
-                updateStatistics(bufferWithChannel.getBuffer(), isBroadcast);
-                toWrite.add(compressBufferIfPossible(bufferWithChannel));
+        fileWriter.startNewRegion(isBroadcast);
+        do {
+            if (toWrite.size() >= numBuffersToWrite) {
+                writeBuffers(toWrite);
+                segments = new ArrayDeque<>(writeSegments);
             }
 
-            fileWriter.writeBuffers(toWrite);
-        }
-
-        releaseSortBuffer(sortBuffer);
-    }
+            BufferWithChannel bufferWithChannel = dataBuffer.getNextBuffer(segments.poll());
+            if (bufferWithChannel == null) {
+                writeBuffers(toWrite);
+                break;
+            }
 
-    private void flushBroadcastSortBuffer() throws IOException {
-        flushSortBuffer(broadcastSortBuffer, true);
+            updateStatistics(bufferWithChannel.getBuffer(), isBroadcast);
+            toWrite.add(compressBufferIfPossible(bufferWithChannel));
+        } while (true);
     }
 
-    private void flushUnicastSortBuffer() throws IOException {
-        flushSortBuffer(unicastSortBuffer, false);
+    private void flushBroadcastDataBuffer() throws IOException {
+        if (broadcastDataBuffer != null) {
+            broadcastDataBuffer.finish();
+            flushDataBuffer(broadcastDataBuffer, true);
+            broadcastDataBuffer.release();
+            broadcastDataBuffer = null;
+        }
     }
 
-    private Queue<MemorySegment> getWriteSegments() {
-        checkState(!writeSegments.isEmpty(), "Task has been canceled.");
-        return new ArrayDeque<>(writeSegments);
+    private void flushUnicastDataBuffer() throws IOException {
+        if (unicastDataBuffer != null) {
+            unicastDataBuffer.finish();
+            flushDataBuffer(unicastDataBuffer, false);
+            unicastDataBuffer.release();
+            unicastDataBuffer = null;
+        }
     }
 
     private BufferWithChannel compressBufferIfPossible(BufferWithChannel bufferWithChannel) {
@@ -383,16 +410,19 @@ public class SortMergeResultPartition extends ResultPartition {
     private void writeLargeRecord(
             ByteBuffer record, int targetSubpartition, DataType dataType, boolean isBroadcast)
             throws IOException {
+        // for the hash-based data buffer implementation, a large record will be appended to the
+        // data buffer directly and spilled to multiple data regions
+        checkState(!useHashBuffer, "No buffers available for writing.");
         fileWriter.startNewRegion(isBroadcast);
 
         List<BufferWithChannel> toWrite = new ArrayList<>();
-        Queue<MemorySegment> segments = getWriteSegments();
+        Queue<MemorySegment> segments = new ArrayDeque<>(writeSegments);
 
         while (record.hasRemaining()) {
             if (segments.isEmpty()) {
                 fileWriter.writeBuffers(toWrite);
                 toWrite.clear();
-                segments = getWriteSegments();
+                segments = new ArrayDeque<>(writeSegments);
             }
 
             int toCopy = Math.min(record.remaining(), networkBufferSize);
@@ -408,6 +438,12 @@ public class SortMergeResultPartition extends ResultPartition {
         fileWriter.writeBuffers(toWrite);
     }
 
+    private void writeBuffers(List<BufferWithChannel> buffers) throws IOException {
+        fileWriter.writeBuffers(buffers);
+        buffers.forEach(buffer -> buffer.getBuffer().recycleBuffer());
+        buffers.clear();
+    }
+
     @Override
     public void notifyEndOfData(StopMode mode) throws IOException {
         if (!hasNotifiedEndOfUserRecords) {
@@ -420,9 +456,9 @@ public class SortMergeResultPartition extends ResultPartition {
     public void finish() throws IOException {
         broadcastEvent(EndOfPartitionEvent.INSTANCE, false);
         checkState(
-                unicastSortBuffer == null || unicastSortBuffer.isReleased(),
+                unicastDataBuffer == null,
                 "The unicast sort buffer should be either null or released.");
-        flushBroadcastSortBuffer();
+        flushBroadcastDataBuffer();
 
         synchronized (lock) {
             checkState(!isReleased(), "Result partition is already released.");
@@ -447,8 +483,8 @@ public class SortMergeResultPartition extends ResultPartition {
         releaseWriteBuffers();
         // the close method will be always called by the task thread, so there is need to make
         // the sort buffer fields volatile and visible to the cancel thread intermediately
-        releaseSortBuffer(unicastSortBuffer);
-        releaseSortBuffer(broadcastSortBuffer);
+        releaseDataBuffer(unicastDataBuffer);
+        releaseDataBuffer(broadcastDataBuffer);
         super.close();
 
         IOUtils.closeQuietly(fileWriter);
@@ -475,8 +511,8 @@ public class SortMergeResultPartition extends ResultPartition {
     @Override
     public void flushAll() {
         try {
-            flushUnicastSortBuffer();
-            flushBroadcastSortBuffer();
+            flushUnicastDataBuffer();
+            flushBroadcastDataBuffer();
         } catch (IOException e) {
             LOG.error("Failed to flush the current sort buffer.", e);
         }
@@ -485,8 +521,8 @@ public class SortMergeResultPartition extends ResultPartition {
     @Override
     public void flush(int subpartitionIndex) {
         try {
-            flushUnicastSortBuffer();
-            flushBroadcastSortBuffer();
+            flushUnicastDataBuffer();
+            flushBroadcastDataBuffer();
         } catch (IOException e) {
             LOG.error("Failed to flush the current sort buffer.", e);
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBufferTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
similarity index 56%
rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBufferTest.java
rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
index fa9c27c..a6b8b6c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBufferTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/DataBufferTest.java
@@ -22,9 +22,12 @@ import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -38,16 +41,29 @@ import java.util.Random;
 import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
-/** Tests for {@link PartitionSortedBuffer}. */
-public class PartitionSortedBufferTest {
+/** Tests for {@link SortBasedDataBuffer} and {@link HashBasedDataBuffer}. */
+@RunWith(Parameterized.class)
+public class DataBufferTest {
+
+    private final boolean useHashBuffer;
+
+    @Parameterized.Parameters(name = "UseHashBuffer = {0}")
+    public static Object[] parameters() {
+        return new Object[] {true, false};
+    }
+
+    public DataBufferTest(boolean useHashBuffer) {
+        this.useHashBuffer = useHashBuffer;
+    }
 
     @Test
-    public void testWriteAndReadSortBuffer() throws Exception {
+    public void testWriteAndReadDataBuffer() throws Exception {
         int numSubpartitions = 10;
         int bufferSize = 1024;
-        int bufferPoolSize = 1000;
+        int bufferPoolSize = 512;
         Random random = new Random(1111);
 
         // used to store data written to and read from sort buffer for correctness check
@@ -65,13 +81,14 @@ public class PartitionSortedBufferTest {
 
         // fill the sort buffer with randomly generated data
         int totalBytesWritten = 0;
-        SortBuffer sortBuffer =
-                createSortBuffer(
+        DataBuffer dataBuffer =
+                createDataBuffer(
                         bufferPoolSize,
                         bufferSize,
                         numSubpartitions,
                         getRandomSubpartitionOrder(numSubpartitions));
-        while (true) {
+        int numDataBuffers = 5;
+        while (numDataBuffers > 0) {
             // record size may be larger than buffer size so a record may span multiple segments
             int recordSize = random.nextInt(bufferSize * 4 - 1) + 1;
             byte[] bytes = new byte[recordSize];
@@ -86,30 +103,71 @@ public class PartitionSortedBufferTest {
             // select a random data type
             boolean isBuffer = random.nextBoolean();
             DataType dataType = isBuffer ? DataType.DATA_BUFFER : DataType.EVENT_BUFFER;
-            if (!sortBuffer.append(record, subpartition, dataType)) {
-                sortBuffer.finish();
-                break;
+            boolean isFull = dataBuffer.append(record, subpartition, dataType);
+
+            record.flip();
+            if (record.hasRemaining()) {
+                dataWritten[subpartition].add(new DataAndType(record, dataType));
+                numBytesWritten[subpartition] += record.remaining();
+                totalBytesWritten += record.remaining();
+            }
+
+            while (isFull && dataBuffer.hasRemaining()) {
+                BufferWithChannel buffer = copyIntoSegment(bufferSize, dataBuffer);
+                if (buffer == null) {
+                    break;
+                }
+                addBufferRead(buffer, buffersRead, numBytesRead);
+            }
+
+            if (isFull) {
+                --numDataBuffers;
+                dataBuffer.reset();
             }
-            record.rewind();
-            dataWritten[subpartition].add(new DataAndType(record, dataType));
-            numBytesWritten[subpartition] += recordSize;
-            totalBytesWritten += recordSize;
         }
 
         // read all data from the sort buffer
-        while (sortBuffer.hasRemaining()) {
-            MemorySegment readBuffer = MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
-            BufferWithChannel bufferAndChannel = sortBuffer.copyIntoSegment(readBuffer);
-            int subpartition = bufferAndChannel.getChannelIndex();
-            buffersRead[subpartition].add(bufferAndChannel.getBuffer());
-            numBytesRead[subpartition] += bufferAndChannel.getBuffer().readableBytes();
+        if (dataBuffer.hasRemaining()) {
+            assertTrue(dataBuffer instanceof HashBasedDataBuffer);
+            dataBuffer.reset();
+            dataBuffer.finish();
+            while (dataBuffer.hasRemaining()) {
+                addBufferRead(copyIntoSegment(bufferSize, dataBuffer), buffersRead, numBytesRead);
+            }
         }
 
-        assertEquals(totalBytesWritten, sortBuffer.numBytes());
+        assertEquals(totalBytesWritten, dataBuffer.numTotalBytes());
         checkWriteReadResult(
                 numSubpartitions, numBytesWritten, numBytesRead, dataWritten, buffersRead);
     }
 
+    private BufferWithChannel copyIntoSegment(int bufferSize, DataBuffer dataBuffer) {
+        if (useHashBuffer) {
+            BufferWithChannel buffer = dataBuffer.getNextBuffer(null);
+            if (buffer == null || !buffer.getBuffer().isBuffer()) {
+                return buffer;
+            }
+
+            MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
+            int numBytes = buffer.getBuffer().readableBytes();
+            segment.put(0, buffer.getBuffer().getNioBufferReadable(), numBytes);
+            buffer.getBuffer().recycleBuffer();
+            return new BufferWithChannel(
+                    new NetworkBuffer(segment, MemorySegment::free, DataType.DATA_BUFFER, numBytes),
+                    buffer.getChannelIndex());
+        } else {
+            MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
+            return dataBuffer.getNextBuffer(segment);
+        }
+    }
+
+    private void addBufferRead(
+            BufferWithChannel buffer, Queue<Buffer>[] buffersRead, int[] numBytesRead) {
+        int channel = buffer.getChannelIndex();
+        buffersRead[channel].add(buffer.getBuffer());
+        numBytesRead[channel] += buffer.getBuffer().readableBytes();
+    }
+
     public static void checkWriteReadResult(
             int numSubpartitions,
             int[] numBytesWritten,
@@ -166,33 +224,33 @@ public class PartitionSortedBufferTest {
             ByteBuffer.allocate(1024)
         };
 
-        SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, numSubpartitions);
+        DataBuffer dataBuffer = createDataBuffer(bufferPoolSize, bufferSize, numSubpartitions);
         for (int subpartition = 0; subpartition < numSubpartitions; ++subpartition) {
             ByteBuffer record = subpartitionRecords[subpartition];
             if (record != null) {
-                sortBuffer.append(record, subpartition, Buffer.DataType.DATA_BUFFER);
+                dataBuffer.append(record, subpartition, Buffer.DataType.DATA_BUFFER);
                 record.rewind();
             }
         }
-        sortBuffer.finish();
+        dataBuffer.finish();
 
-        checkReadResult(sortBuffer, subpartitionRecords[0], 0, bufferSize);
+        checkReadResult(dataBuffer, subpartitionRecords[0], 0, bufferSize);
 
         ByteBuffer expected1 = subpartitionRecords[2].duplicate();
         expected1.limit(bufferSize);
-        checkReadResult(sortBuffer, expected1.slice(), 2, bufferSize);
+        checkReadResult(dataBuffer, expected1.slice(), 2, bufferSize);
 
         ByteBuffer expected2 = subpartitionRecords[2].duplicate();
         expected2.position(bufferSize);
-        checkReadResult(sortBuffer, expected2.slice(), 2, bufferSize);
+        checkReadResult(dataBuffer, expected2.slice(), 2, bufferSize);
 
-        checkReadResult(sortBuffer, subpartitionRecords[4], 4, bufferSize);
+        checkReadResult(dataBuffer, subpartitionRecords[4], 4, bufferSize);
     }
 
     private void checkReadResult(
-            SortBuffer sortBuffer, ByteBuffer expectedBuffer, int expectedChannel, int bufferSize) {
+            DataBuffer dataBuffer, ByteBuffer expectedBuffer, int expectedChannel, int bufferSize) {
         MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
-        BufferWithChannel bufferWithChannel = sortBuffer.copyIntoSegment(segment);
+        BufferWithChannel bufferWithChannel = dataBuffer.getNextBuffer(segment);
         assertEquals(expectedChannel, bufferWithChannel.getChannelIndex());
         assertEquals(expectedBuffer, bufferWithChannel.getBuffer().getNioBufferReadable());
     }
@@ -201,32 +259,32 @@ public class PartitionSortedBufferTest {
     public void testWriteEmptyData() throws Exception {
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
+        DataBuffer dataBuffer = createDataBuffer(1, bufferSize, 1);
 
         ByteBuffer record = ByteBuffer.allocate(1);
         record.position(1);
 
-        sortBuffer.append(record, 0, Buffer.DataType.DATA_BUFFER);
+        dataBuffer.append(record, 0, Buffer.DataType.DATA_BUFFER);
     }
 
     @Test(expected = IllegalStateException.class)
-    public void testWriteFinishedSortBuffer() throws Exception {
+    public void testWriteFinishedDataBuffer() throws Exception {
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
-        sortBuffer.finish();
+        DataBuffer dataBuffer = createDataBuffer(1, bufferSize, 1);
+        dataBuffer.finish();
 
-        sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+        dataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
     }
 
     @Test(expected = IllegalStateException.class)
-    public void testWriteReleasedSortBuffer() throws Exception {
+    public void testWriteReleasedDataBuffer() throws Exception {
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
-        sortBuffer.release();
+        DataBuffer dataBuffer = createDataBuffer(1, bufferSize, 1);
+        dataBuffer.release();
 
-        sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+        dataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
     }
 
     @Test
@@ -234,16 +292,16 @@ public class PartitionSortedBufferTest {
         int bufferPoolSize = 10;
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, 1);
+        DataBuffer dataBuffer = createDataBuffer(bufferPoolSize, bufferSize, 1);
 
         for (int i = 1; i < bufferPoolSize; ++i) {
-            appendAndCheckResult(sortBuffer, bufferSize, true, bufferSize * i, i, true);
+            appendAndCheckResult(dataBuffer, bufferSize, false, bufferSize * i, i, true);
         }
 
         // append should fail for insufficient capacity
         int numRecords = bufferPoolSize - 1;
-        appendAndCheckResult(
-                sortBuffer, bufferSize, false, bufferSize * numRecords, numRecords, true);
+        long numBytes = useHashBuffer ? bufferSize * bufferPoolSize : bufferSize * numRecords;
+        appendAndCheckResult(dataBuffer, bufferSize + 1, true, numBytes, numRecords, true);
     }
 
     @Test
@@ -251,66 +309,68 @@ public class PartitionSortedBufferTest {
         int bufferPoolSize = 10;
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, 1);
-        // append should fail for insufficient capacity
-        appendAndCheckResult(sortBuffer, bufferPoolSize * bufferSize, false, 0, 0, false);
+        DataBuffer dataBuffer = createDataBuffer(bufferPoolSize, bufferSize, 1);
+        long numBytes = useHashBuffer ? bufferPoolSize * bufferSize : 0;
+        appendAndCheckResult(
+                dataBuffer, bufferPoolSize * bufferSize + 1, true, numBytes, 0, useHashBuffer);
     }
 
     private void appendAndCheckResult(
-            SortBuffer sortBuffer,
+            DataBuffer dataBuffer,
             int recordSize,
-            boolean isSuccessful,
+            boolean isFull,
             long numBytes,
             long numRecords,
             boolean hasRemaining)
             throws IOException {
         ByteBuffer largeRecord = ByteBuffer.allocate(recordSize);
 
-        assertEquals(isSuccessful, sortBuffer.append(largeRecord, 0, Buffer.DataType.DATA_BUFFER));
-        assertEquals(numBytes, sortBuffer.numBytes());
-        assertEquals(numRecords, sortBuffer.numRecords());
-        assertEquals(hasRemaining, sortBuffer.hasRemaining());
+        assertEquals(isFull, dataBuffer.append(largeRecord, 0, Buffer.DataType.DATA_BUFFER));
+        assertEquals(numBytes, dataBuffer.numTotalBytes());
+        assertEquals(numRecords, dataBuffer.numTotalRecords());
+        assertEquals(hasRemaining, dataBuffer.hasRemaining());
     }
 
     @Test(expected = IllegalStateException.class)
-    public void testReadUnfinishedSortBuffer() throws Exception {
+    public void testReadUnfinishedDataBuffer() throws Exception {
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
-        sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+        DataBuffer dataBuffer = createDataBuffer(1, bufferSize, 1);
+        dataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
 
-        assertTrue(sortBuffer.hasRemaining());
-        sortBuffer.copyIntoSegment(MemorySegmentFactory.allocateUnpooledSegment(bufferSize));
+        assertTrue(dataBuffer.hasRemaining());
+        dataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(bufferSize));
     }
 
     @Test(expected = IllegalStateException.class)
-    public void testReadReleasedSortBuffer() throws Exception {
+    public void testReadReleasedDataBuffer() throws Exception {
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
-        sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
-        sortBuffer.finish();
-        assertTrue(sortBuffer.hasRemaining());
+        DataBuffer dataBuffer = createDataBuffer(1, bufferSize, 1);
+        dataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+        dataBuffer.finish();
+        assertTrue(dataBuffer.hasRemaining());
 
-        sortBuffer.release();
-        assertFalse(sortBuffer.hasRemaining());
+        dataBuffer.release();
+        assertTrue(dataBuffer.hasRemaining());
 
-        sortBuffer.copyIntoSegment(MemorySegmentFactory.allocateUnpooledSegment(bufferSize));
+        dataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(bufferSize));
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testReadEmptySortBuffer() throws Exception {
+    @Test
+    public void testReadEmptyDataBuffer() throws Exception {
         int bufferSize = 1024;
 
-        SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
-        sortBuffer.finish();
+        DataBuffer dataBuffer = createDataBuffer(1, bufferSize, 1);
+        dataBuffer.finish();
 
-        assertFalse(sortBuffer.hasRemaining());
-        sortBuffer.copyIntoSegment(MemorySegmentFactory.allocateUnpooledSegment(bufferSize));
+        assertFalse(dataBuffer.hasRemaining());
+        assertNull(
+                dataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(bufferSize)));
     }
 
     @Test
-    public void testReleaseSortBuffer() throws Exception {
+    public void testReleaseDataBuffer() throws Exception {
         int bufferPoolSize = 10;
         int bufferSize = 1024;
         int recordSize = (bufferPoolSize - 1) * bufferSize;
@@ -318,36 +378,41 @@ public class PartitionSortedBufferTest {
         NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
         BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
 
-        SortBuffer sortBuffer =
-                new PartitionSortedBuffer(bufferPool, 1, bufferSize, bufferPoolSize, null);
-        sortBuffer.append(ByteBuffer.allocate(recordSize), 0, Buffer.DataType.DATA_BUFFER);
+        DataBuffer dataBuffer =
+                new SortBasedDataBuffer(bufferPool, 1, bufferSize, bufferPoolSize, null);
+        dataBuffer.append(ByteBuffer.allocate(recordSize), 0, Buffer.DataType.DATA_BUFFER);
 
         assertEquals(bufferPoolSize, bufferPool.bestEffortGetNumOfUsedBuffers());
-        assertTrue(sortBuffer.hasRemaining());
-        assertEquals(1, sortBuffer.numRecords());
-        assertEquals(recordSize, sortBuffer.numBytes());
+        assertTrue(dataBuffer.hasRemaining());
+        assertEquals(1, dataBuffer.numTotalRecords());
+        assertEquals(recordSize, dataBuffer.numTotalBytes());
 
         // should release all data and resources
-        sortBuffer.release();
+        dataBuffer.release();
         assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
-        assertFalse(sortBuffer.hasRemaining());
-        assertEquals(0, sortBuffer.numRecords());
-        assertEquals(0, sortBuffer.numBytes());
+        assertTrue(dataBuffer.hasRemaining());
+        assertEquals(1, dataBuffer.numTotalRecords());
+        assertEquals(recordSize, dataBuffer.numTotalBytes());
     }
 
-    private SortBuffer createSortBuffer(int bufferPoolSize, int bufferSize, int numSubpartitions)
+    private DataBuffer createDataBuffer(int bufferPoolSize, int bufferSize, int numSubpartitions)
             throws IOException {
-        return createSortBuffer(bufferPoolSize, bufferSize, numSubpartitions, null);
+        return createDataBuffer(bufferPoolSize, bufferSize, numSubpartitions, null);
     }
 
-    private SortBuffer createSortBuffer(
+    private DataBuffer createDataBuffer(
             int bufferPoolSize, int bufferSize, int numSubpartitions, int[] customReadOrder)
             throws IOException {
         NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
         BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
 
-        return new PartitionSortedBuffer(
-                bufferPool, numSubpartitions, bufferSize, bufferPoolSize, customReadOrder);
+        if (useHashBuffer) {
+            return new HashBasedDataBuffer(
+                    bufferPool, numSubpartitions, bufferPoolSize, customReadOrder);
+        } else {
+            return new SortBasedDataBuffer(
+                    bufferPool, numSubpartitions, bufferSize, bufferPoolSize, customReadOrder);
+        }
     }
 
     public static int[] getRandomSubpartitionOrder(int numSubpartitions) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionedFileWriteReadTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionedFileWriteReadTest.java
index 10a7d26..53c1f2c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionedFileWriteReadTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionedFileWriteReadTest.java
@@ -90,8 +90,7 @@ public class PartitionedFileWriteReadTest {
                 }
             }
 
-            int[] writeOrder =
-                    PartitionSortedBufferTest.getRandomSubpartitionOrder(numSubpartitions);
+            int[] writeOrder = DataBufferTest.getRandomSubpartitionOrder(numSubpartitions);
             for (int index = 0; index < numSubpartitions; ++index) {
                 int subpartition = writeOrder[index];
                 fileWriter.writeBuffers(regionBuffers[subpartition]);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
index 143225c..faea851 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
@@ -37,6 +37,8 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.File;
 import java.io.IOException;
@@ -59,6 +61,7 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
 /** Tests for {@link SortMergeResultPartition}. */
+@RunWith(Parameterized.class)
 public class SortMergeResultPartitionTest extends TestLogger {
 
     private static final int bufferSize = 1024;
@@ -69,6 +72,8 @@ public class SortMergeResultPartitionTest extends TestLogger {
 
     private static final int numThreads = 4;
 
+    private final boolean useHashDataBuffer;
+
     private final TestBufferAvailabilityListener listener = new TestBufferAvailabilityListener();
 
     private FileChannelManager fileChannelManager;
@@ -100,10 +105,19 @@ public class SortMergeResultPartitionTest extends TestLogger {
         readIOExecutor.shutdown();
     }
 
+    @Parameterized.Parameters(name = "UseHashDataBuffer = {0}")
+    public static Object[] parameters() {
+        return new Object[] {true, false};
+    }
+
+    public SortMergeResultPartitionTest(boolean useHashDataBuffer) {
+        this.useHashDataBuffer = useHashDataBuffer;
+    }
+
     @Test
     public void testWriteAndRead() throws Exception {
+        int numBuffers = useHashDataBuffer ? 100 : 15;
         int numSubpartitions = 10;
-        int numBuffers = 100;
         int numRecords = 1000;
         Random random = new Random();
 
@@ -111,7 +125,7 @@ public class SortMergeResultPartitionTest extends TestLogger {
         SortMergeResultPartition partition =
                 createSortMergedPartition(numSubpartitions, bufferPool);
 
-        Queue<PartitionSortedBufferTest.DataAndType>[] dataWritten = new Queue[numSubpartitions];
+        Queue<DataBufferTest.DataAndType>[] dataWritten = new Queue[numSubpartitions];
         Queue<Buffer>[] buffersRead = new Queue[numSubpartitions];
         for (int i = 0; i < numSubpartitions; ++i) {
             dataWritten[i] = new ArrayDeque<>();
@@ -169,18 +183,18 @@ public class SortMergeResultPartitionTest extends TestLogger {
                             new NetworkBuffer(
                                     segment, (buf) -> {}, buffer.getDataType(), numBytes));
                 });
-        PartitionSortedBufferTest.checkWriteReadResult(
+        DataBufferTest.checkWriteReadResult(
                 numSubpartitions, numBytesWritten, numBytesRead, dataWritten, buffersRead);
     }
 
     private void recordDataWritten(
             ByteBuffer record,
-            Queue<PartitionSortedBufferTest.DataAndType>[] dataWritten,
+            Queue<DataBufferTest.DataAndType>[] dataWritten,
             int subpartition,
             int[] numBytesWritten,
             Buffer.DataType dataType) {
         record.rewind();
-        dataWritten[subpartition].add(new PartitionSortedBufferTest.DataAndType(record, dataType));
+        dataWritten[subpartition].add(new DataBufferTest.DataAndType(record, dataType));
         numBytesWritten[subpartition] += record.remaining();
     }
 
@@ -231,14 +245,16 @@ public class SortMergeResultPartitionTest extends TestLogger {
 
     @Test
     public void testWriteLargeRecord() throws Exception {
-        int numBuffers = 100;
-        int numWriteBuffers = numBuffers / 2;
+        int numBuffers = useHashDataBuffer ? 100 : 15;
+        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
 
         ByteBuffer recordWritten = generateRandomData(bufferSize * numBuffers, new Random());
         partition.emitRecord(recordWritten, 0);
-        assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+        assertEquals(
+                useHashDataBuffer ? numBuffers : numWriteBuffers,
+                bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.finish();
         partition.close();
@@ -261,7 +277,7 @@ public class SortMergeResultPartitionTest extends TestLogger {
     @Test
     public void testDataBroadcast() throws Exception {
         int numSubpartitions = 10;
-        int numBuffers = 100;
+        int numBuffers = useHashDataBuffer ? 100 : 15;
         int numRecords = 10000;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
@@ -292,22 +308,26 @@ public class SortMergeResultPartitionTest extends TestLogger {
 
     @Test
     public void testFlush() throws Exception {
-        int numBuffers = 10;
-        int numWriteBuffers = numBuffers / 2;
+        int numBuffers = useHashDataBuffer ? 100 : 15;
+        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
         SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
         assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.emitRecord(ByteBuffer.allocate(bufferSize), 0);
         partition.emitRecord(ByteBuffer.allocate(bufferSize), 1);
-        assertEquals(3 + numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+        assertEquals(
+                (useHashDataBuffer ? 2 : 3) + numWriteBuffers,
+                bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.flush(0);
         assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.emitRecord(ByteBuffer.allocate(bufferSize), 2);
         partition.emitRecord(ByteBuffer.allocate(bufferSize), 3);
-        assertEquals(3 + numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+        assertEquals(
+                (useHashDataBuffer ? 2 : 3) + numWriteBuffers,
+                bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.flushAll();
         assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
@@ -321,8 +341,8 @@ public class SortMergeResultPartitionTest extends TestLogger {
 
     @Test(expected = IllegalStateException.class)
     public void testReleaseWhileWriting() throws Exception {
-        int numBuffers = 10;
-        int numWriteBuffers = numBuffers / 2;
+        int numBuffers = useHashDataBuffer ? 100 : 15;
+        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
         int numBuffersForSort = numBuffers - numWriteBuffers;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
@@ -348,8 +368,8 @@ public class SortMergeResultPartitionTest extends TestLogger {
 
     @Test
     public void testRelease() throws Exception {
-        int numBuffers = 10;
-        int numWriteBuffers = numBuffers / 2;
+        int numBuffers = useHashDataBuffer ? 100 : 15;
+        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
         int numBuffersForSort = numBuffers - numWriteBuffers;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
@@ -383,8 +403,8 @@ public class SortMergeResultPartitionTest extends TestLogger {
 
     @Test
     public void testCloseReleasesAllBuffers() throws Exception {
-        int numBuffers = 100;
-        int numWriteBuffers = numBuffers / 2;
+        int numBuffers = useHashDataBuffer ? 100 : 15;
+        int numWriteBuffers = useHashDataBuffer ? 0 : numBuffers / 2;
         int numBuffersForSort = numBuffers - numWriteBuffers;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
@@ -392,7 +412,9 @@ public class SortMergeResultPartitionTest extends TestLogger {
         assertEquals(numWriteBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffersForSort - 1)), 5);
-        assertEquals(numBuffers, bufferPool.bestEffortGetNumOfUsedBuffers());
+        assertEquals(
+                useHashDataBuffer ? numBuffers - 1 : numBuffers,
+                bufferPool.bestEffortGetNumOfUsedBuffers());
 
         partition.close();
         assertTrue(bufferPool.isDestroyed());
@@ -434,10 +456,10 @@ public class SortMergeResultPartitionTest extends TestLogger {
     }
 
     private void testNumBytesProducedCounter(boolean isBroadcast) throws IOException {
-
+        int numBuffers = useHashDataBuffer ? 100 : 15;
         int numSubpartitions = 10;
 
-        BufferPool bufferPool = globalPool.createBufferPool(10, 10);
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
         SortMergeResultPartition partition =
                 createSortMergedPartition(numSubpartitions, bufferPool);