This is an automated email from the ASF dual-hosted git repository.

lucasbru pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 904b459dca8 KAFKA-19683: More cleanup and rewrite [4/N] (#20777)
904b459dca8 is described below

commit 904b459dca89d69f27d3801ad90e2477cefef0bf
Author: Shashank <[email protected]>
AuthorDate: Mon Nov 3 03:22:31 2025 -0800

    KAFKA-19683: More cleanup and rewrite [4/N] (#20777)
    
    - Removed tests that were tagged for removal
    - Rewrote more tests
    
    Reviewers: Lucas Brutschy <[email protected]>
---
 .../processor/internals/TaskManagerTest.java       | 975 ++++++++++++---------
 1 file changed, 562 insertions(+), 413 deletions(-)

diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 63cbc441f8a..0156d101ed8 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -73,7 +73,6 @@ import java.nio.file.Path;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.Deque;
 import java.util.HashMap;
@@ -103,7 +102,6 @@ import static 
org.apache.kafka.test.StreamsTestUtils.TaskBuilder.standbyTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask;
 import static org.hamcrest.CoreMatchers.hasItem;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.core.IsEqual.equalTo;
@@ -113,7 +111,6 @@ import static 
org.junit.jupiter.api.Assertions.assertInstanceOf;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.junit.jupiter.api.Assertions.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyLong;
@@ -2188,72 +2185,65 @@ public class TaskManagerTest {
 
     @Test
     public void shouldReviveCorruptTasks() {
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final AtomicBoolean enforcedCheckpoint = new AtomicBoolean(false);
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager) {
-            @Override
-            public void postCommit(final boolean enforceCheckpoint) {
-                if (enforceCheckpoint) {
-                    enforcedCheckpoint.set(true);
-                }
-                super.postCommit(enforceCheckpoint);
-            }
-        };
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.task(taskId00)).thenReturn(task00);
+        when(tasks.allTasksPerId()).thenReturn(singletonMap(taskId00, task00));
+        when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00));
 
-        // `handleAssignment`
-        when(consumer.assignment())
-            .thenReturn(assignment)
-            .thenReturn(taskId00Partitions);
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singletonList(task00));
+        when(task00.prepareCommit(false)).thenReturn(emptyMap());
+        doNothing().when(task00).postCommit(anyBoolean());
+        
when(task00.changelogPartitions()).thenReturn(taskId00ChangelogPartitions);
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
+        when(consumer.assignment()).thenReturn(taskId00Partitions);
 
-        task00.setChangelogOffsets(singletonMap(t1p0, 0L));
-        taskManager.handleCorruption(singleton(taskId00));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
-        assertThat(task00.commitPrepared, is(true));
-        assertThat(task00.state(), is(Task.State.CREATED));
-        assertThat(task00.partitionsForOffsetReset, 
equalTo(taskId00Partitions));
-        assertThat(enforcedCheckpoint.get(), is(true));
-        assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, 
task00)));
-        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        taskManager.handleCorruption(singleton(taskId00));
 
-        verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
+        verify(task00).prepareCommit(false);
+        verify(task00).postCommit(true);
+        verify(task00).addPartitionsForOffsetReset(taskId00Partitions);
+        verify(task00).changelogPartitions();
+        verify(task00).closeDirty();
+        verify(task00).revive();
+        verify(tasks).removeTask(task00);
+        verify(tasks).addPendingTasksToInit(Set.of(task00));
+        verify(consumer, never()).commitSync(emptyMap());
     }
 
     @Test
     public void shouldReviveCorruptTasksEvenIfTheyCannotCloseClean() {
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new RuntimeException("oops");
-            }
-        };
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.task(taskId00)).thenReturn(task00);
+        when(tasks.allTasksPerId()).thenReturn(singletonMap(taskId00, task00));
+        when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00));
 
-        when(consumer.assignment())
-            .thenReturn(assignment)
-            .thenReturn(taskId00Partitions);
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singletonList(task00));
+        when(task00.prepareCommit(false)).thenReturn(emptyMap());
+        
when(task00.changelogPartitions()).thenReturn(taskId00ChangelogPartitions);
+        doThrow(new RuntimeException("oops")).when(task00).suspend();
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
tp -> assertThat(tp, is(empty()))), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
-        task00.setChangelogOffsets(singletonMap(t1p0, 0L));
         taskManager.handleCorruption(singleton(taskId00));
-        assertThat(task00.commitPrepared, is(true));
-        assertThat(task00.state(), is(Task.State.CREATED));
-        assertThat(task00.partitionsForOffsetReset, 
equalTo(taskId00Partitions));
-        assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, 
task00)));
-        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
 
-        verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
+        verify(task00).prepareCommit(false);
+        verify(task00).suspend();
+        verify(task00, never()).postCommit(anyBoolean()); // postCommit is NOT 
called
+        verify(task00).closeDirty();
+        verify(task00).revive();
+        verify(tasks).removeTask(task00);
+        verify(tasks).addPendingTasksToInit(Set.of(task00));
+        verify(task00).addPartitionsForOffsetReset(emptySet());
     }
 
     @Test
@@ -2326,431 +2316,558 @@ public class TaskManagerTest {
 
     @Test
     public void 
shouldCleanAndReviveCorruptedStandbyTasksBeforeCommittingNonCorruptedTasks() {
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
-
-        final StateMachineTask corruptedStandby = new 
StateMachineTask(taskId00, taskId00Partitions, false, stateManager);
-        final StateMachineTask runningNonCorruptedActive = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager) {
-            @Override
-            public Map<TopicPartition, OffsetAndMetadata> prepareCommit(final 
boolean clean) {
-                throw new TaskMigratedException("You dropped out of the 
group!", new RuntimeException());
-            }
-        };
-
-        // handleAssignment
-        when(activeTaskCreator.createTasks(any(), eq(taskId01Assignment)))
-            .thenReturn(singleton(runningNonCorruptedActive));
-        
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singleton(corruptedStandby));
+        final StandbyTask corruptedStandby = standbyTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId00Partitions).build();
+        final StreamTask runningNonCorruptedActive = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions).build();
 
-        when(consumer.assignment()).thenReturn(assignment);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.task(taskId00)).thenReturn(corruptedStandby);
+        when(tasks.allTasksPerId()).thenReturn(mkMap(
+            mkEntry(taskId00, corruptedStandby),
+            mkEntry(taskId01, runningNonCorruptedActive)
+        ));
+        when(tasks.activeTaskIds()).thenReturn(Set.of(taskId01));
 
