This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 0158e1d  MINOR: Add 'task container' class to KafkaStreams TaskManager 
(#9835)
0158e1d is described below

commit 0158e1d7196748aa346da90f688048a32f75c2d1
Author: Matthias J. Sax <[email protected]>
AuthorDate: Tue Jan 19 19:28:46 2021 -0800

    MINOR: Add 'task container' class to KafkaStreams TaskManager (#9835)
    
    Kafka Streams' TaskManager is a central class that grew quite big. This
    PR breaks out a new 'task container' class to descope what TaskManager
    does. In follow up PRs, we plan to move more methods from TaskManager
    to the new 'Tasks.java' class and also improve task-type type safety.
    
    Reviewers: A. Sophie Blee-Goldman <[email protected]>
---
 .../processor/internals/ActiveTaskCreator.java     |   2 +
 .../processor/internals/StandbyTaskCreator.java    |   2 +
 .../streams/processor/internals/StreamThread.java  |   1 +
 .../streams/processor/internals/TaskManager.java   | 197 ++++++--------
 .../kafka/streams/processor/internals/Tasks.java   | 295 +++++++++++++++++++++
 .../processor/internals/StreamThreadTest.java      |   2 +
 .../processor/internals/TaskManagerTest.java       |  40 +--
 7 files changed, 397 insertions(+), 142 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index 1bf0603..482a2c5 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -132,8 +132,10 @@ class ActiveTaskCreator {
         return threadProducer;
     }
 
+    // TODO: change return type to `StreamTask`
     Collection<Task> createTasks(final Consumer<byte[], byte[]> consumer,
                                  final Map<TaskId, Set<TopicPartition>> 
tasksToBeCreated) {
+        // TODO: change type to `StreamTask`
         final List<Task> createdTasks = new ArrayList<>();
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions 
: tasksToBeCreated.entrySet()) {
             final TaskId taskId = newTaskAndPartitions.getKey();
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
index 3576378..56d4220 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -66,7 +66,9 @@ class StandbyTaskCreator {
         );
     }
 
+    // TODO: change return type to `StandbyTask`
     Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>> 
tasksToBeCreated) {
+        // TODO: change type to `StandbyTask`
         final List<Task> createdTasks = new ArrayList<>();
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions 
: tasksToBeCreated.entrySet()) {
             final TaskId taskId = newTaskAndPartitions.getKey();
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 12c24aa..e07b92f 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -364,6 +364,7 @@ public class StreamThread extends Thread {
             changelogReader,
             processId,
             logPrefix,
+            streamsMetrics,
             activeTaskCreator,
             standbyTaskCreator,
             builder,
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index f321dfc..fd57cab 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -37,6 +37,7 @@ import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.errors.TaskTimeoutExceptions;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task.State;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.slf4j.Logger;
 
@@ -54,7 +55,6 @@ import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.TreeMap;
 import java.util.TreeSet;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicReference;
@@ -75,16 +75,11 @@ public class TaskManager {
     private final ChangelogReader changelogReader;
     private final UUID processId;
     private final String logPrefix;
-    private final ActiveTaskCreator activeTaskCreator;
-    private final StandbyTaskCreator standbyTaskCreator;
     private final InternalTopologyBuilder builder;
     private final Admin adminClient;
     private final StateDirectory stateDirectory;
     private final StreamThread.ProcessingMode processingMode;
-
-    private final Map<TaskId, Task> tasks = new TreeMap<>();
-    // materializing this relationship because the lookup is on the hot path
-    private final Map<TopicPartition, Task> partitionToTask = new HashMap<>();
+    private final Tasks tasks;
 
     private Consumer<byte[], byte[]> mainConsumer;
 
@@ -100,6 +95,7 @@ public class TaskManager {
                 final ChangelogReader changelogReader,
                 final UUID processId,
                 final String logPrefix,
+                final StreamsMetricsImpl streamsMetrics,
                 final ActiveTaskCreator activeTaskCreator,
                 final StandbyTaskCreator standbyTaskCreator,
                 final InternalTopologyBuilder builder,
@@ -110,12 +106,11 @@ public class TaskManager {
         this.changelogReader = changelogReader;
         this.processId = processId;
         this.logPrefix = logPrefix;
-        this.activeTaskCreator = activeTaskCreator;
-        this.standbyTaskCreator = standbyTaskCreator;
         this.builder = builder;
         this.adminClient = adminClient;
         this.stateDirectory = stateDirectory;
         this.processingMode = processingMode;
+        this.tasks = new Tasks(logPrefix, builder,  streamsMetrics, 
activeTaskCreator, standbyTaskCreator);
 
         final LogContext logContext = new LogContext(logPrefix);
         log = logContext.logger(getClass());
@@ -123,6 +118,7 @@ public class TaskManager {
 
     void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
         this.mainConsumer = mainConsumer;
+        tasks.setMainConsumer(mainConsumer);
     }
 
     public UUID processId() {
@@ -155,13 +151,16 @@ public class TaskManager {
         rebalanceInProgress = false;
     }
 
-    void handleCorruption(final Map<TaskId, Collection<TopicPartition>> 
tasksWithChangelogs) throws TaskMigratedException {
+    /**
+     * @throws TaskMigratedException
+     */
+    void handleCorruption(final Map<TaskId, Collection<TopicPartition>> 
tasksWithChangelogs) {
         final Map<Task, Collection<TopicPartition>> corruptedStandbyTasks = 
new HashMap<>();
         final Map<Task, Collection<TopicPartition>> corruptedActiveTasks = new 
HashMap<>();
 
         for (final Map.Entry<TaskId, Collection<TopicPartition>> taskEntry : 
tasksWithChangelogs.entrySet()) {
             final TaskId taskId = taskEntry.getKey();
-            final Task task = tasks.get(taskId);
+            final Task task = tasks.task(taskId);
             if (task.isActive()) {
                 corruptedActiveTasks.put(task, taskEntry.getValue());
             } else {
@@ -212,7 +211,7 @@ public class TaskManager {
 
             // For active tasks pause their input partitions so we won't poll 
any more records
             // for this task until it has been re-initialized;
-            // Note, closeDirty already clears the partitiongroup for the task.
+            // Note, closeDirty already clears the partition-group for the 
task.
             if (task.isActive()) {
                 final Set<TopicPartition> currentAssignment = 
mainConsumer.assignment();
                 final Set<TopicPartition> taskInputPartitions = 
task.inputPartitions();
@@ -273,12 +272,12 @@ public class TaskManager {
         final Set<Task> tasksToCloseDirty = new TreeSet<>(byId);
 
         // first rectify all existing tasks
-        for (final Task task : tasks.values()) {
+        for (final Task task : tasks.allTasks()) {
             if (activeTasks.containsKey(task.id()) && task.isActive()) {
-                updateInputPartitionsAndResume(task, 
activeTasks.get(task.id()));
+                tasks.updateInputPartitionsAndResume(task, 
activeTasks.get(task.id()));
                 activeTasksToCreate.remove(task.id());
             } else if (standbyTasks.containsKey(task.id()) && 
!task.isActive()) {
-                updateInputPartitionsAndResume(task, 
standbyTasks.get(task.id()));
+                tasks.updateInputPartitionsAndResume(task, 
standbyTasks.get(task.id()));
                 standbyTasksToCreate.remove(task.id());
             } else if (activeTasks.containsKey(task.id()) || 
standbyTasks.containsKey(task.id())) {
                 // check for tasks that were owned previously but have changed 
active/standby status
@@ -289,7 +288,14 @@ public class TaskManager {
         }
 
         // close and recycle those tasks
-        handleCloseAndRecycle(tasksToRecycle, tasksToCloseClean, 
tasksToCloseDirty, activeTasksToCreate, standbyTasksToCreate, 
taskCloseExceptions);
+        handleCloseAndRecycle(
+            tasksToRecycle,
+            tasksToCloseClean,
+            tasksToCloseDirty,
+            activeTasksToCreate,
+            standbyTasksToCreate,
+            taskCloseExceptions
+        );
 
         if (!taskCloseExceptions.isEmpty()) {
             log.error("Hit exceptions while closing / recycling tasks: {}", 
taskCloseExceptions);
@@ -313,17 +319,7 @@ public class TaskManager {
             throw first.getValue();
         }
 
-        if (!activeTasksToCreate.isEmpty()) {
-            for (final Task task : activeTaskCreator.createTasks(mainConsumer, 
activeTasksToCreate)) {
-                addNewTask(task);
-            }
-        }
-
-        if (!standbyTasksToCreate.isEmpty()) {
-            for (final Task task : 
standbyTaskCreator.createTasks(standbyTasksToCreate)) {
-                addNewTask(task);
-            }
-        }
+        tasks.createTasks(activeTasksToCreate, standbyTasksToCreate);
     }
 
     private void handleCloseAndRecycle(final Set<Task> tasksToRecycle,
@@ -377,8 +373,9 @@ public class TaskManager {
         for (final Task task : tasksToCloseClean) {
             try {
                 completeTaskCloseClean(task);
-                cleanUpTaskProducer(task, taskCloseExceptions);
-                tasks.remove(task.id());
+                if (task.isActive()) {
+                    tasks.cleanUpTaskProducerAndRemoveTask(task.id(), 
taskCloseExceptions);
+                }
             } catch (final RuntimeException e) {
                 final String uncleanMessage = String.format(
                         "Failed to close task %s cleanly. Attempting to close 
remaining tasks before re-throwing:",
@@ -390,71 +387,28 @@ public class TaskManager {
         }
 
         tasksToRecycle.removeAll(tasksToCloseDirty);
-        for (final Task task : tasksToRecycle) {
+        for (final Task oldTask : tasksToRecycle) {
             final Task newTask;
             try {
-                if (task.isActive()) {
-                    final Set<TopicPartition> partitions = 
standbyTasksToCreate.remove(task.id());
-                    newTask = 
standbyTaskCreator.createStandbyTaskFromActive((StreamTask) task, partitions);
-                    cleanUpTaskProducer(task, taskCloseExceptions);
+                if (oldTask.isActive()) {
+                    final Set<TopicPartition> partitions = 
standbyTasksToCreate.remove(oldTask.id());
+                    tasks.convertActiveToStandby((StreamTask) oldTask, 
partitions, taskCloseExceptions);
                 } else {
-                    final Set<TopicPartition> partitions = 
activeTasksToCreate.remove(task.id());
-                    newTask = 
activeTaskCreator.createActiveTaskFromStandby((StandbyTask) task, partitions, 
mainConsumer);
+                    final Set<TopicPartition> partitions = 
activeTasksToCreate.remove(oldTask.id());
+                    tasks.convertStandbyToActive((StandbyTask) oldTask, 
partitions);
                 }
-                tasks.remove(task.id());
-                addNewTask(newTask);
             } catch (final RuntimeException e) {
-                final String uncleanMessage = String.format("Failed to recycle 
task %s cleanly. Attempting to close remaining tasks before re-throwing:", 
task.id());
+                final String uncleanMessage = String.format("Failed to recycle 
task %s cleanly. Attempting to close remaining tasks before re-throwing:", 
oldTask.id());
                 log.error(uncleanMessage, e);
-                taskCloseExceptions.putIfAbsent(task.id(), e);
-                tasksToCloseDirty.add(task);
+                taskCloseExceptions.putIfAbsent(oldTask.id(), e);
+                tasksToCloseDirty.add(oldTask);
             }
         }
 
         // for tasks that cannot be cleanly closed or recycled, close them 
dirty
         for (final Task task : tasksToCloseDirty) {
             closeTaskDirty(task);
-            cleanUpTaskProducer(task, taskCloseExceptions);
-            tasks.remove(task.id());
-        }
-    }
-
-    private void cleanUpTaskProducer(final Task task,
-                                     final Map<TaskId, RuntimeException> 
taskCloseExceptions) {
-        if (task.isActive()) {
-            try {
-                
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
-            } catch (final RuntimeException e) {
-                final String uncleanMessage = String.format("Failed to close 
task %s cleanly. Attempting to close remaining tasks before re-throwing:", 
task.id());
-                log.error(uncleanMessage, e);
-                taskCloseExceptions.putIfAbsent(task.id(), e);
-            }
-        }
-    }
-
-    private void updateInputPartitionsAndResume(final Task task, final 
Set<TopicPartition> topicPartitions) {
-        final boolean requiresUpdate = 
!task.inputPartitions().equals(topicPartitions);
-        if (requiresUpdate) {
-            log.debug("Update task {} inputPartitions: current {}, new {}", 
task, task.inputPartitions(), topicPartitions);
-            for (final TopicPartition inputPartition : task.inputPartitions()) 
{
-                partitionToTask.remove(inputPartition);
-            }
-            for (final TopicPartition topicPartition : topicPartitions) {
-                partitionToTask.put(topicPartition, task);
-            }
-            task.updateInputPartitions(topicPartitions, 
builder.nodeToSourceTopics());
-        }
-        task.resume();
-    }
-
-    private void addNewTask(final Task task) {
-        final Task previous = tasks.put(task.id(), task);
-        if (previous != null) {
-            throw new IllegalStateException("Attempted to create a task that 
we already owned: " + task.id());
-        }
-
-        for (final TopicPartition topicPartition : task.inputPartitions()) {
-            partitionToTask.put(topicPartition, task);
+            tasks.cleanUpTaskProducerAndRemoveTask(task.id(), 
taskCloseExceptions);
         }
     }
 
@@ -469,7 +423,7 @@ public class TaskManager {
         boolean allRunning = true;
 
         final List<Task> activeTasks = new LinkedList<>();
-        for (final Task task : tasks.values()) {
+        for (final Task task : tasks.allTasks()) {
             try {
                 task.initializeIfNeeded();
                 task.clearTaskTimeout();
@@ -647,25 +601,19 @@ public class TaskManager {
     void handleLostAll() {
         log.debug("Closing lost active tasks as zombies.");
 
-        final Iterator<Task> iterator = tasks.values().iterator();
-        while (iterator.hasNext()) {
-            final Task task = iterator.next();
+        final Set<Task> allTask = new HashSet<>(tasks.allTasks());
+        for (final Task task : allTask) {
             // Even though we've apparently dropped out of the group, we can 
continue safely to maintain our
             // standby tasks while we rejoin.
             if (task.isActive()) {
                 closeTaskDirty(task);
-                iterator.remove();
 
-                try {
-                    
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
-                } catch (final RuntimeException e) {
-                    log.warn("Error closing task producer for " + task.id() + 
" while handling lostAll", e);
-                }
+                tasks.cleanUpTaskProducerAndRemoveTask(task.id(), new 
HashMap<>());
             }
         }
 
         if (processingMode == EXACTLY_ONCE_BETA) {
-            activeTaskCreator.reInitializeThreadProducer();
+            tasks.reInitializeThreadProducer();
         }
     }
 
@@ -680,8 +628,8 @@ public class TaskManager {
         // Not all tasks will create directories, and there may be directories 
for tasks we don't currently own,
         // so we consider all tasks that are either owned or on disk. This 
includes stateless tasks, which should
         // just have an empty changelogOffsets map.
-        for (final TaskId id : union(HashSet::new, lockedTaskDirectories, 
tasks.keySet())) {
-            final Task task = tasks.get(id);
+        for (final TaskId id : union(HashSet::new, lockedTaskDirectories, 
tasks.tasksPerId().keySet())) {
+            final Task task = tasks.owned(id) ? tasks.task(id) : null;
             // Closed and uninitialized tasks don't have any offsets so we 
should read directly from the checkpoint
             if (task != null && task.state() != State.CREATED && task.state() 
!= State.CLOSED) {
                 final Map<TopicPartition, Long> changelogOffsets = 
task.changelogOffsets();
@@ -722,7 +670,7 @@ public class TaskManager {
                 try {
                     if (stateDirectory.lock(id)) {
                         lockedTaskDirectories.add(id);
-                        if (!tasks.containsKey(id)) {
+                        if (!tasks.owned(id)) {
                             log.debug("Temporarily locked unassigned task {} 
for the upcoming rebalance", id);
                         }
                     }
@@ -746,7 +694,7 @@ public class TaskManager {
         final Iterator<TaskId> taskIdIterator = 
lockedTaskDirectories.iterator();
         while (taskIdIterator.hasNext()) {
             final TaskId id = taskIdIterator.next();
-            if (!tasks.containsKey(id)) {
+            if (!tasks.owned(id)) {
                 try {
                     stateDirectory.unlock(id);
                     taskIdIterator.remove();
@@ -804,26 +752,22 @@ public class TaskManager {
         } catch (final RuntimeException swallow) {
             log.error("Error suspending dirty task {} ", task.id(), swallow);
         }
-        cleanupTask(task);
+        tasks.removeTaskBeforeClosing(task.id());
         task.closeDirty();
     }
 
     private void completeTaskCloseClean(final Task task) {
-        cleanupTask(task);
+        tasks.removeTaskBeforeClosing(task.id());
         task.closeClean();
     }
 
-    // Note: this MUST be called *before* actually closing the task
-    private void cleanupTask(final Task task) {
-        for (final TopicPartition inputPartition : task.inputPartitions()) {
-            partitionToTask.remove(inputPartition);
-        }
-    }
-
     void shutdown(final boolean clean) {
         final AtomicReference<RuntimeException> firstException = new 
AtomicReference<>(null);
 
         final Set<Task> tasksToCloseDirty = new HashSet<>();
+        // TODO: change type to `StreamTask`
+        final Set<Task> activeTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
+        activeTasks.addAll(tasks.activeTasks());
         tasksToCloseDirty.addAll(tryCloseCleanAllActiveTasks(clean, 
firstException));
         tasksToCloseDirty.addAll(tryCloseCleanAllStandbyTasks(clean, 
firstException));
 
@@ -831,18 +775,19 @@ public class TaskManager {
             closeTaskDirty(task);
         }
 
-        for (final Task task : activeTaskIterable()) {
+        // TODO: change type to `StreamTask`
+        for (final Task activeTask : activeTasks) {
             executeAndMaybeSwallow(
                 clean,
-                () -> 
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id()),
+                () -> tasks.closeAndRemoveTaskProducerIfNeeded(activeTask),
                 e -> firstException.compareAndSet(null, e),
-                e -> log.warn("Ignoring an exception while closing task " + 
task.id() + " producer.", e)
+                e -> log.warn("Ignoring an exception while closing task " + 
activeTask.id() + " producer.", e)
             );
         }
 
         executeAndMaybeSwallow(
             clean,
-            activeTaskCreator::closeThreadProducerIfNeeded,
+            tasks::closeThreadProducerIfNeeded,
             e -> firstException.compareAndSet(null, e),
             e -> log.warn("Ignoring an exception while closing thread 
producer.", e)
         );
@@ -987,7 +932,7 @@ public class TaskManager {
     Map<TaskId, Task> tasks() {
         // not bothering with an unmodifiable map, since the tasks themselves 
are mutable, but
         // if any outside code modifies the map or the tasks, it would be a 
severe transgression.
-        return tasks;
+        return tasks.tasksPerId();
     }
 
     Map<TaskId, Task> activeTaskMap() {
@@ -999,7 +944,7 @@ public class TaskManager {
     }
 
     private Stream<Task> activeTaskStream() {
-        return tasks.values().stream().filter(Task::isActive);
+        return tasks.allTasks().stream().filter(Task::isActive);
     }
 
     Map<TaskId, Task> standbyTaskMap() {
@@ -1011,12 +956,12 @@ public class TaskManager {
     }
 
     private Stream<Task> standbyTaskStream() {
-        return tasks.values().stream().filter(t -> !t.isActive());
+        return tasks.allTasks().stream().filter(t -> !t.isActive());
     }
 
     // For testing only.
     int commitAll() {
-        return commit(new HashSet<>(tasks.values()));
+        return commit(new HashSet<>(tasks.allTasks()));
     }
 
     /**
@@ -1026,15 +971,16 @@ public class TaskManager {
      */
     void addRecordsToTasks(final ConsumerRecords<byte[], byte[]> records) {
         for (final TopicPartition partition : records.partitions()) {
-            final Task task = partitionToTask.get(partition);
+            // TODO: change type to `StreamTask`
+            final Task activeTask = 
tasks.activeTasksForInputPartition(partition);
 
-            if (task == null) {
+            if (activeTask == null) {
                 log.error("Unable to locate active task for received-record 
partition {}. Current tasks: {}",
                     partition, toString(">"));
                 throw new NullPointerException("Task was unexpectedly missing 
for partition " + partition);
             }
 
-            task.addRecords(partition, records.records(partition));
+            activeTask.addRecords(partition, records.records(partition));
         }
     }
 
@@ -1106,7 +1052,7 @@ public class TaskManager {
         }
     }
 
-    private void commitOffsetsOrTransaction(final Map<Task, 
Map<TopicPartition, OffsetAndMetadata>> offsetsPerTask) throws 
TaskTimeoutExceptions {
+    private void commitOffsetsOrTransaction(final Map<Task, 
Map<TopicPartition, OffsetAndMetadata>> offsetsPerTask) {
         log.debug("Committing task offsets {}", offsetsPerTask);
 
         TaskTimeoutExceptions timeoutExceptions = null;
@@ -1116,7 +1062,7 @@ public class TaskManager {
                 for (final Map.Entry<Task, Map<TopicPartition, 
OffsetAndMetadata>> taskToCommit : offsetsPerTask.entrySet()) {
                     final Task task = taskToCommit.getKey();
                     try {
-                        activeTaskCreator.streamsProducerForTask(task.id())
+                        tasks.streamsProducerForTask(task.id())
                             .commitTransaction(taskToCommit.getValue(), 
mainConsumer.groupMetadata());
                     } catch (final TimeoutException timeoutException) {
                         if (timeoutExceptions == null) {
@@ -1131,7 +1077,7 @@ public class TaskManager {
 
                 if (processingMode == EXACTLY_ONCE_BETA) {
                     try {
-                        
activeTaskCreator.threadProducer().commitTransaction(allOffsets, 
mainConsumer.groupMetadata());
+                        tasks.threadProducer().commitTransaction(allOffsets, 
mainConsumer.groupMetadata());
                     } catch (final TimeoutException timeoutException) {
                         throw new TaskTimeoutExceptions(timeoutException);
                     }
@@ -1258,7 +1204,7 @@ public class TaskManager {
         stringBuilder.append("TaskManager\n");
         stringBuilder.append(indent).append("\tMetadataState:\n");
         stringBuilder.append(indent).append("\tTasks:\n");
-        for (final Task task : tasks.values()) {
+        for (final Task task : tasks.allTasks()) {
             stringBuilder.append(indent)
                          .append("\t\t")
                          .append(task.id())
@@ -1272,11 +1218,11 @@ public class TaskManager {
     }
 
     Map<MetricName, Metric> producerMetrics() {
-        return activeTaskCreator.producerMetrics();
+        return tasks.producerMetrics();
     }
 
     Set<String> producerClientIds() {
-        return activeTaskCreator.producerClientIds();
+        return tasks.producerClientIds();
     }
 
     Set<TaskId> lockedTaskDirectories() {
@@ -1314,4 +1260,9 @@ public class TaskManager {
     public void setPartitionResetter(final 
java.util.function.Consumer<Set<TopicPartition>> resetter) {
         this.resetter = resetter;
     }
+
+    // for testing only
+    void addTask(final Task task) {
+        tasks.addTask(task);
+    }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
new file mode 100644
index 0000000..843db80
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.slf4j.Logger;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+
+class Tasks {
+    private final Logger log;
+    private final InternalTopologyBuilder builder;
+    private final StreamsMetricsImpl streamsMetrics;
+
+    private final Map<TaskId, Task> allTasksPerId = new TreeMap<>();
+    private final Map<TaskId, Task> readOnlyTasksPerId = 
Collections.unmodifiableMap(allTasksPerId);
+    private final Collection<Task> readOnlyTasks = 
Collections.unmodifiableCollection(allTasksPerId.values());
+
+    // TODO: change type to `StreamTask`
+    private final Map<TaskId, Task> activeTasksPerId = new TreeMap<>();
+    // TODO: change type to `StreamTask`
+    private final Map<TopicPartition, Task> activeTasksPerPartition = new 
HashMap<>();
+    // TODO: change type to `StreamTask`
+    private final Map<TaskId, Task> readOnlyActiveTasksPerId = 
Collections.unmodifiableMap(activeTasksPerId);
+    private final Set<TaskId> readOnlyActiveTaskIds = 
Collections.unmodifiableSet(activeTasksPerId.keySet());
+    // TODO: change type to `StreamTask`
+    private final Collection<Task> readOnlyActiveTasks = 
Collections.unmodifiableCollection(activeTasksPerId.values());
+
+    // TODO: change type to `StandbyTask`
+    private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>();
+    // TODO: change type to `StandbyTask`
+    private final Map<TaskId, Task> readOnlyStandbyTasksPerId = 
Collections.unmodifiableMap(standbyTasksPerId);
+    private final Set<TaskId> readOnlyStandbyTaskIds = 
Collections.unmodifiableSet(standbyTasksPerId.keySet());
+
+    private final ActiveTaskCreator activeTaskCreator;
+    private final StandbyTaskCreator standbyTaskCreator;
+
+    private Consumer<byte[], byte[]> mainConsumer;
+
+    Tasks(final String logPrefix,
+          final InternalTopologyBuilder builder,
+          final StreamsMetricsImpl streamsMetrics,
+          final ActiveTaskCreator activeTaskCreator,
+          final StandbyTaskCreator standbyTaskCreator) {
+
+        final LogContext logContext = new LogContext(logPrefix);
+        log = logContext.logger(getClass());
+
+        this.builder = builder;
+        this.streamsMetrics = streamsMetrics;
+        this.activeTaskCreator = activeTaskCreator;
+        this.standbyTaskCreator = standbyTaskCreator;
+    }
+
+    void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
+        this.mainConsumer = mainConsumer;
+    }
+
+    void createTasks(final Map<TaskId, Set<TopicPartition>> 
activeTasksToCreate,
+                     final Map<TaskId, Set<TopicPartition>> 
standbyTasksToCreate) {
+        for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated : 
activeTasksToCreate.entrySet()) {
+            final TaskId taskId = taskToBeCreated.getKey();
+
+            if (activeTasksPerId.containsKey(taskId)) {
+                throw new IllegalStateException("Attempted to create an active 
task that we already own: " + taskId);
+            }
+        }
+
+        for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated : 
standbyTasksToCreate.entrySet()) {
+            final TaskId taskId = taskToBeCreated.getKey();
+
+            if (standbyTasksPerId.containsKey(taskId)) {
+                throw new IllegalStateException("Attempted to create a standby 
task that we already own: " + taskId);
+            }
+        }
+
+        // keep this check to simplify testing (ie, no need to mock 
`activeTaskCreator`)
+        if (!activeTasksToCreate.isEmpty()) {
+            // TODO: change type to `StreamTask`
+            for (final Task activeTask : 
activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate)) {
+                activeTasksPerId.put(activeTask.id(), activeTask);
+                allTasksPerId.put(activeTask.id(), activeTask);
+                for (final TopicPartition topicPartition : 
activeTask.inputPartitions()) {
+                    activeTasksPerPartition.put(topicPartition, activeTask);
+                }
+            }
+        }
+
+        // keep this check to simplify testing (ie, no need to mock 
`standbyTaskCreator`)
+        if (!standbyTasksToCreate.isEmpty()) {
+            // TODO: change type to `StandbyTask`
+            for (final Task standbyTask : 
standbyTaskCreator.createTasks(standbyTasksToCreate)) {
+                standbyTasksPerId.put(standbyTask.id(), standbyTask);
+                allTasksPerId.put(standbyTask.id(), standbyTask);
+            }
+        }
+    }
+
+    void convertActiveToStandby(final StreamTask activeTask,
+                                final Set<TopicPartition> partitions,
+                                final Map<TaskId, RuntimeException> 
taskCloseExceptions) {
+        if (activeTasksPerId.remove(activeTask.id()) == null) {
+            throw new IllegalStateException("Attempted to convert unknown 
active task to standby task: " + activeTask.id());
+        }
+        activeTasksPerPartition.entrySet().stream()
+            .filter(e -> e.getValue().id().equals(activeTask.id()))
+            .forEach(e -> activeTasksPerPartition.remove(e.getKey()));
+
+        cleanUpTaskProducerAndRemoveTask(activeTask.id(), taskCloseExceptions);
+
+        final StandbyTask standbyTask = 
standbyTaskCreator.createStandbyTaskFromActive(activeTask, partitions);
+        standbyTasksPerId.put(standbyTask.id(), standbyTask);
+        allTasksPerId.put(standbyTask.id(), standbyTask);
+    }
+
+    void convertStandbyToActive(final StandbyTask standbyTask, final 
Set<TopicPartition> partitions) {
+        if (standbyTasksPerId.remove(standbyTask.id()) == null) {
+            throw new IllegalStateException("Attempted to convert unknown 
standby task to stream task: " + standbyTask.id());
+        }
+
+        final StreamTask activeTask = 
activeTaskCreator.createActiveTaskFromStandby(standbyTask, partitions, 
mainConsumer);
+        activeTasksPerId.put(activeTask.id(), activeTask);
+        for (final TopicPartition topicPartition : 
activeTask.inputPartitions()) {
+            activeTasksPerPartition.put(topicPartition, activeTask);
+        }
+        allTasksPerId.put(activeTask.id(), activeTask);
+    }
+
+    void updateInputPartitionsAndResume(final Task task, final 
Set<TopicPartition> topicPartitions) {
+        final boolean requiresUpdate = 
!task.inputPartitions().equals(topicPartitions);
+        if (requiresUpdate) {
+            log.debug("Update task {} inputPartitions: current {}, new {}", 
task, task.inputPartitions(), topicPartitions);
+            for (final TopicPartition inputPartition : task.inputPartitions()) 
{
+                activeTasksPerPartition.remove(inputPartition);
+            }
+            if (task.isActive()) {
+                for (final TopicPartition topicPartition : topicPartitions) {
+                    activeTasksPerPartition.put(topicPartition, task);
+                }
+            }
+            task.updateInputPartitions(topicPartitions, 
builder.nodeToSourceTopics());
+        }
+        task.resume();
+    }
+
+    void cleanUpTaskProducerAndRemoveTask(final TaskId taskId,
+                                          final Map<TaskId, RuntimeException> 
taskCloseExceptions) {
+        try {
+            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId);
+        } catch (final RuntimeException e) {
+            final String uncleanMessage = String.format("Failed to close task 
%s cleanly. Attempting to close remaining tasks before re-throwing:", taskId);
+            log.error(uncleanMessage, e);
+            taskCloseExceptions.putIfAbsent(taskId, e);
+        }
+        removeTaskBeforeClosing(taskId);
+    }
+
+    void reInitializeThreadProducer() {
+        activeTaskCreator.reInitializeThreadProducer();
+    }
+
+    void closeThreadProducerIfNeeded() {
+        activeTaskCreator.closeThreadProducerIfNeeded();
+    }
+
+    // TODO: change type to `StreamTask`
+    void closeAndRemoveTaskProducerIfNeeded(final Task activeTask) {
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTask.id());
+    }
+
+    void removeTaskBeforeClosing(final TaskId taskId) {
+        activeTasksPerId.remove(taskId);
+        final Set<TopicPartition> toBeRemoved = 
activeTasksPerPartition.entrySet().stream()
+            .filter(e -> e.getValue().id().equals(taskId))
+            .map(Map.Entry::getKey)
+            .collect(Collectors.toSet());
+        toBeRemoved.forEach(activeTasksPerPartition::remove);
+        standbyTasksPerId.remove(taskId);
+        allTasksPerId.remove(taskId);
+    }
+
+    void clear() {
+        activeTasksPerId.clear();
+        activeTasksPerPartition.clear();
+        standbyTasksPerId.clear();
+        allTasksPerId.clear();
+    }
+
+    // TODO: change return type to `StreamTask`
+    Task activeTasksForInputPartition(final TopicPartition partition) {
+        return activeTasksPerPartition.get(partition);
+    }
+
+    // TODO: change return type to `StandbyTask`
+    Task standbyTask(final TaskId taskId) {
+        if (!standbyTasksPerId.containsKey(taskId)) {
+            throw new IllegalStateException("Standby task unknown: " + taskId);
+        }
+        return standbyTasksPerId.get(taskId);
+    }
+
+    Task task(final TaskId taskId) {
+        if (!allTasksPerId.containsKey(taskId)) {
+            throw new IllegalStateException("Task unknown: " + taskId);
+        }
+        return allTasksPerId.get(taskId);
+    }
+
+    // TODO: change return type to `StreamTask`
+    Collection<Task> activeTasks() {
+        return readOnlyActiveTasks;
+    }
+
+    Collection<Task> allTasks() {
+        return readOnlyTasks;
+    }
+
+    Set<TaskId> activeTaskIds() {
+        return readOnlyActiveTaskIds;
+    }
+
+    Set<TaskId> standbyTaskIds() {
+        return readOnlyStandbyTaskIds;
+    }
+
+    // TODO: change return type to `StreamTask`
+    Map<TaskId, Task> activeTaskMap() {
+        return readOnlyActiveTasksPerId;
+    }
+
+    // TODO: change return type to `StandbyTask`
+    Map<TaskId, Task> standbyTaskMap() {
+        return readOnlyStandbyTasksPerId;
+    }
+
+    Map<TaskId, Task> tasksPerId() {
+        return readOnlyTasksPerId;
+    }
+
+    boolean owned(final TaskId taskId) {
+        return allTasksPerId.containsKey(taskId);
+    }
+
+    StreamsProducer streamsProducerForTask(final TaskId taskId) {
+        return activeTaskCreator.streamsProducerForTask(taskId);
+    }
+
+    StreamsProducer threadProducer() {
+        return activeTaskCreator.threadProducer();
+    }
+
+    Map<MetricName, Metric> producerMetrics() {
+        return activeTaskCreator.producerMetrics();
+    }
+
+    Set<String> producerClientIds() {
+        return activeTaskCreator.producerClientIds();
+    }
+
+    // for testing only
+    void addTask(final Task task) {
+        if (task.isActive()) {
+            activeTasksPerId.put(task.id(), task);
+        } else {
+            standbyTasksPerId.put(task.id(), task);
+        }
+        allTasksPerId.put(task.id(), task);
+    }
+}
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 157f9e5..401a830 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -767,6 +767,7 @@ public class StreamThreadTest {
             null,
             null,
             null,
+            null,
             null
         ) {
             @Override
@@ -2564,6 +2565,7 @@ public class StreamThreadTest {
         );
     }
 
+    // TODO: change return type to `StandbyTask`
     private Collection<Task> createStandbyTask() {
         final LogContext logContext = new LogContext("test");
         final Logger log = logContext.logger(StreamThreadTest.class);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 0eb7817..04bd252 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -33,8 +33,10 @@ import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.internals.KafkaFutureImpl;
 import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.Measurable;
+import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
@@ -42,6 +44,7 @@ import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode;
 import org.apache.kafka.streams.processor.internals.Task.State;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import 
org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.easymock.EasyMock;
@@ -171,6 +174,7 @@ public class TaskManagerTest {
             changeLogReader,
             UUID.randomUUID(),
             "taskManagerTest",
+            new StreamsMetricsImpl(new Metrics(), "clientId", 
StreamsConfig.METRICS_LATEST, time),
             activeTaskCreator,
             standbyTaskCreator,
             topologyBuilder,
@@ -1495,8 +1499,8 @@ public class TaskManagerTest {
         };
         task01.setCommitNeeded();
 
-        taskManager.tasks().put(taskId00, task00);
-        taskManager.tasks().put(taskId01, task01);
+        taskManager.addTask(task00);
+        taskManager.addTask(task01);
 
         final RuntimeException thrown = assertThrows(RuntimeException.class,
             () -> taskManager.handleAssignment(
@@ -1514,8 +1518,6 @@ public class TaskManagerTest {
 
     @Test
     public void 
shouldSuspendAllRevokedActiveTasksAndPropagateSuspendException() {
-        setUpTaskManager(StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA);
-
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, 
true);
 
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true) {
@@ -1528,9 +1530,9 @@ public class TaskManagerTest {
 
         final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true);
 
-        taskManager.tasks().put(taskId00, task00);
-        taskManager.tasks().put(taskId01, task01);
-        taskManager.tasks().put(taskId02, task02);
+        taskManager.addTask(task00);
+        taskManager.addTask(task01);
+        taskManager.addTask(task02);
 
         replay(activeTaskCreator);
 
@@ -1857,7 +1859,7 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.tasks().put(taskId01, task01);
+        taskManager.addTask(task01);
 
         consumer.commitSync(offsets);
         expectLastCall();
@@ -1912,11 +1914,11 @@ public class TaskManagerTest {
         final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true);
         task01.setCommittableOffsetsAndMetadata(offsetsT01);
         task01.setCommitNeeded();
-        taskManager.tasks().put(taskId01, task01);
+        taskManager.addTask(task01);
         final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true);
         task02.setCommittableOffsetsAndMetadata(offsetsT02);
         task02.setCommitNeeded();
-        taskManager.tasks().put(taskId02, task02);
+        taskManager.addTask(task02);
 
         reset(consumer);
         expect(consumer.groupMetadata()).andReturn(new 
ConsumerGroupMetadata("appId")).anyTimes();
@@ -2400,8 +2402,8 @@ public class TaskManagerTest {
                 throw new TaskMigratedException("t2 close exception", new 
RuntimeException());
             }
         };
-        taskManager.tasks().put(taskId01, migratedTask01);
-        taskManager.tasks().put(taskId02, migratedTask02);
+        taskManager.addTask(migratedTask01);
+        taskManager.addTask(migratedTask02);
 
         final TaskMigratedException thrown = assertThrows(
             TaskMigratedException.class,
@@ -2432,8 +2434,8 @@ public class TaskManagerTest {
                 throw new IllegalStateException("t2 illegal state exception", 
new RuntimeException());
             }
         };
-        taskManager.tasks().put(taskId01, migratedTask01);
-        taskManager.tasks().put(taskId02, migratedTask02);
+        taskManager.addTask(migratedTask01);
+        taskManager.addTask(migratedTask02);
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
@@ -2463,8 +2465,8 @@ public class TaskManagerTest {
                 throw new KafkaException("Kaboom for t2!", new 
RuntimeException());
             }
         };
-        taskManager.tasks().put(taskId01, migratedTask01);
-        taskManager.tasks().put(taskId02, migratedTask02);
+        taskManager.addTask(migratedTask01);
+        taskManager.addTask(migratedTask02);
 
         final KafkaException thrown = assertThrows(
             KafkaException.class,
@@ -2573,7 +2575,7 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.tasks().put(taskId01, task01);
+        taskManager.addTask(task01);
 
         consumer.commitSync(offsets);
         expectLastCall().andThrow(new CommitFailedException());
@@ -2712,7 +2714,7 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.tasks().put(taskId01, task01);
+        taskManager.addTask(task01);
 
         consumer.commitSync(offsets);
         expectLastCall().andThrow(new KafkaException());
@@ -2734,7 +2736,7 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task01.setCommittableOffsetsAndMetadata(offsets);
         task01.setCommitNeeded();
-        taskManager.tasks().put(taskId01, task01);
+        taskManager.addTask(task01);
 
         consumer.commitSync(offsets);
         expectLastCall().andThrow(new RuntimeException("KABOOM"));

Reply via email to