This is an automated email from the ASF dual-hosted git repository.

lucasbru pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new e04fd9d8bd0 KAFKA-19683: Final test replacements [8/N] (#20944)
e04fd9d8bd0 is described below

commit e04fd9d8bd0d8bc806675e2a58765a0b673e267b
Author: Shashank <[email protected]>
AuthorDate: Wed Nov 26 02:05:25 2025 -0800

    KAFKA-19683: Final test replacements [8/N] (#20944)
    
    Last set of test replacements in `TaskManagerTest.java`
    
    Reviewers: Lucas Brutschy <[email protected]>
---
 .../processor/internals/TaskManagerTest.java       | 366 ++++++++++++---------
 1 file changed, 211 insertions(+), 155 deletions(-)

diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index aa2e95623e3..7cfa428d337 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -263,7 +263,6 @@ public class TaskManagerTest {
         return taskManager;
     }
 
-
     @Test
     public void shouldLockAllTasksOnCorruptionWithProcessingThreads() {
         final StreamTask activeTask1 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
@@ -2409,7 +2408,6 @@ public class TaskManagerTest {
         verify(consumer, never()).commitSync(emptyMap());
     }
 
-    @SuppressWarnings("removal")
     @Test
     public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringHandleCorruptedWithEOS()
 {
         final StreamTask corruptedActive = statefulTask(taskId00, 
taskId00ChangelogPartitions)
@@ -2433,7 +2431,8 @@ public class TaskManagerTest {
 
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.streamsProducer()).thenReturn(producer);
-        final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
+        final ConsumerGroupMetadata groupMetadata = 
mock(ConsumerGroupMetadata.class);
+        
         when(consumer.groupMetadata()).thenReturn(groupMetadata);
         when(consumer.assignment()).thenReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions));
 
@@ -2649,7 +2648,6 @@ public class TaskManagerTest {
         verify(unrevokedActiveTaskWithoutCommit, 
never()).addPartitionsForOffsetReset(any());
     }
 
-    @SuppressWarnings("removal")
     @Test
     public void 
shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringRevocationWithEOS()
 {
         // task being revoked - needs commit
@@ -2676,7 +2674,7 @@ public class TaskManagerTest {
 
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.streamsProducer()).thenReturn(producer);
-        final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
+        final ConsumerGroupMetadata groupMetadata = 
mock(ConsumerGroupMetadata.class);
         when(consumer.groupMetadata()).thenReturn(groupMetadata);
         when(consumer.assignment()).thenReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
@@ -2973,7 +2971,6 @@ public class TaskManagerTest {
         verify(task00).suspend();
     }
 
-    @SuppressWarnings("removal")
     @Test
     public void 
shouldCommitAllActiveTasksThatNeedCommittingOnHandleRevocationWithEosV2() {
         // task being revoked, needs commit
@@ -3006,7 +3003,7 @@ public class TaskManagerTest {
 
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.streamsProducer()).thenReturn(producer);
-        final ConsumerGroupMetadata groupMetadata = new 
ConsumerGroupMetadata("appId");
+        final ConsumerGroupMetadata groupMetadata = 
mock(ConsumerGroupMetadata.class);
         when(consumer.groupMetadata()).thenReturn(groupMetadata);
 
         final Map<TopicPartition, OffsetAndMetadata> offsets00 = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
@@ -3396,34 +3393,42 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void 
shouldOnlyCommitRevokedStandbyTaskAndPropagatePrepareCommitException() {
-        setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, 
null, false);
+    public void shouldPropagateSuspendExceptionWhenRevokingStandbyTask() {
+        final StandbyTask task00 = standbyTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId00Partitions)
+            .build();
+        final StandbyTask task01 = standbyTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions)
+            .build();
 
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
false, stateManager);
+        doThrow(new RuntimeException("task 0_1 suspend 
boom!")).when(task01).suspend();
 
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager) {
-            @Override
-            public Map<TopicPartition, OffsetAndMetadata> prepareCommit(final 
boolean clean) {
-                throw new RuntimeException("task 0_1 prepare commit boom!");
-            }
-        };
-        task01.setCommitNeeded();
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks);
+
+        when(stateUpdater.tasks()).thenReturn(Set.of(task00, task01));
 
-        taskManager.addTask(task00);
-        taskManager.addTask(task01);
+        // task01 is revoked, task00 stays
+        final CompletableFuture<StateUpdater.RemovedTaskResult> futureTask01 = 
new CompletableFuture<>();
+        when(stateUpdater.remove(task01.id())).thenReturn(futureTask01);
+        futureTask01.complete(new StateUpdater.RemovedTaskResult(task01));
 
         final RuntimeException thrown = assertThrows(RuntimeException.class,
             () -> taskManager.handleAssignment(
                 Collections.emptyMap(),
                 singletonMap(taskId00, taskId00Partitions)
             ));
-        assertThat(thrown.getCause().getMessage(), is("task 0_1 prepare commit 
boom!"));
-
-        assertThat(task00.state(), is(Task.State.CREATED));
-        assertThat(task01.state(), is(Task.State.CLOSED));
+        assertThat(thrown.getCause().getMessage(), is("task 0_1 suspend 
boom!"));
 
-        // All the tasks involving in the commit should already be removed.
-        assertThat(taskManager.allTasks(), 
is(Collections.singletonMap(taskId00, task00)));
+        verify(task01, times(2)).suspend();
+        verify(task01).closeDirty();
+        verify(stateUpdater, never()).remove(task00.id());
+        verify(task00, never()).suspend();
+        verify(task00, never()).prepareCommit(anyBoolean());
+        verify(task00, never()).closeClean();
+        verify(task00, never()).closeDirty();
     }
 
     @Test
@@ -3462,70 +3467,50 @@ public class TaskManagerTest {
 
     @Test
     public void shouldCloseActiveTasksAndIgnoreExceptionsOnUncleanShutdown() {
-        final TopicPartition changelog = new TopicPartition("changelog", 0);
-        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
-        );
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true, stateManager) {
-            @Override
-            public Set<TopicPartition> changelogPartitions() {
-                return singleton(changelog);
-            }
-        };
-        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, 
true, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new TaskMigratedException("migrated", new 
RuntimeException("cause"));
-            }
-        };
-        final Task task02 = new StateMachineTask(taskId02, taskId02Partitions, 
true, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new RuntimeException("oops");
-            }
-        };
-
-        when(activeTaskCreator.createTasks(any(), 
eq(assignment))).thenReturn(asList(task00, task01, task02));
-        doThrow(new 
RuntimeException("whatever")).when(activeTaskCreator).close();
-
-        taskManager.handleAssignment(assignment, emptyMap());
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId00Partitions)
+            .build();
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions)
+            .build();
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions)
+            .build();
 
