You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by di...@apache.org on 2024/01/10 13:34:13 UTC

(kafka) branch trunk updated: KAFKA-14133: Migrate consumer mock in TaskManagerTest to Mockito (#15112)

This is an automated email from the ASF dual-hosted git repository.

divijv pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ee96935c604 KAFKA-14133: Migrate consumer mock in TaskManagerTest to Mockito (#15112)
ee96935c604 is described below

commit ee96935c604067c69cc7ddc73f130269d9d0708e
Author: Christo Lolov <lo...@amazon.com>
AuthorDate: Wed Jan 10 13:34:03 2024 +0000

    KAFKA-14133: Migrate consumer mock in TaskManagerTest to Mockito (#15112)
    
    Reviewers: Divij Vaidya <di...@amazon.com>
---
 .../processor/internals/TaskManagerTest.java       | 195 +++++++++------------
 1 file changed, 83 insertions(+), 112 deletions(-)

diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 1f2fa5cc649..31ecf4c836b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -108,7 +108,6 @@ import static org.apache.kafka.streams.processor.internals.TopologyMetadata.UNNA
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.standbyTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statelessTask;
-import static org.easymock.EasyMock.anyObject;
 import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
@@ -198,6 +197,8 @@ public class TaskManagerTest {
     @Mock(type = MockType.STRICT)
     private Consumer<byte[], byte[]> consumer;
     @org.mockito.Mock
+    private Consumer<byte[], byte[]> mockitoConsumer;
+    @org.mockito.Mock
     private ActiveTaskCreator activeTaskCreator;
     @org.mockito.Mock
     private StandbyTaskCreator standbyTaskCreator;
@@ -310,16 +311,15 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
+        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.activeTaskIds()).thenReturn(mkSet(taskId00, taskId01));
         when(tasks.task(taskId00)).thenReturn(activeTask1);
         final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
         when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        replay(consumer);
 
         taskManager.handleCorruption(mkSet(taskId00));
 
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).assignment();
         Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
         Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
     }
@@ -1231,13 +1231,13 @@ public class TaskManagerTest {
         when(stateUpdater.hasRemovedTasks()).thenReturn(true);
         when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(statefulTask));
         taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
-        verify(consumer);
         Mockito.verify(statefulTask).suspend();
         Mockito.verify(tasks).addTask(statefulTask);
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -1263,13 +1263,10 @@ public class TaskManagerTest {
         when(stateUpdater.drainRemovedTasks())
             .thenReturn(mkSet(taskToRecycle0, taskToRecycle1, taskToClose, taskToUpdateInputPartitions, taskToCloseReviveAndUpdateInputPartitions));
         when(stateUpdater.restoresActiveTasks()).thenReturn(true);
-        when(activeTaskCreator.createActiveTaskFromStandby(taskToRecycle1, taskId01Partitions, consumer))
+        when(activeTaskCreator.createActiveTaskFromStandby(taskToRecycle1, taskId01Partitions, mockitoConsumer))
             .thenReturn(convertedTask1);
         when(standbyTaskCreator.createStandbyTaskFromActive(taskToRecycle0, taskId00Partitions))
             .thenReturn(convertedTask0);
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        consumer.resume(anyObject());
-        expectLastCall().anyTimes();
         final TasksRegistry tasks = mock(TasksRegistry.class);
         when(tasks.removePendingTaskToCloseClean(taskToClose.id())).thenReturn(true);
         when(tasks.removePendingTaskToCloseClean(argThat(taskId -> !taskId.equals(taskToClose.id())))).thenReturn(false);
@@ -1284,12 +1281,11 @@ public class TaskManagerTest {
             argThat(taskId -> !taskId.equals(taskToCloseReviveAndUpdateInputPartitions.id()))
         )).thenReturn(null);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(consumer);
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
 
-        verify(consumer);
+        Mockito.verify(activeTaskCreator, times(3)).closeAndRemoveTaskProducerIfNeeded(any());
         Mockito.verify(activeTaskCreator, times(3)).closeAndRemoveTaskProducerIfNeeded(any());
         Mockito.verify(convertedTask0).initializeIfNeeded();
         Mockito.verify(convertedTask1).initializeIfNeeded();
@@ -1303,6 +1299,7 @@ public class TaskManagerTest {
         Mockito.verify(taskToCloseReviveAndUpdateInputPartitions).updateInputPartitions(Mockito.eq(taskId05Partitions), anyMap());
         Mockito.verify(taskToCloseReviveAndUpdateInputPartitions).initializeIfNeeded();
         Mockito.verify(stateUpdater).add(taskToCloseReviveAndUpdateInputPartitions);
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -1438,15 +1435,14 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = setUpTransitionToRunningOfRestoredTask(task, tasks);
-        consumer.resume(task.inputPartitions());
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(task).completeRestoration(noOpResetter);
         Mockito.verify(task).clearTaskTimeout();
         Mockito.verify(tasks).addTask(task);
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).resume(task.inputPartitions());
     }
 
     @Test