-        taskManager.handleAssignment(taskId01Assignment, taskId00Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        when(runningNonCorruptedActive.commitNeeded()).thenReturn(true);
+        when(runningNonCorruptedActive.prepareCommit(true))
+            .thenThrow(new TaskMigratedException("You dropped out of the 
group!", new RuntimeException()));
 
-        // make sure this will be committed and throw
-        assertThat(runningNonCorruptedActive.state(), is(Task.State.RUNNING));
-        assertThat(corruptedStandby.state(), is(Task.State.RUNNING));
+        
when(corruptedStandby.changelogPartitions()).thenReturn(taskId00ChangelogPartitions);
+        when(corruptedStandby.prepareCommit(false)).thenReturn(emptyMap());
+        doNothing().when(corruptedStandby).suspend();
+        doNothing().when(corruptedStandby).postCommit(anyBoolean());
 
-        runningNonCorruptedActive.setCommitNeeded();
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
-        corruptedStandby.setChangelogOffsets(singletonMap(t1p0, 0L));
         assertThrows(TaskMigratedException.class, () -> 
taskManager.handleCorruption(singleton(taskId00)));
 
+        // verifying the entire task lifecycle
+        final InOrder taskOrder = inOrder(corruptedStandby, 
runningNonCorruptedActive);
+        taskOrder.verify(corruptedStandby).prepareCommit(false);
+        taskOrder.verify(corruptedStandby).suspend();
+        taskOrder.verify(corruptedStandby).postCommit(true);
+        taskOrder.verify(corruptedStandby).closeDirty();
+        taskOrder.verify(corruptedStandby).revive();
+        taskOrder.verify(runningNonCorruptedActive).prepareCommit(true);
 
-        assertThat(corruptedStandby.commitPrepared, is(true));
-        assertThat(corruptedStandby.state(), is(Task.State.CREATED));
-        verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
+        verify(tasks).removeTask(corruptedStandby);
+        verify(tasks).addPendingTasksToInit(Set.of(corruptedStandby));
     }
 
     @Test
     public void shouldNotAttemptToCommitInHandleCorruptedDuringARebalance() {
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
-        when(stateDirectory.listNonEmptyTaskDirectories()).thenReturn(new 
ArrayList<>());
-
-        final StateMachineTask corruptedActive = new 
StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
-
-        // make sure this will attempt to be committed and throw
-        final StateMachineTask uncorruptedActive = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager);
-        final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
-        uncorruptedActive.setCommitNeeded();
+        final StreamTask corruptedActive = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        // handleAssignment
-        final Map<TaskId, Set<TopicPartition>> firstAssignement = new 
HashMap<>();
-        firstAssignement.putAll(taskId00Assignment);
-        firstAssignement.putAll(taskId01Assignment);
-        when(activeTaskCreator.createTasks(any(), eq(firstAssignement)))
-            .thenReturn(asList(corruptedActive, uncorruptedActive));
+        final StreamTask uncorruptedActive = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        when(consumer.assignment())
-            .thenReturn(assignment)
-            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.task(taskId00)).thenReturn(corruptedActive);
+        when(tasks.allTasksPerId()).thenReturn(mkMap(
+            mkEntry(taskId00, corruptedActive),
+            mkEntry(taskId01, uncorruptedActive)
+        ));
+        when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00, taskId01));
 
-        uncorruptedActive.setCommittableOffsetsAndMetadata(offsets);
+        when(uncorruptedActive.commitNeeded()).thenReturn(true);
+        when(uncorruptedActive.prepareCommit(true)).thenReturn(emptyMap());
 
-        taskManager.handleAssignment(firstAssignement, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        when(corruptedActive.prepareCommit(false)).thenReturn(emptyMap());
+        doNothing().when(corruptedActive).postCommit(anyBoolean());
 
-        assertThat(uncorruptedActive.state(), is(Task.State.RUNNING));
+        when(consumer.assignment()).thenReturn(taskId00Partitions);
 
-        assertThat(uncorruptedActive.commitPrepared, is(false));
-        assertThat(uncorruptedActive.commitNeeded, is(true));
-        assertThat(uncorruptedActive.commitCompleted, is(false));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         taskManager.handleRebalanceStart(singleton(topic1));
         assertThat(taskManager.rebalanceInProgress(), is(true));
+
         taskManager.handleCorruption(singleton(taskId00));
 
-        assertThat(uncorruptedActive.commitPrepared, is(false));
-        assertThat(uncorruptedActive.commitNeeded, is(true));
-        assertThat(uncorruptedActive.commitCompleted, is(false));
+        verify(uncorruptedActive, never()).prepareCommit(anyBoolean());
+        verify(uncorruptedActive, never()).postCommit(anyBoolean());
 
-        assertThat(uncorruptedActive.state(), is(State.RUNNING));
+        verify(corruptedActive).changelogPartitions();
+        verify(corruptedActive).postCommit(true);
+        
verify(corruptedActive).addPartitionsForOffsetReset(taskId00Partitions);
+        verify(consumer, never()).commitSync(emptyMap());
     }
 
+    @SuppressWarnings("removal")
     @Test
-    public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitWithAlos()
 {
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
-
-        final StateMachineTask corruptedActive = new 
StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
-        final StateMachineTask uncorruptedActive = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager) {
-            @Override
-            public void markChangelogAsCorrupted(final 
Collection<TopicPartition> partitions) {
-                fail("Should not try to mark changelogs as corrupted for 
uncorrupted task");
-            }
-        };
-        final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
-        uncorruptedActive.setCommittableOffsetsAndMetadata(offsets);
-
-        // handleAssignment
-        final Map<TaskId, Set<TopicPartition>> firstAssignment = new 
HashMap<>();
-        firstAssignment.putAll(taskId00Assignment);
-        firstAssignment.putAll(taskId01Assignment);
-        when(activeTaskCreator.createTasks(any(), eq(firstAssignment)))
-            .thenReturn(asList(corruptedActive, uncorruptedActive));
+    public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringHandleCorruptedWithEOS()
 {
+        final StreamTask corruptedActive = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        when(consumer.assignment())
-            .thenReturn(assignment)
-            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
+        // this task will time out during commit
+        final StreamTask uncorruptedActive = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        doThrow(new TimeoutException()).when(consumer).commitSync(offsets);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.task(taskId00)).thenReturn(corruptedActive);
+        when(tasks.allTasksPerId()).thenReturn(mkMap(
+            mkEntry(taskId00, corruptedActive),
+            mkEntry(taskId01, uncorruptedActive)
+        ));
+        when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00, taskId01));
 
-        taskManager.handleAssignment(firstAssignment, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        final StreamsProducer producer = mock(StreamsProducer.class);
+        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
+        final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
+        when(consumer.groupMetadata()).thenReturn(groupMetadata);
+        when(consumer.assignment()).thenReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions));
 
