Repository: kafka
Updated Branches:
  refs/heads/trunk f9d7808ba -> ea42d6535


KAFKA-3637: Added initial states

Author: Eno Thereska <eno.there...@gmail.com>

Reviewers: Ismael Juma, Dan Norwood, Xavier Léauté, Damian Guy, Michael G. 
Noll, Matthias J. Sax, Guozhang Wang

Closes #2135 from enothereska/KAFKA-3637-streams-state


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/ea42d653
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/ea42d653
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/ea42d653

Branch: refs/heads/trunk
Commit: ea42d65354b5905668d45dedae1cd1f7f39c888c
Parents: f9d7808
Author: Eno Thereska <eno.there...@gmail.com>
Authored: Wed Nov 30 22:23:31 2016 -0800
Committer: Guozhang Wang <wangg...@gmail.com>
Committed: Wed Nov 30 22:23:31 2016 -0800

----------------------------------------------------------------------
 .../org/apache/kafka/streams/KafkaStreams.java  | 144 ++++++++++++++++---
 .../processor/internals/StreamThread.java       | 140 +++++++++++++++---
 .../apache/kafka/streams/KafkaStreamsTest.java  |  28 +++-
 .../QueryableStateIntegrationTest.java          |   2 +
 .../processor/internals/StreamThreadTest.java   |  36 ++++-
 5 files changed, 309 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/ea42d653/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java 