@@ -1456,16 +1452,16 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = setUpTransitionToRunningOfRestoredTask(task, tasks);
+        taskManager.setMainConsumer(mockitoConsumer);
         final TimeoutException timeoutException = new TimeoutException();
         doThrow(timeoutException).when(task).completeRestoration(noOpResetter);
-        replay(consumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(task).maybeInitTaskTimeoutOrThrow(anyLong(), Mockito.eq(timeoutException));
         Mockito.verify(tasks, never()).addTask(task);
         Mockito.verify(task, never()).clearTaskTimeout();
-        verify(consumer);
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     private TaskManager setUpTransitionToRunningOfRestoredTask(final StreamTask statefulTask,
@@ -1655,12 +1651,11 @@ public class TaskManagerTest {
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
         when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        consumer.resume(statefulTask.inputPartitions());
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).resume(statefulTask.inputPartitions());
         Mockito.verify(statefulTask).updateInputPartitions(Mockito.eq(taskId01Partitions), anyMap());
         Mockito.verify(statefulTask).completeRestoration(noOpResetter);
         Mockito.verify(statefulTask).clearTaskTimeout();
@@ -1701,13 +1696,13 @@ public class TaskManagerTest {
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
         when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
-        verify(consumer);
         Mockito.verify(statefulTask).suspend();
         Mockito.verify(tasks).addTask(statefulTask);
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -1958,13 +1953,12 @@ public class TaskManagerTest {
     @Test
     public void shouldPauseAllTopicsWithoutStateUpdaterOnRebalanceComplete() {
         final Set<TopicPartition> assigned = mkSet(t1p0, t1p1);
-        expect(consumer.assignment()).andReturn(assigned);
-        consumer.pause(assigned);
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
+        when(mockitoConsumer.assignment()).thenReturn(assigned);
 
         taskManager.handleRebalanceComplete();
 
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).pause(assigned);
     }
 
     @Test
@@ -1974,15 +1968,14 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.allTasks()).thenReturn(mkSet(statefulTask0));
         final Set<TopicPartition> assigned = mkSet(t1p0, t1p1);
-        expect(consumer.assignment()).andReturn(assigned);
-        consumer.pause(mkSet(t1p1));
-        replay(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assigned);
 
         taskManager.handleRebalanceComplete();
 
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).pause(mkSet(t1p1));
     }
 
     @Test
@@ -2027,6 +2020,7 @@ public class TaskManagerTest {
             .withInputPartitions(taskId03Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.allTasksPerId()).thenReturn(mkMap(mkEntry(taskId00, runningStatefulTask)));
         when(stateUpdater.getTasks()).thenReturn(mkSet(standbyTask, restoringStatefulTask));
         when(tasks.allTasks()).thenReturn(mkSet(runningStatefulTask));
@@ -2042,14 +2036,12 @@ public class TaskManagerTest {
         replay(stateDirectory);
 
         final Set<TopicPartition> assigned = mkSet(t1p0, t1p1, t1p2);
-        expect(consumer.assignment()).andReturn(assigned);
-        consumer.pause(mkSet(t1p1, t1p2));
-        replay(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assigned);
 
         taskManager.handleRebalanceStart(singleton("topic"));
         taskManager.handleRebalanceComplete();
 
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).pause(mkSet(t1p1, t1p2));
         verify(stateDirectory);
         assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, taskId01, taskId02)));
     }