-        assertThat(uncorruptedActive.state(), is(Task.State.RUNNING));
-        assertThat(corruptedActive.state(), is(Task.State.RUNNING));
+        // mock uncorrupted task to indicate that it needs commit and will 
return offsets
+        final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
+        
when(tasks.tasks(singleton(taskId01))).thenReturn(Set.of(uncorruptedActive));
+        when(uncorruptedActive.commitNeeded()).thenReturn(true);
+        when(uncorruptedActive.prepareCommit(true)).thenReturn(offsets);
+        when(uncorruptedActive.prepareCommit(false)).thenReturn(emptyMap());
+        
when(uncorruptedActive.changelogPartitions()).thenReturn(taskId01ChangelogPartitions);
+        doNothing().when(uncorruptedActive).suspend();
+        doNothing().when(uncorruptedActive).closeDirty();
+        doNothing().when(uncorruptedActive).revive();
+        
doNothing().when(uncorruptedActive).markChangelogAsCorrupted(taskId01ChangelogPartitions);
+
+        // corrupted task doesn't need commit
+        when(corruptedActive.commitNeeded()).thenReturn(false);
+        when(corruptedActive.prepareCommit(false)).thenReturn(emptyMap());
+        
when(corruptedActive.changelogPartitions()).thenReturn(taskId00ChangelogPartitions);
+        doNothing().when(corruptedActive).suspend();
+        doNothing().when(corruptedActive).postCommit(true);
+        doNothing().when(corruptedActive).closeDirty();
+        doNothing().when(corruptedActive).revive();
 
-        // make sure this will be committed and throw
-        uncorruptedActive.setCommitNeeded();
-        corruptedActive.setChangelogOffsets(singletonMap(t1p0, 0L));
+        doThrow(new 
TimeoutException()).when(producer).commitTransaction(offsets, groupMetadata);
 
-        assertThat(uncorruptedActive.commitPrepared, is(false));
-        assertThat(uncorruptedActive.commitNeeded, is(true));
-        assertThat(uncorruptedActive.commitCompleted, is(false));
-        assertThat(corruptedActive.commitPrepared, is(false));
-        assertThat(corruptedActive.commitNeeded, is(false));
-        assertThat(corruptedActive.commitCompleted, is(false));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks);
 
         taskManager.handleCorruption(singleton(taskId00));
 
-        assertThat(uncorruptedActive.commitPrepared, is(true));
-        assertThat(uncorruptedActive.commitNeeded, is(false));
-        assertThat(uncorruptedActive.commitCompleted, is(false)); //if not 
corrupted, we should close dirty without committing
-        assertThat(corruptedActive.commitPrepared, is(true));
-        assertThat(corruptedActive.commitNeeded, is(false));
-        assertThat(corruptedActive.commitCompleted, is(true)); //if corrupted, 
should enforce checkpoint with corrupted tasks removed
-
-        assertThat(corruptedActive.state(), is(Task.State.CREATED));
-        assertThat(uncorruptedActive.state(), is(Task.State.CREATED));
-        verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
+        // 1. verify corrupted task was closed dirty and revived
+        final InOrder corruptedOrder = inOrder(corruptedActive, tasks);
+        corruptedOrder.verify(corruptedActive).prepareCommit(false);
+        corruptedOrder.verify(corruptedActive).suspend();
+        corruptedOrder.verify(corruptedActive).postCommit(true);
+        corruptedOrder.verify(corruptedActive).closeDirty();
+        corruptedOrder.verify(tasks).removeTask(corruptedActive);
+        corruptedOrder.verify(corruptedActive).revive();
+        
corruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(corruptedActive));
+
+        // 2. verify uncorrupted task attempted commit, failed with timeout, 
then was closed dirty and revived
+        final InOrder uncorruptedOrder = inOrder(uncorruptedActive, producer, 
tasks);
+        uncorruptedOrder.verify(uncorruptedActive).prepareCommit(true);
+        uncorruptedOrder.verify(producer).commitTransaction(offsets, 
groupMetadata); // tries to commit, throws TimeoutException
+        uncorruptedOrder.verify(uncorruptedActive).suspend();
+        uncorruptedOrder.verify(uncorruptedActive).postCommit(true);
+        uncorruptedOrder.verify(uncorruptedActive).closeDirty();
+        uncorruptedOrder.verify(tasks).removeTask(uncorruptedActive);
+        uncorruptedOrder.verify(uncorruptedActive).revive();
+        
uncorruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(uncorruptedActive));
+
+        // verify both tasks had their input partitions reset
+        
verify(corruptedActive).addPartitionsForOffsetReset(taskId00Partitions);
+        
verify(uncorruptedActive).addPartitionsForOffsetReset(taskId01Partitions);
     }
 
-    @SuppressWarnings("removal")
     @Test
-    public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringHandleCorruptedWithEOS()
 {
-        final TaskManager taskManager = 
setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, 
false);
-        final StreamsProducer producer = mock(StreamsProducer.class);
-        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
-
-        final AtomicBoolean corruptedTaskChangelogMarkedAsCorrupted = new 
AtomicBoolean(false);
-        final StateMachineTask corruptedActiveTask = new 
StateMachineTask(taskId00, taskId00Partitions, true, stateManager) {
-            @Override
-            public void markChangelogAsCorrupted(final 
Collection<TopicPartition> partitions) {
-                super.markChangelogAsCorrupted(partitions);
-                corruptedTaskChangelogMarkedAsCorrupted.set(true);
-            }
-        };
-
-        final AtomicBoolean uncorruptedTaskChangelogMarkedAsCorrupted = new 
AtomicBoolean(false);
-        final StateMachineTask uncorruptedActiveTask = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager) {
-            @Override
-            public void markChangelogAsCorrupted(final 
Collection<TopicPartition> partitions) {
-                super.markChangelogAsCorrupted(partitions);
-                uncorruptedTaskChangelogMarkedAsCorrupted.set(true);
-            }
-        };
-        final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
-        uncorruptedActiveTask.setCommittableOffsetsAndMetadata(offsets);
-
-        // handleAssignment
-        final Map<TaskId, Set<TopicPartition>> firstAssignment = new 
HashMap<>();
-        firstAssignment.putAll(taskId00Assignment);
-        firstAssignment.putAll(taskId01Assignment);
-        when(activeTaskCreator.createTasks(any(), eq(firstAssignment)))
-            .thenReturn(asList(corruptedActiveTask, uncorruptedActiveTask));
-
-        when(consumer.assignment())
-            .thenReturn(assignment)
-            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
-
-        final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        when(consumer.groupMetadata()).thenReturn(groupMetadata);
-
-        doThrow(new 
TimeoutException()).when(producer).commitTransaction(offsets, groupMetadata);
+    public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitWithAlos()
 {
+        final StreamTask corruptedActive = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        taskManager.handleAssignment(firstAssignment, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        // this task will time out during commit
+        final StreamTask uncorruptedActive = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        assertThat(uncorruptedActiveTask.state(), is(Task.State.RUNNING));
-        assertThat(corruptedActiveTask.state(), is(Task.State.RUNNING));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.task(taskId00)).thenReturn(corruptedActive);
+        when(tasks.allTasksPerId()).thenReturn(mkMap(
+            mkEntry(taskId00, corruptedActive),
+            mkEntry(taskId01, uncorruptedActive)
+        ));
+        when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00, taskId01));
+        when(tasks.activeTasks()).thenReturn(Set.of(corruptedActive, 
uncorruptedActive));
 