b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
index 6b35d24..df6da21 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
@@ -46,6 +46,9 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Properties;
 import java.util.UUID;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Arrays;
 import java.util.concurrent.TimeUnit;
 
 /**
@@ -92,10 +95,6 @@ public class KafkaStreams {
     private static final Logger log = 
LoggerFactory.getLogger(KafkaStreams.class);
     private static final String JMX_PREFIX = "kafka.streams";
     public static final int DEFAULT_CLOSE_TIMEOUT = 0;
-
-    private enum StreamsState { created, running, stopped }
-    private StreamsState state = StreamsState.created;
-
     private final StreamThread[] threads;
     private final Metrics metrics;
     private final QueryableStoreProvider queryableStoreProvider;
@@ -109,6 +108,110 @@ public class KafkaStreams {
 
     private final StreamsConfig config;
 
+    // container states
+    /**
+     * Kafka Streams states are the possible state that a Kafka Streams 
instance can be in.
+     * An instance must only be in one state at a time.
+     * Note this instance will be in "Rebalancing" state if any of its threads 
is rebalancing
+     * The expected state transition with the following defined states is:
+     *
+     *                 +-----------+
+     *         +<------|Created    |
+     *         |       +-----+-----+
+     *         |             |   +--+
+     *         |             v   |  |
+     *         |       +-----+---v--+--+
+     *         +<----- | Rebalancing   |<--------+
+     *         |       +-----+---------+         ^
+     *         |                 +--+            |
+     *         |                 |  |            |
+     *         |       +-----+---v--+-----+      |
+     *         +------>|Running           |------+
+     *         |       +-----+------------+
+     *         |             |
+     *         |             v
+     *         |     +-------+--------+
+     *         +---->|Pending         |
+     *               |Shutdown        |
+     *               +-------+--------+
+     *                       |
+     *                       v
+     *                 +-----+-----+
+     *                 |Not Running|
+     *                 +-----------+
+     */
+    public enum State {
+        CREATED(1, 2, 3), RUNNING(2, 3), REBALANCING(1, 2, 3), 
PENDING_SHUTDOWN(4), NOT_RUNNING;
+
+        private final Set<Integer> validTransitions = new HashSet<>();
+
+        State(final Integer...validTransitions) {
+            this.validTransitions.addAll(Arrays.asList(validTransitions));
+        }
+
+        public boolean isRunning() {
+            return this.equals(RUNNING) || this.equals(REBALANCING);
+        }
+        public boolean isCreatedOrRunning() {
+            return isRunning() || this.equals(CREATED);
+        }
+        public boolean isValidTransition(final State newState) {
+            return validTransitions.contains(newState.ordinal());
+        }
+    }
+    private volatile State state = KafkaStreams.State.CREATED;
+    private StateListener stateListener = null;
+    private final StreamStateListener streamStateListener = new 
StreamStateListener();
+
+    /**
+     * Listen to state change events
+     */
+    public interface StateListener {
+
+        /**
+         * Called when state changes
+         * @param newState     current state
+         * @param oldState     previous state
+         */
+        void onChange(final State newState, final State oldState);
+    }
+
+    /**
+     * An app can set {@link StateListener} so that the app is notified when 
state changes
+     * @param listener
+     */
+    public void setStateListener(final StateListener listener) {
+        this.stateListener = listener;
+    }
+
+    private synchronized void setState(State newState) {
+        State oldState = state;
+        if (!state.isValidTransition(newState)) {
+            throw new IllegalStateException("Incorrect state transition from " 
+ state + " to " + newState);
+        }
+        state = newState;
+        if (stateListener != null) {
+            stateListener.onChange(state, oldState);
+        }
+    }
+
+
+    /**
+     * @return The state this instance is in
+     */
+    public synchronized State state() {
+        return state;
+    }
+
+    private class StreamStateListener implements StreamThread.StateListener {
+        @Override
+        public void onChange(final StreamThread.State newState, final 
StreamThread.State oldState) {
+            if (newState == StreamThread.State.PARTITIONS_REVOKED ||
+                newState == StreamThread.State.ASSIGNING_PARTITIONS) {
+                setState(KafkaStreams.State.REBALANCING);
+            }
+        }
+    }
     /**
      * Construct the stream instance.
      *
@@ -140,7 +243,6 @@ public class KafkaStreams {
     public KafkaStreams(final TopologyBuilder builder, final StreamsConfig 
config, final KafkaClientSupplier clientSupplier) {
         // create the metrics
         final Time time = new SystemTime();
-
         processId = UUID.randomUUID();
 
         this.config = config;
@@ -177,6 +279,7 @@ public class KafkaStreams {
                 metrics,
                 time,
                 streamsMetadataState);
+            threads[i].setStateListener(streamStateListener);
             storeProviders.add(new StreamThreadStateStoreProvider(threads[i]));
         }
         queryableStoreProvider = new QueryableStoreProvider(storeProviders);
@@ -190,16 +293,15 @@ public class KafkaStreams {
     public synchronized void start() {
         log.debug("Starting Kafka Stream process");
 
-        if (state == StreamsState.created) {
-            for (final StreamThread thread : threads) {
+        if (state == KafkaStreams.State.CREATED) {
+            for (final StreamThread thread : threads)
                 thread.start();
-            }
-            state = StreamsState.running;
+
+            setState(KafkaStreams.State.RUNNING);
+
             log.info("Started Kafka Stream process");
-        } else if (state == StreamsState.running) {
-            throw new IllegalStateException("This process was already 
started.");
         } else {
-            throw new IllegalStateException("Cannot restart after closing.");
+            throw new IllegalStateException("Cannot start again.");
         }
     }
 
@@ -225,7 +327,8 @@ public class KafkaStreams {
      */
     public synchronized boolean close(final long timeout, final TimeUnit 
timeUnit) {
         log.debug("Stopping Kafka Stream process");
-        if (state == StreamsState.running) {
+        if (state.isCreatedOrRunning()) {
+            setState(KafkaStreams.State.PENDING_SHUTDOWN);
             // save the current thread so that if it is a stream thread
             // we don't attempt to join it and cause a deadlock
             final Thread shutdown = new Thread(new Runnable() {
@@ -233,6 +336,9 @@ public class KafkaStreams {
                 public void run() {
                         // signal the threads to stop and wait
                         for (final StreamThread thread : threads) {
+                            // avoid deadlocks by stopping any further state 
reports
+                            // from the thread since we're shutting down
+                            thread.setStateListener(null);
                             thread.close();
                         }
 
@@ -247,7 +353,7 @@ public class KafkaStreams {
                         }
 
                         metrics.close();
-                        log.info("Stopped Kafka Stream process");
+                        log.info("Stopped Kafka Streams process");
                     }
             }, "kafka-streams-close-thread");
             shutdown.setDaemon(true);
@@ -257,11 +363,10 @@ public class KafkaStreams {
             } catch (InterruptedException e) {
                 Thread.interrupted();
             }
-            state = StreamsState.stopped;
+            setState(KafkaStreams.State.NOT_RUNNING);
             return !shutdown.isAlive();
         }
         return true;
-
     }
 
     /**
@@ -288,7 +393,7 @@ public class KafkaStreams {
      * @throws IllegalStateException if instance is currently running
      */
     public void cleanUp() {
-        if (state == StreamsState.running) {
+        if (state.isRunning()) {
             throw new IllegalStateException("Cannot clean up while running.");
         }
 
@@ -404,9 +509,8 @@ public class KafkaStreams {
     }
 
     private void validateIsRunning() {
-        if (state != StreamsState.running) {
-            throw new IllegalStateException("KafkaStreams is not running");
+        if (!state.isRunning()) {
+            throw new IllegalStateException("KafkaStreams is not running. 
State is " + state);
         }
     }
-
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/ea42d653/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index a135a15..0c42521 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -51,6 +51,7 @@ import org.slf4j.LoggerFactory;
 
 import java.io.File;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -60,7 +61,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.regex.Pattern;
 
@@ -71,6 +71,102 @@ public class StreamThread extends Thread {
     private static final Logger log = 
LoggerFactory.getLogger(StreamThread.class);
     private static final AtomicInteger STREAM_THREAD_ID_SEQUENCE = new 
AtomicInteger(1);
 
+    /**
+     * Stream thread states are the possible states that a stream thread can 
be in.
+     * A thread must only be in one state at a time
+     * The expected state transitions with the following defined states is:
+     *
+     *                +-----------+
+     *                |Not Running|<---------------+
+     *                +-----+-----+                |
+     *                      |                      |
+     *                      v                      |
+     *                +-----+-----+                |
+     *          +-----| Running   |<------------+  |
+     *          |     +-----+-----+             |  |
+     *          |           |                   |  |
+     *          |           v                   |  |
+     *          |     +-----+------------+      |  |
+     *          <---- |Partitions        |      |  |
+     *          |     |Revoked           |      |  |
+     *          |     +-----+------------+      |  |
+     *          |           |                   |  |
+     *          |           v                   |  |
+     *          |     +-----+------------+      |  |
+     *          |     |Assigning         |      |  |
+     *          |     |Partitions        |------+  |
+     *          |     +-----+------------+         |
+     *          |                                  |
+     *          |                                  |
+     *          |    +-----+----------+            |
+     *          +--->|Pending         |------------+
+     *               |Shutdown        |
+     *               +-----+----------+
+     *
+     */
+    public enum State {
+        NOT_RUNNING(1), RUNNING(1, 2, 4), PARTITIONS_REVOKED(3, 4), 
ASSIGNING_PARTITIONS(1), PENDING_SHUTDOWN(0);
+
+        private final Set<Integer> validTransitions = new HashSet<>();
+
+        State(final Integer...validTransitions) {
+            this.validTransitions.addAll(Arrays.asList(validTransitions));
+        }
+
+        public boolean isRunning() {
+            return !this.equals(PENDING_SHUTDOWN) && !this.equals(NOT_RUNNING);
+        }
+
+        public boolean isValidTransition(final State newState) {
+            return validTransitions.contains(newState.ordinal());
+        }
+    }
+    private volatile State state = State.NOT_RUNNING;
+    private StateListener stateListener = null;
+
+    /**
+     * Listen to state change events
+     */
+    public interface StateListener {
+
+        /**
+         * Called when state changes
+         * @param newState     current state
+         * @param oldState     previous state
+         */
+        void onChange(final State newState, final State oldState);
+    }
+
+    /**
+     * Set the {@link StateListener} to be notified when state changes.
+     * Note this API is internal to Kafka Streams and is not intended to be 
used by an
+     * external application.
+     * @param listener
+     */
+    public void setStateListener(final StateListener listener) {
+        this.stateListener = listener;
+    }
+
+    /**
+     * @return The state this instance is in
+     */
+    public synchronized State state() {
+        return state;
+    }
+
+    private synchronized void setState(State newState) {
+        State oldState = state;
+        if (!state.isValidTransition(newState)) {
+            throw new IllegalStateException("Incorrect state transition from " 
+ state + " to " + newState);
+        }
+        state = newState;
+        if (stateListener != null) {
+            synchronized (stateListener) {
+                stateListener.onChange(state, oldState);
+            }
+        }
+    }
+
     public final PartitionGrouper partitionGrouper;
     private final StreamsMetadataState streamsMetadataState;
     public final String applicationId;
@@ -87,7 +183,6 @@ public class StreamThread extends Thread {
 
     private final String logPrefix;
     private final String threadClientId;
-    private final AtomicBoolean running;
     private final Map<TaskId, StreamTask> activeTasks;
     private final Map<TaskId, StandbyTask> standbyTasks;
     private final Map<TopicPartition, StreamTask> activeTasksByPartition;
@@ -111,7 +206,6 @@ public class StreamThread extends Thread {
 
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
standbyRecords;
     private boolean processStandbyRecords = false;
-    private AtomicBoolean initialized = new AtomicBoolean(false);
 
     private ThreadCache cache;
 
@@ -119,13 +213,20 @@ public class StreamThread extends Thread {
         @Override
         public void onPartitionsAssigned(Collection<TopicPartition> 
assignment) {
             try {
-                log.info("stream-thread [{}] New partitions [{}] assigned at 
the end of consumer rebalance.",
+                if (state == State.PENDING_SHUTDOWN) {
+                    log.info("stream-thread [{}] New partitions [{}] assigned 
while shutting down.",
                         StreamThread.this.getName(), assignment);
+                    return;
+                }
+                log.info("stream-thread [{}] New partitions [{}] assigned at 
the end of consumer rebalance.",
+                    StreamThread.this.getName(), assignment);
+
+                setState(State.ASSIGNING_PARTITIONS);
                 addStreamTasks(assignment);
                 addStandbyTasks();
                 lastCleanMs = time.milliseconds(); // start the cleaning cycle
                 
streamsMetadataState.onChange(partitionAssignor.getPartitionsByHostState(), 
partitionAssignor.clusterMetadata());
-                initialized.set(true);
+                setState(State.RUNNING);
             } catch (Throwable t) {
                 rebalanceException = t;
                 throw t;
@@ -135,9 +236,14 @@ public class StreamThread extends Thread {
         @Override
         public void onPartitionsRevoked(Collection<TopicPartition> assignment) 
{
             try {
+                if (state == State.PENDING_SHUTDOWN) {
+                    log.info("stream-thread [{}] New partitions [{}] revoked 
while shutting down.",
+                        StreamThread.this.getName(), assignment);
+                    return;
+                }
                 log.info("stream-thread [{}] partitions [{}] revoked at the 
beginning of consumer rebalance.",
                         StreamThread.this.getName(), assignment);
-                initialized.set(false);
+                setState(State.PARTITIONS_REVOKED);
                 lastCleanMs = Long.MAX_VALUE; // stop the cleaning cycle until 
partitions are assigned
                 // suspend active tasks
                 suspendTasksAndState(true);
@@ -152,8 +258,8 @@ public class StreamThread extends Thread {
         }
     };
 
-    public boolean isInitialized() {
-        return initialized.get();
+    public synchronized boolean isInitialized() {
+        return state == State.RUNNING;
     }
 
     public StreamThread(TopologyBuilder builder,
@@ -166,7 +272,6 @@ public class StreamThread extends Thread {
                         Time time,
                         StreamsMetadataState streamsMetadataState) {
         super("StreamThread-" + STREAM_THREAD_ID_SEQUENCE.getAndIncrement());
-
         this.applicationId = applicationId;
         String threadName = getName();
         this.config = config;
@@ -220,9 +325,7 @@ public class StreamThread extends Thread {
         this.timerStartedMs = time.milliseconds();
         this.lastCleanMs = Long.MAX_VALUE; // the cleaning cycle won't start 
until partition assignment
         this.lastCommitMs = timerStartedMs;
-
-
-        this.running = new AtomicBoolean(true);
+        setState(state.RUNNING);
     }
 
     public void partitionAssignor(StreamPartitionAssignor partitionAssignor) {
@@ -256,8 +359,9 @@ public class StreamThread extends Thread {
     /**
      * Shutdown this stream thread.
      */
-    public void close() {
-        running.set(false);
+    public synchronized void close() {
+        log.info("{} Informed thread to shut down", logPrefix);
+        setState(State.PENDING_SHUTDOWN);
     }
 
     public Map<TaskId, StreamTask> tasks() {
@@ -290,7 +394,7 @@ public class StreamThread extends Thread {
         removeStandbyTasks();
 
         log.info("{} Stream thread shutdown complete", logPrefix);
-        running.set(false);
+        setState(State.NOT_RUNNING);
     }
 
     private void unAssignChangeLogPartitions(final boolean rethrowExceptions) {
@@ -493,7 +597,7 @@ public class StreamThread extends Thread {
 
             maybeClean();
         }
-        log.debug("{} Shutting down at user request", logPrefix);
+        log.info("{} Shutting down at user request", logPrefix);
     }
 
     private void maybeUpdateStandbyTasks() {
@@ -540,8 +644,8 @@ public class StreamThread extends Thread {
         }
     }
 
-    public boolean stillRunning() {
-        return running.get();
+    public synchronized boolean stillRunning() {
+        return state.isRunning();
     }
 
     private void maybePunctuate(StreamTask task) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/ea42d653/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java 
b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index e17e89f..e8b46cc 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -59,8 +59,17 @@ public class KafkaStreamsTest {
 
         final KStreamBuilder builder = new KStreamBuilder();
         final KafkaStreams streams = new KafkaStreams(builder, props);
+        StateListenerStub stateListener = new StateListenerStub();
+        streams.setStateListener(stateListener);
+        Assert.assertEquals(streams.state(), KafkaStreams.State.CREATED);
+        Assert.assertEquals(stateListener.numChanges, 0);
 
         streams.start();
+        Assert.assertEquals(streams.state(), KafkaStreams.State.RUNNING);
+        Assert.assertEquals(stateListener.numChanges, 1);
+        Assert.assertEquals(stateListener.oldState, 
KafkaStreams.State.CREATED);
+        Assert.assertEquals(stateListener.newState, 
KafkaStreams.State.RUNNING);
+
         final int newInitCount = MockMetricsReporter.INIT_COUNT.get();
         final int initCountDifference = newInitCount - oldInitCount;
         assertTrue("some reporters should be initialized by calling start()", 
initCountDifference > 0);
@@ -68,6 +77,7 @@ public class KafkaStreamsTest {
         assertTrue(streams.close(15, TimeUnit.SECONDS));
         Assert.assertEquals("each reporter initialized should also be closed",
             oldCloseCount + initCountDifference, 
MockMetricsReporter.CLOSE_COUNT.get());
+        Assert.assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
     }
 
     @Test
@@ -100,7 +110,7 @@ public class KafkaStreamsTest {
         try {
             streams.start();
         } catch (final IllegalStateException e) {
-            Assert.assertEquals("Cannot restart after closing.", 
e.getMessage());
+            Assert.assertEquals("Cannot start again.", e.getMessage());
             throw e;
         } finally {
             streams.close();
@@ -120,7 +130,7 @@ public class KafkaStreamsTest {
         try {
             streams.start();
         } catch (final IllegalStateException e) {
-            Assert.assertEquals("This process was already started.", 
e.getMessage());
+            Assert.assertEquals("Cannot start again.", e.getMessage());
             throw e;
         } finally {
             streams.close();
@@ -246,4 +256,18 @@ public class KafkaStreamsTest {
             streams.close();
         }
     }
+
+
+    public static class StateListenerStub implements 
KafkaStreams.StateListener {
+        public int numChanges = 0;
+        public KafkaStreams.State oldState;
+        public KafkaStreams.State newState;
+
+        @Override
+        public void onChange(final KafkaStreams.State newState, final 
KafkaStreams.State oldState) {
+            this.numChanges++;
+            this.oldState = oldState;
+            this.newState = newState;
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/ea42d653/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
index d89c33a..0c5c79c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
@@ -276,6 +276,7 @@ public class QueryableStateIntegrationTest {
                         return false;
                     } catch (final InvalidStateStoreException e) {
                         // rebalance
+                        assertEquals(streams.state(), 
KafkaStreams.State.REBALANCING);
                         return false;
                     }
 
@@ -306,6 +307,7 @@ public class QueryableStateIntegrationTest {
                         return false;
                     } catch (InvalidStateStoreException e) {
                         // rebalance
+                        assertEquals(streams.state(), 
KafkaStreams.State.REBALANCING);
                         return false;
                     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/ea42d653/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index e3aaab8..70bea14 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -40,6 +40,7 @@ import org.apache.kafka.streams.processor.TopologyBuilder;
 import org.apache.kafka.test.MockClientSupplier;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.MockTimestampExtractor;
+import org.junit.Assert;
 import org.junit.Test;
 
 import java.io.File;
@@ -159,6 +160,8 @@ public class StreamThreadTest {
     @Test
     public void testPartitionAssignmentChange() throws Exception {
         StreamsConfig config = new StreamsConfig(configProps());
+        StateListenerStub stateListener = new StateListenerStub();
+
 
         TopologyBuilder builder = new TopologyBuilder().setApplicationId("X");
         builder.addSource("source1", "topic1");
@@ -173,7 +176,8 @@ public class StreamThreadTest {
                 return new TestStreamTask(id, applicationId, 
partitionsForTask, topology, consumer, producer, restoreConsumer, config, 
stateDirectory);
             }
         };
-
+        thread.setStateListener(stateListener);
+        assertEquals(thread.state(), StreamThread.State.RUNNING);
         initPartitionGrouper(config, thread);
 
         ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
@@ -190,7 +194,15 @@ public class StreamThreadTest {
         expectedGroup1 = new HashSet<>(Arrays.asList(t1p1));
 
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
+        assertEquals(thread.state(), StreamThread.State.PARTITIONS_REVOKED);
+        Assert.assertEquals(stateListener.numChanges, 1);
+        Assert.assertEquals(stateListener.oldState, 
StreamThread.State.RUNNING);
+        Assert.assertEquals(stateListener.newState, 
StreamThread.State.PARTITIONS_REVOKED);
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
+        assertEquals(thread.state(), StreamThread.State.RUNNING);
+        Assert.assertEquals(stateListener.numChanges, 3);
+        Assert.assertEquals(stateListener.oldState, 
StreamThread.State.ASSIGNING_PARTITIONS);
+        Assert.assertEquals(stateListener.newState, 
StreamThread.State.RUNNING);
 
         assertTrue(thread.tasks().containsKey(task1));
         assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
@@ -272,6 +284,10 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
         assertTrue(thread.tasks().isEmpty());
+
+        thread.close();
+        assertTrue((thread.state() == StreamThread.State.PENDING_SHUTDOWN) ||
+            (thread.state() == StreamThread.State.NOT_RUNNING));
     }
 
 
@@ -510,4 +526,22 @@ public class StreamThreadTest {
 
         partitionAssignor.onAssignment(assignments.get("client"));
     }
+
+    public static class StateListenerStub implements 
StreamThread.StateListener {
+        public int numChanges = 0;
+        public StreamThread.State oldState = null;
+        public StreamThread.State newState = null;
+
+        @Override
+        public void onChange(final StreamThread.State newState, final 
StreamThread.State oldState) {
+            this.numChanges++;
+            if (this.newState != null) {
+                if (this.newState != oldState) {
+                    throw new RuntimeException("State mismatch " + oldState + 
" different from " + this.newState);
+                }
+            }
+            this.oldState = oldState;
+            this.newState = newState;
+        }
+    }
 }

Reply via email to