You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/08/08 14:58:41 UTC

[flink] 01/02: [hotfix][tests] Migrate tests relevant to FLINK-28663 to Junit5/AssertJ

This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b3be6bbd9c99fa988129497e13d7ae97de17c264
Author: kevin.cyj <ke...@alibaba-inc.com>
AuthorDate: Wed Jul 27 19:57:01 2022 +0800

    [hotfix][tests] Migrate tests relevant to FLINK-28663 to Junit5/AssertJ
    
    Migrated tests include DefaultExecutionGraphConstructionTest, EdgeManagerBuildUtilTest, EdgeManagerTest, ExecutionJobVertexTest, IntermediateResultPartitionTest, RemoveCachedShuffleDescriptorTest, JobTaskVertexTest, DefaultExecutionTopologyTest, DefaultExecutionVertexTest, DefaultResultPartitionTest, AdaptiveBatchSchedulerTest and ForwardGroupComputeUtilTest.
    
    This closes #20350.
---
 .../DefaultExecutionGraphConstructionTest.java     | 229 ++++++++++-----------
 .../executiongraph/EdgeManagerBuildUtilTest.java   |  28 +--
 .../runtime/executiongraph/EdgeManagerTest.java    |  24 ++-
 .../executiongraph/ExecutionJobVertexTest.java     | 128 +++++-------
 .../IntermediateResultPartitionTest.java           | 124 ++++++-----
 .../RemoveCachedShuffleDescriptorTest.java         |  82 ++++----
 .../flink/runtime/jobgraph/JobTaskVertexTest.java  | 154 ++++++--------
 .../adapter/DefaultExecutionTopologyTest.java      | 128 +++++-------
 .../adapter/DefaultExecutionVertexTest.java        |  25 ++-
 .../adapter/DefaultResultPartitionTest.java        |  29 ++-
 .../adaptivebatch/AdaptiveBatchSchedulerTest.java  |  38 ++--
 .../forwardgroup/ForwardGroupComputeUtilTest.java  |  37 ++--
 12 files changed, 462 insertions(+), 564 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
index ee4383c8296..9245fa8f966 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.runtime.executiongraph;
 
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.io.InputSplit;
 import org.apache.flink.core.io.InputSplitAssigner;
 import org.apache.flink.core.io.InputSplitSource;
@@ -33,13 +31,12 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
 import org.apache.flink.shaded.guava30.com.google.common.collect.Sets;
 