-        // make sure this will be committed and throw
-        uncorruptedActiveTask.setCommitNeeded();
+        // we need to mock uncorrupted task to indicate that it needs commit 
and will return offsets
+        final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
+        when(uncorruptedActive.commitNeeded()).thenReturn(true);
+        when(uncorruptedActive.prepareCommit(true)).thenReturn(offsets);
+        
when(uncorruptedActive.changelogPartitions()).thenReturn(taskId01ChangelogPartitions);
+        doNothing().when(uncorruptedActive).suspend();
+        doNothing().when(uncorruptedActive).closeDirty();
+        doNothing().when(uncorruptedActive).revive();
+
+        // corrupted task doesn't need commit
+        when(corruptedActive.commitNeeded()).thenReturn(false);
+        when(corruptedActive.prepareCommit(false)).thenReturn(emptyMap());
+        
when(corruptedActive.changelogPartitions()).thenReturn(taskId00ChangelogPartitions);
+        doNothing().when(corruptedActive).suspend();
+        doNothing().when(corruptedActive).postCommit(anyBoolean());
+        doNothing().when(corruptedActive).closeDirty();
+        doNothing().when(corruptedActive).revive();
 
-        final Map<TopicPartition, Long> corruptedActiveTaskChangelogOffsets = 
singletonMap(t1p0changelog, 0L);
-        
corruptedActiveTask.setChangelogOffsets(corruptedActiveTaskChangelogOffsets);
-        final Map<TopicPartition, Long> uncorruptedActiveTaskChangelogOffsets 
= singletonMap(t1p1changelog, 0L);
-        
uncorruptedActiveTask.setChangelogOffsets(uncorruptedActiveTaskChangelogOffsets);
+        doThrow(new TimeoutException()).when(consumer).commitSync(offsets);
+        when(consumer.assignment()).thenReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions));
 
-        assertThat(uncorruptedActiveTask.commitPrepared, is(false));
-        assertThat(uncorruptedActiveTask.commitNeeded, is(true));
-        assertThat(uncorruptedActiveTask.commitCompleted, is(false));
-        assertThat(corruptedActiveTask.commitPrepared, is(false));
-        assertThat(corruptedActiveTask.commitNeeded, is(false));
-        assertThat(corruptedActiveTask.commitCompleted, is(false));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         taskManager.handleCorruption(singleton(taskId00));
 
-        assertThat(uncorruptedActiveTask.commitPrepared, is(true));
-        assertThat(uncorruptedActiveTask.commitNeeded, is(false));
-        assertThat(uncorruptedActiveTask.commitCompleted, is(true)); //if 
corrupted due to timeout on commit, should enforce checkpoint with corrupted 
tasks removed
-        assertThat(corruptedActiveTask.commitPrepared, is(true));
-        assertThat(corruptedActiveTask.commitNeeded, is(false));
-        assertThat(corruptedActiveTask.commitCompleted, is(true)); //if 
corrupted, should enforce checkpoint with corrupted tasks removed
-
-        assertThat(corruptedActiveTask.state(), is(Task.State.CREATED));
-        assertThat(uncorruptedActiveTask.state(), is(Task.State.CREATED));
-        assertThat(corruptedTaskChangelogMarkedAsCorrupted.get(), is(true));
-        assertThat(uncorruptedTaskChangelogMarkedAsCorrupted.get(), is(true));
-        
verify(stateManager).markChangelogAsCorrupted(taskId00ChangelogPartitions);
-        
verify(stateManager).markChangelogAsCorrupted(taskId01ChangelogPartitions);
+        // 1. verify corrupted task was closed dirty and revived
+        final InOrder corruptedOrder = inOrder(corruptedActive, tasks);
+        corruptedOrder.verify(corruptedActive).prepareCommit(false);
+        corruptedOrder.verify(corruptedActive).suspend();
+        corruptedOrder.verify(corruptedActive).postCommit(true);
+        corruptedOrder.verify(corruptedActive).closeDirty();
+        corruptedOrder.verify(tasks).removeTask(corruptedActive);
+        corruptedOrder.verify(corruptedActive).revive();
+        
corruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(corruptedActive));
+
+        // 2. verify uncorrupted task attempted commit, failed with timeout, 
then was closed dirty and revived
+        final InOrder uncorruptedOrder = inOrder(uncorruptedActive, consumer, 
tasks);
+        uncorruptedOrder.verify(uncorruptedActive).prepareCommit(true);
+        uncorruptedOrder.verify(consumer).commitSync(offsets); // attempt 
commit, throws TimeoutException
+        uncorruptedOrder.verify(uncorruptedActive).prepareCommit(false);
+        uncorruptedOrder.verify(uncorruptedActive).suspend();
+        uncorruptedOrder.verify(uncorruptedActive).closeDirty();
+        uncorruptedOrder.verify(tasks).removeTask(uncorruptedActive);
+        uncorruptedOrder.verify(uncorruptedActive).revive();
+        
uncorruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(uncorruptedActive));
+
+        // verify both tasks had their input partitions reset
+        
verify(corruptedActive).addPartitionsForOffsetReset(taskId00Partitions);
+        
verify(uncorruptedActive).addPartitionsForOffsetReset(taskId01Partitions);
     }
 
     @Test
     public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringRevocationWithAlos()
 {
-        final StateMachineTask revokedActiveTask = new 
StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
-        final Map<TopicPartition, OffsetAndMetadata> offsets00 = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        revokedActiveTask.setCommittableOffsetsAndMetadata(offsets00);
-        revokedActiveTask.setCommitNeeded();
+        // task being revoked - needs commit
+        final StreamTask revokedActiveTask = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final StateMachineTask unrevokedActiveTaskWithCommitNeeded = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager) {
-            @Override
-            public void markChangelogAsCorrupted(final 
Collection<TopicPartition> partitions) {
-                fail("Should not try to mark changelogs as corrupted for 
uncorrupted task");
-            }
-        };
-        final Map<TopicPartition, OffsetAndMetadata> offsets01 = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
-        
unrevokedActiveTaskWithCommitNeeded.setCommittableOffsetsAndMetadata(offsets01);
-        unrevokedActiveTaskWithCommitNeeded.setCommitNeeded();
+        // unrevoked task that needs commit - this will also be affected by 
timeout
+        final StreamTask unrevokedActiveTaskWithCommit = 
statefulTask(taskId01, taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final StateMachineTask unrevokedActiveTaskWithoutCommitNeeded = new 
StateMachineTask(taskId02, taskId02Partitions, true, stateManager);
+        // unrevoked task without commit needed - this should stay RUNNING
+        final StreamTask unrevokedActiveTaskWithoutCommit = 
statefulTask(taskId02, taskId02ChangelogPartitions)
+            .withInputPartitions(taskId02Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets 
= new HashMap<>();
-        expectedCommittedOffsets.putAll(offsets00);
-        expectedCommittedOffsets.putAll(offsets01);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(revokedActiveTask, 
unrevokedActiveTaskWithCommit, unrevokedActiveTaskWithoutCommit));
 
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
-        );
+        when(consumer.assignment()).thenReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
-        when(consumer.assignment())
-            .thenReturn(assignment)
-            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions, taskId02Partitions));
+        // revoked task needs commit
+        final Map<TopicPartition, OffsetAndMetadata> revokedTaskOffsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        when(revokedActiveTask.commitNeeded()).thenReturn(true);
+        
when(revokedActiveTask.prepareCommit(true)).thenReturn(revokedTaskOffsets);
+        
when(revokedActiveTask.changelogPartitions()).thenReturn(taskId00ChangelogPartitions);
+        doNothing().when(revokedActiveTask).suspend();
+        doNothing().when(revokedActiveTask).closeDirty();
+        doNothing().when(revokedActiveTask).revive();
 
