This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 083e11a22ca KAFKA-14650: Synchronize access to tasks inside task 
manager (#13167)
083e11a22ca is described below

commit 083e11a22ca9966ed010acdd5705351ab4300b52
Author: Guozhang Wang <wangg...@gmail.com>
AuthorDate: Thu Feb 9 10:33:19 2023 -0800

    KAFKA-14650: Synchronize access to tasks inside task manager (#13167)
    
    1. The major fix: synchronize access to tasks inside task manager, this is 
a fix of a regression introduced in #12397
    2. Clarify on func names of StreamThread that maybe triggered outside the 
StreamThread.
    3. Minor cleanups.
    
    Reviewers: Lucas Brutschy <lucas...@users.noreply.github.com>
---
 .../org/apache/kafka/streams/KafkaStreams.java     | 16 ++++++-------
 .../streams/processor/internals/StreamThread.java  | 24 ++++++++++++--------
 .../streams/processor/internals/TaskManager.java   |  6 +++++
 .../kafka/streams/processor/internals/Tasks.java   | 22 +++++++++---------
 .../KafkaStreamsNamedTopologyWrapper.java          |  2 +-
 .../internals/StreamThreadStateStoreProvider.java  |  2 +-
 .../org/apache/kafka/streams/KafkaStreamsTest.java |  5 ++---
 .../processor/internals/StreamThreadTest.java      | 26 +++++++++++-----------
 .../StreamThreadStateStoreProviderTest.java        |  4 ++--
 9 files changed, 59 insertions(+), 48 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java 
b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
index bc3a1243ce2..ee9f57b0680 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
@@ -1826,7 +1826,7 @@ public class KafkaStreams implements AutoCloseable {
      */
     public Map<String, Map<Integer, LagInfo>> allLocalStorePartitionLags() {
         final List<Task> allTasks = new ArrayList<>();
-        processStreamThread(thread -> 
allTasks.addAll(thread.allTasks().values()));
+        processStreamThread(thread -> 
allTasks.addAll(thread.readyOnlyAllTasks()));
         return allLocalStorePartitionLags(allTasks);
     }
 
@@ -1917,21 +1917,19 @@ public class KafkaStreams implements AutoCloseable {
             );
         } else {
             for (final StreamThread thread : threads) {
-                final Map<TaskId, Task> tasks = thread.allTasks();
-                for (final Entry<TaskId, Task> entry : tasks.entrySet()) {
+                final Set<Task> tasks = thread.readyOnlyAllTasks();
+                for (final Task task : tasks) {
 
-                    final TaskId taskId = entry.getKey();
+                    final TaskId taskId = task.id();
                     final int partition = taskId.partition();
-                    if (request.isAllPartitions()
-                        || request.getPartitions().contains(partition)) {
-                        final Task task = entry.getValue();
+                    if (request.isAllPartitions() || 
request.getPartitions().contains(partition)) {
                         final StateStore store = task.getStore(storeName);
                         if (store != null) {
                             final StreamThread.State state = thread.state();
                             final boolean active = task.isActive();
                             if (request.isRequireActive()
-                                && (state != StreamThread.State.RUNNING
-                                || !active)) {
+                                && (state != StreamThread.State.RUNNING || 
!active)) {
+
                                 result.addResult(
                                     partition,
                                     QueryResult.forFailure(
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 2d433eb3c82..d4be6a83af9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -60,7 +60,6 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
@@ -1269,16 +1268,23 @@ public class StreamThread extends Thread {
         );
     }
 
-    public Map<TaskId, Task> activeTaskMap() {
-        return taskManager.activeTaskMap();
-    }
-
-    public List<Task> activeTasks() {
-        return taskManager.activeTaskIterable();
+    /**
+     * Getting the list of current active tasks of the thread;
+     * Note that the returned list may be used by other thread than the 
StreamThread itself,
+     * and hence need to be read-only
+     */
+    public Set<Task> readOnlyActiveTasks() {
+        return readyOnlyAllTasks().stream()
+            .filter(Task::isActive).collect(Collectors.toSet());
     }
 
-    public Map<TaskId, Task> allTasks() {
-        return taskManager.allTasks();
+    /**
+     * Getting the list of all owned tasks of the thread, including both 
active and standby;
+     * Note that the returned list may be used by other thread than the 
StreamThread itself,
+     * and hence need to be read-only
+     */
+    public Set<Task> readyOnlyAllTasks() {
+        return taskManager.readOnlyAllTasks();
     }
 
     /**
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 583662f0924..2a5b3d2bacf 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -1537,6 +1537,12 @@ public class TaskManager {
         return tasks.allTasksPerId();
     }
 
+    Set<Task> readOnlyAllTasks() {
+        // need to make sure the returned set is unmodifiable as it could be 
accessed
+        // by other thread than the StreamThread owning this task manager;
+        return Collections.unmodifiableSet(tasks.allTasks());
+    }
+
     Map<TaskId, Task> notPausedTasks() {
         return Collections.unmodifiableMap(tasks.allTasks()
             .stream()
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
index 14de0e794fd..7b3f7860fb7 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -45,6 +45,8 @@ class Tasks implements TasksRegistry {
     private final Logger log;
 
     // TODO: convert to Stream/StandbyTask when we remove 
TaskManager#StateMachineTask with mocks
+    // note that these two maps may be accessed by concurrent threads and hence
+    // should be synchronized when accessed
     private final Map<TaskId, Task> activeTasksPerId = new TreeMap<>();
     private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>();
 
@@ -203,7 +205,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void addTask(final Task task) {
+    public synchronized void addTask(final Task task) {
         final TaskId taskId = task.id();
         if (activeTasksPerId.containsKey(taskId)) {
             throw new IllegalStateException("Attempted to create an active 
task that we already own: " + taskId);
@@ -225,7 +227,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void removeTask(final Task taskToRemove) {
+    public synchronized void removeTask(final Task taskToRemove) {
         final TaskId taskId = taskToRemove.id();
 
         if (taskToRemove.state() != Task.State.CLOSED) {
@@ -245,7 +247,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void replaceActiveWithStandby(final StandbyTask standbyTask) {
+    public synchronized void replaceActiveWithStandby(final StandbyTask 
standbyTask) {
         final TaskId taskId = standbyTask.id();
         if (activeTasksPerId.remove(taskId) == null) {
             throw new IllegalStateException("Attempted to replace unknown 
active task with standby task: " + taskId);
@@ -256,7 +258,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void replaceStandbyWithActive(final StreamTask activeTask) {
+    public synchronized void replaceStandbyWithActive(final StreamTask 
activeTask) {
         final TaskId taskId = activeTask.id();
         if (standbyTasksPerId.remove(taskId) == null) {
             throw new IllegalStateException("Attempted to convert unknown 
standby task to stream task: " + taskId);
@@ -295,7 +297,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void clear() {
+    public synchronized void clear() {
         activeTasksPerId.clear();
         standbyTasksPerId.clear();
         activeTasksPerPartition.clear();
@@ -307,7 +309,7 @@ class Tasks implements TasksRegistry {
         return activeTasksPerPartition.get(partition);
     }
 
-    private Task getTask(final TaskId taskId) {
+    private synchronized Task getTask(final TaskId taskId) {
         if (activeTasksPerId.containsKey(taskId)) {
             return activeTasksPerId.get(taskId);
         }
@@ -337,7 +339,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public Collection<Task> activeTasks() {
+    public synchronized Collection<Task> activeTasks() {
         return Collections.unmodifiableCollection(activeTasksPerId.values());
     }
 
@@ -346,17 +348,17 @@ class Tasks implements TasksRegistry {
      * and the returned task could be modified by other threads concurrently
      */
     @Override
-    public Set<Task> allTasks() {
+    public synchronized Set<Task> allTasks() {
         return union(HashSet::new, new HashSet<>(activeTasksPerId.values()), 
new HashSet<>(standbyTasksPerId.values()));
     }
 
     @Override
-    public Set<TaskId> allTaskIds() {
+    public synchronized Set<TaskId> allTaskIds() {
         return union(HashSet::new, activeTasksPerId.keySet(), 
standbyTasksPerId.keySet());
     }
 
     @Override
-    public Map<TaskId, Task> allTasksPerId() {
+    public synchronized Map<TaskId, Task> allTasksPerId() {
         final Map<TaskId, Task> ret = new HashMap<>();
         ret.putAll(activeTasksPerId);
         ret.putAll(standbyTasksPerId);
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java
index 3d22c583373..4704d1d4df7 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java
@@ -436,7 +436,7 @@ public class KafkaStreamsNamedTopologyWrapper extends 
KafkaStreams {
         }
         final List<Task> allTopologyTasks = new ArrayList<>();
         processStreamThread(thread -> allTopologyTasks.addAll(
-            thread.allTasks().values().stream()
+            thread.readyOnlyAllTasks().stream()
                 .filter(t -> topologyName.equals(t.id().topologyName()))
                 .collect(Collectors.toList())));
         return allLocalStorePartitionLags(allTopologyTasks);
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
index cf033252442..a962fe0a677 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
@@ -56,7 +56,7 @@ public class StreamThreadStateStoreProvider {
 
         if (storeQueryParams.staleStoresEnabled() ? state.isAlive() : state == 
StreamThread.State.RUNNING) {
             final Collection<Task> tasks = 
storeQueryParams.staleStoresEnabled() ?
-                    streamThread.allTasks().values() : 
streamThread.activeTasks();
+                    streamThread.readyOnlyAllTasks() : 
streamThread.readOnlyActiveTasks();
 
             if (storeQueryParams.partition() != null) {
                 for (final Task task : tasks) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java 
b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index 1c0331824e7..fef0353b90b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -93,7 +93,6 @@ import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 
-import static java.util.Collections.emptyList;
 import static java.util.Collections.singletonList;
 import static 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName;
 import static 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState;
@@ -363,8 +362,8 @@ public class KafkaStreamsTest {
             }).when(thread).join();
         }
 
-        when(thread.activeTasks()).thenReturn(emptyList());
-        when(thread.allTasks()).thenReturn(Collections.emptyMap());
+        when(thread.readOnlyActiveTasks()).thenReturn(Collections.emptySet());
+        when(thread.readyOnlyAllTasks()).thenReturn(Collections.emptySet());
     }
 
     @Test
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 76574cccb70..dc9259a1f91 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -1038,7 +1038,7 @@ public class StreamThreadTest {
 
         assertEquals(1, clientSupplier.producers.size());
         final Producer<byte[], byte[]> globalProducer = 
clientSupplier.producers.get(0);
-        for (final Task task : thread.activeTasks()) {
+        for (final Task task : thread.readOnlyActiveTasks()) {
             assertSame(globalProducer, ((RecordCollectorImpl) ((StreamTask) 
task).recordCollector()).producer());
         }
         assertSame(clientSupplier.consumer, thread.mainConsumer());
@@ -1078,7 +1078,7 @@ public class StreamThreadTest {
 
         thread.runOnce();
 
-        assertEquals(thread.activeTasks().size(), 
clientSupplier.producers.size());
+        assertEquals(thread.readOnlyActiveTasks().size(), 
clientSupplier.producers.size());
         assertSame(clientSupplier.consumer, thread.mainConsumer());
         assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer());
     }
@@ -1195,7 +1195,7 @@ public class StreamThreadTest {
             10 * 1000,
             "Thread never shut down.");
 
-        for (final Task task : thread.activeTasks()) {
+        for (final Task task : thread.readOnlyActiveTasks()) {
             assertTrue(((MockProducer<byte[], byte[]>) ((RecordCollectorImpl) 
((StreamTask) task).recordCollector()).producer()).closed());
         }
     }
@@ -1388,7 +1388,7 @@ public class StreamThreadTest {
         thread.rebalanceListener().onPartitionsAssigned(assignedPartitions);
 
         thread.runOnce();
-        assertThat(thread.activeTasks().size(), equalTo(1));
+        assertThat(thread.readOnlyActiveTasks().size(), equalTo(1));
         final MockProducer<byte[], byte[]> producer = 
clientSupplier.producers.get(0);
 
         // change consumer subscription from "pattern" to "manual" to be able 
to call .addRecords()
@@ -1415,7 +1415,7 @@ public class StreamThreadTest {
         } catch (final KafkaException expected) {
             assertTrue(expected instanceof TaskMigratedException);
             assertTrue("StreamsThread removed the fenced zombie task already, 
should wait for rebalance to close all zombies together.",
-                thread.activeTasks().stream().anyMatch(task -> 
task.id().equals(task1)));
+                thread.readOnlyActiveTasks().stream().anyMatch(task -> 
task.id().equals(task1)));
         }
 
         assertThat(producer.commitCount(), equalTo(1L));
@@ -1447,7 +1447,7 @@ public class StreamThreadTest {
 
         thread.runOnce();
 
-        assertThat(thread.activeTasks().size(), equalTo(1));
+        assertThat(thread.readOnlyActiveTasks().size(), equalTo(1));
 
         // need to process a record to enable committing
         addRecord(mockConsumer, 0L);
@@ -1457,7 +1457,7 @@ public class StreamThreadTest {
         assertThrows(TaskMigratedException.class, () -> 
thread.rebalanceListener().onPartitionsRevoked(assignedPartitions));
         assertFalse(clientSupplier.producers.get(0).transactionCommitted());
         assertFalse(clientSupplier.producers.get(0).closed());
-        assertEquals(1, thread.activeTasks().size());
+        assertEquals(1, thread.readOnlyActiveTasks().size());
     }
 
     @Test
@@ -1523,7 +1523,7 @@ public class StreamThreadTest {
 
         // the first iteration completes the restoration
         thread.runOnce();
-        assertThat(thread.activeTasks().size(), equalTo(1));
+        assertThat(thread.readOnlyActiveTasks().size(), equalTo(1));
 
         // the second transits to running and unpause the input
         thread.runOnce();
@@ -1578,7 +1578,7 @@ public class StreamThreadTest {
         thread.rebalanceListener().onPartitionsAssigned(assignedPartitions);
 
         thread.runOnce();
-        assertThat(thread.activeTasks().size(), equalTo(1));
+        assertThat(thread.readOnlyActiveTasks().size(), equalTo(1));
         final MockProducer<byte[], byte[]> producer = 
clientSupplier.producers.get(0);
 
         producer.commitTransactionException = new 
ProducerFencedException("Producer is fenced");
@@ -1590,7 +1590,7 @@ public class StreamThreadTest {
         } catch (final KafkaException expected) {
             assertTrue(expected instanceof TaskMigratedException);
             assertTrue("StreamsThread removed the fenced zombie task already, 
should wait for rebalance to close all zombies together.",
-                thread.activeTasks().stream().anyMatch(task -> 
task.id().equals(task1)));
+                thread.readOnlyActiveTasks().stream().anyMatch(task -> 
task.id().equals(task1)));
         }
 
         assertThat(producer.commitCount(), equalTo(0L));
@@ -1598,7 +1598,7 @@ public class StreamThreadTest {
         assertTrue(clientSupplier.producers.get(0).transactionInFlight());
         assertFalse(clientSupplier.producers.get(0).transactionCommitted());
         assertFalse(clientSupplier.producers.get(0).closed());
-        assertEquals(1, thread.activeTasks().size());
+        assertEquals(1, thread.readOnlyActiveTasks().size());
     }
 
     @Test
@@ -1627,7 +1627,7 @@ public class StreamThreadTest {
 
         thread.runOnce();
 
-        assertThat(thread.activeTasks().size(), equalTo(1));
+        assertThat(thread.readOnlyActiveTasks().size(), equalTo(1));
 
         // need to process a record to enable committing
         addRecord(mockConsumer, 0L);
@@ -1636,7 +1636,7 @@ public class StreamThreadTest {
         thread.rebalanceListener().onPartitionsRevoked(assignedPartitions);
         assertTrue(clientSupplier.producers.get(0).transactionCommitted());
         assertFalse(clientSupplier.producers.get(0).closed());
-        assertEquals(1, thread.activeTasks().size());
+        assertEquals(1, thread.readOnlyActiveTasks().size());
     }
 
     @Test
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
index 3075d8e78bc..92a2c069195 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
@@ -69,10 +69,10 @@ import org.mockito.junit.MockitoJUnitRunner;
 import java.io.File;
 import java.io.IOException;
 import java.time.Duration;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
@@ -464,7 +464,7 @@ public class StreamThreadStateStoreProviderTest {
     }
 
     private void mockThread(final boolean initialized) {
-        when(threadMock.activeTasks()).thenReturn(new 
ArrayList<>(tasks.values()));
+        when(threadMock.readOnlyActiveTasks()).thenReturn(new 
HashSet<>(tasks.values()));
         when(threadMock.state()).thenReturn(
             initialized ? StreamThread.State.RUNNING : 
StreamThread.State.PARTITIONS_ASSIGNED
         );

Reply via email to