This is an automated email from the ASF dual-hosted git repository.
mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push:
new 0158e1d MINOR: Add 'task container' class to KafkaStreams TaskManager
(#9835)
0158e1d is described below
commit 0158e1d7196748aa346da90f688048a32f75c2d1
Author: Matthias J. Sax <[email protected]>
AuthorDate: Tue Jan 19 19:28:46 2021 -0800
MINOR: Add 'task container' class to KafkaStreams TaskManager (#9835)
Kafka Streams' TaskManager is a central class that grew quite big. This
PR breaks out a new 'task container' class to descope what TaskManager
does. In follow up PRs, we plan to move more methods from TaskManager
to the new 'Tasks.java' class and also improve task-type type safety.
Reviewers: A. Sophie Blee-Goldman <[email protected]>
---
.../processor/internals/ActiveTaskCreator.java | 2 +
.../processor/internals/StandbyTaskCreator.java | 2 +
.../streams/processor/internals/StreamThread.java | 1 +
.../streams/processor/internals/TaskManager.java | 197 ++++++--------
.../kafka/streams/processor/internals/Tasks.java | 295 +++++++++++++++++++++
.../processor/internals/StreamThreadTest.java | 2 +
.../processor/internals/TaskManagerTest.java | 40 +--
7 files changed, 397 insertions(+), 142 deletions(-)
diff --git
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index 1bf0603..482a2c5 100644
---
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -132,8 +132,10 @@ class ActiveTaskCreator {
return threadProducer;
}
+ // TODO: change return type to `StreamTask`
Collection<Task> createTasks(final Consumer<byte[], byte[]> consumer,
final Map<TaskId, Set<TopicPartition>>
tasksToBeCreated) {
+ // TODO: change type to `StreamTask`
final List<Task> createdTasks = new ArrayList<>();
for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions
: tasksToBeCreated.entrySet()) {
final TaskId taskId = newTaskAndPartitions.getKey();
diff --git
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
index 3576378..56d4220 100644
---
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
+++
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -66,7 +66,9 @@ class StandbyTaskCreator {
);
}
+ // TODO: change return type to `StandbyTask`
Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>>
tasksToBeCreated) {
+ // TODO: change type to `StandbyTask`
final List<Task> createdTasks = new ArrayList<>();
for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions
: tasksToBeCreated.entrySet()) {
final TaskId taskId = newTaskAndPartitions.getKey();
diff --git
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 12c24aa..e07b92f 100644
---
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -364,6 +364,7 @@ public class StreamThread extends Thread {
changelogReader,
processId,
logPrefix,
+ streamsMetrics,
activeTaskCreator,
standbyTaskCreator,
builder,
diff --git
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index f321dfc..fd57cab 100644
---
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -37,6 +37,7 @@ import org.apache.kafka.streams.errors.TaskMigratedException;
import org.apache.kafka.streams.errors.TaskTimeoutExceptions;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.Task.State;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
import org.slf4j.Logger;
@@ -54,7 +55,6 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
@@ -75,16 +75,11 @@ public class TaskManager {
private final ChangelogReader changelogReader;
private final UUID processId;
private final String logPrefix;
- private final ActiveTaskCreator activeTaskCreator;
- private final StandbyTaskCreator standbyTaskCreator;
private final InternalTopologyBuilder builder;
private final Admin adminClient;
private final StateDirectory stateDirectory;
private final StreamThread.ProcessingMode processingMode;
-
- private final Map<TaskId, Task> tasks = new TreeMap<>();
- // materializing this relationship because the lookup is on the hot path
- private final Map<TopicPartition, Task> partitionToTask = new HashMap<>();
+ private final Tasks tasks;
private Consumer<byte[], byte[]> mainConsumer;
@@ -100,6 +95,7 @@ public class TaskManager {
final ChangelogReader changelogReader,
final UUID processId,
final String logPrefix,
+ final StreamsMetricsImpl streamsMetrics,
final ActiveTaskCreator activeTaskCreator,
final StandbyTaskCreator standbyTaskCreator,
final InternalTopologyBuilder builder,
@@ -110,12 +106,11 @@ public class TaskManager {
this.changelogReader = changelogReader;
this.processId = processId;
this.logPrefix = logPrefix;
- this.activeTaskCreator = activeTaskCreator;
- this.standbyTaskCreator = standbyTaskCreator;
this.builder = builder;
this.adminClient = adminClient;
this.stateDirectory = stateDirectory;
this.processingMode = processingMode;
+ this.tasks = new Tasks(logPrefix, builder, streamsMetrics,
activeTaskCreator, standbyTaskCreator);
final LogContext logContext = new LogContext(logPrefix);
log = logContext.logger(getClass());
@@ -123,6 +118,7 @@ public class TaskManager {
void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
this.mainConsumer = mainConsumer;
+ tasks.setMainConsumer(mainConsumer);
}
public UUID processId() {
@@ -155,13 +151,16 @@ public class TaskManager {
rebalanceInProgress = false;
}
- void handleCorruption(final Map<TaskId, Collection<TopicPartition>>
tasksWithChangelogs) throws TaskMigratedException {
+ /**
+ * @throws TaskMigratedException
+ */
+ void handleCorruption(final Map<TaskId, Collection<TopicPartition>>
tasksWithChangelogs) {
final Map<Task, Collection<TopicPartition>> corruptedStandbyTasks =
new HashMap<>();
final Map<Task, Collection<TopicPartition>> corruptedActiveTasks = new
HashMap<>();
for (final Map.Entry<TaskId, Collection<TopicPartition>> taskEntry :
tasksWithChangelogs.entrySet()) {
final TaskId taskId = taskEntry.getKey();
- final Task task = tasks.get(taskId);
+ final Task task = tasks.task(taskId);
if (task.isActive()) {
corruptedActiveTasks.put(task, taskEntry.getValue());
} else {
@@ -212,7 +211,7 @@ public class TaskManager {
// For active tasks pause their input partitions so we won't poll
any more records
// for this task until it has been re-initialized;
- // Note, closeDirty already clears the partitiongroup for the task.
+ // Note, closeDirty already clears the partition-group for the
task.
if (task.isActive()) {
final Set<TopicPartition> currentAssignment =
mainConsumer.assignment();
final Set<TopicPartition> taskInputPartitions =
task.inputPartitions();
@@ -273,12 +272,12 @@ public class TaskManager {
final Set<Task> tasksToCloseDirty = new TreeSet<>(byId);
// first rectify all existing tasks
- for (final Task task : tasks.values()) {
+ for (final Task task : tasks.allTasks()) {
if (activeTasks.containsKey(task.id()) && task.isActive()) {
- updateInputPartitionsAndResume(task,
activeTasks.get(task.id()));
+ tasks.updateInputPartitionsAndResume(task,
activeTasks.get(task.id()));
activeTasksToCreate.remove(task.id());
} else if (standbyTasks.containsKey(task.id()) &&
!task.isActive()) {
- updateInputPartitionsAndResume(task,
standbyTasks.get(task.id()));
+ tasks.updateInputPartitionsAndResume(task,
standbyTasks.get(task.id()));
standbyTasksToCreate.remove(task.id());
} else if (activeTasks.containsKey(task.id()) ||
standbyTasks.containsKey(task.id())) {
// check for tasks that were owned previously but have changed
active/standby status
@@ -289,7 +288,14 @@ public class TaskManager {
}
// close and recycle those tasks
- handleCloseAndRecycle(tasksToRecycle, tasksToCloseClean,
tasksToCloseDirty, activeTasksToCreate, standbyTasksToCreate,
taskCloseExceptions);
+ handleCloseAndRecycle(
+ tasksToRecycle,
+ tasksToCloseClean,
+ tasksToCloseDirty,
+ activeTasksToCreate,
+ standbyTasksToCreate,
+ taskCloseExceptions
+ );
if (!taskCloseExceptions.isEmpty()) {
log.error("Hit exceptions while closing / recycling tasks: {}",
taskCloseExceptions);
@@ -313,17 +319,7 @@ public class TaskManager {
throw first.getValue();
}
- if (!activeTasksToCreate.isEmpty()) {
- for (final Task task : activeTaskCreator.createTasks(mainConsumer,
activeTasksToCreate)) {
- addNewTask(task);
- }
- }
-
- if (!standbyTasksToCreate.isEmpty()) {
- for (final Task task :
standbyTaskCreator.createTasks(standbyTasksToCreate)) {
- addNewTask(task);
- }
- }
+ tasks.createTasks(activeTasksToCreate, standbyTasksToCreate);
}
private void handleCloseAndRecycle(final Set<Task> tasksToRecycle,
@@ -377,8 +373,9 @@ public class TaskManager {
for (final Task task : tasksToCloseClean) {
try {
completeTaskCloseClean(task);
- cleanUpTaskProducer(task, taskCloseExceptions);
- tasks.remove(task.id());
+ if (task.isActive()) {
+ tasks.cleanUpTaskProducerAndRemoveTask(task.id(),
taskCloseExceptions);
+ }
} catch (final RuntimeException e) {
final String uncleanMessage = String.format(
"Failed to close task %s cleanly. Attempting to close
remaining tasks before re-throwing:",
@@ -390,71 +387,28 @@ public class TaskManager {
}
tasksToRecycle.removeAll(tasksToCloseDirty);
- for (final Task task : tasksToRecycle) {
+ for (final Task oldTask : tasksToRecycle) {
final Task newTask;
try {
- if (task.isActive()) {
- final Set<TopicPartition> partitions =
standbyTasksToCreate.remove(task.id());
- newTask =
standbyTaskCreator.createStandbyTaskFromActive((StreamTask) task, partitions);
- cleanUpTaskProducer(task, taskCloseExceptions);
+ if (oldTask.isActive()) {
+ final Set<TopicPartition> partitions =
standbyTasksToCreate.remove(oldTask.id());
+ tasks.convertActiveToStandby((StreamTask) oldTask,
partitions, taskCloseExceptions);
} else {
- final Set<TopicPartition> partitions =
activeTasksToCreate.remove(task.id());
- newTask =
activeTaskCreator.createActiveTaskFromStandby((StandbyTask) task, partitions,
mainConsumer);
+ final Set<TopicPartition> partitions =
activeTasksToCreate.remove(oldTask.id());
+ tasks.convertStandbyToActive((StandbyTask) oldTask,
partitions);
}
- tasks.remove(task.id());
- addNewTask(newTask);
} catch (final RuntimeException e) {
- final String uncleanMessage = String.format("Failed to recycle
task %s cleanly. Attempting to close remaining tasks before re-throwing:",
task.id());
+ final String uncleanMessage = String.format("Failed to recycle
task %s cleanly. Attempting to close remaining tasks before re-throwing:",
oldTask.id());
log.error(uncleanMessage, e);
- taskCloseExceptions.putIfAbsent(task.id(), e);
- tasksToCloseDirty.add(task);
+ taskCloseExceptions.putIfAbsent(oldTask.id(), e);
+ tasksToCloseDirty.add(oldTask);
}
}
// for tasks that cannot be cleanly closed or recycled, close them
dirty
for (final Task task : tasksToCloseDirty) {
closeTaskDirty(task);
- cleanUpTaskProducer(task, taskCloseExceptions);
- tasks.remove(task.id());
- }
- }
-
- private void cleanUpTaskProducer(final Task task,
- final Map<TaskId, RuntimeException>
taskCloseExceptions) {
- if (task.isActive()) {
- try {
-
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
- } catch (final RuntimeException e) {
- final String uncleanMessage = String.format("Failed to close
task %s cleanly. Attempting to close remaining tasks before re-throwing:",
task.id());
- log.error(uncleanMessage, e);
- taskCloseExceptions.putIfAbsent(task.id(), e);
- }
- }
- }
-
- private void updateInputPartitionsAndResume(final Task task, final
Set<TopicPartition> topicPartitions) {
- final boolean requiresUpdate =
!task.inputPartitions().equals(topicPartitions);
- if (requiresUpdate) {
- log.debug("Update task {} inputPartitions: current {}, new {}",
task, task.inputPartitions(), topicPartitions);
- for (final TopicPartition inputPartition : task.inputPartitions())
{
- partitionToTask.remove(inputPartition);
- }
- for (final TopicPartition topicPartition : topicPartitions) {
- partitionToTask.put(topicPartition, task);
- }
- task.updateInputPartitions(topicPartitions,
builder.nodeToSourceTopics());
- }
- task.resume();
- }
-
- private void addNewTask(final Task task) {
- final Task previous = tasks.put(task.id(), task);
- if (previous != null) {
- throw new IllegalStateException("Attempted to create a task that
we already owned: " + task.id());
- }
-
- for (final TopicPartition topicPartition : task.inputPartitions()) {
- partitionToTask.put(topicPartition, task);
+ tasks.cleanUpTaskProducerAndRemoveTask(task.id(),
taskCloseExceptions);
}
}
@@ -469,7 +423,7 @@ public class TaskManager {
boolean allRunning = true;
final List<Task> activeTasks = new LinkedList<>();
- for (final Task task : tasks.values()) {
+ for (final Task task : tasks.allTasks()) {
try {
task.initializeIfNeeded();
task.clearTaskTimeout();
@@ -647,25 +601,19 @@ public class TaskManager {
void handleLostAll() {
log.debug("Closing lost active tasks as zombies.");
- final Iterator<Task> iterator = tasks.values().iterator();
- while (iterator.hasNext()) {
- final Task task = iterator.next();
+ final Set<Task> allTask = new HashSet<>(tasks.allTasks());
+ for (final Task task : allTask) {
// Even though we've apparently dropped out of the group, we can
continue safely to maintain our
// standby tasks while we rejoin.
if (task.isActive()) {
closeTaskDirty(task);
- iterator.remove();
- try {
-
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
- } catch (final RuntimeException e) {
- log.warn("Error closing task producer for " + task.id() +
" while handling lostAll", e);
- }
+ tasks.cleanUpTaskProducerAndRemoveTask(task.id(), new
HashMap<>());
}
}
if (processingMode == EXACTLY_ONCE_BETA) {
- activeTaskCreator.reInitializeThreadProducer();
+ tasks.reInitializeThreadProducer();
}
}
@@ -680,8 +628,8 @@ public class TaskManager {
// Not all tasks will create directories, and there may be directories
for tasks we don't currently own,
// so we consider all tasks that are either owned or on disk. This
includes stateless tasks, which should
// just have an empty changelogOffsets map.
- for (final TaskId id : union(HashSet::new, lockedTaskDirectories,
tasks.keySet())) {
- final Task task = tasks.get(id);
+ for (final TaskId id : union(HashSet::new, lockedTaskDirectories,
tasks.tasksPerId().keySet())) {
+ final Task task = tasks.owned(id) ? tasks.task(id) : null;
// Closed and uninitialized tasks don't have any offsets so we
should read directly from the checkpoint
if (task != null && task.state() != State.CREATED && task.state()
!= State.CLOSED) {
final Map<TopicPartition, Long> changelogOffsets =
task.changelogOffsets();
@@ -722,7 +670,7 @@ public class TaskManager {
try {
if (stateDirectory.lock(id)) {
lockedTaskDirectories.add(id);
- if (!tasks.containsKey(id)) {
+ if (!tasks.owned(id)) {
log.debug("Temporarily locked unassigned task {}
for the upcoming rebalance", id);
}
}
@@ -746,7 +694,7 @@ public class TaskManager {
final Iterator<TaskId> taskIdIterator =
lockedTaskDirectories.iterator();
while (taskIdIterator.hasNext()) {
final TaskId id = taskIdIterator.next();
- if (!tasks.containsKey(id)) {
+ if (!tasks.owned(id)) {
try {
stateDirectory.unlock(id);
taskIdIterator.remove();
@@ -804,26 +752,22 @@ public class TaskManager {
} catch (final RuntimeException swallow) {
log.error("Error suspending dirty task {} ", task.id(), swallow);
}
- cleanupTask(task);
+ tasks.removeTaskBeforeClosing(task.id());
task.closeDirty();
}
private void completeTaskCloseClean(final Task task) {
- cleanupTask(task);
+ tasks.removeTaskBeforeClosing(task.id());
task.closeClean();
}
- // Note: this MUST be called *before* actually closing the task
- private void cleanupTask(final Task task) {
- for (final TopicPartition inputPartition : task.inputPartitions()) {
- partitionToTask.remove(inputPartition);
- }
- }
-
void shutdown(final boolean clean) {
final AtomicReference<RuntimeException> firstException = new
AtomicReference<>(null);
final Set<Task> tasksToCloseDirty = new HashSet<>();
+ // TODO: change type to `StreamTask`
+ final Set<Task> activeTasks = new
TreeSet<>(Comparator.comparing(Task::id));
+ activeTasks.addAll(tasks.activeTasks());
tasksToCloseDirty.addAll(tryCloseCleanAllActiveTasks(clean,
firstException));
tasksToCloseDirty.addAll(tryCloseCleanAllStandbyTasks(clean,
firstException));
@@ -831,18 +775,19 @@ public class TaskManager {
closeTaskDirty(task);
}
- for (final Task task : activeTaskIterable()) {
+ // TODO: change type to `StreamTask`
+ for (final Task activeTask : activeTasks) {
executeAndMaybeSwallow(
clean,
- () ->
activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id()),
+ () -> tasks.closeAndRemoveTaskProducerIfNeeded(activeTask),
e -> firstException.compareAndSet(null, e),
- e -> log.warn("Ignoring an exception while closing task " +
task.id() + " producer.", e)
+ e -> log.warn("Ignoring an exception while closing task " +
activeTask.id() + " producer.", e)
);
}
executeAndMaybeSwallow(
clean,
- activeTaskCreator::closeThreadProducerIfNeeded,
+ tasks::closeThreadProducerIfNeeded,
e -> firstException.compareAndSet(null, e),
e -> log.warn("Ignoring an exception while closing thread
producer.", e)
);
@@ -987,7 +932,7 @@ public class TaskManager {
Map<TaskId, Task> tasks() {
// not bothering with an unmodifiable map, since the tasks themselves
are mutable, but
// if any outside code modifies the map or the tasks, it would be a
severe transgression.
- return tasks;
+ return tasks.tasksPerId();
}
Map<TaskId, Task> activeTaskMap() {
@@ -999,7 +944,7 @@ public class TaskManager {
}
private Stream<Task> activeTaskStream() {
- return tasks.values().stream().filter(Task::isActive);
+ return tasks.allTasks().stream().filter(Task::isActive);
}
Map<TaskId, Task> standbyTaskMap() {
@@ -1011,12 +956,12 @@ public class TaskManager {
}
private Stream<Task> standbyTaskStream() {
- return tasks.values().stream().filter(t -> !t.isActive());
+ return tasks.allTasks().stream().filter(t -> !t.isActive());
}
// For testing only.
int commitAll() {
- return commit(new HashSet<>(tasks.values()));
+ return commit(new HashSet<>(tasks.allTasks()));
}
/**
@@ -1026,15 +971,16 @@ public class TaskManager {
*/
void addRecordsToTasks(final ConsumerRecords<byte[], byte[]> records) {
for (final TopicPartition partition : records.partitions()) {
- final Task task = partitionToTask.get(partition);
+ // TODO: change type to `StreamTask`
+ final Task activeTask =
tasks.activeTasksForInputPartition(partition);
- if (task == null) {
+ if (activeTask == null) {
log.error("Unable to locate active task for received-record
partition {}. Current tasks: {}",
partition, toString(">"));
throw new NullPointerException("Task was unexpectedly missing
for partition " + partition);
}
- task.addRecords(partition, records.records(partition));
+ activeTask.addRecords(partition, records.records(partition));
}
}
@@ -1106,7 +1052,7 @@ public class TaskManager {
}
}
- private void commitOffsetsOrTransaction(final Map<Task,
Map<TopicPartition, OffsetAndMetadata>> offsetsPerTask) throws
TaskTimeoutExceptions {
+ private void commitOffsetsOrTransaction(final Map<Task,
Map<TopicPartition, OffsetAndMetadata>> offsetsPerTask) {
log.debug("Committing task offsets {}", offsetsPerTask);
TaskTimeoutExceptions timeoutExceptions = null;
@@ -1116,7 +1062,7 @@ public class TaskManager {
for (final Map.Entry<Task, Map<TopicPartition,
OffsetAndMetadata>> taskToCommit : offsetsPerTask.entrySet()) {
final Task task = taskToCommit.getKey();
try {
- activeTaskCreator.streamsProducerForTask(task.id())
+ tasks.streamsProducerForTask(task.id())
.commitTransaction(taskToCommit.getValue(),
mainConsumer.groupMetadata());
} catch (final TimeoutException timeoutException) {
if (timeoutExceptions == null) {
@@ -1131,7 +1077,7 @@ public class TaskManager {
if (processingMode == EXACTLY_ONCE_BETA) {
try {
-
activeTaskCreator.threadProducer().commitTransaction(allOffsets,
mainConsumer.groupMetadata());
+ tasks.threadProducer().commitTransaction(allOffsets,
mainConsumer.groupMetadata());
} catch (final TimeoutException timeoutException) {
throw new TaskTimeoutExceptions(timeoutException);
}
@@ -1258,7 +1204,7 @@ public class TaskManager {
stringBuilder.append("TaskManager\n");
stringBuilder.append(indent).append("\tMetadataState:\n");
stringBuilder.append(indent).append("\tTasks:\n");
- for (final Task task : tasks.values()) {
+ for (final Task task : tasks.allTasks()) {
stringBuilder.append(indent)
.append("\t\t")
.append(task.id())
@@ -1272,11 +1218,11 @@ public class TaskManager {
}
Map<MetricName, Metric> producerMetrics() {
- return activeTaskCreator.producerMetrics();
+ return tasks.producerMetrics();
}
Set<String> producerClientIds() {
- return activeTaskCreator.producerClientIds();
+ return tasks.producerClientIds();
}
Set<TaskId> lockedTaskDirectories() {
@@ -1314,4 +1260,9 @@ public class TaskManager {
public void setPartitionResetter(final
java.util.function.Consumer<Set<TopicPartition>> resetter) {
this.resetter = resetter;
}
+
+ // for testing only
+ void addTask(final Task task) {
+ tasks.addTask(task);
+ }
}
diff --git
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
new file mode 100644
index 0000000..843db80
--- /dev/null
+++
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.slf4j.Logger;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+
+class Tasks {
+ private final Logger log;
+ private final InternalTopologyBuilder builder;
+ private final StreamsMetricsImpl streamsMetrics;
+
+ private final Map<TaskId, Task> allTasksPerId = new TreeMap<>();
+ private final Map<TaskId, Task> readOnlyTasksPerId =
Collections.unmodifiableMap(allTasksPerId);
+ private final Collection<Task> readOnlyTasks =
Collections.unmodifiableCollection(allTasksPerId.values());
+
+ // TODO: change type to `StreamTask`
+ private final Map<TaskId, Task> activeTasksPerId = new TreeMap<>();
+ // TODO: change type to `StreamTask`
+ private final Map<TopicPartition, Task> activeTasksPerPartition = new
HashMap<>();
+ // TODO: change type to `StreamTask`
+ private final Map<TaskId, Task> readOnlyActiveTasksPerId =
Collections.unmodifiableMap(activeTasksPerId);
+ private final Set<TaskId> readOnlyActiveTaskIds =
Collections.unmodifiableSet(activeTasksPerId.keySet());
+ // TODO: change type to `StreamTask`
+ private final Collection<Task> readOnlyActiveTasks =
Collections.unmodifiableCollection(activeTasksPerId.values());
+
+ // TODO: change type to `StandbyTask`
+ private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>();
+ // TODO: change type to `StandbyTask`
+ private final Map<TaskId, Task> readOnlyStandbyTasksPerId =
Collections.unmodifiableMap(standbyTasksPerId);
+ private final Set<TaskId> readOnlyStandbyTaskIds =
Collections.unmodifiableSet(standbyTasksPerId.keySet());
+
+ private final ActiveTaskCreator activeTaskCreator;
+ private final StandbyTaskCreator standbyTaskCreator;
+
+ private Consumer<byte[], byte[]> mainConsumer;
+
+ Tasks(final String logPrefix,
+ final InternalTopologyBuilder builder,
+ final StreamsMetricsImpl streamsMetrics,
+ final ActiveTaskCreator activeTaskCreator,
+ final StandbyTaskCreator standbyTaskCreator) {
+
+ final LogContext logContext = new LogContext(logPrefix);
+ log = logContext.logger(getClass());
+
+ this.builder = builder;
+ this.streamsMetrics = streamsMetrics;
+ this.activeTaskCreator = activeTaskCreator;
+ this.standbyTaskCreator = standbyTaskCreator;
+ }
+
+ void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
+ this.mainConsumer = mainConsumer;
+ }
+
+ void createTasks(final Map<TaskId, Set<TopicPartition>>
activeTasksToCreate,
+ final Map<TaskId, Set<TopicPartition>>
standbyTasksToCreate) {
+ for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated :
activeTasksToCreate.entrySet()) {
+ final TaskId taskId = taskToBeCreated.getKey();
+
+ if (activeTasksPerId.containsKey(taskId)) {
+ throw new IllegalStateException("Attempted to create an active
task that we already own: " + taskId);
+ }
+ }
+
+ for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated :
standbyTasksToCreate.entrySet()) {
+ final TaskId taskId = taskToBeCreated.getKey();
+
+ if (standbyTasksPerId.containsKey(taskId)) {
+ throw new IllegalStateException("Attempted to create a standby
task that we already own: " + taskId);
+ }
+ }
+
+ // keep this check to simplify testing (ie, no need to mock
`activeTaskCreator`)
+ if (!activeTasksToCreate.isEmpty()) {
+ // TODO: change type to `StreamTask`
+ for (final Task activeTask :
activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate)) {
+ activeTasksPerId.put(activeTask.id(), activeTask);
+ allTasksPerId.put(activeTask.id(), activeTask);
+ for (final TopicPartition topicPartition :
activeTask.inputPartitions()) {
+ activeTasksPerPartition.put(topicPartition, activeTask);
+ }
+ }
+ }
+
+ // keep this check to simplify testing (ie, no need to mock
`standbyTaskCreator`)
+ if (!standbyTasksToCreate.isEmpty()) {
+ // TODO: change type to `StandbyTask`
+ for (final Task standbyTask :
standbyTaskCreator.createTasks(standbyTasksToCreate)) {
+ standbyTasksPerId.put(standbyTask.id(), standbyTask);
+ allTasksPerId.put(standbyTask.id(), standbyTask);
+ }
+ }
+ }
+
+ void convertActiveToStandby(final StreamTask activeTask,
+ final Set<TopicPartition> partitions,
+ final Map<TaskId, RuntimeException>
taskCloseExceptions) {
+ if (activeTasksPerId.remove(activeTask.id()) == null) {
+ throw new IllegalStateException("Attempted to convert unknown
active task to standby task: " + activeTask.id());
+ }
+ activeTasksPerPartition.entrySet().stream()
+ .filter(e -> e.getValue().id().equals(activeTask.id()))
+ .forEach(e -> activeTasksPerPartition.remove(e.getKey()));
+
+ cleanUpTaskProducerAndRemoveTask(activeTask.id(), taskCloseExceptions);
+
+ final StandbyTask standbyTask =
standbyTaskCreator.createStandbyTaskFromActive(activeTask, partitions);
+ standbyTasksPerId.put(standbyTask.id(), standbyTask);
+ allTasksPerId.put(standbyTask.id(), standbyTask);
+ }
+
+ void convertStandbyToActive(final StandbyTask standbyTask, final
Set<TopicPartition> partitions) {
+ if (standbyTasksPerId.remove(standbyTask.id()) == null) {
+ throw new IllegalStateException("Attempted to convert unknown
standby task to stream task: " + standbyTask.id());
+ }
+
+ final StreamTask activeTask =
activeTaskCreator.createActiveTaskFromStandby(standbyTask, partitions,
mainConsumer);
+ activeTasksPerId.put(activeTask.id(), activeTask);
+ for (final TopicPartition topicPartition :
activeTask.inputPartitions()) {
+ activeTasksPerPartition.put(topicPartition, activeTask);
+ }
+ allTasksPerId.put(activeTask.id(), activeTask);
+ }
+
+ void updateInputPartitionsAndResume(final Task task, final
Set<TopicPartition> topicPartitions) {
+ final boolean requiresUpdate =
!task.inputPartitions().equals(topicPartitions);
+ if (requiresUpdate) {
+ log.debug("Update task {} inputPartitions: current {}, new {}",
task, task.inputPartitions(), topicPartitions);
+ for (final TopicPartition inputPartition : task.inputPartitions())
{
+ activeTasksPerPartition.remove(inputPartition);
+ }
+ if (task.isActive()) {
+ for (final TopicPartition topicPartition : topicPartitions) {
+ activeTasksPerPartition.put(topicPartition, task);
+ }
+ }
+ task.updateInputPartitions(topicPartitions,
builder.nodeToSourceTopics());
+ }
+ task.resume();
+ }
+
+ void cleanUpTaskProducerAndRemoveTask(final TaskId taskId,
+ final Map<TaskId, RuntimeException>
taskCloseExceptions) {
+ try {
+ activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId);
+ } catch (final RuntimeException e) {
+ final String uncleanMessage = String.format("Failed to close task
%s cleanly. Attempting to close remaining tasks before re-throwing:", taskId);
+ log.error(uncleanMessage, e);
+ taskCloseExceptions.putIfAbsent(taskId, e);
+ }
+ removeTaskBeforeClosing(taskId);
+ }
+
+ void reInitializeThreadProducer() {
+ activeTaskCreator.reInitializeThreadProducer();
+ }
+
+ void closeThreadProducerIfNeeded() {
+ activeTaskCreator.closeThreadProducerIfNeeded();
+ }
+
+ // TODO: change type to `StreamTask`
+ void closeAndRemoveTaskProducerIfNeeded(final Task activeTask) {
+ activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTask.id());
+ }
+
+ void removeTaskBeforeClosing(final TaskId taskId) {
+ activeTasksPerId.remove(taskId);
+ final Set<TopicPartition> toBeRemoved =
activeTasksPerPartition.entrySet().stream()
+ .filter(e -> e.getValue().id().equals(taskId))
+ .map(Map.Entry::getKey)
+ .collect(Collectors.toSet());
+ toBeRemoved.forEach(activeTasksPerPartition::remove);
+ standbyTasksPerId.remove(taskId);
+ allTasksPerId.remove(taskId);
+ }
+
+ void clear() {
+ activeTasksPerId.clear();
+ activeTasksPerPartition.clear();
+ standbyTasksPerId.clear();
+ allTasksPerId.clear();
+ }
+
+ // TODO: change return type to `StreamTask`
+ Task activeTasksForInputPartition(final TopicPartition partition) {
+ return activeTasksPerPartition.get(partition);
+ }
+
+ // TODO: change return type to `StandbyTask`
+ Task standbyTask(final TaskId taskId) {
+ if (!standbyTasksPerId.containsKey(taskId)) {
+ throw new IllegalStateException("Standby task unknown: " + taskId);
+ }
+ return standbyTasksPerId.get(taskId);
+ }
+
+ Task task(final TaskId taskId) {
+ if (!allTasksPerId.containsKey(taskId)) {
+ throw new IllegalStateException("Task unknown: " + taskId);
+ }
+ return allTasksPerId.get(taskId);
+ }
+
+ // TODO: change return type to `StreamTask`
+ Collection<Task> activeTasks() {
+ return readOnlyActiveTasks;
+ }
+
+ Collection<Task> allTasks() {
+ return readOnlyTasks;
+ }
+
+ Set<TaskId> activeTaskIds() {
+ return readOnlyActiveTaskIds;
+ }
+
+ Set<TaskId> standbyTaskIds() {
+ return readOnlyStandbyTaskIds;
+ }
+
+ // TODO: change return type to `StreamTask`
+ Map<TaskId, Task> activeTaskMap() {
+ return readOnlyActiveTasksPerId;
+ }
+
+ // TODO: change return type to `StandbyTask`
+ Map<TaskId, Task> standbyTaskMap() {
+ return readOnlyStandbyTasksPerId;
+ }
+
+ Map<TaskId, Task> tasksPerId() {
+ return readOnlyTasksPerId;
+ }
+
+ boolean owned(final TaskId taskId) {
+ return allTasksPerId.containsKey(taskId);
+ }
+
+ StreamsProducer streamsProducerForTask(final TaskId taskId) {
+ return activeTaskCreator.streamsProducerForTask(taskId);
+ }
+
+ StreamsProducer threadProducer() {
+ return activeTaskCreator.threadProducer();
+ }
+
+ Map<MetricName, Metric> producerMetrics() {
+ return activeTaskCreator.producerMetrics();
+ }
+
+ Set<String> producerClientIds() {
+ return activeTaskCreator.producerClientIds();
+ }
+
+ // for testing only
+ void addTask(final Task task) {
+ if (task.isActive()) {
+ activeTasksPerId.put(task.id(), task);
+ } else {
+ standbyTasksPerId.put(task.id(), task);
+ }
+ allTasksPerId.put(task.id(), task);
+ }
+}
diff --git
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 157f9e5..401a830 100644
---
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -767,6 +767,7 @@ public class StreamThreadTest {
null,
null,
null,
+ null,
null
) {
@Override
@@ -2564,6 +2565,7 @@ public class StreamThreadTest {
);
}
+ // TODO: change return type to `StandbyTask`
private Collection<Task> createStandbyTask() {
final LogContext logContext = new LogContext("test");
final Logger log = logContext.logger(StreamThreadTest.class);
diff --git
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 0eb7817..04bd252 100644
---
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -33,8 +33,10 @@ import org.apache.kafka.common.errors.TimeoutException;
import org.apache.kafka.common.internals.KafkaFutureImpl;
import org.apache.kafka.common.metrics.KafkaMetric;
import org.apache.kafka.common.metrics.Measurable;
+import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.errors.LockException;
import org.apache.kafka.streams.errors.StreamsException;
import org.apache.kafka.streams.errors.TaskMigratedException;
@@ -42,6 +44,7 @@ import org.apache.kafka.streams.processor.StateStore;
import org.apache.kafka.streams.processor.TaskId;
import
org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode;
import org.apache.kafka.streams.processor.internals.Task.State;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
import
org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
import org.easymock.EasyMock;
@@ -171,6 +174,7 @@ public class TaskManagerTest {
changeLogReader,
UUID.randomUUID(),
"taskManagerTest",
+ new StreamsMetricsImpl(new Metrics(), "clientId",
StreamsConfig.METRICS_LATEST, time),
activeTaskCreator,
standbyTaskCreator,
topologyBuilder,
@@ -1495,8 +1499,8 @@ public class TaskManagerTest {
};
task01.setCommitNeeded();
- taskManager.tasks().put(taskId00, task00);
- taskManager.tasks().put(taskId01, task01);
+ taskManager.addTask(task00);
+ taskManager.addTask(task01);
final RuntimeException thrown = assertThrows(RuntimeException.class,
() -> taskManager.handleAssignment(
@@ -1514,8 +1518,6 @@ public class TaskManagerTest {
@Test
public void
shouldSuspendAllRevokedActiveTasksAndPropagateSuspendException() {
- setUpTaskManager(StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA);
-
final Task task00 = new StateMachineTask(taskId00, taskId00Partitions,
true);
final StateMachineTask task01 = new StateMachineTask(taskId01,
taskId01Partitions, true) {
@@ -1528,9 +1530,9 @@ public class TaskManagerTest {
final StateMachineTask task02 = new StateMachineTask(taskId02,
taskId02Partitions, true);
- taskManager.tasks().put(taskId00, task00);
- taskManager.tasks().put(taskId01, task01);
- taskManager.tasks().put(taskId02, task02);
+ taskManager.addTask(task00);
+ taskManager.addTask(task01);
+ taskManager.addTask(task02);
replay(activeTaskCreator);
@@ -1857,7 +1859,7 @@ public class TaskManagerTest {
final Map<TopicPartition, OffsetAndMetadata> offsets =
singletonMap(t1p1, new OffsetAndMetadata(0L, null));
task01.setCommittableOffsetsAndMetadata(offsets);
task01.setCommitNeeded();
- taskManager.tasks().put(taskId01, task01);
+ taskManager.addTask(task01);
consumer.commitSync(offsets);
expectLastCall();
@@ -1912,11 +1914,11 @@ public class TaskManagerTest {
final StateMachineTask task01 = new StateMachineTask(taskId01,
taskId01Partitions, true);
task01.setCommittableOffsetsAndMetadata(offsetsT01);
task01.setCommitNeeded();
- taskManager.tasks().put(taskId01, task01);
+ taskManager.addTask(task01);
final StateMachineTask task02 = new StateMachineTask(taskId02,
taskId02Partitions, true);
task02.setCommittableOffsetsAndMetadata(offsetsT02);
task02.setCommitNeeded();
- taskManager.tasks().put(taskId02, task02);
+ taskManager.addTask(task02);
reset(consumer);
expect(consumer.groupMetadata()).andReturn(new
ConsumerGroupMetadata("appId")).anyTimes();
@@ -2400,8 +2402,8 @@ public class TaskManagerTest {
throw new TaskMigratedException("t2 close exception", new
RuntimeException());
}
};
- taskManager.tasks().put(taskId01, migratedTask01);
- taskManager.tasks().put(taskId02, migratedTask02);
+ taskManager.addTask(migratedTask01);
+ taskManager.addTask(migratedTask02);
final TaskMigratedException thrown = assertThrows(
TaskMigratedException.class,
@@ -2432,8 +2434,8 @@ public class TaskManagerTest {
throw new IllegalStateException("t2 illegal state exception",
new RuntimeException());
}
};
- taskManager.tasks().put(taskId01, migratedTask01);
- taskManager.tasks().put(taskId02, migratedTask02);
+ taskManager.addTask(migratedTask01);
+ taskManager.addTask(migratedTask02);
final RuntimeException thrown = assertThrows(
RuntimeException.class,
@@ -2463,8 +2465,8 @@ public class TaskManagerTest {
throw new KafkaException("Kaboom for t2!", new
RuntimeException());
}
};
- taskManager.tasks().put(taskId01, migratedTask01);
- taskManager.tasks().put(taskId02, migratedTask02);
+ taskManager.addTask(migratedTask01);
+ taskManager.addTask(migratedTask02);
final KafkaException thrown = assertThrows(
KafkaException.class,
@@ -2573,7 +2575,7 @@ public class TaskManagerTest {
final Map<TopicPartition, OffsetAndMetadata> offsets =
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
task01.setCommittableOffsetsAndMetadata(offsets);
task01.setCommitNeeded();
- taskManager.tasks().put(taskId01, task01);
+ taskManager.addTask(task01);
consumer.commitSync(offsets);
expectLastCall().andThrow(new CommitFailedException());
@@ -2712,7 +2714,7 @@ public class TaskManagerTest {
final Map<TopicPartition, OffsetAndMetadata> offsets =
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
task01.setCommittableOffsetsAndMetadata(offsets);
task01.setCommitNeeded();
- taskManager.tasks().put(taskId01, task01);
+ taskManager.addTask(task01);
consumer.commitSync(offsets);
expectLastCall().andThrow(new KafkaException());
@@ -2734,7 +2736,7 @@ public class TaskManagerTest {
final Map<TopicPartition, OffsetAndMetadata> offsets =
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
task01.setCommittableOffsetsAndMetadata(offsets);
task01.setCommitNeeded();
- taskManager.tasks().put(taskId01, task01);
+ taskManager.addTask(task01);
consumer.commitSync(offsets);
expectLastCall().andThrow(new RuntimeException("KABOOM"));