@@ -2481,10 +2473,9 @@ public class TaskManagerTest {
             .withInputPartitions(taskId02Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.task(taskId03)).thenReturn(corruptedActiveTask);
         when(tasks.task(taskId02)).thenReturn(corruptedStandbyTask);
-        expect(consumer.assignment()).andReturn(emptySet());
-        replay(consumer);
 
         taskManager.handleCorruption(mkSet(corruptedActiveTask.id(), corruptedStandbyTask.id()));
 
@@ -2498,6 +2489,7 @@ public class TaskManagerTest {
         Mockito.verify(tasks).removeTask(corruptedStandbyTask);
         Mockito.verify(tasks).addPendingTasksToInit(mkSet(corruptedActiveTask));
         Mockito.verify(tasks).addPendingTasksToInit(mkSet(corruptedStandbyTask));
+        Mockito.verify(mockitoConsumer).assignment();
     }
 
     @Test
@@ -2656,9 +2648,9 @@ public class TaskManagerTest {
         when(tasks.allTasksPerId()).thenReturn(mkMap(mkEntry(taskId02, corruptedTask)));
         when(tasks.task(taskId02)).thenReturn(corruptedTask);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        taskManager.setMainConsumer(mockitoConsumer);
         when(stateUpdater.getTasks()).thenReturn(mkSet(activeRestoringTask, standbyTask));
-        expect(consumer.assignment()).andReturn(intersection(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions));
-        replay(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(intersection(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
         taskManager.handleCorruption(mkSet(taskId02));
 
@@ -2668,7 +2660,6 @@ public class TaskManagerTest {
         Mockito.verify(standbyTask, never()).commitNeeded();
         Mockito.verify(standbyTask, never()).prepareCommit();
         Mockito.verify(standbyTask, never()).postCommit(Mockito.anyBoolean());
-        verify(consumer);
     }
 
     @Test
@@ -2691,8 +2682,8 @@ public class TaskManagerTest {
         ));
         when(tasks.task(taskId02)).thenReturn(corruptedTask);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, false);
-        expect(consumer.assignment()).andReturn(intersection(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions));
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
+        when(mockitoConsumer.assignment()).thenReturn(intersection(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
         taskManager.handleCorruption(mkSet(taskId02));
 
@@ -3135,11 +3126,9 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = taskId00Assignment;
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
 
-        expect(consumer.assignment()).andReturn(emptySet());
-        consumer.resume(eq(emptySet()));
-        expectLastCall();
+        taskManager.setMainConsumer(mockitoConsumer);
+
         when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
 
         taskManager.handleAssignment(assignment, emptyMap());
 
@@ -3151,6 +3140,8 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), Matchers.equalTo(singletonMap(taskId00, task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         Mockito.verify(changeLogReader).enforceRestoreActive();
+        Mockito.verify(mockitoConsumer).assignment();
+        Mockito.verify(mockitoConsumer).resume(Mockito.eq(emptySet()));
     }
 
     @Test
@@ -3172,13 +3163,9 @@ public class TaskManagerTest {
             }
         };
 
-        consumer.commitSync(Collections.emptyMap());
-        expectLastCall();
-        expect(consumer.assignment()).andReturn(emptySet());
-        consumer.resume(eq(emptySet()));
-        expectLastCall();
+        taskManager.setMainConsumer(mockitoConsumer);
+
         when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment))).thenReturn(asList(task00, task01));
-        replay(consumer);
 
         taskManager.handleAssignment(assignment, emptyMap());
 
@@ -3195,6 +3182,7 @@ public class TaskManagerTest {
         );
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         Mockito.verify(changeLogReader).enforceRestoreActive();
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -3209,14 +3197,9 @@ public class TaskManagerTest {
             }
         };
 
-        consumer.commitSync(Collections.emptyMap());
-        expectLastCall();
-        expect(consumer.assignment()).andReturn(emptySet());
-        consumer.resume(eq(emptySet()));
-        expectLastCall();
-        expectLastCall();
+        taskManager.setMainConsumer(mockitoConsumer);
+
         when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
 
         taskManager.handleAssignment(assignment, emptyMap());
 
