Repository: kafka
Updated Branches:
  refs/heads/trunk b2b529522 -> 3e69ce801


http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 465f8be..4752227 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
@@ -39,28 +38,23 @@ import 
org.apache.kafka.streams.kstream.internals.InternalStreamsBuilder;
 import org.apache.kafka.streams.kstream.internals.InternalStreamsBuilderTest;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.TopologyBuilder;
 import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
 import org.apache.kafka.streams.state.HostInfo;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.test.MockClientSupplier;
-import org.apache.kafka.test.MockInternalTopicManager;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.MockStateRestoreListener;
-import org.apache.kafka.test.MockStateStoreSupplier;
 import org.apache.kafka.test.MockTimestampExtractor;
 import org.apache.kafka.test.TestCondition;
 import org.apache.kafka.test.TestUtils;
 import org.easymock.EasyMock;
+import org.easymock.IAnswer;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import java.io.File;
-import java.io.IOException;
 import java.lang.reflect.Field;
 import java.nio.ByteBuffer;
-import java.nio.file.Files;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -77,11 +71,9 @@ import java.util.regex.Pattern;
 
 import static java.util.Collections.EMPTY_SET;
 import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
@@ -93,13 +85,14 @@ public class StreamThreadTest {
     private final String applicationId = "stream-thread-test";
     private final MockTime mockTime = new MockTime();
     private final Metrics metrics = new Metrics();
-    private final MockClientSupplier clientSupplier = new MockClientSupplier();
+    private MockClientSupplier clientSupplier = new MockClientSupplier();
     private UUID processId = UUID.randomUUID();
     private final InternalStreamsBuilder internalStreamsBuilder = new 
InternalStreamsBuilder(new InternalTopologyBuilder());
     private InternalTopologyBuilder internalTopologyBuilder;
     private final StreamsConfig config = new StreamsConfig(configProps(false));
     private final String stateDir = TestUtils.tempDirectory().getPath();
     private final StateDirectory stateDirectory  = new 
StateDirectory("applicationId", stateDir, mockTime);
+    private StreamsMetadataState streamsMetadataState;
 
     @Before
     public void setUp() throws Exception {
@@ -107,6 +100,7 @@ public class StreamThreadTest {
 
         internalTopologyBuilder = 
InternalStreamsBuilderTest.internalTopologyBuilder(internalStreamsBuilder);
         internalTopologyBuilder.setApplicationId(applicationId);
+        streamsMetadataState = new 
StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST);
     }
 
     private final TopicPartition t1p1 = new TopicPartition("topic1", 1);
@@ -267,9 +261,9 @@ public class StreamThreadTest {
         final StreamThread thread = getStreamThread();
 
         final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        thread.setPartitionAssignor(new StreamPartitionAssignor() {
+        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
             @Override
-            Map<TaskId, Set<TopicPartition>> activeTasks() {
+            public Map<TaskId, Set<TopicPartition>> activeTasks() {
                 return activeTasks;
             }
         });
@@ -364,9 +358,9 @@ public class StreamThreadTest {
         final StreamThread thread = getStreamThread();
 
         final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        thread.setPartitionAssignor(new StreamPartitionAssignor() {
+        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
             @Override
-            Map<TaskId, Set<TopicPartition>> activeTasks() {
+            public Map<TaskId, Set<TopicPartition>> activeTasks() {
                 return activeTasks;
             }
         });
@@ -451,19 +445,7 @@ public class StreamThreadTest {
     @Test
     public void testStateChangeStartClose() throws InterruptedException {
 
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            Time.SYSTEM,
-
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, config, 
false);
 
         final StateListenerStub stateListener = new StateListenerStub();
         thread.setStateListener(stateListener);
@@ -486,6 +468,23 @@ public class StreamThreadTest {
         assertEquals(thread.state(), StreamThread.State.DEAD);
     }
 
+    private StreamThread createStreamThread(final String clientId, final 
StreamsConfig config, final boolean eosEnabled) {
+        if (eosEnabled) {
+            clientSupplier = new MockClientSupplier(applicationId);
+        }
+        return StreamThread.create(internalTopologyBuilder,
+                                   config,
+                                   clientSupplier,
+                                   processId,
+                                   clientId,
+                                   metrics,
+                                   mockTime,
+                                   streamsMetadataState,
+                                   0,
+                                   stateDirectory,
+                                   new MockStateRestoreListener());
+    }
+
     private final static String TOPIC = "topic";
     private final Set<TopicPartition> task0Assignment = 
Collections.singleton(new TopicPartition(TOPIC, 0));
     private final Set<TopicPartition> task1Assignment = 
Collections.singleton(new TopicPartition(TOPIC, 1));
@@ -505,30 +504,9 @@ public class StreamThreadTest {
 
         //clientSupplier.consumer.assign(Arrays.asList(new 
TopicPartition(TOPIC, 0), new TopicPartition(TOPIC, 1)));
 
-        final StreamThread thread1 = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId + 1,
-            processId,
-            metrics,
-            Time.SYSTEM,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
-        final StreamThread thread2 = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId + 2,
-            processId,
-            metrics,
-            Time.SYSTEM,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread1 = createStreamThread(clientId + 1, config, 
false);
+        final StreamThread thread2 = createStreamThread(clientId + 2, config, 
false);
+
 
         final Map<TaskId, Set<TopicPartition>> task0 = 
Collections.singletonMap(new TaskId(0, 0), task0Assignment);
         final Map<TaskId, Set<TopicPartition>> task1 = 
Collections.singletonMap(new TaskId(0, 1), task1Assignment);
@@ -536,8 +514,8 @@ public class StreamThreadTest {
         final Map<TaskId, Set<TopicPartition>> thread1Assignment = new 
HashMap<>(task0);
         final Map<TaskId, Set<TopicPartition>> thread2Assignment = new 
HashMap<>(task1);
 
-        thread1.setPartitionAssignor(new 
MockStreamsPartitionAssignor(thread1Assignment));
-        thread2.setPartitionAssignor(new 
MockStreamsPartitionAssignor(thread2Assignment));
+        thread1.setThreadMetadataProvider(new 
MockStreamsPartitionAssignor(thread1Assignment));
+        thread2.setThreadMetadataProvider(new 
MockStreamsPartitionAssignor(thread2Assignment));
 
 
         thread1.start();
@@ -615,12 +593,12 @@ public class StreamThreadTest {
         }
 
         @Override
-        Map<TaskId, Set<TopicPartition>> activeTasks() {
+        public Map<TaskId, Set<TopicPartition>> activeTasks() {
             return activeTaskAssignment;
         }
 
         @Override
-        Map<TaskId, Set<TopicPartition>> standbyTasks() {
+        public Map<TaskId, Set<TopicPartition>> standbyTasks() {
             return standbyTaskAssignment;
         }
 
@@ -630,17 +608,7 @@ public class StreamThreadTest {
 
     @Test
     public void testMetrics() {
-        final StreamThread thread = new StreamThread(
-                internalTopologyBuilder,
-                config,
-                clientSupplier,
-                applicationId,
-                clientId,
-                processId,
-                metrics,
-                mockTime,
-                new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-                0, stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, config, 
false);
         final String defaultGroupName = "stream-metrics";
         final String defaultPrefix = "thread." + thread.threadClientId();
         final Map<String, String> defaultTags = 
Collections.singletonMap("client-id", thread.threadClientId());
@@ -671,128 +639,105 @@ public class StreamThreadTest {
     }
 
 
+    @SuppressWarnings({"unchecked", "ThrowableNotThrown"})
     @Test
-    public void testMaybeCommit() throws IOException, InterruptedException {
-        final File baseDir = Files.createTempDirectory("test").toFile();
-        try {
-            final long commitInterval = 1000L;
-            final Properties props = configProps(false);
-            props.setProperty(StreamsConfig.STATE_DIR_CONFIG, 
baseDir.getCanonicalPath());
-            props.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
Long.toString(commitInterval));
-
-            final StreamsConfig config = new StreamsConfig(props);
-
-            internalTopologyBuilder.addSource(null, "source1", null, null, 
null, "topic1");
-
-            final StreamThread thread = new StreamThread(
-                    internalTopologyBuilder,
-                    config,
-                    clientSupplier,
-                    applicationId,
-                    clientId,
-                    processId,
-                    metrics,
-                    mockTime,
-                    new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-                    0, stateDirectory) {
-
-                @Override
-                public void maybeCommit(final long now) {
-                    super.maybeCommit(now);
-                }
-
-                @Override
-                protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitionsForTask) {
-                    final ProcessorTopology topology = 
builder.build(id.topicGroupId);
-                    return new TestStreamTask(
-                        id,
-                        applicationId,
-                        partitionsForTask,
-                        topology,
-                        consumer,
-                        clientSupplier.getProducer(new HashMap<String, 
Object>()),
-                        restoreConsumer,
-                        config,
-                        new MockStreamsMetrics(new Metrics()),
-                        stateDirectory);
-                }
-            };
-
-            initPartitionGrouper(config, thread, clientSupplier);
-
-            final ConsumerRebalanceListener rebalanceListener = 
thread.rebalanceListener;
+    public void shouldNotCommitBeforeTheCommitInterval() {
+        final long commitInterval = 1000L;
+        final Properties props = configProps(false);
+        props.setProperty(StreamsConfig.STATE_DIR_CONFIG, stateDir);
+        props.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
Long.toString(commitInterval));
 
-            final List<TopicPartition> revokedPartitions;
-            final List<TopicPartition> assignedPartitions;
+        final StreamsConfig config = new StreamsConfig(props);
+        final Consumer<byte[], byte[]> consumer = 
EasyMock.createNiceMock(Consumer.class);
+        final TaskManager taskManager = mockTaskMangerCommit(consumer, 1);
 
-            //
-            // Assign t1p1 and t1p2. This should create Task 1 & 2
-            //
-            revokedPartitions = Collections.emptyList();
-            assignedPartitions = Arrays.asList(t1p1, t1p2);
-
-            thread.setState(StreamThread.State.RUNNING);
-            
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-            rebalanceListener.onPartitionsRevoked(revokedPartitions);
-            rebalanceListener.onPartitionsAssigned(assignedPartitions);
+        StreamThread.StreamsMetricsThreadImpl streamsMetrics = new 
StreamThread.StreamsMetricsThreadImpl(metrics, "", "", Collections.<String, 
String>emptyMap());
+        final StreamThread thread = new StreamThread(internalTopologyBuilder,
+                                                     clientId,
+                                                     "",
+                                                     config,
+                                                     processId,
+                                                     mockTime,
+                                                     streamsMetadataState,
+                                                     taskManager,
+                                                     streamsMetrics,
+                                                     clientSupplier,
+                                                     consumer,
+                                                     stateDirectory);
+        thread.maybeCommit(mockTime.milliseconds());
+        mockTime.sleep(commitInterval - 10L);
+        thread.maybeCommit(mockTime.milliseconds());
+
+        EasyMock.verify(taskManager);
+    }
 
-            assertEquals(2, thread.tasks().size());
 
-            // no task is committed before the commit interval
-            mockTime.sleep(commitInterval - 10L);
-            thread.maybeCommit(mockTime.milliseconds());
-            for (final StreamTask task : thread.tasks().values()) {
-                assertFalse(((TestStreamTask) task).committed);
-            }
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldCommitAfterTheCommitInterval() {
+        final long commitInterval = 1000L;
+        final Properties props = configProps(false);
+        props.setProperty(StreamsConfig.STATE_DIR_CONFIG, stateDir);
+        props.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
Long.toString(commitInterval));
 
-            // all tasks are committed after the commit interval
-            mockTime.sleep(11L);
-            thread.maybeCommit(mockTime.milliseconds());
-            for (final StreamTask task : thread.tasks().values()) {
-                assertTrue(((TestStreamTask) task).committed);
-                ((TestStreamTask) task).committed = false;
-            }
+        final StreamsConfig config = new StreamsConfig(props);
+        final Consumer<byte[], byte[]> consumer = 
EasyMock.createNiceMock(Consumer.class);
+        final TaskManager taskManager = mockTaskMangerCommit(consumer, 2);
 
-            // no task is committed before the commit interval, again
-            mockTime.sleep(commitInterval - 10L);
-            thread.maybeCommit(mockTime.milliseconds());
-            for (final StreamTask task : thread.tasks().values()) {
-                assertFalse(((TestStreamTask) task).committed);
-            }
+        StreamThread.StreamsMetricsThreadImpl streamsMetrics = new 
StreamThread.StreamsMetricsThreadImpl(metrics, "", "", Collections.<String, 
String>emptyMap());
+        final StreamThread thread = new StreamThread(internalTopologyBuilder,
+                                                     clientId,
+                                                     "",
+                                                     config,
+                                                     processId,
+                                                     mockTime,
+                                                     streamsMetadataState,
+                                                     taskManager,
+                                                     streamsMetrics,
+                                                     clientSupplier,
+                                                     consumer,
+                                                     stateDirectory);
+        thread.maybeCommit(mockTime.milliseconds());
+        mockTime.sleep(commitInterval + 1);
+        thread.maybeCommit(mockTime.milliseconds());
+
+        EasyMock.verify(taskManager);
+    }
 
-            // all tasks are committed after the commit interval, again
-            mockTime.sleep(11L);
-            thread.maybeCommit(mockTime.milliseconds());
-            for (final StreamTask task : thread.tasks().values()) {
-                assertTrue(((TestStreamTask) task).committed);
-                ((TestStreamTask) task).committed = false;
+    @SuppressWarnings({"ThrowableNotThrown", "unchecked"})
+    private TaskManager mockTaskMangerCommit(final Consumer<byte[], byte[]> 
consumer, final int numberOfCommits) {
+        final TaskManager taskManager = EasyMock.createMock(TaskManager.class);
+        taskManager.setConsumer(EasyMock.anyObject(Consumer.class));
+        EasyMock.expectLastCall();
+        IAnswer<Object> checkCommitAction = new IAnswer<Object>() {
+            @Override
+            public Object answer() throws Throwable {
+                final Object[] currentArguments = 
EasyMock.getCurrentArguments();
+                TaskManager.TaskAction action = (TaskManager.TaskAction) 
currentArguments[0];
+                if (!action.name().equals("commit")) {
+                    throw new IllegalArgumentException("expected to get commit 
action but was:" + action.name());
+                }
+                return null;
             }
-        } finally {
-            Utils.delete(baseDir);
-        }
+        };
+        
taskManager.performOnActiveTasks(EasyMock.anyObject(TaskManager.TaskAction.class));
+        
EasyMock.expectLastCall().andAnswer(checkCommitAction).times(numberOfCommits);
+        
taskManager.performOnStandbyTasks(EasyMock.anyObject(TaskManager.TaskAction.class));
+        
EasyMock.expectLastCall().andAnswer(checkCommitAction).times(numberOfCommits);
+        EasyMock.replay(taskManager, consumer);
+        return taskManager;
     }
 
     @Test
     public void 
shouldInjectSharedProducerForAllTasksUsingClientSupplierOnCreateIfEosDisabled() 
throws InterruptedException {
         internalTopologyBuilder.addSource(null, "source1", null, null, null, 
"someTopic");
 
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, config, 
false);
 
         final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
         assignment.put(new TaskId(0, 0), Collections.singleton(new 
TopicPartition("someTopic", 0)));
         assignment.put(new TaskId(0, 1), Collections.singleton(new 
TopicPartition("someTopic", 1)));
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(assignment));
+        thread.setThreadMetadataProvider(new 
MockStreamsPartitionAssignor(assignment));
 
         thread.setState(StreamThread.State.RUNNING);
         
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
@@ -800,9 +745,8 @@ public class StreamThreadTest {
 
         assertEquals(1, clientSupplier.producers.size());
         final Producer globalProducer = clientSupplier.producers.get(0);
-        assertSame(globalProducer, thread.threadProducer);
-        for (final StreamTask task : thread.tasks().values()) {
-            assertSame(globalProducer, ((RecordCollectorImpl) 
task.recordCollector()).producer());
+        for (final Task task : thread.tasks().values()) {
+            assertSame(globalProducer, ((RecordCollectorImpl) ((StreamTask) 
task).recordCollector()).producer());
         }
         assertSame(clientSupplier.consumer, thread.consumer);
         assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer);
@@ -812,25 +756,13 @@ public class StreamThreadTest {
     public void 
shouldInjectProducerPerTaskUsingClientSupplierOnCreateIfEosEnable() throws 
InterruptedException {
         internalTopologyBuilder.addSource(null, "source1", null, null, null, 
"someTopic");
 
-        final MockClientSupplier clientSupplier = new 
MockClientSupplier(applicationId);
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            new StreamsConfig(configProps(true)),
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, new 
StreamsConfig(configProps(true)), true);
 
         final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
         assignment.put(new TaskId(0, 0), Collections.singleton(new 
TopicPartition("someTopic", 0)));
         assignment.put(new TaskId(0, 1), Collections.singleton(new 
TopicPartition("someTopic", 1)));
         assignment.put(new TaskId(0, 2), Collections.singleton(new 
TopicPartition("someTopic", 2)));
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(assignment));
+        thread.setThreadMetadataProvider(new 
MockStreamsPartitionAssignor(assignment));
 
         final Set<TopicPartition> assignedPartitions = new HashSet<>();
         Collections.addAll(assignedPartitions, new TopicPartition("someTopic", 
0), new TopicPartition("someTopic", 2));
@@ -838,11 +770,10 @@ public class StreamThreadTest {
         
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
-        assertNull(thread.threadProducer);
         assertEquals(thread.tasks().size(), clientSupplier.producers.size());
         final Iterator it = clientSupplier.producers.iterator();
-        for (final StreamTask task : thread.tasks().values()) {
-            assertSame(it.next(), ((RecordCollectorImpl) 
task.recordCollector()).producer());
+        for (final Task task : thread.tasks().values()) {
+            assertSame(it.next(), ((RecordCollectorImpl) ((StreamTask) 
task).recordCollector()).producer());
         }
         assertSame(clientSupplier.consumer, thread.consumer);
         assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer);
@@ -852,24 +783,12 @@ public class StreamThreadTest {
     public void shouldCloseAllTaskProducersOnCloseIfEosEnabled() throws 
InterruptedException {
         internalTopologyBuilder.addSource(null, "source1", null, null, null, 
"someTopic");
 
-        final MockClientSupplier clientSupplier = new 
MockClientSupplier(applicationId);
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            new StreamsConfig(configProps(true)),
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, new 
StreamsConfig(configProps(true)), true);
 
         final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
         assignment.put(new TaskId(0, 0), Collections.singleton(new 
TopicPartition("someTopic", 0)));
         assignment.put(new TaskId(0, 1), Collections.singleton(new 
TopicPartition("someTopic", 1)));
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(assignment));
+        thread.setThreadMetadataProvider(new 
MockStreamsPartitionAssignor(assignment));
 
         thread.setState(StreamThread.State.RUNNING);
         
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
@@ -878,41 +797,39 @@ public class StreamThreadTest {
         thread.close();
         thread.run();
 
-        for (final StreamTask task : thread.tasks().values()) {
-            assertTrue(((MockProducer) ((RecordCollectorImpl) 
task.recordCollector()).producer()).closed());
+        for (final Task task : thread.tasks().values()) {
+            assertTrue(((MockProducer) ((RecordCollectorImpl) ((StreamTask) 
task).recordCollector()).producer()).closed());
         }
     }
 
     @Test
     public void shouldCloseThreadProducerOnCloseIfEosDisabled() throws 
InterruptedException {
-        internalTopologyBuilder.addSource(null, "source1", null, null, null, 
"someTopic");
+        final Consumer<byte[], byte[]> consumer = 
EasyMock.createNiceMock(Consumer.class);
+        final TaskManager taskManager = 
EasyMock.createNiceMock(TaskManager.class);
 
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
-
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.put(new TaskId(0, 0), Collections.singleton(new 
TopicPartition("someTopic", 0)));
-        assignment.put(new TaskId(0, 1), Collections.singleton(new 
TopicPartition("someTopic", 1)));
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(assignment));
+        taskManager.setConsumer(EasyMock.anyObject(Consumer.class));
+        EasyMock.expectLastCall();
+        taskManager.closeProducer();
+        EasyMock.expectLastCall();
+        EasyMock.replay(taskManager, consumer);
 
+        StreamThread.StreamsMetricsThreadImpl streamsMetrics = new 
StreamThread.StreamsMetricsThreadImpl(metrics, "", "", Collections.<String, 
String>emptyMap());
+        final StreamThread thread = new StreamThread(internalTopologyBuilder,
+                                                     clientId,
+                                                     "",
+                                                     config,
+                                                     processId,
+                                                     mockTime,
+                                                     streamsMetadataState,
+                                                     taskManager,
+                                                     streamsMetrics,
+                                                     clientSupplier,
+                                                     consumer,
+                                                     stateDirectory);
         thread.setState(StreamThread.State.RUNNING);
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new 
TopicPartition("someTopic", 0)));
-
         thread.close();
         thread.run();
-
-        assertTrue(((MockProducer) thread.threadProducer).closed());
+        EasyMock.verify(taskManager);
     }
 
     @Test
@@ -920,150 +837,18 @@ public class StreamThreadTest {
         internalTopologyBuilder.addSource(null, "name", null, null, null, 
"topic");
         internalTopologyBuilder.addSink("out", "output", null, null, null);
 
-        final StreamThread thread = new StreamThread(internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
-
-        thread.setPartitionAssignor(new StreamPartitionAssignor() {
-            @Override
-            Map<TaskId, Set<TopicPartition>> standbyTasks() {
-                return Collections.singletonMap(new TaskId(0, 0), 
Utils.mkSet(new TopicPartition("topic", 0)));
-            }
-        });
-
-        thread.setState(StreamThread.State.RUNNING);
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
-    }
-
-    @Test
-    public void shouldNotCloseSuspendedTaskswice() throws Exception {
-        internalTopologyBuilder.addSource(null, "name", null, null, null, 
"topic");
-        internalTopologyBuilder.addSink("out", "output", null, null, null);
-
-        final TestStreamTask testStreamTask = new TestStreamTask(
-                new TaskId(0, 0),
-                applicationId,
-                Utils.mkSet(new TopicPartition("topic", 0)),
-                internalTopologyBuilder.build(0),
-                clientSupplier.consumer,
-                clientSupplier.getProducer(new HashMap<String, Object>()),
-                clientSupplier.restoreConsumer,
-                config,
-                new MockStreamsMetrics(new Metrics()),
-                new StateDirectory(applicationId, 
config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime));
-
-        final StreamThread thread = new StreamThread(
-                internalTopologyBuilder,
-                config,
-                clientSupplier,
-                applicationId,
-                clientId,
-                processId,
-                metrics,
-                mockTime,
-                new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-                0,
-                stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitionsForTask) {
-                return testStreamTask;
-            }
-        };
-
-        final Set<TopicPartition> activeTasks = new HashSet<>();
-        activeTasks.add(new TopicPartition("topic", 0));
-
-        thread.setPartitionAssignor(new StreamPartitionAssignor() {
-            @Override
-            Map<TaskId, Set<TopicPartition>> activeTasks() {
-                return new HashMap<TaskId, Set<TopicPartition>>() {
-                    {
-                        put(new TaskId(0, 0), activeTasks);
-                    }
-                };
-            }
-        });
-        thread.setState(StreamThread.State.RUNNING);
-        thread.setState(StreamThread.State.PARTITIONS_REVOKED);
-        thread.rebalanceListener.onPartitionsAssigned(activeTasks);
-        thread.rebalanceListener.onPartitionsRevoked(activeTasks);
-
-        assertTrue(testStreamTask.suspended);
-        assertFalse(testStreamTask.closed);
-
-        activeTasks.clear();
-        // this should succeed without exception
-        
thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
-
-        assertTrue(testStreamTask.closed);
-    }
-
-    @Test
-    public void shouldInitializeRestoreConsumerWithOffsetsFromStandbyTasks()  {
-        internalStreamsBuilder.stream(null, null, null, null, 
"t1").groupByKey().count("count-one");
-        internalStreamsBuilder.stream(null, null, null, null, 
"t2").groupByKey().count("count-two");
-
-        final StreamThread thread = new StreamThread(
-                internalTopologyBuilder,
-                config,
-                clientSupplier,
-                applicationId,
-                clientId,
-                processId,
-                metrics,
-                mockTime,
-                new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-                0,
-                stateDirectory);
-
-        final MockConsumer<byte[], byte[]> restoreConsumer = 
clientSupplier.restoreConsumer;
-        
restoreConsumer.updatePartitions("stream-thread-test-count-one-changelog",
-                Collections.singletonList(new 
PartitionInfo("stream-thread-test-count-one-changelog",
-                        0,
-                        null,
-                        new Node[0],
-                        new Node[0])));
-        
restoreConsumer.updatePartitions("stream-thread-test-count-two-changelog",
-                Collections.singletonList(new 
PartitionInfo("stream-thread-test-count-two-changelog",
-                        0,
-                        null,
-                        new Node[0],
-                        new Node[0])));
+        final StreamThread thread = createStreamThread(clientId, config, 
false);
 
-        final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-        final TopicPartition t1 = new TopicPartition("t1", 0);
-        standbyTasks.put(new TaskId(0, 0), Utils.mkSet(t1));
-
-        thread.setPartitionAssignor(new StreamPartitionAssignor() {
+        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
             @Override
-            Map<TaskId, Set<TopicPartition>> standbyTasks() {
-                return standbyTasks;
+            public Map<TaskId, Set<TopicPartition>> standbyTasks() {
+                return Collections.singletonMap(new TaskId(0, 0), 
Utils.mkSet(new TopicPartition("topic", 0)));
             }
         });
 
         thread.setState(StreamThread.State.RUNNING);
         
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         
thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
-
-        assertThat(restoreConsumer.assignment(), equalTo(Utils.mkSet(new 
TopicPartition("stream-thread-test-count-one-changelog", 0))));
-
-        // assign an existing standby plus a new one
-        standbyTasks.put(new TaskId(1, 0), Utils.mkSet(new 
TopicPartition("t2", 0)));
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
-
-        assertThat(restoreConsumer.assignment(), equalTo(Utils.mkSet(new 
TopicPartition("stream-thread-test-count-one-changelog", 0),
-                new TopicPartition("stream-thread-test-count-two-changelog", 
0))));
     }
 
     @Test
@@ -1071,18 +856,7 @@ public class StreamThreadTest {
         internalStreamsBuilder.stream(null, null, null, null, 
"t1").groupByKey().count("count-one");
         internalStreamsBuilder.stream(null, null, null, null, 
"t2").groupByKey().count("count-two");
 
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, config, 
false);
         final MockConsumer<byte[], byte[]> restoreConsumer = 
clientSupplier.restoreConsumer;
         
restoreConsumer.updatePartitions("stream-thread-test-count-one-changelog",
                                          Collections.singletonList(new 
PartitionInfo("stream-thread-test-count-one-changelog",
@@ -1112,14 +886,14 @@ public class StreamThreadTest {
         final TopicPartition t2 = new TopicPartition("t2", 0);
         activeTasks.put(new TaskId(1, 0), Utils.mkSet(t2));
 
-        thread.setPartitionAssignor(new StreamPartitionAssignor() {
+        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
             @Override
-            Map<TaskId, Set<TopicPartition>> standbyTasks() {
+            public Map<TaskId, Set<TopicPartition>> standbyTasks() {
                 return standbyTasks;
             }
 
             @Override
-            Map<TaskId, Set<TopicPartition>> activeTasks() {
+            public Map<TaskId, Set<TopicPartition>> activeTasks() {
                 return activeTasks;
             }
         });
@@ -1139,105 +913,11 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void 
shouldCloseActiveTasksThatAreAssignedToThisStreamThreadButAssignmentHasChangedBeforeCreatingNewTasks()
 throws Exception {
-        internalStreamsBuilder.stream(null, null, null, null, 
Pattern.compile("t.*")).to("out");
-
-        final Map<Collection<TopicPartition>, TestStreamTask> createdTasks = 
new HashMap<>();
-
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                final ProcessorTopology topology = 
builder.build(id.topicGroupId);
-                final TestStreamTask task = new TestStreamTask(
-                    id,
-                    applicationId,
-                    partitions,
-                    topology,
-                    consumer,
-                    clientSupplier.getProducer(new HashMap<String, Object>()),
-                    restoreConsumer,
-                    config,
-                    new MockStreamsMetrics(new Metrics()),
-                    stateDirectory);
-                createdTasks.put(partitions, task);
-                return task;
-            }
-        };
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        final TopicPartition t1 = new TopicPartition("t1", 0);
-        final Set<TopicPartition> task00Partitions = new HashSet<>();
-        task00Partitions.add(t1);
-        final TaskId taskId = new TaskId(0, 0);
-        activeTasks.put(taskId, task00Partitions);
-
-        thread.setPartitionAssignor(new StreamPartitionAssignor() {
-            @Override
-            Map<TaskId, Set<TopicPartition>> activeTasks() {
-                return activeTasks;
-            }
-        });
-
-        final StreamPartitionAssignor.SubscriptionUpdates subscriptionUpdates 
= new StreamPartitionAssignor.SubscriptionUpdates();
-        final Field updatedTopicsField  = 
subscriptionUpdates.getClass().getDeclaredField("updatedTopicSubscriptions");
-        updatedTopicsField.setAccessible(true);
-        final Set<String> updatedTopics = (Set<String>) 
updatedTopicsField.get(subscriptionUpdates);
-        updatedTopics.add(t1.topic());
-        internalTopologyBuilder.updateSubscriptions(subscriptionUpdates, null);
-
-        // should create task for id 0_0 with a single partition
-        thread.setState(StreamThread.State.RUNNING);
-
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        thread.rebalanceListener.onPartitionsAssigned(task00Partitions);
-
-        final TestStreamTask firstTask = createdTasks.get(task00Partitions);
-        assertThat(firstTask.id(), is(taskId));
-
-        // update assignment for the task 0_0 so it now has 2 partitions
-        task00Partitions.add(new TopicPartition("t2", 0));
-        updatedTopics.add("t2");
-
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        thread.rebalanceListener.onPartitionsAssigned(task00Partitions);
-
-        // should close the first task as the assignment has changed
-        assertTrue("task should have been closed as assignment has changed", 
firstTask.closed);
-        assertTrue("tasks state manager should have been closed as assignment 
has changed", firstTask.closedStateManager);
-        // should have created a new task for 00
-        assertThat(createdTasks.get(task00Partitions).id(), is(taskId));
-    }
-
-    @Test
     public void 
shouldCloseTaskAsZombieAndRemoveFromActiveTasksIfProducerWasFencedWhileProcessing()
 throws Exception {
         internalTopologyBuilder.addSource(null, "source", null, null, null, 
TOPIC);
         internalTopologyBuilder.addSink("sink", "dummyTopic", null, null, 
null, "source");
 
-        final MockClientSupplier clientSupplier = new 
MockClientSupplier(applicationId);
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            new StreamsConfig(configProps(true)),
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, new 
StreamsConfig(configProps(true)), true);
 
         final MockConsumer consumer = clientSupplier.consumer;
         consumer.updatePartitions(TOPIC, Collections.singletonList(new 
PartitionInfo(TOPIC, 0, null, null, null)));
@@ -1245,7 +925,7 @@ public class StreamThreadTest {
         final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
         activeTasks.put(task1, task0Assignment);
 
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
+        thread.setThreadMetadataProvider(new 
MockStreamsPartitionAssignor(activeTasks));
 
         thread.start();
         TestUtils.waitForCondition(new TestCondition() {
@@ -1323,24 +1003,12 @@ public class StreamThreadTest {
         internalTopologyBuilder.addSource(null, "name", null, null, null, 
"topic");
         internalTopologyBuilder.addSink("out", "output", null, null, null);
 
-        final MockClientSupplier clientSupplier = new 
MockClientSupplier(applicationId);
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            new StreamsConfig(configProps(true)),
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            new Metrics(),
-            new MockTime(),
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, new 
StreamsConfig(configProps(true)), true);
 
         final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
         activeTasks.put(task1, task0Assignment);
 
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
+        thread.setThreadMetadataProvider(new 
MockStreamsPartitionAssignor(activeTasks));
 
         thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(null);
@@ -1358,320 +1026,12 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void 
shouldNotViolateAtLeastOnceWhenAnExceptionOccursOnTaskCloseDuringShutdown() 
throws Exception {
-        internalStreamsBuilder.stream(null, null, null, null, 
"t1").groupByKey();
-
-        final TestStreamTask testStreamTask = new TestStreamTask(
-            new TaskId(0, 0),
-            applicationId,
-            Utils.mkSet(new TopicPartition("t1", 0)),
-            internalTopologyBuilder.build(0),
-            clientSupplier.consumer,
-            clientSupplier.getProducer(new HashMap<String, Object>()),
-            clientSupplier.restoreConsumer,
-            config,
-            new MockStreamsMetrics(new Metrics()),
-            new StateDirectory(applicationId, 
config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime)) {
-
-            @Override
-            public void close(final boolean clean) {
-                throw new RuntimeException("KABOOM!");
-            }
-        };
-
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                return testStreamTask;
-            }
-        };
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(testStreamTask.id(), testStreamTask.partitions);
-
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
-
-        thread.setState(StreamThread.State.RUNNING);
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
-
-        thread.close();
-        thread.join();
-        assertFalse("task shouldn't have been committed as there was an 
exception during shutdown", testStreamTask.committed);
-    }
-
-    @Test
-    public void 
shouldNotViolateAtLeastOnceWhenAnExceptionOccursOnTaskFlushDuringShutdown() 
throws Exception {
-        final MockStateStoreSupplier.MockStateStore stateStore = new 
MockStateStoreSupplier.MockStateStore("foo", false);
-        internalStreamsBuilder.stream(null, null, null, null, 
"t1").groupByKey().count(new MockStateStoreSupplier(stateStore));
-        final TestStreamTask testStreamTask = new TestStreamTask(
-            new TaskId(0, 0),
-            applicationId,
-            Utils.mkSet(new TopicPartition("t1", 0)),
-            internalTopologyBuilder.build(0),
-            clientSupplier.consumer,
-            clientSupplier.getProducer(new HashMap<String, Object>()),
-            clientSupplier.restoreConsumer,
-            config,
-            new MockStreamsMetrics(metrics),
-            new StateDirectory(applicationId, 
config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime)) {
-
-            @Override
-            public void flushState() {
-                throw new RuntimeException("KABOOM!");
-            }
-        };
-
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                return testStreamTask;
-            }
-        };
-
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(testStreamTask.id(), testStreamTask.partitions);
-
-
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
-
-        thread.start();
-        TestUtils.waitForCondition(new TestCondition() {
-            @Override
-            public boolean conditionMet() {
-                return thread.state() == StreamThread.State.RUNNING;
-            }
-        }, 10 * 1000, "Thread never started.");
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
-        // store should have been opened
-        assertTrue(stateStore.isOpen());
-
-        thread.close();
-        thread.join();
-        assertFalse("task shouldn't have been committed as there was an 
exception during shutdown", testStreamTask.committed);
-        // store should be closed even if we had an exception
-        assertFalse(stateStore.isOpen());
-    }
-
-    @Test
-    public void shouldCaptureCommitFailedExceptionOnTaskSuspension() throws 
Exception {
-        internalStreamsBuilder.stream(null, null, null, null, "t1");
-
-        final TestStreamTask testStreamTask = new TestStreamTask(
-                new TaskId(0, 0),
-                applicationId,
-                Utils.mkSet(new TopicPartition("t1", 0)),
-                internalTopologyBuilder.build(0),
-                clientSupplier.consumer,
-                clientSupplier.getProducer(new HashMap<String, Object>()),
-                clientSupplier.restoreConsumer,
-                config,
-                new MockStreamsMetrics(new Metrics()),
-                new StateDirectory(applicationId, 
config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime)) {
-
-            @Override
-            public void suspend() {
-                throw new CommitFailedException();
-            }
-        };
-
-        final StreamThread thread = new StreamThread(
-                internalTopologyBuilder,
-                config,
-                clientSupplier,
-                applicationId,
-                clientId,
-                processId,
-                metrics,
-                mockTime,
-                new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-                0,
-                stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                return testStreamTask;
-            }
-        };
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(testStreamTask.id(), testStreamTask.partitions);
-
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
-        thread.setState(StreamThread.State.RUNNING);
-
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
-
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-
-        assertFalse(testStreamTask.committed);
-    }
-
-
-    @Test
-    public void 
shouldNotViolateAtLeastOnceWhenExceptionOccursDuringTaskSuspension() throws 
Exception {
-        internalStreamsBuilder.stream(null, null, null, null, 
"t1").groupByKey();
-
-        final TestStreamTask testStreamTask = new TestStreamTask(
-            new TaskId(0, 0),
-            applicationId,
-            Utils.mkSet(new TopicPartition("t1", 0)),
-            internalTopologyBuilder.build(0),
-            clientSupplier.consumer,
-            clientSupplier.getProducer(new HashMap<String, Object>()),
-            clientSupplier.restoreConsumer,
-            config,
-            new MockStreamsMetrics(new Metrics()),
-            new StateDirectory(applicationId, 
config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime)) {
-
-            @Override
-            public void suspend() {
-                throw new RuntimeException("KABOOM!");
-            }
-        };
-
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                return testStreamTask;
-            }
-        };
-
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(testStreamTask.id(), testStreamTask.partitions);
-
-
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
-
-        thread.setState(StreamThread.State.RUNNING);
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
-        try {
-            
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-            fail("should have thrown exception");
-        } catch (final Exception e) {
-            // expected
-        }
-        assertFalse(testStreamTask.committed);
-    }
-
-    @Test
-    public void 
shouldNotViolateAtLeastOnceWhenExceptionOccursDuringFlushStateWhileSuspendingState()
 throws Exception {
-        internalStreamsBuilder.stream(null, null, null, null, 
"t1").groupByKey();
-
-        final TestStreamTask testStreamTask = new TestStreamTask(
-            new TaskId(0, 0),
-            applicationId,
-            Utils.mkSet(new TopicPartition("t1", 0)),
-            internalTopologyBuilder.build(0),
-            clientSupplier.consumer,
-            clientSupplier.getProducer(new HashMap<String, Object>()),
-            clientSupplier.restoreConsumer,
-            config,
-            new MockStreamsMetrics(new Metrics()),
-            new StateDirectory(applicationId, 
config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime)) {
-
-            @Override
-            protected void flushState() {
-                throw new RuntimeException("KABOOM!");
-            }
-        };
-
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST), 0, stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                return testStreamTask;
-            }
-        };
-
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(testStreamTask.id(), testStreamTask.partitions);
-
-
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
-
-        thread.setState(StreamThread.State.RUNNING);
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        
thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
-        try {
-            
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-            fail("should have thrown exception");
-        } catch (final Exception e) {
-            // expected
-        }
-        assertFalse(testStreamTask.committed);
-    }
-
-
-    @Test
     @SuppressWarnings("unchecked")
     public void 
shouldAlwaysUpdateWithLatestTopicsFromStreamPartitionAssignor() throws 
Exception {
         internalTopologyBuilder.addSource(null, "source", null, null, null, 
Pattern.compile("t.*"));
         internalTopologyBuilder.addProcessor("processor", new 
MockProcessorSupplier(), "source");
 
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory);
+        final StreamThread thread = createStreamThread(clientId, config, 
false);
 
         final StreamPartitionAssignor partitionAssignor = new 
StreamPartitionAssignor();
         final Map<String, Object> configurationMap = new HashMap<>();
@@ -1680,7 +1040,7 @@ public class StreamThreadTest {
         configurationMap.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 0);
         partitionAssignor.configure(configurationMap);
 
-        thread.setPartitionAssignor(partitionAssignor);
+        thread.setThreadMetadataProvider(partitionAssignor);
 
         final Field nodeToSourceTopicsField =
             
internalTopologyBuilder.getClass().getDeclaredField("nodeToSourceTopics");
@@ -1739,240 +1099,6 @@ public class StreamThreadTest {
 
     }
 
-    @Test
-    public void shouldReleaseStateDirLockIfFailureOnTaskSuspend() throws 
Exception {
-        final TaskId taskId = new TaskId(0, 0);
-
-        final StateDirectory stateDirMock = mockStateDirInteractions(taskId);
-        final StreamThread thread = setupTest(taskId, stateDirMock);
-
-        try {
-            
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-            fail("Should have thrown exception");
-        } catch (final Exception e) {
-            //
-        } finally {
-            thread.close();
-        }
-
-        EasyMock.verify(stateDirMock);
-    }
-
-    @Test
-    public void 
shouldReleaseStateDirLockIfFailureOnTaskCloseForSuspendedTask() throws 
Exception {
-        final TaskId taskId = new TaskId(0, 0);
-
-        final StateDirectory stateDirMock = mockStateDirInteractions(taskId);
-
-        final StreamThread thread = setupTest(taskId, stateDirMock);
-        thread.close();
-        thread.join();
-        EasyMock.verify(stateDirMock);
-    }
-
-
-    private StreamThread setupTest(final TaskId taskId, final StateDirectory 
stateDirectory) throws InterruptedException {
-        final TopologyBuilder builder = new TopologyBuilder();
-        builder.setApplicationId(applicationId);
-        builder.addSource("source", "topic");
-
-        final MockClientSupplier clientSupplier = new MockClientSupplier();
-
-        final TestStreamTask testStreamTask = new TestStreamTask(taskId,
-            applicationId,
-            Utils.mkSet(new TopicPartition("topic", 0)),
-            builder.build(0),
-            clientSupplier.consumer,
-            clientSupplier.getProducer(new HashMap<String, Object>()),
-            clientSupplier.restoreConsumer,
-            config,
-            new MockStreamsMetrics(new Metrics()),
-            stateDirectory) {
-
-            @Override
-            public void suspend() {
-                throw new RuntimeException("KABOOM!!!");
-            }
-        };
-
-        final StreamThread thread = new StreamThread(
-            builder.internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            new Metrics(),
-            new MockTime(),
-            new StreamsMetadataState(builder.internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0, stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                return testStreamTask;
-            }
-        };
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(testStreamTask.id, testStreamTask.partitions);
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(activeTasks));
-        thread.start();
-
-        TestUtils.waitForCondition(new TestCondition() {
-            @Override
-            public boolean conditionMet() {
-                return thread.state() == StreamThread.State.RUNNING;
-            }
-        }, "thread didn't transition to running");
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptySet());
-        
thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
-
-        return thread;
-    }
-
-    @Test
-    public void shouldReleaseStateDirLockIfFailureOnStandbyTaskSuspend() 
throws Exception {
-        final TaskId taskId = new TaskId(0, 0);
-
-        final StateDirectory stateDirMock = mockStateDirInteractions(taskId);
-        final StreamThread thread = setupStandbyTest(taskId, stateDirMock);
-
-        startThreadAndRebalance(thread);
-
-        try {
-            
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-            fail("Should have thrown exception");
-        } catch (final Exception e) {
-            // ok
-        } finally {
-            thread.close();
-        }
-        EasyMock.verify(stateDirMock);
-    }
-
-    private void startThreadAndRebalance(final StreamThread thread) throws 
InterruptedException {
-        thread.start();
-        TestUtils.waitForCondition(new TestCondition() {
-            @Override
-            public boolean conditionMet() {
-                return thread.state() == StreamThread.State.RUNNING;
-            }
-        }, "thread didn't transition to running");
-        
thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptySet());
-        
thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptySet());
-    }
-
-    @Test
-    public void 
shouldReleaseStateDirLockIfFailureOnStandbyTaskCloseForUnassignedSuspendedStandbyTask()
 throws Exception {
-        final TaskId taskId = new TaskId(0, 0);
-
-        final StateDirectory stateDirMock = mockStateDirInteractions(taskId);
-        final StreamThread thread = setupStandbyTest(taskId, stateDirMock);
-        startThreadAndRebalance(thread);
-
-
-        try {
-            thread.close();
-            thread.join();
-        } finally {
-            thread.close();
-        }
-        EasyMock.verify(stateDirMock);
-    }
-
-    private StateDirectory mockStateDirInteractions(final TaskId taskId) 
throws IOException {
-        final StateDirectory stateDirMock = 
EasyMock.createNiceMock(StateDirectory.class);
-        EasyMock.expect(stateDirMock.lock(EasyMock.eq(taskId), 
EasyMock.anyInt())).andReturn(true);
-        EasyMock.expect(stateDirMock.directoryForTask(taskId)).andReturn(new 
File(stateDir));
-        stateDirMock.unlock(taskId);
-        EasyMock.expectLastCall();
-        EasyMock.replay(stateDirMock);
-        return stateDirMock;
-    }
-
-    private StreamThread setupStandbyTest(final TaskId taskId, final 
StateDirectory stateDirectory) {
-        final String storeName = "store";
-        final String changelogTopic = applicationId + "-" + storeName + 
"-changelog";
-
-        internalStreamsBuilder.stream(null, null, null, null, 
"topic1").groupByKey().count(storeName);
-
-        final MockClientSupplier clientSupplier = new MockClientSupplier();
-        clientSupplier.restoreConsumer.updatePartitions(changelogTopic,
-            Collections.singletonList(new PartitionInfo(changelogTopic, 0, 
null, null, null)));
-        clientSupplier.restoreConsumer.updateBeginningOffsets(new 
HashMap<TopicPartition, Long>() {
-            {
-                put(new TopicPartition(changelogTopic, 0), 0L);
-            }
-        });
-        clientSupplier.restoreConsumer.updateEndOffsets(new 
HashMap<TopicPartition, Long>() {
-            {
-                put(new TopicPartition(changelogTopic, 0), 0L);
-            }
-        });
-
-        final StreamThread thread = new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            new Metrics(),
-            new MockTime(),
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory) {
-
-            @Override
-            protected StandbyTask createStandbyTask(final TaskId id, final 
Collection<TopicPartition> partitions) {
-                return new StandbyTask(
-                    taskId,
-                    applicationId,
-                    partitions,
-                    builder.build(0),
-                    clientSupplier.consumer,
-                    new StoreChangelogReader(getName(), 
clientSupplier.restoreConsumer, mockTime, 1000,
-                                             new MockStateRestoreListener()),
-                    StreamThreadTest.this.config,
-                    new StreamsMetricsImpl(new Metrics(), "groupName", 
Collections.<String, String>emptyMap()),
-                    stateDirectory) {
-
-                    @Override
-                    public void suspend() {
-                        throw new RuntimeException("KABOOM!!!");
-                    }
-
-                    @Override
-                    public void commit() {
-                        throw new RuntimeException("KABOOM!!!");
-                    }
-                };
-            }
-        };
-
-        final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-        standbyTasks.put(taskId, Collections.singleton(new 
TopicPartition("topic", 0)));
-        thread.setPartitionAssignor(new 
MockStreamsPartitionAssignor(Collections.<TaskId, 
Set<TopicPartition>>emptyMap(), standbyTasks));
-
-        return thread;
-    }
-
-    private void initPartitionGrouper(final StreamsConfig config,
-                                      final StreamThread thread,
-                                      final MockClientSupplier clientSupplier) 
{
-
-        final StreamPartitionAssignor partitionAssignor = new 
StreamPartitionAssignor();
-
-        partitionAssignor.configure(config.getConsumerConfigs(thread, 
thread.applicationId, thread.clientId));
-        final MockInternalTopicManager internalTopicManager = new 
MockInternalTopicManager(thread.config, clientSupplier.restoreConsumer);
-        partitionAssignor.setInternalTopicManager(internalTopicManager);
-
-        final Map<String, PartitionAssignor.Assignment> assignments =
-            partitionAssignor.assign(metadata, 
Collections.singletonMap("client", subscription));
-
-        partitionAssignor.onAssignment(assignments.get("client"));
-    }
-
     private static class StateListenerStub implements 
StreamThread.StateListener {
         int numChanges = 0;
         ThreadStateTransitionValidator oldState = null;
@@ -1993,34 +1119,6 @@ public class StreamThreadTest {
     }
 
     private StreamThread getStreamThread() {
-        return new StreamThread(
-            internalTopologyBuilder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            Time.SYSTEM,
-            new StreamsMetadataState(internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-            0,
-            stateDirectory) {
-
-            @Override
-            protected StreamTask createStreamTask(final TaskId id, final 
Collection<TopicPartition> partitionsForTask) {
-                final ProcessorTopology topology = 
builder.build(id.topicGroupId);
-                return new TestStreamTask(
-                    id,
-                    applicationId,
-                    partitionsForTask,
-                    topology,
-                    consumer,
-                    clientSupplier.getProducer(new HashMap()),
-                    restoreConsumer,
-                    config,
-                    new MockStreamsMetrics(new Metrics()),
-                    stateDirectory);
-            }
-        };
+        return createStreamThread(clientId, config, false);
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
new file mode 100644
index 0000000..0c31a81
--- /dev/null
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.CommitFailedException;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.TaskId;
+import org.easymock.EasyMock;
+import org.easymock.EasyMockRunner;
+import org.easymock.Mock;
+import org.easymock.MockType;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+
+import static org.easymock.EasyMock.checkOrder;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.expectLastCall;
+import static org.easymock.EasyMock.replay;
+import static org.easymock.EasyMock.verify;
+import static org.junit.Assert.fail;
+
+@RunWith(EasyMockRunner.class)
+public class TaskManagerTest {
+
+    private final Time time = new MockTime();
+    private final TaskId taskId0 = new TaskId(0, 0);
+    private final TopicPartition t1p0 = new TopicPartition("t1", 0);
+    private final TopicPartition t1p1 = new TopicPartition("t1", 1);
+    private final Set<TopicPartition> taskId0Partitions = Utils.mkSet(t1p0);
+    private final Map<TaskId, Set<TopicPartition>> taskId0Assignment = 
Collections.singletonMap(taskId0, taskId0Partitions);
+
+    @Mock(type = MockType.NICE)
+    private ChangelogReader changeLogReader;
+    @Mock(type = MockType.NICE)
+    private Consumer<byte[], byte[]> restoreConsumer;
+    @Mock(type = MockType.NICE)
+    private Consumer<byte[], byte[]> consumer;
+    @Mock(type = MockType.NICE)
+    private StreamThread.AbstractTaskCreator activeTaskCreator;
+    @Mock(type = MockType.NICE)
+    private StreamThread.AbstractTaskCreator standbyTaskCreator;
+    @Mock(type = MockType.NICE)
+    private ThreadMetadataProvider threadMetadataProvider;
+    @Mock(type = MockType.NICE)
+    private Task firstTask;
+
+    private TaskManager taskManager;
+
+
+    @Before
+    public void setUp() throws Exception {
+        taskManager = new TaskManager(changeLogReader, time, "", 
restoreConsumer, activeTaskCreator, standbyTaskCreator);
+        taskManager.setThreadMetadataProvider(threadMetadataProvider);
+        taskManager.setConsumer(consumer);
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void 
shouldCloseActiveUnAssignedSuspendedTasksBeforeCreatingNewTasksWhenTaskPartitionAssignmentHasChanged()
 {
+        final Set<TopicPartition> secondPartitionAssignment = 
Utils.mkSet(t1p0, t1p1);
+        final Map<TaskId, Set<TopicPartition>> secondAssignment = 
Collections.singletonMap(taskId0, secondPartitionAssignment);
+
+        mockSingleActiveTask();
+        
expect(activeTaskCreator.retryWithBackoff(EasyMock.anyObject(Consumer.class),
+                                                           
EasyMock.eq(secondAssignment),
+                                                           
EasyMock.eq(time.milliseconds())))
+                .andReturn(Collections.singletonMap(firstTask, 
secondPartitionAssignment));
+
+        expect(threadMetadataProvider.activeTasks())
+                .andReturn(secondAssignment);
+
+        firstTask.closeSuspended(true, null);
+        expectLastCall();
+
+        replay(threadMetadataProvider, activeTaskCreator, standbyTaskCreator, 
firstTask);
+
+        taskManager.createTasks(taskId0Partitions);
+        taskManager.suspendTasksAndState();
+
+        taskManager.createTasks(secondPartitionAssignment);
+
+        verify(activeTaskCreator, firstTask);
+    }
+
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldCloseStandbyTaskIfFailureOnSuspend() {
+        checkOrder(firstTask, true);
+        mockStandbyTaskExpectations(Collections.<TopicPartition, 
Long>emptyMap());
+        
verifyTaskIsClosedOnSuspendFailure(Collections.<TopicPartition>emptySet());
+    }
+
+    @Test
+    public void shouldCloseActiveTaskIfFailureOnSuspend() {
+        checkOrder(firstTask, true);
+        mockSingleActiveTask();
+        verifyTaskIsClosedOnSuspendFailure(taskId0Partitions);
+    }
+
+    @Test
+    public void shouldInitializeRestoreConsumerWithOffsetsFromStandbyTasks() {
+        mockStandbyTaskExpectations(Collections.singletonMap(t1p0, 0L));
+        restoreConsumer.assign(EasyMock.eq(taskId0Partitions));
+        expectLastCall();
+        replay(threadMetadataProvider, firstTask, activeTaskCreator, 
standbyTaskCreator, restoreConsumer);
+
+        taskManager.createTasks(Collections.<TopicPartition>emptySet());
+
+        EasyMock.verify(restoreConsumer);
+    }
+
+    @Test
+    public void shouldNotCloseSuspendedTasksTwice() {
+        mockSingleActiveTask();
+        expect(threadMetadataProvider.activeTasks())
+                .andReturn(Collections.<TaskId, 
Set<TopicPartition>>emptyMap());
+        firstTask.suspend();
+        expectLastCall();
+        firstTask.closeSuspended(true, null);
+        expectLastCall();
+
+        replay(threadMetadataProvider, activeTaskCreator, standbyTaskCreator, 
firstTask);
+
+        taskManager.createTasks(taskId0Partitions);
+        taskManager.suspendTasksAndState();
+
+        taskManager.createTasks(Collections.<TopicPartition>emptySet());
+
+        verify(firstTask);
+    }
+
+    @Test
+    public void 
shouldNotCloseActiveTaskOnCommitFailedExceptionDuringTaskSuspend() {
+        checkOrder(firstTask, true);
+        mockSingleActiveTask();
+        firstTask.suspend();
+        expectLastCall().andThrow(new CommitFailedException());
+
+        replay(threadMetadataProvider, firstTask, activeTaskCreator, 
standbyTaskCreator);
+
+        taskManager.createTasks(taskId0Partitions);
+
+        taskManager.suspendTasksAndState();
+        verify(firstTask);
+    }
+
+
+    @SuppressWarnings("unchecked")
+    private void mockStandbyTaskExpectations(final Map<TopicPartition, Long> 
checkpoint) {
+        expect(threadMetadataProvider.standbyTasks())
+                .andReturn(taskId0Assignment)
+                .anyTimes();
+        expect(threadMetadataProvider.activeTasks())
+                .andStubReturn(Collections.<TaskId, 
Set<TopicPartition>>emptyMap());
+
+        
expect(standbyTaskCreator.retryWithBackoff(EasyMock.anyObject(Consumer.class),
+                                                   
EasyMock.eq(taskId0Assignment),
+                                                   
EasyMock.eq(time.milliseconds())))
+                .andReturn(Collections.singletonMap(firstTask, 
taskId0Partitions));
+
+        stubTaskCreator(activeTaskCreator);
+
+        expect(firstTask.checkpointedOffsets())
+                .andReturn(checkpoint)
+                .anyTimes();
+    }
+
+    @SuppressWarnings("unchecked")
+    private void mockSingleActiveTask() {
+        expect(threadMetadataProvider.standbyTasks())
+                .andReturn(Collections.<TaskId, Set<TopicPartition>>emptyMap())
+                .anyTimes();
+        expect(threadMetadataProvider.activeTasks())
+                .andReturn(taskId0Assignment);
+
+        
expect(activeTaskCreator.retryWithBackoff(EasyMock.anyObject(Consumer.class),
+                                                  
EasyMock.eq(taskId0Assignment),
+                                                  
EasyMock.eq(time.milliseconds())))
+                .andReturn(Collections.singletonMap(firstTask, 
taskId0Partitions));
+
+        stubTaskCreator(standbyTaskCreator);
+
+        expect(firstTask.id()).andStubReturn(taskId0);
+        expect(firstTask.partitions()).andStubReturn(taskId0Partitions);
+    }
+
+    private void verifyTaskIsClosedOnSuspendFailure(final Set<TopicPartition> 
assignment) {
+        firstTask.suspend();
+        expectLastCall().andThrow(new RuntimeException("KABOOM!"));
+        firstTask.close(false);
+        expectLastCall();
+        replay(threadMetadataProvider, firstTask, activeTaskCreator, 
standbyTaskCreator);
+
+        taskManager.createTasks(assignment);
+
+        try {
+            taskManager.suspendTasksAndState();
+            fail("should have thrown StreamsException");
+        } catch (final StreamsException e) {
+            // pass
+        }
+        verify(firstTask);
+    }
+
+    @SuppressWarnings("unchecked")
+    private void stubTaskCreator(final StreamThread.AbstractTaskCreator 
taskCreator) {
+        expect(taskCreator.retryWithBackoff(EasyMock.anyObject(Consumer.class),
+                                                   
EasyMock.anyObject(Map.class),
+                                                   EasyMock.anyLong()))
+                .andReturn(Collections.<Task, Set<TopicPartition>>emptyMap())
+                .anyTimes();
+    }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
index 62cb09c..ef24cf4 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
@@ -33,7 +33,7 @@ import 
org.apache.kafka.streams.processor.internals.StateDirectory;
 import org.apache.kafka.streams.processor.internals.StoreChangelogReader;
 import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.StreamThread;
-import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
+import org.apache.kafka.streams.processor.internals.Task;
 import org.apache.kafka.streams.state.QueryableStoreTypes;
 import org.apache.kafka.streams.state.ReadOnlyKeyValueStore;
 import org.apache.kafka.streams.state.ReadOnlyWindowStore;
@@ -42,6 +42,7 @@ import org.apache.kafka.test.MockClientSupplier;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.MockStateRestoreListener;
 import org.apache.kafka.test.TestUtils;
+import org.easymock.EasyMock;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -54,7 +55,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
-import java.util.UUID;
 
 import static org.apache.kafka.streams.state.QueryableStoreTypes.windowStore;
 import static org.junit.Assert.assertEquals;
@@ -66,9 +66,11 @@ public class StreamThreadStateStoreProviderTest {
     private StreamThreadStateStoreProvider provider;
     private StateDirectory stateDirectory;
     private File stateDir;
-    private boolean storesAvailable;
     private final String topicName = "topic";
+    private StreamThread threadMock;
+    private Map<TaskId, Task> tasks;
 
+    @SuppressWarnings("deprecation")
     @Before
     public void before() throws IOException {
         final TopologyBuilder builder = new TopologyBuilder();
@@ -100,7 +102,7 @@ public class StreamThreadStateStoreProviderTest {
 
         builder.setApplicationId(applicationId);
         final ProcessorTopology topology = builder.build(null);
-        final Map<TaskId, StreamTask> tasks = new HashMap<>();
+        tasks = new HashMap<>();
         stateDirectory = new StateDirectory(applicationId, stateConfigDir, new 
MockTime());
         taskOne = createStreamsTask(applicationId, streamsConfig, 
clientSupplier, topology,
                                     new TaskId(0, 0));
@@ -111,30 +113,9 @@ public class StreamThreadStateStoreProviderTest {
         tasks.put(new TaskId(0, 1),
                   taskTwo);
 
-        storesAvailable = true;
-        provider = new StreamThreadStateStoreProvider(
-            new StreamThread(
-                    builder.internalTopologyBuilder,
-                    streamsConfig,
-                    clientSupplier,
-                    applicationId,
-                    "clientId",
-                    UUID.randomUUID(),
-                    new Metrics(),
-                    Time.SYSTEM,
-                    new StreamsMetadataState(builder.internalTopologyBuilder, 
StreamsMetadataState.UNKNOWN_HOST),
-                    0, stateDirectory) {
-
-                @Override
-                public Map<TaskId, StreamTask> tasks() {
-                    return tasks;
-                }
-
-                @Override
-                public boolean isInitialized() {
-                    return storesAvailable;
-                }
-            });
+        threadMock = EasyMock.createNiceMock(StreamThread.class);
+        provider = new StreamThreadStateStoreProvider(threadMock);
+
     }
 
     @After
@@ -144,6 +125,7 @@ public class StreamThreadStateStoreProviderTest {
     
     @Test
     public void shouldFindKeyValueStores() throws Exception {
+        mockThread(true);
         final List<ReadOnlyKeyValueStore<String, String>> kvStores =
             provider.stores("kv-store", QueryableStoreTypes.<String, 
String>keyValueStore());
         assertEquals(2, kvStores.size());
@@ -151,6 +133,7 @@ public class StreamThreadStateStoreProviderTest {
 
     @Test
     public void shouldFindWindowStores() throws Exception {
+        mockThread(true);
         final List<ReadOnlyWindowStore<Object, Object>>
             windowStores =
             provider.stores("window-store", windowStore());
@@ -159,18 +142,21 @@ public class StreamThreadStateStoreProviderTest {
 
     @Test(expected = InvalidStateStoreException.class)
     public void shouldThrowInvalidStoreExceptionIfWindowStoreClosed() throws 
Exception {
+        mockThread(true);
         taskOne.getStore("window-store").close();
         provider.stores("window-store", QueryableStoreTypes.windowStore());
     }
 
     @Test(expected = InvalidStateStoreException.class)
     public void shouldThrowInvalidStoreExceptionIfKVStoreClosed() throws 
Exception {
+        mockThread(true);
         taskOne.getStore("kv-store").close();
         provider.stores("kv-store", QueryableStoreTypes.keyValueStore());
     }
 
     @Test
     public void shouldReturnEmptyListIfNoStoresFoundWithName() throws 
Exception {
+        mockThread(true);
         assertEquals(Collections.emptyList(), provider.stores("not-a-store", 
QueryableStoreTypes
             .keyValueStore()));
     }
@@ -178,13 +164,14 @@ public class StreamThreadStateStoreProviderTest {
 
     @Test
     public void shouldReturnEmptyListIfStoreExistsButIsNotOfTypeValueStore() 
throws Exception {
+        mockThread(true);
         assertEquals(Collections.emptyList(), provider.stores("window-store",
                                                               
QueryableStoreTypes.keyValueStore()));
     }
 
     @Test(expected = InvalidStateStoreException.class)
     public void shouldThrowInvalidStoreExceptionIfNotAllStoresAvailable() 
throws Exception {
-        storesAvailable = false;
+        mockThread(false);
         provider.stores("kv-store", QueryableStoreTypes.keyValueStore());
     }
 
@@ -212,6 +199,12 @@ public class StreamThreadStateStoreProviderTest {
         };
     }
 
+    private void mockThread(final boolean initialized) {
+        EasyMock.expect(threadMock.isInitialized()).andReturn(initialized);
+        EasyMock.expect(threadMock.tasks()).andStubReturn(tasks);
+        EasyMock.replay(threadMock);
+    }
+
     private void configureRestoreConsumer(final MockClientSupplier 
clientSupplier,
                                           final String topic) {
         clientSupplier.restoreConsumer

http://git-wip-us.apache.org/repos/asf/kafka/blob/3e69ce80/streams/src/test/java/org/apache/kafka/test/MockChangelogReader.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/test/MockChangelogReader.java 
b/streams/src/test/java/org/apache/kafka/test/MockChangelogReader.java
index 4db10e7..1389bbe 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockChangelogReader.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockChangelogReader.java
@@ -48,6 +48,11 @@ public class MockChangelogReader implements ChangelogReader {
         return Collections.emptyMap();
     }
 
+    @Override
+    public void clear() {
+        registered.clear();
+    }
+
     public boolean wasRegistered(final TopicPartition partition) {
         return registered.contains(partition);
     }

Reply via email to