-        when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
-            .thenReturn(asList(revokedActiveTask, 
unrevokedActiveTaskWithCommitNeeded, unrevokedActiveTaskWithoutCommitNeeded));
+        // unrevoked task with commit also takes part in commit
+        final Map<TopicPartition, OffsetAndMetadata> unrevokedTaskOffsets = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
+        when(unrevokedActiveTaskWithCommit.commitNeeded()).thenReturn(true);
+        
when(unrevokedActiveTaskWithCommit.prepareCommit(true)).thenReturn(unrevokedTaskOffsets);
+        
when(unrevokedActiveTaskWithCommit.changelogPartitions()).thenReturn(taskId01ChangelogPartitions);
+        doNothing().when(unrevokedActiveTaskWithCommit).suspend();
+        doNothing().when(unrevokedActiveTaskWithCommit).closeDirty();
+        doNothing().when(unrevokedActiveTaskWithCommit).revive();
+
+        // unrevoked task without commit needed
+        
when(unrevokedActiveTaskWithoutCommit.commitNeeded()).thenReturn(false);
 
+        // mock timeout during commit - all offsets from tasks needing commit
+        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets 
= new HashMap<>();
+        expectedCommittedOffsets.putAll(revokedTaskOffsets);
+        expectedCommittedOffsets.putAll(unrevokedTaskOffsets);
         doThrow(new 
TimeoutException()).when(consumer).commitSync(expectedCommittedOffsets);
 
-        taskManager.handleAssignment(assignmentActive, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(revokedActiveTask.state(), is(Task.State.RUNNING));
-        assertThat(unrevokedActiveTaskWithCommitNeeded.state(), 
is(State.RUNNING));
-        assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), 
is(Task.State.RUNNING));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         taskManager.handleRevocation(taskId00Partitions);
 
