This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new a33c47ea4dd KAFKA-14133: Move consumer mock in TaskManagerTest to 
Mockito - part 2 (#15261)
a33c47ea4dd is described below

commit a33c47ea4ddc810d66b6ed17ab74e40c5b7668fb
Author: Christo Lolov <lol...@amazon.com>
AuthorDate: Thu Mar 7 09:33:31 2024 +0000

    KAFKA-14133: Move consumer mock in TaskManagerTest to Mockito - part 2 
(#15261)
    
    The previous pull request in this series was #15112.
    
    This pull request continues the migration of the consumer mock in 
TaskManagerTest test by test for easier reviews.
    
    I envision there will be at least 1 more pull request to clean things up. 
For example, all calls to taskManager.setMainConsumer should be removed.
    
    Reviewer: Bruno Cadonna <cado...@apache.org>
---
 .../processor/internals/TaskManagerTest.java       | 379 ++++++++++-----------
 1 file changed, 177 insertions(+), 202 deletions(-)

diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index ba1c91e7f71..681e69d3004 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -58,7 +58,6 @@ import java.nio.file.Files;
 import java.time.Duration;
 import java.util.ArrayList;
 
-import org.easymock.EasyMock;
 import org.easymock.EasyMockRunner;
 import org.easymock.Mock;
 import org.easymock.MockType;
@@ -112,7 +111,6 @@ import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
 import static org.easymock.EasyMock.replay;
-import static org.easymock.EasyMock.reset;
 import static org.easymock.EasyMock.verify;
 import static org.hamcrest.CoreMatchers.hasItem;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -133,6 +131,7 @@ import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.lenient;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.when;
@@ -185,6 +184,7 @@ public class TaskManagerTest {
     private final TaskId taskId10 = new TaskId(1, 0);
     private final TopicPartition t2p0 = new TopicPartition(topic2, 0);
     private final Set<TopicPartition> taskId10Partitions = mkSet(t2p0);
+    private final Set<TopicPartition> assignment = singleton(new 
TopicPartition("assignment", 0));
 
     final java.util.function.Consumer<Set<TopicPartition>> noOpResetter = 
partitions -> { };
 
@@ -2016,13 +2016,12 @@ public class TaskManagerTest {
         assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, 
taskId01, taskId02)));
 
         handleAssignment(taskId00Assignment, taskId01Assignment, emptyMap());
-        reset(consumer);
-        expectConsumerAssignmentPaused(consumer);
-        replay(consumer);
 
         taskManager.handleRebalanceComplete();
         assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, 
taskId01)));
         verify(stateDirectory);
+
+        Mockito.verify(mockitoConsumer).pause(assignment);
     }
 
     @Test
@@ -2332,19 +2331,11 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(offsets);
 
         // first `handleAssignment`
-        expectRestoreToBeCompleted(consumer);
-        when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        expectLastCall();
-
-        // `handleRevocation`
-        consumer.commitSync(offsets);
-        expectLastCall();
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
 
-        // second `handleAssignment`
-        consumer.commitSync(offsets);
-        expectLastCall();
+        when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -2369,12 +2360,9 @@ public class TaskManagerTest {
             }
         };
 
-        // first `handleAssignment`
-        expectRestoreToBeCompleted(consumer);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        expectLastCall();
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         taskManager.handleRevocation(taskId00Partitions);
@@ -2399,7 +2387,7 @@ public class TaskManagerTest {
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
 
         // `handleAssignment`
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
 
@@ -2412,11 +2400,11 @@ public class TaskManagerTest {
         expectLockObtainedFor();
         replay(stateDirectory);
 
+        taskManager.setMainConsumer(mockitoConsumer);
+
         taskManager.handleRebalanceStart(emptySet());
         assertThat(taskManager.lockedTaskDirectories(), 
Matchers.is(mkSet(taskId00, taskId01)));
 
-        replay(consumer);
-
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -2455,15 +2443,13 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(offsets);
 
         // `handleAssignment`
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
         // `handleAssignment`
-        consumer.commitSync(offsets);
-        expectLastCall();
         doThrow(new 
RuntimeException("KABOOM!")).when(activeTaskCreator).closeAndRemoveTaskProducerIfNeeded(taskId00);
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -2529,10 +2515,12 @@ public class TaskManagerTest {
         };
 
         // `handleAssignment`
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(taskId00Partitions);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        expect(consumer.assignment()).andReturn(taskId00Partitions);
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
@@ -2548,7 +2536,6 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, 
task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
 
-        verify(consumer);
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
     }
 
@@ -2564,10 +2551,12 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(taskId00Partitions);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        expect(consumer.assignment()).andReturn(taskId00Partitions);
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
@@ -2581,7 +2570,6 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, 
task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
 
-        verify(consumer);
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
     }
 