-import org.junit.ClassRule;
-import org.junit.Test;
-import org.mockito.Matchers;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -51,24 +48,18 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.is;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * This class contains test concerning the correct conversion from {@link JobGraph} to {@link
  * ExecutionGraph} objects. It also tests that {@link EdgeManagerBuildUtil#connectVertexToResult}
  * builds {@link DistributionPattern#ALL_TO_ALL} connections correctly.
  */
-public class DefaultExecutionGraphConstructionTest {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class DefaultExecutionGraphConstructionTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private ExecutionGraph createDefaultExecutionGraph(List<JobVertex> vertices) throws Exception {
         return TestingDefaultExecutionGraphBuilder.newBuilder()
@@ -83,7 +74,7 @@ public class DefaultExecutionGraphConstructionTest {
     }
 
     @Test
-    public void testExecutionAttemptIdInTwoIdenticalJobsIsNotSame() throws Exception {
+    void testExecutionAttemptIdInTwoIdenticalJobsIsNotSame() throws Exception {
         JobVertex v1 = new JobVertex("vertex1");
         JobVertex v2 = new JobVertex("vertex2");
         JobVertex v3 = new JobVertex("vertex3");
@@ -104,10 +95,10 @@ public class DefaultExecutionGraphConstructionTest {
         eg2.attachJobGraph(ordered);
 
         assertThat(
-                Sets.intersection(
-                        eg1.getRegisteredExecutions().keySet(),
-                        eg2.getRegisteredExecutions().keySet()),
-                is(empty()));
+                        Sets.intersection(
+                                eg1.getRegisteredExecutions().keySet(),
+                                eg2.getRegisteredExecutions().keySet()))
+                .isEmpty();
     }
 
     /**
@@ -124,7 +115,7 @@ public class DefaultExecutionGraphConstructionTest {
      * </pre>
      */
     @Test
-    public void testCreateSimpleGraphBipartite() throws Exception {
+    void testCreateSimpleGraphBipartite() throws Exception {
         JobVertex v1 = new JobVertex("vertex1");
         JobVertex v2 = new JobVertex("vertex2");
         JobVertex v3 = new JobVertex("vertex3");
@@ -157,13 +148,7 @@ public class DefaultExecutionGraphConstructionTest {
         List<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2, v3, v4, v5));
 
         ExecutionGraph eg = createDefaultExecutionGraph(ordered);
-        try {
-            eg.attachJobGraph(ordered);
-        } catch (JobException e) {
-            e.printStackTrace();
-            fail("Job failed with exception: " + e.getMessage());
-        }
-
+        eg.attachJobGraph(ordered);
         verifyTestGraph(eg, v1, v2, v3, v4, v5);
     }
 
@@ -187,7 +172,7 @@ public class DefaultExecutionGraphConstructionTest {
     }
 
     @Test
-    public void testCannotConnectWrongOrder() throws Exception {
+    void testCannotConnectWrongOrder() throws Exception {
         JobVertex v1 = new JobVertex("vertex1");
         JobVertex v2 = new JobVertex("vertex2");
         JobVertex v3 = new JobVertex("vertex3");
@@ -220,88 +205,64 @@ public class DefaultExecutionGraphConstructionTest {
         List<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2, v3, v5, v4));
 
         ExecutionGraph eg = createDefaultExecutionGraph(ordered);
-        try {
-            eg.attachJobGraph(ordered);
-            fail("Attached wrong jobgraph");
-        } catch (JobException e) {
-            // expected
-        }
+        assertThatThrownBy(() -> eg.attachJobGraph(ordered)).isInstanceOf(JobException.class);
     }
 
     @Test
-    public void testSetupInputSplits() {
-        try {
-            final InputSplit[] emptySplits = new InputSplit[0];
-
-            InputSplitAssigner assigner1 = mock(InputSplitAssigner.class);
-            InputSplitAssigner assigner2 = mock(InputSplitAssigner.class);
-
-            @SuppressWarnings("unchecked")
-            InputSplitSource<InputSplit> source1 = mock(InputSplitSource.class);
-            @SuppressWarnings("unchecked")
-            InputSplitSource<InputSplit> source2 = mock(InputSplitSource.class);
-
-            when(source1.createInputSplits(Matchers.anyInt())).thenReturn(emptySplits);
-            when(source2.createInputSplits(Matchers.anyInt())).thenReturn(emptySplits);
-            when(source1.getInputSplitAssigner(emptySplits)).thenReturn(assigner1);
-            when(source2.getInputSplitAssigner(emptySplits)).thenReturn(assigner2);
-
-            final JobID jobId = new JobID();
-            final String jobName = "Test Job Sample Name";
-            final Configuration cfg = new Configuration();
-
-            JobVertex v1 = new JobVertex("vertex1");
-            JobVertex v2 = new JobVertex("vertex2");
-            JobVertex v3 = new JobVertex("vertex3");
-            JobVertex v4 = new JobVertex("vertex4");
-            JobVertex v5 = new JobVertex("vertex5");
-
-            v1.setParallelism(5);
-            v2.setParallelism(7);
-            v3.setParallelism(2);
-            v4.setParallelism(11);
-            v5.setParallelism(4);
-
-            v1.setInvokableClass(AbstractInvokable.class);
-            v2.setInvokableClass(AbstractInvokable.class);
-            v3.setInvokableClass(AbstractInvokable.class);
-            v4.setInvokableClass(AbstractInvokable.class);
-            v5.setInvokableClass(AbstractInvokable.class);
-
-            v2.connectNewDataSetAsInput(
-                    v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v4.connectNewDataSetAsInput(
-                    v2, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v4.connectNewDataSetAsInput(
-                    v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v5.connectNewDataSetAsInput(
-                    v4, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v5.connectNewDataSetAsInput(
-                    v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-
-            v3.setInputSplitSource(source1);
-            v5.setInputSplitSource(source2);
-
-            List<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2, v3, v4, v5));
-
-            ExecutionGraph eg = createDefaultExecutionGraph(ordered);
-            try {
-                eg.attachJobGraph(ordered);
-            } catch (JobException e) {
-                e.printStackTrace();
-                fail("Job failed with exception: " + e.getMessage());
-            }
-
-            assertEquals(assigner1, eg.getAllVertices().get(v3.getID()).getSplitAssigner());
-            assertEquals(assigner2, eg.getAllVertices().get(v5.getID()).getSplitAssigner());
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail(e.getMessage());
-        }
+    void testSetupInputSplits() throws Exception {
+        final InputSplit[] emptySplits = new InputSplit[0];
+
+        InputSplitAssigner assigner1 = new TestingInputSplitAssigner();
+        InputSplitAssigner assigner2 = new TestingInputSplitAssigner();
+
+        InputSplitSource<InputSplit> source1 =
+                new TestingInputSplitSource<>(emptySplits, assigner1);
+        InputSplitSource<InputSplit> source2 =
+                new TestingInputSplitSource<>(emptySplits, assigner2);
+
+        JobVertex v1 = new JobVertex("vertex1");
+        JobVertex v2 = new JobVertex("vertex2");
+        JobVertex v3 = new JobVertex("vertex3");
+        JobVertex v4 = new JobVertex("vertex4");
+        JobVertex v5 = new JobVertex("vertex5");
+
+        v1.setParallelism(5);
+        v2.setParallelism(7);
+        v3.setParallelism(2);
+        v4.setParallelism(11);
+        v5.setParallelism(4);
+
+        v1.setInvokableClass(AbstractInvokable.class);
+        v2.setInvokableClass(AbstractInvokable.class);
+        v3.setInvokableClass(AbstractInvokable.class);
+        v4.setInvokableClass(AbstractInvokable.class);
+        v5.setInvokableClass(AbstractInvokable.class);
+
+        v2.connectNewDataSetAsInput(
+                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v4.connectNewDataSetAsInput(
+                v2, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v4.connectNewDataSetAsInput(
+                v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v5.connectNewDataSetAsInput(
+                v4, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v5.connectNewDataSetAsInput(
+                v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+
+        v3.setInputSplitSource(source1);
+        v5.setInputSplitSource(source2);
+
+        List<JobVertex> ordered = new ArrayList<>(Arrays.asList(v1, v2, v3, v4, v5));
+
+        ExecutionGraph eg = createDefaultExecutionGraph(ordered);
+        eg.attachJobGraph(ordered);
+
+        assertThat(eg.getAllVertices().get(v3.getID()).getSplitAssigner()).isEqualTo(assigner1);
+        assertThat(eg.getAllVertices().get(v5.getID()).getSplitAssigner()).isEqualTo(assigner2);
     }
 
     @Test
-    public void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception {
+    void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
 
@@ -321,9 +282,8 @@ public class DefaultExecutionGraphConstructionTest {
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
 
-        assertEquals(
-                partition1.getConsumedPartitionGroups().get(0),
-                partition2.getConsumedPartitionGroups().get(0));
+        assertThat(partition2.getConsumedPartitionGroups().get(0))
+                .isEqualTo(partition1.getConsumedPartitionGroups().get(0));
 
         ConsumedPartitionGroup consumedPartitionGroup =
                 partition1.getConsumedPartitionGroups().get(0);
@@ -331,13 +291,13 @@ public class DefaultExecutionGraphConstructionTest {
         for (IntermediateResultPartitionID partitionId : consumedPartitionGroup) {
             partitionIds.add(partitionId);
         }
-        assertThat(
-                partitionIds,
-                containsInAnyOrder(partition1.getPartitionId(), partition2.getPartitionId()));
+        assertThat(partitionIds)
+                .containsExactlyInAnyOrder(
+                        partition1.getPartitionId(), partition2.getPartitionId());
     }
 
     @Test
-    public void testAttachToDynamicGraph() throws Exception {
+    void testAttachToDynamicGraph() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
 
@@ -351,9 +311,42 @@ public class DefaultExecutionGraphConstructionTest {
         ExecutionGraph eg = createDynamicExecutionGraph(ordered);
         eg.attachJobGraph(ordered);
 
-        assertThat(eg.getAllVertices().size(), is(2));
+        assertThat(eg.getAllVertices()).hasSize(2);
         Iterator<ExecutionJobVertex> jobVertices = eg.getVerticesTopologically().iterator();
-        assertThat(jobVertices.next().isInitialized(), is(false));
-        assertThat(jobVertices.next().isInitialized(), is(false));
+        assertThat(jobVertices.next().isInitialized()).isFalse();
+        assertThat(jobVertices.next().isInitialized()).isFalse();
+    }
+
+    private static final class TestingInputSplitAssigner implements InputSplitAssigner {
+
+        @Override
+        public InputSplit getNextInputSplit(String host, int taskId) {
+            return null;
+        }
+
+        @Override
+        public void returnInputSplit(List<InputSplit> splits, int taskId) {}
+    }
+
+    private static final class TestingInputSplitSource<T extends InputSplit>
+            implements InputSplitSource<T> {
+
+        private final T[] inputSplits;
+        private final InputSplitAssigner assigner;
+
+        private TestingInputSplitSource(T[] inputSplits, InputSplitAssigner assigner) {
+            this.inputSplits = inputSplits;
+            this.assigner = assigner;
+        }
+
+        @Override
+        public T[] createInputSplits(int minNumSplits) throws Exception {
+            return inputSplits;
+        }
+
+        @Override
+        public InputSplitAssigner getInputSplitAssigner(T[] inputSplits) {
+            return assigner;
+        }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index 2cd5a8d1ce3..e3b603a8790 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -25,11 +25,11 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
 import org.apache.commons.lang3.tuple.Pair;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -38,20 +38,20 @@ import java.util.concurrent.ScheduledExecutorService;
 
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
  * Tests for {@link EdgeManagerBuildUtil} to verify the max number of connecting edges between
  * vertices for pattern of both {@link DistributionPattern#POINTWISE} and {@link
  * DistributionPattern#ALL_TO_ALL}.
  */
-public class EdgeManagerBuildUtilTest {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class EdgeManagerBuildUtilTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testGetMaxNumEdgesToTargetInPointwiseConnection() throws Exception {
+    void testGetMaxNumEdgesToTargetInPointwiseConnection() throws Exception {
         testGetMaxNumEdgesToTarget(17, 17, POINTWISE);
         testGetMaxNumEdgesToTarget(17, 23, POINTWISE);
         testGetMaxNumEdgesToTarget(17, 34, POINTWISE);
@@ -60,7 +60,7 @@ public class EdgeManagerBuildUtilTest {
     }
 
     @Test
-    public void testGetMaxNumEdgesToTargetInAllToAllConnection() throws Exception {
+    void testGetMaxNumEdgesToTargetInAllToAllConnection() throws Exception {
         testGetMaxNumEdgesToTarget(17, 17, ALL_TO_ALL);
         testGetMaxNumEdgesToTarget(17, 23, ALL_TO_ALL);
         testGetMaxNumEdgesToTarget(17, 34, ALL_TO_ALL);
@@ -81,7 +81,7 @@ public class EdgeManagerBuildUtilTest {
                         upstream, downstream, pattern);
         int actualMaxForUpstream = -1;
         for (ExecutionVertex ev : upstreamEJV.getTaskVertices()) {
-            assertEquals(1, ev.getProducedPartitions().size());
+            assertThat(ev.getProducedPartitions()).hasSize(1);
 
             IntermediateResultPartition partition =
                     ev.getProducedPartitions().values().iterator().next();
@@ -91,21 +91,21 @@ public class EdgeManagerBuildUtilTest {
                 actualMaxForUpstream = actual;
             }
         }
-        assertEquals(actualMaxForUpstream, calculatedMaxForUpstream);
+        assertThat(actualMaxForUpstream).isEqualTo(calculatedMaxForUpstream);
 
         int calculatedMaxForDownstream =
                 EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
                         downstream, upstream, pattern);
         int actualMaxForDownstream = -1;
         for (ExecutionVertex ev : downstreamEJV.getTaskVertices()) {
-            assertEquals(1, ev.getNumberOfInputs());
+            assertThat(ev.getNumberOfInputs()).isEqualTo(1);
 
             int actual = ev.getConsumedPartitionGroup(0).size();
             if (actual > actualMaxForDownstream) {
                 actualMaxForDownstream = actual;
             }
         }
-        assertEquals(actualMaxForDownstream, calculatedMaxForDownstream);
+        assertThat(actualMaxForDownstream).isEqualTo(calculatedMaxForDownstream);
     }
 
     private Pair<ExecutionJobVertex, ExecutionJobVertex> setupExecutionGraph(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
index 323a753a511..81165da6193 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
@@ -29,25 +29,25 @@ import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Objects;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link EdgeManager}. */
-public class EdgeManagerTest {
+class EdgeManagerTest {
 
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testGetConsumedPartitionGroup() throws Exception {
+    void testGetConsumedPartitionGroup() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
 
@@ -82,7 +82,8 @@ public class EdgeManagerTest {
         ConsumedPartitionGroup groupRetrievedByIntermediateResultPartition =
                 consumedPartition.getConsumedPartitionGroups().get(0);
 
-        assertEquals(groupRetrievedByDownstreamVertex, groupRetrievedByIntermediateResultPartition);
+        assertThat(groupRetrievedByIntermediateResultPartition)
+                .isEqualTo(groupRetrievedByDownstreamVertex);
 
         ConsumedPartitionGroup groupRetrievedByScheduledResultPartition =
                 scheduler
@@ -92,6 +93,7 @@ public class EdgeManagerTest {
                         .getConsumedPartitionGroups()
                         .get(0);
 
-        assertEquals(groupRetrievedByDownstreamVertex, groupRetrievedByScheduledResultPartition);
+        assertThat(groupRetrievedByScheduledResultPartition)
+                .isEqualTo(groupRetrievedByDownstreamVertex);
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
index 8f2a35cece5..d847e9ede68 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
@@ -29,164 +29,128 @@ import org.apache.flink.runtime.scheduler.VertexParallelismInformation;
 import org.apache.flink.runtime.scheduler.VertexParallelismStore;
 import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.Assert;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Collections;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.apache.flink.core.testutils.CommonTestUtils.assertThrows;
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Test for {@link ExecutionJobVertex} */
-public class ExecutionJobVertexTest {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class ExecutionJobVertexTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testParallelismGreaterThanMaxParallelism() {
+    void testParallelismGreaterThanMaxParallelism() {
         JobVertex jobVertex = new JobVertex("testVertex");
         jobVertex.setInvokableClass(AbstractInvokable.class);
         // parallelism must be smaller than the max parallelism
         jobVertex.setParallelism(172);
         jobVertex.setMaxParallelism(4);
 
-        assertThrows(
-                "higher than the max parallelism",
-                JobException.class,
-                () -> ExecutionGraphTestUtils.getExecutionJobVertex(jobVertex));
+        assertThatThrownBy(() -> ExecutionGraphTestUtils.getExecutionJobVertex(jobVertex))
+                .isInstanceOf(JobException.class)
+                .hasMessageContaining("higher than the max parallelism");
     }
 
     @Test
-    public void testLazyInitialization() throws Exception {
+    void testLazyInitialization() throws Exception {
         final int parallelism = 3;
         final int configuredMaxParallelism = 12;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(parallelism, configuredMaxParallelism, -1);
 
-        assertThat(ejv.getParallelism(), is(parallelism));
-        assertThat(ejv.getMaxParallelism(), is(configuredMaxParallelism));
-        assertThat(ejv.isInitialized(), is(false));
+        assertThat(ejv.getParallelism()).isEqualTo(parallelism);
+        assertThat(ejv.getMaxParallelism()).isEqualTo(configuredMaxParallelism);
+        assertThat(ejv.isInitialized()).isFalse();
 
-        assertThat(ejv.getTaskVertices().length, is(0));
+        assertThat(ejv.getTaskVertices()).isEmpty();
 
-        try {
-            ejv.getInputs();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getInputs).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.getProducedDataSets();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getProducedDataSets).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.getSplitAssigner();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getSplitAssigner).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.getOperatorCoordinators();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getOperatorCoordinators).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.connectToPredecessors(Collections.emptyMap());
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(() -> ejv.connectToPredecessors(Collections.emptyMap()))
+                .isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.executionVertexFinished();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::executionVertexFinished).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.executionVertexUnFinished();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::executionVertexUnFinished)
+                .isInstanceOf(IllegalStateException.class);
 
         initializeVertex(ejv);
 
-        assertThat(ejv.isInitialized(), is(true));
-        assertThat(ejv.getTaskVertices().length, is(3));
-        assertThat(ejv.getInputs().size(), is(0));
-        assertThat(ejv.getProducedDataSets().length, is(1));
-        assertThat(ejv.getOperatorCoordinators().size(), is(0));
+        assertThat(ejv.isInitialized()).isTrue();
+        assertThat(ejv.getTaskVertices()).hasSize(3);
+        assertThat(ejv.getInputs()).isEmpty();
+        assertThat(ejv.getProducedDataSets()).hasSize(1);
+        assertThat(ejv.getOperatorCoordinators()).isEmpty();
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testErrorIfInitializationWithoutParallelismDecided() throws Exception {
+    @Test
+    void testErrorIfInitializationWithoutParallelismDecided() throws Exception {
         final ExecutionJobVertex ejv = createDynamicExecutionJobVertex();
 
-        initializeVertex(ejv);
+        assertThatThrownBy(() -> initializeVertex(ejv)).isInstanceOf(IllegalStateException.class);
     }
 
     @Test
-    public void testSetParallelismLazily() throws Exception {
+    void testSetParallelismLazily() throws Exception {
         final int parallelism = 3;
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(-1, -1, defaultMaxParallelism);
 
-        assertThat(ejv.isParallelismDecided(), is(false));
+        assertThat(ejv.isParallelismDecided()).isFalse();
 
         ejv.setParallelism(parallelism);
 
-        assertThat(ejv.isParallelismDecided(), is(true));
-        assertThat(ejv.getParallelism(), is(parallelism));
+        assertThat(ejv.isParallelismDecided()).isTrue();
+        assertThat(ejv.getParallelism()).isEqualTo(parallelism);
 
         initializeVertex(ejv);
 
-        assertThat(ejv.getTaskVertices().length, is(parallelism));
+        assertThat(ejv.getTaskVertices()).hasSize(parallelism);
     }
 
     @Test
-    public void testConfiguredMaxParallelismIsRespected() throws Exception {
+    void testConfiguredMaxParallelismIsRespected() throws Exception {
         final int configuredMaxParallelism = 12;
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(
                         -1, configuredMaxParallelism, defaultMaxParallelism);
 
-        assertThat(ejv.getMaxParallelism(), is(configuredMaxParallelism));
+        assertThat(ejv.getMaxParallelism()).isEqualTo(configuredMaxParallelism);
     }
 
     @Test
-    public void testComputingMaxParallelismFromConfiguredParallelism() throws Exception {
+    void testComputingMaxParallelismFromConfiguredParallelism() throws Exception {
         final int parallelism = 300;
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(parallelism, -1, defaultMaxParallelism);
 
-        assertThat(ejv.getMaxParallelism(), is(512));
+        assertThat(ejv.getMaxParallelism()).isEqualTo(512);
     }
 
     @Test
-    public void testFallingBackToDefaultMaxParallelism() throws Exception {
+    void testFallingBackToDefaultMaxParallelism() throws Exception {
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(-1, -1, defaultMaxParallelism);
 
-        assertThat(ejv.getMaxParallelism(), is(defaultMaxParallelism));
+        assertThat(ejv.getMaxParallelism()).isEqualTo(defaultMaxParallelism);
     }
 
     static void initializeVertex(ExecutionJobVertex vertex) throws Exception {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
index 456de025668..5ff5d9f201c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
@@ -34,11 +34,10 @@ import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.runtime.testutils.DirectScheduledExecutorService;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Arrays;
 import java.util.Iterator;
@@ -46,42 +45,37 @@ import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link IntermediateResultPartition}. */
-public class IntermediateResultPartitionTest extends TestLogger {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+public class IntermediateResultPartitionTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testPipelinedPartitionConsumable() throws Exception {
+    void testPipelinedPartitionConsumable() throws Exception {
         IntermediateResult result = createResult(ResultPartitionType.PIPELINED, 2);
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
 
         // Not consumable on init
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
 
         // Partition 1 consumable after data are produced
         partition1.markDataProduced();
-        assertTrue(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isFalse();
 
         // Not consumable if failover happens
         result.resetForNewExecution();
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
     }
 
     @Test
-    public void testBlockingPartitionConsumable() throws Exception {
+    void testBlockingPartitionConsumable() throws Exception {
         IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2);
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
@@ -90,31 +84,31 @@ public class IntermediateResultPartitionTest extends TestLogger {
                 partition1.getConsumedPartitionGroups().get(0);
 
         // Not consumable on init
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Not consumable if only one partition is FINISHED
         partition1.markFinished();
-        assertTrue(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Consumable after all partitions are FINISHED
         partition2.markFinished();
-        assertTrue(partition1.isConsumable());
-        assertTrue(partition2.isConsumable());
-        assertTrue(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isTrue();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isTrue();
 
         // Not consumable if failover happens
         result.resetForNewExecution();
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
     }
 
     @Test
-    public void testBlockingPartitionResetting() throws Exception {
+    void testBlockingPartitionResetting() throws Exception {
         IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2);
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
@@ -123,71 +117,71 @@ public class IntermediateResultPartitionTest extends TestLogger {
                 partition1.getConsumedPartitionGroups().get(0);
 
         // Not consumable on init
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
 
         // Not consumable if partition1 is FINISHED
         partition1.markFinished();
-        assertEquals(1, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertTrue(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(1);
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Reset the result and mark partition2 FINISHED, the result should still not be consumable
         result.resetForNewExecution();
-        assertEquals(2, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(2);
         partition2.markFinished();
-        assertEquals(1, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertFalse(partition1.isConsumable());
-        assertTrue(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(1);
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isTrue();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Consumable after all partitions are FINISHED
         partition1.markFinished();
-        assertEquals(0, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertTrue(partition1.isConsumable());
-        assertTrue(partition2.isConsumable());
-        assertTrue(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(0);
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isTrue();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isTrue();
 
         // Not consumable again if failover happens
         result.resetForNewExecution();
-        assertEquals(2, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(2);
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception {
+    void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, false, Arrays.asList(7, 7));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsForNonDynamicPointwiseGraph() throws Exception {
+    void testGetNumberOfSubpartitionsForNonDynamicPointwiseGraph() throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.POINTWISE, false, Arrays.asList(4, 3));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicAllToAllGraph()
+    void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicAllToAllGraph()
             throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, true, Arrays.asList(7, 7));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicPointwiseGraph()
+    void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicPointwiseGraph()
             throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.POINTWISE, true, Arrays.asList(4, 4));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicAllToAllGraph()
+    void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicAllToAllGraph()
             throws Exception {
         testGetNumberOfSubpartitions(
                 -1, DistributionPattern.ALL_TO_ALL, true, Arrays.asList(13, 13));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicPointwiseGraph()
+    void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicPointwiseGraph()
             throws Exception {
         testGetNumberOfSubpartitions(-1, DistributionPattern.POINTWISE, true, Arrays.asList(7, 7));
     }
@@ -221,12 +215,12 @@ public class IntermediateResultPartitionTest extends TestLogger {
 
         final IntermediateResult result = producer.getProducedDataSets()[0];
 
-        assertThat(expectedNumSubpartitions.size(), is(producerParallelism));
+        assertThat(expectedNumSubpartitions).hasSize(producerParallelism);
         assertThat(
-                Arrays.stream(result.getPartitions())
-                        .map(IntermediateResultPartition::getNumberOfSubpartitions)
-                        .collect(Collectors.toList()),
-                equalTo(expectedNumSubpartitions));
+                        Arrays.stream(result.getPartitions())
+                                .map(IntermediateResultPartition::getNumberOfSubpartitions)
+                                .collect(Collectors.toList()))
+                .isEqualTo(expectedNumSubpartitions);
     }
 
     public static ExecutionGraph createExecutionGraph(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
index 73dd9cdeb3a..7c00da22a85 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
@@ -43,13 +43,12 @@ import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.After;
-import org.junit.Before;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -62,28 +61,27 @@ import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeoutException;
 
 import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
  * Tests for removing cached {@link ShuffleDescriptor}s when the related partitions are no longer
  * valid. Currently, there are two scenarios as illustrated in {@link
  * IntermediateResult#clearCachedInformationForPartitionGroup}.
  */
-public class RemoveCachedShuffleDescriptorTest extends TestLogger {
+class RemoveCachedShuffleDescriptorTest {
 
     private static final int PARALLELISM = 4;
 
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private ScheduledExecutorService scheduledExecutorService;
     private ComponentMainThreadExecutor mainThreadExecutor;
     private ManuallyTriggeredScheduledExecutorService ioExecutor;
 
-    @Before
-    public void setup() {
+    @BeforeEach
+    void setup() {
         scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
         mainThreadExecutor =
                 ComponentMainThreadExecutorServiceAdapter.forSingleThreadExecutor(
@@ -91,21 +89,21 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         ioExecutor = new ManuallyTriggeredScheduledExecutorService();
     }
 
-    @After
-    public void teardown() {
+    @AfterEach
+    void teardown() {
         if (scheduledExecutorService != null) {
             scheduledExecutorService.shutdownNow();
         }
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
+    void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
         // Here we expect no offloaded BLOB.
         testRemoveCacheForAllToAllEdgeAfterFinished(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
+    void testRemoveOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
         // Here we expect 4 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the ALL-TO-ALL
         // edge (1).
@@ -129,8 +127,8 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(PARALLELISM, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(PARALLELISM);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         // For the all-to-all edge, we transition all downstream tasks to finished
         CompletableFuture.runAsync(
@@ -140,17 +138,17 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         ioExecutor.triggerAll();
 
         // Cache should be removed since partitions are released
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2));
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2)).isNull();
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
+    void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
         testRemoveCacheForAllToAllEdgeAfterFailover(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
+    void testRemoveOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
         // Here we expect 4 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the ALL-TO-ALL
         // edge (1).
@@ -174,25 +172,25 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(PARALLELISM, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(PARALLELISM);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         triggerGlobalFailoverAndComplete(scheduler, v1);
         ioExecutor.triggerAll();
 
         // Cache should be removed during ExecutionVertex#resetForNewExecution
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2));
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2)).isNull();
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
+    void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
         testRemoveCacheForPointwiseEdgeAfterFinished(
                 new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
+    void testRemoveOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
         // Here we expect 7 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the POINTWISE
         // edges (4).
@@ -216,8 +214,8 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(1, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(1);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         // For the pointwise edge, we just transition the first downstream task to FINISHED
         ExecutionVertex ev21 =
@@ -229,7 +227,7 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         ioExecutor.triggerAll();
 
         // The cache of the first upstream task should be removed since its partition is released
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0));
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0)).isNull();
 
         // The cache of the other upstream tasks should stay
         final ShuffleDescriptor[] shuffleDescriptorsForOtherVertex =
@@ -237,19 +235,19 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
                         getConsumedCachedShuffleDescriptor(executionGraph, v2, 1),
                         jobId,
                         blobWriter);
-        assertEquals(1, shuffleDescriptorsForOtherVertex.length);
+        assertThat(shuffleDescriptorsForOtherVertex).hasSize(1);
 
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
+    void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
         testRemoveCacheForPointwiseEdgeAfterFailover(
                 new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
+    void testRemoveOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
         // Here we expect 7 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the POINTWISE
         // edges (4).
@@ -273,15 +271,15 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(1, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(1);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         triggerExceptionAndComplete(executionGraph, v1, v2);
         ioExecutor.triggerAll();
 
         // The cache of the first upstream task should be removed during
         // ExecutionVertex#resetForNewExecution
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0));
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0)).isNull();
 
         // The cache of the other upstream tasks should stay
         final ShuffleDescriptor[] shuffleDescriptorsForOtherVertex =
@@ -289,9 +287,9 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
                         getConsumedCachedShuffleDescriptor(executionGraph, v2, 1),
                         jobId,
                         blobWriter);
-        assertEquals(1, shuffleDescriptorsForOtherVertex.length);
+        assertThat(shuffleDescriptorsForOtherVertex).hasSize(1);
 
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     private DefaultScheduler createSchedulerAndDeploy(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
index 47cc722fa9b..66e097b34b1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
@@ -29,117 +29,89 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.operators.util.TaskConfig;
 import org.apache.flink.util.InstantiationUtil;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.io.IOException;
 import java.net.URL;
 import java.net.URLClassLoader;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 @SuppressWarnings("serial")
-public class JobTaskVertexTest {
+class JobTaskVertexTest {
 
     @Test
-    public void testConnectDirectly() {
+    void testConnectDirectly() {
         JobVertex source = new JobVertex("source");
         JobVertex target = new JobVertex("target");
         target.connectNewDataSetAsInput(
                 source, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
 
-        assertTrue(source.isInputVertex());
-        assertFalse(source.isOutputVertex());
-        assertFalse(target.isInputVertex());
-        assertTrue(target.isOutputVertex());
+        assertThat(source.isInputVertex()).isTrue();
+        assertThat(source.isOutputVertex()).isFalse();
+        assertThat(target.isInputVertex()).isFalse();
+        assertThat(target.isOutputVertex()).isTrue();
 
-        assertEquals(1, source.getNumberOfProducedIntermediateDataSets());
-        assertEquals(1, target.getNumberOfInputs());
+        assertThat(source.getNumberOfProducedIntermediateDataSets()).isEqualTo(1);
+        assertThat(target.getNumberOfInputs()).isEqualTo(1);
 
-        assertEquals(target.getInputs().get(0).getSource(), source.getProducedDataSets().get(0));
+        assertThat(source.getProducedDataSets().get(0))
+                .isEqualTo(target.getInputs().get(0).getSource());
 
-        assertEquals(target, source.getProducedDataSets().get(0).getConsumer().getTarget());
+        assertThat(source.getProducedDataSets().get(0).getConsumer().getTarget()).isEqualTo(target);
     }
 
     @Test
-    public void testOutputFormat() {
-        try {
-            final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
-
-            OperatorID operatorID = new OperatorID();
-            Configuration parameters = new Configuration();
-            parameters.setString("test_key", "test_value");
-            new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
-                    .addOutputFormat(operatorID, new TestingOutputFormat(parameters))
-                    .addParameters(operatorID, parameters)
-                    .write(new TaskConfig(vertex.getConfiguration()));
-
-            final ClassLoader cl = new TestClassLoader();
-
-            try {
-                vertex.initializeOnMaster(cl);
-                fail("Did not throw expected exception.");
-            } catch (TestException e) {
-                // all good
-            }
+    void testOutputFormat() throws Exception {
+        final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
 
-            InputOutputFormatVertex copy = InstantiationUtil.clone(vertex);
-            ClassLoader ctxCl = Thread.currentThread().getContextClassLoader();
-            try {
-                copy.initializeOnMaster(cl);
-                fail("Did not throw expected exception.");
-            } catch (TestException e) {
-                // all good
-            }
-            assertEquals(
-                    "Previous classloader was not restored.",
-                    ctxCl,
-                    Thread.currentThread().getContextClassLoader());
-
-            try {
-                copy.finalizeOnMaster(cl);
-                fail("Did not throw expected exception.");
-            } catch (TestException e) {
-                // all good
-            }
-            assertEquals(
-                    "Previous classloader was not restored.",
-                    ctxCl,
-                    Thread.currentThread().getContextClassLoader());
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail(e.getMessage());
-        }
+        OperatorID operatorID = new OperatorID();
+        Configuration parameters = new Configuration();
+        parameters.setString("test_key", "test_value");
+        new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
+                .addOutputFormat(operatorID, new TestingOutputFormat(parameters))
+                .addParameters(operatorID, parameters)
+                .write(new TaskConfig(vertex.getConfiguration()));
+
+        final ClassLoader cl = new TestClassLoader();
+
+        assertThatThrownBy(() -> vertex.initializeOnMaster(cl)).isInstanceOf(TestException.class);
+
+        InputOutputFormatVertex copy = InstantiationUtil.clone(vertex);
+        ClassLoader ctxCl = Thread.currentThread().getContextClassLoader();
+        assertThatThrownBy(() -> copy.initializeOnMaster(cl)).isInstanceOf(TestException.class);
+
+        assertThat(Thread.currentThread().getContextClassLoader())
+                .as("Previous classloader was not restored.")
+                .isEqualTo(ctxCl);
+
+        assertThatThrownBy(() -> copy.finalizeOnMaster(cl)).isInstanceOf(TestException.class);
+        assertThat(Thread.currentThread().getContextClassLoader())
+                .as("Previous classloader was not restored.")
+                .isEqualTo(ctxCl);
     }
 
     @Test
-    public void testInputFormat() {
-        try {
-            final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
-
-            OperatorID operatorID = new OperatorID();
-            Configuration parameters = new Configuration();
-            parameters.setString("test_key", "test_value");
-            new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
-                    .addInputFormat(operatorID, new TestInputFormat(parameters))
-                    .addParameters(operatorID, "test_key", "test_value")
-                    .write(new TaskConfig(vertex.getConfiguration()));
-
-            final ClassLoader cl = new TestClassLoader();
-
-            vertex.initializeOnMaster(cl);
-            InputSplit[] splits = vertex.getInputSplitSource().createInputSplits(77);
-
-            assertNotNull(splits);
-            assertEquals(1, splits.length);
-            assertEquals(TestSplit.class, splits[0].getClass());
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail(e.getMessage());
-        }
+    void testInputFormat() throws Exception {
+        final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
+
+        OperatorID operatorID = new OperatorID();
+        Configuration parameters = new Configuration();
+        parameters.setString("test_key", "test_value");
+        new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
+                .addInputFormat(operatorID, new TestInputFormat(parameters))
+                .addParameters(operatorID, "test_key", "test_value")
+                .write(new TaskConfig(vertex.getConfiguration()));
+
+        final ClassLoader cl = new TestClassLoader();
+
+        vertex.initializeOnMaster(cl);
+        InputSplit[] splits = vertex.getInputSplitSource().createInputSplits(77);
+
+        assertThat(splits).isNotNull();
+        assertThat(splits).hasSize(1);
+        assertThat(splits[0].getClass()).isEqualTo(TestSplit.class);
     }
 
     // --------------------------------------------------------------------------------------------
@@ -191,8 +163,8 @@ public class JobTaskVertexTest {
                 throw new IllegalStateException("Context ClassLoader was not correctly switched.");
             }
             for (String key : expectedParameters.keySet()) {
-                assertEquals(
-                        expectedParameters.getString(key, null), parameters.getString(key, null));
+                assertThat(parameters.getString(key, null))
+                        .isEqualTo(expectedParameters.getString(key, null));
             }
             isConfigured = true;
         }
@@ -244,8 +216,8 @@ public class JobTaskVertexTest {
                 throw new IllegalStateException("Context ClassLoader was not correctly switched.");
             }
             for (String key : expectedParameters.keySet()) {
-                assertEquals(
-                        expectedParameters.getString(key, null), parameters.getString(key, null));
+                assertThat(parameters.getString(key, null))
+                        .isEqualTo(expectedParameters.getString(key, null));
             }
             isConfigured = true;
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
index 109484fc1a1..4e1d9cb76a4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
@@ -37,16 +37,14 @@ import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingPipelinedRegion;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 import org.apache.flink.util.IterableUtils;
-import org.apache.flink.util.TestLogger;
 
-import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables;
 import org.apache.flink.shaded.guava30.com.google.common.collect.Sets;
 
-import org.junit.Before;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -58,31 +56,26 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static junit.framework.TestCase.assertSame;
-import static junit.framework.TestCase.assertTrue;
 import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createExecutionGraph;
 import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createNoOpVertex;
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.PIPELINED;
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.Is.is;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Unit tests for {@link DefaultExecutionTopology}. */
-public class DefaultExecutionTopologyTest extends TestLogger {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class DefaultExecutionTopologyTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private DefaultExecutionGraph executionGraph;
 
     private DefaultExecutionTopology adapter;
 
-    @Before
-    public void setUp() throws Exception {
+    @BeforeEach
+    void setUp() throws Exception {
         JobVertex[] jobVertices = new JobVertex[2];
         int parallelism = 3;
         jobVertices[0] = createNoOpVertex(parallelism);
@@ -93,13 +86,13 @@ public class DefaultExecutionTopologyTest extends TestLogger {
     }
 
     @Test
-    public void testConstructor() {
+    void testConstructor() {
         // implicitly tests order constraint of getVertices()
         assertGraphEquals(executionGraph, adapter);
     }
 
     @Test
-    public void testGetResultPartition() {
+    void testGetResultPartition() {
         for (ExecutionVertex vertex : executionGraph.getAllExecutionVertices()) {
             for (Map.Entry<IntermediateResultPartitionID, IntermediateResultPartition> entry :
                     vertex.getProducedPartitions().entrySet()) {
@@ -113,7 +106,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
     }
 
     @Test
-    public void testResultPartitionStateSupplier() {
+    void testResultPartitionStateSupplier() {
         final IntermediateResultPartition intermediateResultPartition =
                 IterableUtils.toStream(executionGraph.getAllExecutionVertices())
                         .flatMap(v -> v.getProducedPartitions().values().stream())
@@ -123,41 +116,33 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         final DefaultResultPartition schedulingResultPartition =
                 adapter.getResultPartition(intermediateResultPartition.getPartitionId());
 
-        assertEquals(ResultPartitionState.CREATED, schedulingResultPartition.getState());
+        assertThat(schedulingResultPartition.getState()).isEqualTo(ResultPartitionState.CREATED);
 
         intermediateResultPartition.markDataProduced();
-        assertEquals(ResultPartitionState.CONSUMABLE, schedulingResultPartition.getState());
+        assertThat(schedulingResultPartition.getState()).isEqualTo(ResultPartitionState.CONSUMABLE);
     }
 
     @Test
-    public void testGetVertexOrThrow() {
-        try {
-            adapter.getVertex(new ExecutionVertexID(new JobVertexID(), 0));
-            fail("get not exist vertex");
-        } catch (IllegalArgumentException exception) {
-            // expected
-        }
+    void testGetVertexOrThrow() {
+        assertThatThrownBy(() -> adapter.getVertex(new ExecutionVertexID(new JobVertexID(), 0)))
+                .isInstanceOf(IllegalArgumentException.class);
     }
 
     @Test
-    public void testResultPartitionOrThrow() {
-        try {
-            adapter.getResultPartition(new IntermediateResultPartitionID());
-            fail("get not exist result partition");
-        } catch (IllegalArgumentException exception) {
-            // expected
-        }
+    void testResultPartitionOrThrow() {
+        assertThatThrownBy(() -> adapter.getResultPartition(new IntermediateResultPartitionID()))
+                .isInstanceOf(IllegalArgumentException.class);
     }
 
     @Test
-    public void testGetAllPipelinedRegions() {
+    void testGetAllPipelinedRegions() {
         final Iterable<DefaultSchedulingPipelinedRegion> allPipelinedRegions =
                 adapter.getAllPipelinedRegions();
-        assertEquals(1, Iterables.size(allPipelinedRegions));
+        assertThat(allPipelinedRegions).hasSize(1);
     }
 
     @Test
-    public void testGetPipelinedRegionOfVertex() {
+    void testGetPipelinedRegionOfVertex() {
         for (DefaultExecutionVertex vertex : adapter.getVertices()) {
             final DefaultSchedulingPipelinedRegion pipelinedRegion =
                     adapter.getPipelinedRegionOfVertex(vertex.getId());
@@ -165,8 +150,8 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         }
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testErrorIfCoLocatedTasksAreNotInSameRegion() throws Exception {
+    @Test
+    void testErrorIfCoLocatedTasksAreNotInSameRegion() throws Exception {
         int parallelism = 3;
         final JobVertex v1 = createNoOpVertex(parallelism);
         final JobVertex v2 = createNoOpVertex(parallelism);
@@ -176,13 +161,12 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         v2.setSlotSharingGroup(slotSharingGroup);
         v1.setStrictlyCoLocatedWith(v2);
 
-        final DefaultExecutionGraph executionGraph =
-                createExecutionGraph(EXECUTOR_RESOURCE.getExecutor(), v1, v2);
-        DefaultExecutionTopology.fromExecutionGraph(executionGraph);
+        assertThatThrownBy(() -> createExecutionGraph(EXECUTOR_RESOURCE.getExecutor(), v1, v2))
+                .isInstanceOf(IllegalStateException.class);
     }
 
     @Test
-    public void testUpdateTopology() throws Exception {
+    void testUpdateTopology() throws Exception {
         final JobVertex[] jobVertices = createJobVertices(BLOCKING);
         executionGraph = createDynamicGraph(jobVertices);
         adapter = DefaultExecutionTopology.fromExecutionGraph(executionGraph);
@@ -192,18 +176,17 @@ public class DefaultExecutionTopologyTest extends TestLogger {
 
         executionGraph.initializeJobVertex(ejv1, 0L);
         adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv1));
-        assertThat(IterableUtils.toStream(adapter.getVertices()).count(), is(3L));
+        assertThat(adapter.getVertices()).hasSize(3);
 
         executionGraph.initializeJobVertex(ejv2, 0L);
         adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv2));
-        assertThat(IterableUtils.toStream(adapter.getVertices()).count(), is(6L));
+        assertThat(adapter.getVertices()).hasSize(6);
 
         assertGraphEquals(executionGraph, adapter);
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testErrorIfUpdateTopologyWithNewVertexPipelinedConnectedToOldOnes()
-            throws Exception {
+    @Test
+    void testErrorIfUpdateTopologyWithNewVertexPipelinedConnectedToOldOnes() throws Exception {
         final JobVertex[] jobVertices = createJobVertices(PIPELINED);
         executionGraph = createDynamicGraph(jobVertices);
         adapter = DefaultExecutionTopology.fromExecutionGraph(executionGraph);
@@ -215,11 +198,15 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv1));
 
         executionGraph.initializeJobVertex(ejv2, 0L);
-        adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv2));
+        assertThatThrownBy(
+                        () ->
+                                adapter.notifyExecutionGraphUpdated(
+                                        executionGraph, Collections.singletonList(ejv2)))
+                .isInstanceOf(IllegalStateException.class);
     }
 
     @Test
-    public void testExistingRegionsAreNotAffectedDuringTopologyUpdate() throws Exception {
+    void testExistingRegionsAreNotAffectedDuringTopologyUpdate() throws Exception {
         final JobVertex[] jobVertices = createJobVertices(BLOCKING);
         executionGraph = createDynamicGraph(jobVertices);
         adapter = DefaultExecutionTopology.fromExecutionGraph(executionGraph);
@@ -237,7 +224,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         SchedulingPipelinedRegion regionNew =
                 adapter.getPipelinedRegionOfVertex(new ExecutionVertexID(ejv1.getJobVertexId(), 0));
 
-        assertSame(regionOld, regionNew);
+        assertThat(regionNew).isSameAs(regionOld);
     }
 
     private JobVertex[] createJobVertices(ResultPartitionType resultPartitionType) {
@@ -260,7 +247,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             final DefaultSchedulingPipelinedRegion pipelinedRegionOfVertex) {
         final Set<DefaultExecutionVertex> allVertices =
                 Sets.newHashSet(pipelinedRegionOfVertex.getVertices());
-        assertEquals(Sets.newHashSet(adapter.getVertices()), allVertices);
+        assertThat(allVertices).isEqualTo(Sets.newHashSet(adapter.getVertices()));
     }
 
     private static void assertGraphEquals(
@@ -274,7 +261,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             ExecutionVertex originalVertex = originalVertices.next();
             DefaultExecutionVertex adaptedVertex = adaptedVertices.next();
 
-            assertEquals(originalVertex.getID(), adaptedVertex.getId());
+            assertThat(adaptedVertex.getId()).isEqualTo(originalVertex.getID());
 
             List<IntermediateResultPartition> originalConsumedPartitions = new ArrayList<>();
             for (ConsumedPartitionGroup consumedPartitionGroup :
@@ -300,17 +287,16 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             assertPartitionsEquals(originalProducedPartitions, adaptedProducedPartitions);
         }
 
-        assertFalse(
-                "Number of adapted vertices exceeds number of original vertices.",
-                adaptedVertices.hasNext());
+        assertThat(adaptedVertices)
+                .as("Number of adapted vertices exceeds number of original vertices.")
+                .isExhausted();
     }
 
     private static void assertPartitionsEquals(
             Iterable<IntermediateResultPartition> originalResultPartitions,
             Iterable<DefaultResultPartition> adaptedResultPartitions) {
 
-        assertEquals(
-                Iterables.size(originalResultPartitions), Iterables.size(adaptedResultPartitions));
+        assertThat(originalResultPartitions).hasSameSizeAs(adaptedResultPartitions);
 
         for (IntermediateResultPartition originalPartition : originalResultPartitions) {
             DefaultResultPartition adaptedPartition =
@@ -331,16 +317,14 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             ConsumerVertexGroup consumerVertexGroup = originalPartition.getConsumerVertexGroup();
             Optional<ConsumerVertexGroup> adaptedConsumers =
                     adaptedPartition.getConsumerVertexGroup();
-            assertTrue(adaptedConsumers.isPresent());
+            assertThat(adaptedConsumers).isPresent();
             for (ExecutionVertexID originalId : consumerVertexGroup) {
                 // it is sufficient to verify that some vertex exists with the correct ID here,
                 // since deep equality is verified later in the main loop
                 // this DOES rely on an implicit assumption that the vertices objects returned by
                 // the topology are
                 // identical to those stored in the partition
-                assertTrue(
-                        IterableUtils.toStream(adaptedConsumers.get())
-                                .anyMatch(adaptedConsumer -> adaptedConsumer.equals(originalId)));
+                assertThat(adaptedConsumers.get()).contains(originalId);
             }
         }
     }
@@ -349,11 +333,11 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             IntermediateResultPartition originalPartition,
             DefaultResultPartition adaptedPartition) {
 
-        assertEquals(originalPartition.getPartitionId(), adaptedPartition.getId());
-        assertEquals(
-                originalPartition.getIntermediateResult().getId(), adaptedPartition.getResultId());
-        assertEquals(originalPartition.getResultType(), adaptedPartition.getResultType());
-        assertEquals(
-                originalPartition.getProducer().getID(), adaptedPartition.getProducer().getId());
+        assertThat(adaptedPartition.getId()).isEqualTo(originalPartition.getPartitionId());
+        assertThat(adaptedPartition.getResultId())
+                .isEqualTo(originalPartition.getIntermediateResult().getId());
+        assertThat(adaptedPartition.getResultType()).isEqualTo(originalPartition.getResultType());
+        assertThat(adaptedPartition.getProducer().getId())
+                .isEqualTo(originalPartition.getProducer().getID());
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
index f97116e58b2..fc2df697464 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
@@ -27,10 +27,9 @@ import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition;
 import org.apache.flink.util.IterableUtils;
-import org.apache.flink.util.TestLogger;
 
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
 import java.util.List;
@@ -38,10 +37,10 @@ import java.util.Map;
 import java.util.function.Supplier;
 
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link DefaultExecutionVertex}. */
-public class DefaultExecutionVertexTest extends TestLogger {
+class DefaultExecutionVertexTest {
 
     private final TestExecutionStateSupplier stateSupplier = new TestExecutionStateSupplier();
 
@@ -51,8 +50,8 @@ public class DefaultExecutionVertexTest extends TestLogger {
 
     private IntermediateResultPartitionID intermediateResultPartitionId;
 
-    @Before
-    public void setUp() throws Exception {
+    @BeforeEach
+    void setUp() throws Exception {
 
         intermediateResultPartitionId = new IntermediateResultPartitionID();
 
@@ -97,15 +96,15 @@ public class DefaultExecutionVertexTest extends TestLogger {
     }
 
     @Test
-    public void testGetExecutionState() {
+    void testGetExecutionState() {
         for (ExecutionState state : ExecutionState.values()) {
             stateSupplier.setExecutionState(state);
-            assertEquals(state, producerVertex.getState());
+            assertThat(producerVertex.getState()).isEqualTo(state);
         }
     }
 
     @Test
-    public void testGetProducedResultPartitions() {
+    void testGetProducedResultPartitions() {
         IntermediateResultPartitionID partitionIds1 =
                 IterableUtils.toStream(producerVertex.getProducedResults())
                         .findAny()
@@ -114,11 +113,11 @@ public class DefaultExecutionVertexTest extends TestLogger {
                                 () ->
                                         new IllegalArgumentException(
                                                 "can not find result partition"));
-        assertEquals(partitionIds1, intermediateResultPartitionId);
+        assertThat(intermediateResultPartitionId).isEqualTo(partitionIds1);
     }
 
     @Test
-    public void testGetConsumedResultPartitions() {
+    void testGetConsumedResultPartitions() {
         IntermediateResultPartitionID partitionIds1 =
                 IterableUtils.toStream(consumerVertex.getConsumedResults())
                         .findAny()
@@ -127,7 +126,7 @@ public class DefaultExecutionVertexTest extends TestLogger {
                                 () ->
                                         new IllegalArgumentException(
                                                 "can not find result partition"));
-        assertEquals(partitionIds1, intermediateResultPartitionId);
+        assertThat(intermediateResultPartitionId).isEqualTo(partitionIds1);
     }
 
     /** A simple implementation of {@link Supplier} for testing. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
index 424dd469212..6d0024626c1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
@@ -24,24 +24,19 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
-import org.apache.flink.util.TestLogger;
 
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
 
 import java.util.HashMap;
 import java.util.Map;
 import java.util.function.Supplier;
 
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.contains;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link DefaultResultPartition}. */
-public class DefaultResultPartitionTest extends TestLogger {
+class DefaultResultPartitionTest {
 
     private static final TestResultPartitionStateSupplier resultPartitionState =
             new TestResultPartitionStateSupplier();
@@ -55,8 +50,8 @@ public class DefaultResultPartitionTest extends TestLogger {
     private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> consumerVertexGroups =
             new HashMap<>();
 
-    @Before
-    public void setUp() {
+    @BeforeEach
+    void setUp() {
         resultPartition =
                 new DefaultResultPartition(
                         resultPartitionId,
@@ -70,24 +65,24 @@ public class DefaultResultPartitionTest extends TestLogger {
     }
 
     @Test
-    public void testGetPartitionState() {
+    void testGetPartitionState() {
         for (ResultPartitionState state : ResultPartitionState.values()) {
             resultPartitionState.setResultPartitionState(state);
-            assertEquals(state, resultPartition.getState());
+            assertThat(resultPartition.getState()).isEqualTo(state);
         }
     }
 
     @Test
-    public void testGetConsumerVertexGroup() {
+    void testGetConsumerVertexGroup() {
 
-        assertFalse(resultPartition.getConsumerVertexGroup().isPresent());
+        assertThat(resultPartition.getConsumerVertexGroup()).isNotPresent();
 
         // test update consumers
         ExecutionVertexID executionVertexId = new ExecutionVertexID(new JobVertexID(), 0);
         consumerVertexGroups.put(
                 resultPartition.getId(), ConsumerVertexGroup.fromSingleVertex(executionVertexId));
-        assertTrue(resultPartition.getConsumerVertexGroup().isPresent());
-        assertThat(resultPartition.getConsumerVertexGroup().get(), contains(executionVertexId));
+        assertThat(resultPartition.getConsumerVertexGroup()).isPresent();
+        assertThat(resultPartition.getConsumerVertexGroup().get()).contains(executionVertexId);
     }
 
     /** A test {@link ResultPartitionState} supplier. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index d6508d8b66a..3c9bed61354 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -40,11 +40,10 @@ import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Arrays;
 import java.util.Iterator;
@@ -52,24 +51,23 @@ import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.is;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Test for {@link AdaptiveBatchScheduler}. */
-public class AdaptiveBatchSchedulerTest extends TestLogger {
+class AdaptiveBatchSchedulerTest {
 
     private static final int SOURCE_PARALLELISM_1 = 6;
     private static final int SOURCE_PARALLELISM_2 = 4;
 
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private static final ComponentMainThreadExecutor mainThreadExecutor =
             ComponentMainThreadExecutorServiceAdapter.forMainThread();
 
     @Test
-    public void testAdaptiveBatchScheduler() throws Exception {
+    void testAdaptiveBatchScheduler() throws Exception {
         JobGraph jobGraph = createJobGraph(false);
         Iterator<JobVertex> jobVertexIterator = jobGraph.getVertices().iterator();
         JobVertex source1 = jobVertexIterator.next();
@@ -82,22 +80,22 @@ public class AdaptiveBatchSchedulerTest extends TestLogger {
         final ExecutionJobVertex sinkExecutionJobVertex = graph.getJobVertex(sink.getID());
 
         scheduler.startScheduling();
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source1 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source2 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(10));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(10);
 
         // check that the jobGraph is updated
-        assertThat(sink.getParallelism(), is(10));
+        assertThat(sink.getParallelism()).isEqualTo(10);
     }
 
     @Test
-    public void testDecideParallelismForForwardTarget() throws Exception {
+    void testDecideParallelismForForwardTarget() throws Exception {
         JobGraph jobGraph = createJobGraph(true);
         Iterator<JobVertex> jobVertexIterator = jobGraph.getVertices().iterator();
         JobVertex source1 = jobVertexIterator.next();
@@ -110,18 +108,18 @@ public class AdaptiveBatchSchedulerTest extends TestLogger {
         final ExecutionJobVertex sinkExecutionJobVertex = graph.getJobVertex(sink.getID());
 
         scheduler.startScheduling();
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source1 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source2 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(SOURCE_PARALLELISM_1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
 
         // check that the jobGraph is updated
-        assertThat(sink.getParallelism(), is(SOURCE_PARALLELISM_1));
+        assertThat(sink.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
     }
 
     /** Transit the state of all executions. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
index dbf6cb99062..b45ba3f5ebe 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
@@ -29,11 +29,10 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Arrays;
 import java.util.HashSet;
@@ -41,14 +40,13 @@ import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link ForwardGroupComputeUtil}. */
-public class ForwardGroupComputeUtilTest extends TestLogger {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class ForwardGroupComputeUtilTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     /**
      * Tests that the computation of the job graph with isolated vertices works correctly.
@@ -62,7 +60,7 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testIsolatedVertices() throws Exception {
+    void testIsolatedVertices() throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
         JobVertex v3 = new JobVertex("v3");
@@ -83,14 +81,14 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testVariousResultPartitionTypesBetweenVertices() throws Exception {
+    void testVariousResultPartitionTypesBetweenVertices() throws Exception {
         testThreeVerticesConnectSequentially(false, true, 1, 2);
         testThreeVerticesConnectSequentially(false, false, 0);
         testThreeVerticesConnectSequentially(true, true, 1, 3);
     }
 
     private void testThreeVerticesConnectSequentially(
-            boolean isForward1, boolean isForward2, int numOfGroups, int... groupSizes)
+            boolean isForward1, boolean isForward2, int numOfGroups, Integer... groupSizes)
             throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
@@ -129,7 +127,7 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testTwoInputsMergesIntoOne() throws Exception {
+    void testTwoInputsMergesIntoOne() throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
         JobVertex v3 = new JobVertex("v3");
@@ -164,7 +162,7 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testOneInputSplitsIntoTwo() throws Exception {
+    void testOneInputSplitsIntoTwo() throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
         JobVertex v3 = new JobVertex("v3");
@@ -193,10 +191,11 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
                         .values());
     }
 
-    private static void checkGroupSize(Set<ForwardGroup> groups, int numOfGroups, int... sizes) {
-        assertEquals(numOfGroups, groups.size());
-        containsInAnyOrder(
-                groups.stream().map(ForwardGroup::size).collect(Collectors.toList()), sizes);
+    private static void checkGroupSize(
+            Set<ForwardGroup> groups, int numOfGroups, Integer... sizes) {
+        assertThat(groups.size()).isEqualTo(numOfGroups);
+        assertThat(groups.stream().map(ForwardGroup::size).collect(Collectors.toList()))
+                .contains(sizes);
     }
 
     private static DefaultExecutionGraph createDynamicGraph(JobVertex... vertices)