-        assertThat(revokedActiveTask.state(), is(State.SUSPENDED));
-        assertThat(unrevokedActiveTaskWithCommitNeeded.state(), 
is(State.CREATED));
-        assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), 
is(State.RUNNING));
+        // 1. verify that the revoked task was suspended, closed dirty, and 
revived
+        final InOrder revokedOrder = inOrder(revokedActiveTask, tasks);
+        revokedOrder.verify(revokedActiveTask).prepareCommit(true);
+        revokedOrder.verify(revokedActiveTask).suspend();
+        revokedOrder.verify(revokedActiveTask).closeDirty();
+        revokedOrder.verify(tasks).removeTask(revokedActiveTask);
+        revokedOrder.verify(revokedActiveTask).revive();
+        revokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> 
set.contains(revokedActiveTask)));
+
+        // 2. verify that the unrevoked task with commit also tried to commit 
and was closed dirty due to timeout
+        final InOrder unrevokedOrder = inOrder(unrevokedActiveTaskWithCommit, 
consumer, tasks);
+        
unrevokedOrder.verify(unrevokedActiveTaskWithCommit).prepareCommit(true);
+        unrevokedOrder.verify(consumer).commitSync(expectedCommittedOffsets); 
// timeout thrown here
+        unrevokedOrder.verify(unrevokedActiveTaskWithCommit).suspend();
+        unrevokedOrder.verify(unrevokedActiveTaskWithCommit).closeDirty();
+        unrevokedOrder.verify(tasks).removeTask(unrevokedActiveTaskWithCommit);
+        unrevokedOrder.verify(unrevokedActiveTaskWithCommit).revive();
+        unrevokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> 
set.contains(unrevokedActiveTaskWithCommit)));
+
+        // 3. verify that the unrevoked task without commit needed was not 
affected
+        verify(unrevokedActiveTaskWithoutCommit, 
never()).prepareCommit(anyBoolean());
+        verify(unrevokedActiveTaskWithoutCommit, never()).suspend();
+        verify(unrevokedActiveTaskWithoutCommit, never()).closeDirty();
+
+        // input partitions were reset for affected tasks
+        
verify(revokedActiveTask).addPartitionsForOffsetReset(taskId00Partitions);
+        
verify(unrevokedActiveTaskWithCommit).addPartitionsForOffsetReset(taskId01Partitions);
+        verify(unrevokedActiveTaskWithoutCommit, 
never()).addPartitionsForOffsetReset(any());
     }
 
     @SuppressWarnings("removal")
     @Test
     public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringRevocationWithEOS()
 {
-        final TaskManager taskManager = 
setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, 
false);
-        final StreamsProducer producer = mock(StreamsProducer.class);
-        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
-
-        final StateMachineTask revokedActiveTask = new 
StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
-        final Map<TopicPartition, OffsetAndMetadata> revokedActiveTaskOffsets 
= singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        
revokedActiveTask.setCommittableOffsetsAndMetadata(revokedActiveTaskOffsets);
-        revokedActiveTask.setCommitNeeded();
-
-        final AtomicBoolean unrevokedTaskChangelogMarkedAsCorrupted = new 
AtomicBoolean(false);
-        final StateMachineTask unrevokedActiveTask = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager) {
-            @Override
-            public void markChangelogAsCorrupted(final 
Collection<TopicPartition> partitions) {
-                super.markChangelogAsCorrupted(partitions);
-                unrevokedTaskChangelogMarkedAsCorrupted.set(true);
-            }
-        };
-        final Map<TopicPartition, OffsetAndMetadata> unrevokedTaskOffsets = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
-        
unrevokedActiveTask.setCommittableOffsetsAndMetadata(unrevokedTaskOffsets);
-        unrevokedActiveTask.setCommitNeeded();
-
-        final StateMachineTask unrevokedActiveTaskWithoutCommitNeeded = new 
StateMachineTask(taskId02, taskId02Partitions, true, stateManager);
-
-        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets 
= new HashMap<>();
-        expectedCommittedOffsets.putAll(revokedActiveTaskOffsets);
-        expectedCommittedOffsets.putAll(unrevokedTaskOffsets);
+        // task being revoked - needs commit
+        final StreamTask revokedActiveTask = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
-            );
+        // unrevoked task that needs commit - this will also be affected by 
timeout
+        final StreamTask unrevokedActiveTaskWithCommit = 
statefulTask(taskId01, taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        when(consumer.assignment())
-            .thenReturn(assignment)
-            .thenReturn(union(HashSet::new, taskId00Partitions, 
taskId01Partitions, taskId02Partitions));
+        // unrevoked task without commit needed - this should remain RUNNING
+        final StreamTask unrevokedActiveTaskWithoutCommit = 
statefulTask(taskId02, taskId02ChangelogPartitions)
+            .withInputPartitions(taskId02Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
-            .thenReturn(asList(revokedActiveTask, unrevokedActiveTask, 
unrevokedActiveTaskWithoutCommitNeeded));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(revokedActiveTask, 
unrevokedActiveTaskWithCommit, unrevokedActiveTaskWithoutCommit));
+        when(tasks.tasks(Set.of(taskId00, 
taskId01))).thenReturn(Set.of(revokedActiveTask, 
unrevokedActiveTaskWithCommit));
 
+        final StreamsProducer producer = mock(StreamsProducer.class);
+        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
         final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
         when(consumer.groupMetadata()).thenReturn(groupMetadata);
-
+        when(consumer.assignment()).thenReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
+
+        // revoked task needs commit
+        final Map<TopicPartition, OffsetAndMetadata> revokedTaskOffsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        when(revokedActiveTask.commitNeeded()).thenReturn(true);
+        
when(revokedActiveTask.prepareCommit(true)).thenReturn(revokedTaskOffsets);
+        
when(revokedActiveTask.changelogPartitions()).thenReturn(taskId00ChangelogPartitions);
+        doNothing().when(revokedActiveTask).suspend();
+        doNothing().when(revokedActiveTask).closeDirty();
+        doNothing().when(revokedActiveTask).revive();
+        
doNothing().when(revokedActiveTask).markChangelogAsCorrupted(taskId00ChangelogPartitions);
+
+        // unrevoked task with commit also takes part in EOS-v2 commit
+        final Map<TopicPartition, OffsetAndMetadata> unrevokedTaskOffsets = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
+        when(unrevokedActiveTaskWithCommit.commitNeeded()).thenReturn(true);
+        
when(unrevokedActiveTaskWithCommit.prepareCommit(true)).thenReturn(unrevokedTaskOffsets);
+        
when(unrevokedActiveTaskWithCommit.changelogPartitions()).thenReturn(taskId01ChangelogPartitions);
+        doNothing().when(unrevokedActiveTaskWithCommit).suspend();
+        doNothing().when(unrevokedActiveTaskWithCommit).closeDirty();
+        doNothing().when(unrevokedActiveTaskWithCommit).revive();
+        
doNothing().when(unrevokedActiveTaskWithCommit).markChangelogAsCorrupted(taskId01ChangelogPartitions);
+
+        // unrevoked task without commit needed
+        
when(unrevokedActiveTaskWithoutCommit.commitNeeded()).thenReturn(false);
+
+        // mock timeout during commit - all offsets from tasks needing commit
+        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets 
= new HashMap<>();
+        expectedCommittedOffsets.putAll(revokedTaskOffsets);
+        expectedCommittedOffsets.putAll(unrevokedTaskOffsets);
         doThrow(new 
TimeoutException()).when(producer).commitTransaction(expectedCommittedOffsets, 
groupMetadata);
 
-        taskManager.handleAssignment(assignmentActive, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(revokedActiveTask.state(), is(Task.State.RUNNING));
-        assertThat(unrevokedActiveTask.state(), is(Task.State.RUNNING));
-        assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), 
is(State.RUNNING));
-
-        final Map<TopicPartition, Long> revokedActiveTaskChangelogOffsets = 
singletonMap(t1p0changelog, 0L);
-        
revokedActiveTask.setChangelogOffsets(revokedActiveTaskChangelogOffsets);
-        final Map<TopicPartition, Long> unrevokedActiveTaskChangelogOffsets = 
singletonMap(t1p1changelog, 0L);
-        
unrevokedActiveTask.setChangelogOffsets(unrevokedActiveTaskChangelogOffsets);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks);
 
         taskManager.handleRevocation(taskId00Partitions);
 
-        assertThat(unrevokedTaskChangelogMarkedAsCorrupted.get(), is(true));
-        assertThat(revokedActiveTask.state(), is(State.SUSPENDED));
-        assertThat(unrevokedActiveTask.state(), is(State.CREATED));
-        assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), 
is(State.RUNNING));
-        
verify(stateManager).markChangelogAsCorrupted(taskId00ChangelogPartitions);
-        
verify(stateManager).markChangelogAsCorrupted(taskId01ChangelogPartitions);
+        // 1. verify that the revoked task was suspended, closed dirty, and 
revived
+        final InOrder revokedOrder = inOrder(revokedActiveTask, tasks);
+        revokedOrder.verify(revokedActiveTask).prepareCommit(true);
+        revokedOrder.verify(revokedActiveTask).suspend();
+        revokedOrder.verify(revokedActiveTask).closeDirty();
+        revokedOrder.verify(tasks).removeTask(revokedActiveTask);
+        revokedOrder.verify(revokedActiveTask).revive();
+        revokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> 
set.contains(revokedActiveTask)));
+
+        // 2. verify that the unrevoked task with commit also tried to commit 
and was closed dirty due to timeout
+        final InOrder unrevokedOrder = inOrder(unrevokedActiveTaskWithCommit, 
producer, tasks);
+        
unrevokedOrder.verify(unrevokedActiveTaskWithCommit).prepareCommit(true);
+        
unrevokedOrder.verify(producer).commitTransaction(expectedCommittedOffsets, 
groupMetadata); // timeout thrown here
+        unrevokedOrder.verify(unrevokedActiveTaskWithCommit).suspend();
+        unrevokedOrder.verify(unrevokedActiveTaskWithCommit).closeDirty();
+        unrevokedOrder.verify(tasks).removeTask(unrevokedActiveTaskWithCommit);
+        unrevokedOrder.verify(unrevokedActiveTaskWithCommit).revive();
+        unrevokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> 
set.contains(unrevokedActiveTaskWithCommit)));
+
+        // 3. verify that the unrevoked task without commit needed was not 
affected
+        verify(unrevokedActiveTaskWithoutCommit, 
never()).prepareCommit(anyBoolean());
+        verify(unrevokedActiveTaskWithoutCommit, never()).suspend();
+        verify(unrevokedActiveTaskWithoutCommit, never()).closeDirty();
+
+        // verify input partitions were reset for affected tasks
+        
verify(revokedActiveTask).addPartitionsForOffsetReset(taskId00Partitions);
+        
verify(unrevokedActiveTaskWithCommit).addPartitionsForOffsetReset(taskId01Partitions);
+        verify(unrevokedActiveTaskWithoutCommit, 
never()).addPartitionsForOffsetReset(any());
     }
 
     @Test
     public void shouldCloseStandbyUnassignedTasksWhenCreatingNewTasks() {
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
false, stateManager);
+        final StandbyTask task00 = standbyTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId00Partitions)
+            .build();
 