@@ -2592,20 +2580,20 @@ public class TaskManagerTest {
         final StateMachineTask corruptedTask = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
         final StateMachineTask nonCorruptedTask = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager);
 
-        final Map<TaskId, Set<TopicPartition>> assignment = new 
HashMap<>(taskId00Assignment);
-        assignment.putAll(taskId01Assignment);
+        final Map<TaskId, Set<TopicPartition>> firstAssignment = new 
HashMap<>(taskId00Assignment);
+        firstAssignment.putAll(taskId01Assignment);
 
         // `handleAssignment`
-        when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment)))
+        when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(asList(corruptedTask, nonCorruptedTask));
-        expectRestoreToBeCompleted(consumer);
-        expect(consumer.assignment()).andReturn(taskId00Partitions);
-        // check that we should not commit empty map either
-        consumer.commitSync(eq(emptyMap()));
-        expectLastCall().andStubThrow(new AssertionError("should not invoke 
commitSync when offset map is empty"));
-        replay(consumer);
 
-        taskManager.handleAssignment(assignment, emptyMap());
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(taskId00Partitions);
+
+        taskManager.setMainConsumer(mockitoConsumer);
+
+        taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
 
         assertThat(nonCorruptedTask.state(), is(Task.State.RUNNING));
@@ -2618,7 +2606,8 @@ public class TaskManagerTest {
         assertThat(nonCorruptedTask.partitionsForOffsetReset, 
equalTo(Collections.emptySet()));
         assertThat(corruptedTask.partitionsForOffsetReset, 
equalTo(taskId00Partitions));
 
-        verify(consumer);
+        // check that we should not commit empty map either
+        Mockito.verify(mockitoConsumer, never()).commitSync(emptyMap());
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
     }
 
@@ -2637,8 +2626,9 @@ public class TaskManagerTest {
         // `handleAssignment`
         when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment)))
             .thenReturn(asList(corruptedTask, nonRunningNonCorruptedTask));
-        expect(consumer.assignment()).andReturn(taskId00Partitions);
-        replay(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(taskId00Partitions);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignment, emptyMap());
 
@@ -2650,7 +2640,6 @@ public class TaskManagerTest {
         assertThat(corruptedTask.partitionsForOffsetReset, 
equalTo(taskId00Partitions));
 
         assertFalse(nonRunningNonCorruptedTask.commitPrepared);
-        verify(consumer);
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
     }
 
@@ -2732,9 +2721,9 @@ public class TaskManagerTest {
             .thenReturn(singleton(runningNonCorruptedActive));
         
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singleton(corruptedStandby));
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId01Assignment, taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -2751,7 +2740,6 @@ public class TaskManagerTest {
 
         assertThat(corruptedStandby.commitPrepared, is(true));
         assertThat(corruptedStandby.state(), is(Task.State.CREATED));
-        verify(consumer);
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
     }
 
@@ -2768,21 +2756,23 @@ public class TaskManagerTest {
         uncorruptedActive.setCommitNeeded();
 
         // handleAssignment
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.putAll(taskId00Assignment);
-        assignment.putAll(taskId01Assignment);
-        when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment)))
+        final Map<TaskId, Set<TopicPartition>> firstAssignement = new 
HashMap<>();
+        firstAssignement.putAll(taskId00Assignment);
+        firstAssignement.putAll(taskId01Assignment);
+        when(activeTaskCreator.createTasks(any(), 
Mockito.eq(firstAssignement)))
             .thenReturn(asList(corruptedActive, uncorruptedActive));
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
 
-        expect(consumer.assignment()).andStubReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions));
-
-        replay(consumer, stateDirectory);
+        replay(stateDirectory);
 
         uncorruptedActive.setCommittableOffsetsAndMetadata(offsets);
 