-        assertThat(task00.state(), is(Task.State.CREATED));
-        assertThat(task01.state(), is(Task.State.CREATED));
-        assertThat(task02.state(), is(Task.State.CREATED));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), null);
+        doThrow(new TaskMigratedException("migrated", new 
RuntimeException("cause")))
+            .when(task01).suspend();
+        doThrow(new RuntimeException("oops"))
+            .when(task02).suspend();
+        doThrow(new 
RuntimeException("whatever")).when(activeTaskCreator).close();
 
-        assertThat(task00.state(), is(Task.State.RESTORING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
-        assertThat(task02.state(), is(Task.State.RUNNING));
-        assertThat(
-            taskManager.activeTaskMap(),
-            Matchers.equalTo(
-                mkMap(
-                    mkEntry(taskId00, task00),
-                    mkEntry(taskId01, task01),
-                    mkEntry(taskId02, task02)
-                )
-            )
-        );
-        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
-        verify(changeLogReader).enforceRestoreActive();
-        verify(changeLogReader).completedChangelogs();
+        when(tasks.allTasks()).thenReturn(Set.of(task00, task01, task02));
+        when(tasks.activeTasks()).thenReturn(Set.of(task00, task01, task02));
 
         taskManager.shutdown(false);
 
-        assertThat(task00.state(), is(Task.State.CLOSED));
-        assertThat(task01.state(), is(Task.State.CLOSED));
-        assertThat(task02.state(), is(Task.State.CLOSED));
-        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
-        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        verify(task00).prepareCommit(false);
+        verify(task00).suspend();
+        verify(task00).closeDirty();
+        verify(task00, never()).closeClean();
+        verify(task01).prepareCommit(false);
+        verify(task01).suspend();
+        verify(task01).closeDirty();
+        verify(task01, never()).closeClean();
+        verify(task02).prepareCommit(false);
+        verify(task02).suspend();
+        verify(task02).closeDirty();
+        verify(task02, never()).closeClean();
+        verify(tasks).clear();
+
         // the active task creator should also get closed (so that it closes 
the thread producer if applicable)
         verify(activeTaskCreator).close();
+        verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
     }
 
     @Test
@@ -3892,7 +3877,6 @@ public class TaskManagerTest {
         verify(consumer).commitSync(offsets);
     }
 
