This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 8b72a2c72f0 KAFKA-14133: Move consumer mock in TaskManagerTest to 
Mockito - part 3 (#15497)
8b72a2c72f0 is described below

commit 8b72a2c72f09838fdd2e7416c98d30fe876b4078
Author: Christo Lolov <[email protected]>
AuthorDate: Mon Mar 11 11:51:20 2024 +0000

    KAFKA-14133: Move consumer mock in TaskManagerTest to Mockito - part 3 
(#15497)
    
    The previous pull request in this series was #15261.
    
    This pull request continues the migration of the consumer mock in 
TaskManagerTest test by test for easier reviews.
    
    The next pull request in the series will be #15254 which ought to complete 
the Mockito migration for the TaskManagerTest class
    
    Reviewer: Bruno Cadonna <[email protected]>
---
 .../processor/internals/TaskManagerTest.java       | 318 +++++++--------------
 1 file changed, 96 insertions(+), 222 deletions(-)

diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 681e69d3004..36d2a3e3786 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -194,10 +194,8 @@ public class TaskManagerTest {
     private StateDirectory stateDirectory;
     @org.mockito.Mock
     private ChangelogReader changeLogReader;
-    @Mock(type = MockType.STRICT)
-    private Consumer<byte[], byte[]> consumer;
     @org.mockito.Mock
-    private Consumer<byte[], byte[]> mockitoConsumer;
+    private Consumer<byte[], byte[]> consumer;
     @org.mockito.Mock
     private ActiveTaskCreator activeTaskCreator;
     @org.mockito.Mock
@@ -311,7 +309,6 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
-        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.activeTaskIds()).thenReturn(mkSet(taskId00, taskId01));
         when(tasks.task(taskId00)).thenReturn(activeTask1);
         final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
@@ -319,7 +316,7 @@ public class TaskManagerTest {
 
         taskManager.handleCorruption(mkSet(taskId00));
 
-        Mockito.verify(mockitoConsumer).assignment();
+        Mockito.verify(consumer).assignment();
         Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, 
taskId01));
         Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, 
taskId01));
     }
@@ -1252,13 +1249,12 @@ public class TaskManagerTest {
         when(stateUpdater.hasRemovedTasks()).thenReturn(true);
         when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(statefulTask));
         taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, 
true);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(statefulTask).suspend();
         Mockito.verify(tasks).addTask(statefulTask);
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -1284,7 +1280,7 @@ public class TaskManagerTest {
         when(stateUpdater.drainRemovedTasks())
             .thenReturn(mkSet(taskToRecycle0, taskToRecycle1, taskToClose, 
taskToUpdateInputPartitions, taskToCloseReviveAndUpdateInputPartitions));
         when(stateUpdater.restoresActiveTasks()).thenReturn(true);
-        when(activeTaskCreator.createActiveTaskFromStandby(taskToRecycle1, 
taskId01Partitions, mockitoConsumer))
+        when(activeTaskCreator.createActiveTaskFromStandby(taskToRecycle1, 
taskId01Partitions, consumer))
             .thenReturn(convertedTask1);
         when(standbyTaskCreator.createStandbyTaskFromActive(taskToRecycle0, 
taskId00Partitions))
             .thenReturn(convertedTask0);
@@ -1302,7 +1298,6 @@ public class TaskManagerTest {
             argThat(taskId -> 
!taskId.equals(taskToCloseReviveAndUpdateInputPartitions.id()))
         )).thenReturn(null);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { 
});
 
@@ -1320,7 +1315,7 @@ public class TaskManagerTest {
         
Mockito.verify(taskToCloseReviveAndUpdateInputPartitions).updateInputPartitions(Mockito.eq(taskId05Partitions),
 anyMap());
         
Mockito.verify(taskToCloseReviveAndUpdateInputPartitions).initializeIfNeeded();
         
Mockito.verify(stateUpdater).add(taskToCloseReviveAndUpdateInputPartitions);
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -1456,14 +1451,13 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTransitionToRunningOfRestoredTask(task, tasks);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(task).completeRestoration(noOpResetter);
         Mockito.verify(task).clearTaskTimeout();
         Mockito.verify(tasks).addTask(task);
-        Mockito.verify(mockitoConsumer).resume(task.inputPartitions());
+        Mockito.verify(consumer).resume(task.inputPartitions());
     }
 
     @Test