-        taskManager.handleAssignment(assignment, emptyMap());
+        taskManager.setMainConsumer(mockitoConsumer);
+
+        taskManager.handleAssignment(firstAssignement, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
         assertThat(uncorruptedActive.state(), is(Task.State.RUNNING));
@@ -2800,7 +2790,6 @@ public class TaskManagerTest {
         assertThat(uncorruptedActive.commitCompleted, is(false));
 
         assertThat(uncorruptedActive.state(), is(State.RUNNING));
-        verify(consumer);
     }
 
     @Test
@@ -2818,22 +2807,21 @@ public class TaskManagerTest {
         uncorruptedActive.setCommittableOffsetsAndMetadata(offsets);
 
         // handleAssignment
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.putAll(taskId00Assignment);
-        assignment.putAll(taskId01Assignment);
-        when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment)))
+        final Map<TaskId, Set<TopicPartition>> firstAssignment = new 
HashMap<>();
+        firstAssignment.putAll(taskId00Assignment);
+        firstAssignment.putAll(taskId01Assignment);
+        when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(asList(corruptedActive, uncorruptedActive));
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
 
-        consumer.commitSync(offsets);
-        expectLastCall().andThrow(new TimeoutException());
+        doThrow(new 
TimeoutException()).when(mockitoConsumer).commitSync(offsets);
 
-        expect(consumer.assignment()).andStubReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions));
-
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
-        taskManager.handleAssignment(assignment, emptyMap());
+        taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
         assertThat(uncorruptedActive.state(), is(Task.State.RUNNING));
@@ -2861,7 +2849,6 @@ public class TaskManagerTest {
 
         assertThat(corruptedActive.state(), is(Task.State.CREATED));
         assertThat(uncorruptedActive.state(), is(Task.State.CREATED));
-        verify(consumer);
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
     }
 
@@ -2893,24 +2880,24 @@ public class TaskManagerTest {
         uncorruptedActiveTask.setCommittableOffsetsAndMetadata(offsets);
 
         // handleAssignment
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.putAll(taskId00Assignment);
-        assignment.putAll(taskId01Assignment);
-        when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment)))
+        final Map<TaskId, Set<TopicPartition>> firstAssignment = new 
HashMap<>();
+        firstAssignment.putAll(taskId00Assignment);
+        firstAssignment.putAll(taskId01Assignment);
+        when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(asList(corruptedActiveTask, uncorruptedActiveTask));
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
 
         final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        expect(consumer.groupMetadata()).andReturn(groupMetadata);
+        when(mockitoConsumer.groupMetadata()).thenReturn(groupMetadata);
 
         doThrow(new 
TimeoutException()).when(producer).commitTransaction(offsets, groupMetadata);
 
-        expect(consumer.assignment()).andStubReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions));
-
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
-        taskManager.handleAssignment(assignment, emptyMap());
+        taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
         assertThat(uncorruptedActiveTask.state(), is(Task.State.RUNNING));
@@ -2944,7 +2931,6 @@ public class TaskManagerTest {
         assertThat(uncorruptedActiveTask.state(), is(Task.State.CREATED));
         assertThat(corruptedTaskChangelogMarkedAsCorrupted.get(), is(true));
         assertThat(uncorruptedTaskChangelogMarkedAsCorrupted.get(), is(true));
-        verify(consumer);
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId00ChangelogPartitions);
         
Mockito.verify(stateManager).markChangelogAsCorrupted(taskId01ChangelogPartitions);
     }
@@ -2978,16 +2964,16 @@ public class TaskManagerTest {
             mkEntry(taskId02, taskId02Partitions)
         );
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions, taskId02Partitions));
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(revokedActiveTask, 
unrevokedActiveTaskWithCommitNeeded, unrevokedActiveTaskWithoutCommitNeeded));
-        expectLastCall();
-        consumer.commitSync(expectedCommittedOffsets);
-        expectLastCall().andThrow(new TimeoutException());
-        expect(consumer.assignment()).andStubReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
-        replay(consumer);
+        doThrow(new 
TimeoutException()).when(mockitoConsumer).commitSync(expectedCommittedOffsets);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3038,19 +3024,19 @@ public class TaskManagerTest {
             mkEntry(taskId02, taskId02Partitions)
             );
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment())
+            .thenReturn(assignment)
+            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions, taskId02Partitions));
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(revokedActiveTask, unrevokedActiveTask, 
unrevokedActiveTaskWithoutCommitNeeded));
 
         final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        expect(consumer.groupMetadata()).andReturn(groupMetadata);
