You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/08/02 06:50:12 UTC

[flink] branch master updated (70e8d7e251c -> 47d0b6d26c0)

This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


    from 70e8d7e251c [FLINK-27769][sql-gateway]Introduce the REST endpoint framework
     new 10c20837126 [hotfix] Migrate CreditBasedPartitionRequestClientHandlerTest, NettyMessageClientSideSerializationTest, SingleInputGateTest and BlockCompressionTest to Junit5/AssertJ
     new 47d0b6d26c0 [FLINK-28382][network] Introduce LZO and ZSTD compression based on aircompressor for blocking shuffle

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../generated/all_taskmanager_network_section.html |   6 +
 .../netty_shuffle_environment_configuration.html   |   6 +
 .../NettyShuffleEnvironmentOptions.java            |  14 +-
 flink-runtime/pom.xml                              |  38 +-
 ...lockCompressor.java => AirBlockCompressor.java} |  76 ++--
 .../io/compression/AirBlockDecompressor.java       | 103 +++++
 ...ssionFactory.java => AirCompressorFactory.java} |  25 +-
 .../io/compression/BlockCompressionFactory.java    |  19 +-
 .../runtime/io/compression/BlockCompressor.java    |   8 +-
 .../runtime/io/compression/BlockDecompressor.java  |  10 +-
 ...eption.java => BufferCompressionException.java} |  14 +-
 ...tion.java => BufferDecompressionException.java} |  15 +-
 ...ompressionFactory.java => CompressorUtils.java} |  31 +-
 .../io/compression/Lz4BlockCompressionFactory.java |   7 -
 .../runtime/io/compression/Lz4BlockCompressor.java |  24 +-
 .../io/compression/Lz4BlockDecompressor.java       |  40 +-
 .../io/network/buffer/BufferCompressor.java        |  47 ++-
 .../io/network/buffer/BufferDecompressor.java      |  40 +-
 .../src/main/resources/META-INF/NOTICE             |   4 +-
 .../io/compression/BlockCompressionTest.java       |  71 ++--
 .../io/network/buffer/BufferCompressionTest.java   |  12 +
 ...editBasedPartitionRequestClientHandlerTest.java | 230 +++++------
 .../NettyMessageClientSideSerializationTest.java   |  79 ++--
 .../partition/consumer/SingleInputGateTest.java    | 434 +++++++++++----------
 .../io/CompressedHeaderlessChannelTest.java        |  15 +-
 25 files changed, 807 insertions(+), 561 deletions(-)
 copy flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/{Lz4BlockCompressor.java => AirBlockCompressor.java} (51%)
 create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirBlockDecompressor.java
 copy flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/{Lz4BlockCompressionFactory.java => AirCompressorFactory.java} (58%)
 rename flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/{DataCorruptionException.java => BufferCompressionException.java} (68%)
 rename flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/{InsufficientBufferException.java => BufferDecompressionException.java} (65%)
 copy flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/{Lz4BlockCompressionFactory.java => CompressorUtils.java} (54%)
 copy {flink-connectors/flink-sql-connector-kafka => flink-runtime}/src/main/resources/META-INF/NOTICE (81%)


[flink] 01/02: [hotfix] Migrate CreditBasedPartitionRequestClientHandlerTest, NettyMessageClientSideSerializationTest, SingleInputGateTest and BlockCompressionTest to Junit5/AssertJ

Posted by yi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 10c20837126d7e50c4b2c22678a6853193ae7534
Author: Weijie Guo <re...@163.com>
AuthorDate: Fri Jul 8 01:00:33 2022 +0800

    [hotfix] Migrate CreditBasedPartitionRequestClientHandlerTest, NettyMessageClientSideSerializationTest, SingleInputGateTest and BlockCompressionTest to Junit5/AssertJ
    
    This closes #20216.
---
 .../io/compression/BlockCompressionTest.java       |  57 +--
 ...editBasedPartitionRequestClientHandlerTest.java | 224 +++++------
 .../NettyMessageClientSideSerializationTest.java   |  60 +--
 .../partition/consumer/SingleInputGateTest.java    | 428 +++++++++++----------
 4 files changed, 384 insertions(+), 385 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
index 1f57ce2ad19..fd8a05db6f9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
@@ -18,19 +18,19 @@
 
 package org.apache.flink.runtime.io.compression;
 
-import org.junit.Assert;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.nio.ByteBuffer;
 
 import static org.apache.flink.runtime.io.compression.Lz4BlockCompressionFactory.HEADER_LENGTH;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for block compression. */
