You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by da...@apache.org on 2017/07/06 09:05:22 UTC

kafka git commit: KAFKA-5372; fixes to state transitions

Repository: kafka
Updated Branches:
  refs/heads/trunk f53f5eaa1 -> 907855423


KAFKA-5372; fixes to state transitions

Several fixes to state transition logic:
- Kafka streams will now be in ERROR when all threads are DEAD or when global thread stops unexpectedly
- Fixed transition logic in corner cases when thread is already dead or Kafka Streams is already closed
- Fixed incorrect transition diagram in StreamThread
- Unit tests to verify transitions

Also:
- re-enabled throwing an exception when an unexpected state change happens
- fixed a whole bunch of EoS tests that did not start a thread
- added more comments.

Author: Eno Thereska <en...@gmail.com>

Reviewers: Matthias J. Sax <ma...@confluent.io>, Guozhang Wang <wa...@gmail.com>, Damian Guy <da...@gmail.com>

Closes #3432 from enothereska/KAFKA-5372-state-transitions


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/90785542
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/90785542
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/90785542

Branch: refs/heads/trunk
Commit: 9078554232f9de56849b5c35bf45b81a67a748c4
Parents: f53f5ea
Author: Eno Thereska <en...@gmail.com>
Authored: Thu Jul 6 10:05:16 2017 +0100
Committer: Damian Guy <da...@gmail.com>
Committed: Thu Jul 6 10:05:16 2017 +0100

----------------------------------------------------------------------
 .../org/apache/kafka/streams/KafkaStreams.java  | 332 +++++++++++++------
 .../processor/internals/GlobalStreamThread.java | 135 +++++++-
 .../processor/internals/StreamThread.java       | 120 ++++---
 .../ThreadStateTransitionValidator.java         |  24 ++
 .../apache/kafka/streams/KafkaStreamsTest.java  | 149 ++++++++-
 .../internals/GlobalStreamThreadTest.java       |  48 ++-
 .../processor/internals/StreamThreadTest.java   | 247 +++++++++-----
 streams/src/test/resources/log4j.properties     |   4 +-
 8 files changed, 808 insertions(+), 251 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
index c16f379..67535d6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
@@ -46,6 +46,7 @@ import org.apache.kafka.streams.processor.internals.StateDirectory;
 import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.apache.kafka.streams.processor.internals.StreamsKafkaClient;
 import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
+import org.apache.kafka.streams.processor.internals.ThreadStateTransitionValidator;
 import org.apache.kafka.streams.state.HostInfo;
 import org.apache.kafka.streams.state.QueryableStoreType;
 import org.apache.kafka.streams.state.StreamsMetadata;
@@ -73,6 +74,9 @@ import java.util.concurrent.TimeUnit;
 
 import static org.apache.kafka.common.utils.Utils.getHost;
 import static org.apache.kafka.common.utils.Utils.getPort;
+import static org.apache.kafka.streams.KafkaStreams.State.ERROR;
+import static org.apache.kafka.streams.KafkaStreams.State.NOT_RUNNING;
+import static org.apache.kafka.streams.KafkaStreams.State.PENDING_SHUTDOWN;
 import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
 import static org.apache.kafka.streams.StreamsConfig.PROCESSING_GUARANTEE_CONFIG;
 
