http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
new file mode 100644
index 0000000..4bdad9a
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -0,0 +1,524 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.CommitFailedException;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.ProducerFencedException;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static java.util.Collections.singleton;
+
+class TaskManager {
+    // initialize the task list
+    // activeTasks needs to be concurrent as it can be accessed
+    // by QueryableState
+    private static final Logger log = 
LoggerFactory.getLogger(TaskManager.class);
+    private final Map<TaskId, Task> activeTasks = new ConcurrentHashMap<>();
+    private final Map<TaskId, Task> standbyTasks = new HashMap<>();
+    private final Map<TopicPartition, Task> activeTasksByPartition = new 
HashMap<>();
+    private final Map<TopicPartition, Task> standbyTasksByPartition = new 
HashMap<>();
+    private final Set<TaskId> prevActiveTasks = new TreeSet<>();
+    private final Map<TaskId, Task> suspendedTasks = new HashMap<>();
+    private final Map<TaskId, Task> suspendedStandbyTasks = new HashMap<>();
+    private final ChangelogReader changelogReader;
+    private final Time time;
+    private final String logPrefix;
+    private final Consumer<byte[], byte[]> restoreConsumer;
+    private final StreamThread.AbstractTaskCreator taskCreator;
+    private final StreamThread.AbstractTaskCreator standbyTaskCreator;
+    private ThreadMetadataProvider threadMetadataProvider;
+    private Consumer<byte[], byte[]> consumer;
+
+    TaskManager(final ChangelogReader changelogReader,
+                final Time time,
+                final String logPrefix,
+                final Consumer<byte[], byte[]> restoreConsumer,
+                final StreamThread.AbstractTaskCreator taskCreator,
+                final StreamThread.AbstractTaskCreator standbyTaskCreator) {
+        this.changelogReader = changelogReader;
+        this.time = time;
+        this.logPrefix = logPrefix;
+        this.restoreConsumer = restoreConsumer;
+        this.taskCreator = taskCreator;
+        this.standbyTaskCreator = standbyTaskCreator;
+    }
+
+    void createTasks(final Collection<TopicPartition> assignment) {
+        if (threadMetadataProvider == null) {
+            throw new IllegalStateException(logPrefix + " taskIdProvider has 
not been initialized while adding stream tasks. This should not happen.");
+        }
+        if (consumer == null) {
+            throw new IllegalStateException(logPrefix + " consumer has not 
been initialized while adding stream tasks. This should not happen.");
+        }
+
+        final long start = time.milliseconds();
+        changelogReader.clear();
+        // do this first as we may have suspended standby tasks that
+        // will become active or vice versa
+        closeNonAssignedSuspendedStandbyTasks();
+        Map<TaskId, Set<TopicPartition>> assignedActiveTasks = 
threadMetadataProvider.activeTasks();
+        closeNonAssignedSuspendedTasks(assignedActiveTasks);
+        addStreamTasks(assignment, assignedActiveTasks, start);
+        changelogReader.restore();
+        addStandbyTasks(start);
+    }
+
+    void setThreadMetadataProvider(final ThreadMetadataProvider 
threadMetadataProvider) {
+        this.threadMetadataProvider = threadMetadataProvider;
+    }
+
+    private void closeNonAssignedSuspendedStandbyTasks() {
+        final Set<TaskId> currentSuspendedTaskIds = 
threadMetadataProvider.standbyTasks().keySet();
+        final Iterator<Map.Entry<TaskId, Task>> standByTaskIterator = 
suspendedStandbyTasks.entrySet().iterator();
+        while (standByTaskIterator.hasNext()) {
+            final Map.Entry<TaskId, Task> suspendedTask = 
standByTaskIterator.next();
+            if (!currentSuspendedTaskIds.contains(suspendedTask.getKey())) {
+                final Task task = suspendedTask.getValue();
+                log.debug("{} Closing suspended and not re-assigned standby 
task {}", logPrefix, task.id());
+                try {
+                    task.close(true);
+                } catch (final Exception e) {
+                    log.error("{} Failed to remove suspended standby task {} 
due to the following error:", logPrefix, task.id(), e);
+                } finally {
+                    standByTaskIterator.remove();
+                }
+            }
+        }
+    }
+
+    private void closeNonAssignedSuspendedTasks(final Map<TaskId, 
Set<TopicPartition>> newTaskAssignment) {
+        final Iterator<Map.Entry<TaskId, Task>> suspendedTaskIterator = 
suspendedTasks.entrySet().iterator();
+        while (suspendedTaskIterator.hasNext()) {
+            final Map.Entry<TaskId, Task> next = suspendedTaskIterator.next();
+            final Task task = next.getValue();
+            final Set<TopicPartition> assignedPartitionsForTask = 
newTaskAssignment.get(next.getKey());
+            if (!task.partitions().equals(assignedPartitionsForTask)) {
+                log.debug("{} Closing suspended and not re-assigned task {}", 
logPrefix, task.id());
+                try {
+                    task.closeSuspended(true, null);
+                } catch (final Exception e) {
+                    log.error("{} Failed to close suspended task {} due to the 
following error:", logPrefix, next.getKey(), e);
+                } finally {
+                    suspendedTaskIterator.remove();
+                }
+            }
+        }
+    }
+
+    private void addStreamTasks(final Collection<TopicPartition> assignment, 
final Map<TaskId, Set<TopicPartition>> assignedTasks, final long start) {
+        final Map<TaskId, Set<TopicPartition>> newTasks = new HashMap<>();
+
+        // collect newly assigned tasks and reopen re-assigned tasks
+        log.debug("{} Adding assigned tasks as active: {}", logPrefix, 
assignedTasks);
+        for (final Map.Entry<TaskId, Set<TopicPartition>> entry : 
assignedTasks.entrySet()) {
+            final TaskId taskId = entry.getKey();
+            final Set<TopicPartition> partitions = entry.getValue();
+
+            if (assignment.containsAll(partitions)) {
+                try {
+                    final Task task = findMatchingSuspendedTask(taskId, 
partitions);
+                    if (task != null) {
+                        suspendedTasks.remove(taskId);
+                        task.resume();
+
+                        activeTasks.put(taskId, task);
+
+                        for (final TopicPartition partition : partitions) {
+                            activeTasksByPartition.put(partition, task);
+                        }
+                    } else {
+                        newTasks.put(taskId, partitions);
+                    }
+                } catch (final StreamsException e) {
+                    log.error("{} Failed to create an active task {} due to 
the following error:", logPrefix, taskId, e);
+                    throw e;
+                }
+            } else {
+                log.warn("{} Task {} owned partitions {} are not contained in 
the assignment {}", logPrefix, taskId, partitions, assignment);
+            }
+        }
+
+        // create all newly assigned tasks (guard against race condition with 
other thread via backoff and retry)
+        // -> other thread will call removeSuspendedTasks(); eventually
+        log.trace("{} New active tasks to be created: {}", logPrefix, 
newTasks);
+
+        if (!newTasks.isEmpty()) {
+            final Map<Task, Set<TopicPartition>> createdTasks = 
taskCreator.retryWithBackoff(consumer, newTasks, start);
+            for (final Map.Entry<Task, Set<TopicPartition>> entry : 
createdTasks.entrySet()) {
+                final Task task = entry.getKey();
+                activeTasks.put(task.id(), task);
+                for (final TopicPartition partition : entry.getValue()) {
+                    activeTasksByPartition.put(partition, task);
+                }
+            }
+        }
+    }
+
+    private void addStandbyTasks(final long start) {
+        final Map<TopicPartition, Long> checkpointedOffsets = new HashMap<>();
+
+        final Map<TaskId, Set<TopicPartition>> newStandbyTasks = new 
HashMap<>();
+
+        Map<TaskId, Set<TopicPartition>> assignedStandbyTasks = 
threadMetadataProvider.standbyTasks();
+        log.debug("{} Adding assigned standby tasks {}", logPrefix, 
assignedStandbyTasks);
+        // collect newly assigned standby tasks and reopen re-assigned standby 
tasks
+        for (final Map.Entry<TaskId, Set<TopicPartition>> entry : 
assignedStandbyTasks.entrySet()) {
+            final TaskId taskId = entry.getKey();
+            final Set<TopicPartition> partitions = entry.getValue();
+            final Task task = findMatchingSuspendedStandbyTask(taskId, 
partitions);
+
+            if (task != null) {
+                suspendedStandbyTasks.remove(taskId);
+                task.resume();
+            } else {
+                newStandbyTasks.put(taskId, partitions);
+            }
+
+            updateStandByTasks(checkpointedOffsets, taskId, partitions, task);
+        }
+
+        // create all newly assigned standby tasks (guard against race 
condition with other thread via backoff and retry)
+        // -> other thread will call removeSuspendedStandbyTasks(); eventually
+        log.trace("{} New standby tasks to be created: {}", logPrefix, 
newStandbyTasks);
+        if (!newStandbyTasks.isEmpty()) {
+            final Map<Task, Set<TopicPartition>> createdStandbyTasks = 
standbyTaskCreator.retryWithBackoff(consumer, newStandbyTasks, start);
+            for (Map.Entry<Task, Set<TopicPartition>> entry : 
createdStandbyTasks.entrySet()) {
+                final Task task = entry.getKey();
+                updateStandByTasks(checkpointedOffsets, task.id(), 
entry.getValue(), task);
+            }
+        }
+
+        restoreConsumer.assign(checkpointedOffsets.keySet());
+
+        for (final Map.Entry<TopicPartition, Long> entry : 
checkpointedOffsets.entrySet()) {
+            final TopicPartition partition = entry.getKey();
+            final long offset = entry.getValue();
+            if (offset >= 0) {
+                restoreConsumer.seek(partition, offset);
+            } else {
+                restoreConsumer.seekToBeginning(singleton(partition));
+            }
+        }
+    }
+
+    private void updateStandByTasks(final Map<TopicPartition, Long> 
checkpointedOffsets,
+                                    final TaskId taskId,
+                                    final Set<TopicPartition> partitions,
+                                    final Task task) {
+        if (task != null) {
+            standbyTasks.put(taskId, task);
+            for (final TopicPartition partition : partitions) {
+                standbyTasksByPartition.put(partition, task);
+            }
+            // collect checked pointed offsets to position the restore consumer
+            // this include all partitions from which we restore states
+            for (final TopicPartition partition : 
task.checkpointedOffsets().keySet()) {
+                standbyTasksByPartition.put(partition, task);
+            }
+            checkpointedOffsets.putAll(task.checkpointedOffsets());
+        }
+    }
+
+    List<Task> allTasks() {
+        final List<Task> tasks = activeAndStandbytasks();
+        tasks.addAll(suspendedAndSuspendedStandbytasks());
+        return tasks;
+    }
+
+    private List<Task> activeAndStandbytasks() {
+        final List<Task> tasks = new ArrayList<>(activeTasks.values());
+        tasks.addAll(standbyTasks.values());
+        return tasks;
+    }
+
+    private List<Task> suspendedAndSuspendedStandbytasks() {
+        final List<Task> tasks = new ArrayList<>(suspendedTasks.values());
+        tasks.addAll(suspendedStandbyTasks.values());
+        return tasks;
+    }
+
+    private Task findMatchingSuspendedTask(final TaskId taskId, final 
Set<TopicPartition> partitions) {
+        if (suspendedTasks.containsKey(taskId)) {
+            final Task task = suspendedTasks.get(taskId);
+            if (task.partitions().equals(partitions)) {
+                return task;
+            }
+        }
+        return null;
+    }
+
+    private Task findMatchingSuspendedStandbyTask(final TaskId taskId, final 
Set<TopicPartition> partitions) {
+        if (suspendedStandbyTasks.containsKey(taskId)) {
+            final Task task = suspendedStandbyTasks.get(taskId);
+            if (task.partitions().equals(partitions)) {
+                return task;
+            }
+        }
+        return null;
+    }
+
+    Set<TaskId> activeTaskIds() {
+        return Collections.unmodifiableSet(activeTasks.keySet());
+    }
+
+    Set<TaskId> standbyTaskIds() {
+        return Collections.unmodifiableSet(standbyTasks.keySet());
+    }
+
+    Set<TaskId> prevActiveTaskIds() {
+        return Collections.unmodifiableSet(prevActiveTasks);
+    }
+
+    /**
+     * Similar to shutdownTasksAndState, however does not close the task 
managers, in the hope that
+     * soon the tasks will be assigned again
+     */
+    void suspendTasksAndState()  {
+        log.debug("{} Suspending all active tasks {} and standby tasks {}",
+                  logPrefix, activeTasks.keySet(), standbyTasks.keySet());
+
+        final AtomicReference<RuntimeException> firstException = new 
AtomicReference<>(null);
+
+        firstException.compareAndSet(null, performOnActiveTasks(new 
TaskAction() {
+            @Override
+            public String name() {
+                return "suspend";
+            }
+
+            @Override
+            public void apply(final Task task) {
+                try {
+                    task.suspend();
+                } catch (final CommitFailedException e) {
+                    // commit failed during suspension. Just log it.
+                    log.warn("{} Failed to commit task {} state when 
suspending due to CommitFailedException", logPrefix, task.id());
+                } catch (final Exception e) {
+                    log.error("{} Suspending task {} failed due to the 
following error:", logPrefix, task.id(), e);
+                    try {
+                        task.close(false);
+                    } catch (final Exception f) {
+                        log.error("{} After suspending failed, closing the 
same task {} failed again due to the following error:", logPrefix, task.id(), 
f);
+                    }
+                    throw e;
+                }
+            }
+        }));
+
+        for (final Task task : standbyTasks.values()) {
+            try {
+                try {
+                    task.suspend();
+                } catch (final Exception e) {
+                    log.error("{} Suspending standby task {} failed due to the 
following error:", logPrefix, task.id(), e);
+                    try {
+                        task.close(false);
+                    } catch (final Exception f) {
+                        log.error("{} After suspending failed, closing the 
same standby task {} failed again due to the following error:", logPrefix, 
task.id(), f);
+                    }
+                    throw e;
+                }
+            } catch (final RuntimeException e) {
+                firstException.compareAndSet(null, e);
+            }
+        }
+
+        // remove the changelog partitions from restore consumer
+        firstException.compareAndSet(null, unAssignChangeLogPartitions());
+
+        updateSuspendedTasks();
+
+        if (firstException.get() != null) {
+            throw new StreamsException(logPrefix + " failed to suspend stream 
tasks", firstException.get());
+        }
+    }
+
+    private RuntimeException unAssignChangeLogPartitions() {
+        try {
+            // un-assign the change log partitions
+            restoreConsumer.assign(Collections.<TopicPartition>emptyList());
+        } catch (final RuntimeException e) {
+            log.error("{} Failed to un-assign change log partitions due to the 
following error:", logPrefix, e);
+            return e;
+        }
+        return null;
+    }
+
+    private void updateSuspendedTasks() {
+        suspendedTasks.clear();
+        suspendedTasks.putAll(activeTasks);
+        suspendedStandbyTasks.putAll(standbyTasks);
+    }
+
+    private void removeStreamTasks() {
+        log.debug("{} Removing all active tasks {}", logPrefix, 
activeTasks.keySet());
+
+        try {
+            prevActiveTasks.clear();
+            prevActiveTasks.addAll(activeTasks.keySet());
+
+            activeTasks.clear();
+            activeTasksByPartition.clear();
+        } catch (final Exception e) {
+            log.error("{} Failed to remove stream tasks due to the following 
error:", logPrefix, e);
+        }
+    }
+
+    void closeZombieTask(final Task task) {
+        log.warn("{} Producer of task {} fenced; closing zombie task", 
logPrefix, task.id());
+        try {
+            task.close(false);
+        } catch (final Exception e) {
+            log.warn("{} Failed to close zombie task due to {}, ignore and 
proceed", logPrefix, e);
+        }
+        activeTasks.remove(task.id());
+    }
+
+
+    RuntimeException performOnActiveTasks(final TaskAction action) {
+        return performOnTasks(action, activeTasks, "stream task");
+    }
+
+    RuntimeException performOnStandbyTasks(final TaskAction action) {
+        return performOnTasks(action, standbyTasks, "standby task");
+    }
+
+    private RuntimeException performOnTasks(final TaskAction action, final 
Map<TaskId, Task> tasks, final String taskType) {
+        RuntimeException firstException = null;
+        final Iterator<Map.Entry<TaskId, Task>> it = 
tasks.entrySet().iterator();
+        while (it.hasNext()) {
+            final Task task = it.next().getValue();
+            try {
+                action.apply(task);
+            } catch (final ProducerFencedException e) {
+                closeZombieTask(task);
+                it.remove();
+            } catch (final RuntimeException t) {
+                log.error("{} Failed to {} " + taskType + " {} due to the 
following error:",
+                          logPrefix,
+                          action.name(),
+                          task.id(),
+                          t);
+                if (firstException == null) {
+                    firstException = t;
+                }
+            }
+        }
+
+        return firstException;
+    }
+
+
+
+    void shutdown(final boolean clean) {
+        log.debug("{} Shutting down all active tasks {}, standby tasks {}, 
suspended tasks {}, and suspended standby tasks {}",
+                  logPrefix, activeTasks.keySet(), standbyTasks.keySet(),
+                  suspendedTasks.keySet(), suspendedStandbyTasks.keySet());
+
+        for (final Task task : allTasks()) {
+            try {
+                task.close(clean);
+            } catch (final RuntimeException e) {
+                log.error("{} Failed while closing {} {} due to the following 
error:",
+                          logPrefix,
+                          task.getClass().getSimpleName(),
+                          task.id(),
+                          e);
+            }
+        }
+        try {
+            threadMetadataProvider.close();
+        } catch (final Throwable e) {
+            log.error("{} Failed to close KafkaStreamClient due to the 
following error:", logPrefix, e);
+        }
+        // remove the changelog partitions from restore consumer
+        unAssignChangeLogPartitions();
+
+    }
+
+    Set<TaskId> suspendedActiveTaskIds() {
+        return Collections.unmodifiableSet(suspendedTasks.keySet());
+    }
+
+    Set<TaskId> suspendedStandbyTaskIds() {
+        return Collections.unmodifiableSet(suspendedStandbyTasks.keySet());
+    }
+
+    void removeTasks() {
+        removeStreamTasks();
+        removeStandbyTasks();
+    }
+
+    private void removeStandbyTasks() {
+        log.debug("{} Removing all standby tasks {}", logPrefix, 
standbyTasks.keySet());
+        standbyTasks.clear();
+        standbyTasksByPartition.clear();
+    }
+
+    Task activeTask(final TopicPartition partition) {
+        return activeTasksByPartition.get(partition);
+    }
+
+    boolean hasStandbyTasks() {
+        return !standbyTasks.isEmpty();
+    }
+
+    Task standbyTask(final TopicPartition partition) {
+        return standbyTasksByPartition.get(partition);
+    }
+
+    public Map<TaskId, Task> activeTasks() {
+        return activeTasks;
+    }
+
+    boolean hasActiveTasks() {
+        return !activeTasks.isEmpty();
+    }
+
+    void setConsumer(final Consumer<byte[], byte[]> consumer) {
+        this.consumer = consumer;
+    }
+
+    public void closeProducer() {
+        taskCreator.close();
+    }
+
+
+
+
+    interface TaskAction {
+        String name();
+        void apply(final Task task);
+    }
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java
new file mode 100644
index 0000000..ded98f7
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.processor.PartitionGrouper;
+import org.apache.kafka.streams.processor.TaskId;
+
+import java.util.Set;
+import java.util.UUID;
+
+// interface to get info about the StreamThread
+interface ThreadDataProvider {
+    InternalTopologyBuilder builder();
+    String name();
+    Set<TaskId> prevActiveTasks();
+    Set<TaskId> cachedTasks();
+    UUID processId();
+    StreamsConfig config();
+    PartitionGrouper partitionGrouper();
+    void setThreadMetadataProvider(final ThreadMetadataProvider provider);
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java
new file mode 100644
index 0000000..f185045
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.state.HostInfo;
+
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Interface used by a <code>StreamThread</code> to get metadata from the 
<code>StreamPartitionAssignor</code>
+ */
+public interface ThreadMetadataProvider {
+    Map<TaskId, Set<TopicPartition>> standbyTasks();
+    Map<TaskId, Set<TopicPartition>> activeTasks();
+    Map<HostInfo, Set<TopicPartition>> getPartitionsByHostState();
+    Cluster clusterMetadata();
+    void close();
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java 
b/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java
index 6f48543..8fcdd03 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.state;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.streams.KafkaStreams;
 import org.apache.kafka.streams.processor.StreamPartitioner;
+import org.apache.kafka.streams.processor.internals.StreamPartitionAssignor;
 
 /**
  * Represents a user defined endpoint in a {@link 
org.apache.kafka.streams.KafkaStreams} application.
@@ -29,7 +30,7 @@ import org.apache.kafka.streams.processor.StreamPartitioner;
  *  {@link KafkaStreams#metadataForKey(String, Object, Serializer)}
  *
  *  The HostInfo is constructed during Partition Assignment
- *  see {@link 
org.apache.kafka.streams.processor.internals.StreamPartitionAssignor}
+ *  see {@link StreamPartitionAssignor}
  *  It is extracted from the config {@link 
org.apache.kafka.streams.StreamsConfig#APPLICATION_SERVER_CONFIG}
  *
  *  If developers wish to expose an endpoint in their KafkaStreams 
applications they should provide the above

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
index 45d9898..19a898e 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java
@@ -18,8 +18,8 @@ package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.streams.errors.InvalidStateStoreException;
 import org.apache.kafka.streams.processor.StateStore;
-import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.processor.internals.Task;
 import org.apache.kafka.streams.state.QueryableStoreType;
 
 import java.util.ArrayList;
@@ -48,7 +48,7 @@ public class StreamThreadStateStoreProvider implements 
StateStoreProvider {
             throw new InvalidStateStoreException("the state store, " + 
storeName + ", may have migrated to another instance.");
         }
         final List<T> stores = new ArrayList<>();
-        for (StreamTask streamTask : streamThread.tasks().values()) {
+        for (Task streamTask : streamThread.tasks().values()) {
             final StateStore store = streamTask.getStore(storeName);
             if (store != null && queryableStoreType.accepts(store)) {
                 if (!store.isOpen()) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java
index b5b6e4f..d67588a 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java
@@ -17,31 +17,25 @@
 package org.apache.kafka.streams.integration;
 
 import kafka.utils.MockTime;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
+import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.serialization.StringSerializer;
-import org.apache.kafka.common.utils.Time;
-import org.apache.kafka.streams.KafkaClientSupplier;
 import org.apache.kafka.streams.KafkaStreams;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsBuilder;
-import org.apache.kafka.streams.StreamsBuilderTest;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.processor.ProcessorSupplier;
-import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.TopologyBuilder;
 import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier;
-import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder;
-import org.apache.kafka.streams.processor.internals.StateDirectory;
-import org.apache.kafka.streams.processor.internals.StreamTask;
-import org.apache.kafka.streams.processor.internals.StreamThread;
-import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
 import org.apache.kafka.test.IntegrationTest;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.MockStateStoreSupplier;
@@ -55,7 +49,6 @@ import org.junit.ClassRule;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
 
-import java.lang.reflect.Field;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -63,7 +56,6 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
-import java.util.UUID;
 import java.util.regex.Pattern;
 
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -96,6 +88,7 @@ public class RegexSourceIntegrationTest {
     private static final String STRING_SERDE_CLASSNAME = 
Serdes.String().getClass().getName();
     private Properties streamsConfiguration;
     private static final String STREAM_TASKS_NOT_UPDATED = "Stream tasks not 
updated";
+    private KafkaStreams streams;
 
 
     @BeforeClass
@@ -127,6 +120,9 @@ public class RegexSourceIntegrationTest {
 
     @After
     public void tearDown() throws Exception {
+        if (streams != null) {
+            streams.close();
+        }
         // Remove any state from previous test runs
         IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration);
     }
@@ -147,48 +143,38 @@ public class RegexSourceIntegrationTest {
         final KStream<String, String> pattern1Stream = 
builder.stream(Pattern.compile("TEST-TOPIC-\\d"));
 
         pattern1Stream.to(stringSerde, stringSerde, DEFAULT_OUTPUT_TOPIC);
+        final List<String> assignedTopics = new ArrayList<>();
+        streams = new KafkaStreams(builder.build(), streamsConfig, new 
DefaultKafkaClientSupplier() {
+            @Override
+            public Consumer<byte[], byte[]> getConsumer(final Map<String, 
Object> config) {
+                return new KafkaConsumer<byte[], byte[]>(config, new 
ByteArrayDeserializer(), new ByteArrayDeserializer()) {
+                    @Override
+                    public void subscribe(final Pattern topics, final 
ConsumerRebalanceListener listener) {
+                        super.subscribe(topics, new 
TheConsumerRebalanceListener(assignedTopics, listener));
+                    }
+                };
 
-        final KafkaStreams streams = new KafkaStreams(builder.build(), 
streamsConfiguration);
-
-        final Field streamThreadsField = 
streams.getClass().getDeclaredField("threads");
-        streamThreadsField.setAccessible(true);
-        final StreamThread[] streamThreads = (StreamThread[]) 
streamThreadsField.get(streams);
-        final StreamThread originalThread = streamThreads[0];
+            }
+        });
 
-        final TestStreamThread testStreamThread = new TestStreamThread(
-            StreamsBuilderTest.internalTopologyBuilder(builder),
-            streamsConfig,
-            new DefaultKafkaClientSupplier(),
-            originalThread.applicationId,
-            originalThread.clientId,
-            originalThread.processId,
-            new Metrics(),
-            Time.SYSTEM);
 
-        final TestCondition oneTopicAdded = new TestCondition() {
+        streams.start();
+        TestUtils.waitForCondition(new TestCondition() {
             @Override
             public boolean conditionMet() {
-                return 
testStreamThread.assignedTopicPartitions.equals(expectedFirstAssignment);
+                return assignedTopics.equals(expectedFirstAssignment);
             }
-        };
-
-        streamThreads[0] = testStreamThread;
-        streams.start();
-
-        TestUtils.waitForCondition(oneTopicAdded, STREAM_TASKS_NOT_UPDATED);
+        }, STREAM_TASKS_NOT_UPDATED);
 
         CLUSTER.createTopic("TEST-TOPIC-2");
 
-        final TestCondition secondTopicAdded = new TestCondition() {
+        TestUtils.waitForCondition(new TestCondition() {
             @Override
             public boolean conditionMet() {
-                return 
testStreamThread.assignedTopicPartitions.equals(expectedSecondAssignment);
+                return assignedTopics.equals(expectedSecondAssignment);
             }
-        };
-
-        TestUtils.waitForCondition(secondTopicAdded, STREAM_TASKS_NOT_UPDATED);
+        }, STREAM_TASKS_NOT_UPDATED);
 
-        streams.close();
     }
 
     @Test
@@ -208,49 +194,40 @@ public class RegexSourceIntegrationTest {
 
         pattern1Stream.to(stringSerde, stringSerde, DEFAULT_OUTPUT_TOPIC);
 
-        final KafkaStreams streams = new KafkaStreams(builder.build(), 
streamsConfiguration);
-
-        final Field streamThreadsField = 
streams.getClass().getDeclaredField("threads");
-        streamThreadsField.setAccessible(true);
-        final StreamThread[] streamThreads = (StreamThread[]) 
streamThreadsField.get(streams);
-        final StreamThread originalThread = streamThreads[0];
+        final List<String> assignedTopics = new ArrayList<>();
+        streams = new KafkaStreams(builder.build(), streamsConfig, new 
DefaultKafkaClientSupplier() {
+            @Override
+            public Consumer<byte[], byte[]> getConsumer(final Map<String, 
Object> config) {
+                return new KafkaConsumer<byte[], byte[]>(config, new 
ByteArrayDeserializer(), new ByteArrayDeserializer()) {
+                    @Override
+                    public void subscribe(final Pattern topics, final 
ConsumerRebalanceListener listener) {
+                        super.subscribe(topics, new 
TheConsumerRebalanceListener(assignedTopics, listener));
+                    }
+                };
 
-        final TestStreamThread testStreamThread = new TestStreamThread(
-            StreamsBuilderTest.internalTopologyBuilder(builder),
-            streamsConfig,
-            new DefaultKafkaClientSupplier(),
-            originalThread.applicationId,
-            originalThread.clientId,
-            originalThread.processId,
-            new Metrics(),
-            Time.SYSTEM);
+            }
+        });
 
-        streamThreads[0] = testStreamThread;
 
-        final TestCondition bothTopicsAdded = new TestCondition() {
+        streams.start();
+        TestUtils.waitForCondition(new TestCondition() {
             @Override
             public boolean conditionMet() {
-                return 
testStreamThread.assignedTopicPartitions.equals(expectedFirstAssignment);
+                return assignedTopics.equals(expectedFirstAssignment);
             }
-        };
-        streams.start();
-
-        TestUtils.waitForCondition(bothTopicsAdded, STREAM_TASKS_NOT_UPDATED);
+        }, STREAM_TASKS_NOT_UPDATED);
 
         CLUSTER.deleteTopic("TEST-TOPIC-A");
 
-        final TestCondition oneTopicRemoved = new TestCondition() {
+        TestUtils.waitForCondition(new TestCondition() {
             @Override
             public boolean conditionMet() {
-                return 
testStreamThread.assignedTopicPartitions.equals(expectedSecondAssignment);
+                return assignedTopics.equals(expectedSecondAssignment);
             }
-        };
-
-        TestUtils.waitForCondition(oneTopicRemoved, STREAM_TASKS_NOT_UPDATED);
-
-        streams.close();
+        }, STREAM_TASKS_NOT_UPDATED);
     }
 
+    @SuppressWarnings("deprecation")
     @Test
     public void shouldAddStateStoreToRegexDefinedSource() throws Exception {
 
@@ -264,7 +241,7 @@ public class RegexSourceIntegrationTest {
                 .addStateStore(stateStoreSupplier, "my-processor");
 
 
-        final KafkaStreams streams = new KafkaStreams(builder, 
streamsConfiguration);
+        streams = new KafkaStreams(builder, streamsConfiguration);
         try {
             streams.start();
 
@@ -308,7 +285,7 @@ public class RegexSourceIntegrationTest {
         pattern2Stream.to(stringSerde, stringSerde, DEFAULT_OUTPUT_TOPIC);
         namedTopicsStream.to(stringSerde, stringSerde, DEFAULT_OUTPUT_TOPIC);
 
-        final KafkaStreams streams = new KafkaStreams(builder.build(), 
streamsConfiguration);
+        streams = new KafkaStreams(builder.build(), streamsConfiguration);
         streams.start();
 
         final Properties producerConfig = 
TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, 
StringSerializer.class);
@@ -330,7 +307,6 @@ public class RegexSourceIntegrationTest {
             actualValues.add(receivedKeyValue.value);
         }
 
-        streams.close();
         Collections.sort(actualValues);
         Collections.sort(expectedReceivedValues);
         assertThat(actualValues, equalTo(expectedReceivedValues));
@@ -339,84 +315,67 @@ public class RegexSourceIntegrationTest {
     @Test
     public void testMultipleConsumersCanReadFromPartitionedTopic() throws 
Exception {
 
-        final Serde<String> stringSerde = Serdes.String();
-        final StreamsBuilder builderLeader = new StreamsBuilder();
-        final StreamsBuilder builderFollower = new StreamsBuilder();
-        final List<String> expectedAssignment = 
Arrays.asList(PARTITIONED_TOPIC_1,  PARTITIONED_TOPIC_2);
-
-        final KStream<String, String> partitionedStreamLeader = 
builderLeader.stream(Pattern.compile("partitioned-\\d"));
-        final KStream<String, String> partitionedStreamFollower = 
builderFollower.stream(Pattern.compile("partitioned-\\d"));
+        KafkaStreams partitionedStreamsLeader = null;
+        KafkaStreams partitionedStreamsFollower = null;
+        try {
+            final Serde<String> stringSerde = Serdes.String();
+            final StreamsBuilder builderLeader = new StreamsBuilder();
+            final StreamsBuilder builderFollower = new StreamsBuilder();
+            final List<String> expectedAssignment = 
Arrays.asList(PARTITIONED_TOPIC_1,  PARTITIONED_TOPIC_2);
 
+            final KStream<String, String> partitionedStreamLeader = 
builderLeader.stream(Pattern.compile("partitioned-\\d"));
+            final KStream<String, String> partitionedStreamFollower = 
builderFollower.stream(Pattern.compile("partitioned-\\d"));
 
-        partitionedStreamLeader.to(stringSerde, stringSerde, 
DEFAULT_OUTPUT_TOPIC);
-        partitionedStreamFollower.to(stringSerde, stringSerde, 
DEFAULT_OUTPUT_TOPIC);
 
-        final KafkaStreams partitionedStreamsLeader  = new 
KafkaStreams(builderLeader.build(), streamsConfiguration);
-        final KafkaStreams partitionedStreamsFollower  = new 
KafkaStreams(builderFollower.build(), streamsConfiguration);
+            partitionedStreamLeader.to(stringSerde, stringSerde, 
DEFAULT_OUTPUT_TOPIC);
+            partitionedStreamFollower.to(stringSerde, stringSerde, 
DEFAULT_OUTPUT_TOPIC);
 
-        final StreamsConfig streamsConfig = new 
StreamsConfig(streamsConfiguration);
+            final List<String> leaderAssignment = new ArrayList<>();
+            final List<String> followerAssignment = new ArrayList<>();
+            StreamsConfig config = new StreamsConfig(streamsConfiguration);
 
+            partitionedStreamsLeader  = new 
KafkaStreams(builderLeader.build(), config, new DefaultKafkaClientSupplier() {
+                @Override
+                public Consumer<byte[], byte[]> getConsumer(final Map<String, 
Object> config) {
+                    return new KafkaConsumer<byte[], byte[]>(config, new 
ByteArrayDeserializer(), new ByteArrayDeserializer()) {
+                        @Override
+                        public void subscribe(final Pattern topics, final 
ConsumerRebalanceListener listener) {
+                            super.subscribe(topics, new 
TheConsumerRebalanceListener(leaderAssignment, listener));
+                        }
+                    };
 
-        final Field leaderStreamThreadsField = 
partitionedStreamsLeader.getClass().getDeclaredField("threads");
-        leaderStreamThreadsField.setAccessible(true);
-        final StreamThread[] leaderStreamThreads = (StreamThread[]) 
leaderStreamThreadsField.get(partitionedStreamsLeader);
-        final StreamThread originalLeaderThread = leaderStreamThreads[0];
+                }
+            });
+            partitionedStreamsFollower  = new 
KafkaStreams(builderFollower.build(), config, new DefaultKafkaClientSupplier() {
+                @Override
+                public Consumer<byte[], byte[]> getConsumer(final Map<String, 
Object> config) {
+                    return new KafkaConsumer<byte[], byte[]>(config, new 
ByteArrayDeserializer(), new ByteArrayDeserializer()) {
+                        @Override
+                        public void subscribe(final Pattern topics, final 
ConsumerRebalanceListener listener) {
+                            super.subscribe(topics, new 
TheConsumerRebalanceListener(followerAssignment, listener));
+                        }
+                    };
 
-        final TestStreamThread leaderTestStreamThread = new TestStreamThread(
-            StreamsBuilderTest.internalTopologyBuilder(builderLeader),
-            streamsConfig,
-            new DefaultKafkaClientSupplier(),
-            originalLeaderThread.applicationId,
-            originalLeaderThread.clientId,
-            originalLeaderThread.processId,
-            new Metrics(),
-            Time.SYSTEM);
+                }
+            });
 
-        leaderStreamThreads[0] = leaderTestStreamThread;
 
-        final TestCondition bothTopicsAddedToLeader = new TestCondition() {
-            @Override
-            public boolean conditionMet() {
-                return 
leaderTestStreamThread.assignedTopicPartitions.equals(expectedAssignment);
+            partitionedStreamsLeader.start();
+            partitionedStreamsFollower.start();
+            TestUtils.waitForCondition(new TestCondition() {
+                @Override
+                public boolean conditionMet() {
+                    return followerAssignment.equals(expectedAssignment) && 
leaderAssignment.equals(expectedAssignment);
+                }
+            }, "topic assignment not completed");
+        } finally {
+            if (partitionedStreamsLeader != null) {
+                partitionedStreamsLeader.close();
             }
-        };
-
-
-
-        final Field followerStreamThreadsField = 
partitionedStreamsFollower.getClass().getDeclaredField("threads");
-        followerStreamThreadsField.setAccessible(true);
-        final StreamThread[] followerStreamThreads = (StreamThread[]) 
followerStreamThreadsField.get(partitionedStreamsFollower);
-        final StreamThread originalFollowerThread = followerStreamThreads[0];
-
-        final TestStreamThread followerTestStreamThread = new TestStreamThread(
-            StreamsBuilderTest.internalTopologyBuilder(builderFollower),
-            streamsConfig,
-            new DefaultKafkaClientSupplier(),
-            originalFollowerThread.applicationId,
-            originalFollowerThread.clientId,
-            originalFollowerThread.processId,
-            new Metrics(),
-            Time.SYSTEM);
-
-        followerStreamThreads[0] = followerTestStreamThread;
-
-
-        final TestCondition bothTopicsAddedToFollower = new TestCondition() {
-            @Override
-            public boolean conditionMet() {
-                return 
followerTestStreamThread.assignedTopicPartitions.equals(expectedAssignment);
+            if (partitionedStreamsFollower != null) {
+                partitionedStreamsFollower.close();
             }
-        };
-
-        partitionedStreamsLeader.start();
-        TestUtils.waitForCondition(bothTopicsAddedToLeader, "Topics never 
assigned to leader stream");
-
-
-        partitionedStreamsFollower.start();
-        TestUtils.waitForCondition(bothTopicsAddedToFollower, "Topics never 
assigned to follower stream");
-
-        partitionedStreamsLeader.close();
-        partitionedStreamsFollower.close();
+        }
 
     }
 
@@ -443,7 +402,7 @@ public class RegexSourceIntegrationTest {
         pattern1Stream.to(stringSerde, stringSerde, DEFAULT_OUTPUT_TOPIC);
         pattern2Stream.to(stringSerde, stringSerde, DEFAULT_OUTPUT_TOPIC);
 
-        final KafkaStreams streams = new KafkaStreams(builder.build(), 
streamsConfiguration);
+        streams = new KafkaStreams(builder.build(), streamsConfiguration);
         streams.start();
 
         final Properties producerConfig = 
TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, 
StringSerializer.class);
@@ -453,34 +412,33 @@ public class RegexSourceIntegrationTest {
 
         final Properties consumerConfig = 
TestUtils.consumerConfig(CLUSTER.bootstrapServers(), StringDeserializer.class, 
StringDeserializer.class);
 
-        try {
-            
IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(consumerConfig, 
DEFAULT_OUTPUT_TOPIC, 2, 5000);
-            fail("Should not get here");
-        } finally {
-            streams.close();
-        }
-
+        
IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(consumerConfig, 
DEFAULT_OUTPUT_TOPIC, 2, 5000);
+        fail("Should not get here");
     }
 
-    private class TestStreamThread extends StreamThread {
-        public volatile List<String> assignedTopicPartitions = new 
ArrayList<>();
+    private static class TheConsumerRebalanceListener implements 
ConsumerRebalanceListener {
+        private final List<String> assignedTopics;
+        private final ConsumerRebalanceListener listener;
+
+        TheConsumerRebalanceListener(final List<String> assignedTopics, final 
ConsumerRebalanceListener listener) {
+            this.assignedTopics = assignedTopics;
+            this.listener = listener;
+        }
 
-        public TestStreamThread(final InternalTopologyBuilder builder, final 
StreamsConfig config, final KafkaClientSupplier clientSupplier, final String 
applicationId, final String clientId, final UUID processId, final Metrics 
metrics, final Time time) {
-            super(builder, config, clientSupplier, applicationId, clientId, 
processId, metrics, time, new StreamsMetadataState(builder, 
StreamsMetadataState.UNKNOWN_HOST),
-                  0, new StateDirectory(applicationId, 
config.getString(StreamsConfig.STATE_DIR_CONFIG), time));
+        @Override
+        public void onPartitionsRevoked(final Collection<TopicPartition> 
partitions) {
+            assignedTopics.clear();
+            listener.onPartitionsRevoked(partitions);
         }
 
         @Override
-        public StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-            final List<String> topicPartitions = new ArrayList<>();
+        public void onPartitionsAssigned(final Collection<TopicPartition> 
partitions) {
             for (final TopicPartition partition : partitions) {
-                topicPartitions.add(partition.topic());
+                assignedTopics.add(partition.topic());
             }
-            Collections.sort(topicPartitions);
-
-            assignedTopicPartitions = topicPartitions;
-            return super.createStreamTask(id, partitions);
+            Collections.sort(assignedTopics);
+            listener.onPartitionsAssigned(partitions);
         }
-
     }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
index f78b43a..353f740 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.MockConsumer;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
@@ -35,6 +36,8 @@ import org.apache.kafka.test.TestUtils;
 import org.junit.Test;
 
 import java.util.Collections;
+import java.util.List;
+import java.util.Map;
 import java.util.Properties;
 
 public class AbstractTaskTest {
@@ -91,6 +94,46 @@ public class AbstractTaskTest {
 
             @Override
             public void close(final boolean clean) {}
+
+            @Override
+            public void closeSuspended(final boolean clean, final 
RuntimeException e) {
+
+            }
+
+            @Override
+            public Map<TopicPartition, Long> checkpointedOffsets() {
+                return null;
+            }
+
+            @Override
+            public boolean process() {
+                return false;
+            }
+
+            @Override
+            public boolean maybePunctuateStreamTime() {
+                return false;
+            }
+
+            @Override
+            public boolean maybePunctuateSystemTime() {
+                return false;
+            }
+
+            @Override
+            public List<ConsumerRecord<byte[], byte[]>> update(final 
TopicPartition partition, final List<ConsumerRecord<byte[], byte[]>> remaining) 
{
+                return null;
+            }
+
+            @Override
+            public int addRecords(final TopicPartition partition, final 
Iterable<ConsumerRecord<byte[], byte[]>> records) {
+                return 0;
+            }
+
+            @Override
+            public boolean commitNeeded() {
+                return false;
+            }
         };
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index fad7116..e232316 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -32,6 +32,7 @@ import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.kstream.internals.InternalStreamsBuilder;
 import org.apache.kafka.streams.kstream.internals.InternalStreamsBuilderTest;
 import org.apache.kafka.streams.processor.StateStore;
@@ -57,12 +58,15 @@ import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import static java.util.Collections.singleton;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 public class StandbyTaskTest {
 
@@ -383,6 +387,48 @@ public class StandbyTaskTest {
 
     }
 
+    @Test
+    public void shouldCloseStateMangerOnTaskCloseWhenCommitFailed() throws 
Exception {
+        consumer.assign(Utils.mkList(ktable));
+        final Map<TopicPartition, OffsetAndMetadata> committedOffsets = new 
HashMap<>();
+        committedOffsets.put(new TopicPartition(ktable.topic(), 
ktable.partition()), new OffsetAndMetadata(100L));
+        consumer.commitSync(committedOffsets);
+
+        restoreStateConsumer.updatePartitions("ktable1", Utils.mkList(
+                new PartitionInfo("ktable1", 0, Node.noNode(), new Node[0], 
new Node[0])));
+
+        final StreamsConfig config = createConfig(baseDir);
+        final AtomicBoolean closedStateManager = new AtomicBoolean(false);
+        final StandbyTask task = new StandbyTask(taskId,
+                                                 applicationId,
+                                                 ktablePartitions,
+                                                 ktableTopology,
+                                                 consumer,
+                                                 changelogReader,
+                                                 config,
+                                                 null,
+                                                 stateDirectory
+        ) {
+            @Override
+            public void commit() {
+                throw new RuntimeException("KABOOM!");
+            }
+
+            @Override
+            void closeStateManager(final boolean writeCheckpoint) throws 
ProcessorStateException {
+                closedStateManager.set(true);
+            }
+        };
+
+        try {
+            task.close(true);
+            fail("should have thrown exception");
+        } catch (Exception e) {
+            // expected
+        }
+        assertTrue(closedStateManager.get());
+    }
+
     private List<ConsumerRecord<byte[], byte[]>> 
records(ConsumerRecord<byte[], byte[]>... recs) {
         return Arrays.asList(recs);
     }

Reply via email to