@@ -1473,7 +1467,6 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTransitionToRunningOfRestoredTask(task, tasks);
-        taskManager.setMainConsumer(mockitoConsumer);
         final TimeoutException timeoutException = new TimeoutException();
         doThrow(timeoutException).when(task).completeRestoration(noOpResetter);
 
@@ -1482,7 +1475,7 @@ public class TaskManagerTest {
         Mockito.verify(task).maybeInitTaskTimeoutOrThrow(anyLong(), 
Mockito.eq(timeoutException));
         Mockito.verify(tasks, never()).addTask(task);
         Mockito.verify(task, never()).clearTaskTimeout();
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     private TaskManager setUpTransitionToRunningOfRestoredTask(final 
StreamTask statefulTask,
@@ -1672,11 +1665,10 @@ public class TaskManagerTest {
         
when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
         when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
-        Mockito.verify(mockitoConsumer).resume(statefulTask.inputPartitions());
+        Mockito.verify(consumer).resume(statefulTask.inputPartitions());
         
Mockito.verify(statefulTask).updateInputPartitions(Mockito.eq(taskId01Partitions),
 anyMap());
         Mockito.verify(statefulTask).completeRestoration(noOpResetter);
         Mockito.verify(statefulTask).clearTaskTimeout();
@@ -1717,13 +1709,12 @@ public class TaskManagerTest {
         
when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
         when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(statefulTask).suspend();
         Mockito.verify(tasks).addTask(statefulTask);
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -1974,12 +1965,11 @@ public class TaskManagerTest {
     @Test
     public void shouldPauseAllTopicsWithoutStateUpdaterOnRebalanceComplete() {
         final Set<TopicPartition> assigned = mkSet(t1p0, t1p1);
-        taskManager.setMainConsumer(mockitoConsumer);
-        when(mockitoConsumer.assignment()).thenReturn(assigned);
+        when(consumer.assignment()).thenReturn(assigned);
 
         taskManager.handleRebalanceComplete();
 
-        Mockito.verify(mockitoConsumer).pause(assigned);
+        Mockito.verify(consumer).pause(assigned);
     }
 
     @Test
@@ -1989,14 +1979,13 @@ public class TaskManagerTest {
             .withInputPartitions(taskId00Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.allTasks()).thenReturn(mkSet(statefulTask0));
         final Set<TopicPartition> assigned = mkSet(t1p0, t1p1);
-        when(mockitoConsumer.assignment()).thenReturn(assigned);
+        when(consumer.assignment()).thenReturn(assigned);
 
         taskManager.handleRebalanceComplete();
 
-        Mockito.verify(mockitoConsumer).pause(mkSet(t1p1));
+        Mockito.verify(consumer).pause(mkSet(t1p1));
     }
 
     @Test
@@ -2021,7 +2010,7 @@ public class TaskManagerTest {
         assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, 
taskId01)));
         verify(stateDirectory);
 
-        Mockito.verify(mockitoConsumer).pause(assignment);
+        Mockito.verify(consumer).pause(assignment);
     }
 
     @Test
@@ -2040,7 +2029,6 @@ public class TaskManagerTest {
             .withInputPartitions(taskId03Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.allTasksPerId()).thenReturn(mkMap(mkEntry(taskId00, 
runningStatefulTask)));
         when(stateUpdater.getTasks()).thenReturn(mkSet(standbyTask, 
restoringStatefulTask));
         when(tasks.allTasks()).thenReturn(mkSet(runningStatefulTask));
@@ -2056,12 +2044,12 @@ public class TaskManagerTest {
         replay(stateDirectory);
 
         final Set<TopicPartition> assigned = mkSet(t1p0, t1p1, t1p2);
-        when(mockitoConsumer.assignment()).thenReturn(assigned);
+        when(consumer.assignment()).thenReturn(assigned);
 
         taskManager.handleRebalanceStart(singleton("topic"));
         taskManager.handleRebalanceComplete();
 
-        Mockito.verify(mockitoConsumer).pause(mkSet(t1p1, t1p2));
+        Mockito.verify(consumer).pause(mkSet(t1p1, t1p2));
         verify(stateDirectory);
         assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, 
taskId01, taskId02)));
     }