-public class BlockCompressionTest {
+class BlockCompressionTest {
 
     @Test
-    public void testLz4() {
+    void testLz4() {
         BlockCompressionFactory factory = new Lz4BlockCompressionFactory();
         runArrayTest(factory, 32768);
         runArrayTest(factory, 16);
@@ -54,12 +54,16 @@ public class BlockCompressionTest {
         int compressedOff = 32;
 
         // 1. test compress with insufficient target
-        byte[] insufficientArray = new byte[compressedOff + HEADER_LENGTH + 1];
-        try {
-            compressor.compress(data, originalOff, originalLen, insufficientArray, compressedOff);
-            Assert.fail("expect exception here");
-        } catch (InsufficientBufferException ex) {
-        }
+        byte[] insufficientCompressArray = new byte[compressedOff + HEADER_LENGTH + 1];
+        assertThatThrownBy(
+                        () ->
+                                compressor.compress(
+                                        data,
+                                        originalOff,
+                                        originalLen,
+                                        insufficientCompressArray,
+                                        compressedOff))
+                .isInstanceOf(InsufficientBufferException.class);
 
         // 2. test normal compress
         byte[] compressedData =
@@ -70,17 +74,16 @@ public class BlockCompressionTest {
         int decompressedOff = 16;
 
         // 3. test decompress with insufficient target
-        insufficientArray = new byte[decompressedOff + originalLen - 1];
-        try {
-            decompressor.decompress(
-                    compressedData,
-                    compressedOff,
-                    compressedLen,
-                    insufficientArray,
-                    decompressedOff);
-            Assert.fail("expect exception here");
-        } catch (InsufficientBufferException ex) {
-        }
+        byte[] insufficientDecompressArray = new byte[decompressedOff + originalLen - 1];
+        assertThatThrownBy(
+                        () ->
+                                decompressor.decompress(
+                                        compressedData,
+                                        compressedOff,
+                                        compressedLen,
+                                        insufficientDecompressArray,
+                                        decompressedOff))
+                .isInstanceOf(InsufficientBufferException.class);
 
         // 4. test normal decompress
         byte[] decompressedData = new byte[decompressedOff + originalLen];
@@ -91,10 +94,10 @@ public class BlockCompressionTest {
                         compressedLen,
                         decompressedData,
                         decompressedOff);
-        assertEquals(originalLen, decompressedLen);
+        assertThat(decompressedLen).isEqualTo(originalLen);
 
         for (int i = 0; i < originalLen; i++) {
-            assertEquals(data[originalOff + i], decompressedData[decompressedOff + i]);
+            assertThat(decompressedData[decompressedOff + i]).isEqualTo(data[originalOff + i]);
         }
     }
 
@@ -129,7 +132,7 @@ public class BlockCompressionTest {
             compressedData = ByteBuffer.allocate(maxCompressedLen);
         }
         int compressedLen = compressor.compress(data, originalOff, originalLen, compressedData, 0);
-        assertEquals(compressedLen, compressedData.position());
+        assertThat(compressedData.position()).isEqualTo(compressedLen);
         compressedData.flip();
 
         int compressedOff = 32;
@@ -159,11 +162,11 @@ public class BlockCompressionTest {
         int decompressedLen =
                 decompressor.decompress(
                         copiedCompressedData, compressedOff, compressedLen, decompressedData, 0);
-        assertEquals(decompressedLen, decompressedData.position());
+        assertThat(decompressedData.position()).isEqualTo(decompressedLen);
         decompressedData.flip();
 
         for (int i = 0; i < decompressedLen; i++) {
-            assertEquals((byte) i, decompressedData.get());
+            assertThat(decompressedData.get()).isEqualTo((byte) i);
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
index f3074f9682b..8ad892b1d7a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
@@ -56,24 +56,17 @@ import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
 import org.apache.flink.shaded.netty4.io.netty.channel.epoll.Epoll;
 import org.apache.flink.shaded.netty4.io.netty.channel.unix.Errors;
 
-import org.junit.Assume;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
 
 import java.io.IOException;
 
 import static org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel;
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel;
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.instanceOf;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertSame;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assumptions.assumeThat;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -83,7 +76,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 /** Test for {@link CreditBasedPartitionRequestClientHandler}. */
-public class CreditBasedPartitionRequestClientHandlerTest {
+class CreditBasedPartitionRequestClientHandlerTest {
 
     /**
      * Tests a fix for FLINK-1627.
@@ -96,9 +89,10 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      *
      * @see <a href="https://issues.apache.org/jira/browse/FLINK-1627">FLINK-1627</a>
      */
-    @Test(timeout = 60000)
+    @Test
+    @Timeout(60)
     @SuppressWarnings("unchecked")
-    public void testReleaseInputChannelDuringDecode() throws Exception {
+    void testReleaseInputChannelDuringDecode() throws Exception {
         // Mocks an input channel in a state as it was released during a decode.
         final BufferProvider bufferProvider = mock(BufferProvider.class);
         when(bufferProvider.requestBuffer()).thenReturn(null);
@@ -130,7 +124,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * <p>FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0.
      */
     @Test
-    public void testReceiveEmptyBuffer() throws Exception {
+    void testReceiveEmptyBuffer() throws Exception {
         // Minimal mock of a remote input channel
         final BufferProvider bufferProvider = mock(BufferProvider.class);
         when(bufferProvider.requestBuffer()).thenReturn(TestBufferFactory.createBuffer(0));
@@ -168,7 +162,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * BufferResponse} is received.
      */
     @Test
-    public void testReceiveBuffer() throws Exception {
+    void testReceiveBuffer() throws Exception {
         final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32);
         final SingleInputGate inputGate = createSingleInputGate(1, networkBufferPool);
         final RemoteInputChannel inputChannel =
@@ -193,8 +187,8 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                             new NetworkBufferAllocator(handler));
             handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
 
-            assertEquals(1, inputChannel.getNumberOfQueuedBuffers());
-            assertEquals(2, inputChannel.getSenderBacklog());
+            assertThat(inputChannel.getNumberOfQueuedBuffers()).isEqualTo(1);
+            assertThat(inputChannel.getSenderBacklog()).isEqualTo(2);
         } finally {
             releaseResource(inputGate, networkBufferPool);
         }
@@ -204,7 +198,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * Verifies that {@link BufferResponse} of compressed {@link Buffer} can be handled correctly.
      */
     @Test
-    public void testReceiveCompressedBuffer() throws Exception {
+    void testReceiveCompressedBuffer() throws Exception {
         int bufferSize = 1024;
         String compressionCodec = "LZ4";
         BufferCompressor compressor = new BufferCompressor(bufferSize, compressionCodec);
@@ -236,12 +230,12 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                             inputChannel.getInputChannelId(),
                             2,
                             new NetworkBufferAllocator(handler));
-            assertTrue(bufferResponse.isCompressed);
+            assertThat(bufferResponse.isCompressed).isTrue();
             handler.channelRead(null, bufferResponse);
 
             Buffer receivedBuffer = inputChannel.getNextReceivedBuffer();
-            assertNotNull(receivedBuffer);
-            assertTrue(receivedBuffer.isCompressed());
+            assertThat(receivedBuffer).isNotNull();
+            assertThat(receivedBuffer.isCompressed()).isTrue();
             receivedBuffer.recycleBuffer();
         } finally {
             releaseResource(inputGate, networkBufferPool);
@@ -250,7 +244,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
     /** Verifies that {@link NettyMessage.BacklogAnnouncement} can be handled correctly. */
     @Test
-    public void testReceiveBacklogAnnouncement() throws Exception {
+    void testReceiveBacklogAnnouncement() throws Exception {
         int bufferSize = 1024;
         int numBuffers = 10;
         NetworkBufferPool networkBufferPool = new NetworkBufferPool(numBuffers, bufferSize);
@@ -268,26 +262,26 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                     new CreditBasedPartitionRequestClientHandler();
             handler.addInputChannel(inputChannel);
 
-            assertEquals(2, inputChannel.getNumberOfAvailableBuffers());
-            assertEquals(0, inputChannel.unsynchronizedGetFloatingBuffersAvailable());
+            assertThat(inputChannel.getNumberOfAvailableBuffers()).isEqualTo(2);
+            assertThat(inputChannel.unsynchronizedGetFloatingBuffersAvailable()).isZero();
 
             int backlog = 5;
             NettyMessage.BacklogAnnouncement announcement =
                     new NettyMessage.BacklogAnnouncement(backlog, inputChannel.getInputChannelId());
             handler.channelRead(null, announcement);
-            assertEquals(7, inputChannel.getNumberOfAvailableBuffers());
-            assertEquals(7, inputChannel.getNumberOfRequiredBuffers());
-            assertEquals(backlog, inputChannel.getSenderBacklog());
-            assertEquals(5, inputChannel.unsynchronizedGetFloatingBuffersAvailable());
+            assertThat(inputChannel.getNumberOfAvailableBuffers()).isEqualTo(7);
+            assertThat(inputChannel.getNumberOfRequiredBuffers()).isEqualTo(7);
+            assertThat(inputChannel.getSenderBacklog()).isEqualTo(backlog);
+            assertThat(inputChannel.unsynchronizedGetFloatingBuffersAvailable()).isEqualTo(5);
 
             backlog = 12;
             announcement =
                     new NettyMessage.BacklogAnnouncement(backlog, inputChannel.getInputChannelId());
             handler.channelRead(null, announcement);
-            assertEquals(10, inputChannel.getNumberOfAvailableBuffers());
-            assertEquals(14, inputChannel.getNumberOfRequiredBuffers());
-            assertEquals(backlog, inputChannel.getSenderBacklog());
-            assertEquals(8, inputChannel.unsynchronizedGetFloatingBuffersAvailable());
+            assertThat(inputChannel.getNumberOfAvailableBuffers()).isEqualTo(10);
+            assertThat(inputChannel.getNumberOfRequiredBuffers()).isEqualTo(14);
+            assertThat(inputChannel.getSenderBacklog()).isEqualTo(backlog);
+            assertThat(inputChannel.unsynchronizedGetFloatingBuffersAvailable()).isEqualTo(8);
         } finally {
             releaseResource(inputGate, networkBufferPool);
         }
@@ -298,7 +292,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * BufferResponse} is received but no available buffer in input channel.
      */
     @Test
-    public void testThrowExceptionForNoAvailableBuffer() throws Exception {
+    void testThrowExceptionForNoAvailableBuffer() throws Exception {
         final SingleInputGate inputGate = createSingleInputGate(1);
         final RemoteInputChannel inputChannel =
                 spy(InputChannelBuilder.newBuilder().buildRemoteChannel(inputGate));
@@ -307,10 +301,9 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                 new CreditBasedPartitionRequestClientHandler();
         handler.addInputChannel(inputChannel);
 
-        assertEquals(
-                "There should be no buffers available in the channel.",
-                0,
-                inputChannel.getNumberOfAvailableBuffers());
+        assertThat(inputChannel.getNumberOfAvailableBuffers())
+                .as("There should be no buffers available in the channel.")
+                .isEqualTo(0);
 
         final BufferResponse bufferResponse =
                 createBufferResponse(
@@ -319,7 +312,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                         inputChannel.getInputChannelId(),
                         2,
                         new NetworkBufferAllocator(handler));
-        assertNull(bufferResponse.getBuffer());
+        assertThat(bufferResponse.getBuffer()).isNull();
 
         handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
         verify(inputChannel, times(1)).onError(any(IllegalStateException.class));
@@ -330,7 +323,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * PartitionNotFoundException} is received.
      */
     @Test
-    public void testReceivePartitionNotFoundException() throws Exception {
+    void testReceivePartitionNotFoundException() throws Exception {
         // Minimal mock of a remote input channel
         final BufferProvider bufferProvider = mock(BufferProvider.class);
         when(bufferProvider.requestBuffer()).thenReturn(TestBufferFactory.createBuffer(0));
@@ -360,7 +353,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
     }
 
     @Test
-    public void testCancelBeforeActive() throws Exception {
+    void testCancelBeforeActive() throws Exception {
 
         final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
         when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
@@ -382,7 +375,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * changed.
      */
     @Test
-    public void testNotifyCreditAvailable() throws Exception {
+    void testNotifyCreditAvailable() throws Exception {
         final CreditBasedPartitionRequestClientHandler handler =
                 new CreditBasedPartitionRequestClientHandler();
         final NetworkBufferAllocator allocator = new NetworkBufferAllocator(handler);
@@ -409,20 +402,18 @@ public class CreditBasedPartitionRequestClientHandlerTest {
             inputChannels[1].requestSubpartition();
 
             // The two input channels should send partition requests
-            assertTrue(channel.isWritable());
+            assertThat(channel.isWritable()).isTrue();
             Object readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(PartitionRequest.class));
-            assertEquals(
-                    inputChannels[0].getInputChannelId(),
-                    ((PartitionRequest) readFromOutbound).receiverId);
-            assertEquals(2, ((PartitionRequest) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(PartitionRequest.class);
+            assertThat(inputChannels[0].getInputChannelId())
+                    .isEqualTo(((PartitionRequest) readFromOutbound).receiverId);
+            assertThat(((PartitionRequest) readFromOutbound).credit).isEqualTo(2);
 
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(PartitionRequest.class));
-            assertEquals(
-                    inputChannels[1].getInputChannelId(),
-                    ((PartitionRequest) readFromOutbound).receiverId);
-            assertEquals(2, ((PartitionRequest) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(PartitionRequest.class);
+            assertThat(inputChannels[1].getInputChannelId())
+                    .isEqualTo(((PartitionRequest) readFromOutbound).receiverId);
+            assertThat(((PartitionRequest) readFromOutbound).credit).isEqualTo(2);
 
             // The buffer response will take one available buffer from input channel, and it will
             // trigger
@@ -444,26 +435,24 @@ public class CreditBasedPartitionRequestClientHandlerTest {
             handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse1);
             handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse2);
 
-            assertEquals(2, inputChannels[0].getUnannouncedCredit());
-            assertEquals(2, inputChannels[1].getUnannouncedCredit());
+            assertThat(inputChannels[0].getUnannouncedCredit()).isEqualTo(2);
+            assertThat(inputChannels[1].getUnannouncedCredit()).isEqualTo(2);
 
             channel.runPendingTasks();
 
             // The two input channels should notify credits availability via the writable channel
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(AddCredit.class));
-            assertEquals(
-                    inputChannels[0].getInputChannelId(),
-                    ((AddCredit) readFromOutbound).receiverId);
-            assertEquals(2, ((AddCredit) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(AddCredit.class);
+            assertThat(inputChannels[0].getInputChannelId())
+                    .isEqualTo(((AddCredit) readFromOutbound).receiverId);
+            assertThat(((AddCredit) readFromOutbound).credit).isEqualTo(2);
 
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(AddCredit.class));
-            assertEquals(
-                    inputChannels[1].getInputChannelId(),
-                    ((AddCredit) readFromOutbound).receiverId);
-            assertEquals(2, ((AddCredit) readFromOutbound).credit);
-            assertNull(channel.readOutbound());
+            assertThat(readFromOutbound).isInstanceOf(AddCredit.class);
+            assertThat(inputChannels[1].getInputChannelId())
+                    .isEqualTo(((AddCredit) readFromOutbound).receiverId);
+            assertThat(((AddCredit) readFromOutbound).credit).isEqualTo(2);
+            assertThat((Object) channel.readOutbound()).isNull();
 
             ByteBuf channelBlockingBuffer = blockChannel(channel);
 
@@ -478,29 +467,29 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                             allocator);
             handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse3);
 
-            assertEquals(1, inputChannels[0].getUnannouncedCredit());
-            assertEquals(0, inputChannels[1].getUnannouncedCredit());
+            assertThat(inputChannels[0].getUnannouncedCredit()).isEqualTo(1);
+            assertThat(inputChannels[1].getUnannouncedCredit()).isZero();
 
             channel.runPendingTasks();
 
             // The input channel will not notify credits via un-writable channel
-            assertFalse(channel.isWritable());
-            assertNull(channel.readOutbound());
+            assertThat(channel.isWritable()).isFalse();
+            assertThat((Object) channel.readOutbound()).isNull();
 
             // Flush the buffer to make the channel writable again
             channel.flush();
-            assertSame(channelBlockingBuffer, channel.readOutbound());
+            assertThat(channelBlockingBuffer).isSameAs(channel.readOutbound());
 
             // The input channel should notify credits via channel's writability changed event
-            assertTrue(channel.isWritable());
+            assertThat(channel.isWritable()).isTrue();
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(AddCredit.class));
-            assertEquals(1, ((AddCredit) readFromOutbound).credit);
-            assertEquals(0, inputChannels[0].getUnannouncedCredit());
-            assertEquals(0, inputChannels[1].getUnannouncedCredit());
+            assertThat(readFromOutbound).isInstanceOf(AddCredit.class);
+            assertThat(((AddCredit) readFromOutbound).credit).isEqualTo(1);
+            assertThat(inputChannels[0].getUnannouncedCredit()).isZero();
+            assertThat(inputChannels[1].getUnannouncedCredit()).isZero();
 
             // no more messages
-            assertNull(channel.readOutbound());
+            assertThat((Object) channel.readOutbound()).isNull();
         } finally {
             releaseResource(inputGate, networkBufferPool);
             channel.close();
@@ -512,7 +501,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
      * message is not sent actually when this input channel is released.
      */
     @Test
-    public void testNotifyCreditAvailableAfterReleased() throws Exception {
+    void testNotifyCreditAvailableAfterReleased() throws Exception {
         final CreditBasedPartitionRequestClientHandler handler =
                 new CreditBasedPartitionRequestClientHandler();
         final EmbeddedChannel channel = new EmbeddedChannel(handler);
@@ -536,8 +525,8 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
             // This should send the partition request
             Object readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(PartitionRequest.class));
-            assertEquals(2, ((PartitionRequest) readFromOutbound).credit);
+            assertThat(readFromOutbound).isInstanceOf(PartitionRequest.class);
+            assertThat(((PartitionRequest) readFromOutbound).credit).isEqualTo(2);
 
             // Trigger request floating buffers via buffer response to notify credits available
             final BufferResponse bufferResponse =
@@ -549,7 +538,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                             new NetworkBufferAllocator(handler));
             handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
 
-            assertEquals(2, inputChannel.getUnannouncedCredit());
+            assertThat(inputChannel.getUnannouncedCredit()).isEqualTo(2);
 
             // Release the input channel
             inputGate.close();
@@ -557,11 +546,10 @@ public class CreditBasedPartitionRequestClientHandlerTest {
             // it should send a close request after releasing the input channel,
             // but will not notify credits for a released input channel.
             readFromOutbound = channel.readOutbound();
-            assertThat(readFromOutbound, instanceOf(CloseRequest.class));
+            assertThat(readFromOutbound).isInstanceOf(CloseRequest.class);
 
             channel.runPendingTasks();
-
-            assertNull(channel.readOutbound());
+            assertThat((Object) channel.readOutbound()).isNull();
         } finally {
             releaseResource(inputGate, networkBufferPool);
             channel.close();
@@ -569,27 +557,27 @@ public class CreditBasedPartitionRequestClientHandlerTest {
     }
 
     @Test
-    public void testReadBufferResponseBeforeReleasingChannel() throws Exception {
+    void testReadBufferResponseBeforeReleasingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(false, true);
     }
 
     @Test
-    public void testReadBufferResponseBeforeRemovingChannel() throws Exception {
+    void testReadBufferResponseBeforeRemovingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(true, true);
     }
 
     @Test
-    public void testReadBufferResponseAfterReleasingChannel() throws Exception {
+    void testReadBufferResponseAfterReleasingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(false, false);
     }
 
     @Test
-    public void testReadBufferResponseAfterRemovingChannel() throws Exception {
+    void testReadBufferResponseAfterRemovingChannel() throws Exception {
         testReadBufferResponseWithReleasingOrRemovingChannel(true, false);
     }
 
     @Test
-    public void testDoNotFailHandlerOnSingleChannelFailure() throws Exception {
+    void testDoNotFailHandlerOnSingleChannelFailure() throws Exception {
         // Setup
         final int bufferSize = 1024;
         final String expectedMessage = "test exception on buffer";
@@ -620,13 +608,11 @@ public class CreditBasedPartitionRequestClientHandlerTest {
             // The handler should not be tagged as error for above excepted exception
             handler.checkError();
 
-            try {
-                // The input channel should be tagged as error and the respective exception is
-                // thrown via #getNext
-                inputGate.getNext();
-            } catch (IOException ignored) {
-                assertEquals(expectedMessage, ignored.getMessage());
-            }
+            // The input channel should be tagged as error and the respective exception is
+            // thrown via #getNext
+            assertThatThrownBy(inputGate::getNext)
+                    .isInstanceOf(IOException.class)
+                    .hasMessage(expectedMessage);
         } finally {
             // Cleanup
             releaseResource(inputGate, networkBufferPool);
@@ -634,7 +620,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
     }
 
     @Test
-    public void testExceptionWrap() {
+    void testExceptionWrap() {
         testExceptionWrap(LocalTransportException.class, new Exception());
         testExceptionWrap(LocalTransportException.class, new Exception("some error"));
         testExceptionWrap(
@@ -642,7 +628,7 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
         // Only when Epoll is available the following exception could be initiated normally
         // since it relies on the native strerror method.
-        Assume.assumeTrue(Epoll.isAvailable());
+        assumeThat(Epoll.isAvailable()).isTrue();
         testExceptionWrap(
                 RemoteTransportException.class,
                 new Errors.NativeIoException("readAddress", Errors.ERRNO_ECONNRESET_NEGATIVE));
@@ -665,19 +651,16 @@ public class CreditBasedPartitionRequestClientHandlerTest {
                         handler);
 
         embeddedChannel.writeInbound(1);
-        try {
-            handler.checkError();
-            fail(
-                    String.format(
-                            "The handler should wrap the exception %s as %s, but it does not.",
-                            cause, expectedClass));
-        } catch (IOException e) {
-            assertThat(e, instanceOf(expectedClass));
-        }
+        assertThatThrownBy(() -> handler.checkError())
+                .isInstanceOf(expectedClass)
+                .withFailMessage(
+                        String.format(
+                                "The handler should wrap the exception %s as %s, but it does not.",
+                                cause, expectedClass));
     }
 
     @Test
-    public void testAnnounceBufferSize() throws Exception {
+    void testAnnounceBufferSize() throws Exception {
         final CreditBasedPartitionRequestClientHandler handler =
                 new CreditBasedPartitionRequestClientHandler();
         final EmbeddedChannel channel = new EmbeddedChannel(handler);
@@ -709,13 +692,13 @@ public class CreditBasedPartitionRequestClientHandlerTest {
             channel.runPendingTasks();
 
             NettyMessage.NewBufferSize readOutbound = channel.readOutbound();
-            assertThat(readOutbound, instanceOf(NettyMessage.NewBufferSize.class));
-            assertThat(readOutbound.receiverId, is(inputChannels[0].getInputChannelId()));
-            assertThat(readOutbound.bufferSize, is(333));
+            assertThat(readOutbound).isInstanceOf(NettyMessage.NewBufferSize.class);
+            assertThat(inputChannels[0].getInputChannelId()).isEqualTo(readOutbound.receiverId);
+            assertThat(readOutbound.bufferSize).isEqualTo(333);
 
             readOutbound = channel.readOutbound();
-            assertThat(readOutbound.receiverId, is(inputChannels[1].getInputChannelId()));
-            assertThat(readOutbound.bufferSize, is(333));
+            assertThat(inputChannels[1].getInputChannelId()).isEqualTo(readOutbound.receiverId);
+            assertThat(readOutbound.bufferSize).isEqualTo(333);
 
         } finally {
             releaseResource(inputGate, networkBufferPool);
@@ -766,19 +749,20 @@ public class CreditBasedPartitionRequestClientHandlerTest {
 
             handler.channelRead(null, bufferResponse);
 
-            assertEquals(0, inputChannel.getNumberOfQueuedBuffers());
+            assertThat(inputChannel.getNumberOfQueuedBuffers()).isZero();
             if (!readBeforeReleasingOrRemoving) {
-                assertNull(bufferResponse.getBuffer());
+                assertThat(bufferResponse.getBuffer()).isNull();
             } else {
-                assertNotNull(bufferResponse.getBuffer());
-                assertTrue(bufferResponse.getBuffer().isRecycled());
+                assertThat(bufferResponse.getBuffer()).isNotNull();
+                assertThat(bufferResponse.getBuffer().isRecycled()).isTrue();
             }
 
             embeddedChannel.runScheduledPendingTasks();
             NettyMessage.CancelPartitionRequest cancelPartitionRequest =
                     embeddedChannel.readOutbound();
-            assertNotNull(cancelPartitionRequest);
-            assertEquals(inputChannel.getInputChannelId(), cancelPartitionRequest.receiverId);
+            assertThat(cancelPartitionRequest).isNotNull();
+            assertThat(inputChannel.getInputChannelId())
+                    .isEqualTo(cancelPartitionRequest.receiverId);
         } finally {
             releaseResource(inputGate, networkBufferPool);
             embeddedChannel.close();
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
index ee42d8dd748..925a1444c32 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
@@ -30,13 +30,14 @@ import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
 import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.util.TestLoggerExtension;
 
 import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
 
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
 
 import java.io.IOException;
 import java.util.Random;
@@ -51,15 +52,14 @@ import static org.apache.flink.runtime.io.network.netty.NettyTestUtil.verifyErro
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createRemoteInputChannel;
 import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate;
 import static org.apache.flink.util.Preconditions.checkArgument;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
  * Tests for the serialization and deserialization of the various {@link NettyMessage} sub-classes
  * sent from server side to client side.
  */
-public class NettyMessageClientSideSerializationTest extends TestLogger {
+@ExtendWith(TestLoggerExtension.class)
+class NettyMessageClientSideSerializationTest {
 
     private static final int BUFFER_SIZE = 1024;
 
@@ -78,8 +78,8 @@ public class NettyMessageClientSideSerializationTest extends TestLogger {
 
     private InputChannelID inputChannelId;
 
-    @Before
-    public void setup() throws IOException, InterruptedException {
+    @BeforeEach
+    void setup() throws IOException, InterruptedException {
         networkBufferPool = new NetworkBufferPool(8, BUFFER_SIZE);
         inputGate = createSingleInputGate(1, networkBufferPool);
         RemoteInputChannel inputChannel =
@@ -100,8 +100,8 @@ public class NettyMessageClientSideSerializationTest extends TestLogger {
         inputChannelId = inputChannel.getInputChannelId();
     }
 
-    @After
-    public void tearDown() throws IOException {
+    @AfterEach
+    void tearDown() throws IOException {
         if (inputGate != null) {
             inputGate.close();
         }
@@ -117,43 +117,43 @@ public class NettyMessageClientSideSerializationTest extends TestLogger {
     }
 
     @Test
-    public void testErrorResponseWithoutErrorMessage() {
+    void testErrorResponseWithoutErrorMessage() {
         testErrorResponse(new ErrorResponse(new IllegalStateException(), inputChannelId));
     }
 
     @Test
-    public void testErrorResponseWithErrorMessage() {
+    void testErrorResponseWithErrorMessage() {
         testErrorResponse(
                 new ErrorResponse(
                         new IllegalStateException("Illegal illegal illegal"), inputChannelId));
     }
 
     @Test
-    public void testErrorResponseWithFatalError() {
+    void testErrorResponseWithFatalError() {
         testErrorResponse(new ErrorResponse(new IllegalStateException("Illegal illegal illegal")));
     }
 
     @Test
-    public void testOrdinaryBufferResponse() {
+    void testOrdinaryBufferResponse() {
         testBufferResponse(false, false);
     }
 
     @Test
-    public void testBufferResponseWithReadOnlySlice() {
+    void testBufferResponseWithReadOnlySlice() {
         testBufferResponse(true, false);
     }
 
     @Test
-    public void testCompressedBufferResponse() {
+    void testCompressedBufferResponse() {
         testBufferResponse(false, true);
     }
 
     @Test
-    public void testBacklogAnnouncement() {
+    void testBacklogAnnouncement() {
         BacklogAnnouncement expected = new BacklogAnnouncement(1024, inputChannelId);
         BacklogAnnouncement actual = encodeAndDecode(expected, channel);
-        assertEquals(expected.backlog, actual.backlog);
-        assertEquals(expected.receiverId, actual.receiverId);
+        assertThat(actual.backlog).isEqualTo(expected.backlog);
+        assertThat(actual.receiverId).isEqualTo(expected.receiverId);
     }
 
     private void testErrorResponse(ErrorResponse expect) {
@@ -189,22 +189,22 @@ public class NettyMessageClientSideSerializationTest extends TestLogger {
                         random.nextInt(Integer.MAX_VALUE));
         BufferResponse actual = encodeAndDecode(expected, channel);
 
-        assertTrue(buffer.isRecycled());
-        assertTrue(testBuffer.isRecycled());
-        assertNotNull(
-                "The request input channel should always have available buffers in this test.",
-                actual.getBuffer());
+        assertThat(buffer.isRecycled()).isTrue();
+        assertThat(testBuffer.isRecycled()).isTrue();
+        assertThat(actual.getBuffer())
+                .as("The request input channel should always have available buffers in this test.")
+                .isNotNull();
 
         Buffer decodedBuffer = actual.getBuffer();
         if (testCompressedBuffer) {
-            assertTrue(actual.isCompressed);
+            assertThat(actual.isCompressed).isTrue();
             decodedBuffer = decompress(decodedBuffer);
         }
 
         verifyBufferResponseHeader(expected, actual);
-        assertEquals(BUFFER_SIZE, decodedBuffer.readableBytes());
+        assertThat(decodedBuffer.readableBytes()).isEqualTo(BUFFER_SIZE);
         for (int i = 0; i < BUFFER_SIZE; i += 8) {
-            assertEquals(i, decodedBuffer.asByteBuf().readLong());
+            assertThat(decodedBuffer.asByteBuf().readLong()).isEqualTo(i);
         }
 
         // Release the received message.
@@ -213,7 +213,7 @@ public class NettyMessageClientSideSerializationTest extends TestLogger {
             decodedBuffer.recycleBuffer();
         }
 
-        assertTrue(actual.getBuffer().isRecycled());
+        assertThat(actual.getBuffer().isRecycled()).isTrue();
     }
 
     private Buffer decompress(Buffer buffer) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index dc0e383660f..44249a57bfa 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -72,7 +72,7 @@ import org.apache.flink.util.CompressedSerializedValue;
 
 import org.apache.flink.shaded.guava30.com.google.common.io.Closer;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -97,35 +97,40 @@ import static org.apache.flink.runtime.io.network.util.TestBufferFactory.createB
 import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
 import static org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder.createRemoteWithIdAndLocation;
 import static org.apache.flink.util.Preconditions.checkState;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.instanceOf;
-import static org.hamcrest.Matchers.is;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for {@link SingleInputGate}. */
 public class SingleInputGateTest extends InputGateTestBase {
 
-    @Test(expected = CheckpointException.class)
-    public void testCheckpointsDeclinedUnlessAllChannelsAreKnown() throws CheckpointException {
+    @Test
+    void testCheckpointsDeclinedUnlessAllChannelsAreKnown() throws CheckpointException {
         SingleInputGate gate =
                 createInputGate(createNettyShuffleEnvironment(), 1, ResultPartitionType.PIPELINED);
         gate.setInputChannels(
                 new InputChannelBuilder().setChannelIndex(0).buildUnknownChannel(gate));
-        gate.checkpointStarted(
-                new CheckpointBarrier(1L, 1L, alignedNoTimeout(CHECKPOINT, getDefault())));
+        assertThatThrownBy(
+                        () ->
+                                gate.checkpointStarted(
+                                        new CheckpointBarrier(
+                                                1L,
+                                                1L,
+                                                alignedNoTimeout(CHECKPOINT, getDefault()))))
+                .isInstanceOf(CheckpointException.class);
     }
 
-    @Test(expected = CheckpointException.class)
-    public void testCheckpointsDeclinedUnlessStateConsumed() throws CheckpointException {
+    @Test
+    void testCheckpointsDeclinedUnlessStateConsumed() throws CheckpointException {
         SingleInputGate gate = createInputGate(createNettyShuffleEnvironment());
         checkState(!gate.getStateConsumedFuture().isDone());
-        gate.checkpointStarted(
-                new CheckpointBarrier(1L, 1L, alignedNoTimeout(CHECKPOINT, getDefault())));
+        assertThatThrownBy(
+                        () ->
+                                gate.checkpointStarted(
+                                        new CheckpointBarrier(
+                                                1L,
+                                                1L,
+                                                alignedNoTimeout(CHECKPOINT, getDefault()))))
+                .isInstanceOf(CheckpointException.class);
     }
 
     /**
@@ -133,7 +138,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * exclusive buffers for {@link RemoteInputChannel}s, but should not request partitions.
      */
     @Test
-    public void testSetupLogic() throws Exception {
+    void testSetupLogic() throws Exception {
         final NettyShuffleEnvironment environment = createNettyShuffleEnvironment();
         final SingleInputGate inputGate = createInputGate(environment);
         try (Closer closer = Closer.create()) {
@@ -141,52 +146,53 @@ public class SingleInputGateTest extends InputGateTestBase {
             closer.register(inputGate::close);
 
             // before setup
-            assertNull(inputGate.getBufferPool());
+            assertThat(inputGate.getBufferPool()).isNull();
             for (InputChannel inputChannel : inputGate.getInputChannels().values()) {
-                assertTrue(
-                        inputChannel instanceof RecoveredInputChannel
-                                || inputChannel instanceof UnknownInputChannel);
+                assertThat(
+                                inputChannel instanceof RecoveredInputChannel
+                                        || inputChannel instanceof UnknownInputChannel)
+                        .isTrue();
                 if (inputChannel instanceof RecoveredInputChannel) {
-                    assertEquals(
-                            0,
-                            ((RecoveredInputChannel) inputChannel)
-                                    .bufferManager.getNumberOfAvailableBuffers());
+                    assertThat(
+                                    ((RecoveredInputChannel) inputChannel)
+                                            .bufferManager.getNumberOfAvailableBuffers())
+                            .isEqualTo(0);
                 }
             }
 
             inputGate.setup();
 
             // after setup
-            assertNotNull(inputGate.getBufferPool());
-            assertEquals(1, inputGate.getBufferPool().getNumberOfRequiredMemorySegments());
+            assertThat(inputGate.getBufferPool()).isNotNull();
+            assertThat(inputGate.getBufferPool().getNumberOfRequiredMemorySegments()).isEqualTo(1);
             for (InputChannel inputChannel : inputGate.getInputChannels().values()) {
                 if (inputChannel instanceof RemoteRecoveredInputChannel) {
-                    assertEquals(
-                            0,
-                            ((RemoteRecoveredInputChannel) inputChannel)
-                                    .bufferManager.getNumberOfAvailableBuffers());
+                    assertThat(
+                                    ((RemoteRecoveredInputChannel) inputChannel)
+                                            .bufferManager.getNumberOfAvailableBuffers())
+                            .isEqualTo(0);
                 } else if (inputChannel instanceof LocalRecoveredInputChannel) {
-                    assertEquals(
-                            0,
-                            ((LocalRecoveredInputChannel) inputChannel)
-                                    .bufferManager.getNumberOfAvailableBuffers());
+                    assertThat(
+                                    ((LocalRecoveredInputChannel) inputChannel)
+                                            .bufferManager.getNumberOfAvailableBuffers())
+                            .isEqualTo(0);
                 }
             }
 
             inputGate.convertRecoveredInputChannels();
-            assertNotNull(inputGate.getBufferPool());
-            assertEquals(1, inputGate.getBufferPool().getNumberOfRequiredMemorySegments());
+            assertThat(inputGate.getBufferPool()).isNotNull();
+            assertThat(inputGate.getBufferPool().getNumberOfRequiredMemorySegments()).isEqualTo(1);
             for (InputChannel inputChannel : inputGate.getInputChannels().values()) {
                 if (inputChannel instanceof RemoteInputChannel) {
-                    assertEquals(
-                            2, ((RemoteInputChannel) inputChannel).getNumberOfAvailableBuffers());
+                    assertThat(((RemoteInputChannel) inputChannel).getNumberOfAvailableBuffers())
+                            .isEqualTo(2);
                 }
             }
         }
     }
 
     @Test
-    public void testPartitionRequestLogic() throws Exception {
+    void testPartitionRequestLogic() throws Exception {
         final NettyShuffleEnvironment environment = new NettyShuffleEnvironmentBuilder().build();
         final SingleInputGate gate = createInputGate(environment);
 
@@ -203,15 +209,16 @@ public class SingleInputGateTest extends InputGateTestBase {
             gate.pollNext();
 
             final InputChannel remoteChannel = gate.getChannel(0);
-            assertThat(remoteChannel, instanceOf(RemoteInputChannel.class));
-            assertNotNull(((RemoteInputChannel) remoteChannel).getPartitionRequestClient());
-            assertEquals(2, ((RemoteInputChannel) remoteChannel).getInitialCredit());
+            assertThat(remoteChannel).isInstanceOf(RemoteInputChannel.class);
+            assertThat(((RemoteInputChannel) remoteChannel).getPartitionRequestClient())
+                    .isNotNull();
+            assertThat(((RemoteInputChannel) remoteChannel).getInitialCredit()).isEqualTo(2);
 
             final InputChannel localChannel = gate.getChannel(1);
-            assertThat(localChannel, instanceOf(LocalInputChannel.class));
-            assertNotNull(((LocalInputChannel) localChannel).getSubpartitionView());
+            assertThat(localChannel).isInstanceOf(LocalInputChannel.class);
+            assertThat(((LocalInputChannel) localChannel).getSubpartitionView()).isNotNull();
 
-            assertThat(gate.getChannel(2), instanceOf(UnknownInputChannel.class));
+            assertThat(gate.getChannel(2)).isInstanceOf(UnknownInputChannel.class);
         }
     }
 
@@ -220,7 +227,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * value after receiving all end-of-partition events.
      */
     @Test
-    public void testBasicGetNextLogic() throws Exception {
+    void testBasicGetNextLogic() throws Exception {
         // Setup
         final SingleInputGate inputGate = createInputGate();
 
@@ -248,20 +255,19 @@ public class SingleInputGateTest extends InputGateTestBase {
         verifyBufferOrEvent(inputGate, true, 0, true);
         verifyBufferOrEvent(inputGate, false, 1, true);
         // we have received EndOfData on a single channel only
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
-                inputGate.hasReceivedEndOfData());
+        assertThat(inputGate.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA);
         verifyBufferOrEvent(inputGate, false, 0, true);
-        assertFalse(inputGate.isFinished());
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.DRAINED, inputGate.hasReceivedEndOfData());
+        assertThat(inputGate.isFinished()).isFalse();
+        assertThat(inputGate.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.DRAINED);
         verifyBufferOrEvent(inputGate, false, 1, true);
         verifyBufferOrEvent(inputGate, false, 0, false);
 
         // Return null when the input gate has received all end-of-partition events
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.DRAINED, inputGate.hasReceivedEndOfData());
-        assertTrue(inputGate.isFinished());
+        assertThat(inputGate.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.DRAINED);
+        assertThat(inputGate.isFinished()).isTrue();
 
         for (TestInputChannel ic : inputChannels) {
             ic.assertReturnedEventsAreRecycled();
@@ -269,7 +275,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testDrainFlagComputation() throws Exception {
+    void testDrainFlagComputation() throws Exception {
         // Setup
         final SingleInputGate inputGate1 = createInputGate();
         final SingleInputGate inputGate2 = createInputGate();
@@ -299,23 +305,21 @@ public class SingleInputGateTest extends InputGateTestBase {
 
         verifyBufferOrEvent(inputGate1, false, 0, true);
         // we have received EndOfData on a single channel only
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
-                inputGate1.hasReceivedEndOfData());
+        assertThat(inputGate1.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA);
         verifyBufferOrEvent(inputGate1, false, 1, true);
         // one of the channels said we should not drain
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.STOPPED, inputGate1.hasReceivedEndOfData());
+        assertThat(inputGate1.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.STOPPED);
 
         verifyBufferOrEvent(inputGate2, false, 0, true);
         // we have received EndOfData on a single channel only
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
-                inputGate2.hasReceivedEndOfData());
+        assertThat(inputGate2.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA);
         verifyBufferOrEvent(inputGate2, false, 1, true);
         // both channels said we should drain
-        assertEquals(
-                PullingAsyncDataInput.EndOfDataStatus.DRAINED, inputGate2.hasReceivedEndOfData());
+        assertThat(inputGate2.hasReceivedEndOfData())
+                .isEqualTo(PullingAsyncDataInput.EndOfDataStatus.DRAINED);
     }
 
     /**
@@ -323,7 +327,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * SingleInputGate#getNext()}.
      */
     @Test
-    public void testGetCompressedBuffer() throws Exception {
+    void testGetCompressedBuffer() throws Exception {
         int bufferSize = 1024;
         String compressionCodec = "LZ4";
         BufferCompressor compressor = new BufferCompressor(bufferSize, compressionCodec);
@@ -340,15 +344,15 @@ public class SingleInputGateTest extends InputGateTestBase {
             Buffer uncompressedBuffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE);
             uncompressedBuffer.setSize(bufferSize);
             Buffer compressedBuffer = compressor.compressToOriginalBuffer(uncompressedBuffer);
-            assertTrue(compressedBuffer.isCompressed());
+            assertThat(compressedBuffer.isCompressed()).isTrue();
 
             inputChannel.read(compressedBuffer);
             inputGate.setInputChannels(inputChannel);
             inputGate.notifyChannelNonEmpty(inputChannel);
 
             Optional<BufferOrEvent> bufferOrEvent = inputGate.getNext();
-            assertTrue(bufferOrEvent.isPresent());
-            assertTrue(bufferOrEvent.get().isBuffer());
+            assertThat(bufferOrEvent.isPresent()).isTrue();
+            assertThat(bufferOrEvent.get().isBuffer()).isTrue();
             ByteBuffer buffer =
                     bufferOrEvent
                             .get()
@@ -356,29 +360,29 @@ public class SingleInputGateTest extends InputGateTestBase {
                             .getNioBufferReadable()
                             .order(ByteOrder.LITTLE_ENDIAN);
             for (int i = 0; i < bufferSize; i += 8) {
-                assertEquals(i, buffer.getLong());
+                assertThat(buffer.getLong()).isEqualTo(i);
             }
         }
     }
 
     @Test
-    public void testNotifyAfterEndOfPartition() throws Exception {
+    void testNotifyAfterEndOfPartition() throws Exception {
         final SingleInputGate inputGate = createInputGate(2);
         TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
         inputGate.setInputChannels(inputChannel, new TestInputChannel(inputGate, 1));
 
         inputChannel.readEndOfPartitionEvent();
         inputChannel.notifyChannelNonEmpty();
-        assertEquals(EndOfPartitionEvent.INSTANCE, inputGate.pollNext().get().getEvent());
+        assertThat(inputGate.pollNext().get().getEvent()).isEqualTo(EndOfPartitionEvent.INSTANCE);
 
         // gate is still active because of secondary channel
         // test if released channel is enqueued
         inputChannel.notifyChannelNonEmpty();
-        assertFalse(inputGate.pollNext().isPresent());
+        assertThat(inputGate.pollNext().isPresent()).isFalse();
     }
 
     @Test
-    public void testIsAvailable() throws Exception {
+    void testIsAvailable() throws Exception {
         final SingleInputGate inputGate = createInputGate(1);
         TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
         inputGate.setInputChannels(inputChannel);
@@ -387,7 +391,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testIsAvailableAfterFinished() throws Exception {
+    void testIsAvailableAfterFinished() throws Exception {
         final SingleInputGate inputGate = createInputGate(1);
         TestInputChannel inputChannel = new TestInputChannel(inputGate, 0);
         inputGate.setInputChannels(inputChannel);
@@ -401,7 +405,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testIsMoreAvailableReadingFromSingleInputChannel() throws Exception {
+    void testIsMoreAvailableReadingFromSingleInputChannel() throws Exception {
         // Setup
         final SingleInputGate inputGate = createInputGate();
 
@@ -423,7 +427,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testBackwardsEventWithUninitializedChannel() throws Exception {
+    void testBackwardsEventWithUninitializedChannel() throws Exception {
         // Setup environment
         TestingTaskEventPublisher taskEventPublisher = new TestingTaskEventPublisher();
 
@@ -464,14 +468,14 @@ public class SingleInputGateTest extends InputGateTestBase {
             setupInputGate(inputGate, inputChannels);
 
             // Only the local channel can request
-            assertEquals(1, partitionManager.counter);
+            assertThat(partitionManager.counter).isEqualTo(1);
 
             // Send event backwards and initialize unknown channel afterwards
             final TaskEvent event = new TestTaskEvent();
             inputGate.sendTaskEvent(event);
 
             // Only the local channel can send out the event
-            assertEquals(1, taskEventPublisher.counter);
+            assertThat(taskEventPublisher.counter).isEqualTo(1);
 
             // After the update, the pending event should be send to local channel
 
@@ -480,8 +484,8 @@ public class SingleInputGateTest extends InputGateTestBase {
                     location,
                     createRemoteWithIdAndLocation(unknownPartitionId.getPartitionId(), location));
 
-            assertEquals(2, partitionManager.counter);
-            assertEquals(2, taskEventPublisher.counter);
+            assertThat(partitionManager.counter).isEqualTo(2);
+            assertThat(taskEventPublisher.counter).isEqualTo(2);
         }
     }
 
@@ -492,7 +496,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * listener.
      */
     @Test
-    public void testUpdateChannelBeforeRequest() throws Exception {
+    void testUpdateChannelBeforeRequest() throws Exception {
         SingleInputGate inputGate = createInputGate(1);
 
         TestingResultPartitionManager partitionManager =
@@ -511,7 +515,7 @@ public class SingleInputGateTest extends InputGateTestBase {
                 location,
                 createRemoteWithIdAndLocation(resultPartitionID.getPartitionId(), location));
 
-        assertEquals(0, partitionManager.counter);
+        assertThat(partitionManager.counter).isEqualTo(0);
     }
 
     /**
@@ -519,7 +523,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * data.
      */
     @Test
-    public void testReleaseWhilePollingChannel() throws Exception {
+    void testReleaseWhilePollingChannel() throws Exception {
         final AtomicReference<Exception> asyncException = new AtomicReference<>();
 
         // Setup the input gate with a single channel that does nothing
@@ -558,7 +562,7 @@ public class SingleInputGateTest extends InputGateTestBase {
         }
 
         // Verify that async consumer is in blocking request
-        assertTrue("Did not trigger blocking buffer request.", success);
+        assertThat(success).as("Did not trigger blocking buffer request.").isTrue();
 
         // Release the input gate
         inputGate.close();
@@ -568,13 +572,13 @@ public class SingleInputGateTest extends InputGateTestBase {
         // call will never return.
         asyncConsumer.join();
 
-        assertNotNull(asyncException.get());
-        assertEquals(IllegalStateException.class, asyncException.get().getClass());
+        assertThat(asyncException.get()).isNotNull();
+        assertThat(asyncException.get().getClass()).isEqualTo(IllegalStateException.class);
     }
 
     /** Tests request back off configuration is correctly forwarded to the channels. */
     @Test
-    public void testRequestBackoffConfiguration() throws Exception {
+    void testRequestBackoffConfiguration() throws Exception {
         IntermediateResultPartitionID[] partitionIds =
                 new IntermediateResultPartitionID[] {
                     new IntermediateResultPartitionID(),
@@ -605,11 +609,11 @@ public class SingleInputGateTest extends InputGateTestBase {
             closer.register(netEnv::close);
             closer.register(gate::close);
 
-            assertEquals(ResultPartitionType.PIPELINED, gate.getConsumedPartitionType());
+            assertThat(gate.getConsumedPartitionType()).isEqualTo(ResultPartitionType.PIPELINED);
 
             Map<SubpartitionInfo, InputChannel> channelMap = gate.getInputChannels();
 
-            assertEquals(3, channelMap.size());
+            assertThat(channelMap.size()).isEqualTo(3);
             channelMap
                     .values()
                     .forEach(
@@ -621,39 +625,39 @@ public class SingleInputGateTest extends InputGateTestBase {
                                 }
                             });
             InputChannel localChannel = channelMap.get(createSubpartitionInfo(partitionIds[0]));
-            assertEquals(LocalInputChannel.class, localChannel.getClass());
+            assertThat(localChannel.getClass()).isEqualTo(LocalInputChannel.class);
 
             InputChannel remoteChannel = channelMap.get(createSubpartitionInfo(partitionIds[1]));
-            assertEquals(RemoteInputChannel.class, remoteChannel.getClass());
+            assertThat(remoteChannel.getClass()).isEqualTo(RemoteInputChannel.class);
 
             InputChannel unknownChannel = channelMap.get(createSubpartitionInfo(partitionIds[2]));
-            assertEquals(UnknownInputChannel.class, unknownChannel.getClass());
+            assertThat(unknownChannel.getClass()).isEqualTo(UnknownInputChannel.class);
 
             InputChannel[] channels =
                     new InputChannel[] {localChannel, remoteChannel, unknownChannel};
             for (InputChannel ch : channels) {
-                assertEquals(0, ch.getCurrentBackoff());
+                assertThat(ch.getCurrentBackoff()).isEqualTo(0);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(initialBackoff, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(initialBackoff);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(initialBackoff * 2, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(initialBackoff * 2);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(initialBackoff * 2 * 2);
 
-                assertTrue(ch.increaseBackoff());
-                assertEquals(maxBackoff, ch.getCurrentBackoff());
+                assertThat(ch.increaseBackoff()).isTrue();
+                assertThat(ch.getCurrentBackoff()).isEqualTo(maxBackoff);
 
-                assertFalse(ch.increaseBackoff());
+                assertThat(ch.increaseBackoff()).isFalse();
             }
         }
     }
 
     /** Tests that input gate requests and assigns network buffers for remote input channel. */
     @Test
-    public void testRequestBuffersWithRemoteInputChannel() throws Exception {
+    void testRequestBuffersWithRemoteInputChannel() throws Exception {
         final NettyShuffleEnvironment network = createNettyShuffleEnvironment();
         final SingleInputGate inputGate =
                 createInputGate(network, 1, ResultPartitionType.PIPELINED_BOUNDED);
@@ -674,14 +678,13 @@ public class SingleInputGateTest extends InputGateTestBase {
 
             NetworkBufferPool bufferPool = network.getNetworkBufferPool();
             // only the exclusive buffers should be assigned/available now
-            assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers());
+            assertThat(remote.getNumberOfAvailableBuffers()).isEqualTo(buffersPerChannel);
 
-            assertEquals(
-                    bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel - 1,
-                    bufferPool.getNumberOfAvailableMemorySegments());
+            assertThat(bufferPool.getNumberOfAvailableMemorySegments())
+                    .isEqualTo(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel - 1);
             // note: exclusive buffers are not handed out into LocalBufferPool and are thus not
             // counted
-            assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+            assertThat(bufferPool.countBuffers()).isEqualTo(extraNetworkBuffersPerGate);
         }
     }
 
@@ -690,7 +693,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * to remote input channel.
      */
     @Test
-    public void testRequestBuffersWithUnknownInputChannel() throws Exception {
+    void testRequestBuffersWithUnknownInputChannel() throws Exception {
         final NettyShuffleEnvironment network = createNettyShuffleEnvironment();
         final SingleInputGate inputGate =
                 createInputGate(network, 1, ResultPartitionType.PIPELINED_BOUNDED);
@@ -709,12 +712,11 @@ public class SingleInputGateTest extends InputGateTestBase {
             inputGate.setup();
             NetworkBufferPool bufferPool = network.getNetworkBufferPool();
 
-            assertEquals(
-                    bufferPool.getTotalNumberOfMemorySegments() - 1,
-                    bufferPool.getNumberOfAvailableMemorySegments());
+            assertThat(bufferPool.getNumberOfAvailableMemorySegments())
+                    .isEqualTo(bufferPool.getTotalNumberOfMemorySegments() - 1);
             // note: exclusive buffers are not handed out into LocalBufferPool and are thus not
             // counted
-            assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+            assertThat(bufferPool.countBuffers()).isEqualTo(extraNetworkBuffersPerGate);
 
             // Trigger updates to remote input channel from unknown input channel
             inputGate.updateInputChannel(
@@ -730,14 +732,13 @@ public class SingleInputGateTest extends InputGateTestBase {
                                             createSubpartitionInfo(
                                                     resultPartitionId.getPartitionId()));
             // only the exclusive buffers should be assigned/available now
-            assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers());
+            assertThat(remote.getNumberOfAvailableBuffers()).isEqualTo(buffersPerChannel);
 
-            assertEquals(
-                    bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel - 1,
-                    bufferPool.getNumberOfAvailableMemorySegments());
+            assertThat(bufferPool.getNumberOfAvailableMemorySegments())
+                    .isEqualTo(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel - 1);
             // note: exclusive buffers are not handed out into LocalBufferPool and are thus not
             // counted
-            assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers());
+            assertThat(bufferPool.countBuffers()).isEqualTo(extraNetworkBuffersPerGate);
         }
     }
 
@@ -746,7 +747,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * channels.
      */
     @Test
-    public void testUpdateUnknownInputChannel() throws Exception {
+    void testUpdateUnknownInputChannel() throws Exception {
         final NettyShuffleEnvironment network = createNettyShuffleEnvironment();
 
         final ResultPartition localResultPartition =
@@ -785,15 +786,19 @@ public class SingleInputGateTest extends InputGateTestBase {
             inputGate.setup();
 
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            .get(createSubpartitionInfo(remoteResultPartitionId.getPartitionId())),
-                    is(instanceOf((UnknownInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    remoteResultPartitionId.getPartitionId())))
+                    .isInstanceOf(UnknownInputChannel.class);
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            .get(createSubpartitionInfo(localResultPartitionId.getPartitionId())),
-                    is(instanceOf((UnknownInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    localResultPartitionId.getPartitionId())))
+                    .isInstanceOf(UnknownInputChannel.class);
 
             ResourceID localLocation = ResourceID.generate();
 
@@ -804,15 +809,19 @@ public class SingleInputGateTest extends InputGateTestBase {
                             remoteResultPartitionId.getPartitionId(), ResourceID.generate()));
 
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            .get(createSubpartitionInfo(remoteResultPartitionId.getPartitionId())),
-                    is(instanceOf((RemoteInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    remoteResultPartitionId.getPartitionId())))
+                    .isInstanceOf(RemoteInputChannel.class);
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            .get(createSubpartitionInfo(localResultPartitionId.getPartitionId())),
-                    is(instanceOf((UnknownInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    localResultPartitionId.getPartitionId())))
+                    .isInstanceOf(UnknownInputChannel.class);
 
             // Trigger updates to local input channel from unknown input channel
             inputGate.updateInputChannel(
@@ -821,21 +830,24 @@ public class SingleInputGateTest extends InputGateTestBase {
                             localResultPartitionId.getPartitionId(), localLocation));
 
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            .get(createSubpartitionInfo(remoteResultPartitionId.getPartitionId())),
-                    is(instanceOf((RemoteInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    remoteResultPartitionId.getPartitionId())))
+                    .isInstanceOf(RemoteInputChannel.class);
             assertThat(
-                    inputGate
-                            .getInputChannels()
-                            .get(createSubpartitionInfo(localResultPartitionId.getPartitionId())),
-                    is(instanceOf((LocalInputChannel.class))));
+                            inputGate
+                                    .getInputChannels()
+                                    .get(
+                                            createSubpartitionInfo(
+                                                    localResultPartitionId.getPartitionId())))
+                    .isInstanceOf(LocalInputChannel.class);
         }
     }
 
     @Test
-    public void testSingleInputGateWithSubpartitionIndexRange()
-            throws IOException, InterruptedException {
+    void testSingleInputGateWithSubpartitionIndexRange() throws IOException, InterruptedException {
 
         IntermediateResultPartitionID[] partitionIds =
                 new IntermediateResultPartitionID[] {
@@ -872,13 +884,13 @@ public class SingleInputGateTest extends InputGateTestBase {
         SubpartitionInfo info5 = createSubpartitionInfo(partitionIds[2], 0);
         SubpartitionInfo info6 = createSubpartitionInfo(partitionIds[2], 1);
 
-        assertThat(gate.getInputChannels().size(), is(6));
-        assertThat(gate.getInputChannels().get(info1).getConsumedSubpartitionIndex(), is(0));
-        assertThat(gate.getInputChannels().get(info2).getConsumedSubpartitionIndex(), is(1));
-        assertThat(gate.getInputChannels().get(info3).getConsumedSubpartitionIndex(), is(0));
-        assertThat(gate.getInputChannels().get(info4).getConsumedSubpartitionIndex(), is(1));
-        assertThat(gate.getInputChannels().get(info5).getConsumedSubpartitionIndex(), is(0));
-        assertThat(gate.getInputChannels().get(info6).getConsumedSubpartitionIndex(), is(1));
+        assertThat(gate.getInputChannels().size()).isEqualTo(6);
+        assertThat(gate.getInputChannels().get(info1).getConsumedSubpartitionIndex()).isEqualTo(0);
+        assertThat(gate.getInputChannels().get(info2).getConsumedSubpartitionIndex()).isEqualTo(1);
+        assertThat(gate.getInputChannels().get(info3).getConsumedSubpartitionIndex()).isEqualTo(0);
+        assertThat(gate.getInputChannels().get(info4).getConsumedSubpartitionIndex()).isEqualTo(1);
+        assertThat(gate.getInputChannels().get(info5).getConsumedSubpartitionIndex()).isEqualTo(0);
+        assertThat(gate.getInputChannels().get(info6).getConsumedSubpartitionIndex()).isEqualTo(1);
 
         assertChannelsType(gate, LocalRecoveredInputChannel.class, Arrays.asList(info1, info2));
         assertChannelsType(gate, RemoteRecoveredInputChannel.class, Arrays.asList(info3, info4));
@@ -886,8 +898,8 @@ public class SingleInputGateTest extends InputGateTestBase {
 
         // test setup
         gate.setup();
-        assertNotNull(gate.getBufferPool());
-        assertEquals(1, gate.getBufferPool().getNumberOfRequiredMemorySegments());
+        assertThat(gate.getBufferPool()).isNotNull();
+        assertThat(gate.getBufferPool().getNumberOfRequiredMemorySegments()).isEqualTo(1);
 
         gate.finishReadRecoveredState();
         while (!gate.getStateConsumedFuture().isDone()) {
@@ -902,10 +914,11 @@ public class SingleInputGateTest extends InputGateTestBase {
         assertChannelsType(gate, UnknownInputChannel.class, Arrays.asList(info5, info6));
         for (InputChannel inputChannel : gate.getInputChannels().values()) {
             if (inputChannel instanceof RemoteInputChannel) {
-                assertNotNull(((RemoteInputChannel) inputChannel).getPartitionRequestClient());
-                assertEquals(2, ((RemoteInputChannel) inputChannel).getInitialCredit());
+                assertThat(((RemoteInputChannel) inputChannel).getPartitionRequestClient())
+                        .isNotNull();
+                assertThat(((RemoteInputChannel) inputChannel).getInitialCredit()).isEqualTo(2);
             } else if (inputChannel instanceof LocalInputChannel) {
-                assertNotNull(((LocalInputChannel) inputChannel).getSubpartitionView());
+                assertThat(((LocalInputChannel) inputChannel).getSubpartitionView()).isNotNull();
             }
         }
 
@@ -920,12 +933,12 @@ public class SingleInputGateTest extends InputGateTestBase {
     private void assertChannelsType(
             SingleInputGate gate, Class<?> clazz, List<SubpartitionInfo> infos) {
         for (SubpartitionInfo subpartitionInfo : infos) {
-            assertThat(gate.getInputChannels().get(subpartitionInfo), instanceOf(clazz));
+            assertThat(gate.getInputChannels().get(subpartitionInfo)).isInstanceOf(clazz);
         }
     }
 
     @Test
-    public void testQueuedBuffers() throws Exception {
+    void testQueuedBuffers() throws Exception {
         final NettyShuffleEnvironment network = createNettyShuffleEnvironment();
 
         final BufferWritingResultPartition resultPartition =
@@ -966,10 +979,10 @@ public class SingleInputGateTest extends InputGateTestBase {
             setupInputGate(inputGate, inputChannels);
 
             remoteInputChannel.onBuffer(createBuffer(1), 0, 0);
-            assertEquals(1, inputGate.getNumberOfQueuedBuffers());
+            assertThat(inputGate.getNumberOfQueuedBuffers()).isEqualTo(1);
 
             resultPartition.emitRecord(ByteBuffer.allocate(1), 0);
-            assertEquals(2, inputGate.getNumberOfQueuedBuffers());
+            assertThat(inputGate.getNumberOfQueuedBuffers()).isEqualTo(2);
         }
     }
 
@@ -979,7 +992,7 @@ public class SingleInputGateTest extends InputGateTestBase {
      * the {@link SingleInputGate} would not swallow or transform the original exception.
      */
     @Test
-    public void testPartitionNotFoundExceptionWhileGetNextBuffer() throws Exception {
+    void testPartitionNotFoundExceptionWhileGetNextBuffer() throws Exception {
         final SingleInputGate inputGate = InputChannelTestUtils.createSingleInputGate(1);
         final LocalInputChannel localChannel =
                 createLocalInputChannel(inputGate, new ResultPartitionManager());
@@ -987,17 +1000,16 @@ public class SingleInputGateTest extends InputGateTestBase {
 
         inputGate.setInputChannels(localChannel);
         localChannel.setError(new PartitionNotFoundException(partitionId));
-        try {
-            inputGate.getNext();
-
-            fail("Should throw a PartitionNotFoundException.");
-        } catch (PartitionNotFoundException notFound) {
-            assertThat(partitionId, is(notFound.getPartitionId()));
-        }
+        assertThatThrownBy(inputGate::getNext)
+                .isInstanceOfSatisfying(
+                        PartitionNotFoundException.class,
+                        (notFoundException) ->
+                                assertThat(notFoundException.getPartitionId())
+                                        .isEqualTo(partitionId));
     }
 
     @Test
-    public void testAnnounceBufferSize() throws Exception {
+    void testAnnounceBufferSize() throws Exception {
         final SingleInputGate inputGate = InputChannelTestUtils.createSingleInputGate(2);
         final LocalInputChannel localChannel =
                 createLocalInputChannel(
@@ -1028,7 +1040,7 @@ public class SingleInputGateTest extends InputGateTestBase {
     }
 
     @Test
-    public void testInputGateRemovalFromNettyShuffleEnvironment() throws Exception {
+    void testInputGateRemovalFromNettyShuffleEnvironment() throws Exception {
         NettyShuffleEnvironment network = createNettyShuffleEnvironment();
 
         try (Closer closer = Closer.create()) {
@@ -1038,18 +1050,18 @@ public class SingleInputGateTest extends InputGateTestBase {
             Map<InputGateID, SingleInputGate> createdInputGatesById =
                     createInputGateWithLocalChannels(network, numberOfGates, 1);
 
-            assertEquals(numberOfGates, createdInputGatesById.size());
+            assertThat(createdInputGatesById.size()).isEqualTo(numberOfGates);
 
             for (InputGateID id : createdInputGatesById.keySet()) {
-                assertThat(network.getInputGate(id).isPresent(), is(true));
+                assertThat(network.getInputGate(id).isPresent()).isTrue();
                 createdInputGatesById.get(id).close();
-                assertThat(network.getInputGate(id).isPresent(), is(false));
+                assertThat(network.getInputGate(id).isPresent()).isFalse();
             }
         }
     }
 
     @Test
-    public void testSingleInputGateInfo() {
+    void testSingleInputGateInfo() {
         final int numSingleInputGates = 2;
         final int numInputChannels = 3;
 
@@ -1064,14 +1076,14 @@ public class SingleInputGateTest extends InputGateTestBase {
             for (InputChannel inputChannel : gate.getInputChannels().values()) {
                 InputChannelInfo channelInfo = inputChannel.getChannelInfo();
 
-                assertEquals(i, channelInfo.getGateIdx());
-                assertEquals(channelCounter++, channelInfo.getInputChannelIdx());
+                assertThat(channelInfo.getGateIdx()).isEqualTo(i);
+                assertThat(channelInfo.getInputChannelIdx()).isEqualTo(channelCounter++);
             }
         }
     }
 
     @Test
-    public void testGetUnfinishedChannels() throws IOException, InterruptedException {
+    void testGetUnfinishedChannels() throws IOException, InterruptedException {
         SingleInputGate inputGate =
                 new SingleInputGateBuilder()
                         .setSingleInputGateIndex(1)
@@ -1085,35 +1097,36 @@ public class SingleInputGateTest extends InputGateTestBase {
                 };
         inputGate.setInputChannels(inputChannels);
 
-        assertEquals(
-                Arrays.asList(
-                        inputChannels[0].getChannelInfo(),
-                        inputChannels[1].getChannelInfo(),
-                        inputChannels[2].getChannelInfo()),
-                inputGate.getUnfinishedChannels());
+        assertThat(inputGate.getUnfinishedChannels())
+                .isEqualTo(
+                        Arrays.asList(
+                                inputChannels[0].getChannelInfo(),
+                                inputChannels[1].getChannelInfo(),
+                                inputChannels[2].getChannelInfo()));
 
         inputChannels[1].readEndOfPartitionEvent();
         inputGate.notifyChannelNonEmpty(inputChannels[1]);
         inputGate.getNext();
-        assertEquals(
-                Arrays.asList(inputChannels[0].getChannelInfo(), inputChannels[2].getChannelInfo()),
-                inputGate.getUnfinishedChannels());
+        assertThat(inputGate.getUnfinishedChannels())
+                .isEqualTo(
+                        Arrays.asList(
+                                inputChannels[0].getChannelInfo(),
+                                inputChannels[2].getChannelInfo()));
 
         inputChannels[0].readEndOfPartitionEvent();
         inputGate.notifyChannelNonEmpty(inputChannels[0]);
         inputGate.getNext();
-        assertEquals(
-                Collections.singletonList(inputChannels[2].getChannelInfo()),
-                inputGate.getUnfinishedChannels());
+        assertThat(inputGate.getUnfinishedChannels())
+                .isEqualTo(Collections.singletonList(inputChannels[2].getChannelInfo()));
 
         inputChannels[2].readEndOfPartitionEvent();
         inputGate.notifyChannelNonEmpty(inputChannels[2]);
         inputGate.getNext();
-        assertEquals(Collections.emptyList(), inputGate.getUnfinishedChannels());
+        assertThat(inputGate.getUnfinishedChannels()).isEqualTo(Collections.emptyList());
     }
 
     @Test
-    public void testBufferInUseCount() throws Exception {
+    void testBufferInUseCount() throws Exception {
         // Setup
         final SingleInputGate inputGate = createInputGate();
 
@@ -1125,17 +1138,17 @@ public class SingleInputGateTest extends InputGateTestBase {
         inputGate.setInputChannels(inputChannels);
 
         // It should be no buffers when all channels are empty.
-        assertThat(inputGate.getBuffersInUseCount(), is(0));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(0);
 
         // Add buffers into channels.
         inputChannels[0].readBuffer();
-        assertThat(inputGate.getBuffersInUseCount(), is(1));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(1);
 
         inputChannels[0].readBuffer();
-        assertThat(inputGate.getBuffersInUseCount(), is(2));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(2);
 
         inputChannels[1].readBuffer();
-        assertThat(inputGate.getBuffersInUseCount(), is(3));
+        assertThat(inputGate.getBuffersInUseCount()).isEqualTo(3);
     }
 
     // ---------------------------------------------------------------------------------------------
@@ -1278,14 +1291,13 @@ public class SingleInputGateTest extends InputGateTestBase {
             throws IOException, InterruptedException {
 
         final Optional<BufferOrEvent> bufferOrEvent = inputGate.getNext();
-        assertTrue(bufferOrEvent.isPresent());
-        assertEquals(expectedIsBuffer, bufferOrEvent.get().isBuffer());
-        assertEquals(
-                inputGate.getChannel(expectedChannelIndex).getChannelInfo(),
-                bufferOrEvent.get().getChannelInfo());
-        assertEquals(expectedMoreAvailable, bufferOrEvent.get().moreAvailable());
+        assertThat(bufferOrEvent.isPresent()).isTrue();
+        assertThat(bufferOrEvent.get().isBuffer()).isEqualTo(expectedIsBuffer);
+        assertThat(bufferOrEvent.get().getChannelInfo())
+                .isEqualTo(inputGate.getChannel(expectedChannelIndex).getChannelInfo());
+        assertThat(bufferOrEvent.get().moreAvailable()).isEqualTo(expectedMoreAvailable);
         if (!expectedMoreAvailable) {
-            assertFalse(inputGate.pollNext().isPresent());
+            assertThat(inputGate.pollNext().isPresent()).isFalse();
         }
     }
 


[flink] 02/02: [FLINK-28382][network] Introduce LZO and ZSTD compression based on aircompressor for blocking shuffle

Posted by yi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 47d0b6d26c052a817a66f7b719eecf01387cb0d3
Author: Weijie Guo <re...@163.com>
AuthorDate: Sat Jul 30 00:53:52 2022 +0800

    [FLINK-28382][network] Introduce LZO and ZSTD compression based on aircompressor for blocking shuffle
    
    This closes #20216.
---
 .../generated/all_taskmanager_network_section.html |   6 ++
 .../netty_shuffle_environment_configuration.html   |   6 ++
 .../NettyShuffleEnvironmentOptions.java            |  14 ++-
 flink-runtime/pom.xml                              |  38 +++++++-
 ...lockCompressor.java => AirBlockCompressor.java} |  76 +++++++--------
 .../io/compression/AirBlockDecompressor.java       | 103 +++++++++++++++++++++
 ...ssionFactory.java => AirCompressorFactory.java} |  25 +++--
 .../io/compression/BlockCompressionFactory.java    |  19 +++-
 .../runtime/io/compression/BlockCompressor.java    |   8 +-
 .../runtime/io/compression/BlockDecompressor.java  |  10 +-
 ...eption.java => BufferCompressionException.java} |  14 +--
 ...tion.java => BufferDecompressionException.java} |  15 ++-
 ...ompressionFactory.java => CompressorUtils.java} |  31 +++++--
 .../io/compression/Lz4BlockCompressionFactory.java |   7 --
 .../runtime/io/compression/Lz4BlockCompressor.java |  24 ++---
 .../io/compression/Lz4BlockDecompressor.java       |  40 ++++----
 .../io/network/buffer/BufferCompressor.java        |  47 +++++++---
 .../io/network/buffer/BufferDecompressor.java      |  40 ++++++--
 flink-runtime/src/main/resources/META-INF/NOTICE   |   9 ++
 .../io/compression/BlockCompressionTest.java       |  22 +++--
 .../io/network/buffer/BufferCompressionTest.java   |  12 +++
 ...editBasedPartitionRequestClientHandlerTest.java |   8 +-
 .../NettyMessageClientSideSerializationTest.java   |  21 +++--
 .../partition/consumer/SingleInputGateTest.java    |   8 +-
 .../io/CompressedHeaderlessChannelTest.java        |  15 ++-
 25 files changed, 437 insertions(+), 181 deletions(-)

diff --git a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
index a6705820858..a626890a3fc 100644
--- a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
+++ b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
@@ -20,6 +20,12 @@
             <td>String</td>
             <td>The blocking shuffle type, either "mmap" or "file". The "auto" means selecting the property type automatically based on system memory architecture (64 bit for mmap and 32 bit for file). Note that the memory usage of mmap is not accounted by configured memory limits, but some resource frameworks like yarn would track this memory usage and kill the container once memory exceeding some threshold. Also note that this option is experimental and might be changed future.</td>
         </tr>
+        <tr>
+            <td><h5>taskmanager.network.compression.codec</h5></td>
+            <td style="word-wrap: break-word;">"LZ4"</td>
+            <td>String</td>
+            <td>The codec to be used when compressing shuffle data, only "LZ4", "LZO" and "ZSTD" are supported now. Through tpc-ds test of these three algorithms, the results show that "LZ4" algorithm has the highest compression and decompression speed, but the compression ratio is the lowest. "ZSTD" has the highest compression ratio, but the compression and decompression speed is the slowest, and LZO is between the two. Also note that this option is experimental and might be changed in  [...]
+        </tr>
         <tr>
             <td><h5>taskmanager.network.detailed-metrics</h5></td>
             <td style="word-wrap: break-word;">false</td>
diff --git a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
index c9f19ff76ba..8bbd176d85e 100644
--- a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
+++ b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
@@ -38,6 +38,12 @@
             <td>String</td>
             <td>The blocking shuffle type, either "mmap" or "file". The "auto" means selecting the property type automatically based on system memory architecture (64 bit for mmap and 32 bit for file). Note that the memory usage of mmap is not accounted by configured memory limits, but some resource frameworks like yarn would track this memory usage and kill the container once memory exceeding some threshold. Also note that this option is experimental and might be changed future.</td>
         </tr>
+        <tr>
+            <td><h5>taskmanager.network.compression.codec</h5></td>
+            <td style="word-wrap: break-word;">"LZ4"</td>
+            <td>String</td>
+            <td>The codec to be used when compressing shuffle data, only "LZ4", "LZO" and "ZSTD" are supported now. Through tpc-ds test of these three algorithms, the results show that "LZ4" algorithm has the highest compression and decompression speed, but the compression ratio is the lowest. "ZSTD" has the highest compression ratio, but the compression and decompression speed is the slowest, and LZO is between the two. Also note that this option is experimental and might be changed in  [...]
+        </tr>
         <tr>
             <td><h5>taskmanager.network.detailed-metrics</h5></td>
             <td style="word-wrap: break-word;">false</td>
diff --git a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
index 0286bc27254..51ba09faad3 100644
--- a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
+++ b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.configuration;
 
+import org.apache.flink.annotation.Experimental;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.annotation.docs.Documentation;
 
@@ -88,12 +89,21 @@ public class NettyShuffleEnvironmentOptions {
                                     + "ratio is high.");
 
     /** The codec to be used when compressing shuffle data. */
-    @Documentation.ExcludeFromDocumentation("Currently, LZ4 is the only legal option.")
+    @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+    @Experimental
     public static final ConfigOption<String> SHUFFLE_COMPRESSION_CODEC =
             key("taskmanager.network.compression.codec")
                     .stringType()
                     .defaultValue("LZ4")
-                    .withDescription("The codec to be used when compressing shuffle data.");
+                    .withDescription(
+                            "The codec to be used when compressing shuffle data, only \"LZ4\", \"LZO\" "
+                                    + "and \"ZSTD\" are supported now. Through tpc-ds test of these "
+                                    + "three algorithms, the results show that \"LZ4\" algorithm has "
+                                    + "the highest compression and decompression speed, but the "
+                                    + "compression ratio is the lowest. \"ZSTD\" has the highest "
+                                    + "compression ratio, but the compression and decompression "
+                                    + "speed is the slowest, and LZO is between the two. Also note "
+                                    + "that this option is experimental and might be changed in the future.");
 
     /**
      * Boolean flag to enable/disable more detailed metrics about inbound/outbound network queue
diff --git a/flink-runtime/pom.xml b/flink-runtime/pom.xml
index 631a10f094a..86206ad4dd3 100644
--- a/flink-runtime/pom.xml
+++ b/flink-runtime/pom.xml
@@ -141,7 +141,7 @@ under the License.
 			<groupId>org.apache.commons</groupId>
 			<artifactId>commons-lang3</artifactId>
 		</dependency>
-		
+
 		<dependency>
 			<groupId>commons-cli</groupId>
 			<artifactId>commons-cli</artifactId>
@@ -165,6 +165,13 @@ under the License.
 			<artifactId>lz4-java</artifactId>
 		</dependency>
 
+		<!-- air compression library -->
+		<dependency>
+			<groupId>io.airlift</groupId>
+			<artifactId>aircompressor</artifactId>
+			<version>0.21</version>
+		</dependency>
+
 		<!-- test dependencies -->
 
 		<dependency>
@@ -303,6 +310,35 @@ under the License.
 					</execution>
 				</executions>
 			</plugin>
+
+			<plugin>
+				<groupId>org.apache.maven.plugins</groupId>
+				<artifactId>maven-shade-plugin</artifactId>
+				<executions>
+					<execution>
+						<id>shade-flink</id>
+						<phase>package</phase>
+						<goals>
+							<goal>shade</goal>
+						</goals>
+						<configuration>
+							<artifactSet>
+								<includes>
+									<include>io.airlift:aircompressor</include>
+								</includes>
+							</artifactSet>
+							<relocations>
+								<relocation>
+									<pattern>io.airlift.compress</pattern>
+									<shadedPattern>
+										org.apache.flink.shaded.io.airlift.compress
+									</shadedPattern>
+								</relocation>
+							</relocations>
+						</configuration>
+					</execution>
+				</executions>
+			</plugin>
 		</plugins>
 	</build>
 </project>
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirBlockCompressor.java
similarity index 51%
copy from flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressor.java
copy to flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirBlockCompressor.java
index 86607c73d86..b4e1664e107 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirBlockCompressor.java
@@ -18,53 +18,44 @@
 
 package org.apache.flink.runtime.io.compression;
 
-import net.jpountz.lz4.LZ4Compressor;
-import net.jpountz.lz4.LZ4Exception;
-import net.jpountz.lz4.LZ4Factory;
+import io.airlift.compress.Compressor;
 
-import java.nio.BufferOverflowException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 
-import static org.apache.flink.runtime.io.compression.Lz4BlockCompressionFactory.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.writeIntLE;
 
-/**
- * Encode data into LZ4 format (not compatible with the LZ4 Frame format). It reads from and writes
- * to byte arrays provided from the outside, thus reducing copy time.
- *
- * <p>This class is copied and modified from {@link net.jpountz.lz4.LZ4BlockOutputStream}.
- */
-public class Lz4BlockCompressor implements BlockCompressor {
+/** Flink compressor that wraps {@link Compressor}. */
+public class AirBlockCompressor implements BlockCompressor {
+    private final Compressor internalCompressor;
 
-    private final LZ4Compressor compressor;
-
-    public Lz4BlockCompressor() {
-        this.compressor = LZ4Factory.fastestInstance().fastCompressor();
+    public AirBlockCompressor(Compressor internalCompressor) {
+        this.internalCompressor = internalCompressor;
     }
 
     @Override
     public int getMaxCompressedSize(int srcSize) {
-        return HEADER_LENGTH + compressor.maxCompressedLength(srcSize);
+        return HEADER_LENGTH + internalCompressor.maxCompressedLength(srcSize);
     }
 
     @Override
     public int compress(ByteBuffer src, int srcOff, int srcLen, ByteBuffer dst, int dstOff)
-            throws InsufficientBufferException {
+            throws BufferCompressionException {
         try {
+            if (dst.remaining() < dstOff + getMaxCompressedSize(srcLen)) {
+                throw new ArrayIndexOutOfBoundsException();
+            }
+
             final int prevSrcOff = src.position() + srcOff;
             final int prevDstOff = dst.position() + dstOff;
 
-            int maxCompressedSize = compressor.maxCompressedLength(srcLen);
-            int compressedLength =
-                    compressor.compress(
-                            src,
-                            prevSrcOff,
-                            srcLen,
-                            dst,
-                            prevDstOff + HEADER_LENGTH,
-                            maxCompressedSize);
+            src.position(prevSrcOff);
+            dst.position(prevDstOff + HEADER_LENGTH);
 
-            src.position(prevSrcOff + srcLen);
+            internalCompressor.compress(src, dst);
+
+            int compressedLength = dst.position() - prevDstOff - HEADER_LENGTH;
 
             dst.position(prevDstOff);
             dst.order(ByteOrder.LITTLE_ENDIAN);
@@ -73,29 +64,32 @@ public class Lz4BlockCompressor implements BlockCompressor {
             dst.position(prevDstOff + compressedLength + HEADER_LENGTH);
 
             return HEADER_LENGTH + compressedLength;
-        } catch (LZ4Exception | ArrayIndexOutOfBoundsException | BufferOverflowException e) {
-            throw new InsufficientBufferException(e);
+        } catch (Exception e) {
+            throw new BufferCompressionException(e);
         }
     }
 
     @Override
     public int compress(byte[] src, int srcOff, int srcLen, byte[] dst, int dstOff)
-            throws InsufficientBufferException {
+            throws BufferCompressionException {
         try {
+            if (dst.length < dstOff + getMaxCompressedSize(srcLen)) {
+                throw new ArrayIndexOutOfBoundsException();
+            }
+
             int compressedLength =
-                    compressor.compress(src, srcOff, srcLen, dst, dstOff + HEADER_LENGTH);
+                    internalCompressor.compress(
+                            src,
+                            srcOff,
+                            srcLen,
+                            dst,
+                            dstOff + HEADER_LENGTH,
+                            internalCompressor.maxCompressedLength(srcLen));
             writeIntLE(compressedLength, dst, dstOff);
             writeIntLE(srcLen, dst, dstOff + 4);
             return HEADER_LENGTH + compressedLength;
-        } catch (LZ4Exception | BufferOverflowException | ArrayIndexOutOfBoundsException e) {
-            throw new InsufficientBufferException(e);
+        } catch (Exception e) {
+            throw new BufferCompressionException(e);
         }
     }
-
-    private static void writeIntLE(int i, byte[] buf, int offset) {
-        buf[offset++] = (byte) i;
-        buf[offset++] = (byte) (i >>> 8);
-        buf[offset++] = (byte) (i >>> 16);
-        buf[offset] = (byte) (i >>> 24);
-    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirBlockDecompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirBlockDecompressor.java
new file mode 100644
index 00000000000..765b9f98875
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirBlockDecompressor.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.compression;
+
+import io.airlift.compress.Decompressor;
+import io.airlift.compress.MalformedInputException;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+import static org.apache.flink.runtime.io.compression.CompressorUtils.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.readIntLE;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.validateLength;
+
+/** Flink decompressor that wraps {@link Decompressor}. */
+public class AirBlockDecompressor implements BlockDecompressor {
+    private final Decompressor internalDecompressor;
+
+    public AirBlockDecompressor(Decompressor internalDecompressor) {
+        this.internalDecompressor = internalDecompressor;
+    }
+
+    @Override
+    public int decompress(ByteBuffer src, int srcOff, int srcLen, ByteBuffer dst, int dstOff)
+            throws BufferDecompressionException {
+        final int prevSrcOff = src.position() + srcOff;
+        final int prevDstOff = dst.position() + dstOff;
+
+        src.position(prevSrcOff);
+        dst.position(prevDstOff);
+        src.order(ByteOrder.LITTLE_ENDIAN);
+        final int compressedLen = src.getInt();
+        final int originalLen = src.getInt();
+        validateLength(compressedLen, originalLen);
+
+        if (dst.capacity() - prevDstOff < originalLen) {
+            throw new BufferDecompressionException("Buffer length too small");
+        }
+
+        if (src.limit() - prevSrcOff - HEADER_LENGTH < compressedLen) {
+            throw new BufferDecompressionException(
+                    "Source data is not integral for decompression.");
+        }
+        src.limit(prevSrcOff + compressedLen + HEADER_LENGTH);
+        try {
+            internalDecompressor.decompress(src, dst);
+            if (originalLen != dst.position() - prevDstOff) {
+                throw new BufferDecompressionException(
+                        "Input is corrupted, unexpected original length.");
+            }
+        } catch (MalformedInputException e) {
+            throw new BufferDecompressionException("Input is corrupted", e);
+        }
+
+        return originalLen;
+    }
+
+    @Override
+    public int decompress(byte[] src, int srcOff, int srcLen, byte[] dst, int dstOff)
+            throws BufferDecompressionException {
+        int compressedLen = readIntLE(src, srcOff);
+        int originalLen = readIntLE(src, srcOff + 4);
+        validateLength(compressedLen, originalLen);
+
+        if (dst.length - dstOff < originalLen) {
+            throw new BufferDecompressionException("Buffer length too small");
+        }
+
+        if (src.length - srcOff - HEADER_LENGTH < compressedLen) {
+            throw new BufferDecompressionException(
+                    "Source data is not integral for decompression.");
+        }
+
+        try {
+            final int decompressedLen =
+                    internalDecompressor.decompress(
+                            src, srcOff + HEADER_LENGTH, compressedLen, dst, dstOff, originalLen);
+            if (originalLen != decompressedLen) {
+                throw new BufferDecompressionException("Input is corrupted");
+            }
+        } catch (MalformedInputException e) {
+            throw new BufferDecompressionException("Input is corrupted", e);
+        }
+
+        return originalLen;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirCompressorFactory.java
similarity index 58%
copy from flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java
copy to flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirCompressorFactory.java
index dec76400c4d..ab68abcf11a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/AirCompressorFactory.java
@@ -18,22 +18,29 @@
 
 package org.apache.flink.runtime.io.compression;
 
-/** Implementation of {@link BlockCompressionFactory} for Lz4 codec. */
-public class Lz4BlockCompressionFactory implements BlockCompressionFactory {
+import io.airlift.compress.Compressor;
+import io.airlift.compress.Decompressor;
 
-    /**
-     * We put two integers before each compressed block, the first integer represents the compressed
-     * length of the block, and the second one represents the original length of the block.
-     */
-    public static final int HEADER_LENGTH = 8;
+/**
+ * {@link BlockCompressionFactory} to create wrapped {@link Compressor} and {@link Decompressor}.
+ */
+public class AirCompressorFactory implements BlockCompressionFactory {
+    private final Compressor internalCompressor;
+
+    private final Decompressor internalDecompressor;
+
+    public AirCompressorFactory(Compressor internalCompressor, Decompressor internalDecompressor) {
+        this.internalCompressor = internalCompressor;
+        this.internalDecompressor = internalDecompressor;
+    }
 
     @Override
     public BlockCompressor getCompressor() {
-        return new Lz4BlockCompressor();
+        return new AirBlockCompressor(internalCompressor);
     }
 
     @Override
     public BlockDecompressor getDecompressor() {
-        return new Lz4BlockDecompressor();
+        return new AirBlockDecompressor(internalDecompressor);
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressionFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressionFactory.java
index 587c17ffcd3..145578e036f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressionFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressionFactory.java
@@ -20,6 +20,11 @@ package org.apache.flink.runtime.io.compression;
 
 import org.apache.flink.configuration.IllegalConfigurationException;
 
+import io.airlift.compress.lzo.LzoCompressor;
+import io.airlift.compress.lzo.LzoDecompressor;
+import io.airlift.compress.zstd.ZstdCompressor;
+import io.airlift.compress.zstd.ZstdDecompressor;
+
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -34,7 +39,9 @@ public interface BlockCompressionFactory {
 
     /** Name of {@link BlockCompressionFactory}. */
     enum CompressionFactoryName {
-        LZ4
+        LZ4,
+        LZO,
+        ZSTD
     }
 
     /**
@@ -54,12 +61,20 @@ public interface BlockCompressionFactory {
             compressionName = null;
         }
 
-        BlockCompressionFactory blockCompressionFactory = null;
+        BlockCompressionFactory blockCompressionFactory;
         if (compressionName != null) {
             switch (compressionName) {
                 case LZ4:
                     blockCompressionFactory = new Lz4BlockCompressionFactory();
                     break;
+                case LZO:
+                    blockCompressionFactory =
+                            new AirCompressorFactory(new LzoCompressor(), new LzoDecompressor());
+                    break;
+                case ZSTD:
+                    blockCompressionFactory =
+                            new AirCompressorFactory(new ZstdCompressor(), new ZstdDecompressor());
+                    break;
                 default:
                     throw new IllegalStateException("Unknown CompressionMethod " + compressionName);
             }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressor.java
index edcc5468c0f..1d83f0ab8d2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockCompressor.java
@@ -39,10 +39,10 @@ public interface BlockCompressor {
      * @param dst The target to write compressed data
      * @param dstOff The start offset to write the compressed data
      * @return Length of compressed data
-     * @throws InsufficientBufferException if the target does not have sufficient space
+     * @throws BufferCompressionException if exception thrown when compressing
      */
     int compress(ByteBuffer src, int srcOff, int srcLen, ByteBuffer dst, int dstOff)
-            throws InsufficientBufferException;
+            throws BufferCompressionException;
 
     /**
      * Compress data read from src, and write the compressed data to dst.
@@ -53,8 +53,8 @@ public interface BlockCompressor {
      * @param dst The target to write compressed data
      * @param dstOff The start offset to write the compressed data
      * @return Length of compressed data
-     * @throws InsufficientBufferException if the target does not have sufficient space
+     * @throws BufferCompressionException if exception thrown when compressing
      */
     int compress(byte[] src, int srcOff, int srcLen, byte[] dst, int dstOff)
-            throws InsufficientBufferException;
+            throws BufferCompressionException;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockDecompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockDecompressor.java
index a52d787db81..ba6b79f4eaf 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockDecompressor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BlockDecompressor.java
@@ -33,11 +33,10 @@ public interface BlockDecompressor {
      * @param dst The target to write decompressed data
      * @param dstOff The start offset to write the decompressed data
      * @return Length of decompressed data
-     * @throws DataCorruptionException if data corruption found when decompressing
-     * @throws InsufficientBufferException if the target does not have sufficient space
+     * @throws BufferDecompressionException if exception thrown when decompressing
      */
     int decompress(ByteBuffer src, int srcOff, int srcLen, ByteBuffer dst, int dstOff)
-            throws DataCorruptionException, InsufficientBufferException;
+            throws BufferDecompressionException;
 
     /**
      * Decompress source data read from src and write the decompressed data to dst.
@@ -48,9 +47,8 @@ public interface BlockDecompressor {
      * @param dst The target to write decompressed data
      * @param dstOff The start offset to write the decompressed data
      * @return Length of decompressed data
-     * @throws DataCorruptionException if data corruption found when decompressing
-     * @throws InsufficientBufferException if the target does not have sufficient space
+     * @throws BufferDecompressionException if exception thrown when decompressing
      */
     int decompress(byte[] src, int srcOff, int srcLen, byte[] dst, int dstOff)
-            throws DataCorruptionException, InsufficientBufferException;
+            throws BufferDecompressionException;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/DataCorruptionException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BufferCompressionException.java
similarity index 68%
rename from flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/DataCorruptionException.java
rename to flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BufferCompressionException.java
index 9169cebfd98..dc2461cc06a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/DataCorruptionException.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BufferCompressionException.java
@@ -19,24 +19,24 @@
 package org.apache.flink.runtime.io.compression;
 
 /**
- * A {@code DataCorruptionException} is thrown when the decompressed data is corrupted and cannot be
- * decompressed.
+ * A {@code BufferCompressionException} is thrown when the target data cannot be compressed, such as
+ * insufficient target buffer space for compression, etc.
  */
-public class DataCorruptionException extends RuntimeException {
+public class BufferCompressionException extends RuntimeException {
 
-    public DataCorruptionException() {
+    public BufferCompressionException() {
         super();
     }
 
-    public DataCorruptionException(String message) {
+    public BufferCompressionException(String message) {
         super(message);
     }
 
-    public DataCorruptionException(String message, Throwable e) {
+    public BufferCompressionException(String message, Throwable e) {
         super(message, e);
     }
 
-    public DataCorruptionException(Throwable e) {
+    public BufferCompressionException(Throwable e) {
         super(e);
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/InsufficientBufferException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BufferDecompressionException.java
similarity index 65%
rename from flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/InsufficientBufferException.java
rename to flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BufferDecompressionException.java
index b67c7f4362b..65be2b1aacb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/InsufficientBufferException.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/BufferDecompressionException.java
@@ -19,25 +19,24 @@
 package org.apache.flink.runtime.io.compression;
 
 /**
- * An {@code InsufficientBufferException} is thrown when there is no enough buffer to serialize or
- * deserialize a buffer to another buffer. When such exception being caught, user may enlarge the
- * output buffer and try again.
+ * A {@code BufferDecompressionException} is thrown when the target data cannot be decompressed,
+ * such as data corruption, insufficient target buffer space for decompression, etc.
  */
-public class InsufficientBufferException extends RuntimeException {
+public class BufferDecompressionException extends RuntimeException {
 
-    public InsufficientBufferException() {
+    public BufferDecompressionException() {
         super();
     }
 
-    public InsufficientBufferException(String message) {
+    public BufferDecompressionException(String message) {
         super(message);
     }
 
-    public InsufficientBufferException(String message, Throwable e) {
+    public BufferDecompressionException(String message, Throwable e) {
         super(message, e);
     }
 
-    public InsufficientBufferException(Throwable e) {
+    public BufferDecompressionException(Throwable e) {
         super(e);
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/CompressorUtils.java
similarity index 54%
copy from flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java
copy to flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/CompressorUtils.java
index dec76400c4d..f531bbed4f2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/CompressorUtils.java
@@ -18,22 +18,35 @@
 
 package org.apache.flink.runtime.io.compression;
 
-/** Implementation of {@link BlockCompressionFactory} for Lz4 codec. */
-public class Lz4BlockCompressionFactory implements BlockCompressionFactory {
-
+/** Utils for {@link BlockCompressor}. */
+public class CompressorUtils {
     /**
      * We put two integers before each compressed block, the first integer represents the compressed
      * length of the block, and the second one represents the original length of the block.
      */
     public static final int HEADER_LENGTH = 8;
 
-    @Override
-    public BlockCompressor getCompressor() {
-        return new Lz4BlockCompressor();
+    public static void writeIntLE(int i, byte[] buf, int offset) {
+        buf[offset++] = (byte) i;
+        buf[offset++] = (byte) (i >>> 8);
+        buf[offset++] = (byte) (i >>> 16);
+        buf[offset] = (byte) (i >>> 24);
+    }
+
+    public static int readIntLE(byte[] buf, int i) {
+        return (buf[i] & 0xFF)
+                | ((buf[i + 1] & 0xFF) << 8)
+                | ((buf[i + 2] & 0xFF) << 16)
+                | ((buf[i + 3] & 0xFF) << 24);
     }
 
-    @Override
-    public BlockDecompressor getDecompressor() {
-        return new Lz4BlockDecompressor();
+    public static void validateLength(int compressedLen, int originalLen)
+            throws BufferDecompressionException {
+        if (originalLen < 0
+                || compressedLen < 0
+                || (originalLen == 0 && compressedLen != 0)
+                || (originalLen != 0 && compressedLen == 0)) {
+            throw new BufferDecompressionException("Input is corrupted, invalid length.");
+        }
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java
index dec76400c4d..7e21979be82 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressionFactory.java
@@ -20,13 +20,6 @@ package org.apache.flink.runtime.io.compression;
 
 /** Implementation of {@link BlockCompressionFactory} for Lz4 codec. */
 public class Lz4BlockCompressionFactory implements BlockCompressionFactory {
-
-    /**
-     * We put two integers before each compressed block, the first integer represents the compressed
-     * length of the block, and the second one represents the original length of the block.
-     */
-    public static final int HEADER_LENGTH = 8;
-
     @Override
     public BlockCompressor getCompressor() {
         return new Lz4BlockCompressor();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressor.java
index 86607c73d86..158afafeefd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockCompressor.java
@@ -19,14 +19,13 @@
 package org.apache.flink.runtime.io.compression;
 
 import net.jpountz.lz4.LZ4Compressor;
-import net.jpountz.lz4.LZ4Exception;
 import net.jpountz.lz4.LZ4Factory;
 
-import java.nio.BufferOverflowException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 
-import static org.apache.flink.runtime.io.compression.Lz4BlockCompressionFactory.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.writeIntLE;
 
 /**
  * Encode data into LZ4 format (not compatible with the LZ4 Frame format). It reads from and writes
@@ -49,7 +48,7 @@ public class Lz4BlockCompressor implements BlockCompressor {
 
     @Override
     public int compress(ByteBuffer src, int srcOff, int srcLen, ByteBuffer dst, int dstOff)
-            throws InsufficientBufferException {
+            throws BufferCompressionException {
         try {
             final int prevSrcOff = src.position() + srcOff;
             final int prevDstOff = dst.position() + dstOff;
@@ -73,29 +72,22 @@ public class Lz4BlockCompressor implements BlockCompressor {
             dst.position(prevDstOff + compressedLength + HEADER_LENGTH);
 
             return HEADER_LENGTH + compressedLength;
-        } catch (LZ4Exception | ArrayIndexOutOfBoundsException | BufferOverflowException e) {
-            throw new InsufficientBufferException(e);
+        } catch (Exception e) {
+            throw new BufferCompressionException(e);
         }
     }
 
     @Override
     public int compress(byte[] src, int srcOff, int srcLen, byte[] dst, int dstOff)
-            throws InsufficientBufferException {
+            throws BufferCompressionException {
         try {
             int compressedLength =
                     compressor.compress(src, srcOff, srcLen, dst, dstOff + HEADER_LENGTH);
             writeIntLE(compressedLength, dst, dstOff);
             writeIntLE(srcLen, dst, dstOff + 4);
             return HEADER_LENGTH + compressedLength;
-        } catch (LZ4Exception | BufferOverflowException | ArrayIndexOutOfBoundsException e) {
-            throw new InsufficientBufferException(e);
+        } catch (Exception e) {
+            throw new BufferCompressionException(e);
         }
     }
-
-    private static void writeIntLE(int i, byte[] buf, int offset) {
-        buf[offset++] = (byte) i;
-        buf[offset++] = (byte) (i >>> 8);
-        buf[offset++] = (byte) (i >>> 16);
-        buf[offset] = (byte) (i >>> 24);
-    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockDecompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockDecompressor.java
index f60014d8fd0..fd73ddce12d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockDecompressor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/compression/Lz4BlockDecompressor.java
@@ -21,12 +21,13 @@ package org.apache.flink.runtime.io.compression;
 import net.jpountz.lz4.LZ4Exception;
 import net.jpountz.lz4.LZ4Factory;
 import net.jpountz.lz4.LZ4FastDecompressor;
-import net.jpountz.util.SafeUtils;
 
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 
-import static org.apache.flink.runtime.io.compression.Lz4BlockCompressionFactory.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.readIntLE;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.validateLength;
 
 /**
  * Decode data written with {@link Lz4BlockCompressor}. It reads from and writes to byte arrays
@@ -44,7 +45,7 @@ public class Lz4BlockDecompressor implements BlockDecompressor {
 
     @Override
     public int decompress(ByteBuffer src, int srcOff, int srcLen, ByteBuffer dst, int dstOff)
-            throws DataCorruptionException {
+            throws BufferDecompressionException {
         final int prevSrcOff = src.position() + srcOff;
         final int prevDstOff = dst.position() + dstOff;
 
@@ -54,11 +55,12 @@ public class Lz4BlockDecompressor implements BlockDecompressor {
         validateLength(compressedLen, originalLen);
 
         if (dst.capacity() - prevDstOff < originalLen) {
-            throw new InsufficientBufferException("Buffer length too small");
+            throw new BufferDecompressionException("Buffer length too small");
         }
 
         if (src.limit() - prevSrcOff - HEADER_LENGTH < compressedLen) {
-            throw new DataCorruptionException("Source data is not integral for decompression.");
+            throw new BufferDecompressionException(
+                    "Source data is not integral for decompression.");
         }
 
         try {
@@ -66,13 +68,13 @@ public class Lz4BlockDecompressor implements BlockDecompressor {
                     decompressor.decompress(
                             src, prevSrcOff + HEADER_LENGTH, dst, prevDstOff, originalLen);
             if (compressedLen != compressedLen2) {
-                throw new DataCorruptionException(
+                throw new BufferDecompressionException(
                         "Input is corrupted, unexpected compressed length.");
             }
             src.position(prevSrcOff + compressedLen + HEADER_LENGTH);
             dst.position(prevDstOff + originalLen);
         } catch (LZ4Exception e) {
-            throw new DataCorruptionException("Input is corrupted", e);
+            throw new BufferDecompressionException("Input is corrupted", e);
         }
 
         return originalLen;
@@ -80,38 +82,30 @@ public class Lz4BlockDecompressor implements BlockDecompressor {
 
     @Override
     public int decompress(byte[] src, int srcOff, int srcLen, byte[] dst, int dstOff)
-            throws InsufficientBufferException, DataCorruptionException {
-        final int compressedLen = SafeUtils.readIntLE(src, srcOff);
-        final int originalLen = SafeUtils.readIntLE(src, srcOff + 4);
+            throws BufferDecompressionException {
+        final int compressedLen = readIntLE(src, srcOff);
+        final int originalLen = readIntLE(src, srcOff + 4);
         validateLength(compressedLen, originalLen);
 
         if (dst.length - dstOff < originalLen) {
-            throw new InsufficientBufferException("Buffer length too small");
+            throw new BufferDecompressionException("Buffer length too small");
         }
 
         if (src.length - srcOff - HEADER_LENGTH < compressedLen) {
-            throw new DataCorruptionException("Source data is not integral for decompression.");
+            throw new BufferDecompressionException(
+                    "Source data is not integral for decompression.");
         }
 
         try {
             final int compressedLen2 =
                     decompressor.decompress(src, srcOff + HEADER_LENGTH, dst, dstOff, originalLen);
             if (compressedLen != compressedLen2) {
-                throw new DataCorruptionException("Input is corrupted");
+                throw new BufferDecompressionException("Input is corrupted");
             }
         } catch (LZ4Exception e) {
-            throw new DataCorruptionException("Input is corrupted", e);
+            throw new BufferDecompressionException("Input is corrupted", e);
         }
 
         return originalLen;
     }
-
-    private void validateLength(int compressedLen, int originalLen) throws DataCorruptionException {
-        if (originalLen < 0
-                || compressedLen < 0
-                || (originalLen == 0 && compressedLen != 0)
-                || (originalLen != 0 && compressedLen == 0)) {
-            throw new DataCorruptionException("Input is corrupted, invalid length.");
-        }
-    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferCompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferCompressor.java
index 4a58ad2c08c..07b245f2b74 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferCompressor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferCompressor.java
@@ -36,16 +36,20 @@ public class BufferCompressor {
     /** The intermediate buffer for the compressed data. */
     private final NetworkBuffer internalBuffer;
 
+    /** The backup array of intermediate buffer. */
+    private final byte[] internalBufferArray;
+
     public BufferCompressor(int bufferSize, String factoryName) {
         checkArgument(bufferSize > 0);
         checkNotNull(factoryName);
         // the size of this intermediate heap buffer will be gotten from the
         // plugin configuration in the future, and currently, double size of
-        // the input buffer is enough for lz4-java compression library.
-        final byte[] heapBuffer = new byte[2 * bufferSize];
+        // the input buffer is enough for the compression libraries used.
+        this.internalBufferArray = new byte[2 * bufferSize];
         this.internalBuffer =
                 new NetworkBuffer(
-                        MemorySegmentFactory.wrap(heapBuffer), FreeingBufferRecycler.INSTANCE);
+                        MemorySegmentFactory.wrap(internalBufferArray),
+                        FreeingBufferRecycler.INSTANCE);
         this.blockCompressor =
                 BlockCompressionFactory.createBlockCompressionFactory(factoryName).getCompressor();
     }
@@ -86,7 +90,7 @@ public class BufferCompressor {
         // copy the compressed data back
         int memorySegmentOffset = buffer.getMemorySegmentOffset();
         MemorySegment segment = buffer.getMemorySegment();
-        segment.put(memorySegmentOffset, internalBuffer.array(), 0, compressedLen);
+        segment.put(memorySegmentOffset, internalBufferArray, 0, compressedLen);
 
         return new ReadOnlySlicedNetworkBuffer(
                 buffer.asByteBuf(), 0, compressedLen, memorySegmentOffset, true);
@@ -107,15 +111,34 @@ public class BufferCompressor {
                 "Illegal reference count, buffer need to be released.");
 
         try {
+            int compressedLen;
             int length = buffer.getSize();
-            // compress the given buffer into the internal heap buffer
-            int compressedLen =
-                    blockCompressor.compress(
-                            buffer.getNioBuffer(0, length),
-                            0,
-                            length,
-                            internalBuffer.getNioBuffer(0, internalBuffer.capacity()),
-                            0);
+            MemorySegment memorySegment = buffer.getMemorySegment();
+            // If buffer is on-heap, manipulate the underlying array directly. There are two main
+            // reasons why NIO buffer is not directly used here: One is that some compression
+            // libraries will use the underlying array for heap buffer, but our input buffer may be
+            // a read-only ByteBuffer, and it is illegal to access internal array. Another reason
+            // is that for the on-heap buffer, directly operating the underlying array can reduce
+            // additional overhead compared to generating a NIO buffer.
+            if (!memorySegment.isOffHeap()) {
+                compressedLen =
+                        blockCompressor.compress(
+                                memorySegment.getArray(),
+                                buffer.getMemorySegmentOffset(),
+                                length,
+                                internalBufferArray,
+                                0);
+            } else {
+                // compress the given buffer into the internal heap buffer
+                compressedLen =
+                        blockCompressor.compress(
+                                buffer.getNioBuffer(0, length),
+                                0,
+                                length,
+                                internalBuffer.getNioBuffer(0, internalBuffer.capacity()),
+                                0);
+            }
+
             return compressedLen < length ? compressedLen : 0;
         } catch (Throwable throwable) {
             // return the original buffer if failed to compress
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferDecompressor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferDecompressor.java
index c6dc89a407a..b99f5d05413 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferDecompressor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferDecompressor.java
@@ -37,15 +37,19 @@ public class BufferDecompressor {
     /** The intermediate buffer for the decompressed data. */
     private final NetworkBuffer internalBuffer;
 
+    /** The backup array of intermediate buffer. */
+    private final byte[] internalBufferArray;
+
     public BufferDecompressor(int bufferSize, String factoryName) {
         checkArgument(bufferSize > 0);
         checkNotNull(factoryName);
 
         // the decompressed data size should be never larger than the configured buffer size
-        final byte[] heapBuffer = new byte[bufferSize];
+        this.internalBufferArray = new byte[bufferSize];
         this.internalBuffer =
                 new NetworkBuffer(
-                        MemorySegmentFactory.wrap(heapBuffer), FreeingBufferRecycler.INSTANCE);
+                        MemorySegmentFactory.wrap(internalBufferArray),
+                        FreeingBufferRecycler.INSTANCE);
         this.blockDecompressor =
                 BlockCompressionFactory.createBlockCompressionFactory(factoryName)
                         .getDecompressor();
@@ -82,7 +86,7 @@ public class BufferDecompressor {
         // copy the decompressed data back
         int memorySegmentOffset = buffer.getMemorySegmentOffset();
         MemorySegment segment = buffer.getMemorySegment();
-        segment.put(memorySegmentOffset, internalBuffer.array(), 0, decompressedLen);
+        segment.put(memorySegmentOffset, internalBufferArray, 0, decompressedLen);
 
         return new ReadOnlySlicedNetworkBuffer(
                 buffer.asByteBuf(), 0, decompressedLen, memorySegmentOffset, false);
@@ -103,12 +107,28 @@ public class BufferDecompressor {
                 "Illegal reference count, buffer need to be released.");
 
         int length = buffer.getSize();
-        // decompress the given buffer into the internal heap buffer
-        return blockDecompressor.decompress(
-                buffer.getNioBuffer(0, length),
-                0,
-                length,
-                internalBuffer.getNioBuffer(0, internalBuffer.capacity()),
-                0);
+        MemorySegment memorySegment = buffer.getMemorySegment();
+        // If buffer is on-heap, manipulate the underlying array directly. There are two main
+        // reasons why NIO buffer is not directly used here: One is that some compression
+        // libraries will use the underlying array for heap buffer, but our input buffer may be
+        // a read-only ByteBuffer, and it is illegal to access internal array. Another reason
+        // is that for the on-heap buffer, directly operating the underlying array can reduce
+        // additional overhead compared to generating a NIO buffer.
+        if (!memorySegment.isOffHeap()) {
+            return blockDecompressor.decompress(
+                    memorySegment.getArray(),
+                    buffer.getMemorySegmentOffset(),
+                    length,
+                    internalBufferArray,
+                    0);
+        } else {
+            // decompress the given buffer into the internal heap buffer
+            return blockDecompressor.decompress(
+                    buffer.getNioBuffer(0, length),
+                    0,
+                    length,
+                    internalBuffer.getNioBuffer(0, internalBuffer.capacity()),
+                    0);
+        }
     }
 }
diff --git a/flink-runtime/src/main/resources/META-INF/NOTICE b/flink-runtime/src/main/resources/META-INF/NOTICE
new file mode 100644
index 00000000000..80ac2a27de1
--- /dev/null
+++ b/flink-runtime/src/main/resources/META-INF/NOTICE
@@ -0,0 +1,9 @@
+flink-runtime
+Copyright 2014-2021 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+This project bundles the following dependencies under the Apache Software License 2.0. (http://www.apache.org/licenses/LICENSE-2.0.txt)
+
+- io.airlift:aircompressor:0.21
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
index fd8a05db6f9..efa59aad27f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/compression/BlockCompressionTest.java
@@ -18,20 +18,28 @@
 
 package org.apache.flink.runtime.io.compression;
 
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
 
 import java.nio.ByteBuffer;
+import java.util.stream.Stream;
 
-import static org.apache.flink.runtime.io.compression.Lz4BlockCompressionFactory.HEADER_LENGTH;
+import static org.apache.flink.runtime.io.compression.CompressorUtils.HEADER_LENGTH;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for block compression. */
 class BlockCompressionTest {
+    private static Stream<BlockCompressionFactory> compressCodecGenerator() {
+        return Stream.of(
+                BlockCompressionFactory.createBlockCompressionFactory("LZ4"),
+                BlockCompressionFactory.createBlockCompressionFactory("LZO"),
+                BlockCompressionFactory.createBlockCompressionFactory("ZSTD"));
+    }
 
-    @Test
-    void testLz4() {
-        BlockCompressionFactory factory = new Lz4BlockCompressionFactory();
+    @ParameterizedTest
+    @MethodSource("compressCodecGenerator")
+    void testBlockCompression(BlockCompressionFactory factory) {
         runArrayTest(factory, 32768);
         runArrayTest(factory, 16);
 
@@ -63,7 +71,7 @@ class BlockCompressionTest {
                                         originalLen,
                                         insufficientCompressArray,
                                         compressedOff))
-                .isInstanceOf(InsufficientBufferException.class);
+                .isInstanceOf(BufferCompressionException.class);
 
         // 2. test normal compress
         byte[] compressedData =
@@ -83,7 +91,7 @@ class BlockCompressionTest {
                                         compressedLen,
                                         insufficientDecompressArray,
                                         decompressedOff))
-                .isInstanceOf(InsufficientBufferException.class);
+                .isInstanceOf(BufferDecompressionException.class);
 
         // 4. test normal decompress
         byte[] decompressedData = new byte[decompressedOff + originalLen];
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferCompressionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferCompressionTest.java
index c7dbe04ff5e..d269151205b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferCompressionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferCompressionTest.java
@@ -68,6 +68,18 @@ public class BufferCompressionTest {
                     {false, "LZ4", true, false},
                     {false, "LZ4", false, true},
                     {false, "LZ4", false, false},
+                    {true, "ZSTD", true, false},
+                    {true, "ZSTD", false, true},
+                    {true, "ZSTD", false, false},
+                    {false, "ZSTD", true, false},
+                    {false, "ZSTD", false, true},
+                    {false, "ZSTD", false, false},
+                    {true, "LZO", true, false},
+                    {true, "LZO", false, true},
+                    {true, "LZO", false, false},
+                    {false, "LZO", true, false},
+                    {false, "LZO", false, true},
+                    {false, "LZO", false, false}
                 });
     }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
index 8ad892b1d7a..1ea783d3e34 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
@@ -58,6 +58,8 @@ import org.apache.flink.shaded.netty4.io.netty.channel.unix.Errors;
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import java.io.IOException;
 
@@ -197,10 +199,10 @@ class CreditBasedPartitionRequestClientHandlerTest {
     /**
      * Verifies that {@link BufferResponse} of compressed {@link Buffer} can be handled correctly.
      */
-    @Test
-    void testReceiveCompressedBuffer() throws Exception {
+    @ParameterizedTest
+    @ValueSource(strings = {"LZ4", "LZO", "ZSTD"})
+    void testReceiveCompressedBuffer(final String compressionCodec) throws Exception {
         int bufferSize = 1024;
-        String compressionCodec = "LZ4";
         BufferCompressor compressor = new BufferCompressor(bufferSize, compressionCodec);
         BufferDecompressor decompressor = new BufferDecompressor(bufferSize, compressionCodec);
         NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, bufferSize);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
index 925a1444c32..b6927e5c1ec 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageClientSideSerializationTest.java
@@ -38,6 +38,8 @@ import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import java.io.IOException;
 import java.util.Random;
@@ -63,12 +65,10 @@ class NettyMessageClientSideSerializationTest {
 
     private static final int BUFFER_SIZE = 1024;
 
-    private static final BufferCompressor COMPRESSOR = new BufferCompressor(BUFFER_SIZE, "LZ4");
-
-    private static final BufferDecompressor DECOMPRESSOR =
-            new BufferDecompressor(BUFFER_SIZE, "LZ4");
-
     private final Random random = new Random();
+    private static BufferCompressor compressor;
+
+    private static BufferDecompressor decompressor;
 
     private EmbeddedChannel channel;
 
@@ -143,8 +143,11 @@ class NettyMessageClientSideSerializationTest {
         testBufferResponse(true, false);
     }
 
-    @Test
-    void testCompressedBufferResponse() {
+    @ParameterizedTest
+    @ValueSource(strings = {"LZ4", "LZO", "ZSTD"})
+    void testCompressedBufferResponse(final String codecFactoryName) {
+        compressor = new BufferCompressor(BUFFER_SIZE, codecFactoryName);
+        decompressor = new BufferDecompressor(BUFFER_SIZE, codecFactoryName);
         testBufferResponse(false, true);
     }
 
@@ -178,7 +181,7 @@ class NettyMessageClientSideSerializationTest {
         if (testReadOnlyBuffer) {
             testBuffer = buffer.readOnlySlice();
         } else if (testCompressedBuffer) {
-            testBuffer = COMPRESSOR.compressToOriginalBuffer(buffer);
+            testBuffer = compressor.compressToOriginalBuffer(buffer);
         }
 
         BufferResponse expected =
@@ -221,6 +224,6 @@ class NettyMessageClientSideSerializationTest {
         Buffer compressedBuffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE);
         buffer.asByteBuf().readBytes(compressedBuffer.asByteBuf(), buffer.readableBytes());
         compressedBuffer.setCompressed(true);
-        return DECOMPRESSOR.decompressToOriginalBuffer(compressedBuffer);
+        return decompressor.decompressToOriginalBuffer(compressedBuffer);
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 44249a57bfa..c5757cba4b2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -73,6 +73,8 @@ import org.apache.flink.util.CompressedSerializedValue;
 import org.apache.flink.shaded.guava30.com.google.common.io.Closer;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -326,10 +328,10 @@ public class SingleInputGateTest extends InputGateTestBase {
      * Tests that the compressed buffer will be decompressed after calling {@link
      * SingleInputGate#getNext()}.
      */
-    @Test
-    void testGetCompressedBuffer() throws Exception {
+    @ParameterizedTest
+    @ValueSource(strings = {"LZ4", "LZO", "ZSTD"})
+    void testGetCompressedBuffer(final String compressionCodec) throws Exception {
         int bufferSize = 1024;
-        String compressionCodec = "LZ4";
         BufferCompressor compressor = new BufferCompressor(bufferSize, compressionCodec);
         BufferDecompressor decompressor = new BufferDecompressor(bufferSize, compressionCodec);
 
diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/io/CompressedHeaderlessChannelTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/io/CompressedHeaderlessChannelTest.java
index 916f560d33e..bb1ba9beba9 100644
--- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/io/CompressedHeaderlessChannelTest.java
+++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/io/CompressedHeaderlessChannelTest.java
@@ -19,7 +19,6 @@
 package org.apache.flink.table.runtime.io;
 
 import org.apache.flink.runtime.io.compression.BlockCompressionFactory;
-import org.apache.flink.runtime.io.compression.Lz4BlockCompressionFactory;
 import org.apache.flink.runtime.io.disk.iomanager.BufferFileWriter;
 import org.apache.flink.runtime.io.disk.iomanager.FileIOChannel;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -27,6 +26,8 @@ import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
 
 import org.junit.After;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.IOException;
 import java.util.Random;
@@ -37,12 +38,22 @@ import static org.assertj.core.api.Assertions.assertThat;
  * Tests for {@link CompressedHeaderlessChannelReaderInputView} and {@link
  * CompressedHeaderlessChannelWriterOutputView}.
  */
+@RunWith(Parameterized.class)
 public class CompressedHeaderlessChannelTest {
     private static final int BUFFER_SIZE = 256;
 
     private IOManager ioManager;
 
-    private BlockCompressionFactory compressionFactory = new Lz4BlockCompressionFactory();
+    @Parameterized.Parameter public static BlockCompressionFactory compressionFactory;
+
+    @Parameterized.Parameters(name = "compressionFactory = {0}")
+    public static BlockCompressionFactory[] compressionFactory() {
+        return new BlockCompressionFactory[] {
+            BlockCompressionFactory.createBlockCompressionFactory("LZ4"),
+            BlockCompressionFactory.createBlockCompressionFactory("LZO"),
+            BlockCompressionFactory.createBlockCompressionFactory("ZSTD")
+        };
+    }
 
     public CompressedHeaderlessChannelTest() {
         ioManager = new IOManagerAsync();