@@ -3231,6 +3214,7 @@ public class TaskManagerTest {
         );
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         Mockito.verify(changeLogReader).enforceRestoreActive();
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -3451,7 +3435,6 @@ public class TaskManagerTest {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
 
         when(activeTaskCreator.createTasks(any(), Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(task00.state(), is(Task.State.CREATED));
@@ -3826,20 +3809,11 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = singletonMap(taskId00, taskId00Partitions);
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false, stateManager);
 
+        taskManager.setMainConsumer(mockitoConsumer);
+
         // `handleAssignment`
         when(standbyTaskCreator.createTasks(assignment)).thenReturn(singletonList(task00));
 
-        // `tryToCompleteRestoration`
-        expect(consumer.assignment()).andReturn(emptySet());
-        consumer.resume(eq(emptySet()));
-        expectLastCall();
-
-        // `shutdown`
-        consumer.commitSync(Collections.emptyMap());
-        expectLastCall();
-
-        replay(consumer);
-
         taskManager.handleAssignment(emptyMap(), assignment);
         assertThat(task00.state(), is(Task.State.CREATED));
 
@@ -3854,6 +3828,9 @@ public class TaskManagerTest {
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         // the active task creator should also get closed (so that it closes the thread producer if applicable)
         Mockito.verify(activeTaskCreator).closeThreadProducerIfNeeded();
+        // `tryToCompleteRestoration`
+        Mockito.verify(mockitoConsumer).assignment();
+        Mockito.verify(mockitoConsumer).resume(Mockito.eq(emptySet()));
     }
 
     @Test
@@ -3975,16 +3952,16 @@ public class TaskManagerTest {
     @Test
     public void shouldHandleRebalanceEvents() {
         final Set<TopicPartition> assignment = singleton(new TopicPartition("assignment", 0));
-        expect(consumer.assignment()).andReturn(assignment);
-        consumer.pause(assignment);
-        expectLastCall();
+        taskManager.setMainConsumer(mockitoConsumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         expect(stateDirectory.listNonEmptyTaskDirectories()).andReturn(new ArrayList<>());
-        replay(consumer, stateDirectory);
+        replay(stateDirectory);
         assertThat(taskManager.rebalanceInProgress(), is(false));
         taskManager.handleRebalanceStart(emptySet());
         assertThat(taskManager.rebalanceInProgress(), is(true));
         taskManager.handleRebalanceComplete();
         assertThat(taskManager.rebalanceInProgress(), is(false));
+        Mockito.verify(mockitoConsumer).pause(assignment);
     }
 
     @Test
@@ -4132,15 +4109,12 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
+        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
-        consumer.commitSync(offsets);
-        expectLastCall();
-        replay(consumer);
-
         taskManager.commitAll();
 
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).commitSync(offsets);
     }
 
     @Test
@@ -4180,6 +4154,7 @@ public class TaskManagerTest {
                                                      final Map<TopicPartition, OffsetAndMetadata> offsetsT01,
                                                      final Map<TopicPartition, OffsetAndMetadata> offsetsT02) {
         final TaskManager taskManager = setUpTaskManager(processingMode, false);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager);
         task01.setCommittableOffsetsAndMetadata(offsetsT01);
@@ -4190,13 +4165,9 @@ public class TaskManagerTest {
         task02.setCommitNeeded();
         taskManager.addTask(task02);
 
-        reset(consumer);
-        expect(consumer.groupMetadata()).andStubReturn(new ConsumerGroupMetadata("appId"));
-        replay(consumer);
+        when(mockitoConsumer.groupMetadata()).thenReturn(new ConsumerGroupMetadata("appId"));
 
         taskManager.commitAll();
-
-        verify(consumer);
     }
 
     @Test
@@ -4652,15 +4623,14 @@ public class TaskManagerTest {
             }
         };
 
+        taskManager.setMainConsumer(mockitoConsumer);
+
         when(activeTaskCreator.createTasks(any(), Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(false));
         assertThat(task00.state(), is(Task.State.RESTORING));
-        // this could be a bit mysterious; we're verifying _no_ interactions on the consumer,
-        // since the taskManager should _not_ resume the assignment while we're still in RESTORING
-        verify(consumer);
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -4891,11 +4861,10 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
+        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
-        consumer.commitSync(offsets);
-        expectLastCall().andThrow(new CommitFailedException());
-        replay(consumer);
+        doThrow(new CommitFailedException()).when(mockitoConsumer).commitSync(offsets);
 
         final TaskMigratedException thrown = assertThrows(
             TaskMigratedException.class,
@@ -4920,11 +4889,9 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(taskId00Partitions.stream().collect(Collectors.toMap(p -> p, p -> new OffsetAndMetadata(0))));
         task01.setCommittableOffsetsAndMetadata(taskId00Partitions.stream().collect(Collectors.toMap(p -> p, p -> new OffsetAndMetadata(0))));
 
-        consumer.commitSync(anyObject(Map.class));
-        expectLastCall().andThrow(new TimeoutException("KABOOM!"));
-        consumer.commitSync(anyObject(Map.class));
-        expectLastCall();
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
+
+        doThrow(new TimeoutException("KABOOM!")).doNothing().when(mockitoConsumer).commitSync(any(Map.class));
 
         task00.setCommitNeeded();
 
@@ -4935,12 +4902,15 @@ public class TaskManagerTest {
         assertThat(taskManager.commit(mkSet(task00, task01)), equalTo(1));
         assertNull(task00.timeout);
         assertNull(task01.timeout);
+
+        Mockito.verify(mockitoConsumer, times(2)).commitSync(any(Map.class));
     }
 
     @Test
     public void shouldNotFailForTimeoutExceptionOnCommitWithEosAlpha() {
         final Tasks tasks = mock(Tasks.class);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.EXACTLY_ONCE_ALPHA, tasks, false);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.streamsProducerForTask(any(TaskId.class))).thenReturn(producer);
@@ -4963,9 +4933,6 @@ public class TaskManagerTest {
         task01.setCommittableOffsetsAndMetadata(offsetsT01);
         final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager);
         when(tasks.allTasks()).thenReturn(mkSet(task00, task01, task02));
-        
-        expect(consumer.groupMetadata()).andStubReturn(null);
-        replay(consumer);
 
         task00.setCommitNeeded();
         task01.setCommitNeeded();
@@ -4978,11 +4945,14 @@ public class TaskManagerTest {
             exception.corruptedTasks(),
             equalTo(Collections.singleton(taskId00))
         );
+
+        Mockito.verify(mockitoConsumer, times(2)).groupMetadata();
     }
 
     @Test
     public void shouldThrowTaskCorruptedExceptionForTimeoutExceptionOnCommitWithEosV2() {
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, false);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.threadProducer()).thenReturn(producer);
@@ -5000,9 +4970,6 @@ public class TaskManagerTest {
         task01.setCommittableOffsetsAndMetadata(offsetsT01);
         final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager);
 