@@ -2331,12 +2319,10 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(offsets);
 
         // first `handleAssignment`
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -2362,8 +2348,6 @@ public class TaskManagerTest {
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         taskManager.handleRevocation(taskId00Partitions);
 
@@ -2387,7 +2371,7 @@ public class TaskManagerTest {
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
 
         // `handleAssignment`
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
 
@@ -2400,8 +2384,6 @@ public class TaskManagerTest {
         expectLockObtainedFor();
         replay(stateDirectory);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleRebalanceStart(emptySet());
         assertThat(taskManager.lockedTaskDirectories(), 
Matchers.is(mkSet(taskId00, taskId01)));
 
@@ -2443,14 +2425,12 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(offsets);
 
         // `handleAssignment`
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
         // `handleAssignment`
         doThrow(new 
RuntimeException("KABOOM!")).when(activeTaskCreator).closeAndRemoveTaskProducerIfNeeded(taskId00);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -2480,7 +2460,6 @@ public class TaskManagerTest {
             .withInputPartitions(taskId02Partitions).build();
         final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(mockitoConsumer);
         when(tasks.task(taskId03)).thenReturn(corruptedActiveTask);
         when(tasks.task(taskId02)).thenReturn(corruptedStandbyTask);
 
@@ -2496,7 +2475,7 @@ public class TaskManagerTest {
         Mockito.verify(tasks).removeTask(corruptedStandbyTask);
         
Mockito.verify(tasks).addPendingTasksToInit(mkSet(corruptedActiveTask));
         
Mockito.verify(tasks).addPendingTasksToInit(mkSet(corruptedStandbyTask));
-        Mockito.verify(mockitoConsumer).assignment();
+        Mockito.verify(consumer).assignment();
     }
 
     @Test
@@ -2515,13 +2494,11 @@ public class TaskManagerTest {
         };
 
         // `handleAssignment`
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(taskId00Partitions);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -2551,13 +2528,11 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(taskId00Partitions);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -2587,12 +2562,10 @@ public class TaskManagerTest {
         when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(asList(corruptedTask, nonCorruptedTask));
 
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(taskId00Partitions);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
 
@@ -2607,7 +2580,7 @@ public class TaskManagerTest {
         assertThat(corruptedTask.partitionsForOffsetReset, 
equalTo(taskId00Partitions));
 
         // check that we should not commit empty map either
-        Mockito.verify(mockitoConsumer, never()).commitSync(emptyMap());
+        Mockito.verify(consumer, never()).commitSync(emptyMap());
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
     }
 
@@ -2626,9 +2599,7 @@ public class TaskManagerTest {
         // `handleAssignment`
         when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment)))
             .thenReturn(asList(corruptedTask, nonRunningNonCorruptedTask));
-        when(mockitoConsumer.assignment()).thenReturn(taskId00Partitions);
-
-        taskManager.setMainConsumer(mockitoConsumer);
+        when(consumer.assignment()).thenReturn(taskId00Partitions);
 
         taskManager.handleAssignment(assignment, emptyMap());
 
@@ -2658,9 +2629,8 @@ public class TaskManagerTest {
         when(tasks.allTasksPerId()).thenReturn(mkMap(mkEntry(taskId02, 
corruptedTask)));
         when(tasks.task(taskId02)).thenReturn(corruptedTask);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(mockitoConsumer);
         when(stateUpdater.getTasks()).thenReturn(mkSet(activeRestoringTask, 
standbyTask));
-        
when(mockitoConsumer.assignment()).thenReturn(intersection(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
+        when(consumer.assignment()).thenReturn(intersection(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
         taskManager.handleCorruption(mkSet(taskId02));
 
@@ -2692,8 +2662,7 @@ public class TaskManagerTest {
         ));
         when(tasks.task(taskId02)).thenReturn(corruptedTask);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, false);
-        taskManager.setMainConsumer(mockitoConsumer);
-        
when(mockitoConsumer.assignment()).thenReturn(intersection(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
+        when(consumer.assignment()).thenReturn(intersection(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
         taskManager.handleCorruption(mkSet(taskId02));
 
@@ -2721,9 +2690,7 @@ public class TaskManagerTest {
             .thenReturn(singleton(runningNonCorruptedActive));
         
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singleton(corruptedStandby));
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
-
-        taskManager.setMainConsumer(mockitoConsumer);
+        when(consumer.assignment()).thenReturn(assignment);
 
         taskManager.handleAssignment(taskId01Assignment, taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -2762,7 +2729,7 @@ public class TaskManagerTest {
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(firstAssignement)))
             .thenReturn(asList(corruptedActive, uncorruptedActive));
 
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
 
@@ -2770,8 +2737,6 @@ public class TaskManagerTest {
 
         uncorruptedActive.setCommittableOffsetsAndMetadata(offsets);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(firstAssignement, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -2813,13 +2778,11 @@ public class TaskManagerTest {
         when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(asList(corruptedActive, uncorruptedActive));
 
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
 
-        doThrow(new 
TimeoutException()).when(mockitoConsumer).commitSync(offsets);
-
-        taskManager.setMainConsumer(mockitoConsumer);
+        doThrow(new TimeoutException()).when(consumer).commitSync(offsets);
 
         taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -2886,17 +2849,15 @@ public class TaskManagerTest {
         when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(asList(corruptedActiveTask, uncorruptedActiveTask));
 
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
 
         final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        when(mockitoConsumer.groupMetadata()).thenReturn(groupMetadata);
+        when(consumer.groupMetadata()).thenReturn(groupMetadata);
 
         doThrow(new 
TimeoutException()).when(producer).commitTransaction(offsets, groupMetadata);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -2964,16 +2925,14 @@ public class TaskManagerTest {
             mkEntry(taskId02, taskId02Partitions)
         );
 
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions, taskId02Partitions));
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(revokedActiveTask, 
unrevokedActiveTaskWithCommitNeeded, unrevokedActiveTaskWithoutCommitNeeded));
 
-        doThrow(new 
TimeoutException()).when(mockitoConsumer).commitSync(expectedCommittedOffsets);
-
-        taskManager.setMainConsumer(mockitoConsumer);
+        doThrow(new 
TimeoutException()).when(consumer).commitSync(expectedCommittedOffsets);
 
         taskManager.handleAssignment(assignmentActive, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3024,7 +2983,7 @@ public class TaskManagerTest {
             mkEntry(taskId02, taskId02Partitions)
             );
 
-        when(mockitoConsumer.assignment())
+        when(consumer.assignment())
             .thenReturn(assignment)
             .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions, taskId02Partitions));
 
@@ -3032,12 +2991,10 @@ public class TaskManagerTest {
             .thenReturn(asList(revokedActiveTask, unrevokedActiveTask, 
unrevokedActiveTaskWithoutCommitNeeded));
 
         final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        when(mockitoConsumer.groupMetadata()).thenReturn(groupMetadata);
+        when(consumer.groupMetadata()).thenReturn(groupMetadata);
 
         doThrow(new 
TimeoutException()).when(producer).commitTransaction(expectedCommittedOffsets, 
groupMetadata);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignmentActive, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(revokedActiveTask.state(), is(Task.State.RUNNING));
@@ -3063,11 +3020,9 @@ public class TaskManagerTest {
     public void shouldCloseStandbyUnassignedTasksWhenCreatingNewTasks() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
false, stateManager);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(emptyMap(), taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3083,12 +3038,10 @@ public class TaskManagerTest {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
         final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, 
false, stateManager);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3101,19 +3054,17 @@ public class TaskManagerTest {
 
         // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
         Mockito.verify(activeTaskCreator).createTasks(any(), 
Mockito.eq(emptyMap()));
-        Mockito.verify(mockitoConsumer, times(2)).assignment();
-        Mockito.verify(mockitoConsumer, times(2)).resume(assignment);
+        Mockito.verify(consumer, times(2)).assignment();
+        Mockito.verify(consumer, times(2)).resume(assignment);
     }
 
     @Test
     public void shouldUpdateInputPartitionsAfterRebalance() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3125,8 +3076,8 @@ public class TaskManagerTest {
         assertThat(task00.state(), is(Task.State.RUNNING));
         assertEquals(newPartitionsSet, task00.inputPartitions());
         // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
-        Mockito.verify(mockitoConsumer, times(2)).resume(assignment);
-        Mockito.verify(mockitoConsumer, times(2)).assignment();
+        Mockito.verify(consumer, times(2)).resume(assignment);
+        Mockito.verify(consumer, times(2)).assignment();
         Mockito.verify(activeTaskCreator).createTasks(any(), 
Mockito.eq(emptyMap()));
     }
 
@@ -3135,8 +3086,6 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = taskId00Assignment;
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignment))).thenReturn(singletonList(task00));
 
         taskManager.handleAssignment(assignment, emptyMap());
@@ -3149,8 +3098,8 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), 
Matchers.equalTo(singletonMap(taskId00, task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         Mockito.verify(changeLogReader).enforceRestoreActive();
-        Mockito.verify(mockitoConsumer).assignment();
-        Mockito.verify(mockitoConsumer).resume(Mockito.eq(emptySet()));
+        Mockito.verify(consumer).assignment();
+        Mockito.verify(consumer).resume(Mockito.eq(emptySet()));
     }
 
     @Test
@@ -3172,8 +3121,6 @@ public class TaskManagerTest {
             }
         };
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignment))).thenReturn(asList(task00, task01));
 
         taskManager.handleAssignment(assignment, emptyMap());
@@ -3191,7 +3138,7 @@ public class TaskManagerTest {
         );
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         Mockito.verify(changeLogReader).enforceRestoreActive();
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -3206,8 +3153,6 @@ public class TaskManagerTest {
             }
         };
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignment))).thenReturn(singletonList(task00));
 
         taskManager.handleAssignment(assignment, emptyMap());
@@ -3223,7 +3168,7 @@ public class TaskManagerTest {
         );
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         Mockito.verify(changeLogReader).enforceRestoreActive();
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -3232,11 +3177,9 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3279,7 +3222,7 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
             mkEntry(taskId10, taskId10Partitions)
         );
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(task00, task01, task02));
@@ -3289,15 +3232,13 @@ public class TaskManagerTest {
             .thenReturn(singletonList(task10));
 
         final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        when(mockitoConsumer.groupMetadata()).thenReturn(groupMetadata);
+        when(consumer.groupMetadata()).thenReturn(groupMetadata);
 
         task00.committedOffsets();
         task01.committedOffsets();
         task02.committedOffsets();
         task10.committedOffsets();
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3346,15 +3287,13 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
             mkEntry(taskId10, taskId10Partitions)
         );
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(task00, task01, task02));
         when(standbyTaskCreator.createTasks(assignmentStandby))
             .thenReturn(singletonList(task10));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3371,7 +3310,7 @@ public class TaskManagerTest {
         assertThat(task02.commitPrepared, is(false));
         assertThat(task10.commitPrepared, is(false));
 
-        Mockito.verify(mockitoConsumer).commitSync(expectedCommittedOffsets);
+        Mockito.verify(consumer).commitSync(expectedCommittedOffsets);
     }
 
     @Test