-    @SuppressWarnings("removal")
     @Test
     public void shouldCommitViaProducerIfEosV2Enabled() {
         final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
@@ -3925,13 +3909,14 @@ public class TaskManagerTest {
         when(task02.prepareCommit(true)).thenReturn(offsetsT02);
         doNothing().when(task02).postCommit(false);
 
-        when(consumer.groupMetadata()).thenReturn(new 
ConsumerGroupMetadata("appId"));
+        final ConsumerGroupMetadata groupMetadata = 
mock(ConsumerGroupMetadata.class);
+        when(consumer.groupMetadata()).thenReturn(groupMetadata);
 
         final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks);
 
         taskManager.commitAll();
 
-        verify(producer).commitTransaction(allOffsets, new 
ConsumerGroupMetadata("appId"));
+        verify(producer).commitTransaction(allOffsets, groupMetadata);
         verify(task01, times(2)).commitNeeded();
         verify(task01).prepareCommit(true);
         verify(task01).postCommit(false);
@@ -4431,23 +4416,33 @@ public class TaskManagerTest {
 
     @Test
     public void 
shouldThrowRuntimeExceptionWhenEncounteredUnknownExceptionDuringTaskClose() {
-        final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new TaskMigratedException("t1 close exception", new 
RuntimeException());
-            }
-        };
+        final StandbyTask migratedTask01 = standbyTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions)
+            .build();
+        final StandbyTask migratedTask02 = standbyTask(taskId02, 
taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions)
+            .build();
 
-        final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, 
taskId02Partitions, false, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new IllegalStateException("t2 illegal state exception", 
new RuntimeException());
-            }
-        };
-        taskManager.addTask(migratedTask01);
-        taskManager.addTask(migratedTask02);
+        doThrow(new TaskMigratedException("t1 close exception", new 
RuntimeException()))
+            .when(migratedTask01).suspend();
+        doThrow(new IllegalStateException("t2 illegal state exception", new 
RuntimeException()))
+            .when(migratedTask02).suspend();
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
+
+        when(stateUpdater.tasks()).thenReturn(Set.of(migratedTask01, 
migratedTask02));
+
+        // mock futures for removing tasks from StateUpdater
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future01 = new 
CompletableFuture<>();
+        when(stateUpdater.remove(taskId01)).thenReturn(future01);
+        future01.complete(new StateUpdater.RemovedTaskResult(migratedTask01));
+
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future02 = new 
CompletableFuture<>();
+        when(stateUpdater.remove(taskId02)).thenReturn(future02);
+        future02.complete(new StateUpdater.RemovedTaskResult(migratedTask02));
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
@@ -4457,27 +4452,42 @@ public class TaskManagerTest {
         assertThat(thrown.getMessage(), equalTo("Encounter unexpected fatal 
error for task 0_2"));
 
         assertThat(thrown.getCause().getMessage(), equalTo("t2 illegal state 
exception"));
+
+        verify(migratedTask01, times(2)).suspend();
+        verify(migratedTask02, times(2)).suspend();
+        verify(stateUpdater).remove(taskId01);
+        verify(stateUpdater).remove(taskId02);
     }
 
     @Test
     public void shouldThrowSameKafkaExceptionWhenEncounteredDuringTaskClose() {
-        final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new TaskMigratedException("t1 close exception", new 
RuntimeException());
-            }
-        };
+        final StandbyTask migratedTask01 = standbyTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions)
+            .build();
+        final StandbyTask migratedTask02 = standbyTask(taskId02, 
taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions)
+            .build();
 
-        final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, 
taskId02Partitions, false, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new KafkaException("Kaboom for t2!", new 
RuntimeException());
-            }
-        };
-        taskManager.addTask(migratedTask01);
-        taskManager.addTask(migratedTask02);
+        doThrow(new TaskMigratedException("t1 close exception", new 
RuntimeException()))
+            .when(migratedTask01).suspend();
+        doThrow(new KafkaException("Kaboom for t2!", new RuntimeException()))
+            .when(migratedTask02).suspend();
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
+
+        when(stateUpdater.tasks()).thenReturn(Set.of(migratedTask01, 
migratedTask02));
+
+        // mock futures for removing tasks from StateUpdater
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future01 = new 
CompletableFuture<>();
+        when(stateUpdater.remove(taskId01)).thenReturn(future01);
+        future01.complete(new StateUpdater.RemovedTaskResult(migratedTask01));
+
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future02 = new 
CompletableFuture<>();
+        when(stateUpdater.remove(taskId02)).thenReturn(future02);
+        future02.complete(new StateUpdater.RemovedTaskResult(migratedTask02));
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
@@ -4489,6 +4499,11 @@ public class TaskManagerTest {
 
         // Expecting the original Kafka exception wrapped in the 
StreamsException.
         assertThat(thrown.getCause().getMessage(), equalTo("Kaboom for t2!"));
+
+        verify(migratedTask01, times(2)).suspend();
+        verify(migratedTask02, times(2)).suspend();
+        verify(stateUpdater).remove(taskId01);
+        verify(stateUpdater).remove(taskId02);
     }
 
     @Test