@@ -134,7 +138,6 @@ public class KafkaStreams {
     private final UUID processId;
     private final String logPrefix;
     private final StreamsMetadataState streamsMetadataState;
-
     private final StreamsConfig config;
 
     // container states
@@ -150,9 +153,9 @@ public class KafkaStreams {
      *         |       +-----+--------+
      *         |             |
      *         |             v
-     *         |       +-----+--------+
-     *         +<----- | Rebalancing  | <----+
-     *         |       +--------------+      |
+     *         |       +-----+--------+ <-+
+     *         +<----- | Rebalancing  | --+
+     *         |       +--------------+ <----+
      *         |                             |
      *         |                             |
      *         |       +--------------+      |
@@ -162,17 +165,30 @@ public class KafkaStreams {
      *         |             v
      *         |       +-----+--------+
      *         +-----> | Pending      |
-     *                 | Shutdown     |
-     *                 +-----+--------+
-     *                       |
-     *                       v
-     *                 +-----+--------+
-     *                 | Not Running  |
+     *         |       | Shutdown     |
+     *         |       +-----+--------+
+     *         |             |
+     *         |             v
+     *         |       +-----+--------+
+     *         +-----> | Not Running  |
+     *         |       +--------------+
+     *         |
+     *         |       +--------------+
+     *         +-----> | Error        |
      *                 +--------------+
+     *
+     *
      * </pre>
+     * Note the following:
+     * - Any state can go to PENDING_SHUTDOWN (during clean shutdown) or NOT_RUNNING (e.g., during an exception).
+     * - It is theoretically possible for a thread to always be in the PARTITION_REVOKED state
+     * (see {@code StreamThread} state diagram) and hence it is possible that this instance is always
+     * on a REBALANCING state.
+     * - Of special importance: If the global stream thread dies, or all stream threads die (or both) then
+     * the instance will be in the ERROR state. The user will need to close it.
      */
     public enum State {
-        CREATED(1, 2, 3), RUNNING(2, 3), REBALANCING(1, 2, 3), PENDING_SHUTDOWN(4), NOT_RUNNING;
+        CREATED(1, 2, 3, 5), REBALANCING(1, 2, 3, 4, 5), RUNNING(1, 3, 4, 5), PENDING_SHUTDOWN(4), NOT_RUNNING, ERROR;
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
@@ -192,9 +208,7 @@ public class KafkaStreams {
     }
 
     private final Object stateLock = new Object();
-
     private volatile State state = State.CREATED;
-
     private KafkaStreams.StateListener stateListener = null;
 
 
@@ -220,11 +234,27 @@ public class KafkaStreams {
         stateListener = listener;
     }
 
+    /**
+     * Sets the state
+     * @param newState New state
+     */
     private void setState(final State newState) {
+
         synchronized (stateLock) {
+
+            // there are cases when we shouldn't check if a transition is valid, e.g.,
+            // when, for testing, Kafka Streams is closed multiple times. We could either
+            // check here and immediately return for those cases, or add them to the transition
+            // diagram (but then the diagram would be confusing and have transitions like
+            // NOT_RUNNING->NOT_RUNNING).
+            if (newState != NOT_RUNNING && (state == State.NOT_RUNNING || state == PENDING_SHUTDOWN)) {
+                return;
+            }
+
             final State oldState = state;
             if (!state.isValidTransition(newState)) {
                 log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
+                throw new StreamsException(logPrefix + " Unexpected state transition from " + oldState + " to " + newState);
             } else {
                 log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
             }
@@ -240,8 +270,10 @@ public class KafkaStreams {
      *
      * @return the currnt state of this Kafka Streams instance
      */
-    public synchronized State state() {
-        return state;
+    public State state() {
+        synchronized (stateLock) {
+            return state;
+        }
     }
 
     /**
@@ -253,33 +285,104 @@ public class KafkaStreams {
         return Collections.unmodifiableMap(metrics.metrics());
     }
 
-    private final class StreamStateListener implements StreamThread.StateListener {
 
+    /**
+     * Class that handles stream thread transitions
+     */
+    final class StreamStateListener implements StreamThread.StateListener {
         private final Map<Long, StreamThread.State> threadState;
+        private GlobalStreamThread.State globalThreadState;
 
-        StreamStateListener(Map<Long, StreamThread.State> threadState) {
+        StreamStateListener(final Map<Long, StreamThread.State> threadState,
+                            final GlobalStreamThread.State globalThreadState) {
             this.threadState = threadState;
+            this.globalThreadState = globalThreadState;
+        }
+
+        /**
+         * If all threads are dead set to ERROR
+         */
+        private void checkAllThreadsDeadAndSetError() {
+
+            synchronized (stateLock) {
+                // if we are pending a shutdown, it's ok for all threads to die, in fact
+                // it is expected. Otherwise, it is an error
+                if (state != PENDING_SHUTDOWN) {
+                    // one thread died, check if we have enough threads running
+                    for (final StreamThread.State state : threadState.values()) {
+                        if (state != StreamThread.State.DEAD) {
+                            return;
+                        }
+                    }
+                    log.warn("{} All stream threads have died. The Kafka Streams instance will be in an error state and should be closed.",
+                            logPrefix);
+                    setState(ERROR);
+                }
+            }
         }
 
+        /**
+         * If all global thread is DEAD
+         */
+        private void maybeSetErrorSinceGlobalStreamThreadIsDead() {
+
+            synchronized (stateLock) {
+                // if we are pending a shutdown, it's ok for all threads to die, in fact
+                // it is expected. Otherwise, it is an error
+                if (state != PENDING_SHUTDOWN) {
+                    log.warn("{} Global Stream thread has died. The Kafka Streams instance will be in an error state and should be closed.",
+                            logPrefix);
+                    setState(ERROR);
+                }
+            }
+        }
+
+        /**
+         * If all threads are up, including the global thread, set to RUNNING
+         */
+        private void maybeSetRunning() {
+            // one thread is running, check others, including global thread
+            for (final StreamThread.State state : threadState.values()) {
+                if (state != StreamThread.State.RUNNING) {
+                    return;
+                }
+            }
+            // the global state thread is relevant only if it is started. There are cases
+            // when we don't have a global state thread at all, e.g., when we don't have global KTables
+            if (globalThreadState != null && globalThreadState != GlobalStreamThread.State.RUNNING) {
+                return;
+            }
+
+            setState(State.RUNNING);
+        }
+
+
         @Override
-        public synchronized void onChange(final StreamThread thread,
-                                          final StreamThread.State newState,
-                                          final StreamThread.State oldState) {
-            if (newState != StreamThread.State.DEAD) {
+        public synchronized void onChange(final Thread thread,
+                                          final ThreadStateTransitionValidator abstractNewState,
+                                          final ThreadStateTransitionValidator abstractOldState) {
+            // StreamThreads first
+            if (thread instanceof StreamThread) {
+                StreamThread.State newState = (StreamThread.State) abstractNewState;
                 threadState.put(thread.getId(), newState);
-            } else {
-                threadState.remove(thread.getId());
-            }
-            if (newState == StreamThread.State.PARTITIONS_REVOKED ||
-                newState == StreamThread.State.ASSIGNING_PARTITIONS) {
-                setState(State.REBALANCING);
-            } else if (newState == StreamThread.State.RUNNING) {
-                for (final StreamThread.State state : threadState.values()) {
-                    if (state != StreamThread.State.RUNNING) {
-                        return;
-                    }
+
+                if (newState == StreamThread.State.PARTITIONS_REVOKED ||
+                        newState == StreamThread.State.ASSIGNING_PARTITIONS) {
+                    setState(State.REBALANCING);
+                } else if (newState == StreamThread.State.RUNNING && state() != State.RUNNING) {
+                    maybeSetRunning();
+                } else if (newState == StreamThread.State.DEAD) {
+                    checkAllThreadsDeadAndSetError();
+                }
+            } else if (thread instanceof GlobalStreamThread) {
+                // global stream thread has different invariants
+                GlobalStreamThread.State newState = (GlobalStreamThread.State) abstractNewState;
+                globalThreadState = newState;
+
+                // special case when global thread is dead
+                if (newState == GlobalStreamThread.State.DEAD) {
+                    maybeSetErrorSinceGlobalStreamThreadIsDead();
                 }
-                setState(State.RUNNING);
             }
         }
     }
@@ -346,6 +449,8 @@ public class KafkaStreams {
 
         threads = new StreamThread[config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG)];
         final Map<Long, StreamThread.State> threadState = new HashMap<>(threads.length);
+        GlobalStreamThread.State globalThreadState = null;
+
         final ArrayList<StateStoreProvider> storeProviders = new ArrayList<>();
         streamsMetadataState = new StreamsMetadataState(builder, parseHostInfo(config.getString(StreamsConfig.APPLICATION_SERVER_CONFIG)));
 
@@ -358,7 +463,6 @@ public class KafkaStreams {
         final long cacheSizeBytes = Math.max(0, config.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG) /
                 (config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG) + (globalTaskTopology == null ? 0 : 1)));
 
-
         if (globalTaskTopology != null) {
             final String globalThreadId = clientId + "-GlobalStreamThread";
             globalStreamThread = new GlobalStreamThread(globalTaskTopology,
@@ -368,9 +472,9 @@ public class KafkaStreams {
                                                         metrics,
                                                         time,
                                                         globalThreadId);
+            globalThreadState = globalStreamThread.state();
         }
 
-        final StreamStateListener streamStateListener = new StreamStateListener(threadState);
         for (int i = 0; i < threads.length; i++) {
             threads[i] = new StreamThread(builder,
                                           config,
@@ -382,10 +486,17 @@ public class KafkaStreams {
                                           time,
                                           streamsMetadataState,
                                           cacheSizeBytes);
-            threads[i].setStateListener(streamStateListener);
             threadState.put(threads[i].getId(), threads[i].state());
             storeProviders.add(new StreamThreadStateStoreProvider(threads[i]));
         }
+        final StreamStateListener streamStateListener = new StreamStateListener(threadState, globalThreadState);
+        if (globalTaskTopology != null) {
+            globalStreamThread.setStateListener(streamStateListener);
+        }
+        for (int i = 0; i < threads.length; i++) {
+            threads[i].setStateListener(streamStateListener);
+        }
+
         final GlobalStateStoreProvider globalStateStoreProvider = new GlobalStateStoreProvider(builder.globalStateStores());
         queryableStoreProvider = new QueryableStoreProvider(storeProviders, globalStateStoreProvider);
     }
@@ -425,6 +536,16 @@ public class KafkaStreams {
 
     }
 
+    private void validateStartOnce() {
+        synchronized (stateLock) {
+            if (state == State.CREATED) {
+                state = State.RUNNING;
+            } else {
+                throw new IllegalStateException("Cannot start again.");
+            }
+        }
+    }
+
     /**
      * Start the {@code KafkaStreams} instance by starting all its threads.
      * <p>
@@ -437,23 +558,18 @@ public class KafkaStreams {
      */
     public synchronized void start() throws IllegalStateException, StreamsException {
         log.debug("{} Starting Kafka Stream process.", logPrefix);
+        validateStartOnce();
+        checkBrokerVersionCompatibility();
 
-        if (state == State.CREATED) {
-            checkBrokerVersionCompatibility();
-            setState(State.RUNNING);
-
-            if (globalStreamThread != null) {
-                globalStreamThread.start();
-            }
-
-            for (final StreamThread thread : threads) {
-                thread.start();
-            }
+        if (globalStreamThread != null) {
+            globalStreamThread.start();
+        }
 
-            log.info("{} Started Kafka Stream process", logPrefix);
-        } else {
-            throw new IllegalStateException("Cannot start again.");
+        for (final StreamThread thread : threads) {
+            thread.start();
         }
+
+        log.info("{} Started Kafka Stream process", logPrefix);
     }
 
     /**
@@ -464,6 +580,31 @@ public class KafkaStreams {
         close(DEFAULT_CLOSE_TIMEOUT, TimeUnit.SECONDS);
     }
 
+
+    private boolean checkFirstTimeClosing() {
+        synchronized (stateLock) {
+            if (state.isCreatedOrRunning() || state == ERROR) {
+                state = PENDING_SHUTDOWN;
+                return true;
+            }
+            return false;
+        }
+    }
+
+    private void closeGlobalStreamThread() {
+        if (globalStreamThread != null) {
+            globalStreamThread.close();
+            if (!globalStreamThread.stillRunning()) {
+                try {
+                    globalStreamThread.join();
+                } catch (final InterruptedException e) {
+                    Thread.interrupted();
+                }
+            }
+            globalStreamThread = null;
+        }
+    }
+
     /**
      * Shutdown this {@code KafkaStreams} by signaling all the threads to stop, and then wait up to the timeout for the
      * threads to join.
@@ -476,55 +617,48 @@ public class KafkaStreams {
      */
     public synchronized boolean close(final long timeout, final TimeUnit timeUnit) {
         log.debug("{} Stopping Kafka Stream process.", logPrefix);
-        if (state.isCreatedOrRunning()) {
-            setState(State.PENDING_SHUTDOWN);
-            // save the current thread so that if it is a stream thread
-            // we don't attempt to join it and cause a deadlock
-            final Thread shutdown = new Thread(new Runnable() {
-                @Override
-                public void run() {
-                    // signal the threads to stop and wait
-                    for (final StreamThread thread : threads) {
-                        // avoid deadlocks by stopping any further state reports
-                        // from the thread since we're shutting down
-                        thread.setStateListener(null);
-                        thread.close();
-                    }
-                    if (globalStreamThread != null) {
-                        globalStreamThread.close();
-                        if (!globalStreamThread.stillRunning()) {
-                            try {
-                                globalStreamThread.join();
-                            } catch (final InterruptedException e) {
-                                Thread.interrupted();
-                            }
-                        }
-                    }
-                    for (final StreamThread thread : threads) {
-                        try {
-                            if (!thread.stillRunning()) {
-                                thread.join();
-                            }
-                        } catch (final InterruptedException ex) {
-                            Thread.interrupted();
+
+        // only clean up once
+        if (!checkFirstTimeClosing()) {
+            return true;
+        }
+
+        // save the current thread so that if it is a stream thread
+        // we don't attempt to join it and cause a deadlock
+        final Thread shutdown = new Thread(new Runnable() {
+            @Override
+            public void run() {
+                // signal the threads to stop and wait
+                for (final StreamThread thread : threads) {
+                    // avoid deadlocks by stopping any further state reports
+                    // from the thread since we're shutting down
+                    thread.setStateListener(null);
+                    thread.close();
+                }
+                closeGlobalStreamThread();
+                for (final StreamThread thread : threads) {
+                    try {
+                        if (!thread.stillRunning()) {
+                            thread.join();
                         }
+                    } catch (final InterruptedException ex) {
+                        Thread.interrupted();
                     }
-
-                    metrics.close();
-                    log.info("{} Stopped Kafka Streams process.", logPrefix);
                 }
-            }, "kafka-streams-close-thread");
-            shutdown.setDaemon(true);
-            shutdown.start();
-            try {
-                shutdown.join(TimeUnit.MILLISECONDS.convert(timeout, timeUnit));
-            } catch (final InterruptedException e) {
-                Thread.interrupted();
+
+                metrics.close();
+                log.info("{} Stopped Kafka Streams process.", logPrefix);
             }
-            setState(State.NOT_RUNNING);
-            return !shutdown.isAlive();
+        }, "kafka-streams-close-thread");
+        shutdown.setDaemon(true);
+        shutdown.start();
+        try {
+            shutdown.join(TimeUnit.MILLISECONDS.convert(timeout, timeUnit));
+        } catch (final InterruptedException e) {
+            Thread.interrupted();
         }
-        return true;
+        setState(State.NOT_RUNNING);
+        return !shutdown.isAlive();
     }
 
     /**
@@ -561,6 +695,12 @@ public class KafkaStreams {
         return sb.toString();
     }
 
+    private boolean isRunning() {
+        synchronized (stateLock) {
+            return state.isRunning();
+        }
+    }
+
     /**
      * Do a clean up of the local {@link StateStore} directory ({@link StreamsConfig#STATE_DIR_CONFIG}) by deleting all
      * data with regard to the {@link StreamsConfig#APPLICATION_ID_CONFIG application ID}.
@@ -573,7 +713,7 @@ public class KafkaStreams {
      * @throws IllegalStateException if the instance is currently running
      */
     public void cleanUp() {
-        if (state.isRunning()) {
+        if (isRunning()) {
             throw new IllegalStateException("Cannot clean up while running.");
         }
 
@@ -729,7 +869,7 @@ public class KafkaStreams {
     }
 
     private void validateIsRunning() {
-        if (!state.isRunning()) {
+        if (!isRunning()) {
             throw new IllegalStateException("KafkaStreams is not running. State is " + state + ".");
         }
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
index 36a248e..11b89df 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
@@ -31,8 +31,14 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
-import java.util.Collections;
 import java.util.Map;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD;
+import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.PENDING_SHUTDOWN;
 
 /**
  * This is the thread responsible for keeping all Global State Stores updated.
@@ -49,9 +55,114 @@ public class GlobalStreamThread extends Thread {
     private final ThreadCache cache;
     private final StreamsMetrics streamsMetrics;
     private final ProcessorTopology topology;
-    private volatile boolean running = false;
     private volatile StreamsException startupException;
 
+    /**
+     * The states that the global stream thread can be in
+     *
+     * <pre>
+     *                +-------------+
+     *          +<--- | Created     |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
+     *          +<--- | Running     |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
+     *          +---> | Pending     |
+     *          |     | Shutdown    |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
+     *          +---> | Dead        |
+     *                +-------------+
+     * </pre>
+     *
+     * Note the following:
+     * - Any state can go to PENDING_SHUTDOWN. That is because streams can be closed at any time.
+     * - Any state can go to DEAD. That is because exceptions can happen at any other state,
+     *   leading to the stream thread terminating.
+     *
+     */
+    public enum State implements ThreadStateTransitionValidator {
+        CREATED(1, 2, 3), RUNNING(2, 3), PENDING_SHUTDOWN(3), DEAD;
+
+        private final Set<Integer> validTransitions = new HashSet<>();
+
+        State(final Integer... validTransitions) {
+            this.validTransitions.addAll(Arrays.asList(validTransitions));
+        }
+
+        public boolean isRunning() {
+            return !equals(PENDING_SHUTDOWN) && !equals(CREATED) && !equals(DEAD);
+        }
+
+        public boolean isValidTransition(final ThreadStateTransitionValidator newState) {
+            State tmpState = (State) newState;
+            return validTransitions.contains(tmpState.ordinal());
+        }
+    }
+
+    private volatile State state = State.CREATED;
+    private final Object stateLock = new Object();
+    private StreamThread.StateListener stateListener = null;
+    private final String logPrefix;
+
+
+    /**
+     * Set the {@link StreamThread.StateListener} to be notified when state changes. Note this API is internal to
+     * Kafka Streams and is not intended to be used by an external application.
+     */
+    public void setStateListener(final StreamThread.StateListener listener) {
+        stateListener = listener;
+    }
+
+    /**
+     * @return The state this instance is in
+     */
+    public State state() {
+        synchronized (stateLock) {
+            return state;
+        }
+    }
+
+    /**
+     * Sets the state
+     * @param newState New state
+     * @param ignoreWhenShuttingDownOrDead,       if true, then we'll first check if the state is
+     *                                            PENDING_SHUTDOWN or DEAD, and if it is,
+     *                                            we immediately return. Effectively this enables
+     *                                            a conditional set, under the stateLock lock.
+     */
+    void setState(final State newState, boolean ignoreWhenShuttingDownOrDead) {
+        synchronized (stateLock) {
+            final State oldState = state;
+
+            if (ignoreWhenShuttingDownOrDead) {
+                if (state == PENDING_SHUTDOWN || state == DEAD) {
+                    return;
+                }
+            }
+
+            if (!state.isValidTransition(newState)) {
+                log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
+                throw new StreamsException(logPrefix + " Unexpected state transition from " + oldState + " to " + newState);
+            } else {
+                log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
+            }
+
+            state = newState;
+            if (stateListener != null) {
+                stateListener.onChange(this, state, oldState);
+            }
+        }
+    }
+
+
     public GlobalStreamThread(final ProcessorTopology topology,
                               final StreamsConfig config,
                               final Consumer<byte[], byte[]> globalConsumer,
@@ -69,6 +180,7 @@ public class GlobalStreamThread extends Thread {
                 (config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG) + 1));
         this.streamsMetrics = new StreamsMetricsImpl(metrics, threadClientId, Collections.singletonMap("client-id", threadClientId));
         this.cache = new ThreadCache(threadClientId, cacheSizeBytes, streamsMetrics);
+        this.logPrefix = String.format("global-stream-thread [%s]", threadClientId);
     }
 
     static class StateConsumer {
@@ -114,6 +226,7 @@ public class GlobalStreamThread extends Thread {
         }
 
         public void close() throws IOException {
+
             // just log an error if the consumer throws an exception during close
             // so we can always attempt to close the state stores.
             try {
@@ -127,20 +240,25 @@ public class GlobalStreamThread extends Thread {
         }
     }
 
+
     @Override
     public void run() {
         final StateConsumer stateConsumer = initialize();
+
         if (stateConsumer == null) {
             return;
         }
+        // one could kill the thread before it had a chance to actually start
+        setState(State.RUNNING, true);
 
         try {
-            while (running) {
+            while (stillRunning()) {
                 stateConsumer.pollAndUpdate();
             }
             log.debug("Shutting down GlobalStreamThread at user request");
         } finally {
             try {
+                setState(DEAD, false);
                 stateConsumer.close();
             } catch (IOException e) {
                 log.error("Failed to cleanly shutdown GlobalStreamThread", e);
@@ -164,7 +282,6 @@ public class GlobalStreamThread extends Thread {
                                         config.getLong(StreamsConfig.POLL_MS_CONFIG),
                                         config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG));
             stateConsumer.initialize();
-            running = true;
             return stateConsumer;
         } catch (StreamsException e) {
             startupException = e;
@@ -177,7 +294,7 @@ public class GlobalStreamThread extends Thread {
     @Override
     public synchronized void start() {
         super.start();
-        while (!running) {
+        while (!stillRunning()) {
             Utils.sleep(1);
             if (startupException != null) {
                 throw startupException;
@@ -187,11 +304,15 @@ public class GlobalStreamThread extends Thread {
 
 
     public void close() {
-        running = false;
+        // one could call close() multiple times, so ignore subsequent calls
+        // if already shutting down or dead
+        setState(PENDING_SHUTDOWN, true);
     }
 
     public boolean stillRunning() {
-        return running;
+        synchronized (stateLock) {
+            return state.isRunning();
+        }
     }
 
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index efdbeeb..9dc5640 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -82,40 +82,48 @@ public class StreamThread extends Thread {
      *
      * <pre>
      *                +-------------+
-     *                | Created     |
-     *                +-----+-------+
-     *                      |
-     *                      v
-     *                +-----+-------+
+     *          +<--- | Created     |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
      *          +<--- | Running     | <----+
      *          |     +-----+-------+      |
      *          |           |              |
      *          |           v              |
      *          |     +-----+-------+      |
-     *          +<--- | Partitions  |      |
-     *          |     | Revoked     |      |
+     *          +<--- | Partitions  | <-+  |
+     *          |     | Revoked     | --+  |
      *          |     +-----+-------+      |
      *          |           |              |
      *          |           v              |
      *          |     +-----+-------+      |
-     *          |     | Assigning   |      |
+     *          +<--- | Assigning   |      |
      *          |     | Partitions  | ---->+
      *          |     +-----+-------+
      *          |           |
      *          |           v
      *          |     +-----+-------+
      *          +---> | Pending     |
-     *                | Shutdown    |
-     *                +-----+-------+
-     *                      |
-     *                      v
-     *                +-----+-------+
-     *                | Dead        |
+     *          |     | Shutdown    |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
+     *          +---> | Dead        |
      *                +-------------+
      * </pre>
+     *
+     * Note the following:
+     * - Any state can go to PENDING_SHUTDOWN. That is because streams can be closed at any time.
+     * - Any state can go to DEAD. That is because exceptions can happen at any other state,
+     *   leading to the stream thread terminating.
+     * - A streams thread can stay in PARTITIONS_REVOKED indefinitely, in the corner case when
+     *   the coordinator repeatedly fails in-between revoking partitions and assigning new partitions.
+     *
      */
-    public enum State {
-        CREATED(1), RUNNING(1, 2, 4), PARTITIONS_REVOKED(3, 4), ASSIGNING_PARTITIONS(1, 4), PENDING_SHUTDOWN(5), DEAD;
+    public enum State implements ThreadStateTransitionValidator {
+        CREATED(1, 4, 5), RUNNING(2, 4, 5), PARTITIONS_REVOKED(2, 3, 4, 5), ASSIGNING_PARTITIONS(1, 4, 5), PENDING_SHUTDOWN(5), DEAD;
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
@@ -127,8 +135,10 @@ public class StreamThread extends Thread {
             return !equals(PENDING_SHUTDOWN) && !equals(CREATED) && !equals(DEAD);
         }
 
-        public boolean isValidTransition(final State newState) {
-            return validTransitions.contains(newState.ordinal());
+        @Override
+        public boolean isValidTransition(final ThreadStateTransitionValidator newState) {
+            State tmpState = (State) newState;
+            return validTransitions.contains(tmpState.ordinal());
         }
     }
 
@@ -143,7 +153,7 @@ public class StreamThread extends Thread {
          * @param newState     current state
          * @param oldState     previous state
          */
-        void onChange(final StreamThread thread, final State newState, final State oldState);
+        void onChange(final Thread thread, final ThreadStateTransitionValidator newState, final ThreadStateTransitionValidator oldState);
     }
 
     private class RebalanceListener implements ConsumerRebalanceListener {
@@ -175,7 +185,7 @@ public class StreamThread extends Thread {
             final long start = time.milliseconds();
             try {
                 storeChangelogReader = new StoreChangelogReader(getName(), restoreConsumer, time, requestTimeOut);
-                setStateWhenNotInPendingShutdown(State.ASSIGNING_PARTITIONS);
+                setState(State.ASSIGNING_PARTITIONS);
                 // do this first as we may have suspended standby tasks that
                 // will become active or vice versa
                 closeNonAssignedSuspendedStandbyTasks();
@@ -185,7 +195,7 @@ public class StreamThread extends Thread {
                 addStandbyTasks(start);
                 streamsMetadataState.onChange(partitionAssignor.getPartitionsByHostState(), partitionAssignor.clusterMetadata());
                 lastCleanMs = time.milliseconds(); // start the cleaning cycle
-                setStateWhenNotInPendingShutdown(State.RUNNING);
+                setState(State.RUNNING);
             } catch (final Throwable t) {
                 rebalanceException = t;
                 throw t;
@@ -214,7 +224,7 @@ public class StreamThread extends Thread {
 
             final long start = time.milliseconds();
             try {
-                setStateWhenNotInPendingShutdown(State.PARTITIONS_REVOKED);
+                setState(State.PARTITIONS_REVOKED);
                 lastCleanMs = Long.MAX_VALUE; // stop the cleaning cycle until partitions are assigned
                 // suspend active tasks
                 suspendTasksAndState();
@@ -378,6 +388,8 @@ public class StreamThread extends Thread {
 
 
     private volatile State state = State.CREATED;
+    private final Object stateLock = new Object();
+
     private StreamThread.StateListener stateListener = null;
     final PartitionGrouper partitionGrouper;
     private final StreamsMetadataState streamsMetadataState;
@@ -500,8 +512,6 @@ public class StreamThread extends Thread {
         lastCleanMs = Long.MAX_VALUE; // the cleaning cycle won't start until partition assignment
         lastCommitMs = timerStartedMs;
         rebalanceListener = new RebalanceListener(time, config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG));
-        setState(State.RUNNING);
-
     }
 
     /**
@@ -513,7 +523,7 @@ public class StreamThread extends Thread {
     @Override
     public void run() {
         log.info("{} Starting", logPrefix);
-
+        setState(State.RUNNING);
         boolean cleanRun = false;
         try {
             runLoop();
@@ -599,7 +609,6 @@ public class StreamThread extends Thread {
                 addToResetList(partition, seekToEnd, "{} Setting topic '{}' to consume from {} offset", "latest", loggedTopics);
             } else {
                 if (originalReset == null || (!originalReset.equals("earliest") && !originalReset.equals("latest"))) {
-                    setState(State.PENDING_SHUTDOWN);
                     final String errorMessage = "No valid committed offset found for input topic %s (partition %s) and no valid reset policy configured." +
                         " You need to set configuration parameter \"auto.offset.reset\" or specify a topic specific reset " +
                         "policy via KStreamBuilder#stream(StreamsConfig.AutoOffsetReset offsetReset, ...) or KStreamBuilder#table(StreamsConfig.AutoOffsetReset offsetReset, ...)";
@@ -928,6 +937,8 @@ public class StreamThread extends Thread {
 
     /**
      * Shutdown this stream thread.
+     * Note that there is nothing to prevent this function from being called multiple times
+     * (e.g., in testing), hence the state is set only the first time
      */
     public synchronized void close() {
         log.info("{} Informed thread to shut down", logPrefix);
@@ -935,11 +946,15 @@ public class StreamThread extends Thread {
     }
 
     public synchronized boolean isInitialized() {
-        return state == State.RUNNING;
+        synchronized (stateLock) {
+            return state == State.RUNNING;
+        }
     }
 
     public synchronized boolean stillRunning() {
-        return state.isRunning();
+        synchronized (stateLock) {
+            return state.isRunning();
+        }
     }
 
     public Map<TaskId, StreamTask> tasks() {
@@ -994,28 +1009,41 @@ public class StreamThread extends Thread {
     /**
      * @return The state this instance is in
      */
-    public synchronized State state() {
-        return state;
-    }
-
-    private synchronized void setStateWhenNotInPendingShutdown(final State newState) {
-        if (state == State.PENDING_SHUTDOWN) {
-            return;
+    public State state() {
+        synchronized (stateLock) {
+            return state;
         }
-        setState(newState);
     }
 
-    private synchronized void setState(final State newState) {
-        final State oldState = state;
-        if (!state.isValidTransition(newState)) {
-            log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
-        } else {
-            log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
-        }
 
-        state = newState;
-        if (stateListener != null) {
-            stateListener.onChange(this, state, oldState);
+    /**
+     * Sets the state
+     * @param newState New state
+     */
+    void setState(final State newState) {
+        synchronized (stateLock) {
+            final State oldState = state;
+
+            // there are cases when we shouldn't check if a transition is valid, e.g.,
+            // when, for testing, a thread is closed multiple times. We could either
+            // check here and immediately return for those cases, or add them to the transition
+            // diagram (but then the diagram would be confusing and have transitions like
+            // PENDING_SHUTDOWN->PENDING_SHUTDOWN).
+            if (newState != State.DEAD && (state == State.PENDING_SHUTDOWN || state == State.DEAD)) {
+                return;
+            }
+
+            if (!state.isValidTransition(newState)) {
+                log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
+                throw new StreamsException(logPrefix + " Unexpected state transition from " + oldState + " to " + newState);
+            } else {
+                log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
+            }
+
+            state = newState;
+            if (stateListener != null) {
+                stateListener.onChange(this, state, oldState);
+            }
         }
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java
new file mode 100644
index 0000000..4197c71
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+/**
+ * Basic interface for keeping track of the state of a thread.
+ */
+public interface ThreadStateTransitionValidator {
+    boolean isValidTransition(final ThreadStateTransitionValidator newState);
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index 4ebc42b..e526da4 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -28,8 +28,11 @@ import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.ForeachAction;
 import org.apache.kafka.streams.kstream.KStreamBuilder;
 import org.apache.kafka.streams.processor.StreamPartitioner;
+import org.apache.kafka.streams.processor.internals.GlobalStreamThread;
+import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.apache.kafka.test.IntegrationTest;
 import org.apache.kafka.test.MockMetricsReporter;
+import org.apache.kafka.test.TestCondition;
 import org.apache.kafka.test.TestUtils;
 import org.junit.Assert;
 import org.junit.Before;
@@ -53,6 +56,7 @@ import static org.junit.Assert.assertTrue;
 public class KafkaStreamsTest {
 
     private static final int NUM_BROKERS = 1;
+    private static final int NUM_THREADS = 2;
     // We need this to avoid the KafkaConsumer hanging on poll (this may occur if the test doesn't complete
     // quick enough)
     @ClassRule
@@ -68,17 +72,14 @@ public class KafkaStreamsTest {
         props.setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
         props.setProperty(StreamsConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName());
         props.setProperty(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
+        props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, NUM_THREADS);
         streams = new KafkaStreams(builder, props);
     }
 
     @Test
-    public void testInitializesAndDestroysMetricsReporters() throws Exception {
-        final int oldInitCount = MockMetricsReporter.INIT_COUNT.get();
+    public void testStateChanges() throws Exception {
         final KStreamBuilder builder = new KStreamBuilder();
         final KafkaStreams streams = new KafkaStreams(builder, props);
-        final int newInitCount = MockMetricsReporter.INIT_COUNT.get();
-        final int initDiff = newInitCount - oldInitCount;
-        assertTrue("some reporters should be initialized by calling on construction", initDiff > 0);
 
         StateListenerStub stateListener = new StateListenerStub();
         streams.setStateListener(stateListener);
@@ -86,17 +87,133 @@ public class KafkaStreamsTest {
         Assert.assertEquals(stateListener.numChanges, 0);
 
         streams.start();
-        Assert.assertEquals(streams.state(), KafkaStreams.State.RUNNING);
-        Assert.assertEquals(stateListener.numChanges, 1);
-        Assert.assertEquals(stateListener.oldState, KafkaStreams.State.CREATED);
-        Assert.assertEquals(stateListener.newState, KafkaStreams.State.RUNNING);
-        Assert.assertEquals(stateListener.mapStates.get(KafkaStreams.State.RUNNING).longValue(), 1L);
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.RUNNING;
+            }
+        }, 10 * 1000, "Streams never started.");
+        streams.close();
+        Assert.assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
+    }
+
+    @Test
+    public void testStateCloseAfterCreate() throws Exception {
+        final KStreamBuilder builder = new KStreamBuilder();
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+
+        StateListenerStub stateListener = new StateListenerStub();
+        streams.setStateListener(stateListener);
+        streams.close();
+        Assert.assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
+    }
+
+    @Test
+    public void testStateThreadClose() throws Exception {
+        final int numThreads = 2;
+        final KStreamBuilder builder = new KStreamBuilder();
+        // make sure we have the global state thread running too
+        builder.globalTable("anyTopic");
+        props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numThreads);
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+
+
+
+        final java.lang.reflect.Field threadsField = streams.getClass().getDeclaredField("threads");
+        threadsField.setAccessible(true);
+        final StreamThread[] threads = (StreamThread[]) threadsField.get(streams);
+
+        assertEquals(numThreads, threads.length);
+        assertEquals(streams.state(), KafkaStreams.State.CREATED);
+
+        streams.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.RUNNING;
+            }
+        }, 10 * 1000, "Streams never started.");
+
+        for (int i = 0; i < numThreads; i++) {
+            final StreamThread tmpThread = threads[i];
+            tmpThread.close();
+            TestUtils.waitForCondition(new TestCondition() {
+                @Override
+                public boolean conditionMet() {
+                    return tmpThread.state() == StreamThread.State.DEAD;
+                }
+            }, 10 * 1000, "Thread never stopped.");
+            threads[i].join();
+        }
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.ERROR;
+            }
+        }, 10 * 1000, "Streams never stopped.");
+        streams.close();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.NOT_RUNNING;
+            }
+        }, 10 * 1000, "Streams never stopped.");
+
+        final java.lang.reflect.Field globalThreadField = streams.getClass().getDeclaredField("globalStreamThread");
+        globalThreadField.setAccessible(true);
+        GlobalStreamThread globalStreamThread = (GlobalStreamThread) globalThreadField.get(streams);
+        assertEquals(globalStreamThread, null);
+    }
+
+    @Test
+    public void testStateGlobalThreadClose() throws Exception {
+        final int numThreads = 2;
+        final KStreamBuilder builder = new KStreamBuilder();
+        // make sure we have the global state thread running too
+        builder.globalTable("anyTopic");
+        props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numThreads);
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+
+
+        streams.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.RUNNING;
+            }
+        }, 10 * 1000, "Streams never started.");
+        final java.lang.reflect.Field globalThreadField = streams.getClass().getDeclaredField("globalStreamThread");
+        globalThreadField.setAccessible(true);
+        final GlobalStreamThread globalStreamThread = (GlobalStreamThread) globalThreadField.get(streams);
+        globalStreamThread.close();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return globalStreamThread.state() == GlobalStreamThread.State.DEAD;
+            }
+        }, 10 * 1000, "Thread never stopped.");
+        globalStreamThread.join();
+        assertEquals(streams.state(), KafkaStreams.State.ERROR);
+
+        streams.close();
+        assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
+
+    }
+
+
+    @Test
+    public void testInitializesAndDestroysMetricsReporters() throws Exception {
+        final int oldInitCount = MockMetricsReporter.INIT_COUNT.get();
+        final KStreamBuilder builder = new KStreamBuilder();
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+        final int newInitCount = MockMetricsReporter.INIT_COUNT.get();
+        final int initDiff = newInitCount - oldInitCount;
+        assertTrue("some reporters should be initialized by calling on construction", initDiff > 0);
+
+        streams.start();
         final int oldCloseCount = MockMetricsReporter.CLOSE_COUNT.get();
         streams.close();
         assertEquals(oldCloseCount + initDiff, MockMetricsReporter.CLOSE_COUNT.get());
-        Assert.assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
-        Assert.assertEquals(stateListener.mapStates.get(KafkaStreams.State.RUNNING).longValue(), 1L);
-        Assert.assertEquals(stateListener.mapStates.get(KafkaStreams.State.NOT_RUNNING).longValue(), 1L);
     }
 
     @Test
@@ -279,6 +396,12 @@ public class KafkaStreamsTest {
         final KafkaStreams streams = new KafkaStreams(builder, props);
 
         streams.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.RUNNING;
+            }
+        }, 10 * 1000, "Streams never started.");
         try {
             streams.cleanUp();
         } catch (final IllegalStateException e) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
index 30582ed..07be6e8 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
@@ -27,6 +27,7 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.kstream.KStreamBuilder;
 import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.test.TestCondition;
 import org.apache.kafka.test.TestUtils;
 import org.junit.Before;
 import org.junit.Test;
@@ -35,11 +36,13 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 
+import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.RUNNING;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.IsInstanceOf.instanceOf;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
 
 public class GlobalStreamThreadTest {
@@ -107,7 +110,7 @@ public class GlobalStreamThreadTest {
 
 
     @Test
-    public void shouldBeRunningAfterSuccesulStart() throws Exception {
+    public void shouldBeRunningAfterSuccessfulStart() throws Exception {
         initializeConsumer();
         globalStreamThread.start();
         assertTrue(globalStreamThread.stillRunning());
@@ -119,6 +122,7 @@ public class GlobalStreamThreadTest {
         globalStreamThread.start();
         globalStreamThread.close();
         globalStreamThread.join();
+        assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state());
     }
 
     @Test
@@ -132,6 +136,48 @@ public class GlobalStreamThreadTest {
         assertFalse(globalStore.isOpen());
     }
 
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldTransitionToDeadOnClose() throws InterruptedException {
+
+        initializeConsumer();
+        globalStreamThread.start();
+        globalStreamThread.close();
+        globalStreamThread.join();
+
+        assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldStayDeadAfterTwoCloses() throws InterruptedException {
+
+        initializeConsumer();
+        globalStreamThread.start();
+        globalStreamThread.close();
+        globalStreamThread.join();
+        globalStreamThread.close();
+
+        assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldTransitiontoRunningOnStart() throws InterruptedException {
+
+        initializeConsumer();
+        globalStreamThread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return globalStreamThread.state() == RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
+        globalStreamThread.close();
+    }
+
+
+
     private void initializeConsumer() {
         mockConsumer.updatePartitions("foo", Collections.singletonList(new PartitionInfo("foo",
                                                                                          0,

http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index a0882cf..2056954 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -52,6 +52,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.io.File;
+import java.io.IOException;
 import java.lang.reflect.Field;
 import java.nio.ByteBuffer;
 import java.nio.file.Files;
@@ -265,10 +266,10 @@ public class StreamThreadTest {
 
         final StateListenerStub stateListener = new StateListenerStub();
         thread.setStateListener(stateListener);
-        assertEquals(thread.state(), StreamThread.State.RUNNING);
+        assertEquals(thread.state(), StreamThread.State.CREATED);
 
         final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
-
+        thread.setState(StreamThread.State.RUNNING);
         assertTrue(thread.tasks().isEmpty());
 
         List<TopicPartition> revokedPartitions;
@@ -281,9 +282,6 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
 
         assertEquals(thread.state(), StreamThread.State.PARTITIONS_REVOKED);
-        Assert.assertEquals(stateListener.numChanges, 1);
-        Assert.assertEquals(stateListener.oldState, StreamThread.State.RUNNING);
-        Assert.assertEquals(stateListener.newState, StreamThread.State.PARTITIONS_REVOKED);
 
         // assign single partition
         assignedPartitions = Collections.singletonList(t1p1);
@@ -292,9 +290,8 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
         assertEquals(thread.state(), StreamThread.State.RUNNING);
-        Assert.assertEquals(stateListener.numChanges, 3);
+        Assert.assertEquals(stateListener.numChanges, 4);
         Assert.assertEquals(stateListener.oldState, StreamThread.State.ASSIGNING_PARTITIONS);
-        Assert.assertEquals(stateListener.newState, StreamThread.State.RUNNING);
         assertTrue(thread.tasks().containsKey(task1));
         assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
         assertEquals(1, thread.tasks().size());
@@ -343,8 +340,7 @@ public class StreamThreadTest {
         assertTrue(thread.tasks().isEmpty());
 
         thread.close();
-        assertTrue((thread.state() == StreamThread.State.PENDING_SHUTDOWN) ||
-            (thread.state() == StreamThread.State.CREATED));
+        assertTrue(thread.state() == StreamThread.State.PENDING_SHUTDOWN);
     }
 
     @SuppressWarnings("unchecked")
@@ -367,10 +363,10 @@ public class StreamThreadTest {
 
         final StateListenerStub stateListener = new StateListenerStub();
         thread.setStateListener(stateListener);
-        assertEquals(thread.state(), StreamThread.State.RUNNING);
+        assertEquals(thread.state(), StreamThread.State.CREATED);
 
         final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
-
+        thread.setState(StreamThread.State.RUNNING);
         assertTrue(thread.tasks().isEmpty());
 
         List<TopicPartition> revokedPartitions;
@@ -383,9 +379,6 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
 
         assertEquals(thread.state(), StreamThread.State.PARTITIONS_REVOKED);
-        Assert.assertEquals(stateListener.numChanges, 1);
-        Assert.assertEquals(stateListener.oldState, StreamThread.State.RUNNING);
-        Assert.assertEquals(stateListener.newState, StreamThread.State.PARTITIONS_REVOKED);
 
         // assign four new partitions of second subtopology
         assignedPartitions = Arrays.asList(t2p1, t2p2, t3p1, t3p2);
@@ -441,8 +434,44 @@ public class StreamThreadTest {
         assertTrue(thread.tasks().isEmpty());
 
         thread.close();
-        assertTrue((thread.state() == StreamThread.State.PENDING_SHUTDOWN) ||
-            (thread.state() == StreamThread.State.CREATED));
+        assertEquals(thread.state(), StreamThread.State.PENDING_SHUTDOWN);
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testStateChangeStartClose() throws InterruptedException {
+
+        final StreamThread thread = new StreamThread(
+            builder,
+            config,
+            clientSupplier,
+            applicationId,
+            clientId,
+            processId,
+            metrics,
+            Time.SYSTEM,
+            new StreamsMetadataState(builder, StreamsMetadataState.UNKNOWN_HOST),
+            0);
+
+        final StateListenerStub stateListener = new StateListenerStub();
+        thread.setStateListener(stateListener);
+        thread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
+        thread.close();
+        assertEquals(thread.state(), StreamThread.State.PENDING_SHUTDOWN);
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread.state() == StreamThread.State.DEAD;
+            }
+        }, 10 * 1000, "Thread never shut down.");
+        thread.close();
+        assertEquals(thread.state(), StreamThread.State.DEAD);
     }
 
     private final static String TOPIC = "topic";
@@ -451,7 +480,7 @@ public class StreamThreadTest {
 
     @SuppressWarnings("unchecked")
     @Test
-    public void testHandingOverTaskFromOneToAnotherThread() throws Exception {
+    public void testHandingOverTaskFromOneToAnotherThread() throws InterruptedException {
         builder.addStateStore(
             Stores
                 .create("store")
@@ -462,7 +491,7 @@ public class StreamThreadTest {
         );
         builder.addSource("source", TOPIC);
 
-        clientSupplier.consumer.assign(Arrays.asList(new TopicPartition(TOPIC, 0), new TopicPartition(TOPIC, 1)));
+        //clientSupplier.consumer.assign(Arrays.asList(new TopicPartition(TOPIC, 0), new TopicPartition(TOPIC, 1)));
 
         final StreamThread thread1 = new StreamThread(
             builder,
@@ -496,6 +525,23 @@ public class StreamThreadTest {
         thread1.setPartitionAssignor(new MockStreamsPartitionAssignor(thread1Assignment));
         thread2.setPartitionAssignor(new MockStreamsPartitionAssignor(thread2Assignment));
 
+
+        thread1.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread1.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
+
+        thread2.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread2.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
+
         // revoke (to get threads in correct state)
         thread1.rebalanceListener.onPartitionsRevoked(EMPTY_SET);
         thread2.rebalanceListener.onPartitionsRevoked(EMPTY_SET);
@@ -569,7 +615,7 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void testMetrics() throws Exception {
+    public void testMetrics() {
         final StreamThread thread = new StreamThread(
             builder,
             config,
@@ -610,8 +656,9 @@ public class StreamThreadTest {
         assertNotNull(metrics.metrics().get(metrics.metricName("skipped-records-rate", defaultGroupName, "The average per-second number of skipped records.", defaultTags)));
     }
 
+
     @Test
-    public void testMaybeClean() throws Exception {
+    public void testMaybeClean() throws IOException, InterruptedException, java.nio.file.NoSuchFileException {
         final File baseDir = Files.createTempDirectory("test").toFile();
         try {
             final long cleanupDelay = 1000L;
@@ -629,9 +676,7 @@ public class StreamThreadTest {
             stateDir2.mkdir();
             stateDir3.mkdir();
             extraDir.mkdir();
-
             builder.addSource("source1", "topic1");
-
             final StreamThread thread = new StreamThread(
                 builder,
                 config,
@@ -643,12 +688,10 @@ public class StreamThreadTest {
                 mockTime,
                 new StreamsMetadataState(builder, StreamsMetadataState.UNKNOWN_HOST),
                 0) {
-
                 @Override
                 public void maybeClean(final long now) {
                     super.maybeClean(now);
                 }
-
                 @Override
                 protected StreamTask createStreamTask(final TaskId id, final Collection<TopicPartition> partitionsForTask) {
                     final ProcessorTopology topology = builder.build(id.topicGroupId);
@@ -680,9 +723,7 @@ public class StreamThreadTest {
             List<TopicPartition> assignedPartitions;
             Map<TaskId, StreamTask> prevTasks;
 
-            //
             // Assign t1p1 and t1p2. This should create task1 & task2
-            //
             final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
             activeTasks.put(task1, Collections.singleton(t1p1));
             activeTasks.put(task2, Collections.singleton(t1p2));
@@ -693,6 +734,7 @@ public class StreamThreadTest {
             prevTasks = new HashMap<>(thread.tasks());
 
             final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
+            thread.setState(StreamThread.State.RUNNING);
             rebalanceListener.onPartitionsRevoked(revokedPartitions);
             rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
@@ -718,11 +760,8 @@ public class StreamThreadTest {
             assertFalse(stateDir3.exists());
             assertTrue(extraDir.exists());
 
-            //
             // Revoke t1p1 and t1p2. This should remove task1 & task2
-            //
             activeTasks.clear();
-
             revokedPartitions = assignedPartitions;
             assignedPartitions = Collections.emptyList();
             prevTasks = new HashMap<>(thread.tasks());
@@ -756,12 +795,14 @@ public class StreamThreadTest {
             assertFalse(stateDir3.exists());
             assertTrue(extraDir.exists());
         } finally {
+            // note that this call can throw a java.nio.file.NoSuchFileException
+            // since some of the subdirs might be deleted already during cleanup
             Utils.delete(baseDir);
         }
     }
 
     @Test
-    public void testMaybeCommit() throws Exception {
+    public void testMaybeCommit() throws IOException, InterruptedException {
         final File baseDir = Files.createTempDirectory("test").toFile();
         try {
             final long commitInterval = 1000L;
@@ -820,6 +861,8 @@ public class StreamThreadTest {
             revokedPartitions = Collections.emptyList();
             assignedPartitions = Arrays.asList(t1p1, t1p2);
 
+            thread.setState(StreamThread.State.RUNNING);
+            thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
             rebalanceListener.onPartitionsRevoked(revokedPartitions);
             rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
@@ -860,7 +903,7 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void shouldInjectSharedProducerForAllTasksUsingClientSupplierOnCreateIfEosDisabled() {
+    public void shouldInjectSharedProducerForAllTasksUsingClientSupplierOnCreateIfEosDisabled() throws InterruptedException {
         builder.addSource("source1", "someTopic");
 
         final StreamThread thread = new StreamThread(
@@ -880,6 +923,8 @@ public class StreamThreadTest {
         assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(assignment));
 
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
 
         assertEquals(1, clientSupplier.producers.size());
@@ -893,7 +938,7 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void shouldInjectProducerPerTaskUsingClientSupplierOnCreateIfEosEnable() {
+    public void shouldInjectProducerPerTaskUsingClientSupplierOnCreateIfEosEnable() throws InterruptedException {
         builder.addSource("source1", "someTopic");
 
         final MockClientSupplier clientSupplier = new MockClientSupplier(applicationId);
@@ -917,6 +962,8 @@ public class StreamThreadTest {
 
         final Set<TopicPartition> assignedPartitions = new HashSet<>();
         Collections.addAll(assignedPartitions, new TopicPartition("someTopic", 0), new TopicPartition("someTopic", 2));
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
         assertNull(thread.threadProducer);
@@ -930,7 +977,7 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void shouldCloseAllTaskProducersOnCloseIfEosEnabled() {
+    public void shouldCloseAllTaskProducersOnCloseIfEosEnabled() throws InterruptedException {
         builder.addSource("source1", "someTopic");
 
         final MockClientSupplier clientSupplier = new MockClientSupplier(applicationId);
@@ -951,6 +998,8 @@ public class StreamThreadTest {
         assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(assignment));
 
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
 
         thread.close();
@@ -962,7 +1011,7 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void shouldCloseThreadProducerOnCloseIfEosDisabled() {
+    public void shouldCloseThreadProducerOnCloseIfEosDisabled() throws InterruptedException {
         builder.addSource("source1", "someTopic");
 
         final StreamThread thread = new StreamThread(
@@ -982,6 +1031,8 @@ public class StreamThreadTest {
         assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(assignment));
 
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
 
         thread.close();
@@ -991,7 +1042,7 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void shouldNotNullPointerWhenStandbyTasksAssignedAndNoStateStoresForTopology() throws Exception {
+    public void shouldNotNullPointerWhenStandbyTasksAssignedAndNoStateStoresForTopology() throws InterruptedException {
         builder.addSource("name", "topic").addSink("out", "output");
 
         final StreamThread thread = new StreamThread(
@@ -1013,6 +1064,7 @@ public class StreamThreadTest {
             }
         });
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
     }
@@ -1022,28 +1074,28 @@ public class StreamThreadTest {
         builder.addSource("name", "topic").addSink("out", "output");
 
         final TestStreamTask testStreamTask = new TestStreamTask(
-            new TaskId(0, 0),
-            applicationId,
-            Utils.mkSet(new TopicPartition("topic", 0)),
-            builder.build(0),
-            clientSupplier.consumer,
-            clientSupplier.getProducer(new HashMap<String, Object>()),
-            clientSupplier.restoreConsumer,
-            config,
-            new MockStreamsMetrics(new Metrics()),
-            new StateDirectory(applicationId, config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime));
+                new TaskId(0, 0),
+                applicationId,
+                Utils.mkSet(new TopicPartition("topic", 0)),
+                builder.build(0),
+                clientSupplier.consumer,
+                clientSupplier.getProducer(new HashMap<String, Object>()),
+                clientSupplier.restoreConsumer,
+                config,
+                new MockStreamsMetrics(new Metrics()),
+                new StateDirectory(applicationId, config.getString(StreamsConfig.STATE_DIR_CONFIG), mockTime));
 
         final StreamThread thread = new StreamThread(
-            builder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(builder, StreamsMetadataState.UNKNOWN_HOST),
-            0) {
+                builder,
+                config,
+                clientSupplier,
+                applicationId,
+                clientId,
+                processId,
+                metrics,
+                mockTime,
+                new StreamsMetadataState(builder, StreamsMetadataState.UNKNOWN_HOST),
+                0) {
 
             @Override
             protected StreamTask createStreamTask(final TaskId id, final Collection<TopicPartition> partitionsForTask) {
@@ -1064,7 +1116,8 @@ public class StreamThreadTest {
                 };
             }
         });
-
+        thread.setState(StreamThread.State.RUNNING);
+        thread.setState(StreamThread.State.PARTITIONS_REVOKED);
         thread.rebalanceListener.onPartitionsAssigned(activeTasks);
         thread.rebalanceListener.onPartitionsRevoked(activeTasks);
 
@@ -1077,39 +1130,39 @@ public class StreamThreadTest {
 
         assertTrue(testStreamTask.closed);
     }
-    @Test
 
-    public void shouldInitializeRestoreConsumerWithOffsetsFromStandbyTasks() throws Exception {
+    @Test
+    public void shouldInitializeRestoreConsumerWithOffsetsFromStandbyTasks()  {
         final KStreamBuilder builder = new KStreamBuilder();
         builder.setApplicationId(applicationId);
         builder.stream("t1").groupByKey().count("count-one");
         builder.stream("t2").groupByKey().count("count-two");
 
         final StreamThread thread = new StreamThread(
-            builder,
-            config,
-            clientSupplier,
-            applicationId,
-            clientId,
-            processId,
-            metrics,
-            mockTime,
-            new StreamsMetadataState(builder, StreamsMetadataState.UNKNOWN_HOST),
-            0);
+                builder,
+                config,
+                clientSupplier,
+                applicationId,
+                clientId,
+                processId,
+                metrics,
+                mockTime,
+                new StreamsMetadataState(builder, StreamsMetadataState.UNKNOWN_HOST),
+                0);
 
         final MockConsumer<byte[], byte[]> restoreConsumer = clientSupplier.restoreConsumer;
         restoreConsumer.updatePartitions("stream-thread-test-count-one-changelog",
-                                         Collections.singletonList(new PartitionInfo("stream-thread-test-count-one-changelog",
-                                                                                     0,
-                                                                                     null,
-                                                                                     new Node[0],
-                                                                                     new Node[0])));
+                Collections.singletonList(new PartitionInfo("stream-thread-test-count-one-changelog",
+                        0,
+                        null,
+                        new Node[0],
+                        new Node[0])));
         restoreConsumer.updatePartitions("stream-thread-test-count-two-changelog",
-                                         Collections.singletonList(new PartitionInfo("stream-thread-test-count-two-changelog",
-                                                                                     0,
-                                                                                     null,
-                                                                                     new Node[0],
-                                                                                     new Node[0])));
+                Collections.singletonList(new PartitionInfo("stream-thread-test-count-two-changelog",
+                        0,
+                        null,
+                        new Node[0],
+                        new Node[0])));
 
         final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
         final TopicPartition t1 = new TopicPartition("t1", 0);
@@ -1122,6 +1175,7 @@ public class StreamThreadTest {
             }
         });
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
 
@@ -1133,9 +1187,12 @@ public class StreamThreadTest {
         thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
 
         assertThat(restoreConsumer.assignment(), equalTo(Utils.mkSet(new TopicPartition("stream-thread-test-count-one-changelog", 0),
-                                                                     new TopicPartition("stream-thread-test-count-two-changelog", 0))));
+                new TopicPartition("stream-thread-test-count-two-changelog", 0))));
     }
 
+
+
+
     @Test
     public void shouldCloseSuspendedTasksThatAreNoLongerAssignedToThisStreamThreadBeforeCreatingNewTasks() throws Exception {
         final KStreamBuilder builder = new KStreamBuilder();
@@ -1195,6 +1252,7 @@ public class StreamThreadTest {
             }
         });
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Utils.mkSet(t2));
 
@@ -1269,6 +1327,8 @@ public class StreamThreadTest {
         builder.updateSubscriptions(subscriptionUpdates, null);
 
         // should create task for id 0_0 with a single partition
+        thread.setState(StreamThread.State.RUNNING);
+
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(task00Partitions);
 
@@ -1314,13 +1374,18 @@ public class StreamThreadTest {
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
 
+        thread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
         thread.rebalanceListener.onPartitionsRevoked(null);
         thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
         assertThat(thread.tasks().size(), equalTo(1));
         final MockProducer producer = clientSupplier.producers.get(0);
 
-        thread.start();
-
         TestUtils.waitForCondition(
             new TestCondition() {
                 @Override
@@ -1402,6 +1467,7 @@ public class StreamThreadTest {
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(null);
         thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
         assertThat(thread.tasks().size(), equalTo(1));
@@ -1463,10 +1529,10 @@ public class StreamThreadTest {
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
 
-        thread.start();
         thread.close();
         thread.join();
         assertFalse("task shouldn't have been committed as there was an exception during shutdown", testStreamTask.committed);
@@ -1519,12 +1585,18 @@ public class StreamThreadTest {
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
 
+        thread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
         // store should have been opened
         assertTrue(stateStore.isOpen());
 
-        thread.start();
         thread.close();
         thread.join();
         assertFalse("task shouldn't have been committed as there was an exception during shutdown", testStreamTask.committed);
@@ -1581,6 +1653,7 @@ public class StreamThreadTest {
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
         try {
@@ -1640,6 +1713,7 @@ public class StreamThreadTest {
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
         try {
@@ -1947,11 +2021,12 @@ public class StreamThreadTest {
 
     private static class StateListenerStub implements StreamThread.StateListener {
         int numChanges = 0;
-        StreamThread.State oldState = null;
-        StreamThread.State newState = null;
+        ThreadStateTransitionValidator oldState = null;
+        ThreadStateTransitionValidator newState = null;
 
         @Override
-        public void onChange(final StreamThread thread, final StreamThread.State newState, final StreamThread.State oldState) {
+        public void onChange(final Thread thread, final ThreadStateTransitionValidator newState,
+                             final ThreadStateTransitionValidator oldState) {
             ++numChanges;
             if (this.newState != null) {
                 if (this.newState != oldState) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/90785542/streams/src/test/resources/log4j.properties
----------------------------------------------------------------------
diff --git a/streams/src/test/resources/log4j.properties b/streams/src/test/resources/log4j.properties
index 69f3d1c..91c909b 100644
--- a/streams/src/test/resources/log4j.properties
+++ b/streams/src/test/resources/log4j.properties
@@ -12,10 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-log4j.rootLogger=WARN, stdout
+log4j.rootLogger=INFO, stdout
 
 log4j.appender.stdout=org.apache.log4j.ConsoleAppender
 log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
 log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n
 
-log4j.logger.org.apache.kafka=WARN
\ No newline at end of file
+log4j.logger.org.apache.kafka=INFO
\ No newline at end of file