+        when(mockitoConsumer.groupMetadata()).thenReturn(groupMetadata);
 
         doThrow(new 
TimeoutException()).when(producer).commitTransaction(expectedCommittedOffsets, 
groupMetadata);
 
-        expect(consumer.assignment()).andStubReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
-
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3077,11 +3063,10 @@ public class TaskManagerTest {
     public void shouldCloseStandbyUnassignedTasksWhenCreatingNewTasks() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
false, stateManager);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singletonList(task00));
-        consumer.commitSync(Collections.emptyMap());
-        expectLastCall();
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(emptyMap(), taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3098,12 +3083,11 @@ public class TaskManagerTest {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
         final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, 
false, stateManager);
 
-        expectRestoreToBeCompleted(consumer);
-        // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3115,18 +3099,20 @@ public class TaskManagerTest {
         assertThat(task00.state(), is(Task.State.RUNNING));
         assertThat(task01.state(), is(Task.State.RUNNING));
 
+        // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
         Mockito.verify(activeTaskCreator).createTasks(any(), 
Mockito.eq(emptyMap()));
+        Mockito.verify(mockitoConsumer, times(2)).assignment();
+        Mockito.verify(mockitoConsumer, times(2)).resume(assignment);
     }
 
     @Test
     public void shouldUpdateInputPartitionsAfterRebalance() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
 
-        expectRestoreToBeCompleted(consumer);
-        // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3138,7 +3124,9 @@ public class TaskManagerTest {
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
         assertEquals(newPartitionsSet, task00.inputPartitions());
-        verify(consumer);
+        // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
+        Mockito.verify(mockitoConsumer, times(2)).resume(assignment);
+        Mockito.verify(mockitoConsumer, times(2)).assignment();
         Mockito.verify(activeTaskCreator).createTasks(any(), 
Mockito.eq(emptyMap()));
     }
 
@@ -3244,12 +3232,10 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        consumer.commitSync(offsets);
-        expectLastCall();
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3293,7 +3279,7 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
             mkEntry(taskId10, taskId10Partitions)
         );
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(task00, task01, task02));
@@ -3303,20 +3289,14 @@ public class TaskManagerTest {
             .thenReturn(singletonList(task10));
 
         final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        expect(consumer.groupMetadata()).andReturn(groupMetadata);
-        producer.commitTransaction(expectedCommittedOffsets, groupMetadata);
-        expectLastCall();
+        when(mockitoConsumer.groupMetadata()).thenReturn(groupMetadata);
 
         task00.committedOffsets();
-        EasyMock.expectLastCall();
         task01.committedOffsets();
-        EasyMock.expectLastCall();
         task02.committedOffsets();
-        EasyMock.expectLastCall();
         task10.committedOffsets();
-        EasyMock.expectLastCall();
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3331,6 +3311,8 @@ public class TaskManagerTest {
         assertThat(task01.commitNeeded, is(false));
         assertThat(task02.commitPrepared, is(false));
         assertThat(task10.commitPrepared, is(false));
+
+        Mockito.verify(producer).commitTransaction(expectedCommittedOffsets, 
groupMetadata);
     }
 
     @Test
@@ -3364,16 +3346,14 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
             mkEntry(taskId10, taskId10Partitions)
         );
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(task00, task01, task02));
         when(standbyTaskCreator.createTasks(assignmentStandby))
             .thenReturn(singletonList(task10));
-        consumer.commitSync(expectedCommittedOffsets);
-        expectLastCall();
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3390,6 +3370,8 @@ public class TaskManagerTest {
         assertThat(task01.commitPrepared, is(true));
         assertThat(task02.commitPrepared, is(false));
         assertThat(task10.commitPrepared, is(false));
+
+        Mockito.verify(mockitoConsumer).commitSync(expectedCommittedOffsets);
     }
 
     @Test
@@ -3404,12 +3386,12 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentActive = 
singletonMap(taskId00, taskId00Partitions);
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = 
singletonMap(taskId10, taskId10Partitions);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive))).thenReturn(singleton(task00));
         