-        when(consumer.assignment()).thenReturn(assignment);
-        
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singletonList(task00));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.drainPendingTasksToInit()).thenReturn(emptySet());
 
-        taskManager.handleAssignment(emptyMap(), taskId00Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
+        taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
+
+        when(stateUpdater.tasks()).thenReturn(Set.of(task00));
+
+        // mock future for removing task from StateUpdater
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
+        when(stateUpdater.remove(task00.id())).thenReturn(future);
+        future.complete(new StateUpdater.RemovedTaskResult(task00));
 
         taskManager.handleAssignment(emptyMap(), emptyMap());
-        assertThat(task00.state(), is(Task.State.CLOSED));
-        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
-        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+
+        verify(stateUpdater).remove(task00.id());
+        verify(task00).suspend();
+        verify(task00).closeClean();
+
+        verify(activeTaskCreator).createTasks(any(), eq(emptyMap()));
+        verify(standbyTaskCreator).createTasks(emptyMap());
     }
 
     @Test
     public void shouldAddNonResumedSuspendedTasks() {
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
-        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, 
false, stateManager);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
+        final StandbyTask task01 = standbyTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        when(consumer.assignment()).thenReturn(assignment);
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
 
-        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
+        when(tasks.allNonFailedTasks()).thenReturn(Set.of(task00));
+
+        when(tasks.drainPendingTasksToInit()).thenReturn(emptySet());
+        when(tasks.hasPendingTasksToInit()).thenReturn(false);
+
+        taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
+
+        when(stateUpdater.tasks()).thenReturn(Set.of(task01));
+        when(stateUpdater.restoresActiveTasks()).thenReturn(false);
+        when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(false);
 
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
 
-        // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
+        // checkStateUpdater should return true (all tasks ready, no pending 
work)
+        assertTrue(taskManager.checkStateUpdater(time.milliseconds(), 
noOpResetter));
+
+        verify(stateUpdater, never()).add(any(Task.class));
         verify(activeTaskCreator).createTasks(any(), eq(emptyMap()));
-        verify(consumer, times(2)).assignment();
-        verify(consumer, times(2)).resume(assignment);
+        verify(standbyTaskCreator).createTasks(emptyMap());
+
+        // verify idempotence
+        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
+        assertTrue(taskManager.checkStateUpdater(time.milliseconds(), 
noOpResetter));
+        verify(stateUpdater, never()).add(any(Task.class));
     }
 
     @Test
     public void shouldUpdateInputPartitionsAfterRebalance() {
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        when(consumer.assignment()).thenReturn(assignment);
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singletonList(task00));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final Set<TopicPartition> newPartitionsSet = Set.of(t1p1);
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
+        when(tasks.allNonFailedTasks()).thenReturn(Set.of(task00));
+        when(tasks.drainPendingTasksToInit()).thenReturn(emptySet());
+        when(tasks.hasPendingTasksToInit()).thenReturn(false);
+        when(tasks.updateActiveTaskInputPartitions(task00, 
newPartitionsSet)).thenReturn(true);
+
+        taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
+
+        when(stateUpdater.tasks()).thenReturn(emptySet());
+        when(stateUpdater.restoresActiveTasks()).thenReturn(false);
+        when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(false);
 
-        final Set<TopicPartition> newPartitionsSet = Set.of(t1p1);
         final Map<TaskId, Set<TopicPartition>> taskIdSetMap = 
singletonMap(taskId00, newPartitionsSet);
         taskManager.handleAssignment(taskIdSetMap, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+
+        verify(task00).updateInputPartitions(eq(newPartitionsSet), any());
+        assertTrue(taskManager.checkStateUpdater(time.milliseconds(), 
noOpResetter));
         assertThat(task00.state(), is(Task.State.RUNNING));
-        assertEquals(newPartitionsSet, task00.inputPartitions());
-        // expect these calls twice (because we're going to 
tryToCompleteRestoration twice)
-        verify(consumer, times(2)).resume(assignment);
-        verify(consumer, times(2)).assignment();
         verify(activeTaskCreator).createTasks(any(), eq(emptyMap()));
+        verify(standbyTaskCreator).createTasks(emptyMap());
     }
 
     @Test
     public void shouldAddNewActiveTasks() {
+        // task in created state
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.CREATED)
+            .withInputPartitions(taskId00Partitions)
+            .build();
+
         final Map<TaskId, Set<TopicPartition>> assignment = taskId00Assignment;
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
+        // first, we need to handle assignment -- creates tasks and adds to 
pending initialization
         when(activeTaskCreator.createTasks(any(), 
eq(assignment))).thenReturn(singletonList(task00));
 
         taskManager.handleAssignment(assignment, emptyMap());
 
-        assertThat(task00.state(), is(Task.State.CREATED));
+        verify(tasks).addPendingTasksToInit(singletonList(task00));
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter 
-> { });
+        // next, drain pending tasks, initialize them, and then add to 
stateupdater
+        when(tasks.drainPendingTasksToInit()).thenReturn(Set.of(task00));
 
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(taskManager.activeTaskMap(), 
Matchers.equalTo(singletonMap(taskId00, task00)));
-        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
-        verify(changeLogReader).enforceRestoreActive();
-        verify(consumer).assignment();
-        verify(consumer).resume(eq(emptySet()));
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
+
+        verify(task00).initializeIfNeeded();
+        verify(stateUpdater).add(task00);
+
+        // last, drain the restored tasks from stateupdater and transition to 
running
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
+        
when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(Set.of(task00));
+
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
+
+        verifyTransitionToRunningOfRestoredTask(Set.of(task00), tasks);
     }
 
     @Test
