You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fp...@apache.org on 2022/01/03 12:09:10 UTC

[flink] branch master updated: [FLINK-24857][test][FileSource][Kafka] Upgrade SourceReaderTestBase to JUnit 5

This is an automated email from the ASF dual-hosted git repository.

fpaul pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b02f02  [FLINK-24857][test][FileSource][Kafka] Upgrade SourceReaderTestBase to JUnit 5
1b02f02 is described below

commit 1b02f022cdcb4af449ed5dfe88acd3f9e4926dbe
Author: Yufei Zhang <af...@gmail.com>
AuthorDate: Mon Dec 6 12:11:50 2021 +0800

    [FLINK-24857][test][FileSource][Kafka] Upgrade SourceReaderTestBase to JUnit 5
---
 .../base/source/reader/SourceReaderBaseTest.java   | 126 +++++++++++----------
 .../kafka/source/reader/KafkaSourceReaderTest.java |  82 +++++++-------
 .../flink-connector-test-utils/pom.xml             |   6 +
 .../source/reader/SourceReaderTestBase.java        |  75 ++++++------
 4 files changed, 147 insertions(+), 142 deletions(-)

diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/SourceReaderBaseTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/SourceReaderBaseTest.java
index 07353c2..4a5544d 100644
--- a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/SourceReaderBaseTest.java
+++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/SourceReaderBaseTest.java
@@ -38,9 +38,7 @@ import org.apache.flink.connector.testutils.source.reader.TestingReaderContext;
 import org.apache.flink.connector.testutils.source.reader.TestingReaderOutput;
 import org.apache.flink.core.io.InputStatus;
 
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
+import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -50,75 +48,81 @@ import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.function.Supplier;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** A unit test class for {@link SourceReaderBase}. */
 public class SourceReaderBaseTest extends SourceReaderTestBase<MockSourceSplit> {
 
-    @Rule public ExpectedException expectedException = ExpectedException.none();
-
     @Test
-    public void testExceptionInSplitReader() throws Exception {
-        expectedException.expect(RuntimeException.class);
-        expectedException.expectMessage("One or more fetchers have encountered exception");
-        final String errMsg = "Testing Exception";
-
-        FutureCompletingBlockingQueue<RecordsWithSplitIds<int[]>> elementsQueue =
-                new FutureCompletingBlockingQueue<>();
-        // We have to handle split changes first, otherwise fetch will not be called.
-        try (MockSourceReader reader =
-                new MockSourceReader(
-                        elementsQueue,
-                        () ->
-                                new SplitReader<int[], MockSourceSplit>() {
-                                    @Override
-                                    public RecordsWithSplitIds<int[]> fetch() {
-                                        throw new RuntimeException(errMsg);
-                                    }
-
-                                    @Override
-                                    public void handleSplitsChanges(
-                                            SplitsChange<MockSourceSplit> splitsChanges) {}
-
-                                    @Override
-                                    public void wakeUp() {}
-
-                                    @Override
-                                    public void close() {}
-                                },
-                        getConfig(),
-                        new TestingReaderContext())) {
-            ValidatingSourceOutput output = new ValidatingSourceOutput();
-            reader.addSplits(
-                    Collections.singletonList(
-                            getSplit(0, NUM_RECORDS_PER_SPLIT, Boundedness.CONTINUOUS_UNBOUNDED)));
-            reader.notifyNoMoreSplits();
-            // This is not a real infinite loop, it is supposed to throw exception after two polls.
-            while (true) {
-                InputStatus inputStatus = reader.pollNext(output);
-                assertNotEquals(InputStatus.END_OF_INPUT, inputStatus);
-                // Add a sleep to avoid tight loop.
-                Thread.sleep(1);
-            }
-        }
+    void testExceptionInSplitReader() {
+        assertThatThrownBy(
+                        () -> {
+                            final String errMsg = "Testing Exception";
+
+                            FutureCompletingBlockingQueue<RecordsWithSplitIds<int[]>>
+                                    elementsQueue = new FutureCompletingBlockingQueue<>();
+                            // We have to handle split changes first, otherwise fetch will not be
+                            // called.
+                            try (MockSourceReader reader =
+                                    new MockSourceReader(
+                                            elementsQueue,
+                                            () ->
+                                                    new SplitReader<int[], MockSourceSplit>() {
+                                                        @Override
+                                                        public RecordsWithSplitIds<int[]> fetch() {
+                                                            throw new RuntimeException(errMsg);
+                                                        }
+
+                                                        @Override
+                                                        public void handleSplitsChanges(
+                                                                SplitsChange<MockSourceSplit>
+                                                                        splitsChanges) {}
+
+                                                        @Override
+                                                        public void wakeUp() {}
+
+                                                        @Override
+                                                        public void close() {}
+                                                    },
+                                            getConfig(),
+                                            new TestingReaderContext())) {
+                                ValidatingSourceOutput output = new ValidatingSourceOutput();
+                                reader.addSplits(
+                                        Collections.singletonList(
+                                                getSplit(
+                                                        0,
+                                                        NUM_RECORDS_PER_SPLIT,
+                                                        Boundedness.CONTINUOUS_UNBOUNDED)));
+                                reader.notifyNoMoreSplits();
+                                // This is not a real infinite loop, it is supposed to throw
+                                // exception after
+                                // two polls.
+                                while (true) {
+                                    InputStatus inputStatus = reader.pollNext(output);
+                                    assertThat(inputStatus).isNotEqualTo(InputStatus.END_OF_INPUT);
+                                    // Add a sleep to avoid tight loop.
+                                    Thread.sleep(1);
+                                }
+                            }
+                        })
+                .isInstanceOf(RuntimeException.class)
+                .hasMessage("One or more fetchers have encountered exception");
     }
 
     @Test
-    public void testRecordsWithSplitsNotRecycledWhenRecordsLeft() throws Exception {
+    void testRecordsWithSplitsNotRecycledWhenRecordsLeft() throws Exception {
         final TestingRecordsWithSplitIds<String> records =
                 new TestingRecordsWithSplitIds<>("test-split", "value1", "value2");
         final SourceReader<?, ?> reader = createReaderAndAwaitAvailable("test-split", records);
 
         reader.pollNext(new TestingReaderOutput<>());
 
-        assertFalse(records.isRecycled());
+        assertThat(records.isRecycled()).isFalse();
     }
 
     @Test
-    public void testRecordsWithSplitsRecycledWhenEmpty() throws Exception {
+    void testRecordsWithSplitsRecycledWhenEmpty() throws Exception {
         final TestingRecordsWithSplitIds<String> records =
                 new TestingRecordsWithSplitIds<>("test-split", "value1", "value2");
         final SourceReader<?, ?> reader = createReaderAndAwaitAvailable("test-split", records);
@@ -129,11 +133,11 @@ public class SourceReaderBaseTest extends SourceReaderTestBase<MockSourceSplit>
         reader.pollNext(new TestingReaderOutput<>());
         reader.pollNext(new TestingReaderOutput<>());
 
-        assertTrue(records.isRecycled());
+        assertThat(records.isRecycled()).isTrue();
     }
 
     @Test
-    public void testMultipleSplitsWithDifferentFinishingMoments() throws Exception {
+    void testMultipleSplitsWithDifferentFinishingMoments() throws Exception {
         FutureCompletingBlockingQueue<RecordsWithSplitIds<int[]>> elementsQueue =
                 new FutureCompletingBlockingQueue<>();
         MockSplitReader mockSplitReader =
@@ -169,7 +173,7 @@ public class SourceReaderBaseTest extends SourceReaderTestBase<MockSourceSplit>
     }
 
     @Test
-    public void testMultipleSplitsWithSeparatedFinishedRecord() throws Exception {
+    void testMultipleSplitsWithSeparatedFinishedRecord() throws Exception {
         FutureCompletingBlockingQueue<RecordsWithSplitIds<int[]>> elementsQueue =
                 new FutureCompletingBlockingQueue<>();
         MockSplitReader mockSplitReader =
@@ -205,7 +209,7 @@ public class SourceReaderBaseTest extends SourceReaderTestBase<MockSourceSplit>
     }
 
     @Test
-    public void testPollNextReturnMoreAvailableWhenAllSplitFetcherCloseWithLeftoverElementInQueue()
+    void testPollNextReturnMoreAvailableWhenAllSplitFetcherCloseWithLeftoverElementInQueue()
             throws Exception {
 
         FutureCompletingBlockingQueue<RecordsWithSplitIds<int[]>> elementsQueue =
@@ -231,8 +235,8 @@ public class SourceReaderBaseTest extends SourceReaderTestBase<MockSourceSplit>
 
         // Add the last record to the split when the splitFetcherManager shutting down SplitFetchers
         splitFetcherManager.getInShutdownSplitFetcherFuture().thenRun(() -> split.addRecord(1));
-        assertEquals(
-                InputStatus.MORE_AVAILABLE, sourceReader.pollNext(new TestingReaderOutput<>()));
+        assertThat(sourceReader.pollNext(new TestingReaderOutput<>()))
+                .isEqualTo(InputStatus.MORE_AVAILABLE);
     }
 
     // ---------------- helper methods -----------------
diff --git a/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/reader/KafkaSourceReaderTest.java b/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/reader/KafkaSourceReaderTest.java
index e671520..d5c9abb 100644
--- a/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/reader/KafkaSourceReaderTest.java
+++ b/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/reader/KafkaSourceReaderTest.java
@@ -52,9 +52,9 @@ import org.apache.kafka.common.serialization.IntegerSerializer;
 import org.apache.kafka.common.serialization.StringSerializer;
 import org.hamcrest.MatcherAssert;
 import org.hamcrest.Matchers;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
 
 import java.time.Duration;
 import java.util.ArrayList;
@@ -80,14 +80,13 @@ import static org.apache.flink.connector.kafka.source.metrics.KafkaSourceReaderM
 import static org.apache.flink.connector.kafka.source.metrics.KafkaSourceReaderMetrics.TOPIC_GROUP;
 import static org.apache.flink.connector.kafka.source.testutils.KafkaSourceTestEnv.NUM_PARTITIONS;
 import static org.apache.flink.core.testutils.CommonTestUtils.waitUtil;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link KafkaSourceReader}. */
 public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSplit> {
     private static final String TOPIC = "KafkaSourceReaderTest";
 
-    @BeforeClass
+    @BeforeAll
     public static void setup() throws Throwable {
         KafkaSourceTestEnv.setup();
         try (AdminClient adminClient = KafkaSourceTestEnv.getAdminClient()) {
@@ -114,7 +113,7 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                 getRecords(), StringSerializer.class, IntegerSerializer.class);
     }
 
-    @AfterClass
+    @AfterAll
     public static void tearDown() throws Exception {
         KafkaSourceTestEnv.tearDown();
     }
@@ -126,7 +125,7 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
     // -----------------------------------------
 
     @Test
-    public void testCommitOffsetsWithoutAliveFetchers() throws Exception {
+    void testCommitOffsetsWithoutAliveFetchers() throws Exception {
         final String groupId = "testCommitOffsetsWithoutAliveFetchers";
         try (KafkaSourceReader<Integer> reader =
                 (KafkaSourceReader<Integer>)
@@ -169,15 +168,15 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                             .listConsumerGroupOffsets(groupId)
                             .partitionsToOffsetAndMetadata()
                             .get();
-            assertEquals(1, committedOffsets.size());
-            committedOffsets.forEach(
-                    (tp, offsetAndMetadata) ->
-                            assertEquals(NUM_RECORDS_PER_SPLIT, offsetAndMetadata.offset()));
+            assertThat(committedOffsets).hasSize(1);
+            assertThat(committedOffsets.values())
+                    .extracting(OffsetAndMetadata::offset)
+                    .allMatch(offset -> offset == NUM_RECORDS_PER_SPLIT);
         }
     }
 
     @Test
-    public void testCommitEmptyOffsets() throws Exception {
+    void testCommitEmptyOffsets() throws Exception {
         final String groupId = "testCommitEmptyOffsets";
         try (KafkaSourceReader<Integer> reader =
                 (KafkaSourceReader<Integer>)
@@ -192,12 +191,12 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                             .listConsumerGroupOffsets(groupId)
                             .partitionsToOffsetAndMetadata()
                             .get();
-            assertTrue(committedOffsets.isEmpty());
+            assertThat(committedOffsets).isEmpty();
         }
     }
 
     @Test
-    public void testOffsetCommitOnCheckpointComplete() throws Exception {
+    void testOffsetCommitOnCheckpointComplete() throws Exception {
         final String groupId = "testOffsetCommitOnCheckpointComplete";
         try (KafkaSourceReader<Integer> reader =
                 (KafkaSourceReader<Integer>)
@@ -214,7 +213,7 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
             } while (output.count() < totalNumRecords);
 
             // The completion of the last checkpoint should subsume all the previous checkpoitns.
-            assertEquals(checkpointId, reader.getOffsetsToCommit().size());
+            assertThat(reader.getOffsetsToCommit()).hasSize((int) checkpointId);
 
             long lastCheckpointId = checkpointId;
             waitUtil(
@@ -240,15 +239,15 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                             .listConsumerGroupOffsets(groupId)
                             .partitionsToOffsetAndMetadata()
                             .get();
-            assertEquals(numSplits, committedOffsets.size());
-            committedOffsets.forEach(
-                    (tp, offsetAndMetadata) ->
-                            assertEquals(NUM_RECORDS_PER_SPLIT, offsetAndMetadata.offset()));
+            assertThat(committedOffsets).hasSize(numSplits);
+            assertThat(committedOffsets.values())
+                    .extracting(OffsetAndMetadata::offset)
+                    .allMatch(offset -> offset == NUM_RECORDS_PER_SPLIT);
         }
     }
 
     @Test
-    public void testNotCommitOffsetsForUninitializedSplits() throws Exception {
+    void testNotCommitOffsetsForUninitializedSplits() throws Exception {
         final long checkpointId = 1234L;
         try (KafkaSourceReader<Integer> reader = (KafkaSourceReader<Integer>) createReader()) {
             KafkaPartitionSplit split =
@@ -256,13 +255,13 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                             new TopicPartition(TOPIC, 0), KafkaPartitionSplit.EARLIEST_OFFSET);
             reader.addSplits(Collections.singletonList(split));
             reader.snapshotState(checkpointId);
-            assertEquals(1, reader.getOffsetsToCommit().size());
-            assertTrue(reader.getOffsetsToCommit().get(checkpointId).isEmpty());
+            assertThat(reader.getOffsetsToCommit()).hasSize(1);
+            assertThat(reader.getOffsetsToCommit().get(checkpointId)).isEmpty();
         }
     }
 
     @Test
-    public void testDisableOffsetCommit() throws Exception {
+    void testDisableOffsetCommit() throws Exception {
         final Properties properties = new Properties();
         properties.setProperty(KafkaSourceOptions.COMMIT_OFFSETS_ON_CHECKPOINT.key(), "false");
         try (KafkaSourceReader<Integer> reader =
@@ -282,13 +281,13 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                 // Create a checkpoint for each message consumption, but not complete them.
                 reader.snapshotState(checkpointId);
                 // Offsets to commit should be always empty because offset commit is disabled
-                assertEquals(0, reader.getOffsetsToCommit().size());
+                assertThat(reader.getOffsetsToCommit()).isEmpty();
             } while (output.count() < totalNumRecords);
         }
     }
 
     @Test
-    public void testKafkaSourceMetrics() throws Exception {
+    void testKafkaSourceMetrics() throws Exception {
         final MetricListener metricListener = new MetricListener();
         final String groupId = "testKafkaSourceMetrics";
         final TopicPartition tp0 = new TopicPartition(TOPIC, 0);
@@ -316,17 +315,18 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                             "Failed to poll %d records until timeout", NUM_RECORDS_PER_SPLIT * 2));
 
             // Metric "records-consumed-total" of KafkaConsumer should be NUM_RECORDS_PER_SPLIT
-            assertEquals(
-                    NUM_RECORDS_PER_SPLIT * 2,
-                    getKafkaConsumerMetric("records-consumed-total", metricListener));
+            assertThat(getKafkaConsumerMetric("records-consumed-total", metricListener))
+                    .isEqualTo(NUM_RECORDS_PER_SPLIT * 2);
 
             // Current consuming offset should be NUM_RECORD_PER_SPLIT - 1
-            assertEquals(NUM_RECORDS_PER_SPLIT - 1, getCurrentOffsetMetric(tp0, metricListener));
-            assertEquals(NUM_RECORDS_PER_SPLIT - 1, getCurrentOffsetMetric(tp1, metricListener));
+            assertThat(getCurrentOffsetMetric(tp0, metricListener))
+                    .isEqualTo(NUM_RECORDS_PER_SPLIT - 1);
+            assertThat(getCurrentOffsetMetric(tp1, metricListener))
+                    .isEqualTo(NUM_RECORDS_PER_SPLIT - 1);
 
             // No offset is committed till now
-            assertEquals(INITIAL_OFFSET, getCommittedOffsetMetric(tp0, metricListener));
-            assertEquals(INITIAL_OFFSET, getCommittedOffsetMetric(tp1, metricListener));
+            assertThat(getCommittedOffsetMetric(tp0, metricListener)).isEqualTo(INITIAL_OFFSET);
+            assertThat(getCommittedOffsetMetric(tp1, metricListener)).isEqualTo(INITIAL_OFFSET);
 
             // Trigger offset commit
             final long checkpointId = 15213L;
@@ -354,20 +354,22 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                     Matchers.greaterThan(0L));
 
             // Committed offset should be NUM_RECORD_PER_SPLIT
-            assertEquals(NUM_RECORDS_PER_SPLIT, getCommittedOffsetMetric(tp0, metricListener));
-            assertEquals(NUM_RECORDS_PER_SPLIT, getCommittedOffsetMetric(tp1, metricListener));
+            assertThat(getCommittedOffsetMetric(tp0, metricListener))
+                    .isEqualTo(NUM_RECORDS_PER_SPLIT);
+            assertThat(getCommittedOffsetMetric(tp1, metricListener))
+                    .isEqualTo(NUM_RECORDS_PER_SPLIT);
 
             // Number of successful commits should be greater than 0
             final Optional<Counter> commitsSucceeded =
                     metricListener.getCounter(
                             KAFKA_SOURCE_READER_METRIC_GROUP, COMMITS_SUCCEEDED_METRIC_COUNTER);
-            assertTrue(commitsSucceeded.isPresent());
+            assertThat(commitsSucceeded).isPresent();
             MatcherAssert.assertThat(commitsSucceeded.get().getCount(), Matchers.greaterThan(0L));
         }
     }
 
     @Test
-    public void testAssigningEmptySplits() throws Exception {
+    void testAssigningEmptySplits() throws Exception {
         // Normal split with NUM_RECORDS_PER_SPLIT records
         final KafkaPartitionSplit normalSplit =
                 new KafkaPartitionSplit(
@@ -511,7 +513,7 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
         final Optional<Gauge<Object>> kafkaConsumerGauge =
                 listener.getGauge(
                         KAFKA_SOURCE_READER_METRIC_GROUP, KAFKA_CONSUMER_METRIC_GROUP, name);
-        assertTrue(kafkaConsumerGauge.isPresent());
+        assertThat(kafkaConsumerGauge).isPresent();
         return ((Double) kafkaConsumerGauge.get().getValue()).longValue();
     }
 
@@ -524,7 +526,7 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                         PARTITION_GROUP,
                         String.valueOf(tp.partition()),
                         CURRENT_OFFSET_METRIC_GAUGE);
-        assertTrue(currentOffsetGauge.isPresent());
+        assertThat(currentOffsetGauge).isPresent();
         return (long) currentOffsetGauge.get().getValue();
     }
 
@@ -537,7 +539,7 @@ public class KafkaSourceReaderTest extends SourceReaderTestBase<KafkaPartitionSp
                         PARTITION_GROUP,
                         String.valueOf(tp.partition()),
                         COMMITTED_OFFSET_METRIC_GAUGE);
-        assertTrue(committedOffsetGauge.isPresent());
+        assertThat(committedOffsetGauge).isPresent();
         return (long) committedOffsetGauge.get().getValue();
     }
 
diff --git a/flink-test-utils-parent/flink-connector-test-utils/pom.xml b/flink-test-utils-parent/flink-connector-test-utils/pom.xml
index 89e0011..842b01d 100644
--- a/flink-test-utils-parent/flink-connector-test-utils/pom.xml
+++ b/flink-test-utils-parent/flink-connector-test-utils/pom.xml
@@ -55,6 +55,12 @@
 		</dependency>
 
 		<dependency>
+			<groupId>org.assertj</groupId>
+			<artifactId>assertj-core</artifactId>
+			<scope>compile</scope>
+		</dependency>
+
+		<dependency>
 			<groupId>org.junit.vintage</groupId>
 			<artifactId>junit-vintage-engine</artifactId>
 			<scope>compile</scope>
diff --git a/flink-test-utils-parent/flink-connector-test-utils/src/main/java/org/apache/flink/connector/testutils/source/reader/SourceReaderTestBase.java b/flink-test-utils-parent/flink-connector-test-utils/src/main/java/org/apache/flink/connector/testutils/source/reader/SourceReaderTestBase.java
index 462c5b2..6a83c5e 100644
--- a/flink-test-utils-parent/flink-connector-test-utils/src/main/java/org/apache/flink/connector/testutils/source/reader/SourceReaderTestBase.java
+++ b/flink-test-utils-parent/flink-connector-test-utils/src/main/java/org/apache/flink/connector/testutils/source/reader/SourceReaderTestBase.java
@@ -27,10 +27,9 @@ import org.apache.flink.api.connector.source.SourceSplit;
 import org.apache.flink.core.io.InputStatus;
 import org.apache.flink.util.TestLogger;
 
-import org.junit.After;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -39,8 +38,7 @@ import java.util.List;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
  * An abstract test class for all the unit tests of {@link SourceReader} to inherit.
@@ -48,7 +46,6 @@ import static org.junit.Assert.assertFalse;
  * @param <SplitT> the type of the splits.
  */
 public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends TestLogger {
-
     protected final int numSplits;
     protected final int totalNumRecords;
     protected static final int NUM_RECORDS_PER_SPLIT = 10;
@@ -62,9 +59,7 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
         return 10;
     }
 
-    @Rule public ExpectedException expectedException = ExpectedException.none();
-
-    @After
+    @AfterEach
     public void ensureNoDangling() {
         for (Thread t : Thread.getAllStackTraces().keySet()) {
             if (t.getName().equals("SourceFetcher")) {
@@ -75,7 +70,7 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
 
     /** Simply test the reader reads all the splits fine. */
     @Test
-    public void testRead() throws Exception {
+    void testRead() throws Exception {
         try (SourceReader<Integer, SplitT> reader = createReader()) {
             reader.addSplits(getSplits(numSplits, NUM_RECORDS_PER_SPLIT, Boundedness.BOUNDED));
             ValidatingSourceOutput output = new ValidatingSourceOutput();
@@ -87,7 +82,7 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
     }
 
     @Test
-    public void testAddSplitToExistingFetcher() throws Exception {
+    void testAddSplitToExistingFetcher() throws Exception {
         Thread.sleep(10);
         ValidatingSourceOutput output = new ValidatingSourceOutput();
         // Add a split to start the fetcher.
@@ -108,8 +103,9 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
         }
     }
 
-    @Test(timeout = 30000L)
-    public void testPollingFromEmptyQueue() throws Exception {
+    @Test
+    @Timeout(30)
+    void testPollingFromEmptyQueue() throws Exception {
         ValidatingSourceOutput output = new ValidatingSourceOutput();
         List<SplitT> splits =
                 Collections.singletonList(getSplit(0, NUM_RECORDS_PER_SPLIT, Boundedness.BOUNDED));
@@ -117,19 +113,19 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
         try (SourceReader<Integer, SplitT> reader =
                 consumeRecords(splits, output, NUM_RECORDS_PER_SPLIT)) {
             // Now let the main thread poll again.
-            assertEquals(
-                    "The status should be ",
-                    InputStatus.NOTHING_AVAILABLE,
-                    reader.pollNext(output));
+            assertThat(reader.pollNext(output))
+                    .as("The status should be %s", InputStatus.NOTHING_AVAILABLE)
+                    .isEqualTo(InputStatus.NOTHING_AVAILABLE);
         }
     }
 
-    @Test(timeout = 30000L)
-    public void testAvailableOnEmptyQueue() throws Exception {
+    @Test
+    @Timeout(30)
+    void testAvailableOnEmptyQueue() throws Exception {
         // Consumer all the records in the split.
         try (SourceReader<Integer, SplitT> reader = createReader()) {
             CompletableFuture<?> future = reader.isAvailable();
-            assertFalse("There should be no records ready for poll.", future.isDone());
+            assertThat(future.isDone()).as("There should be no records ready for poll.").isFalse();
             // Add a split to the reader so there are more records to be read.
             reader.addSplits(
                     Collections.singletonList(
@@ -140,8 +136,9 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
         }
     }
 
-    @Test(timeout = 30000L)
-    public void testSnapshot() throws Exception {
+    @Test
+    @Timeout(30)
+    void testSnapshot() throws Exception {
         ValidatingSourceOutput output = new ValidatingSourceOutput();
         // Add a split to start the fetcher.
         List<SplitT> splits =
@@ -149,12 +146,11 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
         try (SourceReader<Integer, SplitT> reader =
                 consumeRecords(splits, output, totalNumRecords)) {
             List<SplitT> state = reader.snapshotState(1L);
-            assertEquals("The snapshot should only have 10 splits. ", numSplits, state.size());
+            assertThat(state).as("The snapshot should only have 10 splits. ").hasSize(numSplits);
             for (int i = 0; i < numSplits; i++) {
-                assertEquals(
-                        "The first four splits should have been fully consumed.",
-                        NUM_RECORDS_PER_SPLIT,
-                        getNextRecordIndex(state.get(i)));
+                assertThat(getNextRecordIndex(state.get(i)))
+                        .as("The first four splits should have been fully consumed.")
+                        .isEqualTo(NUM_RECORDS_PER_SPLIT);
             }
         }
     }
@@ -207,19 +203,16 @@ public abstract class SourceReaderTestBase<SplitT extends SourceSplit> extends T
 
         public void validate() {
 
-            assertEquals(
-                    String.format("Should be %d distinct elements in total", totalNumRecords),
-                    totalNumRecords,
-                    consumedValues.size());
-            assertEquals(
-                    String.format("Should be %d elements in total", totalNumRecords),
-                    totalNumRecords,
-                    count);
-            assertEquals("The min value should be 0", 0, min);
-            assertEquals(
-                    String.format("The max value should be %d", totalNumRecords - 1),
-                    totalNumRecords - 1,
-                    max);
+            assertThat(consumedValues)
+                    .as("Should be %d distinct elements in total", totalNumRecords)
+                    .hasSize(totalNumRecords);
+            assertThat(count)
+                    .as("Should be %d elements in total", totalNumRecords)
+                    .isEqualTo(totalNumRecords);
+            assertThat(min).as("The min value should be 0", totalNumRecords).isZero();
+            assertThat(max)
+                    .as("The max value should be %d", totalNumRecords - 1)
+                    .isEqualTo(totalNumRecords - 1);
         }
 
         public int count() {