when(standbyTaskCreator.createTasks(assignmentStandby)).thenReturn(singletonList(task10));
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3434,12 +3416,12 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignmentActive = 
singletonMap(taskId00, taskId00Partitions);
         final Map<TaskId, Set<TopicPartition>> assignmentStandby = 
singletonMap(taskId10, taskId10Partitions);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
 
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive))).thenReturn(singleton(task00));
         
when(standbyTaskCreator.createTasks(assignmentStandby)).thenReturn(singletonList(task10));
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3478,17 +3460,15 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
 
         assertThrows(RuntimeException.class, () -> 
taskManager.handleRevocation(taskId00Partitions));
         assertThat(task00.state(), is(Task.State.SUSPENDED));
-
-        verify(consumer);
     }
 
     @Test
@@ -3938,10 +3918,12 @@ public class TaskManagerTest {
     @Test
     public void shouldInitializeNewActiveTasks() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
+
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3950,17 +3932,17 @@ public class TaskManagerTest {
         assertThat(taskManager.activeTaskMap(), 
Matchers.equalTo(singletonMap(taskId00, task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         // verifies that we actually resume the assignment at the end of 
restoration.
-        verify(consumer);
+        Mockito.verify(mockitoConsumer).resume(assignment);
     }
 
     @Test
     public void shouldInitialiseNewStandbyTasks() {
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(emptyMap(), taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -3972,7 +3954,6 @@ public class TaskManagerTest {
 
     @Test
     public void shouldHandleRebalanceEvents() {
-        final Set<TopicPartition> assignment = singleton(new 
TopicPartition("assignment", 0));
         taskManager.setMainConsumer(mockitoConsumer);
         when(mockitoConsumer.assignment()).thenReturn(assignment);
         expect(stateDirectory.listNonEmptyTaskDirectories()).andReturn(new 
ArrayList<>());
@@ -3992,15 +3973,13 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(offsets);
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
         when(standbyTaskCreator.createTasks(taskId01Assignment))
             .thenReturn(singletonList(task01));
-        consumer.commitSync(offsets);
-        expectLastCall();
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4014,6 +3993,8 @@ public class TaskManagerTest {
         assertThat(taskManager.commitAll(), equalTo(2));
         assertThat(task00.commitNeeded, is(false));
         assertThat(task01.commitNeeded, is(false));
+
+        Mockito.verify(mockitoConsumer).commitSync(offsets);
     }
 
     @Test
@@ -4036,15 +4017,13 @@ public class TaskManagerTest {
             mkEntry(taskId05, taskId05Partitions)
         );
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(Arrays.asList(task00, task01, task02));
         when(standbyTaskCreator.createTasks(assignmentStandby))
             .thenReturn(Arrays.asList(task03, task04, task05));
 
-        consumer.commitSync(eq(emptyMap()));
-
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4070,10 +4049,10 @@ public class TaskManagerTest {
     public void shouldNotCommitOffsetsIfOnlyStandbyTasksAssigned() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, false, stateManager);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singletonList(task00));
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4094,13 +4073,15 @@ public class TaskManagerTest {
         makeTaskFolders(taskId00.toString(), taskId01.toString());
         expectDirectoryNotEmpty(taskId00, taskId01);
         expectLockObtainedFor(taskId00, taskId01);
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
         when(standbyTaskCreator.createTasks(taskId01Assignment))
             .thenReturn(singletonList(task01));
 
-        replay(stateDirectory, consumer);
+        replay(stateDirectory);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4200,9 +4181,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4225,10 +4207,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(emptyMap(), taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4259,10 +4241,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4294,9 +4276,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4317,14 +4300,14 @@ public class TaskManagerTest {
     public void shouldIgnorePurgeDataErrors() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
 
         final KafkaFutureImpl<DeletedRecords> futureDeletedRecords = new 
KafkaFutureImpl<>();
         final DeleteRecordsResult deleteRecordsResult = new 
DeleteRecordsResult(singletonMap(t1p1, futureDeletedRecords));
         futureDeletedRecords.completeExceptionally(new Exception("KABOOM!"));
         when(adminClient.deleteRecords(any())).thenReturn(deleteRecordsResult);
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.addTask(task00);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4366,15 +4349,13 @@ public class TaskManagerTest {
             mkEntry(taskId10, taskId10Partitions)
         );
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(assignmentActive)))
             .thenReturn(asList(task00, task01, task02, task03));
         when(standbyTaskCreator.createTasks(assignmentStandby))
             .thenReturn(singletonList(task04));
-        consumer.commitSync(expectedCommittedOffsets);
-        expectLastCall();
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4399,6 +4380,8 @@ public class TaskManagerTest {
         task04.setCommitRequested();
 
         assertThat(taskManager.maybeCommitActiveTasksPerUserRequested(), 
equalTo(3));
+
+        Mockito.verify(mockitoConsumer).commitSync(expectedCommittedOffsets);
     }
 
     @Test
@@ -4406,16 +4389,17 @@ public class TaskManagerTest {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
 
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.put(taskId00, taskId00Partitions);
-        assignment.put(taskId01, taskId01Partitions);
+        final Map<TaskId, Set<TopicPartition>> firstAssignment = new 
HashMap<>();
+        firstAssignment.put(taskId00, taskId00Partitions);
+        firstAssignment.put(taskId01, taskId01Partitions);
 
-        expectRestoreToBeCompleted(consumer);
-        when(activeTaskCreator.createTasks(any(), Mockito.eq(assignment)))
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
+        when(activeTaskCreator.createTasks(any(), Mockito.eq(firstAssignment)))
             .thenReturn(Arrays.asList(task00, task01));
-        replay(consumer);
 
-        taskManager.handleAssignment(assignment, emptyMap());
+        taskManager.setMainConsumer(mockitoConsumer);
+
+        taskManager.handleAssignment(firstAssignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
 
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -4523,9 +4507,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4547,10 +4532,11 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment)))
             .thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4575,9 +4561,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4596,9 +4583,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4622,9 +4610,10 @@ public class TaskManagerTest {
             }
         };
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        replay(consumer);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
@@ -4660,12 +4649,10 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets);
 
-        expectRestoreToBeCompleted(consumer);
+        when(mockitoConsumer.assignment()).thenReturn(assignment);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        consumer.commitSync(offsets);
-        expectLastCall();
 
-        replay(consumer);
+        taskManager.setMainConsumer(mockitoConsumer);
 
         try (final LogCaptureAppender appender = 
LogCaptureAppender.createAndRegister(TaskManager.class)) {
             appender.setClassLoggerToDebug(TaskManager.class);
@@ -4821,8 +4808,9 @@ public class TaskManagerTest {
         
when(standbyTaskCreator.createTasks(standbyAssignment)).thenReturn(standbyTasks);
         when(activeTaskCreator.createTasks(any(), 
Mockito.eq(allActiveTasksAssignment))).thenReturn(allActiveTasks);
 
-        expectRestoreToBeCompleted(consumer);
-        replay(consumer);
+        lenient().when(mockitoConsumer.assignment()).thenReturn(assignment);
+
+        taskManager.setMainConsumer(mockitoConsumer);
 
         taskManager.handleAssignment(allActiveTasksAssignment, 
standbyAssignment);
         taskManager.tryToCompleteRestoration(time.milliseconds(), null);
@@ -4870,12 +4858,6 @@ public class TaskManagerTest {
         }
     }
 
-    private static void expectConsumerAssignmentPaused(final Consumer<byte[], 
byte[]> consumer) {
-        final Set<TopicPartition> assignment = singleton(new 
TopicPartition("assignment", 0));
-        expect(consumer.assignment()).andReturn(assignment);
-        consumer.pause(assignment);
-    }
-
     @Test
     public void shouldThrowTaskMigratedExceptionOnCommitFailed() {
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
@@ -5135,13 +5117,6 @@ public class TaskManagerTest {
         assertEquals(taskManager.notPausedTasks().size(), 0);
     }
 
-    private static void expectRestoreToBeCompleted(final Consumer<byte[], 
byte[]> consumer) {
-        final Set<TopicPartition> assignment = singleton(new 
TopicPartition("assignment", 0));
-        expect(consumer.assignment()).andReturn(assignment);
-        consumer.resume(assignment);
-        expectLastCall();
-    }
-
     private static KafkaFutureImpl<DeletedRecords> completedFuture() {
         final KafkaFutureImpl<DeletedRecords> futureDeletedRecords = new 
KafkaFutureImpl<>();
         futureDeletedRecords.complete(null);


Reply via email to