@@ -3386,13 +3325,11 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentActive = 
singletonMap(taskId00, taskId00Partitions);
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = 
singletonMap(taskId10, taskId10Partitions);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive))).thenReturn(singleton(task00));
         
when(standbyTaskCreator.createTasks(assignmentStandby)).thenReturn(singletonList(task10));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3416,13 +3353,11 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentActive = 
singletonMap(taskId00, taskId00Partitions);
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = 
singletonMap(taskId10, taskId10Partitions);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive))).thenReturn(singleton(task00));
         
when(standbyTaskCreator.createTasks(assignmentStandby)).thenReturn(singletonList(task10));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3460,9 +3395,8 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -3810,8 +3744,6 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = 
singletonMap(taskId00, taskId00Partitions);
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
false, stateManager);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         // `handleAssignment`
         
when(standbyTaskCreator.createTasks(assignment)).thenReturn(singletonList(task00));
 
@@ -3830,8 +3762,8 @@ public class TaskManagerTest {
         // the active task creator should also get closed (so that it closes 
the thread producer if applicable)
         Mockito.verify(activeTaskCreator).closeThreadProducerIfNeeded();
         // `tryToCompleteRestoration`
-        Mockito.verify(mockitoConsumer).assignment();
-        Mockito.verify(mockitoConsumer).resume(Mockito.eq(emptySet()));
+        Mockito.verify(consumer).assignment();
+        Mockito.verify(consumer).resume(Mockito.eq(emptySet()));
     }
 
     @Test