-        expect(consumer.groupMetadata()).andStubReturn(null);
-        replay(consumer);
-
         task00.setCommitNeeded();
         task01.setCommitNeeded();
 
@@ -5014,6 +4981,8 @@ public class TaskManagerTest {
             exception.corruptedTasks(),
             equalTo(mkSet(taskId00, taskId01))
         );
+
+        Mockito.verify(mockitoConsumer).groupMetadata();
     }
 
     @Test
@@ -5022,11 +4991,10 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
+        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
-        consumer.commitSync(offsets);
-        expectLastCall().andThrow(new KafkaException());
-        replay(consumer);
+        doThrow(new KafkaException()).when(mockitoConsumer).commitSync(offsets);
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
@@ -5044,11 +5012,10 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
+        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
-        consumer.commitSync(offsets);
-        expectLastCall().andThrow(new RuntimeException("KABOOM"));
-        replay(consumer);
+        doThrow(new RuntimeException("KABOOM")).when(mockitoConsumer).commitSync(offsets);
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
@@ -5073,7 +5040,8 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>(taskId00Assignment);
         assignment.putAll(taskId01Assignment);
         when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment))).thenReturn(asList(task00, task01));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignment, Collections.emptyMap());
 
@@ -5084,6 +5052,7 @@ public class TaskManagerTest {
         assertThat(thrown.getCause().getMessage(), is("KABOOM!"));
         assertThat(task00.state(), is(Task.State.SUSPENDED));
         assertThat(task01.state(), is(Task.State.SUSPENDED));
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -5099,7 +5068,7 @@ public class TaskManagerTest {
         when(activeTaskCreator.createTasks(any(), Mockito.eq(taskId00Assignment))).thenReturn(singletonList(activeTask));
         when(standbyTaskCreator.createStandbyTaskFromActive(Mockito.any(), Mockito.eq(taskId00Partitions))).thenReturn(standbyTask);
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, Collections.emptyMap());
         taskManager.handleAssignment(Collections.emptyMap(), taskId00Assignment);
@@ -5107,6 +5076,7 @@ public class TaskManagerTest {
         Mockito.verify(activeTaskCreator).closeAndRemoveTaskProducerIfNeeded(taskId00);
         Mockito.verify(activeTaskCreator).createTasks(any(), Mockito.eq(emptyMap()));
         Mockito.verify(standbyTaskCreator, times(2)).createTasks(Collections.emptyMap());
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test
@@ -5123,13 +5093,14 @@ public class TaskManagerTest {
         when(activeTaskCreator.createActiveTaskFromStandby(Mockito.eq(standbyTask), Mockito.eq(taskId00Partitions), any()))
             .thenReturn(activeTask);
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(Collections.emptyMap(), taskId00Assignment);
         taskManager.handleAssignment(taskId00Assignment, Collections.emptyMap());
 
         Mockito.verify(activeTaskCreator, times(2)).createTasks(any(), Mockito.eq(emptyMap()));
         Mockito.verify(standbyTaskCreator).createTasks(Collections.emptyMap());
+        Mockito.verifyNoInteractions(mockitoConsumer);
     }
 
     @Test