This is an automated email from the ASF dual-hosted git repository. guozhang pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push: new 083e11a22ca KAFKA-14650: Synchronize access to tasks inside task manager (#13167) 083e11a22ca is described below commit 083e11a22ca9966ed010acdd5705351ab4300b52 Author: Guozhang Wang <wangg...@gmail.com> AuthorDate: Thu Feb 9 10:33:19 2023 -0800 KAFKA-14650: Synchronize access to tasks inside task manager (#13167) 1. The major fix: synchronize access to tasks inside task manager, this is a fix of a regression introduced in #12397 2. Clarify on func names of StreamThread that maybe triggered outside the StreamThread. 3. Minor cleanups. Reviewers: Lucas Brutschy <lucas...@users.noreply.github.com> --- .../org/apache/kafka/streams/KafkaStreams.java | 16 ++++++------- .../streams/processor/internals/StreamThread.java | 24 ++++++++++++-------- .../streams/processor/internals/TaskManager.java | 6 +++++ .../kafka/streams/processor/internals/Tasks.java | 22 +++++++++--------- .../KafkaStreamsNamedTopologyWrapper.java | 2 +- .../internals/StreamThreadStateStoreProvider.java | 2 +- .../org/apache/kafka/streams/KafkaStreamsTest.java | 5 ++--- .../processor/internals/StreamThreadTest.java | 26 +++++++++++----------- .../StreamThreadStateStoreProviderTest.java | 4 ++-- 9 files changed, 59 insertions(+), 48 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java index bc3a1243ce2..ee9f57b0680 100644 --- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java +++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java @@ -1826,7 +1826,7 @@ public class KafkaStreams implements AutoCloseable { */ public Map<String, Map<Integer, LagInfo>> allLocalStorePartitionLags() { final List<Task> allTasks = new ArrayList<>(); - processStreamThread(thread -> allTasks.addAll(thread.allTasks().values())); + processStreamThread(thread -> allTasks.addAll(thread.readyOnlyAllTasks())); return allLocalStorePartitionLags(allTasks); } @@ -1917,21 +1917,19 @@ public class KafkaStreams implements AutoCloseable { ); } else { for (final StreamThread thread : threads) { - final Map<TaskId, Task> tasks = thread.allTasks(); - for (final Entry<TaskId, Task> entry : tasks.entrySet()) { + final Set<Task> tasks = thread.readyOnlyAllTasks(); + for (final Task task : tasks) { - final TaskId taskId = entry.getKey(); + final TaskId taskId = task.id(); final int partition = taskId.partition(); - if (request.isAllPartitions() - || request.getPartitions().contains(partition)) { - final Task task = entry.getValue(); + if (request.isAllPartitions() || request.getPartitions().contains(partition)) { final StateStore store = task.getStore(storeName); if (store != null) { final StreamThread.State state = thread.state(); final boolean active = task.isActive(); if (request.isRequireActive() - && (state != StreamThread.State.RUNNING - || !active)) { + && (state != StreamThread.State.RUNNING || !active)) { + result.addResult( partition, QueryResult.forFailure( diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index 2d433eb3c82..d4be6a83af9 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -60,7 +60,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -1269,16 +1268,23 @@ public class StreamThread extends Thread { ); } - public Map<TaskId, Task> activeTaskMap() { - return taskManager.activeTaskMap(); - } - - public List<Task> activeTasks() { - return taskManager.activeTaskIterable(); + /** + * Getting the list of current active tasks of the thread; + * Note that the returned list may be used by other thread than the StreamThread itself, + * and hence need to be read-only + */ + public Set<Task> readOnlyActiveTasks() { + return readyOnlyAllTasks().stream() + .filter(Task::isActive).collect(Collectors.toSet()); } - public Map<TaskId, Task> allTasks() { - return taskManager.allTasks(); + /** + * Getting the list of all owned tasks of the thread, including both active and standby; + * Note that the returned list may be used by other thread than the StreamThread itself, + * and hence need to be read-only + */ + public Set<Task> readyOnlyAllTasks() { + return taskManager.readOnlyAllTasks(); } /** diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 583662f0924..2a5b3d2bacf 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -1537,6 +1537,12 @@ public class TaskManager { return tasks.allTasksPerId(); } + Set<Task> readOnlyAllTasks() { + // need to make sure the returned set is unmodifiable as it could be accessed + // by other thread than the StreamThread owning this task manager; + return Collections.unmodifiableSet(tasks.allTasks()); + } + Map<TaskId, Task> notPausedTasks() { return Collections.unmodifiableMap(tasks.allTasks() .stream() diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java index 14de0e794fd..7b3f7860fb7 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java @@ -45,6 +45,8 @@ class Tasks implements TasksRegistry { private final Logger log; // TODO: convert to Stream/StandbyTask when we remove TaskManager#StateMachineTask with mocks + // note that these two maps may be accessed by concurrent threads and hence + // should be synchronized when accessed private final Map<TaskId, Task> activeTasksPerId = new TreeMap<>(); private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>(); @@ -203,7 +205,7 @@ class Tasks implements TasksRegistry { } @Override - public void addTask(final Task task) { + public synchronized void addTask(final Task task) { final TaskId taskId = task.id(); if (activeTasksPerId.containsKey(taskId)) { throw new IllegalStateException("Attempted to create an active task that we already own: " + taskId); @@ -225,7 +227,7 @@ class Tasks implements TasksRegistry { } @Override - public void removeTask(final Task taskToRemove) { + public synchronized void removeTask(final Task taskToRemove) { final TaskId taskId = taskToRemove.id(); if (taskToRemove.state() != Task.State.CLOSED) { @@ -245,7 +247,7 @@ class Tasks implements TasksRegistry { } @Override - public void replaceActiveWithStandby(final StandbyTask standbyTask) { + public synchronized void replaceActiveWithStandby(final StandbyTask standbyTask) { final TaskId taskId = standbyTask.id(); if (activeTasksPerId.remove(taskId) == null) { throw new IllegalStateException("Attempted to replace unknown active task with standby task: " + taskId); @@ -256,7 +258,7 @@ class Tasks implements TasksRegistry { } @Override - public void replaceStandbyWithActive(final StreamTask activeTask) { + public synchronized void replaceStandbyWithActive(final StreamTask activeTask) { final TaskId taskId = activeTask.id(); if (standbyTasksPerId.remove(taskId) == null) { throw new IllegalStateException("Attempted to convert unknown standby task to stream task: " + taskId); @@ -295,7 +297,7 @@ class Tasks implements TasksRegistry { } @Override - public void clear() { + public synchronized void clear() { activeTasksPerId.clear(); standbyTasksPerId.clear(); activeTasksPerPartition.clear(); @@ -307,7 +309,7 @@ class Tasks implements TasksRegistry { return activeTasksPerPartition.get(partition); } - private Task getTask(final TaskId taskId) { + private synchronized Task getTask(final TaskId taskId) { if (activeTasksPerId.containsKey(taskId)) { return activeTasksPerId.get(taskId); } @@ -337,7 +339,7 @@ class Tasks implements TasksRegistry { } @Override - public Collection<Task> activeTasks() { + public synchronized Collection<Task> activeTasks() { return Collections.unmodifiableCollection(activeTasksPerId.values()); } @@ -346,17 +348,17 @@ class Tasks implements TasksRegistry { * and the returned task could be modified by other threads concurrently */ @Override - public Set<Task> allTasks() { + public synchronized Set<Task> allTasks() { return union(HashSet::new, new HashSet<>(activeTasksPerId.values()), new HashSet<>(standbyTasksPerId.values())); } @Override - public Set<TaskId> allTaskIds() { + public synchronized Set<TaskId> allTaskIds() { return union(HashSet::new, activeTasksPerId.keySet(), standbyTasksPerId.keySet()); } @Override - public Map<TaskId, Task> allTasksPerId() { + public synchronized Map<TaskId, Task> allTasksPerId() { final Map<TaskId, Task> ret = new HashMap<>(); ret.putAll(activeTasksPerId); ret.putAll(standbyTasksPerId); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java index 3d22c583373..4704d1d4df7 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java @@ -436,7 +436,7 @@ public class KafkaStreamsNamedTopologyWrapper extends KafkaStreams { } final List<Task> allTopologyTasks = new ArrayList<>(); processStreamThread(thread -> allTopologyTasks.addAll( - thread.allTasks().values().stream() + thread.readyOnlyAllTasks().stream() .filter(t -> topologyName.equals(t.id().topologyName())) .collect(Collectors.toList()))); return allLocalStorePartitionLags(allTopologyTasks); diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java index cf033252442..a962fe0a677 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java @@ -56,7 +56,7 @@ public class StreamThreadStateStoreProvider { if (storeQueryParams.staleStoresEnabled() ? state.isAlive() : state == StreamThread.State.RUNNING) { final Collection<Task> tasks = storeQueryParams.staleStoresEnabled() ? - streamThread.allTasks().values() : streamThread.activeTasks(); + streamThread.readyOnlyAllTasks() : streamThread.readOnlyActiveTasks(); if (storeQueryParams.partition() != null) { for (final Task task : tasks) { diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java index 1c0331824e7..fef0353b90b 100644 --- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java @@ -93,7 +93,6 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState; @@ -363,8 +362,8 @@ public class KafkaStreamsTest { }).when(thread).join(); } - when(thread.activeTasks()).thenReturn(emptyList()); - when(thread.allTasks()).thenReturn(Collections.emptyMap()); + when(thread.readOnlyActiveTasks()).thenReturn(Collections.emptySet()); + when(thread.readyOnlyAllTasks()).thenReturn(Collections.emptySet()); } @Test diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 76574cccb70..dc9259a1f91 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -1038,7 +1038,7 @@ public class StreamThreadTest { assertEquals(1, clientSupplier.producers.size()); final Producer<byte[], byte[]> globalProducer = clientSupplier.producers.get(0); - for (final Task task : thread.activeTasks()) { + for (final Task task : thread.readOnlyActiveTasks()) { assertSame(globalProducer, ((RecordCollectorImpl) ((StreamTask) task).recordCollector()).producer()); } assertSame(clientSupplier.consumer, thread.mainConsumer()); @@ -1078,7 +1078,7 @@ public class StreamThreadTest { thread.runOnce(); - assertEquals(thread.activeTasks().size(), clientSupplier.producers.size()); + assertEquals(thread.readOnlyActiveTasks().size(), clientSupplier.producers.size()); assertSame(clientSupplier.consumer, thread.mainConsumer()); assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer()); } @@ -1195,7 +1195,7 @@ public class StreamThreadTest { 10 * 1000, "Thread never shut down."); - for (final Task task : thread.activeTasks()) { + for (final Task task : thread.readOnlyActiveTasks()) { assertTrue(((MockProducer<byte[], byte[]>) ((RecordCollectorImpl) ((StreamTask) task).recordCollector()).producer()).closed()); } } @@ -1388,7 +1388,7 @@ public class StreamThreadTest { thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); thread.runOnce(); - assertThat(thread.activeTasks().size(), equalTo(1)); + assertThat(thread.readOnlyActiveTasks().size(), equalTo(1)); final MockProducer<byte[], byte[]> producer = clientSupplier.producers.get(0); // change consumer subscription from "pattern" to "manual" to be able to call .addRecords() @@ -1415,7 +1415,7 @@ public class StreamThreadTest { } catch (final KafkaException expected) { assertTrue(expected instanceof TaskMigratedException); assertTrue("StreamsThread removed the fenced zombie task already, should wait for rebalance to close all zombies together.", - thread.activeTasks().stream().anyMatch(task -> task.id().equals(task1))); + thread.readOnlyActiveTasks().stream().anyMatch(task -> task.id().equals(task1))); } assertThat(producer.commitCount(), equalTo(1L)); @@ -1447,7 +1447,7 @@ public class StreamThreadTest { thread.runOnce(); - assertThat(thread.activeTasks().size(), equalTo(1)); + assertThat(thread.readOnlyActiveTasks().size(), equalTo(1)); // need to process a record to enable committing addRecord(mockConsumer, 0L); @@ -1457,7 +1457,7 @@ public class StreamThreadTest { assertThrows(TaskMigratedException.class, () -> thread.rebalanceListener().onPartitionsRevoked(assignedPartitions)); assertFalse(clientSupplier.producers.get(0).transactionCommitted()); assertFalse(clientSupplier.producers.get(0).closed()); - assertEquals(1, thread.activeTasks().size()); + assertEquals(1, thread.readOnlyActiveTasks().size()); } @Test @@ -1523,7 +1523,7 @@ public class StreamThreadTest { // the first iteration completes the restoration thread.runOnce(); - assertThat(thread.activeTasks().size(), equalTo(1)); + assertThat(thread.readOnlyActiveTasks().size(), equalTo(1)); // the second transits to running and unpause the input thread.runOnce(); @@ -1578,7 +1578,7 @@ public class StreamThreadTest { thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); thread.runOnce(); - assertThat(thread.activeTasks().size(), equalTo(1)); + assertThat(thread.readOnlyActiveTasks().size(), equalTo(1)); final MockProducer<byte[], byte[]> producer = clientSupplier.producers.get(0); producer.commitTransactionException = new ProducerFencedException("Producer is fenced"); @@ -1590,7 +1590,7 @@ public class StreamThreadTest { } catch (final KafkaException expected) { assertTrue(expected instanceof TaskMigratedException); assertTrue("StreamsThread removed the fenced zombie task already, should wait for rebalance to close all zombies together.", - thread.activeTasks().stream().anyMatch(task -> task.id().equals(task1))); + thread.readOnlyActiveTasks().stream().anyMatch(task -> task.id().equals(task1))); } assertThat(producer.commitCount(), equalTo(0L)); @@ -1598,7 +1598,7 @@ public class StreamThreadTest { assertTrue(clientSupplier.producers.get(0).transactionInFlight()); assertFalse(clientSupplier.producers.get(0).transactionCommitted()); assertFalse(clientSupplier.producers.get(0).closed()); - assertEquals(1, thread.activeTasks().size()); + assertEquals(1, thread.readOnlyActiveTasks().size()); } @Test @@ -1627,7 +1627,7 @@ public class StreamThreadTest { thread.runOnce(); - assertThat(thread.activeTasks().size(), equalTo(1)); + assertThat(thread.readOnlyActiveTasks().size(), equalTo(1)); // need to process a record to enable committing addRecord(mockConsumer, 0L); @@ -1636,7 +1636,7 @@ public class StreamThreadTest { thread.rebalanceListener().onPartitionsRevoked(assignedPartitions); assertTrue(clientSupplier.producers.get(0).transactionCommitted()); assertFalse(clientSupplier.producers.get(0).closed()); - assertEquals(1, thread.activeTasks().size()); + assertEquals(1, thread.readOnlyActiveTasks().size()); } @Test diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java index 3075d8e78bc..92a2c069195 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java @@ -69,10 +69,10 @@ import org.mockito.junit.MockitoJUnitRunner; import java.io.File; import java.io.IOException; import java.time.Duration; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Properties; @@ -464,7 +464,7 @@ public class StreamThreadStateStoreProviderTest { } private void mockThread(final boolean initialized) { - when(threadMock.activeTasks()).thenReturn(new ArrayList<>(tasks.values())); + when(threadMock.readOnlyActiveTasks()).thenReturn(new HashSet<>(tasks.values())); when(threadMock.state()).thenReturn( initialized ? StreamThread.State.RUNNING : StreamThread.State.PARTITIONS_ASSIGNED );