@@ -4640,26 +4655,40 @@ public class TaskManagerTest {
 
     @Test
     public void 
shouldThrowTaskCorruptedExceptionForTimeoutExceptionOnCommitWithEosV2() {
-        final TaskManager taskManager = 
setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, 
false);
-
-        final StreamsProducer producer = mock(StreamsProducer.class);
-        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .withInputPartitions(taskId02Partitions)
+            .inState(State.RUNNING)
+            .build();
 
         final Map<TopicPartition, OffsetAndMetadata> offsetsT00 = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         final Map<TopicPartition, OffsetAndMetadata> offsetsT01 = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
         final Map<TopicPartition, OffsetAndMetadata> allOffsets = new 
HashMap<>(offsetsT00);
         allOffsets.putAll(offsetsT01);
 
-        doThrow(new 
TimeoutException("KABOOM!")).doNothing().when(producer).commitTransaction(allOffsets,
 null);
+        when(task00.commitNeeded()).thenReturn(true);
+        when(task00.prepareCommit(true)).thenReturn(offsetsT00);
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(offsetsT01);
+        when(task02.commitNeeded()).thenReturn(false);
 
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        task00.setCommittableOffsetsAndMetadata(offsetsT00);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
-        task01.setCommittableOffsetsAndMetadata(offsetsT01);
-        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
+        final StreamsProducer producer = mock(StreamsProducer.class);
+        when(activeTaskCreator.streamsProducer()).thenReturn(producer);
+        final ConsumerGroupMetadata groupMetadata = 
mock(ConsumerGroupMetadata.class);
+        when(consumer.groupMetadata()).thenReturn(groupMetadata);
+
+        doThrow(new 
TimeoutException("KABOOM!")).when(producer).commitTransaction(allOffsets, 
groupMetadata);
 
-        task00.setCommitNeeded();
-        task01.setCommitNeeded();
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks);
 
         final TaskCorruptedException exception = assertThrows(
             TaskCorruptedException.class,
@@ -4675,67 +4704,94 @@ public class TaskManagerTest {
 
     @Test
     public void shouldStreamsExceptionOnCommitError() {
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        task01.setCommittableOffsetsAndMetadata(offsets);
-        task01.setCommitNeeded();
-        taskManager.addTask(task01);
+
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(offsets);
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task01));
+
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         doThrow(new KafkaException()).when(consumer).commitSync(offsets);
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> taskManager.commitAll()
+            taskManager::commitAll
         );
 
         assertThat(thrown.getCause(), instanceOf(KafkaException.class));
         assertThat(thrown.getMessage(), equalTo("Error encountered committing 
offsets via consumer"));
-        assertThat(task01.state(), is(Task.State.CREATED));
+
+        verify(task01).commitNeeded();
+        verify(task01).prepareCommit(true);
     }
 
     @Test
     public void shouldFailOnCommitFatal() {
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        task01.setCommittableOffsetsAndMetadata(offsets);
-        task01.setCommitNeeded();
-        taskManager.addTask(task01);
+
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(offsets);
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task01));
+
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         doThrow(new 
RuntimeException("KABOOM")).when(consumer).commitSync(offsets);
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
-            () -> taskManager.commitAll()
+            taskManager::commitAll
         );
 
         assertThat(thrown.getMessage(), equalTo("KABOOM"));
-        assertThat(task01.state(), is(Task.State.CREATED));
+
+        verify(task01).commitNeeded();
+        verify(task01).prepareCommit(true);
     }
 
     @Test
     public void 
shouldSuspendAllTasksButSkipCommitIfSuspendingFailsDuringRevocation() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager) {
-            @Override
-            public void suspend() {
-                super.suspend();
-                throw new RuntimeException("KABOOM!");
-            }
-        };
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId00Partitions)
+            .build();
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions)
+            .build();
 
-        final Map<TaskId, Set<TopicPartition>> assignment = new 
HashMap<>(taskId00Assignment);
-        assignment.putAll(taskId01Assignment);
-        when(activeTaskCreator.createTasks(any(), 
eq(assignment))).thenReturn(asList(task00, task01));
+        doThrow(new RuntimeException("KABOOM!")).when(task00).suspend();
 
-        taskManager.handleAssignment(assignment, Collections.emptyMap());
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+
+        when(tasks.allTasks()).thenReturn(Set.of(task00, task01));
+
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
-            () -> taskManager.handleRevocation(asList(t1p0, t1p1)));
+            () -> taskManager.handleRevocation(union(HashSet::new, 
taskId00Partitions, taskId01Partitions)));
 
         assertThat(thrown.getCause().getMessage(), is("KABOOM!"));
-        assertThat(task00.state(), is(Task.State.SUSPENDED));
-        assertThat(task01.state(), is(Task.State.SUSPENDED));
+
+        // verify both tasks had suspend called
+        verify(task00).suspend();
+        verify(task01).suspend();
+
         verifyNoInteractions(consumer);
     }
 

Reply via email to