@@ -3918,13 +3850,11 @@ public class TaskManagerTest {
     @Test
     public void shouldInitializeNewActiveTasks() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -3932,18 +3862,16 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), 
Matchers.equalTo(singletonMap(taskId00, task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         // verifies that we actually resume the assignment at the end of 
restoration.
-        Mockito.verify(mockitoConsumer).resume(assignment);
+        Mockito.verify(consumer).resume(assignment);
     }
 
     @Test
     public void shouldInitialiseNewStandbyTasks() {
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(emptyMap(), taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -3954,8 +3882,7 @@ public class TaskManagerTest {
 
     @Test
     public void shouldHandleRebalanceEvents() {
-        taskManager.setMainConsumer(mockitoConsumer);
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         expect(stateDirectory.listNonEmptyTaskDirectories()).andReturn(new 
ArrayList<>());
         replay(stateDirectory);
         assertThat(taskManager.rebalanceInProgress(), is(false));
@@ -3963,7 +3890,7 @@ public class TaskManagerTest {
         assertThat(taskManager.rebalanceInProgress(), is(true));
         taskManager.handleRebalanceComplete();
         assertThat(taskManager.rebalanceInProgress(), is(false));
-        Mockito.verify(mockitoConsumer).pause(assignment);
+        Mockito.verify(consumer).pause(assignment);
     }
 
     @Test
@@ -3973,14 +3900,12 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(offsets);
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
         when(standbyTaskCreator.createTasks(taskId01Assignment))
             .thenReturn(singletonList(task01));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -3994,7 +3919,7 @@ public class TaskManagerTest {
         assertThat(task00.commitNeeded, is(false));
         assertThat(task01.commitNeeded, is(false));
 
-        Mockito.verify(mockitoConsumer).commitSync(offsets);
+        Mockito.verify(consumer).commitSync(offsets);
     }
 
     @Test
@@ -4017,14 +3942,12 @@ public class TaskManagerTest {
             mkEntry(taskId05, taskId05Partitions)
         );
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(Arrays.asList(task00, task01, task02));
         when(standbyTaskCreator.createTasks(assignmentStandby))
             .thenReturn(Arrays.asList(task03, task04, task05));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4049,11 +3972,9 @@ public class TaskManagerTest {
     public void shouldNotCommitOffsetsIfOnlyStandbyTasksAssigned() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, false, stateManager);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4073,7 +3994,7 @@ public class TaskManagerTest {
         makeTaskFolders(taskId00.toString(), taskId01.toString());
         expectDirectoryNotEmpty(taskId00, taskId01);
         expectLockObtainedFor(taskId00, taskId01);
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
         when(standbyTaskCreator.createTasks(taskId01Assignment))
@@ -4081,8 +4002,6 @@ public class TaskManagerTest {
 
         replay(stateDirectory);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4111,12 +4030,11 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
         taskManager.commitAll();
 
-        Mockito.verify(mockitoConsumer).commitSync(offsets);
+        Mockito.verify(consumer).commitSync(offsets);
     }
 
     @Test
@@ -4156,7 +4074,6 @@ public class TaskManagerTest {
                                                      final Map<TopicPartition, 
OffsetAndMetadata> offsetsT01,
                                                      final Map<TopicPartition, 
OffsetAndMetadata> offsetsT02) {
         final TaskManager taskManager = setUpTaskManager(processingMode, 
false);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
         task01.setCommittableOffsetsAndMetadata(offsetsT01);
@@ -4167,7 +4084,7 @@ public class TaskManagerTest {
         task02.setCommitNeeded();
         taskManager.addTask(task02);
 
-        when(mockitoConsumer.groupMetadata()).thenReturn(new 
ConsumerGroupMetadata("appId"));
+        when(consumer.groupMetadata()).thenReturn(new 
ConsumerGroupMetadata("appId"));
 
         taskManager.commitAll();
     }
@@ -4181,11 +4098,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4207,11 +4122,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(emptyMap(), taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4241,11 +4154,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4276,11 +4187,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4300,15 +4209,13 @@ public class TaskManagerTest {
     public void shouldIgnorePurgeDataErrors() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
 
         final KafkaFutureImpl<DeletedRecords> futureDeletedRecords = new 
KafkaFutureImpl<>();
         final DeleteRecordsResult deleteRecordsResult = new 
DeleteRecordsResult(singletonMap(t1p1, futureDeletedRecords));
         futureDeletedRecords.completeExceptionally(new Exception("KABOOM!"));
         when(adminClient.deleteRecords(any())).thenReturn(deleteRecordsResult);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.addTask(task00);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4349,14 +4256,12 @@ public class TaskManagerTest {
             mkEntry(taskId10, taskId10Partitions)
         );
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(task00, task01, task02, task03));
         when(standbyTaskCreator.createTasks(assignmentStandby))
             .thenReturn(singletonList(task04));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4381,7 +4286,7 @@ public class TaskManagerTest {
 
         assertThat(taskManager.maybeCommitActiveTasksPerUserRequested(), 
equalTo(3));
 
-        Mockito.verify(mockitoConsumer).commitSync(expectedCommittedOffsets);
+        Mockito.verify(consumer).commitSync(expectedCommittedOffsets);
     }
 
     @Test
@@ -4393,12 +4298,10 @@ public class TaskManagerTest {
         firstAssignment.put(taskId00, taskId00Partitions);
         firstAssignment.put(taskId01, taskId01Partitions);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(Arrays.asList(task00, task01));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4507,11 +4410,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4532,12 +4433,10 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4561,11 +4460,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4583,11 +4480,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4610,11 +4505,9 @@ public class TaskManagerTest {
             }
         };
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
@@ -4633,14 +4526,12 @@ public class TaskManagerTest {
             }
         };
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(false));
         assertThat(task00.state(), is(Task.State.RESTORING));
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -4649,11 +4540,9 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets);
 
-        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(consumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         try (final LogCaptureAppender appender = 
LogCaptureAppender.createAndRegister(TaskManager.class)) {
             appender.setClassLoggerToDebug(TaskManager.class);
             taskManager.handleAssignment(taskId00Assignment, emptyMap());
@@ -4808,9 +4697,7 @@ public class TaskManagerTest {
         
when(standbyTaskCreator.createTasks(standbyAssignment)).thenReturn(standbyTasks);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(allActiveTasksAssignment))).thenReturn(allActiveTasks);
 
-        lenient().when(mockitoConsumer.assignment()).thenReturn(assignment);
-
-        taskManager.setMainConsumer(mockitoConsumer);
+        lenient().when(consumer.assignment()).thenReturn(assignment);
 
         taskManager.handleAssignment(allActiveTasksAssignment, 
standbyAssignment);
         taskManager.tryToCompleteRestoration(time.milliseconds(), null);
@@ -4864,10 +4751,9 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
-        doThrow(new 
CommitFailedException()).when(mockitoConsumer).commitSync(offsets);
+        doThrow(new 
CommitFailedException()).when(consumer).commitSync(offsets);
 
         final TaskMigratedException thrown = assertThrows(
             TaskMigratedException.class,
@@ -4892,9 +4778,7 @@ public class TaskManagerTest {
         
task00.setCommittableOffsetsAndMetadata(taskId00Partitions.stream().collect(Collectors.toMap(p
 -> p, p -> new OffsetAndMetadata(0))));
         
task01.setCommittableOffsetsAndMetadata(taskId00Partitions.stream().collect(Collectors.toMap(p
 -> p, p -> new OffsetAndMetadata(0))));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
-        doThrow(new 
TimeoutException("KABOOM!")).doNothing().when(mockitoConsumer).commitSync(any(Map.class));
+        doThrow(new 
TimeoutException("KABOOM!")).doNothing().when(consumer).commitSync(any(Map.class));
 
         task00.setCommitNeeded();
 
@@ -4906,14 +4790,13 @@ public class TaskManagerTest {
         assertNull(task00.timeout);
         assertNull(task01.timeout);
 
-        Mockito.verify(mockitoConsumer, times(2)).commitSync(any(Map.class));
+        Mockito.verify(consumer, times(2)).commitSync(any(Map.class));
     }
 
     @Test
     public void shouldNotFailForTimeoutExceptionOnCommitWithEosAlpha() {
         final Tasks tasks = mock(Tasks.class);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.EXACTLY_ONCE_ALPHA, tasks, false);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         final StreamsProducer producer = mock(StreamsProducer.class);
         
when(activeTaskCreator.streamsProducerForTask(any(TaskId.class))).thenReturn(producer);
@@ -4949,13 +4832,12 @@ public class TaskManagerTest {
             equalTo(Collections.singleton(taskId00))
         );
 
-        Mockito.verify(mockitoConsumer, times(2)).groupMetadata();
+        Mockito.verify(consumer, times(2)).groupMetadata();
     }
 
     @Test
     public void 
shouldThrowTaskCorruptedExceptionForTimeoutExceptionOnCommitWithEosV2() {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, false);
-        taskManager.setMainConsumer(mockitoConsumer);
 
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.threadProducer()).thenReturn(producer);
@@ -4985,7 +4867,7 @@ public class TaskManagerTest {
             equalTo(mkSet(taskId00, taskId01))
         );
 
-        Mockito.verify(mockitoConsumer).groupMetadata();
+        Mockito.verify(consumer).groupMetadata();
     }
 
     @Test
@@ -4994,10 +4876,9 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
-        doThrow(new 
KafkaException()).when(mockitoConsumer).commitSync(offsets);
+        doThrow(new KafkaException()).when(consumer).commitSync(offsets);
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
@@ -5015,10 +4896,9 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.addTask(task01);
 
-        doThrow(new 
RuntimeException("KABOOM")).when(mockitoConsumer).commitSync(offsets);
+        doThrow(new 
RuntimeException("KABOOM")).when(consumer).commitSync(offsets);
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
@@ -5044,8 +4924,6 @@ public class TaskManagerTest {
         assignment.putAll(taskId01Assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignment))).thenReturn(asList(task00, task01));
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(assignment, Collections.emptyMap());
 
         final RuntimeException thrown = assertThrows(
@@ -5055,7 +4933,7 @@ public class TaskManagerTest {
         assertThat(thrown.getCause().getMessage(), is("KABOOM!"));
         assertThat(task00.state(), is(Task.State.SUSPENDED));
         assertThat(task01.state(), is(Task.State.SUSPENDED));
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -5071,15 +4949,13 @@ public class TaskManagerTest {
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(activeTask));
         when(standbyTaskCreator.createStandbyTaskFromActive(Mockito.any(), 
Mockito.eq(taskId00Partitions))).thenReturn(standbyTask);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(taskId00Assignment, 
Collections.emptyMap());
         taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
 
         
Mockito.verify(activeTaskCreator).closeAndRemoveTaskProducerIfNeeded(taskId00);
         Mockito.verify(activeTaskCreator).createTasks(any(), 
Mockito.eq(emptyMap()));
         Mockito.verify(standbyTaskCreator, 
times(2)).createTasks(Collections.emptyMap());
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test
@@ -5096,14 +4972,12 @@ public class TaskManagerTest {
         
when(activeTaskCreator.createActiveTaskFromStandby(Mockito.eq(standbyTask), 
Mockito.eq(taskId00Partitions), any()))
             .thenReturn(activeTask);
 
-        taskManager.setMainConsumer(mockitoConsumer);
-
         taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
         taskManager.handleAssignment(taskId00Assignment, 
Collections.emptyMap());
 
         Mockito.verify(activeTaskCreator, times(2)).createTasks(any(), 
Mockito.eq(emptyMap()));
         Mockito.verify(standbyTaskCreator).createTasks(Collections.emptyMap());
-        Mockito.verifyNoInteractions(mockitoConsumer);
+        Mockito.verifyNoInteractions(consumer);
     }
 
     @Test

Reply via email to