@@ -2842,70 +2959,84 @@ public class TaskManagerTest {
     @SuppressWarnings("removal")
     @Test
     public void 
shouldCommitAllActiveTasksThatNeedCommittingOnHandleRevocationWithEosV2() {
+        // task being revoked, needs commit
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
+
+        // unrevoked task that needs commit, this should also be committed 
with EOS-v2
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+
+        // unrevoked task that doesn't need commit, should not be committed
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .withInputPartitions(taskId02Partitions)
+            .inState(State.RUNNING)
+            .build();
+
+        // standby task should not be committed
+        final StandbyTask task10 = standbyTask(taskId10, emptySet())
+            .withInputPartitions(taskId10Partitions)
+            .inState(State.RUNNING)
+            .build();
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+
+        when(tasks.allTasks()).thenReturn(Set.of(task00, task01, task02, 
task10));
+
         final StreamsProducer producer = mock(StreamsProducer.class);
-        final TaskManager taskManager = 
setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, 
false);
+        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
+        final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
+        when(consumer.groupMetadata()).thenReturn(groupMetadata);
 
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
         final Map<TopicPartition, OffsetAndMetadata> offsets00 = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        task00.setCommittableOffsetsAndMetadata(offsets00);
-        task00.setCommitNeeded();
+        when(task00.commitNeeded()).thenReturn(true);
+        when(task00.prepareCommit(true)).thenReturn(offsets00);
+        doNothing().when(task00).postCommit(anyBoolean());
+        doNothing().when(task00).suspend();
 
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
         final Map<TopicPartition, OffsetAndMetadata> offsets01 = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
-        task01.setCommittableOffsetsAndMetadata(offsets01);
-        task01.setCommitNeeded();
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(offsets01);
+        doNothing().when(task01).postCommit(anyBoolean());
 
-        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
-        final Map<TopicPartition, OffsetAndMetadata> offsets02 = 
singletonMap(t1p2, new OffsetAndMetadata(2L, null));
-        task02.setCommittableOffsetsAndMetadata(offsets02);
+        // task02 does not need commit
+        when(task02.commitNeeded()).thenReturn(false);
 
-        final StateMachineTask task10 = new StateMachineTask(taskId10, 
taskId10Partitions, false, stateManager);
+        // standby task should not take part in commit
+        when(task10.commitNeeded()).thenReturn(false);
 
+        // expected committed offsets, only task00 and task01 (both need 
commit)
         final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets 
= new HashMap<>();
         expectedCommittedOffsets.putAll(offsets00);
         expectedCommittedOffsets.putAll(offsets01);
 
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
-        );
-
-        final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
-            mkEntry(taskId10, taskId10Partitions)
-        );
-        when(consumer.assignment()).thenReturn(assignment);
-
-        when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
-            .thenReturn(asList(task00, task01, task02));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks);
 
-        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
-        when(standbyTaskCreator.createTasks(assignmentStandby))
-            .thenReturn(singletonList(task10));
-
-        final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
-        when(consumer.groupMetadata()).thenReturn(groupMetadata);
+        taskManager.handleRevocation(taskId00Partitions);
 
-        task00.committedOffsets();
-        task01.committedOffsets();
-        task02.committedOffsets();
-        task10.committedOffsets();
+        // Verify the commit transaction was called with offsets from task00 
and task01
+        verify(producer).commitTransaction(expectedCommittedOffsets, 
groupMetadata);
 
-        taskManager.handleAssignment(assignmentActive, assignmentStandby);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
-        assertThat(task02.state(), is(Task.State.RUNNING));
-        assertThat(task10.state(), is(Task.State.RUNNING));
+        // Verify task00 (revoked) was suspended and committed
+        verify(task00).prepareCommit(true);
+        verify(task00).postCommit(true);
+        verify(task00).suspend();
 
-        taskManager.handleRevocation(taskId00Partitions);
+        // Verify task01 (unrevoked but needs commit) was also committed
+        verify(task01).prepareCommit(true);
+        verify(task01).postCommit(false);
 
-        assertThat(task00.commitNeeded, is(false));
-        assertThat(task01.commitNeeded, is(false));
-        assertThat(task02.commitPrepared, is(false));
-        assertThat(task10.commitPrepared, is(false));
+        // Verify task02 (doesn't need commit) was not committed
+        verify(task02, never()).prepareCommit(anyBoolean());
+        verify(task02, never()).postCommit(anyBoolean());
 
-        verify(producer).commitTransaction(expectedCommittedOffsets, 
groupMetadata);
+        // Verify standby task10 was not committed
+        verify(task10, never()).prepareCommit(anyBoolean());
+        verify(task10, never()).postCommit(anyBoolean());
     }
 
     @Test
@@ -3772,6 +3903,19 @@ public class TaskManagerTest {
     @SuppressWarnings("removal")
     @Test
     public void shouldCommitViaProducerIfEosV2Enabled() {
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .withInputPartitions(taskId02Partitions)
+            .inState(State.RUNNING)
+            .build();
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task01, task02));
+
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.streamsProducer()).thenReturn(producer);
 
@@ -3781,22 +3925,27 @@ public class TaskManagerTest {
         allOffsets.putAll(offsetsT01);
         allOffsets.putAll(offsetsT02);
 
-        final TaskManager taskManager = 
setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, 
false);
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(offsetsT01);
+        doNothing().when(task01).postCommit(false);
 
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
-        task01.setCommittableOffsetsAndMetadata(offsetsT01);
-        task01.setCommitNeeded();
-        taskManager.addTask(task01);
-        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
-        task02.setCommittableOffsetsAndMetadata(offsetsT02);
-        task02.setCommitNeeded();
-        taskManager.addTask(task02);
+        when(task02.commitNeeded()).thenReturn(true);
+        when(task02.prepareCommit(true)).thenReturn(offsetsT02);
+        doNothing().when(task02).postCommit(false);
 
         when(consumer.groupMetadata()).thenReturn(new 
ConsumerGroupMetadata("appId"));
 
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks);
+
         taskManager.commitAll();
 
         verify(producer).commitTransaction(allOffsets, new 
ConsumerGroupMetadata("appId"));
+        verify(task01, times(2)).commitNeeded();
+        verify(task01).prepareCommit(true);
+        verify(task01).postCommit(false);
+        verify(task02, times(2)).commitNeeded();
+        verify(task02).prepareCommit(true);
+        verify(task02).postCommit(false);
         verifyNoMoreInteractions(producer);
     }
 

Reply via email to