This is an automated email from the ASF dual-hosted git repository.

bbejeck pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 8fa7e095a0d KAFKA-19434: Startup state stores initialization (#20749)
8fa7e095a0d is described below

commit 8fa7e095a0d43077aecde2e2691f457aaf7177e8
Author: Eduwer Camacaro <[email protected]>
AuthorDate: Fri Feb 20 09:36:25 2026 -0500

    KAFKA-19434: Startup state stores initialization (#20749)
    
    Instead of creating Standby tasks from the state directory, we open the
    local stores that exists in the state directory.
    
    This resolves the issue raised in [#KAFKA-19434
    ](https://issues.apache.org/jira/browse/KAFKA-19434), where
    store-specific metrics were being duplicated due to tasks being created
    in the main thread and then assigned to a StreamThread.
    
    Additionally, since we can now read the offsets from the store during
    instance initialization, this clears the way for the implementation of
    KIP-1035. As of now, the stores are loading the offsets from the
    checkpoint file, but in a later PR, we will read these offsets from the
    state store itself.
    
    This PR modifies the behavior of Kafka Streams when initializing. Now
    for each pre-existing store on the state directory: We open the store,
    read offsets from the checkpoint file and then close it again. The
    reason why we open the store is because the store will be responsible
    for tracking the offsets and we will deprecate the checkpoint file.
    
    Reviewers: Nikiita Shuplestov<[email protected]>, Bill
     Bejeck<[email protected]>
---
 .../KafkaStreamsTelemetryIntegrationTest.java      |   9 +-
 .../org/apache/kafka/streams/KafkaStreams.java     |   7 +-
 .../processor/internals/ProcessorStateManager.java |  13 --
 .../processor/internals/StateDirectory.java        | 220 ++++++++++++++-------
 .../streams/processor/internals/TaskManager.java   |  74 ++++---
 .../org/apache/kafka/streams/KafkaStreamsTest.java |  23 +--
 .../processor/internals/StateDirectoryTest.java    |  99 +++-------
 .../processor/internals/TaskManagerTest.java       |  54 ++---
 8 files changed, 248 insertions(+), 251 deletions(-)

diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
index df993ac81ad..efb36a66293 100644
--- 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
@@ -284,6 +284,11 @@ public class KafkaStreamsTelemetryIntegrationTest {
         streamsApplicationProperties = props(groupProtocol);
         final Topology topology = topologyType.equals("simple") ? 
simpleTopology(false) : complexTopology();
 
+        shouldPassMetrics(topology, FIRST_INSTANCE_CLIENT);
+        shouldPassMetrics(topology, SECOND_INSTANCE_CLIENT);
+    }
+
+    private void shouldPassMetrics(final Topology topology, final int 
clientInstance) throws Exception {
         try (final KafkaStreams streams = new KafkaStreams(topology, 
streamsApplicationProperties)) {
             IntegrationTestUtils.startApplicationAndWaitUntilRunning(streams);
 
@@ -295,8 +300,8 @@ public class KafkaStreamsTelemetryIntegrationTest {
 
 
 
-            final List<MetricName> consumerPassedStreamThreadMetricNames = 
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics().stream().map(KafkaMetric::metricName).toList();
-            final List<MetricName> adminPassedStreamClientMetricNames = 
INTERCEPTING_ADMIN_CLIENTS.get(FIRST_INSTANCE_CLIENT).passedMetrics.stream().map(KafkaMetric::metricName).toList();
+            final List<MetricName> consumerPassedStreamThreadMetricNames = 
INTERCEPTING_CONSUMERS.get(clientInstance).passedMetrics().stream().map(KafkaMetric::metricName).toList();
+            final List<MetricName> adminPassedStreamClientMetricNames = 
INTERCEPTING_ADMIN_CLIENTS.get(clientInstance).passedMetrics.stream().map(KafkaMetric::metricName).toList();
 
 
             assertEquals(streamsThreadMetrics.size(), 
consumerPassedStreamThreadMetricNames.size());
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java 
b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
index 574dfec0733..ec21a704837 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
@@ -641,9 +641,6 @@ public class KafkaStreams implements AutoCloseable {
                 return;
             }
 
-            // all (alive) threads have received their assignment, close any 
remaining startup tasks, they're not needed
-            stateDirectory.closeStartupTasks();
-
             setState(State.RUNNING);
         }
 
@@ -1379,8 +1376,8 @@ public class KafkaStreams implements AutoCloseable {
      */
     public synchronized void start() throws IllegalStateException, 
StreamsException {
         if (setState(State.REBALANCING)) {
-            log.debug("Initializing STANDBY tasks for existing local state");
-            stateDirectory.initializeStartupTasks(topologyMetadata, 
streamsMetrics, logContext);
+            log.debug("Initializing store offsets for existing local state");
+            stateDirectory.initializeStartupStores(topologyMetadata, 
logContext, streamsMetrics);
 
             log.debug("Starting Streams client");
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index b4f671b79d6..f77a1f9632b 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -229,19 +229,6 @@ public class ProcessorStateManager implements StateManager 
{
         return new ProcessorStateManager(taskId, TaskType.STANDBY, eosEnabled, 
logContext, stateDirectory, storeToChangelogTopic, sourcePartitions);
     }
 
-    /**
-     * Standby tasks initialized for local state on-startup are only partially 
initialized, because they are not yet
-     * assigned to a StreamThread. Once assigned to a StreamThread, we 
complete their initialization here using the
-     * assigned StreamThread's context.
-     */
-    void assignToStreamThread(final LogContext logContext,
-                              final Collection<TopicPartition> 
sourcePartitions) {
-        this.sourcePartitions.clear();
-        this.log = logContext.logger(ProcessorStateManager.class);
-        this.logPrefix = logContext.logPrefix();
-        this.sourcePartitions.addAll(sourcePartitions);
-    }
-
     void registerStateStores(final List<StateStore> allStores, final 
InternalProcessorContext<?, ?> processorContext) {
         processorContext.uninitialize();
         for (final StateStore store : allStores) {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java
index bad10567ddc..ea2bb7a2f28 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java
@@ -17,16 +17,26 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.internals.StreamsConfigUtils;
+import org.apache.kafka.streams.processor.Cancellable;
+import org.apache.kafka.streams.processor.PunctuationType;
+import org.apache.kafka.streams.processor.Punctuator;
+import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.To;
+import org.apache.kafka.streams.processor.api.FixedKeyRecord;
+import org.apache.kafka.streams.processor.api.Record;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
@@ -46,21 +56,21 @@ import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.StandardOpenOption;
 import java.nio.file.attribute.PosixFilePermission;
+import java.time.Duration;
+import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.EnumSet;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ConcurrentSkipListSet;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Predicate;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
@@ -112,7 +122,7 @@ public class StateDirectory implements AutoCloseable {
     private FileLock stateDirLock;
 
     private final StreamsConfig config;
-    private final ConcurrentMap<TaskId, Task> tasksForLocalState = new 
ConcurrentHashMap<>();
+    private final Set<TaskId> tasksInLocalState = new 
ConcurrentSkipListSet<>();
 
     /**
      * Ensures that the state base directory as well as the application's 
sub-directory are created.
@@ -206,30 +216,38 @@ public class StateDirectory implements AutoCloseable {
         return stateDirLock != null;
     }
 
-    public void initializeStartupTasks(final TopologyMetadata topologyMetadata,
-                                       final StreamsMetricsImpl streamsMetrics,
-                                       final LogContext logContext) {
+    /**
+     * @throws LockException If another process already locks any of 
StateDirectory
+     * @throws ProcessorStateException if any of task directory does not exist 
and could not be created
+    */
+    public void initializeStartupStores(final TopologyMetadata 
topologyMetadata,
+                                        final LogContext logContext,
+                                        final StreamsMetricsImpl metricsImpl) {
         final List<TaskDirectory> nonEmptyTaskDirectories = 
listNonEmptyTaskDirectories();
         if (hasPersistentStores && !nonEmptyTaskDirectories.isEmpty()) {
-            final ThreadCache dummyCache = new ThreadCache(logContext, 0, 
streamsMetrics);
             final boolean eosEnabled = StreamsConfigUtils.eosEnabled(config);
 
+            // Initialize thread-specific resources needed to open stores in 
the state directory
+            final String threadLogPrefix = String.format("[%s]", 
Thread.currentThread().getName());
+            final ThreadCache dummyCache = new ThreadCache(new 
LogContext(threadLogPrefix), 0L, metricsImpl);
+
             // discover all non-empty task directories in StateDirectory
             for (final TaskDirectory taskDirectory : nonEmptyTaskDirectories) {
                 final String dirName = taskDirectory.file().getName();
-                final TaskId id = parseTaskDirectoryName(dirName, 
taskDirectory.namedTopology());
-                final ProcessorTopology subTopology = 
topologyMetadata.buildSubtopology(id);
+                final TaskId task = parseTaskDirectoryName(dirName, 
taskDirectory.namedTopology());
+                final ProcessorTopology subTopology = 
topologyMetadata.buildSubtopology(task);
 
                 // we still check if the task's sub-topology is stateful, even 
though we know its directory contains state,
                 // because it's possible that the topology has changed since 
that data was written, and is now stateless
-                // this therefore prevents us from creating unnecessary Tasks 
just because of some left-over state
+                // this therefore prevents us from creating unnecessary stores 
just because of some left-over state
                 if (subTopology.hasStateWithChangelogs()) {
-                    final Set<TopicPartition> inputPartitions = 
topologyMetadata.nodeToSourceTopics(id).values().stream()
+                    final Set<TopicPartition> inputPartitions = 
topologyMetadata.nodeToSourceTopics(task).values().stream()
                             .flatMap(Collection::stream)
-                            .map(t -> new TopicPartition(t, id.partition()))
+                            .map(t -> new TopicPartition(t, task.partition()))
                             .collect(Collectors.toSet());
-                    final ProcessorStateManager stateManager = 
ProcessorStateManager.createStartupTaskStateManager(
-                        id,
+                    // Open a temporary state manager that will open the 
stores inside the subtopology
+                    final ProcessorStateManager temporaryStateManager = 
ProcessorStateManager.createStartupTaskStateManager(
+                        task,
                         eosEnabled,
                         logContext,
                         this,
@@ -237,73 +255,42 @@ public class StateDirectory implements AutoCloseable {
                         inputPartitions
                     );
 
-                    final InternalProcessorContext<Object, Object> context = 
new ProcessorContextImpl(
-                        id,
-                        config,
-                        stateManager,
-                        streamsMetrics,
-                        dummyCache
-                    );
-
-                    final Task task = new StandbyTask(
-                        id,
-                        inputPartitions,
-                        subTopology,
-                        topologyMetadata.taskConfig(id),
-                        streamsMetrics,
-                        stateManager,
-                        this,
-                        dummyCache,
-                        context
-                    );
-
+                    final StartupContext initContext = new 
StartupContext(task, config, temporaryStateManager, metricsImpl, dummyCache);
                     try {
-                        task.initializeIfNeeded();
-
-                        tasksForLocalState.put(id, task);
-                    } catch (final TaskCorruptedException e) {
-                        // Task is corrupt - wipe it out (under EOS) and don't 
initialize a Standby for it
-                        task.suspend();
-                        task.closeDirty();
+                        // We only handle TaskCorruptedException at this 
point. Any other exception is considered fatal.
+                        StateManagerUtil.registerStateStores(log, 
threadLogPrefix, subTopology, temporaryStateManager, this, initContext);
+                        temporaryStateManager.checkpoint();
+                    } catch (final TaskCorruptedException tce) {
+                        // At this point, we only log a warning and continue 
with the startup store initialization.
+                        // The task-corrupted exception will be handled in the 
first Task assignment phase.
+                        log.warn("Failed to register startup state stores for 
task {}: {}", task, tce.getMessage());
+                    } finally {
+                        // Make sure the state manager writes the local 
checkpoint file before closing the stores
+                        // This will be replaced in the future when removing 
the checkpoint file dependency.
+                        temporaryStateManager.close();
                     }
+                    tasksInLocalState.add(task);
                 }
             }
         }
     }
 
     public boolean hasStartupTasks() {
-        return !tasksForLocalState.isEmpty();
+        return !tasksInLocalState.isEmpty();
     }
 
-    public Task removeStartupTask(final TaskId taskId) {
-        final Task task = tasksForLocalState.remove(taskId);
-        if (task != null) {
-            lockedTasksToOwner.replace(taskId, Thread.currentThread());
+    public synchronized boolean removeStartupState(final TaskId taskId) {
+        final boolean removed = tasksInLocalState.remove(taskId);
+        if (removed) {
+            lockedTasksToOwner.put(taskId, Thread.currentThread());
         }
-        return task;
+        return removed;
     }
 
-    public void closeStartupTasks() {
-        closeStartupTasks(t -> true);
-    }
-
-    private void closeStartupTasks(final Predicate<Task> predicate) {
-        if (!tasksForLocalState.isEmpty()) {
-            // "drain" Tasks first to ensure that we don't try to close Tasks 
that another thread is attempting to close
-            final Set<Task> drainedTasks = new 
HashSet<>(tasksForLocalState.size());
-            for (final Map.Entry<TaskId, Task> entry : 
tasksForLocalState.entrySet()) {
-                if (predicate.test(entry.getValue()) && 
removeStartupTask(entry.getKey()) != null) {
-                    // only add to our list of drained Tasks if we exclusively 
"claimed" a Task from tasksForLocalState
-                    // to ensure we don't accidentally try to drain the same 
Task multiple times from concurrent threads
-                    drainedTasks.add(entry.getValue());
-                }
-            }
 
-            // now that we have exclusive ownership of the drained tasks, 
close them
-            for (final Task task : drainedTasks) {
-                task.suspend();
-                task.closeClean();
-            }
+    private void unlockStartupStores() {
+        for (final TaskId task : tasksInLocalState) {
+            unlock(task);
         }
     }
 
@@ -513,7 +500,7 @@ public class StateDirectory implements AutoCloseable {
     @Override
     public void close() {
         if (hasPersistentStores) {
-            closeStartupTasks();
+            unlockStartupStores();
             try {
                 stateDirLock.release();
                 stateDirLockChannel.close();
@@ -596,6 +583,7 @@ public class StateDirectory implements AutoCloseable {
                         if (now - cleanupDelayMs > lastModifiedMs) {
                             log.info("{} Deleting obsolete state directory {} 
for task {} as {}ms has elapsed (cleanup delay is {}ms).",
                                 logPrefix(), dirName, id, now - 
lastModifiedMs, cleanupDelayMs);
+                            removeStartupState(id);
                             Utils.delete(taskDir.file());
                         }
                     }
@@ -631,7 +619,6 @@ public class StateDirectory implements AutoCloseable {
         );
         if (namedTopologyDirs != null) {
             for (final File namedTopologyDir : namedTopologyDirs) {
-                closeStartupTasks(task -> 
task.id().topologyName().equals(parseNamedTopologyFromDirectory(namedTopologyDir.getName())));
                 final File[] contents = namedTopologyDir.listFiles();
                 if (contents != null && contents.length == 0) {
                     try {
@@ -669,7 +656,6 @@ public class StateDirectory implements AutoCloseable {
             log.debug("Tried to clear out the local state for NamedTopology {} 
but none was found", topologyName);
         }
         try {
-            closeStartupTasks(task -> 
task.id().topologyName().equals(topologyName));
             Utils.delete(namedTopologyDir);
         } catch (final IOException e) {
             log.error("Hit an unexpected error while clearing local state for 
topology " + topologyName, e);
@@ -813,4 +799,96 @@ public class StateDirectory implements AutoCloseable {
             return Objects.hash(file, namedTopology);
         }
     }
+
+    private static class StartupContext extends 
AbstractProcessorContext<Object, Object> {
+
+        private final StateManager stateManager;
+        final StreamsMetricsImpl metricsImpl;
+
+        public StartupContext(final TaskId taskId, final StreamsConfig config, 
final StateManager stateManager, final StreamsMetricsImpl metricsImpl, 
ThreadCache cache) {
+            super(taskId, config, metricsImpl, cache);
+            this.stateManager = stateManager;
+            this.metricsImpl = metricsImpl;
+        }
+
+        @Override
+        protected StateManager stateManager() {
+            return stateManager;
+        }
+
+        @Override
+        public void transitionToActive(final StreamTask streamTask, final 
RecordCollector recordCollector, final ThreadCache newCache) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public void transitionToStandby(final ThreadCache newCache) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public void registerCacheFlushListener(final String namespace, final 
ThreadCache.DirtyEntryFlushListener listener) {
+        }
+
+        @Override
+        public void logChange(final String storeName, final Bytes key, final 
byte[] value, final long timestamp, final Position position) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public <K, V> void forward(final K key, final V value) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public <K, V> void forward(final K key, final V value, final To to) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public void commit() {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public long currentStreamTimeMs() {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public <S extends StateStore> S getStateStore(final String name) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public Cancellable schedule(final Duration interval, final 
PunctuationType type, final Punctuator callback) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public Cancellable schedule(final Instant startTime, final Duration 
interval, final PunctuationType type, final Punctuator callback) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+
+        @Override
+        public <K, V> void forward(final FixedKeyRecord<K, V> record) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public <K, V> void forward(final FixedKeyRecord<K, V> record, final 
String childName) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public <K, V> void forward(final Record<K, V> record) {
+            throw new IllegalStateException("Should not be called");
+        }
+
+        @Override
+        public <K, V> void forward(final Record<K, V> record, final String 
childName) {
+            throw new IllegalStateException("Should not be called");
+        }
+    }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 6e7ebf46d90..d4e0baef1ef 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -318,26 +318,34 @@ public class TaskManager {
         }
     }
 
-    private Map<Task, Set<TopicPartition>> assignStartupTasks(final 
Map<TaskId, Set<TopicPartition>> tasksToAssign,
-                                                              final String 
threadLogPrefix) {
+    private Collection<Task> assignActiveTaskFromStartupState(final 
Map<TaskId, Set<TopicPartition>> tasksToAssign) {
         if (stateDirectory.hasStartupTasks()) {
-            final Map<Task, Set<TopicPartition>> assignedTasks = new 
HashMap<>(tasksToAssign.size());
+            final Map<TaskId, Set<TopicPartition>> assignedTasks = new 
HashMap<>(tasksToAssign.size());
             for (final Map.Entry<TaskId, Set<TopicPartition>> entry : 
tasksToAssign.entrySet()) {
                 final TaskId taskId = entry.getKey();
-                final Task task = stateDirectory.removeStartupTask(taskId);
-                if (task != null) {
-                    // replace our dummy values with the real ones, now we 
know our thread and assignment
-                    final Set<TopicPartition> inputPartitions = 
entry.getValue();
-                    task.stateManager().assignToStreamThread(new 
LogContext(threadLogPrefix), inputPartitions);
-                    updateInputPartitionsOfStandbyTaskIfTheyChanged(task, 
inputPartitions);
-
-                    assignedTasks.put(task, inputPartitions);
+                if (stateDirectory.removeStartupState(taskId)) {
+                    assignedTasks.put(taskId, entry.getValue());
                 }
             }
+            return activeTaskCreator.createTasks(mainConsumer, assignedTasks);
+        } else {
+            return Collections.emptySet();
+        }
+    }
 
-            return assignedTasks;
+    private Collection<Task> assignStartupTasks(final Map<TaskId, 
Set<TopicPartition>> tasksToAssign) {
+        if (stateDirectory.hasStartupTasks()) {
+            final Map<TaskId, Set<TopicPartition>> assignedTasks = new 
HashMap<>(tasksToAssign.size());
+            for (final Map.Entry<TaskId, Set<TopicPartition>> entry : 
tasksToAssign.entrySet()) {
+                final TaskId taskId = entry.getKey();
+                if (stateDirectory.removeStartupState(taskId)) {
+                    final Set<TopicPartition> inputPartitions = 
entry.getValue();
+                    assignedTasks.put(taskId, inputPartitions);
+                }
+            }
+            return standbyTaskCreator.createTasks(assignedTasks);
         } else {
-            return Collections.emptyMap();
+            return Collections.emptySet();
         }
     }
 
@@ -484,7 +492,7 @@ public class TaskManager {
                              final Set<Task> tasksToCloseClean,
                              final Map<TaskId, RuntimeException> failedTasks) {
         handleTasksPendingInitialization();
-        handleStartupTaskReuse(activeTasksToCreate, standbyTasksToCreate, 
failedTasks);
+        handleExistingStateForTasks(activeTasksToCreate, standbyTasksToCreate);
         handleRestoringAndUpdatingTasks(activeTasksToCreate, 
standbyTasksToCreate, failedTasks);
         handleRunningAndSuspendedTasks(activeTasksToCreate, 
standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
     }
@@ -502,31 +510,21 @@ public class TaskManager {
         }
     }
 
-    private void handleStartupTaskReuse(final Map<TaskId, Set<TopicPartition>> 
activeTasksToCreate,
-                                        final Map<TaskId, Set<TopicPartition>> 
standbyTasksToCreate,
-                                        final Map<TaskId, RuntimeException> 
failedTasks) {
-        final Map<Task, Set<TopicPartition>> startupStandbyTasksToRecycle = 
assignStartupTasks(activeTasksToCreate, logPrefix);
-        final Map<Task, Set<TopicPartition>> startupStandbyTasksToUse = 
assignStartupTasks(standbyTasksToCreate, logPrefix);
-
-        // recycle the startup standbys to active, and remove them from the 
set of actives that need to be created
-        if (!startupStandbyTasksToRecycle.isEmpty()) {
-            final Set<Task> tasksToCloseDirty = new 
TreeSet<>(Comparator.comparing(Task::id));
-            for (final Map.Entry<Task, Set<TopicPartition>> entry : 
startupStandbyTasksToRecycle.entrySet()) {
-                final Task task = entry.getKey();
-                recycleTaskFromStateUpdater(task, entry.getValue(), 
tasksToCloseDirty, failedTasks);
-                activeTasksToCreate.remove(task.id());
-            }
-
-            // if any standby tasks failed to recycle, close them dirty
-            tasksToCloseDirty.forEach(task ->
-                closeTaskDirty(task, false)
-            );
+    private void handleExistingStateForTasks(final Map<TaskId, 
Set<TopicPartition>> activeTasksToCreate,
+                                             final Map<TaskId, 
Set<TopicPartition>> standbyTasksToCreate) {
+        final Collection<Task> activeTasks = 
assignActiveTaskFromStartupState(activeTasksToCreate);
+        for (final Task activeTask : activeTasks) {
+            activeTasksToCreate.remove(activeTask.id());
         }
-
-        // use startup Standbys as real Standby tasks
-        if (!startupStandbyTasksToUse.isEmpty()) {
-            tasks.addPendingTasksToInit(startupStandbyTasksToUse.keySet());
-            startupStandbyTasksToUse.keySet().forEach(task -> 
standbyTasksToCreate.remove(task.id()));
+        final Collection<Task> standbyTasks = 
assignStartupTasks(standbyTasksToCreate);
+        for (final Task standbyTask : standbyTasks) {
+            standbyTasksToCreate.remove(standbyTask.id());
+        }
+        if (!activeTasks.isEmpty()) {
+            tasks.addPendingTasksToInit(activeTasks);
+        }
+        if (!standbyTasks.isEmpty()) {
+            tasks.addPendingTasksToInit(standbyTasks);
         }
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java 
b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index 14d4cf8c21f..28ca83ddb2b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -416,32 +416,13 @@ public class KafkaStreamsTest {
             try (final KafkaStreams streams = new 
KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) {
                 assertEquals(1, constructed.constructed().size());
                 final StateDirectory stateDirectory = 
constructed.constructed().get(0);
-                verify(stateDirectory, times(0)).initializeStartupTasks(any(), 
any(), any());
+                verify(stateDirectory, 
times(0)).initializeStartupStores(any(), any(), any());
                 streams.start();
-                verify(stateDirectory, times(1)).initializeStartupTasks(any(), 
any(), any());
+                verify(stateDirectory, 
times(1)).initializeStartupStores(any(), any(), any());
             }
         }
     }
 
-    @Test
-    public void shouldCloseStartupTasksAfterFirstRebalance() throws Exception {
-        prepareStreams();
-        final AtomicReference<StreamThread.State> state1 = 
prepareStreamThread(streamThreadOne, 1);
-        final AtomicReference<StreamThread.State> state2 = 
prepareStreamThread(streamThreadTwo, 2);
-        prepareThreadState(streamThreadOne, state1);
-        prepareThreadState(streamThreadTwo, state2);
-        try (final MockedConstruction<StateDirectory> constructed = 
mockConstruction(StateDirectory.class,
-            (mock, context) -> 
when(mock.initializeProcessId()).thenReturn(UUID.randomUUID()))) {
-            try (final KafkaStreams streams = new 
KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) {
-                assertEquals(1, constructed.constructed().size());
-                final StateDirectory stateDirectory = 
constructed.constructed().get(0);
-                streams.setStateListener(streamsStateListener);
-                streams.start();
-                waitForCondition(() -> streams.state() == State.RUNNING, 
"Streams never started.");
-                verify(stateDirectory, times(1)).closeStartupTasks();
-            }
-        }
-    }
 
     @Test
     public void stateShouldTransitToRunningIfNonDeadThreadsBackToRunning() 
throws Exception {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateDirectoryTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateDirectoryTest.java
index 616c397d711..e2f57d83422 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateDirectoryTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateDirectoryTest.java
@@ -17,7 +17,6 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.LogCaptureAppender;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
@@ -28,7 +27,6 @@ import 
org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
-import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.apache.kafka.test.MockKeyValueStore;
 import org.apache.kafka.test.TestUtils;
@@ -80,7 +78,6 @@ import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.endsWith;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -855,113 +852,65 @@ public class StateDirectoryTest {
     }
 
     @Test
-    public void shouldNotInitializeStandbyTasksWhenNoLocalState() {
+    public void shouldNotInitializeStartupStateWhenNoLocalState() {
         final TaskId taskId = new TaskId(0, 0);
-        initializeStartupTasks(new TaskId(0, 0), false);
+        initializeStartupStores(new TaskId(0, 0), false);
         assertFalse(directory.hasStartupTasks());
-        assertNull(directory.removeStartupTask(taskId));
+        assertFalse(directory.removeStartupState(taskId));
         assertFalse(directory.hasStartupTasks());
     }
 
     @Test
-    public void shouldInitializeStandbyTasksForLocalState() {
+    public void shouldInitializeStartupStateForLocalState() {
         final TaskId taskId = new TaskId(0, 0);
-        initializeStartupTasks(new TaskId(0, 0), true);
+        initializeStartupStores(new TaskId(0, 0), true);
         assertTrue(directory.hasStartupTasks());
-        assertNotNull(directory.removeStartupTask(taskId));
+        assertTrue(directory.removeStartupState(taskId));
         assertFalse(directory.hasStartupTasks());
-        assertNull(directory.removeStartupTask(taskId));
+        assertFalse(directory.removeStartupState(taskId));
     }
 
     @Test
-    public void shouldNotAssignStartupTasksWeDontHave() {
+    public void shouldNotAssignStartupStateWeDontHave() {
         final TaskId taskId = new TaskId(0, 0);
-        initializeStartupTasks(taskId, false);
-        final Task task = directory.removeStartupTask(taskId);
-        assertNull(task);
-    }
-
-    private class FakeStreamThread extends Thread {
-        private final TaskId taskId;
-        private final AtomicReference<Task> result;
-
-        private FakeStreamThread(final TaskId taskId, final 
AtomicReference<Task> result) {
-            this.taskId = taskId;
-            this.result = result;
-        }
-
-        @Override
-        public void run() {
-            result.set(directory.removeStartupTask(taskId));
-        }
+        initializeStartupStores(taskId, false);
+        assertFalse(directory.removeStartupState(taskId));
     }
 
     @Test
-    public void shouldAssignStartupTaskToStreamThread() throws 
InterruptedException {
+    public void shouldUnlockStartupStateOnClose() {
         final TaskId taskId = new TaskId(0, 0);
-
-        initializeStartupTasks(taskId, true);
-
-        // main thread owns the newly initialized tasks
-        assertThat(directory.lockOwner(taskId), is(Thread.currentThread()));
-
-        // spawn off a "fake" StreamThread, so we can verify the lock was 
updated to the correct thread
-        final AtomicReference<Task> result = new AtomicReference<>();
-        final Thread streamThread = new FakeStreamThread(taskId, result);
-        streamThread.start();
-        streamThread.join();
-        final Task task = result.get();
-
-        assertNotNull(task);
-        assertThat(task, instanceOf(StandbyTask.class));
-
-        // verify the owner of the task directory lock has been shifted over 
to our assigned StreamThread
-        assertThat(directory.lockOwner(taskId), 
is(instanceOf(FakeStreamThread.class)));
-    }
-
-    @Test
-    public void shouldUnlockStartupTasksOnClose() {
-        final TaskId taskId = new TaskId(0, 0);
-        initializeStartupTasks(taskId, true);
+        initializeStartupStores(taskId, true);
 
         assertEquals(Thread.currentThread(), directory.lockOwner(taskId));
-        directory.closeStartupTasks();
-        assertNull(directory.lockOwner(taskId));
-    }
-
-    @Test
-    public void shouldCloseStartupTasksOnDirectoryClose() {
-        final StateStore store = initializeStartupTasks(new TaskId(0, 0), 
true);
-
-        assertTrue(directory.hasStartupTasks());
-        assertTrue(store.isOpen());
-
         directory.close();
-
-        assertFalse(directory.hasStartupTasks());
-        assertFalse(store.isOpen());
+        assertNull(directory.lockOwner(taskId));
     }
 
     @Test
-    public void shouldNotCloseStartupTasksOnAutoCleanUp() {
+    public void shouldCloseStartupStateOnAutoCleanUp() {
         // we need to set this because the auto-cleanup uses the last-modified 
time from the filesystem,
         // which can't be mocked
         time.setCurrentTimeMs(System.currentTimeMillis());
+        TaskId taskId = new TaskId(0, 0);
 
-        final StateStore store = initializeStartupTasks(new TaskId(0, 0), 
true);
+        final StateStore store = initializeStartupStores(taskId, true);
 
         assertTrue(directory.hasStartupTasks());
-        assertTrue(store.isOpen());
+        assertFalse(store.isOpen());
 
         time.sleep(10000);
+        // We need to manually unlock the task because the cleanup process only
+        // cleans tasks that are no-longer owned by the current thread
+        directory.unlock(taskId);
 
         directory.cleanRemovedTasks(1000);
 
-        assertTrue(directory.hasStartupTasks());
-        assertTrue(store.isOpen());
+        assertFalse(directory.hasStartupTasks());
+        assertFalse(store.isOpen());
     }
 
-    private StateStore initializeStartupTasks(final TaskId taskId, final 
boolean createTaskDir) {
+    private StateStore initializeStartupStores(final TaskId taskId, final 
boolean createTaskDir) {
         directory.initializeProcessId();
         final TopologyMetadata metadata = Mockito.mock(TopologyMetadata.class);
         final TopologyConfig topologyConfig = new TopologyConfig(config);
@@ -987,7 +936,7 @@ public class StateDirectoryTest {
         
Mockito.when(metadata.buildSubtopology(ArgumentMatchers.any())).thenReturn(processorTopology);
         
Mockito.when(metadata.taskConfig(ArgumentMatchers.any())).thenReturn(topologyConfig.getTaskConfig());
 
-        directory.initializeStartupTasks(metadata, new StreamsMetricsImpl(new 
Metrics(), "test", time), new LogContext("test"));
+        directory.initializeStartupStores(metadata, new LogContext("test"), 
null);
 
         return store;
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 586af7cc1ae..69e071dd2b3 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -111,6 +111,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doThrow;
@@ -2441,7 +2442,7 @@ public class TaskManagerTest {
         final StreamsProducer producer = mock(StreamsProducer.class);
         when(activeTaskCreator.streamsProducer()).thenReturn(producer);
         final ConsumerGroupMetadata groupMetadata = 
mock(ConsumerGroupMetadata.class);
-        
+
         when(consumer.groupMetadata()).thenReturn(groupMetadata);
         when(consumer.assignment()).thenReturn(union(HashSet::new, 
taskId00Partitions, taskId01Partitions));
 
@@ -4880,36 +4881,33 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldRecycleStartupTasksFromStateDirectoryAsActive() {
+    public void shouldCreateActiveTaskFromStartupStateStore() {
         final Tasks taskRegistry = new Tasks(new LogContext());
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, taskRegistry);
-        final StandbyTask startupTask = standbyTask(taskId00, 
taskId00ChangelogPartitions).build();
-
         final StreamTask activeTask = statefulTask(taskId00, 
taskId00ChangelogPartitions).build();
-        when(activeTaskCreator.createActiveTaskFromStandby(eq(startupTask), 
eq(taskId00Partitions), any()))
-                .thenReturn(activeTask);
-
+        when(activeTaskCreator.createTasks(consumer, 
taskId00Assignment)).thenReturn(singletonList(activeTask));
         when(stateDirectory.hasStartupTasks()).thenReturn(true, false);
-        
when(stateDirectory.removeStartupTask(taskId00)).thenReturn(startupTask, (Task) 
null);
+        when(stateDirectory.removeStartupState(taskId00)).thenReturn(true, 
false);
+
+        assertFalse(taskRegistry.hasPendingTasksToInit());
 
         taskManager.handleAssignment(taskId00Assignment, 
Collections.emptyMap());
 
-        // ensure we used our existing startup Task directly as a Standby
+        // ensure we used our existing startup state store to create our 
active task
         assertTrue(taskRegistry.hasPendingTasksToInit());
         assertEquals(Collections.singleton(activeTask), 
taskRegistry.drainPendingTasksToInit());
 
-        // we're using a mock StateUpdater here, so now that we've drained the 
task from the queue of startup tasks to init
+        // we're using a mock StateUpdater here, so now that we've created the 
task from the startup store
         // let's "add" it to our mock StateUpdater
         
when(stateUpdater.tasks()).thenReturn(Collections.singleton(activeTask));
         when(stateUpdater.standbyTasks()).thenReturn(Collections.emptySet());
 
-        // ensure we recycled our existing startup Standby into an Active task
-        verify(activeTaskCreator).createActiveTaskFromStandby(eq(startupTask), 
eq(taskId00Partitions), any());
+        InOrder inOrder = inOrder(activeTaskCreator);
+        inOrder.verify(activeTaskCreator).createTasks(same(consumer), 
eq(Map.of(taskId00, taskId00Partitions)));
+        inOrder.verify(activeTaskCreator).createTasks(consumer, emptyMap());
 
-        // ensure we didn't construct any new Tasks
-        verify(activeTaskCreator).createTasks(any(), 
eq(Collections.emptyMap()));
-        verify(standbyTaskCreator).createTasks(Collections.emptyMap());
-        verifyNoMoreInteractions(activeTaskCreator);
+        inOrder.verifyNoMoreInteractions();
+        verify(standbyTaskCreator).createTasks(Map.of());
         verifyNoMoreInteractions(standbyTaskCreator);
 
         // verify the recycled task is now being used as an assigned Active
@@ -4918,36 +4916,40 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldUseStartupTasksFromStateDirectoryAsStandby() {
+    public void shouldCreateStandbyTaskFromStartupStateStore() {
         final Tasks taskRegistry = new Tasks(new LogContext());
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, taskRegistry);
-        final StandbyTask startupTask = standbyTask(taskId00, 
taskId00ChangelogPartitions).build();
+        final StandbyTask standbyTask = standbyTask(taskId00, 
taskId00ChangelogPartitions).build();
+        when(standbyTaskCreator.createTasks(eq(Map.of(taskId00, 
taskId00Partitions)))).thenReturn(Set.of(standbyTask));
+
 
         when(stateDirectory.hasStartupTasks()).thenReturn(true, true, false);
-        
when(stateDirectory.removeStartupTask(taskId00)).thenReturn(startupTask, (Task) 
null);
+        when(stateDirectory.removeStartupState(taskId00)).thenReturn(true, 
false);
 
         assertFalse(taskRegistry.hasPendingTasksToInit());
 
         taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
 
-        // ensure we used our existing startup Task directly as a Standby
+        // ensure we used our existing startup state to create our standby task
         assertTrue(taskRegistry.hasPendingTasksToInit());
-        assertEquals(Collections.singleton(startupTask), 
taskRegistry.drainPendingTasksToInit());
+        assertEquals(Collections.singleton(standbyTask), 
taskRegistry.drainPendingTasksToInit());
 
         // we're using a mock StateUpdater here, so now that we've drained the 
task from the queue of startup tasks to init
         // let's "add" it to our mock StateUpdater
-        
when(stateUpdater.tasks()).thenReturn(Collections.singleton(startupTask));
-        
when(stateUpdater.standbyTasks()).thenReturn(Collections.singleton(startupTask));
+        
when(stateUpdater.tasks()).thenReturn(Collections.singleton(standbyTask));
+        
when(stateUpdater.standbyTasks()).thenReturn(Collections.singleton(standbyTask));
 
         // ensure we didn't construct any new Tasks, or recycle an existing 
Task; we only used the one we already have
-        verify(activeTaskCreator).createTasks(any(), 
eq(Collections.emptyMap()));
-        verify(standbyTaskCreator).createTasks(Collections.emptyMap());
+        verify(activeTaskCreator, times(2)).createTasks(any(), 
eq(Collections.emptyMap()));
+        InOrder inOrder = inOrder(standbyTaskCreator);
+        inOrder.verify(standbyTaskCreator).createTasks(Map.of(taskId00, 
taskId00Partitions));
+        inOrder.verify(standbyTaskCreator).createTasks(Collections.emptyMap());
         verifyNoMoreInteractions(activeTaskCreator);
         verifyNoMoreInteractions(standbyTaskCreator);
 
         // verify the startup Standby is now being used as an assigned Standby
         assertEquals(Collections.emptyMap(), taskManager.activeTaskMap());
-        assertEquals(Collections.singletonMap(taskId00, startupTask), 
taskManager.standbyTaskMap());
+        assertEquals(Collections.singletonMap(taskId00, standbyTask), 
taskManager.standbyTaskMap());
     }
 
     @Test


Reply via email to