You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ma...@apache.org on 2022/12/09 16:08:28 UTC

[flink] branch master updated: [FLINK-30165][runtime][JUnit5 Migration] Migrate unaligned checkpoint related tests under flink-runtime module to junit5 (#21368)

This is an automated email from the ASF dual-hosted git repository.

mapohl pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f924bc8422 [FLINK-30165][runtime][JUnit5 Migration] Migrate unaligned checkpoint related tests under flink-runtime module to junit5 (#21368)
5f924bc8422 is described below

commit 5f924bc84227a3a6c67b44e82c45fe444393f577
Author: 1996fanrui <19...@gmail.com>
AuthorDate: Sat Dec 10 00:08:05 2022 +0800

    [FLINK-30165][runtime][JUnit5 Migration] Migrate unaligned checkpoint related tests under flink-runtime module to junit5 (#21368)
---
 .../channel/ChannelStateCheckpointWriterTest.java  |  84 +++---
 .../channel/ChannelStateChunkReaderTest.java       |  54 ++--
 .../channel/ChannelStateSerializerImplTest.java    |  36 ++-
 ...ChannelStateWriteRequestDispatcherImplTest.java |  25 +-
 .../ChannelStateWriteRequestDispatcherTest.java    | 111 ++++----
 .../ChannelStateWriteRequestExecutorImplTest.java  |  57 ++--
 .../channel/ChannelStateWriterImplTest.java        | 303 ++++++++++-----------
 .../channel/CheckpointInProgressRequestTest.java   |  12 +-
 .../InputChannelRecoveredStateHandlerTest.java     |  32 ++-
 .../channel/RecordingChannelStateWriter.java       |  11 +-
 .../channel/RecoveredChannelStateHandlerTest.java  |   8 +-
 ...esultSubpartitionRecoveredStateHandlerTest.java |  22 +-
 .../SequentialChannelStateReaderImplTest.java      |  95 +++----
 13 files changed, 420 insertions(+), 430 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
index 29cd71f0ff0..d135cbdd500 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
@@ -28,16 +28,17 @@ import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsCheckpointStreamFactory;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory.MemoryCheckpointOutputStream;
+import org.apache.flink.testutils.junit.utils.TempDirUtils;
 import org.apache.flink.util.function.RunnableWithException;
 
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
 
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
 import java.io.File;
 import java.io.IOException;
+import java.nio.file.Path;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Random;
@@ -48,20 +49,18 @@ import static org.apache.flink.core.fs.Path.fromLocalFile;
 import static org.apache.flink.core.fs.local.LocalFileSystem.getSharedInstance;
 import static org.apache.flink.core.memory.MemorySegmentFactory.wrap;
 import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.fail;
 
 /** {@link ChannelStateCheckpointWriter} test. */
-public class ChannelStateCheckpointWriterTest {
+class ChannelStateCheckpointWriterTest {
     private static final RunnableWithException NO_OP_RUNNABLE = () -> {};
     private final Random random = new Random();
 
-    @Rule public final TemporaryFolder temporaryFolder = new TemporaryFolder();
+    @TempDir private Path temporaryFolder;
 
     @Test
-    public void testFileHandleSize() throws Exception {
+    void testFileHandleSize() throws Exception {
         int numChannels = 3;
         int numWritesPerChannel = 4;
         int numBytesPerWrite = 5;
@@ -71,8 +70,12 @@ public class ChannelStateCheckpointWriterTest {
                         result,
                         new FsCheckpointStreamFactory(
                                         getSharedInstance(),
-                                        fromLocalFile(temporaryFolder.newFolder("checkpointsDir")),
-                                        fromLocalFile(temporaryFolder.newFolder("sharedStateDir")),
+                                        fromLocalFile(
+                                                TempDirUtils.newFolder(
+                                                        temporaryFolder, "checkpointsDir")),
+                                        fromLocalFile(
+                                                TempDirUtils.newFolder(
+                                                        temporaryFolder, "sharedStateDir")),
                                         numBytesPerWrite - 1,
                                         numBytesPerWrite - 1)
                                 .createCheckpointStateOutputStream(EXCLUSIVE));
@@ -90,18 +93,17 @@ public class ChannelStateCheckpointWriterTest {
         writer.completeOutput();
 
         for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
-            assertEquals(
-                    (Integer.BYTES + numBytesPerWrite) * numWritesPerChannel,
-                    handle.getStateSize());
+            assertThat(handle.getStateSize())
+                    .isEqualTo((Integer.BYTES + numBytesPerWrite) * numWritesPerChannel);
         }
     }
 
     @Test
     @SuppressWarnings("ConstantConditions")
-    public void testSmallFilesNotWritten() throws Exception {
+    void testSmallFilesNotWritten() throws Exception {
         int threshold = 100;
-        File checkpointsDir = temporaryFolder.newFolder("checkpointsDir");
-        File sharedStateDir = temporaryFolder.newFolder("sharedStateDir");
+        File checkpointsDir = TempDirUtils.newFolder(temporaryFolder, "checkpointsDir");
+        File sharedStateDir = TempDirUtils.newFolder(temporaryFolder, "sharedStateDir");
         FsCheckpointStreamFactory checkpointStreamFactory =
                 new FsCheckpointStreamFactory(
                         getSharedInstance(),
@@ -121,13 +123,13 @@ public class ChannelStateCheckpointWriterTest {
         writer.writeInput(new InputChannelInfo(1, 2), buffer);
         writer.completeOutput();
         writer.completeInput();
-        assertTrue(result.isDone());
-        assertEquals(0, checkpointsDir.list().length);
-        assertEquals(0, sharedStateDir.list().length);
+        assertThat(result.isDone()).isTrue();
+        assertThat(checkpointsDir).isEmptyDirectory();
+        assertThat(sharedStateDir).isEmptyDirectory();
     }
 
     @Test
-    public void testEmptyState() throws Exception {
+    void testEmptyState() throws Exception {
         MemoryCheckpointOutputStream stream =
                 new MemoryCheckpointOutputStream(1000) {
                     @Override
@@ -139,22 +141,22 @@ public class ChannelStateCheckpointWriterTest {
         ChannelStateCheckpointWriter writer = createWriter(new ChannelStateWriteResult(), stream);
         writer.completeOutput();
         writer.completeInput();
-        assertTrue(stream.isClosed());
+        assertThat(stream.isClosed()).isTrue();
     }
 
     @Test
-    public void testRecyclingBuffers() throws Exception {
+    void testRecyclingBuffers() {
         ChannelStateCheckpointWriter writer = createWriter(new ChannelStateWriteResult());
         NetworkBuffer buffer =
                 new NetworkBuffer(
                         MemorySegmentFactory.allocateUnpooledSegment(10, null),
                         FreeingBufferRecycler.INSTANCE);
         writer.writeInput(new InputChannelInfo(1, 2), buffer);
-        assertTrue(buffer.isRecycled());
+        assertThat(buffer.isRecycled()).isTrue();
     }
 
     @Test
-    public void testFlush() throws Exception {
+    void testFlush() throws Exception {
         class FlushRecorder extends DataOutputStream {
             private boolean flushed = false;
 
@@ -184,21 +186,21 @@ public class ChannelStateCheckpointWriterTest {
         writer.completeInput();
         writer.completeOutput();
 
-        assertTrue(dataStream.flushed);
+        assertThat(dataStream.flushed).isTrue();
     }
 
     @Test
-    public void testResultCompletion() throws Exception {
+    void testResultCompletion() throws Exception {
         ChannelStateWriteResult result = new ChannelStateWriteResult();
         ChannelStateCheckpointWriter writer = createWriter(result);
         writer.completeInput();
-        assertFalse(result.isDone());
+        assertThat(result.isDone()).isFalse();
         writer.completeOutput();
-        assertTrue(result.isDone());
+        assertThat(result.isDone()).isTrue();
     }
 
     @Test
-    public void testRecordingOffsets() throws Exception {
+    void testRecordingOffsets() throws Exception {
         Map<InputChannelInfo, Integer> offsetCounts = new HashMap<>();
         offsetCounts.put(new InputChannelInfo(1, 1), 1);
         offsetCounts.put(new InputChannelInfo(1, 2), 2);
@@ -218,12 +220,14 @@ public class ChannelStateCheckpointWriterTest {
         for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
             int headerSize = Integer.BYTES;
             int lengthSize = Integer.BYTES;
-            assertEquals(singletonList((long) headerSize), handle.getOffsets());
-            assertEquals(
-                    headerSize + lengthSize + numBytes * offsetCounts.remove(handle.getInfo()),
-                    handle.getDelegate().getStateSize());
+            assertThat(handle.getOffsets()).isEqualTo(singletonList((long) headerSize));
+            assertThat(handle.getDelegate().getStateSize())
+                    .isEqualTo(
+                            headerSize
+                                    + lengthSize
+                                    + numBytes * offsetCounts.remove(handle.getInfo()));
         }
-        assertTrue(offsetCounts.isEmpty());
+        assertThat(offsetCounts).isEmpty();
     }
 
     private byte[] getData(int len) {
@@ -233,8 +237,7 @@ public class ChannelStateCheckpointWriterTest {
     }
 
     private void write(
-            ChannelStateCheckpointWriter writer, InputChannelInfo channelInfo, byte[] data)
-            throws Exception {
+            ChannelStateCheckpointWriter writer, InputChannelInfo channelInfo, byte[] data) {
         MemorySegment segment = wrap(data);
         NetworkBuffer buffer =
                 new NetworkBuffer(
@@ -245,13 +248,12 @@ public class ChannelStateCheckpointWriterTest {
         writer.writeInput(channelInfo, buffer);
     }
 
-    private ChannelStateCheckpointWriter createWriter(ChannelStateWriteResult result)
-            throws Exception {
+    private ChannelStateCheckpointWriter createWriter(ChannelStateWriteResult result) {
         return createWriter(result, new MemoryCheckpointOutputStream(1000));
     }
 
     private ChannelStateCheckpointWriter createWriter(
-            ChannelStateWriteResult result, CheckpointStateOutputStream stream) throws Exception {
+            ChannelStateWriteResult result, CheckpointStateOutputStream stream) {
         return new ChannelStateCheckpointWriter(
                 "dummy task",
                 0,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java
index 0e88489708e..5d58011a3d9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java
@@ -24,7 +24,7 @@ import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
@@ -35,31 +35,39 @@ import java.util.List;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkState;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assertions.fail;
 
 /** {@link ChannelStateChunkReader} test. */
-public class ChannelStateChunkReaderTest {
+class ChannelStateChunkReaderTest {
 
-    @Test(expected = TestException.class)
-    public void testBufferRecycledOnFailure() throws IOException, InterruptedException {
+    @Test
+    void testBufferRecycledOnFailure() {
         FailingChannelStateSerializer serializer = new FailingChannelStateSerializer();
         TestRecoveredChannelStateHandler handler = new TestRecoveredChannelStateHandler();
 
-        try (FSDataInputStream stream = getStream(serializer, 10)) {
-            new ChannelStateChunkReader(serializer)
-                    .readChunk(stream, serializer.getHeaderLength(), handler, "channelInfo", 0);
-        } finally {
-            checkState(serializer.failed);
-            checkState(!handler.requestedBuffers.isEmpty());
-            assertTrue(
-                    handler.requestedBuffers.stream()
-                            .allMatch(TestChannelStateByteBuffer::isRecycled));
-        }
+        assertThatThrownBy(
+                        () -> {
+                            try (FSDataInputStream stream = getStream(serializer, 10)) {
+                                new ChannelStateChunkReader(serializer)
+                                        .readChunk(
+                                                stream,
+                                                serializer.getHeaderLength(),
+                                                handler,
+                                                "channelInfo",
+                                                0);
+                            } finally {
+                                checkState(serializer.failed);
+                                checkState(!handler.requestedBuffers.isEmpty());
+                            }
+                        })
+                .isInstanceOf(TestException.class);
+        assertThat(handler.requestedBuffers).allMatch(TestChannelStateByteBuffer::isRecycled);
     }
 
     @Test
-    public void testBufferRecycledOnSuccess() throws IOException, InterruptedException {
+    void testBufferRecycledOnSuccess() throws IOException, InterruptedException {
         ChannelStateSerializer serializer = new ChannelStateSerializerImpl();
         TestRecoveredChannelStateHandler handler = new TestRecoveredChannelStateHandler();
 
@@ -68,14 +76,12 @@ public class ChannelStateChunkReaderTest {
                     .readChunk(stream, serializer.getHeaderLength(), handler, "channelInfo", 0);
         } finally {
             checkState(!handler.requestedBuffers.isEmpty());
-            assertTrue(
-                    handler.requestedBuffers.stream()
-                            .allMatch(TestChannelStateByteBuffer::isRecycled));
+            assertThat(handler.requestedBuffers).allMatch(TestChannelStateByteBuffer::isRecycled);
         }
     }
 
     @Test
-    public void testBuffersNotRequestedForEmptyStream() throws IOException, InterruptedException {
+    void testBuffersNotRequestedForEmptyStream() throws IOException, InterruptedException {
         ChannelStateSerializer serializer = new ChannelStateSerializerImpl();
         TestRecoveredChannelStateHandler handler = new TestRecoveredChannelStateHandler();
 
@@ -83,12 +89,12 @@ public class ChannelStateChunkReaderTest {
             new ChannelStateChunkReader(serializer)
                     .readChunk(stream, serializer.getHeaderLength(), handler, "channelInfo", 0);
         } finally {
-            assertTrue(handler.requestedBuffers.isEmpty());
+            assertThat(handler.requestedBuffers).isEmpty();
         }
     }
 
     @Test
-    public void testNoSeekUnnecessarily() throws IOException, InterruptedException {
+    void testNoSeekUnnecessarily() throws IOException, InterruptedException {
         final int offset = 123;
         final FSDataInputStream stream =
                 new FSDataInputStream() {
@@ -99,7 +105,7 @@ public class ChannelStateChunkReaderTest {
 
                     @Override
                     public void seek(long ignored) {
-                        fail();
+                        fail("It shouldn't be called.");
                     }
 
                     @Override
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializerImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializerImplTest.java
index 3590dbb5d01..6d6688919f2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializerImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializerImplTest.java
@@ -24,7 +24,7 @@ import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
@@ -35,17 +35,13 @@ import java.util.Arrays;
 import java.util.Random;
 
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateByteBuffer.wrap;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** {@link ChannelStateSerializerImpl} test. */
-public class ChannelStateSerializerImplTest {
-
-    private final Random random = new Random();
+class ChannelStateSerializerImplTest {
 
     @Test
-    public void testReadWrite() throws IOException {
+    void testReadWrite() throws IOException {
         byte[] data = generateData(123);
         ChannelStateSerializerImpl serializer = new ChannelStateSerializerImpl();
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream(data.length)) {
@@ -55,7 +51,7 @@ public class ChannelStateSerializerImplTest {
     }
 
     @Test
-    public void testReadWriteWithMultipleBuffers() throws IOException {
+    void testReadWriteWithMultipleBuffers() throws IOException {
         int bufSize = 10;
         int[] numBuffersToWriteAtOnce = {0, 1, 2, 3};
         byte[] data = generateData(bufSize);
@@ -75,18 +71,18 @@ public class ChannelStateSerializerImplTest {
         d.readHeader(is);
         for (int count : numBuffersToWriteAtOnce) {
             int expected = bufSize * count;
-            assertEquals(expected, d.readLength(is));
+            assertThat(d.readLength(is)).isEqualTo(expected);
             byte[] readBuf = new byte[expected];
-            assertEquals(expected, d.readData(is, wrap(readBuf), Integer.MAX_VALUE));
+            assertThat(d.readData(is, wrap(readBuf), Integer.MAX_VALUE)).isEqualTo(expected);
             for (int i = 0; i < count; i++) {
-                assertArrayEquals(
-                        data, Arrays.copyOfRange(readBuf, i * bufSize, (i + 1) * bufSize));
+                assertThat(Arrays.copyOfRange(readBuf, i * bufSize, (i + 1) * bufSize))
+                        .isEqualTo(data);
             }
         }
     }
 
     @Test
-    public void testReadToBufferBuilder() throws IOException {
+    void testReadToBufferBuilder() throws IOException {
         byte[] data = generateData(100);
         BufferBuilder bufferBuilder =
                 new BufferBuilder(
@@ -97,15 +93,15 @@ public class ChannelStateSerializerImplTest {
         new ChannelStateSerializerImpl()
                 .readData(new ByteArrayInputStream(data), wrap(bufferBuilder), Integer.MAX_VALUE);
 
-        assertFalse(bufferBuilder.isFinished());
+        assertThat(bufferBuilder.isFinished()).isFalse();
 
         bufferBuilder.finish();
         Buffer buffer = bufferConsumer.build();
 
-        assertEquals(data.length, buffer.readableBytes());
+        assertThat(buffer.readableBytes()).isEqualTo(data.length);
         byte[] actual = new byte[buffer.readableBytes()];
         buffer.asByteBuf().readBytes(actual);
-        assertArrayEquals(data, actual);
+        assertThat(actual).isEqualTo(data);
     }
 
     private NetworkBuffer getBuffer(byte[] data) {
@@ -145,15 +141,15 @@ public class ChannelStateSerializerImplTest {
             throws IOException {
         serializer.readHeader(is);
         int size = serializer.readLength(is);
-        assertEquals(data.length, size);
+        assertThat(data).hasSize(size);
         NetworkBuffer buffer =
                 new NetworkBuffer(
                         MemorySegmentFactory.allocateUnpooledSegment(data.length),
                         FreeingBufferRecycler.INSTANCE);
         try {
             int read = serializer.readData(is, wrap(buffer), size);
-            assertEquals(size, read);
-            assertArrayEquals(data, readBytes(buffer));
+            assertThat(read).isEqualTo(size);
+            assertThat(readBytes(buffer)).isEqualTo(data);
         } finally {
             buffer.release();
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
index 0a52f32e3e0..3d80f290856 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
@@ -26,20 +26,19 @@ import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorageAccess;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.util.function.Function;
 
 import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.util.CloseableIterator.ofElements;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** {@link ChannelStateWriteRequestDispatcherImpl} test. */
-public class ChannelStateWriteRequestDispatcherImplTest {
+class ChannelStateWriteRequestDispatcherImplTest {
 
     @Test
-    public void testPartialInputChannelStateWrite() throws Exception {
+    void testPartialInputChannelStateWrite() throws Exception {
         testBuffersRecycled(
                 buffers ->
                         ChannelStateWriteRequest.write(
@@ -49,7 +48,7 @@ public class ChannelStateWriteRequestDispatcherImplTest {
     }
 
     @Test
-    public void testPartialResultSubpartitionStateWrite() throws Exception {
+    void testPartialResultSubpartitionStateWrite() throws Exception {
         testBuffersRecycled(
                 buffers ->
                         ChannelStateWriteRequest.write(
@@ -57,7 +56,7 @@ public class ChannelStateWriteRequestDispatcherImplTest {
     }
 
     @Test
-    public void testConcurrentUnalignedCheckpoint() throws Exception {
+    void testConcurrentUnalignedCheckpoint() throws Exception {
         ChannelStateWriteRequestDispatcher processor =
                 new ChannelStateWriteRequestDispatcherImpl(
                         "dummy task",
@@ -68,16 +67,16 @@ public class ChannelStateWriteRequestDispatcherImplTest {
         processor.dispatch(
                 ChannelStateWriteRequest.start(
                         1L, result, CheckpointStorageLocationReference.getDefault()));
-        assertFalse(result.isDone());
+        assertThat(result.isDone()).isFalse();
 
         processor.dispatch(
                 ChannelStateWriteRequest.start(
                         2L,
                         new ChannelStateWriteResult(),
                         CheckpointStorageLocationReference.getDefault()));
-        assertTrue(result.isDone());
-        assertTrue(result.getInputChannelStateHandles().isCompletedExceptionally());
-        assertTrue(result.getResultSubpartitionStateHandles().isCompletedExceptionally());
+        assertThat(result.isDone()).isTrue();
+        assertThat(result.getInputChannelStateHandles()).isCompletedExceptionally();
+        assertThat(result.getResultSubpartitionStateHandles()).isCompletedExceptionally();
     }
 
     private void testBuffersRecycled(
@@ -98,9 +97,7 @@ public class ChannelStateWriteRequestDispatcherImplTest {
 
         NetworkBuffer[] buffers = new NetworkBuffer[] {buffer(), buffer()};
         dispatcher.dispatch(requestBuilder.apply(buffers));
-        for (NetworkBuffer buffer : buffers) {
-            assertTrue(buffer.isRecycled());
-        }
+        assertThat(buffers).allMatch(NetworkBuffer::isRecycled);
     }
 
     private NetworkBuffer buffer() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
index f425d057c86..4b59e1235e6 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
@@ -23,13 +23,15 @@ import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
 import org.apache.flink.util.CloseableIterator;
 
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-import org.junit.runners.Parameterized.Parameters;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
 
+import java.util.Arrays;
 import java.util.List;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
@@ -42,56 +44,59 @@ import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteReque
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput;
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.write;
 import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.fail;
 
 /** {@link ChannelStateWriteRequestDispatcherImpl} tests. */
-@RunWith(Parameterized.class)
+@ExtendWith(ParameterizedTestExtension.class)
 public class ChannelStateWriteRequestDispatcherTest {
 
-    private final List<ChannelStateWriteRequest> requests;
-    private final Optional<Class<?>> expectedException;
-    public static final long CHECKPOINT_ID = 42L;
-
-    @Parameters
-    public static Object[][] data() {
-
-        return new Object[][] {
-            // valid calls
-            new Object[] {empty(), asList(start(), completeIn(), completeOut())},
-            new Object[] {empty(), asList(start(), writeIn(), completeIn())},
-            new Object[] {empty(), asList(start(), writeOut(), completeOut())},
-            new Object[] {empty(), asList(start(), writeOutFuture(), completeOut())},
-            new Object[] {empty(), asList(start(), completeIn(), writeOut())},
-            new Object[] {empty(), asList(start(), completeIn(), writeOutFuture())},
-            new Object[] {empty(), asList(start(), completeOut(), writeIn())},
-            // invalid without start
-            new Object[] {of(IllegalArgumentException.class), singletonList(writeIn())},
-            new Object[] {of(IllegalArgumentException.class), singletonList(writeOut())},
-            new Object[] {of(IllegalArgumentException.class), singletonList(writeOutFuture())},
-            new Object[] {of(IllegalArgumentException.class), singletonList(completeIn())},
-            new Object[] {of(IllegalArgumentException.class), singletonList(completeOut())},
-            // invalid double complete
-            new Object[] {
-                of(IllegalArgumentException.class), asList(start(), completeIn(), completeIn())
-            },
-            new Object[] {
-                of(IllegalArgumentException.class), asList(start(), completeOut(), completeOut())
-            },
-            // invalid write after complete
-            new Object[] {
-                of(IllegalStateException.class), asList(start(), completeIn(), writeIn())
-            },
-            new Object[] {
-                of(IllegalStateException.class), asList(start(), completeOut(), writeOut())
-            },
-            new Object[] {
-                of(IllegalStateException.class), asList(start(), completeOut(), writeOutFuture())
-            },
-            // invalid double start
-            new Object[] {of(IllegalStateException.class), asList(start(), start())}
-        };
+    @Parameters(name = "expectedException={0} requests={1}")
+    public static List<Object[]> data() {
+        return Arrays.asList(
+                // valid calls
+                new Object[] {empty(), asList(start(), completeIn(), completeOut())},
+                new Object[] {empty(), asList(start(), writeIn(), completeIn())},
+                new Object[] {empty(), asList(start(), writeOut(), completeOut())},
+                new Object[] {empty(), asList(start(), writeOutFuture(), completeOut())},
+                new Object[] {empty(), asList(start(), completeIn(), writeOut())},
+                new Object[] {empty(), asList(start(), completeIn(), writeOutFuture())},
+                new Object[] {empty(), asList(start(), completeOut(), writeIn())},
+                // invalid without start
+                new Object[] {of(IllegalArgumentException.class), singletonList(writeIn())},
+                new Object[] {of(IllegalArgumentException.class), singletonList(writeOut())},
+                new Object[] {of(IllegalArgumentException.class), singletonList(writeOutFuture())},
+                new Object[] {of(IllegalArgumentException.class), singletonList(completeIn())},
+                new Object[] {of(IllegalArgumentException.class), singletonList(completeOut())},
+                // invalid double complete
+                new Object[] {
+                    of(IllegalArgumentException.class), asList(start(), completeIn(), completeIn())
+                },
+                new Object[] {
+                    of(IllegalArgumentException.class),
+                    asList(start(), completeOut(), completeOut())
+                },
+                // invalid write after complete
+                new Object[] {
+                    of(IllegalStateException.class), asList(start(), completeIn(), writeIn())
+                },
+                new Object[] {
+                    of(IllegalStateException.class), asList(start(), completeOut(), writeOut())
+                },
+                new Object[] {
+                    of(IllegalStateException.class),
+                    asList(start(), completeOut(), writeOutFuture())
+                },
+                // invalid double start
+                new Object[] {of(IllegalStateException.class), asList(start(), start())});
     }
 
+    @Parameter public Optional<Class<Exception>> expectedException;
+
+    @Parameter(value = 1)
+    public List<ChannelStateWriteRequest> requests;
+
+    private static final long CHECKPOINT_ID = 42L;
+
     private static CheckpointInProgressRequest completeOut() {
         return completeOutput(CHECKPOINT_ID);
     }
@@ -139,14 +144,8 @@ public class ChannelStateWriteRequestDispatcherTest {
                 new CheckpointStorageLocationReference(new byte[] {1}));
     }
 
-    public ChannelStateWriteRequestDispatcherTest(
-            Optional<Class<?>> expectedException, List<ChannelStateWriteRequest> requests) {
-        this.requests = requests;
-        this.expectedException = expectedException;
-    }
-
-    @Test
-    public void doRun() {
+    @TestTemplate
+    void doRun() {
         ChannelStateWriteRequestDispatcher processor =
                 new ChannelStateWriteRequestDispatcherImpl(
                         "dummy task",
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
index 1d199985d96..b7aa7afb682 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
@@ -21,44 +21,49 @@ import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.function.BiConsumerWithException;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
-import java.util.Arrays;
+import java.util.Collections;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.LinkedBlockingDeque;
 
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestDispatcher.NO_OP;
 import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assertions.fail;
 
 /** {@link ChannelStateWriteRequestExecutorImpl} test. */
-public class ChannelStateWriteRequestExecutorImplTest {
+class ChannelStateWriteRequestExecutorImplTest {
 
     private static final String TASK_NAME = "test task";
 
-    @Test(expected = IllegalStateException.class)
-    public void testCloseAfterSubmit() throws Exception {
-        testCloseAfterSubmit(ChannelStateWriteRequestExecutor::submit);
+    @Test
+    void testCloseAfterSubmit() {
+        assertThatThrownBy(() -> testCloseAfterSubmit(ChannelStateWriteRequestExecutor::submit))
+                .isInstanceOf(IllegalStateException.class);
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testCloseAfterSubmitPriority() throws Exception {
-        testCloseAfterSubmit(ChannelStateWriteRequestExecutor::submitPriority);
+    @Test
+    void testCloseAfterSubmitPriority() {
+        assertThatThrownBy(
+                        () ->
+                                testCloseAfterSubmit(
+                                        ChannelStateWriteRequestExecutor::submitPriority))
+                .isInstanceOf(IllegalStateException.class);
     }
 
     @Test
-    public void testSubmitFailure() throws Exception {
+    void testSubmitFailure() throws Exception {
         testSubmitFailure(ChannelStateWriteRequestExecutor::submit);
     }
 
     @Test
-    public void testSubmitPriorityFailure() throws Exception {
+    void testSubmitPriorityFailure() throws Exception {
         testSubmitFailure(ChannelStateWriteRequestExecutor::submitPriority);
     }
 
@@ -73,8 +78,8 @@ public class ChannelStateWriteRequestExecutorImplTest {
         closingDeque.setWorker(worker);
         TestWriteRequest request = new TestWriteRequest();
         requestFun.accept(worker, request);
-        assertTrue(closingDeque.isEmpty());
-        assertFalse(request.isCancelled());
+        assertThat(closingDeque).isEmpty();
+        assertThat(request.isCancelled()).isFalse();
     }
 
     private void testSubmitFailure(
@@ -91,15 +96,15 @@ public class ChannelStateWriteRequestExecutorImplTest {
             // expected: executor not started;
             return;
         } finally {
-            assertTrue(request.cancelled);
-            assertTrue(deque.isEmpty());
+            assertThat(request.cancelled).isTrue();
+            assertThat(deque).isEmpty();
         }
         throw new RuntimeException("expected exception not thrown");
     }
 
     @Test
     @SuppressWarnings("CallToThreadRun")
-    public void testCleanup() throws IOException {
+    void testCleanup() throws IOException {
         TestWriteRequest request = new TestWriteRequest();
         LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
         deque.add(request);
@@ -110,13 +115,13 @@ public class ChannelStateWriteRequestExecutorImplTest {
         worker.close();
         worker.run();
 
-        assertTrue(requestProcessor.isStopped());
-        assertTrue(deque.isEmpty());
-        assertTrue(request.isCancelled());
+        assertThat(requestProcessor.isStopped()).isTrue();
+        assertThat(deque).isEmpty();
+        assertThat(request.isCancelled()).isTrue();
     }
 
     @Test
-    public void testIgnoresInterruptsWhileRunning() throws Exception {
+    void testIgnoresInterruptsWhileRunning() throws Exception {
         TestRequestDispatcher requestProcessor = new TestRequestDispatcher();
         LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
         try (ChannelStateWriteRequestExecutorImpl worker =
@@ -132,7 +137,7 @@ public class ChannelStateWriteRequestExecutorImplTest {
     }
 
     @Test
-    public void testCanBeClosed() throws Exception {
+    void testCanBeClosed() throws Exception {
         long checkpointId = 1L;
         ChannelStateWriteRequestDispatcher processor =
                 new ChannelStateWriteRequestDispatcherImpl(
@@ -162,7 +167,7 @@ public class ChannelStateWriteRequestExecutorImplTest {
     }
 
     @Test
-    public void testRecordsException() throws IOException {
+    void testRecordsException() throws IOException {
         TestException testException = new TestException();
         TestRequestDispatcher throwingRequestProcessor =
                 new TestRequestDispatcher() {
@@ -172,7 +177,7 @@ public class ChannelStateWriteRequestExecutorImplTest {
                     }
                 };
         LinkedBlockingDeque<ChannelStateWriteRequest> deque =
-                new LinkedBlockingDeque<>(Arrays.asList(new TestWriteRequest()));
+                new LinkedBlockingDeque<>(Collections.singletonList(new TestWriteRequest()));
         ChannelStateWriteRequestExecutorImpl worker =
                 new ChannelStateWriteRequestExecutorImpl(
                         TASK_NAME, throwingRequestProcessor, deque);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
index 0aa76af9648..fcf9ad65f2a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
@@ -24,9 +24,8 @@ import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.util.function.BiConsumerWithException;
-import org.apache.flink.util.function.RunnableWithException;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.io.IOException;
 import java.util.ArrayDeque;
@@ -36,78 +35,74 @@ import java.util.function.Consumer;
 
 import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.util.CloseableIterator.ofElements;
-import static org.apache.flink.util.ExceptionUtils.findThrowable;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** {@link ChannelStateWriterImpl} lifecycle tests. */
-public class ChannelStateWriterImplTest {
+class ChannelStateWriterImplTest {
     private static final long CHECKPOINT_ID = 42L;
     private static final String TASK_NAME = "test";
 
-    @Test(expected = IllegalArgumentException.class)
-    public void testAddEventBuffer() throws Exception {
+    @Test
+    void testAddEventBuffer() throws Exception {
 
         NetworkBuffer dataBuf = getBuffer();
         NetworkBuffer eventBuf = getBuffer();
         eventBuf.setDataType(Buffer.DataType.EVENT_BUFFER);
-        try {
-            runWithSyncWorker(
-                    writer -> {
-                        callStart(writer);
-                        writer.addInputData(
-                                CHECKPOINT_ID,
-                                new InputChannelInfo(1, 1),
-                                1,
-                                ofElements(Buffer::recycleBuffer, eventBuf, dataBuf));
-                    });
-        } finally {
-            assertTrue(dataBuf.isRecycled());
-        }
+
+        executeCallbackWithSyncWorker(
+                (writer, worker) -> {
+                    callStart(writer);
+                    callAddInputData(writer, eventBuf, dataBuf);
+                    assertThatThrownBy(worker::processAllRequests)
+                            .isInstanceOf(IllegalArgumentException.class);
+                });
+        assertThat(dataBuf.isRecycled()).isTrue();
     }
 
     @Test
-    public void testResultCompletion() throws IOException {
+    void testResultCompletion() throws IOException {
         ChannelStateWriteResult result;
         try (ChannelStateWriterImpl writer = openWriter()) {
             callStart(writer);
             result = writer.getAndRemoveWriteResult(CHECKPOINT_ID);
-            assertFalse(result.resultSubpartitionStateHandles.isDone());
-            assertFalse(result.inputChannelStateHandles.isDone());
+            assertThat(result.resultSubpartitionStateHandles).isNotDone();
+            assertThat(result.inputChannelStateHandles).isNotDone();
         }
-        assertTrue(result.inputChannelStateHandles.isDone());
-        assertTrue(result.resultSubpartitionStateHandles.isDone());
+        assertThat(result.inputChannelStateHandles).isDone();
+        assertThat(result.resultSubpartitionStateHandles).isDone();
     }
 
     @Test
-    public void testAbort() throws Exception {
+    void testAbort() throws Exception {
         NetworkBuffer buffer = getBuffer();
-        runWithSyncWorker(
+        executeCallbackWithSyncWorker(
                 (writer, worker) -> {
                     callStart(writer);
                     ChannelStateWriteResult result = writer.getAndRemoveWriteResult(CHECKPOINT_ID);
                     callAddInputData(writer, buffer);
                     callAbort(writer);
                     worker.processAllRequests();
-                    assertTrue(result.isDone());
-                    assertTrue(buffer.isRecycled());
+                    assertThat(result.isDone()).isTrue();
+                    assertThat(buffer.isRecycled()).isTrue();
                 });
     }
 
-    @Test(expected = IllegalArgumentException.class)
-    public void testAbortClearsResults() throws Exception {
-        runWithSyncWorker(
+    @Test
+    void testAbortClearsResults() throws Exception {
+        executeCallbackWithSyncWorker(
                 (writer, worker) -> {
                     callStart(writer);
                     writer.abort(CHECKPOINT_ID, new TestException(), true);
-                    writer.getAndRemoveWriteResult(CHECKPOINT_ID);
+
+                    assertThatThrownBy(() -> writer.getAndRemoveWriteResult(CHECKPOINT_ID))
+                            .isInstanceOf(IllegalArgumentException.class);
                 });
     }
 
     @Test
-    public void testAbortDoesNotClearsResults() throws Exception {
-        runWithSyncWorker(
+    void testAbortDoesNotClearsResults() throws Exception {
+        executeCallbackWithSyncWorker(
                 (writer, worker) -> {
                     callStart(writer);
                     callAbort(writer);
@@ -117,13 +112,13 @@ public class ChannelStateWriterImplTest {
     }
 
     @Test
-    public void testAbortIgnoresMissing() throws Exception {
-        runWithSyncWorker(this::callAbort);
+    void testAbortIgnoresMissing() throws Exception {
+        executeCallbackAndProcessWithSyncWorker(this::callAbort);
     }
 
     @Test
-    public void testAbortOldAndStartNewCheckpoint() throws Exception {
-        runWithSyncWorker(
+    void testAbortOldAndStartNewCheckpoint() throws Exception {
+        executeCallbackWithSyncWorker(
                 (writer, worker) -> {
                     int checkpoint42 = 42;
                     int checkpoint43 = 43;
@@ -135,100 +130,96 @@ public class ChannelStateWriterImplTest {
                     worker.processAllRequests();
 
                     ChannelStateWriteResult result42 = writer.getAndRemoveWriteResult(checkpoint42);
-                    assertTrue(result42.isDone());
-                    try {
-                        result42.getInputChannelStateHandles().get();
-                        fail("The result should have failed.");
-                    } catch (Throwable throwable) {
-                        assertTrue(findThrowable(throwable, TestException.class).isPresent());
-                    }
+                    assertThat(result42.isDone()).isTrue();
+                    assertThatThrownBy(() -> result42.getInputChannelStateHandles().get())
+                            .as("The result should have failed.")
+                            .hasCauseInstanceOf(TestException.class);
 
                     ChannelStateWriteResult result43 = writer.getAndRemoveWriteResult(checkpoint43);
-                    assertFalse(result43.isDone());
+                    assertThat(result43.isDone()).isFalse();
                 });
     }
 
-    @Test(expected = TestException.class)
-    public void testBuffersRecycledOnError() throws Exception {
-        unwrappingError(
-                TestException.class,
-                () -> {
-                    NetworkBuffer buffer = getBuffer();
-                    try (ChannelStateWriterImpl writer =
-                            new ChannelStateWriterImpl(
-                                    TASK_NAME, new ConcurrentHashMap<>(), failingWorker(), 5)) {
-                        writer.open();
-                        callAddInputData(writer, buffer);
-                    } finally {
-                        assertTrue(buffer.isRecycled());
-                    }
-                });
+    @Test
+    void testBuffersRecycledOnError() throws IOException {
+        NetworkBuffer buffer = getBuffer();
+        try (ChannelStateWriterImpl writer =
+                new ChannelStateWriterImpl(
+                        TASK_NAME, new ConcurrentHashMap<>(), failingWorker(), 5)) {
+            writer.open();
+            assertThatThrownBy(() -> callAddInputData(writer, buffer))
+                    .isInstanceOf(RuntimeException.class)
+                    .hasCauseInstanceOf(TestException.class);
+            assertThat(buffer.isRecycled()).isTrue();
+        }
     }
 
     @Test
-    public void testBuffersRecycledOnClose() throws Exception {
+    void testBuffersRecycledOnClose() throws Exception {
         NetworkBuffer buffer = getBuffer();
-        runWithSyncWorker(
+        executeCallbackAndProcessWithSyncWorker(
                 writer -> {
                     callStart(writer);
                     callAddInputData(writer, buffer);
-                    assertFalse(buffer.isRecycled());
+                    assertThat(buffer.isRecycled()).isFalse();
                 });
-        assertTrue(buffer.isRecycled());
-    }
-
-    @Test(expected = IllegalArgumentException.class)
-    public void testNoAddDataAfterFinished() throws Exception {
-        unwrappingError(
-                IllegalArgumentException.class,
-                () ->
-                        runWithSyncWorker(
-                                writer -> {
-                                    callStart(writer);
-                                    callFinish(writer);
-                                    callAddInputData(writer);
-                                }));
-    }
-
-    @Test(expected = IllegalArgumentException.class)
-    public void testAddDataNotStarted() throws Exception {
-        unwrappingError(
-                IllegalArgumentException.class,
-                () -> runWithSyncWorker((Consumer<ChannelStateWriter>) this::callAddInputData));
-    }
-
-    @Test(expected = IllegalArgumentException.class)
-    public void testFinishNotStarted() throws Exception {
-        unwrappingError(IllegalArgumentException.class, () -> runWithSyncWorker(this::callFinish));
-    }
-
-    @Test(expected = IllegalArgumentException.class)
-    public void testRethrowOnClose() throws Exception {
-        unwrappingError(
-                IllegalArgumentException.class,
-                () ->
-                        runWithSyncWorker(
-                                writer -> {
-                                    try {
-                                        callFinish(writer);
-                                    } catch (IllegalArgumentException e) {
-                                        // ignore here - should rethrow in close
-                                    }
-                                }));
-    }
-
-    @Test(expected = TestException.class)
-    public void testRethrowOnNextCall() throws Exception {
+        assertThat(buffer.isRecycled()).isTrue();
+    }
+
+    @Test
+    void testNoAddDataAfterFinished() throws Exception {
+        executeCallbackWithSyncWorker(
+                (writer, worker) -> {
+                    callStart(writer);
+                    callFinish(writer);
+                    worker.processAllRequests();
+
+                    callAddInputData(writer);
+                    assertThatThrownBy(worker::processAllRequests)
+                            .isInstanceOf(IllegalArgumentException.class);
+                });
+    }
+
+    @Test
+    void testAddDataNotStarted() {
+        assertThatThrownBy(() -> executeCallbackAndProcessWithSyncWorker(this::callAddInputData))
+                .isInstanceOf(IllegalArgumentException.class);
+    }
+
+    @Test
+    void testFinishNotStarted() {
+        assertThatThrownBy(() -> executeCallbackAndProcessWithSyncWorker(this::callFinish))
+                .isInstanceOf(IllegalArgumentException.class);
+    }
+
+    @Test
+    void testRethrowOnClose() {
+        assertThatThrownBy(
+                        () ->
+                                executeCallbackAndProcessWithSyncWorker(
+                                        writer -> {
+                                            try {
+                                                callFinish(writer);
+                                            } catch (IllegalArgumentException e) {
+                                                // ignore here - should rethrow in
+                                                // close
+                                            }
+                                        }))
+                .isInstanceOf(IllegalArgumentException.class);
+    }
+
+    @Test
+    void testRethrowOnNextCall() {
         SyncChannelStateWriteRequestExecutor worker = new SyncChannelStateWriteRequestExecutor();
         ChannelStateWriterImpl writer =
                 new ChannelStateWriterImpl(TASK_NAME, new ConcurrentHashMap<>(), worker, 5);
         writer.open();
         worker.setThrown(new TestException());
-        unwrappingError(TestException.class, () -> callStart(writer));
+        assertThatThrownBy(() -> callStart(writer)).hasCauseInstanceOf(TestException.class);
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testLimit() throws IOException {
+    @Test
+    void testLimit() throws IOException {
         int maxCheckpoints = 3;
         try (ChannelStateWriterImpl writer =
                 new ChannelStateWriterImpl(
@@ -237,52 +228,42 @@ public class ChannelStateWriterImplTest {
             for (int i = 0; i < maxCheckpoints; i++) {
                 writer.start(i, CheckpointOptions.forCheckpointWithDefaultLocation());
             }
-            writer.start(maxCheckpoints, CheckpointOptions.forCheckpointWithDefaultLocation());
+            assertThatThrownBy(
+                            () ->
+                                    writer.start(
+                                            maxCheckpoints,
+                                            CheckpointOptions.forCheckpointWithDefaultLocation()))
+                    .isInstanceOf(IllegalStateException.class);
         }
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testStartNotOpened() throws Exception {
-        unwrappingError(
-                IllegalStateException.class,
-                () -> {
-                    try (ChannelStateWriterImpl writer =
-                            new ChannelStateWriterImpl(TASK_NAME, 0, getStreamFactoryFactory())) {
-                        callStart(writer);
-                    }
-                });
-    }
-
-    @Test(expected = IllegalStateException.class)
-    public void testNoStartAfterClose() throws Exception {
-        unwrappingError(
-                IllegalStateException.class,
-                () -> {
-                    ChannelStateWriterImpl writer = openWriter();
-                    writer.close();
-                    writer.start(42, CheckpointOptions.forCheckpointWithDefaultLocation());
-                });
+    @Test
+    void testStartNotOpened() throws IOException {
+        try (ChannelStateWriterImpl writer =
+                new ChannelStateWriterImpl(TASK_NAME, 0, getStreamFactoryFactory())) {
+            assertThatThrownBy(() -> callStart(writer))
+                    .hasCauseInstanceOf(IllegalStateException.class);
+        }
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testNoAddDataAfterClose() throws Exception {
-        unwrappingError(
-                IllegalStateException.class,
-                () -> {
-                    ChannelStateWriterImpl writer = openWriter();
-                    callStart(writer);
-                    writer.close();
-                    callAddInputData(writer);
-                });
+    @Test
+    void testNoStartAfterClose() throws IOException {
+        ChannelStateWriterImpl writer = openWriter();
+        writer.close();
+        assertThatThrownBy(
+                        () ->
+                                writer.start(
+                                        42, CheckpointOptions.forCheckpointWithDefaultLocation()))
+                .hasCauseInstanceOf(IllegalStateException.class);
     }
 
-    private static <T extends Throwable> void unwrappingError(
-            Class<T> clazz, RunnableWithException r) throws Exception {
-        try {
-            r.run();
-        } catch (Exception e) {
-            throw findThrowable(e, clazz).map(te -> (Exception) te).orElse(e);
-        }
+    @Test
+    void testNoAddDataAfterClose() throws IOException {
+        ChannelStateWriterImpl writer = openWriter();
+        callStart(writer);
+        writer.close();
+        assertThatThrownBy(() -> callAddInputData(writer))
+                .hasCauseInstanceOf(IllegalStateException.class);
     }
 
     private NetworkBuffer getBuffer() {
@@ -311,13 +292,16 @@ public class ChannelStateWriterImplTest {
         };
     }
 
-    private void runWithSyncWorker(Consumer<ChannelStateWriter> writerConsumer) throws Exception {
-        runWithSyncWorker(
-                (channelStateWriter, syncChannelStateWriterWorker) ->
-                        writerConsumer.accept(channelStateWriter));
+    private void executeCallbackAndProcessWithSyncWorker(
+            Consumer<ChannelStateWriter> writerConsumer) throws Exception {
+        executeCallbackWithSyncWorker(
+                (channelStateWriter, syncChannelStateWriterWorker) -> {
+                    writerConsumer.accept(channelStateWriter);
+                    syncChannelStateWriterWorker.processAllRequests();
+                });
     }
 
-    private void runWithSyncWorker(
+    private void executeCallbackWithSyncWorker(
             BiConsumerWithException<
                             ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>
                     testFn)
@@ -329,7 +313,6 @@ public class ChannelStateWriterImplTest {
                                 TASK_NAME, new ConcurrentHashMap<>(), worker, 5)) {
             writer.open();
             testFn.accept(writer, worker);
-            worker.processAllRequests();
         }
     }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
index e37fbcb1404..12a87c14183 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
@@ -17,23 +17,23 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.atomic.AtomicInteger;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.fail;
 
 /** {@link CheckpointInProgressRequest} test. */
-public class CheckpointInProgressRequestTest {
+class CheckpointInProgressRequestTest {
 
     /**
      * Tests that a request can only be cancelled once. This is important for requests to write data
      * to prevent double recycling of their buffers.
      */
     @Test
-    public void testNoCancelTwice() throws Exception {
+    void testNoCancelTwice() throws Exception {
         AtomicInteger counter = new AtomicInteger();
         CyclicBarrier barrier = new CyclicBarrier(10);
         CheckpointInProgressRequest request = cancelCountingRequest(counter, barrier);
@@ -55,7 +55,7 @@ public class CheckpointInProgressRequestTest {
             threads[i].join();
         }
 
-        assertEquals(1, counter.get());
+        assertThat(counter).hasValue(1);
     }
 
     private CheckpointInProgressRequest cancelCountingRequest(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
index d499b908ff0..823183be99e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
@@ -27,23 +27,23 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
 
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
 
 import java.util.HashSet;
 
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Test of different implementation of {@link InputChannelRecoveredStateHandler}. */
-public class InputChannelRecoveredStateHandlerTest extends RecoveredChannelStateHandlerTest {
+class InputChannelRecoveredStateHandlerTest extends RecoveredChannelStateHandlerTest {
     private static final int preAllocatedSegments = 3;
     private NetworkBufferPool networkBufferPool;
     private SingleInputGate inputGate;
-    private InputChannelRecoveredStateHandler icsHander;
+    private InputChannelRecoveredStateHandler icsHandler;
     private InputChannelInfo channelInfo;
 
-    @Before
-    public void setUp() {
+    @BeforeEach
+    void setUp() {
         // given: Segment provider with defined number of allocated segments.
         networkBufferPool = new NetworkBufferPool(preAllocatedSegments, 1024);
 
@@ -54,7 +54,7 @@ public class InputChannelRecoveredStateHandlerTest extends RecoveredChannelState
                         .setSegmentProvider(networkBufferPool)
                         .build();
 
-        icsHander = buildInputChannelStateHandler(inputGate);
+        icsHandler = buildInputChannelStateHandler(inputGate);
 
         channelInfo = new InputChannelInfo(0, 0);
     }
@@ -78,10 +78,10 @@ public class InputChannelRecoveredStateHandlerTest extends RecoveredChannelState
     }
 
     @Test
-    public void testRecycleBufferBeforeRecoverWasCalled() throws Exception {
+    void testRecycleBufferBeforeRecoverWasCalled() throws Exception {
         // when: Request the buffer.
         RecoveredChannelStateHandler.BufferWithContext<Buffer> bufferWithContext =
-                icsHander.getBuffer(channelInfo);
+                icsHandler.getBuffer(channelInfo);
 
         // and: Recycle buffer outside.
         bufferWithContext.buffer.close();
@@ -90,22 +90,24 @@ public class InputChannelRecoveredStateHandlerTest extends RecoveredChannelState
         inputGate.close();
 
         // then: All pre-allocated segments should be successfully recycled.
-        assertEquals(preAllocatedSegments, networkBufferPool.getNumberOfAvailableMemorySegments());
+        assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                .isEqualTo(preAllocatedSegments);
     }
 
     @Test
-    public void testRecycleBufferAfterRecoverWasCalled() throws Exception {
+    void testRecycleBufferAfterRecoverWasCalled() throws Exception {
         // when: Request the buffer.
         RecoveredChannelStateHandler.BufferWithContext<Buffer> bufferWithContext =
-                icsHander.getBuffer(channelInfo);
+                icsHandler.getBuffer(channelInfo);
 
         // and: Recycle buffer outside.
-        icsHander.recover(channelInfo, 0, bufferWithContext);
+        icsHandler.recover(channelInfo, 0, bufferWithContext);
 
         // Close the gate for flushing the cached recycled buffers to the segment provider.
         inputGate.close();
 
         // then: All pre-allocated segments should be successfully recycled.
-        assertEquals(preAllocatedSegments, networkBufferPool.getNumberOfAvailableMemorySegments());
+        assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                .isEqualTo(preAllocatedSegments);
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java
index cc58f2d03fe..f89a7c1940e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java
@@ -31,9 +31,9 @@ import static org.apache.flink.util.ExceptionUtils.rethrow;
 /** A simple {@link ChannelStateWriter} used to write unit tests. */
 public class RecordingChannelStateWriter extends MockChannelStateWriter {
     private long lastStartedCheckpointId = -1;
-    private long lastFinishedCheckpointId = -1;
-    private ListMultimap<InputChannelInfo, Buffer> addedInput = LinkedListMultimap.create();
-    private ListMultimap<ResultSubpartitionInfo, Buffer> addedOutput = LinkedListMultimap.create();
+    private final ListMultimap<InputChannelInfo, Buffer> addedInput = LinkedListMultimap.create();
+    private final ListMultimap<ResultSubpartitionInfo, Buffer> addedOutput =
+            LinkedListMultimap.create();
 
     public RecordingChannelStateWriter() {
         super(false);
@@ -41,7 +41,6 @@ public class RecordingChannelStateWriter extends MockChannelStateWriter {
 
     public void reset() {
         lastStartedCheckpointId = -1;
-        lastFinishedCheckpointId = -1;
         addedInput.values().forEach(Buffer::recycleBuffer);
         addedInput.clear();
         addedOutput.values().forEach(Buffer::recycleBuffer);
@@ -80,10 +79,6 @@ public class RecordingChannelStateWriter extends MockChannelStateWriter {
         return lastStartedCheckpointId;
     }
 
-    public long getLastFinishedCheckpointId() {
-        return lastFinishedCheckpointId;
-    }
-
     public ListMultimap<InputChannelInfo, Buffer> getAddedInput() {
         return addedInput;
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java
index a74c1282f58..01f7c43920a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java
@@ -18,17 +18,17 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 /**
  * Base class which contains all tests which should be implemented for every implementation of
  * {@link InputChannelRecoveredStateHandler}.
  */
-public abstract class RecoveredChannelStateHandlerTest {
+abstract class RecoveredChannelStateHandlerTest {
 
     @Test
-    public abstract void testRecycleBufferBeforeRecoverWasCalled() throws Exception;
+    abstract void testRecycleBufferBeforeRecoverWasCalled() throws Exception;
 
     @Test
-    public abstract void testRecycleBufferAfterRecoverWasCalled() throws Exception;
+    abstract void testRecycleBufferAfterRecoverWasCalled() throws Exception;
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java
index b1889277c6c..91d4800e673 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java
@@ -26,24 +26,24 @@ import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.ResultPartition;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
 
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
 
 import java.io.IOException;
 import java.util.HashSet;
 
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Test of different implementation of {@link ResultSubpartitionRecoveredStateHandler}. */
-public class ResultSubpartitionRecoveredStateHandlerTest extends RecoveredChannelStateHandlerTest {
+class ResultSubpartitionRecoveredStateHandlerTest extends RecoveredChannelStateHandlerTest {
     private static final int preAllocatedSegments = 3;
     private NetworkBufferPool networkBufferPool;
     private ResultPartition partition;
     private ResultSubpartitionRecoveredStateHandler rstHandler;
     private ResultSubpartitionInfo channelInfo;
 
-    @Before
-    public void setUp() throws IOException {
+    @BeforeEach
+    void setUp() throws IOException {
         // given: Segment provider with defined number of allocated segments.
         channelInfo = new ResultSubpartitionInfo(0, 0);
 
@@ -74,7 +74,7 @@ public class ResultSubpartitionRecoveredStateHandlerTest extends RecoveredChanne
     }
 
     @Test
-    public void testRecycleBufferBeforeRecoverWasCalled() throws Exception {
+    void testRecycleBufferBeforeRecoverWasCalled() throws Exception {
         // when: Request the buffer.
         RecoveredChannelStateHandler.BufferWithContext<BufferBuilder> bufferWithContext =
                 rstHandler.getBuffer(new ResultSubpartitionInfo(0, 0));
@@ -86,11 +86,12 @@ public class ResultSubpartitionRecoveredStateHandlerTest extends RecoveredChanne
         partition.close();
 
         // then: All pre-allocated segments should be successfully recycled.
-        assertEquals(preAllocatedSegments, networkBufferPool.getNumberOfAvailableMemorySegments());
+        assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                .isEqualTo(preAllocatedSegments);
     }
 
     @Test
-    public void testRecycleBufferAfterRecoverWasCalled() throws Exception {
+    void testRecycleBufferAfterRecoverWasCalled() throws Exception {
         // when: Request the buffer.
         RecoveredChannelStateHandler.BufferWithContext<BufferBuilder> bufferWithContext =
                 rstHandler.getBuffer(channelInfo);
@@ -102,6 +103,7 @@ public class ResultSubpartitionRecoveredStateHandlerTest extends RecoveredChanne
         partition.close();
 
         // then: All pre-allocated segments should be successfully recycled.
-        assertEquals(preAllocatedSegments, networkBufferPool.getNumberOfAvailableMemorySegments());
+        assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                .isEqualTo(preAllocatedSegments);
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
index 72b82f3eb64..7fcb9c00c78 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
@@ -40,18 +40,22 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
 import org.apache.flink.util.function.ThrowingConsumer;
 
 import org.apache.flink.shaded.guava30.com.google.common.io.Closer;
 
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
 
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -65,49 +69,49 @@ import static java.util.function.Function.identity;
 import static java.util.stream.Collectors.toList;
 import static java.util.stream.Collectors.toMap;
 import static java.util.stream.IntStream.range;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** {@link SequentialChannelStateReaderImpl} Test. */
-@RunWith(Parameterized.class)
+@ExtendWith(ParameterizedTestExtension.class)
 public class SequentialChannelStateReaderImplTest {
 
-    @Parameterized.Parameters(
+    @Parameters(
             name =
                     "{0}: stateParLevel={1}, statePartsPerChannel={2}, stateBytesPerPart={3},  parLevel={4}, bufferSize={5}")
-    public static Object[][] parameters() {
-        return new Object[][] {
-            {"NoStateAndNoChannels", 0, 0, 0, 0, 0},
-            {"NoState", 0, 10, 10, 10, 10},
-            {"ReadPermutedStateWithEqualBuffer", 10, 10, 10, 10, 10},
-            {"ReadPermutedStateWithReducedBuffer", 10, 10, 10, 20, 10},
-            {"ReadPermutedStateWithIncreasedBuffer", 10, 10, 10, 10, 20},
-        };
+    public static List<Object[]> parameters() {
+        return Arrays.asList(
+                new Object[] {"NoStateAndNoChannels", 0, 0, 0, 0, 0},
+                new Object[] {"NoState", 0, 10, 10, 10, 10},
+                new Object[] {"ReadPermutedStateWithEqualBuffer", 10, 10, 10, 10, 10},
+                new Object[] {"ReadPermutedStateWithReducedBuffer", 10, 10, 10, 20, 10},
+                new Object[] {"ReadPermutedStateWithIncreasedBuffer", 10, 10, 10, 10, 20});
     }
 
-    private final ChannelStateSerializer serializer;
-    private final Random random;
-    private final int parLevel;
-    private final int statePartsPerChannel;
-    private final int stateBytesPerPart;
-    private final int bufferSize;
-    private final int stateParLevel;
-    private final int buffersPerChannel;
-
-    public SequentialChannelStateReaderImplTest(
-            String desc,
-            int stateParLevel,
-            int statePartsPerChannel,
-            int stateBytesPerPart,
-            int parLevel,
-            int bufferSize) {
+    @Parameter public String desc;
+
+    @Parameter(value = 1)
+    public int stateParLevel;
+
+    @Parameter(value = 2)
+    public int statePartsPerChannel;
+
+    @Parameter(value = 3)
+    public int stateBytesPerPart;
+
+    @Parameter(value = 4)
+    public int parLevel;
+
+    @Parameter(value = 5)
+    public int bufferSize;
+
+    private ChannelStateSerializer serializer;
+    private Random random;
+    private int buffersPerChannel;
+
+    @BeforeEach
+    void before() {
         serializer = new ChannelStateSerializerImpl();
         random = new Random();
-        this.parLevel = parLevel;
-        this.statePartsPerChannel = statePartsPerChannel;
-        this.stateBytesPerPart = stateBytesPerPart;
-        this.bufferSize = bufferSize;
-        this.stateParLevel = stateParLevel;
         // will read without waiting for consumption
         buffersPerChannel =
                 Math.max(
@@ -118,8 +122,8 @@ public class SequentialChannelStateReaderImplTest {
                                         : stateBytesPerPart / bufferSize));
     }
 
-    @Test
-    public void testReadPermutedState() throws Exception {
+    @TestTemplate
+    void testReadPermutedState() throws Exception {
         Map<InputChannelInfo, List<byte[]>> inputChannelsData =
                 generateState(InputChannelInfo::new);
         Map<ResultSubpartitionInfo, List<byte[]>> resultPartitionsData =
@@ -184,7 +188,7 @@ public class SequentialChannelStateReaderImplTest {
     private void assertConsumed(InputGate[] gates)
             throws InterruptedException, java.util.concurrent.ExecutionException {
         for (InputGate gate : gates) {
-            assertTrue(gate.getStateConsumedFuture().isDone());
+            assertThat(gate.getStateConsumedFuture()).isDone();
             gate.getStateConsumedFuture().get();
         }
     }
@@ -218,8 +222,8 @@ public class SequentialChannelStateReaderImplTest {
                 }
                 action.accept(gates);
             }
-            assertEquals(
-                    segmentsToAllocate, networkBufferPool.getNumberOfAvailableMemorySegments());
+            assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                    .isEqualTo(segmentsToAllocate);
         }
     }
 
@@ -247,8 +251,8 @@ public class SequentialChannelStateReaderImplTest {
                 resultPartition.close();
             }
             try {
-                assertEquals(
-                        segmentsToAllocate, networkBufferPool.getNumberOfAvailableMemorySegments());
+                assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                        .isEqualTo(segmentsToAllocate);
             } finally {
                 networkBufferPool.destroyAllBufferPools();
                 networkBufferPool.destroy();
@@ -363,9 +367,8 @@ public class SequentialChannelStateReaderImplTest {
     private <T> void assertBuffersEquals(
             Map<T, List<byte[]>> expected, Map<T, List<Buffer>> actual) {
         try {
-            assertEquals(
-                    mapValues(expected, this::concat),
-                    mapValues(actual, buffers -> concat(toBytes(buffers))));
+            assertThat(mapValues(actual, buffers -> concat(toBytes(buffers))))
+                    .isEqualTo(mapValues(expected, this::concat));
         } finally {
             actual.values().stream().flatMap(List::stream).forEach(Buffer::recycleBuffer);
         }