You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2017/11/28 17:37:32 UTC

[1/3] kafka git commit: KAFKA-6170; KIP-220 Part 2: Break dependency of Assignor on StreamThread

Repository: kafka
Updated Branches:
  refs/heads/trunk 8f6a372ee -> 5df1eee7d


http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index a3d7523..c3a372c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -20,13 +20,13 @@ import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.MockConsumer;
-import org.apache.kafka.clients.consumer.internals.PartitionAssignor;
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
@@ -37,11 +37,7 @@ import org.apache.kafka.streams.kstream.internals.InternalStreamsBuilderTest;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.TaskMetadata;
 import org.apache.kafka.streams.processor.ThreadMetadata;
-import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
-import org.apache.kafka.streams.state.HostInfo;
-import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.test.MockClientSupplier;
-import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.MockStateRestoreListener;
 import org.apache.kafka.test.MockTimestampExtractor;
 import org.apache.kafka.test.TestCondition;
@@ -51,21 +47,16 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import java.lang.reflect.Field;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
 import java.util.UUID;
-import java.util.regex.Pattern;
 
-import static java.util.Collections.EMPTY_SET;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -100,18 +91,14 @@ public class StreamThreadTest {
         streamsMetadataState = new StreamsMetadataState(internalTopologyBuilder, StreamsMetadataState.UNKNOWN_HOST);
     }
 
-    private final TopicPartition t1p1 = new TopicPartition("topic1", 1);
-    private final TopicPartition t1p2 = new TopicPartition("topic1", 2);
-    private final TopicPartition t2p1 = new TopicPartition("topic2", 1);
-    private final TopicPartition t2p2 = new TopicPartition("topic2", 2);
-    private final TopicPartition t3p1 = new TopicPartition("topic3", 1);
-    private final TopicPartition t3p2 = new TopicPartition("topic3", 2);
+    private final String topic1 = "topic1";
+
+    private final TopicPartition t1p1 = new TopicPartition(topic1, 1);
+    private final TopicPartition t1p2 = new TopicPartition(topic1, 2);
 
     // task0 is unused
     private final TaskId task1 = new TaskId(0, 1);
     private final TaskId task2 = new TaskId(0, 2);
-    private final TaskId task3 = new TaskId(1, 1);
-    private final TaskId task4 = new TaskId(1, 2);
 
     private Properties configProps(final boolean enableEos) {
         return new Properties() {
@@ -132,30 +119,19 @@ public class StreamThreadTest {
     @SuppressWarnings("unchecked")
     @Test
     public void testPartitionAssignmentChangeForSingleGroup() {
-        internalTopologyBuilder.addSource(null, "source1", null, null, null, "topic1");
+        internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);
 
         final StreamThread thread = getStreamThread();
 
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
-            @Override
-            public Map<TaskId, Set<TopicPartition>> activeTasks() {
-                return activeTasks;
-            }
-        });
-
         final StateListenerStub stateListener = new StateListenerStub();
         thread.setStateListener(stateListener);
         assertEquals(thread.state(), StreamThread.State.CREATED);
 
         final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
         thread.setState(StreamThread.State.RUNNING);
-        assertTrue(thread.tasks().isEmpty());
 
         List<TopicPartition> revokedPartitions;
         List<TopicPartition> assignedPartitions;
-        Set<TopicPartition> expectedGroup1;
-        Set<TopicPartition> expectedGroup2;
 
         // revoke nothing
         revokedPartitions = Collections.emptyList();
@@ -165,59 +141,12 @@ public class StreamThreadTest {
 
         // assign single partition
         assignedPartitions = Collections.singletonList(t1p1);
-        expectedGroup1 = new HashSet<>(Collections.singleton(t1p1));
-        activeTasks.put(new TaskId(0, 1), expectedGroup1);
+        thread.taskManager().setAssignmentMetadata(Collections.<TaskId, Set<TopicPartition>>emptyMap(), Collections.<TaskId, Set<TopicPartition>>emptyMap());
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
         thread.runOnce(-1);
         assertEquals(thread.state(), StreamThread.State.RUNNING);
         Assert.assertEquals(4, stateListener.numChanges);
         Assert.assertEquals(StreamThread.State.PARTITIONS_ASSIGNED, stateListener.oldState);
-        assertTrue(thread.tasks().containsKey(task1));
-        assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
-        assertEquals(1, thread.tasks().size());
-
-        // revoke single partition
-        revokedPartitions = assignedPartitions;
-        activeTasks.clear();
-        rebalanceListener.onPartitionsRevoked(revokedPartitions);
-
-        assertFalse(thread.tasks().containsKey(task1));
-        assertEquals(0, thread.tasks().size());
-
-        // assign different single partition
-        assignedPartitions = Collections.singletonList(t1p2);
-        expectedGroup2 = new HashSet<>(Collections.singleton(t1p2));
-        activeTasks.put(new TaskId(0, 2), expectedGroup2);
-        rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        thread.runOnce(-1);
-        assertTrue(thread.tasks().containsKey(task2));
-        assertEquals(expectedGroup2, thread.tasks().get(task2).partitions());
-        assertEquals(1, thread.tasks().size());
-
-        // revoke different single partition and assign both partitions
-        revokedPartitions = assignedPartitions;
-        activeTasks.clear();
-        rebalanceListener.onPartitionsRevoked(revokedPartitions);
-        assignedPartitions = Arrays.asList(t1p1, t1p2);
-        expectedGroup1 = new HashSet<>(Collections.singleton(t1p1));
-        expectedGroup2 = new HashSet<>(Collections.singleton(t1p2));
-        activeTasks.put(new TaskId(0, 1), expectedGroup1);
-        activeTasks.put(new TaskId(0, 2), expectedGroup2);
-        rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        thread.runOnce(-1);
-        assertTrue(thread.tasks().containsKey(task1));
-        assertTrue(thread.tasks().containsKey(task2));
-        assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
-        assertEquals(expectedGroup2, thread.tasks().get(task2).partitions());
-        assertEquals(2, thread.tasks().size());
-
-        // revoke all partitions and assign nothing
-        revokedPartitions = assignedPartitions;
-        rebalanceListener.onPartitionsRevoked(revokedPartitions);
-        assignedPartitions = Collections.emptyList();
-        rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        thread.runOnce(-1);
-        assertTrue(thread.tasks().isEmpty());
 
         thread.shutdown();
         assertTrue(thread.state() == StreamThread.State.PENDING_SHUTDOWN);
@@ -225,104 +154,6 @@ public class StreamThreadTest {
 
     @SuppressWarnings("unchecked")
     @Test
-    public void testPartitionAssignmentChangeForMultipleGroups() {
-        internalTopologyBuilder.addSource(null, "source1", null, null, null, "topic1");
-        internalTopologyBuilder.addSource(null, "source2", null, null, null, "topic2");
-        internalTopologyBuilder.addSource(null, "source3", null, null, null, "topic3");
-        internalTopologyBuilder.addProcessor("processor", new MockProcessorSupplier(), "source2", "source3");
-
-        final StreamThread thread = getStreamThread();
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
-            @Override
-            public Map<TaskId, Set<TopicPartition>> activeTasks() {
-                return activeTasks;
-            }
-        });
-
-        final StateListenerStub stateListener = new StateListenerStub();
-        thread.setStateListener(stateListener);
-        assertEquals(thread.state(), StreamThread.State.CREATED);
-
-        final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
-        thread.setState(StreamThread.State.RUNNING);
-        assertTrue(thread.tasks().isEmpty());
-
-        List<TopicPartition> revokedPartitions;
-        List<TopicPartition> assignedPartitions;
-        Set<TopicPartition> expectedGroup1;
-        Set<TopicPartition> expectedGroup2;
-
-        // revoke nothing
-        revokedPartitions = Collections.emptyList();
-        rebalanceListener.onPartitionsRevoked(revokedPartitions);
-
-        assertEquals(thread.state(), StreamThread.State.PARTITIONS_REVOKED);
-
-        // assign four new partitions of second subtopology
-        assignedPartitions = Arrays.asList(t2p1, t2p2, t3p1, t3p2);
-        expectedGroup1 = new HashSet<>(Arrays.asList(t2p1, t3p1));
-        expectedGroup2 = new HashSet<>(Arrays.asList(t2p2, t3p2));
-        activeTasks.put(new TaskId(1, 1), expectedGroup1);
-        activeTasks.put(new TaskId(1, 2), expectedGroup2);
-        rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        thread.runOnce(-1);
-
-        assertTrue(thread.tasks().containsKey(task3));
-        assertTrue(thread.tasks().containsKey(task4));
-        assertEquals(expectedGroup1, thread.tasks().get(task3).partitions());
-        assertEquals(expectedGroup2, thread.tasks().get(task4).partitions());
-        assertEquals(2, thread.tasks().size());
-
-        // revoke four partitions and assign three partitions of both subtopologies
-        revokedPartitions = assignedPartitions;
-        rebalanceListener.onPartitionsRevoked(revokedPartitions);
-
-        assignedPartitions = Arrays.asList(t1p1, t2p1, t3p1);
-        expectedGroup1 = new HashSet<>(Collections.singleton(t1p1));
-        expectedGroup2 = new HashSet<>(Arrays.asList(t2p1, t3p1));
-        activeTasks.put(new TaskId(0, 1), expectedGroup1);
-        activeTasks.put(new TaskId(1, 1), expectedGroup2);
-        rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        thread.runOnce(-1);
-
-        assertTrue(thread.tasks().containsKey(task1));
-        assertTrue(thread.tasks().containsKey(task3));
-        assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
-        assertEquals(expectedGroup2, thread.tasks().get(task3).partitions());
-        assertEquals(2, thread.tasks().size());
-
-        // revoke all three partitons and reassign the same three partitions (from different subtopologies)
-        revokedPartitions = assignedPartitions;
-        rebalanceListener.onPartitionsRevoked(revokedPartitions);
-        assignedPartitions = Arrays.asList(t1p1, t2p1, t3p1);
-        expectedGroup1 = new HashSet<>(Collections.singleton(t1p1));
-        expectedGroup2 = new HashSet<>(Arrays.asList(t2p1, t3p1));
-        rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        thread.runOnce(-1);
-
-        assertTrue(thread.tasks().containsKey(task1));
-        assertTrue(thread.tasks().containsKey(task3));
-        assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
-        assertEquals(expectedGroup2, thread.tasks().get(task3).partitions());
-        assertEquals(2, thread.tasks().size());
-
-        // revoke all partitions and assign nothing
-        revokedPartitions = assignedPartitions;
-        rebalanceListener.onPartitionsRevoked(revokedPartitions);
-        assignedPartitions = Collections.emptyList();
-        rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        thread.runOnce(-1);
-
-        assertTrue(thread.tasks().isEmpty());
-
-        thread.shutdown();
-        assertEquals(thread.state(), StreamThread.State.PENDING_SHUTDOWN);
-    }
-
-    @SuppressWarnings("unchecked")
-    @Test
     public void testStateChangeStartClose() throws InterruptedException {
 
         final StreamThread thread = createStreamThread(clientId, config, false);
@@ -367,128 +198,12 @@ public class StreamThreadTest {
                                    new MockStateRestoreListener());
     }
 
-    private final static String TOPIC = "topic";
-    private final Set<TopicPartition> task0Assignment = Collections.singleton(new TopicPartition(TOPIC, 0));
-    private final Set<TopicPartition> task1Assignment = Collections.singleton(new TopicPartition(TOPIC, 1));
-
-    @SuppressWarnings("unchecked")
-    @Test
-    public void testHandingOverTaskFromOneToAnotherThread() throws InterruptedException {
-        internalTopologyBuilder.addStateStore(
-            Stores
-                .create("store")
-                .withByteArrayKeys()
-                .withByteArrayValues()
-                .persistent()
-                .build()
-        );
-        internalTopologyBuilder.addSource(null, "source", null, null, null, TOPIC);
-
-        TopicPartition tp0 = new TopicPartition(TOPIC, 0);
-        TopicPartition tp1 = new TopicPartition(TOPIC, 1);
-        clientSupplier.consumer.assign(Arrays.asList(tp0, tp1));
-        final Map<TopicPartition, Long> offsets = new HashMap<>();
-        offsets.put(tp0, 0L);
-        offsets.put(tp1, 0L);
-        clientSupplier.consumer.updateBeginningOffsets(offsets);
-
-        final StreamThread thread1 = createStreamThread(clientId + 1, config, false);
-        final StreamThread thread2 = createStreamThread(clientId + 2, config, false);
-
-
-        final Map<TaskId, Set<TopicPartition>> task0 = Collections.singletonMap(new TaskId(0, 0), task0Assignment);
-        final Map<TaskId, Set<TopicPartition>> task1 = Collections.singletonMap(new TaskId(0, 1), task1Assignment);
-
-        final Map<TaskId, Set<TopicPartition>> thread1Assignment = new HashMap<>(task0);
-        final Map<TaskId, Set<TopicPartition>> thread2Assignment = new HashMap<>(task1);
-
-        thread1.setThreadMetadataProvider(new MockStreamsPartitionAssignor(thread1Assignment));
-        thread2.setThreadMetadataProvider(new MockStreamsPartitionAssignor(thread2Assignment));
-
-        // revoke (to get threads in correct state)
-        thread1.setState(StreamThread.State.RUNNING);
-        thread2.setState(StreamThread.State.RUNNING);
-        thread1.rebalanceListener.onPartitionsRevoked(EMPTY_SET);
-        thread2.rebalanceListener.onPartitionsRevoked(EMPTY_SET);
-
-        // assign
-        thread1.rebalanceListener.onPartitionsAssigned(task0Assignment);
-        thread1.runOnce(-1);
-        thread2.rebalanceListener.onPartitionsAssigned(task1Assignment);
-        thread2.runOnce(-1);
-
-        final Set<TaskId> originalTaskAssignmentThread1 = new HashSet<>();
-        originalTaskAssignmentThread1.addAll(thread1.tasks().keySet());
-        final Set<TaskId> originalTaskAssignmentThread2 = new HashSet<>();
-        originalTaskAssignmentThread2.addAll(thread2.tasks().keySet());
-
-        // revoke (task will be suspended)
-        thread1.rebalanceListener.onPartitionsRevoked(task0Assignment);
-        thread2.rebalanceListener.onPartitionsRevoked(task1Assignment);
-
-        assertThat(thread1.prevActiveTasks(), equalTo(originalTaskAssignmentThread1));
-        assertThat(thread2.prevActiveTasks(), equalTo(originalTaskAssignmentThread2));
-
-        // assign reverted
-        thread1Assignment.clear();
-        thread1Assignment.putAll(task1);
-
-        thread2Assignment.clear();
-        thread2Assignment.putAll(task0);
-
-        final Thread runIt = new Thread(new Runnable() {
-            @Override
-            public void run() {
-                thread1.rebalanceListener.onPartitionsAssigned(task1Assignment);
-                thread1.runOnce(-1);
-            }
-        });
-        runIt.start();
-
-        thread2.rebalanceListener.onPartitionsAssigned(task0Assignment);
-        thread2.runOnce(-1);
-
-        runIt.join();
-
-        assertThat(thread1.tasks().keySet(), equalTo(originalTaskAssignmentThread2));
-        assertThat(thread2.tasks().keySet(), equalTo(originalTaskAssignmentThread1));
-    }
-
-    private class MockStreamsPartitionAssignor extends StreamPartitionAssignor {
-
-        private final Map<TaskId, Set<TopicPartition>> activeTaskAssignment;
-        private final Map<TaskId, Set<TopicPartition>> standbyTaskAssignment;
-
-        MockStreamsPartitionAssignor(final Map<TaskId, Set<TopicPartition>> activeTaskAssignment) {
-            this(activeTaskAssignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
-        }
-
-        MockStreamsPartitionAssignor(final Map<TaskId, Set<TopicPartition>> activeTaskAssignment,
-                                     final Map<TaskId, Set<TopicPartition>> standbyTaskAssignment) {
-            this.activeTaskAssignment = activeTaskAssignment;
-            this.standbyTaskAssignment = standbyTaskAssignment;
-        }
-
-        @Override
-        public Map<TaskId, Set<TopicPartition>> activeTasks() {
-            return activeTaskAssignment;
-        }
-
-        @Override
-        public Map<TaskId, Set<TopicPartition>> standbyTasks() {
-            return standbyTaskAssignment;
-        }
-
-        @Override
-        public void close() {}
-    }
-
     @Test
     public void testMetrics() {
         final StreamThread thread = createStreamThread(clientId, config, false);
         final String defaultGroupName = "stream-metrics";
-        final String defaultPrefix = "thread." + thread.threadClientId();
-        final Map<String, String> defaultTags = Collections.singletonMap("client-id", thread.threadClientId());
+        final String defaultPrefix = "thread." + thread.getName();
+        final Map<String, String> defaultTags = Collections.singletonMap("client-id", thread.getName());
 
         assertNotNull(metrics.getSensor(defaultPrefix + ".commit-latency"));
         assertNotNull(metrics.getSensor(defaultPrefix + ".poll-latency"));
@@ -529,19 +244,18 @@ public class StreamThreadTest {
         final TaskManager taskManager = mockTaskManagerCommit(consumer, 1, 1);
 
         StreamThread.StreamsMetricsThreadImpl streamsMetrics = new StreamThread.StreamsMetricsThreadImpl(metrics, "", "", Collections.<String, String>emptyMap());
-        final StreamThread thread = new StreamThread(internalTopologyBuilder,
-                                                     clientId,
-                                                     "",
-                                                     config,
-                                                     processId,
-                                                     mockTime,
-                                                     streamsMetadataState,
-                                                     taskManager,
-                                                     streamsMetrics,
-                                                     clientSupplier,
-                                                     consumer,
-                                                     clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
-                                                     stateDirectory);
+        final StreamThread thread = new StreamThread(mockTime,
+                config,
+                consumer,
+                consumer,
+                null,
+                clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
+                taskManager,
+                streamsMetrics,
+                internalTopologyBuilder,
+                clientId,
+                new LogContext("")
+        );
         thread.maybeCommit(mockTime.milliseconds());
         mockTime.sleep(commitInterval - 10L);
         thread.maybeCommit(mockTime.milliseconds());
@@ -562,19 +276,17 @@ public class StreamThreadTest {
         final TaskManager taskManager = mockTaskManagerCommit(consumer, 1, 0);
 
         StreamThread.StreamsMetricsThreadImpl streamsMetrics = new StreamThread.StreamsMetricsThreadImpl(metrics, "", "", Collections.<String, String>emptyMap());
-        final StreamThread thread = new StreamThread(internalTopologyBuilder,
-                                                     clientId,
-                                                     "",
-                                                     config,
-                                                     processId,
-                                                     mockTime,
-                                                     streamsMetadataState,
-                                                     taskManager,
-                                                     streamsMetrics,
-                                                     clientSupplier,
-                                                     consumer,
-                                                     clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
-                                                     stateDirectory);
+        final StreamThread thread = new StreamThread(mockTime,
+                config,
+                consumer,
+                consumer,
+                null,
+                clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
+                taskManager,
+                streamsMetrics,
+                internalTopologyBuilder,
+                clientId,
+                new LogContext(""));
         thread.maybeCommit(mockTime.milliseconds());
         mockTime.sleep(commitInterval - 10L);
         thread.maybeCommit(mockTime.milliseconds());
@@ -596,19 +308,17 @@ public class StreamThreadTest {
         final TaskManager taskManager = mockTaskManagerCommit(consumer, 2, 1);
 
         StreamThread.StreamsMetricsThreadImpl streamsMetrics = new StreamThread.StreamsMetricsThreadImpl(metrics, "", "", Collections.<String, String>emptyMap());
-        final StreamThread thread = new StreamThread(internalTopologyBuilder,
-                                                     clientId,
-                                                     "",
-                                                     config,
-                                                     processId,
-                                                     mockTime,
-                                                     streamsMetadataState,
-                                                     taskManager,
-                                                     streamsMetrics,
-                                                     clientSupplier,
-                                                     consumer,
-                                                     clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
-                                                     stateDirectory);
+        final StreamThread thread = new StreamThread(mockTime,
+                config,
+                consumer,
+                consumer,
+                null,
+                clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
+                taskManager,
+                streamsMetrics,
+                internalTopologyBuilder,
+                clientId,
+                new LogContext(""));
         thread.maybeCommit(mockTime.milliseconds());
         mockTime.sleep(commitInterval + 1);
         thread.maybeCommit(mockTime.milliseconds());
@@ -619,8 +329,6 @@ public class StreamThreadTest {
     @SuppressWarnings({"ThrowableNotThrown", "unchecked"})
     private TaskManager mockTaskManagerCommit(final Consumer<byte[], byte[]> consumer, final int numberOfCommits, final int commits) {
         final TaskManager taskManager = EasyMock.createMock(TaskManager.class);
-        taskManager.setConsumer(EasyMock.anyObject(Consumer.class));
-        EasyMock.expectLastCall();
         EasyMock.expect(taskManager.commitAll()).andReturn(commits).times(numberOfCommits);
         EasyMock.replay(taskManager, consumer);
         return taskManager;
@@ -628,18 +336,26 @@ public class StreamThreadTest {
 
     @Test
     public void shouldInjectSharedProducerForAllTasksUsingClientSupplierOnCreateIfEosDisabled() throws InterruptedException {
-        internalTopologyBuilder.addSource(null, "source1", null, null, null, "someTopic");
+        internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);
 
         final StreamThread thread = createStreamThread(clientId, config, false);
 
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.put(new TaskId(0, 0), Collections.singleton(new TopicPartition("someTopic", 0)));
-        assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
-        thread.setThreadMetadataProvider(new MockStreamsPartitionAssignor(assignment));
-
         thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
+
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final List<TopicPartition> assignedPartitions = new ArrayList<>();
+
+        // assign single partition
+        assignedPartitions.add(t1p1);
+        assignedPartitions.add(t1p2);
+        activeTasks.put(task1, Collections.singleton(t1p1));
+        activeTasks.put(task2, Collections.singleton(t1p2));
+
+        thread.taskManager().setAssignmentMetadata(activeTasks, Collections.<TaskId, Set<TopicPartition>>emptyMap());
+        thread.taskManager().createTasks(assignedPartitions);
+
+        thread.rebalanceListener.onPartitionsAssigned(new HashSet<>(assignedPartitions));
 
         assertEquals(1, clientSupplier.producers.size());
         final Producer globalProducer = clientSupplier.producers.get(0);
@@ -652,46 +368,55 @@ public class StreamThreadTest {
 
     @Test
     public void shouldInjectProducerPerTaskUsingClientSupplierOnCreateIfEosEnable() throws InterruptedException {
-        internalTopologyBuilder.addSource(null, "source1", null, null, null, "someTopic");
+        internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);
 
         final StreamThread thread = createStreamThread(clientId, new StreamsConfig(configProps(true)), true);
 
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.put(new TaskId(0, 0), Collections.singleton(new TopicPartition("someTopic", 0)));
-        assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
-        assignment.put(new TaskId(0, 2), Collections.singleton(new TopicPartition("someTopic", 2)));
-        thread.setThreadMetadataProvider(new MockStreamsPartitionAssignor(assignment));
-
-        final Set<TopicPartition> assignedPartitions = new HashSet<>();
-        Collections.addAll(assignedPartitions, new TopicPartition("someTopic", 0), new TopicPartition("someTopic", 2));
         thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
+
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final List<TopicPartition> assignedPartitions = new ArrayList<>();
+
+        // assign single partition
+        assignedPartitions.add(t1p1);
+        assignedPartitions.add(t1p2);
+        activeTasks.put(task1, Collections.singleton(t1p1));
+        activeTasks.put(task2, Collections.singleton(t1p2));
+
+        thread.taskManager().setAssignmentMetadata(activeTasks, Collections.<TaskId, Set<TopicPartition>>emptyMap());
+
+        thread.rebalanceListener.onPartitionsAssigned(new HashSet<>(assignedPartitions));
+
         thread.runOnce(-1);
 
         assertEquals(thread.tasks().size(), clientSupplier.producers.size());
-        final Iterator it = clientSupplier.producers.iterator();
-        for (final Task task : thread.tasks().values()) {
-            assertSame(it.next(), ((RecordCollectorImpl) ((StreamTask) task).recordCollector()).producer());
-        }
         assertSame(clientSupplier.consumer, thread.consumer);
         assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer);
     }
 
     @Test
     public void shouldCloseAllTaskProducersOnCloseIfEosEnabled() throws InterruptedException {
-        internalTopologyBuilder.addSource(null, "source1", null, null, null, "someTopic");
+        internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);
 
         final StreamThread thread = createStreamThread(clientId, new StreamsConfig(configProps(true)), true);
 
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.put(new TaskId(0, 0), Collections.singleton(new TopicPartition("someTopic", 0)));
-        assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
-        thread.setThreadMetadataProvider(new MockStreamsPartitionAssignor(assignment));
-
         thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
+
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final List<TopicPartition> assignedPartitions = new ArrayList<>();
+
+        // assign single partition
+        assignedPartitions.add(t1p1);
+        assignedPartitions.add(t1p2);
+        activeTasks.put(task1, Collections.singleton(t1p1));
+        activeTasks.put(task2, Collections.singleton(t1p2));
+
+        thread.taskManager().setAssignmentMetadata(activeTasks, Collections.<TaskId, Set<TopicPartition>>emptyMap());
+        thread.taskManager().createTasks(assignedPartitions);
+
+        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
         thread.shutdown();
         thread.run();
@@ -706,26 +431,24 @@ public class StreamThreadTest {
     public void shouldShutdownTaskManagerOnClose() throws InterruptedException {
         final Consumer<byte[], byte[]> consumer = EasyMock.createNiceMock(Consumer.class);
         final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class);
-        taskManager.setConsumer(EasyMock.anyObject(Consumer.class));
-        EasyMock.expectLastCall();
+        EasyMock.expect(taskManager.activeTasks()).andReturn(Collections.<TaskId, StreamTask>emptyMap());
+        EasyMock.expect(taskManager.standbyTasks()).andReturn(Collections.<TaskId, StandbyTask>emptyMap());
         taskManager.shutdown(true);
         EasyMock.expectLastCall();
         EasyMock.replay(taskManager, consumer);
 
         StreamThread.StreamsMetricsThreadImpl streamsMetrics = new StreamThread.StreamsMetricsThreadImpl(metrics, "", "", Collections.<String, String>emptyMap());
-        final StreamThread thread = new StreamThread(internalTopologyBuilder,
-                                                     clientId,
-                                                     "",
-                                                     config,
-                                                     processId,
-                                                     mockTime,
-                                                     streamsMetadataState,
-                                                     taskManager,
-                                                     streamsMetrics,
-                                                     clientSupplier,
-                                                     consumer,
-                                                     clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
-                                                     stateDirectory);
+        final StreamThread thread = new StreamThread(mockTime,
+                config,
+                consumer,
+                consumer,
+                null,
+                clientSupplier.getAdminClient(config.getAdminConfigs(clientId)),
+                taskManager,
+                streamsMetrics,
+                internalTopologyBuilder,
+                clientId,
+                new LogContext(""));
         thread.setState(StreamThread.State.RUNNING);
         thread.shutdown();
         thread.run();
@@ -739,112 +462,55 @@ public class StreamThreadTest {
 
         final StreamThread thread = createStreamThread(clientId, config, false);
 
-        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
-            @Override
-            public Map<TaskId, Set<TopicPartition>> standbyTasks() {
-                return Collections.singletonMap(new TaskId(0, 0), Utils.mkSet(new TopicPartition("topic", 0)));
-            }
-        });
-
         thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
-    }
-
-    @Test
-    public void shouldCloseSuspendedTasksThatAreNoLongerAssignedToThisStreamThreadBeforeCreatingNewTasks() {
-        internalStreamsBuilder.stream(Collections.singleton("t1"), consumed).groupByKey().count("count-one");
-        internalStreamsBuilder.stream(Collections.singleton("t2"), consumed).groupByKey().count("count-two");
-
-        final StreamThread thread = createStreamThread(clientId, config, false);
-        final MockConsumer<byte[], byte[]> restoreConsumer = clientSupplier.restoreConsumer;
-        restoreConsumer.updatePartitions("stream-thread-test-count-one-changelog",
-                                         Collections.singletonList(new PartitionInfo("stream-thread-test-count-one-changelog",
-                                                                                     0,
-                                                                                     null,
-                                                                                     new Node[0],
-                                                                                     new Node[0])));
-        restoreConsumer.updatePartitions("stream-thread-test-count-two-changelog",
-                                         Collections.singletonList(new PartitionInfo("stream-thread-test-count-two-changelog",
-                                                                                     0,
-                                                                                     null,
-                                                                                     new Node[0],
-                                                                                     new Node[0])));
-
-
-        final HashMap<TopicPartition, Long> offsets = new HashMap<>();
-        offsets.put(new TopicPartition("stream-thread-test-count-one-changelog", 0), 0L);
-        offsets.put(new TopicPartition("stream-thread-test-count-two-changelog", 0), 0L);
-        restoreConsumer.updateEndOffsets(offsets);
-        restoreConsumer.updateBeginningOffsets(offsets);
 
         final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-        final TopicPartition t1 = new TopicPartition("t1", 0);
-        Set<TopicPartition> partitionsT1 = Utils.mkSet(t1);
-        standbyTasks.put(new TaskId(0, 0), partitionsT1);
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        final TopicPartition t2 = new TopicPartition("t2", 0);
-        Set<TopicPartition> partitionsT2 = Utils.mkSet(t2);
-        activeTasks.put(new TaskId(1, 0), partitionsT2);
-        clientSupplier.consumer.updateBeginningOffsets(Collections.singletonMap(t2, 0L));
 
-        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
-            @Override
-            public Map<TaskId, Set<TopicPartition>> standbyTasks() {
-                return standbyTasks;
-            }
-
-            @Override
-            public Map<TaskId, Set<TopicPartition>> activeTasks() {
-                return activeTasks;
-            }
-        });
+        // assign single partition
+        standbyTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.setState(StreamThread.State.RUNNING);
-        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        clientSupplier.consumer.assign(partitionsT2);
-        thread.rebalanceListener.onPartitionsAssigned(Utils.mkSet(t2));
-        thread.runOnce(-1);
-        // swap the assignment around and make sure we don't get any exceptions
-        standbyTasks.clear();
-        activeTasks.clear();
-        standbyTasks.put(new TaskId(1, 0), Utils.mkSet(t2));
-        activeTasks.put(new TaskId(0, 0), Utils.mkSet(t1));
+        thread.taskManager().setAssignmentMetadata(Collections.<TaskId, Set<TopicPartition>>emptyMap(), standbyTasks);
+        thread.taskManager().createTasks(Collections.<TopicPartition>emptyList());
 
-        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        clientSupplier.consumer.assign(partitionsT1);
-        thread.rebalanceListener.onPartitionsAssigned(Utils.mkSet(t1));
+        thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
     }
 
     @Test
     public void shouldCloseTaskAsZombieAndRemoveFromActiveTasksIfProducerWasFencedWhileProcessing() throws InterruptedException {
-        internalTopologyBuilder.addSource(null, "source", null, null, null, TOPIC);
+        internalTopologyBuilder.addSource(null, "source", null, null, null, topic1);
         internalTopologyBuilder.addSink("sink", "dummyTopic", null, null, null, "source");
 
         final StreamThread thread = createStreamThread(clientId, new StreamsConfig(configProps(true)), true);
 
         final MockConsumer<byte[], byte[]> consumer = clientSupplier.consumer;
-        consumer.updatePartitions(TOPIC, Collections.singletonList(new PartitionInfo(TOPIC, 0, null, null, null)));
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(task1, task0Assignment);
 
-        thread.setThreadMetadataProvider(new MockStreamsPartitionAssignor(activeTasks));
+        consumer.updatePartitions(topic1, Collections.singletonList(new PartitionInfo(topic1, 1, null, null, null)));
 
         thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(null);
-        thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
+
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final List<TopicPartition> assignedPartitions = new ArrayList<>();
+
+        // assign single partition
+        assignedPartitions.add(t1p1);
+        activeTasks.put(task1, Collections.singleton(t1p1));
+
+        thread.taskManager().setAssignmentMetadata(activeTasks, Collections.<TaskId, Set<TopicPartition>>emptyMap());
+
+        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
+
         thread.runOnce(-1);
         assertThat(thread.tasks().size(), equalTo(1));
         final MockProducer producer = clientSupplier.producers.get(0);
 
         // change consumer subscription from "pattern" to "manual" to be able to call .addRecords()
-        consumer.updateBeginningOffsets(Collections.singletonMap(task0Assignment.iterator().next(), 0L));
+        consumer.updateBeginningOffsets(Collections.singletonMap(assignedPartitions.iterator().next(), 0L));
         consumer.unsubscribe();
-        consumer.assign(task0Assignment);
+        consumer.assign(new HashSet<>(assignedPartitions));
 
-        consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, new byte[0], new byte[0]));
+        consumer.addRecord(new ConsumerRecord<>(topic1, 1, 0, new byte[0], new byte[0]));
         mockTime.sleep(config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) + 1);
         thread.runOnce(-1);
         assertThat(producer.history().size(), equalTo(1));
@@ -862,7 +528,7 @@ public class StreamThreadTest {
 
         producer.fenceProducer();
         mockTime.sleep(config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) + 1L);
-        consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, new byte[0], new byte[0]));
+        consumer.addRecord(new ConsumerRecord<>(topic1, 1, 0, new byte[0], new byte[0]));
         try {
             thread.runOnce(-1);
             fail("Should have thrown TaskMigratedException");
@@ -881,26 +547,31 @@ public class StreamThreadTest {
 
     @Test
     public void shouldCloseTaskAsZombieAndRemoveFromActiveTasksIfProducerGotFencedAtBeginTransactionWhenTaskIsResumed() {
-        internalTopologyBuilder.addSource(null, "name", null, null, null, "topic");
+        internalTopologyBuilder.addSource(null, "name", null, null, null, topic1);
         internalTopologyBuilder.addSink("out", "output", null, null, null);
 
         final StreamThread thread = createStreamThread(clientId, new StreamsConfig(configProps(true)), true);
 
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(null);
+
         final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        activeTasks.put(task1, task0Assignment);
+        final List<TopicPartition> assignedPartitions = new ArrayList<>();
 
-        thread.setThreadMetadataProvider(new MockStreamsPartitionAssignor(activeTasks));
+        // assign single partition
+        assignedPartitions.add(t1p1);
+        activeTasks.put(task1, Collections.singleton(t1p1));
+
+        thread.taskManager().setAssignmentMetadata(activeTasks, Collections.<TaskId, Set<TopicPartition>>emptyMap());
+        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
-        thread.setState(StreamThread.State.RUNNING);
-        thread.rebalanceListener.onPartitionsRevoked(null);
-        thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
         thread.runOnce(-1);
 
         assertThat(thread.tasks().size(), equalTo(1));
 
         thread.rebalanceListener.onPartitionsRevoked(null);
         clientSupplier.producers.get(0).fenceProducer();
-        thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
+        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
         try {
             thread.runOnce(-1);
             fail("Should have thrown TaskMigratedException");
@@ -909,80 +580,6 @@ public class StreamThreadTest {
         assertTrue(thread.tasks().isEmpty());
     }
 
-    @Test
-    @SuppressWarnings("unchecked")
-    public void shouldAlwaysUpdateWithLatestTopicsFromStreamPartitionAssignor() throws Exception {
-        internalTopologyBuilder.addSource(null, "source", null, null, null, Pattern.compile("t.*"));
-        internalTopologyBuilder.addProcessor("processor", new MockProcessorSupplier(), "source");
-
-        final StreamThread thread = createStreamThread(clientId, config, false);
-
-        final StreamPartitionAssignor partitionAssignor = new StreamPartitionAssignor();
-        final Map<String, Object> configurationMap = new HashMap<>();
-
-        configurationMap.put(StreamsConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread);
-        configurationMap.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 0);
-        partitionAssignor.configure(configurationMap);
-
-        thread.setThreadMetadataProvider(partitionAssignor);
-
-        final Field nodeToSourceTopicsField =
-            internalTopologyBuilder.getClass().getDeclaredField("nodeToSourceTopics");
-        nodeToSourceTopicsField.setAccessible(true);
-        final Map<String, List<String>>
-            nodeToSourceTopics =
-            (Map<String, List<String>>) nodeToSourceTopicsField.get(internalTopologyBuilder);
-        final List<TopicPartition> topicPartitions = new ArrayList<>();
-
-        final TopicPartition topicPartition1 = new TopicPartition("topic-1", 0);
-        final TopicPartition topicPartition2 = new TopicPartition("topic-2", 0);
-        final TopicPartition topicPartition3 = new TopicPartition("topic-3", 0);
-
-        final TaskId taskId1 = new TaskId(0, 0);
-        final TaskId taskId2 = new TaskId(0, 0);
-        final TaskId taskId3 = new TaskId(0, 0);
-
-        List<TaskId> activeTasks = Utils.mkList(taskId1);
-
-        final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-
-        AssignmentInfo info = new AssignmentInfo(activeTasks, standbyTasks, new HashMap<HostInfo, Set<TopicPartition>>());
-
-        topicPartitions.addAll(Utils.mkList(topicPartition1));
-        PartitionAssignor.Assignment assignment = new PartitionAssignor.Assignment(topicPartitions, info.encode());
-        partitionAssignor.onAssignment(assignment);
-
-        assertTrue(nodeToSourceTopics.get("source").size() == 1);
-        assertTrue(nodeToSourceTopics.get("source").contains("topic-1"));
-
-        topicPartitions.clear();
-
-        activeTasks = Arrays.asList(taskId1, taskId2);
-        info = new AssignmentInfo(activeTasks, standbyTasks, new HashMap<HostInfo, Set<TopicPartition>>());
-        topicPartitions.addAll(Arrays.asList(topicPartition1, topicPartition2));
-        assignment = new PartitionAssignor.Assignment(topicPartitions, info.encode());
-        partitionAssignor.onAssignment(assignment);
-
-        assertTrue(nodeToSourceTopics.get("source").size() == 2);
-        assertTrue(nodeToSourceTopics.get("source").contains("topic-1"));
-        assertTrue(nodeToSourceTopics.get("source").contains("topic-2"));
-
-        topicPartitions.clear();
-
-        activeTasks = Arrays.asList(taskId1, taskId2, taskId3);
-        info = new AssignmentInfo(activeTasks, standbyTasks,
-                               new HashMap<HostInfo, Set<TopicPartition>>());
-        topicPartitions.addAll(Arrays.asList(topicPartition1, topicPartition2, topicPartition3));
-        assignment = new PartitionAssignor.Assignment(topicPartitions, info.encode());
-        partitionAssignor.onAssignment(assignment);
-
-        assertTrue(nodeToSourceTopics.get("source").size() == 3);
-        assertTrue(nodeToSourceTopics.get("source").contains("topic-1"));
-        assertTrue(nodeToSourceTopics.get("source").contains("topic-2"));
-        assertTrue(nodeToSourceTopics.get("source").contains("topic-3"));
-
-    }
-
     private static class StateListenerStub implements StreamThread.StateListener {
         int numChanges = 0;
         ThreadStateTransitionValidator oldState = null;
@@ -1006,33 +603,39 @@ public class StreamThreadTest {
         return createStreamThread(clientId, config, false);
     }
 
-
     @Test
     public void shouldReturnActiveTaskMetadataWhileRunningState() throws InterruptedException {
-        internalTopologyBuilder.addSource(null, "source", null, null, null, TOPIC);
+        internalTopologyBuilder.addSource(null, "source", null, null, null, topic1);
 
-        final TaskId taskId = new TaskId(0, 0);
         final StreamThread thread = createStreamThread(clientId, config, false);
 
-        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>();
-        assignment.put(taskId, task0Assignment);
-        thread.setThreadMetadataProvider(new MockStreamsPartitionAssignor(assignment));
-
         thread.setState(StreamThread.State.RUNNING);
 
         thread.rebalanceListener.onPartitionsRevoked(null);
-        thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
+
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final List<TopicPartition> assignedPartitions = new ArrayList<>();
+
+        // assign single partition
+        assignedPartitions.add(t1p1);
+        activeTasks.put(task1, Collections.singleton(t1p1));
+
+        thread.taskManager().setAssignmentMetadata(activeTasks, Collections.<TaskId, Set<TopicPartition>>emptyMap());
+        thread.taskManager().createTasks(assignedPartitions);
+
+        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
+
         thread.runOnce(-1);
 
         ThreadMetadata threadMetadata = thread.threadMetadata();
         assertEquals(StreamThread.State.RUNNING.name(), threadMetadata.threadState());
-        assertTrue(threadMetadata.activeTasks().contains(new TaskMetadata(taskId.toString(), task0Assignment)));
+        assertTrue(threadMetadata.activeTasks().contains(new TaskMetadata(task1.toString(), Utils.mkSet(t1p1))));
         assertTrue(threadMetadata.standbyTasks().isEmpty());
     }
 
     @Test
     public void shouldReturnStandbyTaskMetadataWhileRunningState() throws InterruptedException {
-        internalStreamsBuilder.stream(Collections.singleton("t1"), consumed).groupByKey().count("count-one");
+        internalStreamsBuilder.stream(Collections.singleton(topic1), consumed).groupByKey().count("count-one");
 
         final StreamThread thread = createStreamThread(clientId, config, false);
         final MockConsumer<byte[], byte[]> restoreConsumer = clientSupplier.restoreConsumer;
@@ -1044,30 +647,28 @@ public class StreamThreadTest {
                         new Node[0])));
 
         final HashMap<TopicPartition, Long> offsets = new HashMap<>();
-        offsets.put(new TopicPartition("stream-thread-test-count-one-changelog", 0), 0L);
+        offsets.put(new TopicPartition("stream-thread-test-count-one-changelog", 1), 0L);
         restoreConsumer.updateEndOffsets(offsets);
         restoreConsumer.updateBeginningOffsets(offsets);
 
-        final TaskId taskId = new TaskId(0, 0);
+        thread.setState(StreamThread.State.RUNNING);
+
+        thread.rebalanceListener.onPartitionsRevoked(null);
 
         final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-        final TopicPartition t1 = new TopicPartition("t1", 0);
-        Set<TopicPartition> partitionsT1 = Utils.mkSet(t1);
-        standbyTasks.put(taskId, partitionsT1);
 
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        // assign single partition
+        standbyTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.setThreadMetadataProvider(new MockStreamsPartitionAssignor(activeTasks, standbyTasks));
+        thread.taskManager().setAssignmentMetadata(Collections.<TaskId, Set<TopicPartition>>emptyMap(), standbyTasks);
 
-        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
 
-        thread.rebalanceListener.onPartitionsRevoked(task0Assignment);
-        thread.rebalanceListener.onPartitionsAssigned(null);
         thread.runOnce(-1);
 
         ThreadMetadata threadMetadata = thread.threadMetadata();
         assertEquals(StreamThread.State.RUNNING.name(), threadMetadata.threadState());
-        assertTrue(threadMetadata.standbyTasks().contains(new TaskMetadata(taskId.toString(), partitionsT1)));
+        assertTrue(threadMetadata.standbyTasks().contains(new TaskMetadata(task1.toString(), Utils.mkSet(t1p1))));
         assertTrue(threadMetadata.activeTasks().isEmpty());
     }
 
@@ -1084,7 +685,7 @@ public class StreamThreadTest {
 
     @Test
     public void shouldAlwaysReturnEmptyTasksMetadataWhileRebalancingStateAndTasksNotRunning() throws InterruptedException {
-        internalStreamsBuilder.stream(Collections.singleton("t1"), consumed).groupByKey().count("count-one");
+        internalStreamsBuilder.stream(Collections.singleton(topic1), consumed).groupByKey().count("count-one");
 
         final StreamThread thread = createStreamThread(clientId, config, false);
         final MockConsumer<byte[], byte[]> restoreConsumer = clientSupplier.restoreConsumer;
@@ -1107,48 +708,28 @@ public class StreamThreadTest {
         restoreConsumer.updateEndOffsets(offsets);
         restoreConsumer.updateBeginningOffsets(offsets);
 
-        final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-        final TopicPartition t1p0 = new TopicPartition("t1", 0);
-        Set<TopicPartition> partitionsT1P0 = Utils.mkSet(t1p0);
-        standbyTasks.put(new TaskId(0, 0), partitionsT1P0);
-
-        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        final TopicPartition t1p1 = new TopicPartition("t1", 1);
-        Set<TopicPartition> partitionsT1P1 = Utils.mkSet(t1p1);
-        activeTasks.put(new TaskId(0, 1), partitionsT1P1);
         clientSupplier.consumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L));
-        thread.setThreadMetadataProvider(new StreamPartitionAssignor() {
-            @Override
-            public Map<TaskId, Set<TopicPartition>> standbyTasks() {
-                return standbyTasks;
-            }
-
-            @Override
-            public Map<TaskId, Set<TopicPartition>> activeTasks() {
-                return activeTasks;
-            }
-        });
 
         thread.setState(StreamThread.State.RUNNING);
 
-        thread.rebalanceListener.onPartitionsRevoked(partitionsT1P0);
+        final List<TopicPartition> assignedPartitions = new ArrayList<>();
+
+        thread.rebalanceListener.onPartitionsRevoked(assignedPartitions);
         assertThreadMetadataHasEmptyTasksWithState(thread.threadMetadata(), StreamThread.State.PARTITIONS_REVOKED);
 
-        clientSupplier.consumer.assign(partitionsT1P1);
-        thread.rebalanceListener.onPartitionsAssigned(partitionsT1P1);
-        assertThreadMetadataHasEmptyTasksWithState(thread.threadMetadata(), StreamThread.State.PARTITIONS_ASSIGNED);
-        thread.runOnce(-1);
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
 
-        standbyTasks.clear();
-        activeTasks.clear();
-        standbyTasks.put(new TaskId(0, 1), Utils.mkSet(t1p1));
-        activeTasks.put(new TaskId(0, 0), Utils.mkSet(t1p0));
+        // assign single partition
+        assignedPartitions.add(t1p1);
+        activeTasks.put(task1, Collections.singleton(t1p1));
+        standbyTasks.put(task2, Collections.singleton(t1p2));
 
-        assertFalse(thread.threadMetadata().activeTasks().isEmpty());
-        assertFalse(thread.threadMetadata().standbyTasks().isEmpty());
+        thread.taskManager().setAssignmentMetadata(activeTasks, standbyTasks);
 
-        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
-        assertThreadMetadataHasEmptyTasksWithState(thread.threadMetadata(), StreamThread.State.PARTITIONS_REVOKED);
+        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
+
+        assertThreadMetadataHasEmptyTasksWithState(thread.threadMetadata(), StreamThread.State.PARTITIONS_ASSIGNED);
     }
 
     private void assertThreadMetadataHasEmptyTasksWithState(ThreadMetadata metadata, StreamThread.State state) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClientTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClientTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClientTest.java
index a399dd4..660a622 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClientTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClientTest.java
@@ -25,11 +25,8 @@ import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.config.TopicConfig;
 import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.MetricsReporter;
-import org.apache.kafka.common.protocol.ApiKeys;
-import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.ApiError;
-import org.apache.kafka.common.requests.ApiVersionsResponse;
 import org.apache.kafka.common.requests.CreateTopicsRequest;
 import org.apache.kafka.common.requests.CreateTopicsResponse;
 import org.apache.kafka.common.requests.MetadataResponse;
@@ -79,10 +76,9 @@ public class StreamsKafkaClientTest {
     public void testConfigFromStreamsConfig() {
         for (final String expectedMechanism : asList("PLAIN", "SCRAM-SHA-512")) {
             config.put(SaslConfigs.SASL_MECHANISM, expectedMechanism);
-            final StreamsConfig streamsConfig = new StreamsConfig(config);
-            final AbstractConfig config = StreamsKafkaClient.Config.fromStreamsConfig(streamsConfig);
-            assertEquals(expectedMechanism, config.values().get(SaslConfigs.SASL_MECHANISM));
-            assertEquals(expectedMechanism, config.getString(SaslConfigs.SASL_MECHANISM));
+            final AbstractConfig abstractConfig = StreamsKafkaClient.Config.fromStreamsConfig(config);
+            assertEquals(expectedMechanism, abstractConfig.values().get(SaslConfigs.SASL_MECHANISM));
+            assertEquals(expectedMechanism, abstractConfig.getString(SaslConfigs.SASL_MECHANISM));
         }
     }
 
@@ -138,7 +134,7 @@ public class StreamsKafkaClientTest {
     public void metricsShouldBeTaggedWithClientId() {
         config.put(StreamsConfig.CLIENT_ID_CONFIG, "some_client_id");
         config.put(StreamsConfig.METRIC_REPORTER_CLASSES_CONFIG, TestMetricsReporter.class.getName());
-        StreamsKafkaClient.create(new StreamsConfig(config));
+        StreamsKafkaClient.create(config);
         assertFalse(TestMetricsReporter.METRICS.isEmpty());
         for (KafkaMetric kafkaMetric : TestMetricsReporter.METRICS.values()) {
             assertEquals("some_client_id", kafkaMetric.metricName().tags().get("client-id"));
@@ -146,34 +142,6 @@ public class StreamsKafkaClientTest {
     }
 
     @Test(expected = StreamsException.class)
-    public void shouldThrowStreamsExceptionOnEmptyBrokerCompatibilityResponse() {
-        kafkaClient.prepareResponse(null);
-        final StreamsKafkaClient streamsKafkaClient = createStreamsKafkaClient();
-        streamsKafkaClient.checkBrokerCompatibility(false);
-    }
-
-    @Test(expected = StreamsException.class)
-    public void shouldThrowStreamsExceptionWhenBrokerCompatibilityResponseInconsistent() {
-        kafkaClient.prepareResponse(new ProduceResponse(Collections.<TopicPartition, ProduceResponse.PartitionResponse>emptyMap()));
-        final StreamsKafkaClient streamsKafkaClient = createStreamsKafkaClient();
-        streamsKafkaClient.checkBrokerCompatibility(false);
-    }
-
-    @Test(expected = StreamsException.class)
-    public void shouldRequireBrokerVersion0101OrHigherWhenEosDisabled() {
-        kafkaClient.prepareResponse(new ApiVersionsResponse(Errors.NONE, Collections.singletonList(new ApiVersionsResponse.ApiVersion(ApiKeys.PRODUCE))));
-        final StreamsKafkaClient streamsKafkaClient = createStreamsKafkaClient();
-        streamsKafkaClient.checkBrokerCompatibility(false);
-    }
-
-    @Test(expected = StreamsException.class)
-    public void shouldRequireBrokerVersions0110OrHigherWhenEosEnabled() {
-        kafkaClient.prepareResponse(new ApiVersionsResponse(Errors.NONE, Collections.singletonList(new ApiVersionsResponse.ApiVersion(ApiKeys.CREATE_TOPICS))));
-        final StreamsKafkaClient streamsKafkaClient = createStreamsKafkaClient();
-        streamsKafkaClient.checkBrokerCompatibility(true);
-    }
-
-    @Test(expected = StreamsException.class)
     public void shouldThrowStreamsExceptionOnEmptyFetchMetadataResponse() {
         kafkaClient.prepareResponse(null);
         final StreamsKafkaClient streamsKafkaClient = createStreamsKafkaClient();
@@ -213,8 +181,7 @@ public class StreamsKafkaClientTest {
     }
 
     private StreamsKafkaClient createStreamsKafkaClient() {
-        final StreamsConfig streamsConfig = new StreamsConfig(config);
-        return new StreamsKafkaClient(StreamsKafkaClient.Config.fromStreamsConfig(streamsConfig),
+        return new StreamsKafkaClient(StreamsKafkaClient.Config.fromStreamsConfig(config),
                                       kafkaClient,
                                       reporters,
                                       new LogContext());

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 1640f9e..55dcf79 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -22,19 +22,28 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.TaskId;
+
 import org.easymock.EasyMock;
 import org.easymock.EasyMockRunner;
 import org.easymock.Mock;
 import org.easymock.MockType;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 import org.junit.runner.RunWith;
 
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.UUID;
+import java.util.regex.Pattern;
 
 import static org.easymock.EasyMock.checkOrder;
 import static org.easymock.EasyMock.expect;
@@ -54,9 +63,17 @@ public class TaskManagerTest {
     private final Set<TopicPartition> taskId0Partitions = Utils.mkSet(t1p0);
     private final Map<TaskId, Set<TopicPartition>> taskId0Assignment = Collections.singletonMap(taskId0, taskId0Partitions);
 
+    @Mock(type = MockType.STRICT)
+    private InternalTopologyBuilder.SubscriptionUpdates subscriptionUpdates;
+    @Mock(type = MockType.STRICT)
+    private InternalTopologyBuilder topologyBuilder;
+    @Mock(type = MockType.STRICT)
+    private StateDirectory stateDirectory;
     @Mock(type = MockType.NICE)
     private ChangelogReader changeLogReader;
     @Mock(type = MockType.NICE)
+    private StreamsMetadataState streamsMetadataState;
+    @Mock(type = MockType.NICE)
     private Consumer<byte[], byte[]> restoreConsumer;
     @Mock(type = MockType.NICE)
     private Consumer<byte[], byte[]> consumer;
@@ -65,7 +82,7 @@ public class TaskManagerTest {
     @Mock(type = MockType.NICE)
     private StreamThread.AbstractTaskCreator<StandbyTask> standbyTaskCreator;
     @Mock(type = MockType.NICE)
-    private ThreadMetadataProvider threadMetadataProvider;
+    private StreamsKafkaClient streamsKafkaClient;
     @Mock(type = MockType.NICE)
     private StreamTask streamTask;
     @Mock(type = MockType.NICE)
@@ -77,17 +94,35 @@ public class TaskManagerTest {
 
     private TaskManager taskManager;
 
+    private final String topic1 = "topic1";
+    private final String topic2 = "topic2";
+    private final TopicPartition t1p1 = new TopicPartition(topic1, 1);
+    private final TopicPartition t1p2 = new TopicPartition(topic1, 2);
+    private final TopicPartition t1p3 = new TopicPartition(topic1, 3);
+    private final TopicPartition t2p1 = new TopicPartition(topic2, 1);
+    private final TopicPartition t2p2 = new TopicPartition(topic2, 2);
+    private final TopicPartition t2p3 = new TopicPartition(topic2, 3);
+
+    private final TaskId task01 = new TaskId(0, 1);
+    private final TaskId task02 = new TaskId(0, 2);
+    private final TaskId task03 = new TaskId(0, 3);
+    private final TaskId task11 = new TaskId(1, 1);
+
+    @Rule
+    public final TemporaryFolder testFolder = new TemporaryFolder();
 
     @Before
     public void setUp() throws Exception {
         taskManager = new TaskManager(changeLogReader,
+                                      UUID.randomUUID(),
                                       "",
                                       restoreConsumer,
+                                      streamsMetadataState,
                                       activeTaskCreator,
                                       standbyTaskCreator,
+                                      streamsKafkaClient,
                                       active,
                                       standby);
-        taskManager.setThreadMetadataProvider(threadMetadataProvider);
         taskManager.setConsumer(consumer);
     }
 
@@ -97,18 +132,110 @@ public class TaskManagerTest {
                         consumer,
                         activeTaskCreator,
                         standbyTaskCreator,
-                        threadMetadataProvider,
                         active,
                         standby);
     }
 
     @Test
+    public void shouldUpdateSubscriptionFromAssignment() {
+        mockTopologyBuilder();
+        expect(subscriptionUpdates.getUpdates()).andReturn(Utils.mkSet(topic1));
+        topologyBuilder.updateSubscribedTopics(EasyMock.eq(Utils.mkSet(topic1, topic2)), EasyMock.anyString());
+        expectLastCall().once();
+
+        EasyMock.replay(activeTaskCreator,
+                        topologyBuilder,
+                        subscriptionUpdates);
+
+        taskManager.updateSubscriptionsFromAssignment(Utils.mkList(t1p1, t2p1));
+
+        EasyMock.verify(activeTaskCreator,
+                        topologyBuilder,
+                        subscriptionUpdates);
+    }
+
+    @Test
+    public void shouldNotUpdateSubscriptionFromAssignment() {
+        mockTopologyBuilder();
+        expect(subscriptionUpdates.getUpdates()).andReturn(Utils.mkSet(topic1, topic2));
+
+        EasyMock.replay(activeTaskCreator,
+                        topologyBuilder,
+                        subscriptionUpdates);
+
+        taskManager.updateSubscriptionsFromAssignment(Utils.mkList(t1p1));
+
+        EasyMock.verify(activeTaskCreator,
+                        topologyBuilder,
+                        subscriptionUpdates);
+    }
+
+    @Test
+    public void shouldUpdateSubscriptionFromMetadata() {
+        mockTopologyBuilder();
+        expect(subscriptionUpdates.getUpdates()).andReturn(Utils.mkSet(topic1));
+        topologyBuilder.updateSubscribedTopics(EasyMock.eq(Utils.mkSet(topic1, topic2)), EasyMock.anyString());
+        expectLastCall().once();
+
+        EasyMock.replay(activeTaskCreator,
+                topologyBuilder,
+                subscriptionUpdates);
+
+        taskManager.updateSubscriptionsFromMetadata(Utils.mkSet(topic1, topic2));
+
+        EasyMock.verify(activeTaskCreator,
+                topologyBuilder,
+                subscriptionUpdates);
+    }
+
+    @Test
+    public void shouldNotUpdateSubscriptionFromMetadata() {
+        mockTopologyBuilder();
+        expect(subscriptionUpdates.getUpdates()).andReturn(Utils.mkSet(topic1));
+
+        EasyMock.replay(activeTaskCreator,
+                topologyBuilder,
+                subscriptionUpdates);
+
+        taskManager.updateSubscriptionsFromMetadata(Utils.mkSet(topic1));
+
+        EasyMock.verify(activeTaskCreator,
+                topologyBuilder,
+                subscriptionUpdates);
+    }
+
+    @Test
+    public void shouldReturnCachedTaskIdsFromDirectory() throws IOException {
+        File[] taskFolders = Utils.mkList(testFolder.newFolder("0_1"),
+                testFolder.newFolder("0_2"),
+                testFolder.newFolder("0_3"),
+                testFolder.newFolder("1_1"),
+                testFolder.newFolder("dummy")).toArray(new File[0]);
+
+        assertTrue((new File(taskFolders[0], ProcessorStateManager.CHECKPOINT_FILE_NAME)).createNewFile());
+        assertTrue((new File(taskFolders[1], ProcessorStateManager.CHECKPOINT_FILE_NAME)).createNewFile());
+        assertTrue((new File(taskFolders[3], ProcessorStateManager.CHECKPOINT_FILE_NAME)).createNewFile());
+
+        expect(activeTaskCreator.stateDirectory()).andReturn(stateDirectory).once();
+        expect(stateDirectory.listTaskDirectories()).andReturn(taskFolders).once();
+
+        EasyMock.replay(activeTaskCreator, stateDirectory);
+
+        Set<TaskId> tasks = taskManager.cachedTasksIds();
+
+        EasyMock.verify(activeTaskCreator, stateDirectory);
+
+        assertThat(tasks, equalTo(Utils.mkSet(task01, task02, task11)));
+    }
+
+    @Test
     public void shouldCloseActiveUnAssignedSuspendedTasksWhenCreatingNewTasks() {
         mockSingleActiveTask();
         active.closeNonAssignedSuspendedTasks(taskId0Assignment);
         expectLastCall();
         replay();
 
+        taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         taskManager.createTasks(taskId0Partitions);
 
         verify(active);
@@ -121,6 +248,7 @@ public class TaskManagerTest {
         expectLastCall();
         replay();
 
+        taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         taskManager.createTasks(taskId0Partitions);
 
         verify(active);
@@ -133,6 +261,7 @@ public class TaskManagerTest {
         EasyMock.expectLastCall();
         replay();
 
+        taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         taskManager.createTasks(taskId0Partitions);
         verify(changeLogReader);
     }
@@ -144,6 +273,7 @@ public class TaskManagerTest {
         active.addNewTask(EasyMock.same(streamTask));
         replay();
 
+        taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         taskManager.createTasks(taskId0Partitions);
 
         verify(activeTaskCreator, active);
@@ -152,10 +282,10 @@ public class TaskManagerTest {
     @Test
     public void shouldNotAddResumedActiveTasks() {
         checkOrder(active, true);
-        mockThreadMetadataProvider(Collections.<TaskId, Set<TopicPartition>>emptyMap(), taskId0Assignment);
         EasyMock.expect(active.maybeResumeSuspendedTask(taskId0, taskId0Partitions)).andReturn(true);
         replay();
 
+        taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         taskManager.createTasks(taskId0Partitions);
 
         // should be no calls to activeTaskCreator and no calls to active.addNewTasks(..)
@@ -169,6 +299,7 @@ public class TaskManagerTest {
         standby.addNewTask(EasyMock.same(standbyTask));
         replay();
 
+        taskManager.setAssignmentMetadata(Collections.<TaskId, Set<TopicPartition>>emptyMap(), taskId0Assignment);
         taskManager.createTasks(taskId0Partitions);
 
         verify(standbyTaskCreator, active);
@@ -177,10 +308,10 @@ public class TaskManagerTest {
     @Test
     public void shouldNotAddResumedStandbyTasks() {
         checkOrder(active, true);
-        mockThreadMetadataProvider(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         EasyMock.expect(standby.maybeResumeSuspendedTask(taskId0, taskId0Partitions)).andReturn(true);
         replay();
 
+        taskManager.setAssignmentMetadata(Collections.<TaskId, Set<TopicPartition>>emptyMap(), taskId0Assignment);
         taskManager.createTasks(taskId0Partitions);
 
         // should be no calls to standbyTaskCreator and no calls to standby.addNewTasks(..)
@@ -196,6 +327,7 @@ public class TaskManagerTest {
         EasyMock.expectLastCall();
         replay();
 
+        taskManager.setAssignmentMetadata(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         taskManager.createTasks(taskId0Partitions);
         verify(consumer);
     }
@@ -276,25 +408,6 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldCloseThreadMetadataProviderOnShutdown() {
-        threadMetadataProvider.close();
-        EasyMock.expectLastCall();
-        replay();
-
-        taskManager.shutdown(true);
-        verify(threadMetadataProvider);
-    }
-
-    @Test
-    public void shouldNotPropagateExceptionsOnShutdown() {
-        threadMetadataProvider.close();
-        EasyMock.expectLastCall().andThrow(new RuntimeException());
-        replay();
-
-        taskManager.shutdown(false);
-    }
-
-    @Test
     public void shouldInitializeNewActiveTasks() {
         EasyMock.expect(active.initializeNewTasks()).andReturn(new HashSet<TopicPartition>());
         EasyMock.expect(active.updateRestored(EasyMock.<Collection<TopicPartition>>anyObject())).
@@ -471,6 +584,26 @@ public class TaskManagerTest {
         EasyMock.verify(consumer);
     }
 
+    @Test
+    public void shouldUpdateTasksFromPartitionAssignment() {
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
+
+        taskManager.setAssignmentMetadata(activeTasks, standbyTasks);
+        assertTrue(taskManager.assignedActiveTasks().isEmpty());
+
+        // assign two active tasks with two partitions each
+        activeTasks.put(task01, new HashSet<>(Arrays.asList(t1p1, t2p1)));
+        activeTasks.put(task02, new HashSet<>(Arrays.asList(t1p2, t2p2)));
+
+        // assign one standby task with two partitions
+        standbyTasks.put(task03, new HashSet<>(Arrays.asList(t1p3, t2p3)));
+        taskManager.setAssignmentMetadata(activeTasks, standbyTasks);
+
+        assertThat(taskManager.assignedActiveTasks(), equalTo(activeTasks));
+        assertThat(taskManager.assignedStandbyTasks(), equalTo(standbyTasks));
+    }
+
     private void mockAssignStandbyPartitions(final long offset) {
         final StandbyTask task = EasyMock.createNiceMock(StandbyTask.class);
         EasyMock.expect(active.initializeNewTasks()).andReturn(new HashSet<TopicPartition>());
@@ -486,31 +619,22 @@ public class TaskManagerTest {
     }
 
     private void mockStandbyTaskExpectations() {
-        mockThreadMetadataProvider(taskId0Assignment, Collections.<TaskId, Set<TopicPartition>>emptyMap());
         expect(standbyTaskCreator.createTasks(EasyMock.<Consumer<byte[], byte[]>>anyObject(),
                                                    EasyMock.eq(taskId0Assignment)))
                 .andReturn(Collections.singletonList(standbyTask));
 
     }
 
-    @SuppressWarnings("unchecked")
     private void mockSingleActiveTask() {
-        mockThreadMetadataProvider(Collections.<TaskId, Set<TopicPartition>>emptyMap(), taskId0Assignment);
-
-        expect(activeTaskCreator.createTasks(EasyMock.anyObject(Consumer.class),
+        expect(activeTaskCreator.createTasks(EasyMock.<Consumer<byte[], byte[]>>anyObject(),
                                                   EasyMock.eq(taskId0Assignment)))
                 .andReturn(Collections.singletonList(streamTask));
 
     }
 
-    private void mockThreadMetadataProvider(final Map<TaskId, Set<TopicPartition>> standbyAssignment,
-                                            final Map<TaskId, Set<TopicPartition>> activeAssignment) {
-        expect(threadMetadataProvider.standbyTasks())
-                .andReturn(standbyAssignment)
-                .anyTimes();
-        expect(threadMetadataProvider.activeTasks())
-                .andReturn(activeAssignment)
-                .anyTimes();
+    private void mockTopologyBuilder() {
+        expect(activeTaskCreator.builder()).andReturn(topologyBuilder).anyTimes();
+        expect(topologyBuilder.sourceTopicPattern()).andReturn(Pattern.compile("abc"));
+        expect(topologyBuilder.subscriptionUpdates()).andReturn(subscriptionUpdates);
     }
-
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java b/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java
index 3908305..598ca8d 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java
@@ -38,7 +38,7 @@ public class MockInternalTopicManager extends InternalTopicManager {
     private MockConsumer<byte[], byte[]> restoreConsumer;
 
     public MockInternalTopicManager(StreamsConfig streamsConfig, MockConsumer<byte[], byte[]> restoreConsumer) {
-        super(StreamsKafkaClient.create(streamsConfig), 0, 0, new MockTime());
+        super(StreamsKafkaClient.create(streamsConfig.originals()), 0, 0, new MockTime());
 
         this.restoreConsumer = restoreConsumer;
     }


[3/3] kafka git commit: KAFKA-6170; KIP-220 Part 2: Break dependency of Assignor on StreamThread

Posted by gu...@apache.org.
KAFKA-6170; KIP-220 Part 2: Break dependency of Assignor on StreamThread

This refactoring is discussed in https://github.com/apache/kafka/pull/3624#discussion_r132614639. More specifically:

1. Moved the access of `StreamThread` in `StreamPartitionAssignor` to `TaskManager`, removed any fields stored in `StreamThread` such as `processId` and `clientId` that are only to be used in `StreamPartitionAssignor`, and pass them to `TaskManager` if necessary.
2. Moved any in-memory states, `metadataWithInternalTopics`, `partitionsByHostState`, `standbyTasks`, `activeTasks` to `TaskManager` so that `StreamPartitionAssignor` becomes a stateless thin layer that access TaskManager directly.
3. Remove the reference of `StreamPartitionAssignor` in `StreamThread`, instead consolidate all related functionalities such as `cachedTasksIds ` in `TaskManager` which could be retrieved by the `StreamThread` and the `StreamPartitionAssignor` directly.
4. Finally, removed the two interfaces used for `StreamThread` and `StreamPartitionAssignor`.

5. Some minor fixes on logPrefixes, etc.

Future work: when replacing the StreamsKafkaClient, we would let `StreamPartitionAssignor` to retrieve it from `TaskManager` directly, and also its closing call do not need to be called (`KafkaStreams` will be responsible for closing it).

Author: Guozhang Wang <wa...@gmail.com>

Reviewers: Bill Bejeck <bi...@confluent.io>, Damian Guy <da...@gmail.com>, Matthias J. Sax <ma...@confluent.io>

Closes #4224 from guozhangwang/K6170-refactor-assignor


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/5df1eee7
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/5df1eee7
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/5df1eee7

Branch: refs/heads/trunk
Commit: 5df1eee7d689e18ac2f7b74410e7a30159d3afdc
Parents: 8f6a372
Author: Guozhang Wang <wa...@gmail.com>
Authored: Tue Nov 28 09:37:27 2017 -0800
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Tue Nov 28 09:37:27 2017 -0800

----------------------------------------------------------------------
 .../kafka/streams/KafkaClientSupplier.java      |   3 +-
 .../org/apache/kafka/streams/KafkaStreams.java  |  88 +--
 .../org/apache/kafka/streams/StreamsConfig.java |   9 +-
 .../streams/processor/TopologyBuilder.java      |   7 +-
 .../internals/DefaultKafkaClientSupplier.java   |   1 +
 .../processor/internals/GlobalStreamThread.java |   3 +-
 .../internals/InternalTopicManager.java         |   4 -
 .../internals/InternalTopologyBuilder.java      |  45 +-
 .../internals/StreamPartitionAssignor.java      | 213 ++---
 .../processor/internals/StreamThread.java       | 402 ++++------
 .../processor/internals/StreamsKafkaClient.java |  64 +-
 .../processor/internals/TaskManager.java        | 140 +++-
 .../processor/internals/ThreadDataProvider.java |  36 -
 .../internals/ThreadMetadataProvider.java       |  36 -
 .../apache/kafka/streams/StreamsConfigTest.java |  20 +-
 .../QueryableStateIntegrationTest.java          |   2 +
 .../streams/processor/TopologyBuilderTest.java  |   6 +-
 .../internals/GlobalStreamThreadTest.java       |   2 +
 .../internals/InternalTopicManagerTest.java     |   2 +-
 .../internals/InternalTopologyBuilderTest.java  |   5 +-
 .../internals/StreamPartitionAssignorTest.java  | 295 +++----
 .../processor/internals/StreamThreadTest.java   | 769 +++++--------------
 .../internals/StreamsKafkaClientTest.java       |  43 +-
 .../processor/internals/TaskManagerTest.java    | 200 ++++-
 .../kafka/test/MockInternalTopicManager.java    |   2 +-
 25 files changed, 912 insertions(+), 1485 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java b/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java
index 5561bd1..2ea5218 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java
@@ -20,7 +20,6 @@ import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.streams.processor.StateStore;
-import org.apache.kafka.streams.processor.internals.StreamThread;
 
 import java.util.Map;
 
@@ -50,7 +49,7 @@ public interface KafkaClientSupplier {
     /**
      * Create a {@link Consumer} which is used to read records of source topics.
      *
-     * @param config {@link StreamsConfig#getConsumerConfigs(StreamThread, String, String) consumer config} which is
+     * @param config {@link StreamsConfig#getConsumerConfigs(String, String) consumer config} which is
      *               supplied by the {@link StreamsConfig} given to the {@link KafkaStreams} instance
      * @return an instance of Kafka consumer
      */

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
index 9e67f54..c7dfe71 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
@@ -50,7 +50,6 @@ import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder;
 import org.apache.kafka.streams.processor.internals.ProcessorTopology;
 import org.apache.kafka.streams.processor.internals.StateDirectory;
 import org.apache.kafka.streams.processor.internals.StreamThread;
-import org.apache.kafka.streams.processor.internals.StreamsKafkaClient;
 import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
 import org.apache.kafka.streams.processor.internals.ThreadStateTransitionValidator;
 import org.apache.kafka.streams.state.HostInfo;
@@ -80,8 +79,6 @@ import java.util.concurrent.TimeUnit;
 
 import static org.apache.kafka.common.utils.Utils.getHost;
 import static org.apache.kafka.common.utils.Utils.getPort;
-import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
-import static org.apache.kafka.streams.StreamsConfig.PROCESSING_GUARANTEE_CONFIG;
 
 /**
  * A Kafka client that allows for performing continuous computation on input coming from one or more input topics and
@@ -133,6 +130,7 @@ public class KafkaStreams {
     // in userData of the subscription request to allow assignor be aware
     // of the co-location of stream thread's consumers. It is for internal
     // usage only and should not be exposed to users at all.
+    private final Time time;
     private final Logger log;
     private final UUID processId;
     private final String clientId;
@@ -214,7 +212,7 @@ public class KafkaStreams {
     private volatile State state = State.CREATED;
 
     private boolean waitOnState(final State targetState, final long waitMs) {
-        long begin = System.currentTimeMillis();
+        long begin = time.milliseconds();
         synchronized (stateLock) {
             long elapsedMs = 0L;
             while (state != State.NOT_RUNNING) {
@@ -235,7 +233,7 @@ public class KafkaStreams {
                     log.debug("Cannot transit to {} within {}ms", targetState, waitMs);
                     return false;
                 }
-                elapsedMs = System.currentTimeMillis() - begin;
+                elapsedMs = time.milliseconds() - begin;
             }
             return true;
         }
@@ -587,62 +585,66 @@ public class KafkaStreams {
                          final StreamsConfig config,
                          final KafkaClientSupplier clientSupplier) throws StreamsException {
         this.config = config;
+        time = Time.SYSTEM;
 
         // The application ID is a required config and hence should always have value
         processId = UUID.randomUUID();
-        final String clientId = config.getString(StreamsConfig.CLIENT_ID_CONFIG);
+        final String userClientId = config.getString(StreamsConfig.CLIENT_ID_CONFIG);
         final String applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
-        if (clientId.length() <= 0) {
-            this.clientId = applicationId + "-" + processId;
+        if (userClientId.length() <= 0) {
+            clientId = applicationId + "-" + processId;
         } else {
-            this.clientId = clientId;
+            clientId = userClientId;
         }
 
         final LogContext logContext = new LogContext(String.format("stream-client [%s] ", clientId));
         this.log = logContext.logger(getClass());
 
-        internalTopologyBuilder.setApplicationId(applicationId);
-        // sanity check to fail-fast in case we cannot build a ProcessorTopology due to an exception
-        internalTopologyBuilder.build(null);
-
-        long cacheSize = config.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG);
-        if (cacheSize < 0) {
-            cacheSize = 0;
-            log.warn("Negative cache size passed in. Reverting to cache size of 0 bytes.");
-        }
-
-        final StateRestoreListener delegatingStateRestoreListener = new DelegatingStateRestoreListener();
-
-        threads = new StreamThread[config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG)];
         try {
-            stateDirectory = new StateDirectory(
-                config,
-                Time.SYSTEM);
+            stateDirectory = new StateDirectory(config, time);
         } catch (final ProcessorStateException fatal) {
             throw new StreamsException(fatal);
         }
-        streamsMetadataState = new StreamsMetadataState(
-            internalTopologyBuilder,
-            parseHostInfo(config.getString(StreamsConfig.APPLICATION_SERVER_CONFIG)));
 
         final MetricConfig metricConfig = new MetricConfig().samples(config.getInt(StreamsConfig.METRICS_NUM_SAMPLES_CONFIG))
             .recordLevel(Sensor.RecordingLevel.forName(config.getString(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG)))
             .timeWindow(config.getLong(StreamsConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS);
         final List<MetricsReporter> reporters = config.getConfiguredInstances(StreamsConfig.METRIC_REPORTER_CLASSES_CONFIG,
-            MetricsReporter.class);
+                MetricsReporter.class);
         reporters.add(new JmxReporter(JMX_PREFIX));
-        metrics = new Metrics(metricConfig, reporters, Time.SYSTEM);
+        metrics = new Metrics(metricConfig, reporters, time);
 
-        GlobalStreamThread.State globalThreadState = null;
+        internalTopologyBuilder.setApplicationId(applicationId);
+
+        // sanity check to fail-fast in case we cannot build a ProcessorTopology due to an exception
+        internalTopologyBuilder.build();
+
+        streamsMetadataState = new StreamsMetadataState(
+                internalTopologyBuilder,
+                parseHostInfo(config.getString(StreamsConfig.APPLICATION_SERVER_CONFIG)));
+
+        // create the stream thread, global update thread, and cleanup thread
+        threads = new StreamThread[config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG)];
+
+        long totalCacheSize = config.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG);
+        if (totalCacheSize < 0) {
+            totalCacheSize = 0;
+            log.warn("Negative cache size passed in. Reverting to cache size of 0 bytes.");
+        }
         final ProcessorTopology globalTaskTopology = internalTopologyBuilder.buildGlobalStateTopology();
+        final long cacheSizePerThread = totalCacheSize / (threads.length + (globalTaskTopology == null ? 0 : 1));
+
+        final StateRestoreListener delegatingStateRestoreListener = new DelegatingStateRestoreListener();
+        GlobalStreamThread.State globalThreadState = null;
         if (globalTaskTopology != null) {
             final String globalThreadId = clientId + "-GlobalStreamThread";
             globalStreamThread = new GlobalStreamThread(globalTaskTopology,
                                                         config,
                                                         clientSupplier.getRestoreConsumer(config.getRestoreConsumerConfigs(clientId + "-global")),
                                                         stateDirectory,
+                                                        cacheSizePerThread,
                                                         metrics,
-                                                        Time.SYSTEM,
+                                                        time,
                                                         globalThreadId,
                                                         delegatingStateRestoreListener);
             globalThreadState = globalStreamThread.state();
@@ -661,9 +663,9 @@ public class KafkaStreams {
                                              processId,
                                              clientId,
                                              metrics,
-                                             Time.SYSTEM,
+                                             time,
                                              streamsMetadataState,
-                                             cacheSize / (threads.length + (globalTaskTopology == null ? 0 : 1)),
+                                             cacheSizePerThread,
                                              stateDirectory,
                                              delegatingStateRestoreListener);
             threadState.put(threads[i].getId(), threads[i].state());
@@ -706,22 +708,6 @@ public class KafkaStreams {
     }
 
     /**
-     * Check if the used brokers have version 0.10.1.x or higher.
-     * <p>
-     * Note, for <em>pre</em> 0.10.x brokers the broker version cannot be checked and the client will hang and retry
-     * until it {@link StreamsConfig#REQUEST_TIMEOUT_MS_CONFIG times out}.
-     *
-     * @throws StreamsException if brokers have version 0.10.0.x
-     */
-    private void checkBrokerVersionCompatibility() throws StreamsException {
-        final StreamsKafkaClient client = StreamsKafkaClient.create(config);
-
-        client.checkBrokerCompatibility(EXACTLY_ONCE.equals(config.getString(PROCESSING_GUARANTEE_CONFIG)));
-
-        client.close();
-    }
-
-    /**
      * Start the {@code KafkaStreams} instance by starting all its threads.
      * This function is expected to be called only once during the life cycle of the client.
      * <p>
@@ -745,8 +731,6 @@ public class KafkaStreams {
         // first set state to RUNNING before kicking off the threads,
         // making sure the state will always transit to RUNNING before REBALANCING
         if (setRunningFromCreated()) {
-            checkBrokerVersionCompatibility();
-
             if (globalStreamThread != null) {
                 globalStreamThread.start();
             }

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
index 941437c..e1f2b09 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
@@ -36,7 +36,6 @@ import org.apache.kafka.streams.processor.DefaultPartitionGrouper;
 import org.apache.kafka.streams.processor.FailOnInvalidTimestamp;
 import org.apache.kafka.streams.processor.TimestampExtractor;
 import org.apache.kafka.streams.processor.internals.StreamPartitionAssignor;
-import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -582,7 +581,7 @@ public class StreamsConfig extends AbstractConfig {
     }
 
     public static class InternalConfig {
-        public static final String STREAM_THREAD_INSTANCE = "__stream.thread.instance__";
+        public static final String TASK_MANAGER_FOR_PARTITION_ASSIGNOR = "__task.manager.instance__";
     }
 
     /**
@@ -722,22 +721,20 @@ public class StreamsConfig extends AbstractConfig {
      * except in the case of {@link ConsumerConfig#BOOTSTRAP_SERVERS_CONFIG} where we always use the non-prefixed
      * version as we only support reading/writing from/to the same Kafka Cluster.
      *
-     * @param streamThread the {@link StreamThread} creating a consumer
      * @param groupId      consumer groupId
      * @param clientId     clientId
      * @return Map of the consumer configuration.
      */
-    public Map<String, Object> getConsumerConfigs(final StreamThread streamThread,
-                                                  final String groupId,
+    public Map<String, Object> getConsumerConfigs(final String groupId,
                                                   final String clientId) {
         final Map<String, Object> consumerProps = getCommonConsumerConfigs();
 
         // add client id with stream client id prefix, and group id
+        consumerProps.put(APPLICATION_ID_CONFIG, groupId);
         consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, groupId);
         consumerProps.put(CommonClientConfigs.CLIENT_ID_CONFIG, clientId + "-consumer");
 
         // add configs required for stream partition assignor
-        consumerProps.put(InternalConfig.STREAM_THREAD_INSTANCE, streamThread);
         consumerProps.put(REPLICATION_FACTOR_CONFIG, getInt(REPLICATION_FACTOR_CONFIG));
         consumerProps.put(NUM_STANDBY_REPLICAS_CONFIG, getInt(NUM_STANDBY_REPLICAS_CONFIG));
         consumerProps.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, StreamPartitionAssignor.class.getName());

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java b/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
index 66dfa27..6f34e25 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
@@ -34,6 +34,7 @@ import org.apache.kafka.streams.state.KeyValueStore;
 
 import java.util.Collection;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -932,7 +933,9 @@ public class TopologyBuilder {
      * for the high-level DSL parsing functionalities.
      */
     public SubscriptionUpdates subscriptionUpdates() {
-        return internalTopologyBuilder.subscriptionUpdates();
+        SubscriptionUpdates clonedSubscriptionUpdates = new SubscriptionUpdates();
+        clonedSubscriptionUpdates.updateTopics(internalTopologyBuilder.subscriptionUpdates().getUpdates());
+        return clonedSubscriptionUpdates;
     }
 
     /**
@@ -949,7 +952,7 @@ public class TopologyBuilder {
      */
     public synchronized void updateSubscriptions(final SubscriptionUpdates subscriptionUpdates,
                                                  final String threadId) {
-        internalTopologyBuilder.updateSubscriptions(subscriptionUpdates, threadId);
+        internalTopologyBuilder.updateSubscribedTopics(new HashSet<>(subscriptionUpdates.getUpdates()), "stream-thread [" + threadId + "] ");
     }
 
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java
index f3038f3..6f01e2f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java
@@ -30,6 +30,7 @@ import org.apache.kafka.streams.KafkaClientSupplier;
 public class DefaultKafkaClientSupplier implements KafkaClientSupplier {
     @Override
     public AdminClient getAdminClient(final Map<String, Object> config) {
+        // create a new client upon each call; but expect this call to be only triggered once so this should be fine
         return AdminClient.create(config);
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
index 24cec25..9d202d1 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
@@ -176,6 +176,7 @@ public class GlobalStreamThread extends Thread {
                               final StreamsConfig config,
                               final Consumer<byte[], byte[]> globalConsumer,
                               final StateDirectory stateDirectory,
+                              final long cacheSizeBytes,
                               final Metrics metrics,
                               final Time time,
                               final String threadClientId,
@@ -186,8 +187,6 @@ public class GlobalStreamThread extends Thread {
         this.topology = topology;
         this.globalConsumer = globalConsumer;
         this.stateDirectory = stateDirectory;
-        long cacheSizeBytes = Math.max(0, config.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG) /
-                (config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG) + 1));
         this.streamsMetrics = new StreamsMetricsImpl(metrics, threadClientId, Collections.singletonMap("client-id", threadClientId));
         this.logPrefix = String.format("global-stream-thread [%s] ", threadClientId);
         this.logContext = new LogContext(logPrefix);

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java
index ae2b375..f8d4eec 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java
@@ -108,10 +108,6 @@ public class InternalTopicManager {
         throw new StreamsException("Could not get number of partitions.");
     }
 
-    public void close() {
-        streamsKafkaClient.close();
-    }
-
     /**
      * Check the existing topics to have correct number of partitions; and return the non existing topics to be created
      */

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java
index f2cbf51..881ecd1 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java
@@ -25,7 +25,6 @@ import org.apache.kafka.streams.processor.ProcessorSupplier;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TimestampExtractor;
-import org.apache.kafka.streams.processor.internals.StreamPartitionAssignor.SubscriptionUpdates;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.state.StoreBuilder;
 import org.apache.kafka.streams.state.internals.WindowStoreBuilder;
@@ -842,6 +841,10 @@ public class InternalTopologyBuilder {
         return nodeGroups;
     }
 
+    public synchronized ProcessorTopology build() {
+        return build((Integer) null);
+    }
+
     public synchronized ProcessorTopology build(final Integer topicGroupId) {
         final Set<String> nodeGroup;
         if (topicGroupId != null) {
@@ -1246,9 +1249,9 @@ public class InternalTopologyBuilder {
     }
 
     public synchronized void updateSubscriptions(final SubscriptionUpdates subscriptionUpdates,
-                                                 final String threadId) {
-        log.debug("stream-thread [{}] updating builder with {} topic(s) with possible matching regex subscription(s)",
-            threadId, subscriptionUpdates);
+                                                 final String logPrefix) {
+        log.debug("{}updating builder with {} topic(s) with possible matching regex subscription(s)",
+                logPrefix, subscriptionUpdates);
         this.subscriptionUpdates = subscriptionUpdates;
         setRegexMatchedTopicsToSourceNodes();
         setRegexMatchedTopicToStateStore();
@@ -1811,4 +1814,38 @@ public class InternalTopologyBuilder {
         return sb.toString();
     }
 
+    /**
+     * Used to capture subscribed topic via Patterns discovered during the
+     * partition assignment process.
+     */
+    public static class SubscriptionUpdates {
+
+        private final Set<String> updatedTopicSubscriptions = new HashSet<>();
+
+        private void updateTopics(final Collection<String> topicNames) {
+            updatedTopicSubscriptions.clear();
+            updatedTopicSubscriptions.addAll(topicNames);
+        }
+
+        public Collection<String> getUpdates() {
+            return Collections.unmodifiableSet(updatedTopicSubscriptions);
+        }
+
+        boolean hasUpdates() {
+            return !updatedTopicSubscriptions.isEmpty();
+        }
+
+        @Override
+        public String toString() {
+            return String.format("SubscriptionUpdates{updatedTopicSubscriptions=%s}", updatedTopicSubscriptions);
+        }
+    }
+
+    public void updateSubscribedTopics(final Set<String> topics, final String logPrefix) {
+        final SubscriptionUpdates subscriptionUpdates = new SubscriptionUpdates();
+        log.debug("{}found {} topics possibly matching regex", topics, logPrefix);
+        // update the topic groups with the returned subscription set for regex pattern subscriptions
+        subscriptionUpdates.updateTopics(topics);
+        updateSubscriptions(subscriptionUpdates, logPrefix);
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
index 9e505a1..ec42a86 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignor.java
@@ -30,6 +30,7 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.TaskAssignmentException;
+import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
 import org.apache.kafka.streams.processor.internals.assignment.ClientState;
@@ -52,9 +53,8 @@ import java.util.UUID;
 
 import static org.apache.kafka.common.utils.Utils.getHost;
 import static org.apache.kafka.common.utils.Utils.getPort;
-import static org.apache.kafka.streams.processor.internals.InternalTopicManager.WINDOW_CHANGE_LOG_ADDITIONAL_RETENTION_DEFAULT;
 
-public class StreamPartitionAssignor implements PartitionAssignor, Configurable, ThreadMetadataProvider {
+public class StreamPartitionAssignor implements PartitionAssignor, Configurable {
 
     private Time time = Time.SYSTEM;
     private final static int UNKNOWN = -1;
@@ -164,21 +164,16 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
             if (result != 0) {
                 return result;
             } else {
-                return p1.partition() < p2.partition() ? UNKNOWN : (p1.partition() > p2.partition() ? 1 : 0);
+                return Integer.compare(p1.partition(), p2.partition());
             }
         }
     };
 
-    private ThreadDataProvider threadDataProvider;
-
     private String userEndPoint;
     private int numStandbyReplicas;
 
-    private Cluster metadataWithInternalTopics;
-    private Map<HostInfo, Set<TopicPartition>> partitionsByHostState;
-
-    private Map<TaskId, Set<TopicPartition>> standbyTasks;
-    private Map<TaskId, Set<TopicPartition>> activeTasks;
+    private TaskManager taskManager;
+    private PartitionGrouper partitionGrouper;
 
     private InternalTopicManager internalTopicManager;
     private CopartitionedTopicsValidator copartitionedTopicsValidator;
@@ -199,30 +194,33 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
      */
     @Override
     public void configure(Map<String, ?> configs) {
-        numStandbyReplicas = (Integer) configs.get(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG);
+        final StreamsConfig streamsConfig = new StreamsConfig(configs);
 
         // Setting the logger with the passed in client thread name
-        logPrefix = String.format("stream-thread [%s] ", configs.get(CommonClientConfigs.CLIENT_ID_CONFIG));
+        logPrefix = String.format("stream-thread [%s] ", streamsConfig.getString(CommonClientConfigs.CLIENT_ID_CONFIG));
         final LogContext logContext = new LogContext(logPrefix);
-        this.log = logContext.logger(getClass());
+        log = logContext.logger(getClass());
 
-        Object o = configs.get(StreamsConfig.InternalConfig.STREAM_THREAD_INSTANCE);
+        final Object o = configs.get(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR);
         if (o == null) {
-            KafkaException ex = new KafkaException("StreamThread is not specified");
+            KafkaException ex = new KafkaException("TaskManager is not specified");
             log.error(ex.getMessage(), ex);
             throw ex;
         }
 
-        if (!(o instanceof ThreadDataProvider)) {
-            KafkaException ex = new KafkaException(String.format("%s is not an instance of %s", o.getClass().getName(), ThreadDataProvider.class.getName()));
+        if (!(o instanceof TaskManager)) {
+            KafkaException ex = new KafkaException(String.format("%s is not an instance of %s", o.getClass().getName(), TaskManager.class.getName()));
             log.error(ex.getMessage(), ex);
             throw ex;
         }
 
-        threadDataProvider = (ThreadDataProvider) o;
-        threadDataProvider.setThreadMetadataProvider(this);
+        taskManager = (TaskManager) o;
+
+        numStandbyReplicas = streamsConfig.getInt(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG);
 
-        String userEndPoint = (String) configs.get(StreamsConfig.APPLICATION_SERVER_CONFIG);
+        partitionGrouper = streamsConfig.getConfiguredInstance(StreamsConfig.PARTITION_GROUPER_CLASS_CONFIG, PartitionGrouper.class);
+
+        final String userEndPoint = streamsConfig.getString(StreamsConfig.APPLICATION_SERVER_CONFIG);
         if (userEndPoint != null && !userEndPoint.isEmpty()) {
             try {
                 String host = getHost(userEndPoint);
@@ -241,13 +239,12 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         }
 
         internalTopicManager = new InternalTopicManager(
-                StreamsKafkaClient.create(this.threadDataProvider.config()),
-                configs.containsKey(StreamsConfig.REPLICATION_FACTOR_CONFIG) ? (Integer) configs.get(StreamsConfig.REPLICATION_FACTOR_CONFIG) : 1,
-                configs.containsKey(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG) ?
-                        (Long) configs.get(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG)
-                        : WINDOW_CHANGE_LOG_ADDITIONAL_RETENTION_DEFAULT, time);
+                taskManager.streamsKafkaClient,
+                streamsConfig.getInt(StreamsConfig.REPLICATION_FACTOR_CONFIG),
+                streamsConfig.getLong(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG),
+                time);
 
-        this.copartitionedTopicsValidator = new CopartitionedTopicsValidator(threadDataProvider.name());
+        copartitionedTopicsValidator = new CopartitionedTopicsValidator(logPrefix);
     }
 
     @Override
@@ -262,27 +259,16 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         // 2. Task ids of previously running tasks
         // 3. Task ids of valid local states on the client's state directory.
 
-        final Set<TaskId> previousActiveTasks = threadDataProvider.prevActiveTasks();
-        Set<TaskId> standbyTasks = threadDataProvider.cachedTasks();
+        final Set<TaskId> previousActiveTasks = taskManager.prevActiveTaskIds();
+        final Set<TaskId> standbyTasks = taskManager.cachedTasksIds();
         standbyTasks.removeAll(previousActiveTasks);
-        SubscriptionInfo data = new SubscriptionInfo(threadDataProvider.processId(), previousActiveTasks, standbyTasks, this.userEndPoint);
+        final SubscriptionInfo data = new SubscriptionInfo(taskManager.processId(), previousActiveTasks, standbyTasks, this.userEndPoint);
 
-        if (threadDataProvider.builder().sourceTopicPattern() != null &&
-            !threadDataProvider.builder().subscriptionUpdates().getUpdates().equals(topics)) {
-            updateSubscribedTopics(topics);
-        }
+        taskManager.updateSubscriptionsFromMetadata(topics);
 
         return new Subscription(new ArrayList<>(topics), data.encode());
     }
 
-    private void updateSubscribedTopics(Set<String> topics) {
-        SubscriptionUpdates subscriptionUpdates = new SubscriptionUpdates();
-        log.debug("found {} topics possibly matching regex", topics);
-        // update the topic groups with the returned subscription set for regex pattern subscriptions
-        subscriptionUpdates.updateTopics(topics);
-        threadDataProvider.builder().updateSubscriptions(subscriptionUpdates, threadDataProvider.name());
-    }
-
     /*
      * This assigns tasks to consumer clients in the following steps.
      *
@@ -333,9 +319,9 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         // parse the topology to determine the repartition source topics,
         // making sure they are created with the number of partitions as
         // the maximum of the depending sub-topologies source topics' number of partitions
-        Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups = threadDataProvider.builder().topicGroups();
+        final Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups = taskManager.builder().topicGroups();
 
-        Map<String, InternalTopicMetadata> repartitionTopicMetadata = new HashMap<>();
+        final Map<String, InternalTopicMetadata> repartitionTopicMetadata = new HashMap<>();
         for (InternalTopologyBuilder.TopicsInfo topicsInfo : topicGroups.values()) {
             for (InternalTopicConfig topic: topicsInfo.repartitionSourceTopics.values()) {
                 repartitionTopicMetadata.put(topic.name(), new InternalTopicMetadata(topic));
@@ -353,7 +339,7 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
                     // try set the number of partitions for this repartition topic if it is not set yet
                     if (numPartitions == UNKNOWN) {
                         for (InternalTopologyBuilder.TopicsInfo otherTopicsInfo : topicGroups.values()) {
-                            Set<String> otherSinkTopics = otherTopicsInfo.sinkTopics;
+                            final Set<String> otherSinkTopics = otherTopicsInfo.sinkTopics;
 
                             if (otherSinkTopics.contains(topicName)) {
                                 // if this topic is one of the sink topics of this topology,
@@ -391,10 +377,10 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
 
         // augment the metadata with the newly computed number of partitions for all the
         // repartition source topics
-        Map<TopicPartition, PartitionInfo> allRepartitionTopicPartitions = new HashMap<>();
+        final Map<TopicPartition, PartitionInfo> allRepartitionTopicPartitions = new HashMap<>();
         for (Map.Entry<String, InternalTopicMetadata> entry : repartitionTopicMetadata.entrySet()) {
-            String topic = entry.getKey();
-            Integer numPartitions = entry.getValue().numPartitions;
+            final String topic = entry.getKey();
+            final Integer numPartitions = entry.getValue().numPartitions;
 
             for (int partition = 0; partition < numPartitions; partition++) {
                 allRepartitionTopicPartitions.put(new TopicPartition(topic, partition),
@@ -405,34 +391,34 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         // ensure the co-partitioning topics within the group have the same number of partitions,
         // and enforce the number of partitions for those repartition topics to be the same if they
         // are co-partitioned as well.
-        ensureCopartitioning(threadDataProvider.builder().copartitionGroups(), repartitionTopicMetadata, metadata);
+        ensureCopartitioning(taskManager.builder().copartitionGroups(), repartitionTopicMetadata, metadata);
 
         // make sure the repartition source topics exist with the right number of partitions,
         // create these topics if necessary
         prepareTopic(repartitionTopicMetadata);
 
-        metadataWithInternalTopics = metadata.withPartitions(allRepartitionTopicPartitions);
+        final Cluster fullMetadata = metadata.withPartitions(allRepartitionTopicPartitions);
+        taskManager.setClusterMetadata(fullMetadata);
 
         log.debug("Created repartition topics {} from the parsed topology.", allRepartitionTopicPartitions.values());
 
         // ---------------- Step One ---------------- //
 
         // get the tasks as partition groups from the partition grouper
-        Set<String> allSourceTopics = new HashSet<>();
-        Map<Integer, Set<String>> sourceTopicsByGroup = new HashMap<>();
+        final Set<String> allSourceTopics = new HashSet<>();
+        final Map<Integer, Set<String>> sourceTopicsByGroup = new HashMap<>();
         for (Map.Entry<Integer, InternalTopologyBuilder.TopicsInfo> entry : topicGroups.entrySet()) {
             allSourceTopics.addAll(entry.getValue().sourceTopics);
             sourceTopicsByGroup.put(entry.getKey(), entry.getValue().sourceTopics);
         }
 
-        Map<TaskId, Set<TopicPartition>> partitionsForTask = threadDataProvider.partitionGrouper().partitionGroups(
-                sourceTopicsByGroup, metadataWithInternalTopics);
+        final Map<TaskId, Set<TopicPartition>> partitionsForTask = partitionGrouper.partitionGroups(sourceTopicsByGroup, fullMetadata);
 
         // check if all partitions are assigned, and there are no duplicates of partitions in multiple tasks
-        Set<TopicPartition> allAssignedPartitions = new HashSet<>();
-        Map<Integer, Set<TaskId>> tasksByTopicGroup = new HashMap<>();
+        final Set<TopicPartition> allAssignedPartitions = new HashSet<>();
+        final Map<Integer, Set<TaskId>> tasksByTopicGroup = new HashMap<>();
         for (Map.Entry<TaskId, Set<TopicPartition>> entry : partitionsForTask.entrySet()) {
-            Set<TopicPartition> partitions = entry.getValue();
+            final Set<TopicPartition> partitions = entry.getValue();
             for (TopicPartition partition : partitions) {
                 if (allAssignedPartitions.contains(partition)) {
                     log.warn("Partition {} is assigned to more than one tasks: {}", partition, partitionsForTask);
@@ -440,7 +426,7 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
             }
             allAssignedPartitions.addAll(partitions);
 
-            TaskId id = entry.getKey();
+            final TaskId id = entry.getKey();
             Set<TaskId> ids = tasksByTopicGroup.get(id.topicGroupId);
             if (ids == null) {
                 ids = new HashSet<>();
@@ -449,10 +435,10 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
             ids.add(id);
         }
         for (String topic : allSourceTopics) {
-            List<PartitionInfo> partitionInfoList = metadataWithInternalTopics.partitionsForTopic(topic);
+            final List<PartitionInfo> partitionInfoList = fullMetadata.partitionsForTopic(topic);
             if (!partitionInfoList.isEmpty()) {
                 for (PartitionInfo partitionInfo : partitionInfoList) {
-                    TopicPartition partition = new TopicPartition(partitionInfo.topic(), partitionInfo.partition());
+                    final TopicPartition partition = new TopicPartition(partitionInfo.topic(), partitionInfo.partition());
                     if (!allAssignedPartitions.contains(partition)) {
                         log.warn("Partition {} is not assigned to any tasks: {}", partition, partitionsForTask);
                     }
@@ -463,7 +449,7 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         }
 
         // add tasks to state change log topic subscribers
-        Map<String, InternalTopicMetadata> changelogTopicMetadata = new HashMap<>();
+        final Map<String, InternalTopicMetadata> changelogTopicMetadata = new HashMap<>();
         for (Map.Entry<Integer, InternalTopologyBuilder.TopicsInfo> entry : topicGroups.entrySet()) {
             final int topicGroupId = entry.getKey();
             final Map<String, InternalTopicConfig> stateChangelogTopics = entry.getValue().stateChangelogTopics;
@@ -476,7 +462,7 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
                         if (numPartitions < task.partition + 1)
                             numPartitions = task.partition + 1;
                     }
-                    InternalTopicMetadata topicMetadata = new InternalTopicMetadata(topicConfig);
+                    final InternalTopicMetadata topicMetadata = new InternalTopicMetadata(topicConfig);
                     topicMetadata.numPartitions = numPartitions;
 
                     changelogTopicMetadata.put(topicConfig.name(), topicMetadata);
@@ -493,7 +479,7 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         // ---------------- Step Two ---------------- //
 
         // assign tasks to clients
-        Map<UUID, ClientState> states = new HashMap<>();
+        final Map<UUID, ClientState> states = new HashMap<>();
         for (Map.Entry<UUID, ClientMetadata> entry : clientsMetadata.entrySet()) {
             states.put(entry.getKey(), entry.getValue().state);
         }
@@ -509,9 +495,9 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         // ---------------- Step Three ---------------- //
 
         // construct the global partition assignment per host map
-        partitionsByHostState = new HashMap<>();
+        final Map<HostInfo, Set<TopicPartition>> partitionsByHostState = new HashMap<>();
         for (Map.Entry<UUID, ClientMetadata> entry : clientsMetadata.entrySet()) {
-            HostInfo hostInfo = entry.getValue().hostInfo;
+            final HostInfo hostInfo = entry.getValue().hostInfo;
 
             if (hostInfo != null) {
                 final Set<TopicPartition> topicPartitions = new HashSet<>();
@@ -524,9 +510,10 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
                 partitionsByHostState.put(hostInfo, topicPartitions);
             }
         }
+        taskManager.setPartitionsByHostState(partitionsByHostState);
 
         // within the client, distribute tasks to its owned consumers
-        Map<String, Assignment> assignment = new HashMap<>();
+        final Map<String, Assignment> assignment = new HashMap<>();
         for (Map.Entry<UUID, ClientMetadata> entry : clientsMetadata.entrySet()) {
             final Set<String> consumers = entry.getValue().consumers;
             final ClientState state = entry.getValue().state;
@@ -541,12 +528,12 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
 
             int i = 0;
             for (String consumer : consumers) {
-                Map<TaskId, Set<TopicPartition>> standby = new HashMap<>();
-                ArrayList<AssignedPartition> assignedPartitions = new ArrayList<>();
+                final Map<TaskId, Set<TopicPartition>> standby = new HashMap<>();
+                final ArrayList<AssignedPartition> assignedPartitions = new ArrayList<>();
 
                 final int numTaskIds = taskIds.size();
                 for (int j = i; j < numTaskIds; j += numConsumers) {
-                    TaskId taskId = taskIds.get(j);
+                    final TaskId taskId = taskIds.get(j);
                     if (j < numActiveTasks) {
                         for (TopicPartition partition : partitionsForTask.get(taskId)) {
                             assignedPartitions.add(new AssignedPartition(taskId, partition));
@@ -562,8 +549,8 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
                 }
 
                 Collections.sort(assignedPartitions);
-                List<TaskId> active = new ArrayList<>();
-                List<TopicPartition> activePartitions = new ArrayList<>();
+                final List<TaskId> active = new ArrayList<>();
+                final List<TopicPartition> activePartitions = new ArrayList<>();
                 for (AssignedPartition partition : assignedPartitions) {
                     active.add(partition.taskId);
                     activePartitions.add(partition.partition);
@@ -588,8 +575,7 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
 
         AssignmentInfo info = AssignmentInfo.decode(assignment.userData());
 
-        this.standbyTasks = info.standbyTasks;
-        this.activeTasks = new HashMap<>();
+        Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
 
         // the number of assigned partitions should be the same as number of active tasks, which
         // could be duplicated if one task has more than one assigned partitions
@@ -612,35 +598,22 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
             assignedPartitions.add(partition);
         }
 
-        this.partitionsByHostState = info.partitionsByHost;
-
-        final Collection<Set<TopicPartition>> values = partitionsByHostState.values();
         final Map<TopicPartition, PartitionInfo> topicToPartitionInfo = new HashMap<>();
-        for (Set<TopicPartition> value : values) {
+        for (Set<TopicPartition> value : info.partitionsByHost.values()) {
             for (TopicPartition topicPartition : value) {
                 topicToPartitionInfo.put(topicPartition, new PartitionInfo(topicPartition.topic(),
-                                                                           topicPartition.partition(),
-                                                                           null,
-                                                                           new Node[0],
-                                                                           new Node[0]));
+                        topicPartition.partition(),
+                        null,
+                        new Node[0],
+                        new Node[0]));
             }
         }
-        metadataWithInternalTopics = Cluster.empty().withPartitions(topicToPartitionInfo);
 
-        checkForNewTopicAssignments(assignment);
-    }
+        taskManager.setClusterMetadata(Cluster.empty().withPartitions(topicToPartitionInfo));
+        taskManager.setPartitionsByHostState(info.partitionsByHost);
+        taskManager.setAssignmentMetadata(activeTasks, info.standbyTasks);
 
-    private void checkForNewTopicAssignments(Assignment assignment) {
-        if (threadDataProvider.builder().sourceTopicPattern() != null) {
-            final Set<String> assignedTopics = new HashSet<>();
-            for (final TopicPartition topicPartition : assignment.partitions()) {
-                assignedTopics.add(topicPartition.topic());
-            }
-            if (!threadDataProvider.builder().subscriptionUpdates().getUpdates().containsAll(assignedTopics)) {
-                assignedTopics.addAll(threadDataProvider.builder().subscriptionUpdates().getUpdates());
-                updateSubscribedTopics(assignedTopics);
-            }
-        }
+        taskManager.updateSubscriptionsFromAssignment(partitions);
     }
 
     /**
@@ -706,53 +679,24 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         }
     }
 
-    public Map<HostInfo, Set<TopicPartition>> getPartitionsByHostState() {
-        if (partitionsByHostState == null) {
-            return Collections.emptyMap();
-        }
-        return Collections.unmodifiableMap(partitionsByHostState);
-    }
-
-    public Cluster clusterMetadata() {
-        if (metadataWithInternalTopics == null) {
-            return Cluster.empty();
-        }
-        return metadataWithInternalTopics;
-    }
-
-    public Map<TaskId, Set<TopicPartition>> activeTasks() {
-        if (activeTasks == null) {
-            return Collections.emptyMap();
-        }
-        return Collections.unmodifiableMap(activeTasks);
-    }
-
-    public Map<TaskId, Set<TopicPartition>> standbyTasks() {
-        if (standbyTasks == null) {
-            return Collections.emptyMap();
-        }
-        return Collections.unmodifiableMap(standbyTasks);
-    }
-
-    void setInternalTopicManager(InternalTopicManager internalTopicManager) {
-        this.internalTopicManager = internalTopicManager;
-    }
-
     /**
      * Used to capture subscribed topic via Patterns discovered during the
      * partition assignment process.
+     *
+     * // TODO: this is a duplicate of the InternalTopologyBuilder#SubscriptionUpdates
+     *          and is maintained only for compatibility of the deprecated TopologyBuilder API
      */
     public static class SubscriptionUpdates {
 
         private final Set<String> updatedTopicSubscriptions = new HashSet<>();
 
-        private  void updateTopics(Collection<String> topicNames) {
+        public void updateTopics(Collection<String> topicNames) {
             updatedTopicSubscriptions.clear();
             updatedTopicSubscriptions.addAll(topicNames);
         }
 
         public Collection<String> getUpdates() {
-            return Collections.unmodifiableSet(new HashSet<>(updatedTopicSubscriptions));
+            return Collections.unmodifiableSet(updatedTopicSubscriptions);
         }
 
         public boolean hasUpdates() {
@@ -767,15 +711,11 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
         }
     }
 
-    public void close() {
-        internalTopicManager.close();
-    }
-
     static class CopartitionedTopicsValidator {
         private final String logPrefix;
 
-        CopartitionedTopicsValidator(final String threadName) {
-            this.logPrefix = String.format("stream-thread [%s]", threadName);
+        CopartitionedTopicsValidator(final String logPrefix) {
+            this.logPrefix = logPrefix;
         }
 
         @SuppressWarnings("deprecation")
@@ -826,4 +766,9 @@ public class StreamPartitionAssignor implements PartitionAssignor, Configurable,
 
         }
     }
+
+    // following functions are for test only
+    void setInternalTopicManager(InternalTopicManager internalTopicManager) {
+        this.internalTopicManager = internalTopicManager;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 1514e26..14b912e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -42,9 +42,7 @@ import org.apache.kafka.streams.KafkaClientSupplier;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.errors.StreamsException;
-import org.apache.kafka.streams.errors.TaskIdFormatException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
-import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.TaskMetadata;
@@ -52,7 +50,6 @@ import org.apache.kafka.streams.processor.ThreadMetadata;
 import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.slf4j.Logger;
 
-import java.io.File;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -67,9 +64,9 @@ import java.util.concurrent.atomic.AtomicInteger;
 
 import static java.util.Collections.singleton;
 
-public class StreamThread extends Thread implements ThreadDataProvider {
+public class StreamThread extends Thread {
 
-    private final Logger log;
+    private final static int UNLIMITED_RECORDS = -1;
     private static final AtomicInteger STREAM_THREAD_ID_SEQUENCE = new AtomicInteger(1);
 
     /**
@@ -204,7 +201,7 @@ public class StreamThread extends Thread implements ThreadDataProvider {
             if (newState == State.RUNNING) {
                 updateThreadMetadata(taskManager.activeTasks(), taskManager.standbyTasks());
             } else {
-                updateThreadMetadata(null, null);
+                updateThreadMetadata(Collections.<TaskId, StreamTask>emptyMap(), Collections.<TaskId, StandbyTask>emptyMap());
             }
         }
 
@@ -258,7 +255,6 @@ public class StreamThread extends Thread implements ThreadDataProvider {
                     return;
                 }
                 taskManager.createTasks(assignment);
-                streamThread.refreshMetadataState();
             } catch (final Throwable t) {
                 log.error("Error caught during partition assignment, " +
                         "will abort the current process and re-throw at the end of rebalance: {}", t.getMessage());
@@ -295,7 +291,6 @@ public class StreamThread extends Thread implements ThreadDataProvider {
                               "will abort the current process and re-throw at the end of rebalance: {}", t.getMessage());
                     streamThread.setRebalanceException(t);
                 } finally {
-                    streamThread.refreshMetadataState();
                     streamThread.clearStandbyRecords();
 
                     log.info("partition revocation took {} ms.\n" +
@@ -340,6 +335,14 @@ public class StreamThread extends Thread implements ThreadDataProvider {
             this.log = log;
         }
 
+        public InternalTopologyBuilder builder() {
+            return builder;
+        }
+
+        public StateDirectory stateDirectory() {
+            return stateDirectory;
+        }
+
         /**
          * @throws TaskMigratedException if the task producer got fenced (EOS only)
          */
@@ -555,15 +558,12 @@ public class StreamThread extends Thread implements ThreadDataProvider {
     private final long pollTimeMs;
     private final long commitTimeMs;
     private final Object stateLock;
-    private final UUID processId;
-    private final String clientId;
+    private final Logger log;
     private final String logPrefix;
-    private final StreamsConfig config;
+    // TODO: adminClient will be passeed to taskManager to be accessed in StreamPartitionAssignor
+    private final AdminClient adminClient;
     private final TaskManager taskManager;
-    private final StateDirectory stateDirectory;
-    private final PartitionGrouper partitionGrouper;
     private final StreamsMetricsThreadImpl streamsMetrics;
-    private final StreamsMetadataState streamsMetadataState;
 
     private long lastCommitMs;
     private long timerStartedMs;
@@ -571,75 +571,16 @@ public class StreamThread extends Thread implements ThreadDataProvider {
     private Throwable rebalanceException = null;
     private boolean processStandbyRecords = false;
     private volatile State state = State.CREATED;
+    private volatile ThreadMetadata threadMetadata;
     private StreamThread.StateListener stateListener;
-    private ThreadMetadataProvider metadataProvider;
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> standbyRecords;
 
-    private final AdminClient adminClient;
-
     // package-private for testing
     final ConsumerRebalanceListener rebalanceListener;
     final Consumer<byte[], byte[]> restoreConsumer;
+    final Consumer<byte[], byte[]> consumer;
+    final InternalTopologyBuilder builder;
 
-    protected final Consumer<byte[], byte[]> consumer;
-    protected final InternalTopologyBuilder builder;
-
-    public final String applicationId;
-
-    private volatile ThreadMetadata threadMetadata;
-
-    private final static int UNLIMITED_RECORDS = -1;
-
-    public StreamThread(final InternalTopologyBuilder builder,
-                        final String clientId,
-                        final String threadClientId,
-                        final StreamsConfig config,
-                        final UUID processId,
-                        final Time time,
-                        final StreamsMetadataState streamsMetadataState,
-                        final TaskManager taskManager,
-                        final StreamsMetricsThreadImpl streamsMetrics,
-                        final KafkaClientSupplier clientSupplier,
-                        final Consumer<byte[], byte[]> restoreConsumer,
-                        final AdminClient adminClient,
-                        final StateDirectory stateDirectory) {
-        super(threadClientId);
-        this.builder = builder;
-        this.clientId = clientId;
-        this.applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
-        this.pollTimeMs = config.getLong(StreamsConfig.POLL_MS_CONFIG);
-        this.commitTimeMs = config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG);
-        this.processId = processId;
-        this.time = time;
-        this.streamsMetadataState = streamsMetadataState;
-        this.taskManager = taskManager;
-        this.logPrefix = String.format("stream-thread [%s] ", threadClientId);
-        this.streamsMetrics = streamsMetrics;
-        this.restoreConsumer = restoreConsumer;
-        this.adminClient = adminClient;
-        this.stateDirectory = stateDirectory;
-        this.config = config;
-        this.stateLock = new Object();
-        this.standbyRecords = new HashMap<>();
-        this.partitionGrouper = config.getConfiguredInstance(StreamsConfig.PARTITION_GROUPER_CLASS_CONFIG, PartitionGrouper.class);
-        final LogContext logContext = new LogContext(this.logPrefix);
-        this.log = logContext.logger(StreamThread.class);
-        this.rebalanceListener = new RebalanceListener(time, taskManager, this, this.log);
-
-        log.info("Creating consumer client");
-        final Map<String, Object> consumerConfigs = config.getConsumerConfigs(this, applicationId, threadClientId);
-
-        if (!builder.latestResetTopicsPattern().pattern().equals("") || !builder.earliestResetTopicsPattern().pattern().equals("")) {
-            originalReset = (String) consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG);
-            log.info("Custom offset resets specified updating configs original auto offset reset {}", originalReset);
-            consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none");
-        }
-        this.consumer = clientSupplier.getConsumer(consumerConfigs);
-        taskManager.setConsumer(consumer);
-        updateThreadMetadata(null, null);
-    }
-
-    @SuppressWarnings("ConstantConditions")
     public static StreamThread create(final InternalTopologyBuilder builder,
                                       final StreamsConfig config,
                                       final KafkaClientSupplier clientSupplier,
@@ -652,82 +593,123 @@ public class StreamThread extends Thread implements ThreadDataProvider {
                                       final long cacheSizeBytes,
                                       final StateDirectory stateDirectory,
                                       final StateRestoreListener userStateRestoreListener) {
-
         final String threadClientId = clientId + "-StreamThread-" + STREAM_THREAD_ID_SEQUENCE.getAndIncrement();
-        final StreamsMetricsThreadImpl streamsMetrics = new StreamsMetricsThreadImpl(metrics,
-                                                                                     "stream-metrics",
-                                                                                     "thread." + threadClientId,
-                                                                                     Collections.singletonMap("client-id",
-                                                                                                              threadClientId));
 
         final String logPrefix = String.format("stream-thread [%s] ", threadClientId);
         final LogContext logContext = new LogContext(logPrefix);
         final Logger log = logContext.logger(StreamThread.class);
 
-        if (config.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG) < 0) {
-            log.warn("Negative cache size passed in thread. Reverting to cache size of 0 bytes");
-        }
-        final ThreadCache cache = new ThreadCache(logContext, cacheSizeBytes, streamsMetrics);
-
-        final boolean eosEnabled = StreamsConfig.EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG));
-
         log.info("Creating restore consumer client");
-        final Map<String, Object> consumerConfigs = config.getRestoreConsumerConfigs(threadClientId);
-        final Consumer<byte[], byte[]> restoreConsumer = clientSupplier.getRestoreConsumer(consumerConfigs);
-        final StoreChangelogReader changelogReader = new StoreChangelogReader(restoreConsumer,
-                                                                              userStateRestoreListener,
-                                                                              logContext);
+        final Map<String, Object> restoreConsumerConfigs = config.getRestoreConsumerConfigs(threadClientId);
+        final Consumer<byte[], byte[]> restoreConsumer = clientSupplier.getRestoreConsumer(restoreConsumerConfigs);
+        final StoreChangelogReader changelogReader = new StoreChangelogReader(restoreConsumer, userStateRestoreListener, logContext);
 
         Producer<byte[], byte[]> threadProducer = null;
+        final boolean eosEnabled = StreamsConfig.EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG));
         if (!eosEnabled) {
             final Map<String, Object> producerConfigs = config.getProducerConfigs(threadClientId);
             log.info("Creating shared producer client");
             threadProducer = clientSupplier.getProducer(producerConfigs);
         }
 
-        final AbstractTaskCreator activeTaskCreator = new TaskCreator(builder,
-                                                                      config,
-                                                                      streamsMetrics,
-                                                                      stateDirectory,
-                                                                      streamsMetrics.taskCreatedSensor,
-                                                                      changelogReader,
-                                                                      cache,
-                                                                      time,
-                                                                      clientSupplier,
-                                                                      threadProducer,
-                                                                      threadClientId,
-                                                                      log);
-        final AbstractTaskCreator standbyTaskCreator = new StandbyTaskCreator(builder,
-                                                                              config,
-                                                                              streamsMetrics,
-                                                                              stateDirectory,
-                                                                              streamsMetrics.taskCreatedSensor,
-                                                                              changelogReader,
-                                                                              time,
-                                                                              log);
-        final TaskManager taskManager = new TaskManager(changelogReader,
-                                                        logPrefix,
-                                                        restoreConsumer,
-                                                        activeTaskCreator,
-                                                        standbyTaskCreator,
-                                                        new AssignedStreamsTasks(logContext),
-                                                        new AssignedStandbyTasks(logContext));
-
-        return new StreamThread(builder,
-                                clientId,
-                                threadClientId,
-                                config,
-                                processId,
-                                time,
-                                streamsMetadataState,
-                                taskManager,
-                                streamsMetrics,
-                                clientSupplier,
-                                restoreConsumer,
-                                adminClient,
-                                stateDirectory);
+        StreamsMetricsThreadImpl streamsMetrics = new StreamsMetricsThreadImpl(
+                metrics,
+                "stream-metrics",
+                "thread." + threadClientId,
+                Collections.singletonMap("client-id", threadClientId));
+
+        final ThreadCache cache = new ThreadCache(logContext, cacheSizeBytes, streamsMetrics);
 
+        final StreamsKafkaClient streamsKafkaClient = StreamsKafkaClient.create(config.originals());
+
+        final AbstractTaskCreator<StreamTask> activeTaskCreator = new TaskCreator(builder,
+                                                                                  config,
+                                                                                  streamsMetrics,
+                                                                                  stateDirectory,
+                                                                                  streamsMetrics.taskCreatedSensor,
+                                                                                  changelogReader,
+                                                                                  cache,
+                                                                                  time,
+                                                                                  clientSupplier,
+                                                                                  threadProducer,
+                                                                                  threadClientId,
+                                                                                  log);
+        final AbstractTaskCreator<StandbyTask> standbyTaskCreator = new StandbyTaskCreator(builder,
+                                                                                           config,
+                                                                                           streamsMetrics,
+                                                                                           stateDirectory,
+                                                                                           streamsMetrics.taskCreatedSensor,
+                                                                                           changelogReader,
+                                                                                           time,
+                                                                                           log);
+        TaskManager taskManager = new TaskManager(changelogReader,
+                                                  processId,
+                                                  logPrefix,
+                                                  restoreConsumer,
+                                                  streamsMetadataState,
+                                                  activeTaskCreator,
+                                                  standbyTaskCreator,
+                                                  streamsKafkaClient,
+                                                  new AssignedStreamsTasks(logContext),
+                                                  new AssignedStandbyTasks(logContext));
 
+        log.info("Creating consumer client");
+        final String applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
+        final Map<String, Object> consumerConfigs = config.getConsumerConfigs(applicationId, threadClientId);
+        consumerConfigs.put(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, taskManager);
+        String originalReset = null;
+        if (!builder.latestResetTopicsPattern().pattern().equals("") || !builder.earliestResetTopicsPattern().pattern().equals("")) {
+            originalReset = (String) consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG);
+            consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none");
+        }
+        final Consumer<byte[], byte[]> consumer = clientSupplier.getConsumer(consumerConfigs);
+        taskManager.setConsumer(consumer);
+
+        return new StreamThread(time,
+                config,
+                restoreConsumer,
+                consumer,
+                originalReset,
+                adminClient,
+                taskManager,
+                streamsMetrics,
+                builder,
+                threadClientId,
+                logContext);
+    }
+
+    public StreamThread(final Time time,
+                        final StreamsConfig config,
+                        final Consumer<byte[], byte[]> restoreConsumer,
+                        final Consumer<byte[], byte[]> consumer,
+                        final String originalReset,
+                        final AdminClient adminClient,
+                        final TaskManager taskManager,
+                        final StreamsMetricsThreadImpl streamsMetrics,
+                        final InternalTopologyBuilder builder,
+                        final String threadClientId,
+                        final LogContext logContext) {
+        super(threadClientId);
+
+        this.stateLock = new Object();
+        this.standbyRecords = new HashMap<>();
+
+        this.time = time;
+        this.builder = builder;
+        this.streamsMetrics = streamsMetrics;
+        this.logPrefix = logContext.logPrefix();
+        this.log = logContext.logger(StreamThread.class);
+        this.rebalanceListener = new RebalanceListener(time, taskManager, this, this.log);
+        this.taskManager = taskManager;
+        this.restoreConsumer = restoreConsumer;
+        this.consumer = consumer;
+        this.originalReset = originalReset;
+        this.adminClient = adminClient;
+
+        this.pollTimeMs = config.getLong(StreamsConfig.POLL_MS_CONFIG);
+        this.commitTimeMs = config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG);
+
+        updateThreadMetadata(Collections.<TaskId, StreamTask>emptyMap(), Collections.<TaskId, StandbyTask>emptyMap());
     }
 
     /**
@@ -1102,107 +1084,6 @@ public class StreamThread extends Thread implements ThreadDataProvider {
         setState(State.PENDING_SHUTDOWN);
     }
 
-    public Map<TaskId, StreamTask> tasks() {
-        return taskManager.activeTasks();
-    }
-
-    /**
-     * Returns ids of tasks that were being executed before the rebalance.
-     */
-    public Set<TaskId> prevActiveTasks() {
-        return taskManager.prevActiveTaskIds();
-    }
-
-    @Override
-    public InternalTopologyBuilder builder() {
-        return builder;
-    }
-
-    @Override
-    public String name() {
-        return getName();
-    }
-
-    /**
-     * Returns ids of tasks whose states are kept on the local storage.
-     */
-    public Set<TaskId> cachedTasks() {
-        // A client could contain some inactive tasks whose states are still kept on the local storage in the following scenarios:
-        // 1) the client is actively maintaining standby tasks by maintaining their states from the change log.
-        // 2) the client has just got some tasks migrated out of itself to other clients while these task states
-        //    have not been cleaned up yet (this can happen in a rolling bounce upgrade, for example).
-
-        final HashSet<TaskId> tasks = new HashSet<>();
-
-        final File[] stateDirs = stateDirectory.listTaskDirectories();
-        if (stateDirs != null) {
-            for (final File dir : stateDirs) {
-                try {
-                    final TaskId id = TaskId.parse(dir.getName());
-                    // if the checkpoint file exists, the state is valid.
-                    if (new File(dir, ProcessorStateManager.CHECKPOINT_FILE_NAME).exists()) {
-                        tasks.add(id);
-                    }
-                } catch (final TaskIdFormatException e) {
-                    // there may be some unknown files that sits in the same directory,
-                    // we should ignore these files instead trying to delete them as well
-                }
-            }
-        }
-
-        return tasks;
-    }
-
-    @Override
-    public UUID processId() {
-        return processId;
-    }
-
-    @Override
-    public StreamsConfig config() {
-        return config;
-    }
-
-    @Override
-    public PartitionGrouper partitionGrouper() {
-        return partitionGrouper;
-    }
-
-    /**
-     * Produces a string representation containing useful information about a StreamThread.
-     * This is useful in debugging scenarios.
-     * @return A string representation of the StreamThread instance.
-     */
-    @Override
-    public String toString() {
-        return toString("");
-    }
-
-    /**
-     * Produces a string representation containing useful information about a StreamThread, starting with the given indent.
-     * This is useful in debugging scenarios.
-     * @return A string representation of the StreamThread instance.
-     */
-    @SuppressWarnings("ThrowableNotThrown")
-    public String toString(final String indent) {
-        final StringBuilder sb = new StringBuilder()
-            .append(indent).append("StreamsThread appId: ").append(applicationId).append("\n")
-            .append(indent).append("\tStreamsThread clientId: ").append(clientId).append("\n")
-            .append(indent).append("\tStreamsThread threadId: ").append(getName()).append("\n");
-
-        sb.append(taskManager.toString(indent));
-        return sb.toString();
-    }
-
-    String threadClientId() {
-        return getName();
-    }
-
-    public void setThreadMetadataProvider(final ThreadMetadataProvider metadataProvider) {
-        this.metadataProvider = metadataProvider;
-        taskManager.setThreadMetadataProvider(metadataProvider);
-    }
-
     private void completeShutdown(final boolean cleanRun) {
         // set the state to pending shutdown first as it may be called due to error;
         // its state may already be PENDING_SHUTDOWN so it will return false but we
@@ -1236,10 +1117,6 @@ public class StreamThread extends Thread implements ThreadDataProvider {
         standbyRecords.clear();
     }
 
-    private void refreshMetadataState() {
-        streamsMetadataState.onChange(metadataProvider.getPartitionsByHostState(), metadataProvider.clusterMetadata());
-    }
-
     /**
      * Return information about the current {@link StreamThread}.
      *
@@ -1251,17 +1128,46 @@ public class StreamThread extends Thread implements ThreadDataProvider {
 
     private void updateThreadMetadata(final Map<TaskId, StreamTask> activeTasks, final Map<TaskId, StandbyTask> standbyTasks) {
         final Set<TaskMetadata> activeTasksMetadata = new HashSet<>();
-        if (activeTasks != null) {
-            for (Map.Entry<TaskId, StreamTask> task : activeTasks.entrySet()) {
-                activeTasksMetadata.add(new TaskMetadata(task.getKey().toString(), task.getValue().partitions()));
-            }
+        for (Map.Entry<TaskId, StreamTask> task : activeTasks.entrySet()) {
+            activeTasksMetadata.add(new TaskMetadata(task.getKey().toString(), task.getValue().partitions()));
         }
         final Set<TaskMetadata> standbyTasksMetadata = new HashSet<>();
-        if (standbyTasks != null) {
-            for (Map.Entry<TaskId, StandbyTask> task : standbyTasks.entrySet()) {
-                standbyTasksMetadata.add(new TaskMetadata(task.getKey().toString(), task.getValue().partitions()));
-            }
+        for (Map.Entry<TaskId, StandbyTask> task : standbyTasks.entrySet()) {
+            standbyTasksMetadata.add(new TaskMetadata(task.getKey().toString(), task.getValue().partitions()));
         }
+
         threadMetadata = new ThreadMetadata(this.getName(), this.state().name(), activeTasksMetadata, standbyTasksMetadata);
     }
+
+    public Map<TaskId, StreamTask> tasks() {
+        return taskManager.activeTasks();
+    }
+
+    /**
+     * Produces a string representation containing useful information about a StreamThread.
+     * This is useful in debugging scenarios.
+     * @return A string representation of the StreamThread instance.
+     */
+    @Override
+    public String toString() {
+        return toString("");
+    }
+
+    /**
+     * Produces a string representation containing useful information about a StreamThread, starting with the given indent.
+     * This is useful in debugging scenarios.
+     * @return A string representation of the StreamThread instance.
+     */
+    public String toString(final String indent) {
+        final StringBuilder sb = new StringBuilder()
+                .append(indent).append("\tStreamsThread threadId: ").append(getName()).append("\n");
+
+        sb.append(taskManager.toString(indent));
+        return sb.toString();
+    }
+
+    // this is for testing only
+    TaskManager taskManager() {
+        return taskManager;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClient.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClient.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClient.java
index 075f445..1e21878 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClient.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsKafkaClient.java
@@ -37,11 +37,8 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.MetricsReporter;
 import org.apache.kafka.common.network.ChannelBuilder;
 import org.apache.kafka.common.network.Selector;
-import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.ApiError;
-import org.apache.kafka.common.requests.ApiVersionsRequest;
-import org.apache.kafka.common.requests.ApiVersionsResponse;
 import org.apache.kafka.common.requests.CreateTopicsRequest;
 import org.apache.kafka.common.requests.CreateTopicsResponse;
 import org.apache.kafka.common.requests.MetadataRequest;
@@ -64,9 +61,6 @@ import java.util.Map;
 import java.util.Properties;
 import java.util.concurrent.TimeUnit;
 
-import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
-import static org.apache.kafka.streams.StreamsConfig.PROCESSING_GUARANTEE_CONFIG;
-
 public class StreamsKafkaClient {
 
     private static final ConfigDef CONFIG = StreamsConfig.configDef()
@@ -75,8 +69,8 @@ public class StreamsKafkaClient {
 
     public static class Config extends AbstractConfig {
 
-        static Config fromStreamsConfig(StreamsConfig streamsConfig) {
-            return new Config(streamsConfig.originals());
+        static Config fromStreamsConfig(Map<String, ?> props) {
+            return new Config(props);
         }
 
         Config(Map<?, ?> originals) {
@@ -166,8 +160,8 @@ public class StreamsKafkaClient {
         return new LogContext("[StreamsKafkaClient clientId=" + clientId + "] ");
     }
 
-    public static StreamsKafkaClient create(final StreamsConfig streamsConfig) {
-        return create(Config.fromStreamsConfig(streamsConfig));
+    public static StreamsKafkaClient create(final Map<String, ?> props) {
+        return create(Config.fromStreamsConfig(props));
     }
 
     public void close() {
@@ -357,55 +351,7 @@ public class StreamsKafkaClient {
             throw new StreamsException("Inconsistent response type for internal topic metadata request. " +
                 "Expected MetadataResponse but received " + clientResponse.responseBody().getClass().getName());
         }
-        final MetadataResponse metadataResponse = (MetadataResponse) clientResponse.responseBody();
-        return metadataResponse;
-    }
-
-    /**
-     * Check if the used brokers have version 0.10.1.x or higher.
-     * <p>
-     * Note, for <em>pre</em> 0.10.x brokers the broker version cannot be checked and the client will hang and retry
-     * until it {@link StreamsConfig#REQUEST_TIMEOUT_MS_CONFIG times out}.
-     *
-     * @throws BrokerNotFoundException if connecting failed within {@code request.timeout.ms}
-     * @throws TimeoutException if there was no response within {@code request.timeout.ms}
-     * @throws StreamsException if brokers have version 0.10.0.x
-     * @throws StreamsException for any other fatal error
-     */
-    public void checkBrokerCompatibility(final boolean eosEnabled) throws StreamsException {
-        final ClientRequest clientRequest = kafkaClient.newClientRequest(
-            getAnyReadyBrokerId(),
-            new ApiVersionsRequest.Builder(),
-            Time.SYSTEM.milliseconds(),
-            true);
-
-        final ClientResponse clientResponse = sendRequestSync(clientRequest);
-        if (!clientResponse.hasResponse()) {
-            throw new StreamsException("Empty response for client request.");
-        }
-        if (!(clientResponse.responseBody() instanceof ApiVersionsResponse)) {
-            throw new StreamsException("Inconsistent response type for API versions request. " +
-                "Expected ApiVersionsResponse but received " + clientResponse.responseBody().getClass().getName());
-        }
-
-        final ApiVersionsResponse apiVersionsResponse =  (ApiVersionsResponse) clientResponse.responseBody();
-
-        if (apiVersionsResponse.apiVersion(ApiKeys.CREATE_TOPICS.id) == null) {
-            throw new StreamsException("Kafka Streams requires broker version 0.10.1.x or higher.");
-        }
-
-        if (eosEnabled && !brokerSupportsTransactions(apiVersionsResponse)) {
-            throw new StreamsException("Setting " + PROCESSING_GUARANTEE_CONFIG + "=" + EXACTLY_ONCE + " requires broker version 0.11.0.x or higher.");
-        }
-    }
-
-    private boolean brokerSupportsTransactions(final ApiVersionsResponse apiVersionsResponse) {
-        return apiVersionsResponse.apiVersion(ApiKeys.INIT_PRODUCER_ID.id) != null
-            && apiVersionsResponse.apiVersion(ApiKeys.ADD_PARTITIONS_TO_TXN.id) != null
-            && apiVersionsResponse.apiVersion(ApiKeys.ADD_OFFSETS_TO_TXN.id) != null
-            && apiVersionsResponse.apiVersion(ApiKeys.END_TXN.id) != null
-            && apiVersionsResponse.apiVersion(ApiKeys.WRITE_TXN_MARKERS.id) != null
-            && apiVersionsResponse.apiVersion(ApiKeys.TXN_OFFSET_COMMIT.id) != null;
+        return (MetadataResponse) clientResponse.responseBody();
     }
 
 }


[2/3] kafka git commit: KAFKA-6170; KIP-220 Part 2: Break dependency of Assignor on StreamThread

Posted by gu...@apache.org.
http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 0238615..1eecb73 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -17,17 +17,24 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskIdFormatException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.state.HostInfo;
 import org.slf4j.Logger;
 
+import java.io.File;
 import java.util.Collection;
 import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.UUID;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.singleton;
@@ -37,6 +44,7 @@ class TaskManager {
     // activeTasks needs to be concurrent as it can be accessed
     // by QueryableState
     private final Logger log;
+    private final UUID processId;
     private final AssignedStreamsTasks active;
     private final AssignedStandbyTasks standby;
     private final ChangelogReader changelogReader;
@@ -44,18 +52,33 @@ class TaskManager {
     private final Consumer<byte[], byte[]> restoreConsumer;
     private final StreamThread.AbstractTaskCreator<StreamTask> taskCreator;
     private final StreamThread.AbstractTaskCreator<StandbyTask> standbyTaskCreator;
-    private ThreadMetadataProvider threadMetadataProvider;
+    private final StreamsMetadataState streamsMetadataState;
+
+    // TODO: this is going to be replaced by AdminClient
+    final StreamsKafkaClient streamsKafkaClient;
+
+    // following information is updated during rebalance phase by the partition assignor
+    private Cluster cluster;
+    private Map<TaskId, Set<TopicPartition>> assignedActiveTasks;
+    private Map<TaskId, Set<TopicPartition>> assignedStandbyTasks;
+    private Map<HostInfo, Set<TopicPartition>> partitionsByHostState;
+
     private Consumer<byte[], byte[]> consumer;
 
     TaskManager(final ChangelogReader changelogReader,
+                final UUID processId,
                 final String logPrefix,
                 final Consumer<byte[], byte[]> restoreConsumer,
+                final StreamsMetadataState streamsMetadataState,
                 final StreamThread.AbstractTaskCreator<StreamTask> taskCreator,
                 final StreamThread.AbstractTaskCreator<StandbyTask> standbyTaskCreator,
+                final StreamsKafkaClient streamsKafkaClient,
                 final AssignedStreamsTasks active,
                 final AssignedStandbyTasks standby) {
         this.changelogReader = changelogReader;
+        this.processId = processId;
         this.logPrefix = logPrefix;
+        this.streamsMetadataState = streamsMetadataState;
         this.restoreConsumer = restoreConsumer;
         this.taskCreator = taskCreator;
         this.standbyTaskCreator = standbyTaskCreator;
@@ -65,15 +88,14 @@ class TaskManager {
         final LogContext logContext = new LogContext(logPrefix);
 
         this.log = logContext.logger(getClass());
+
+        this.streamsKafkaClient = streamsKafkaClient;
     }
 
     /**
      * @throws TaskMigratedException if the task producer got fenced (EOS only)
      */
     void createTasks(final Collection<TopicPartition> assignment) {
-        if (threadMetadataProvider == null) {
-            throw new IllegalStateException(logPrefix + "taskIdProvider has not been initialized while adding stream tasks. This should not happen.");
-        }
         if (consumer == null) {
             throw new IllegalStateException(logPrefix + "consumer has not been initialized while adding stream tasks. This should not happen.");
         }
@@ -81,8 +103,7 @@ class TaskManager {
         changelogReader.reset();
         // do this first as we may have suspended standby tasks that
         // will become active or vice versa
-        standby.closeNonAssignedSuspendedTasks(threadMetadataProvider.standbyTasks());
-        Map<TaskId, Set<TopicPartition>> assignedActiveTasks = threadMetadataProvider.activeTasks();
+        standby.closeNonAssignedSuspendedTasks(assignedStandbyTasks);
         active.closeNonAssignedSuspendedTasks(assignedActiveTasks);
         addStreamTasks(assignment);
         addStandbyTasks();
@@ -91,22 +112,17 @@ class TaskManager {
         consumer.pause(partitions);
     }
 
-    void setThreadMetadataProvider(final ThreadMetadataProvider threadMetadataProvider) {
-        this.threadMetadataProvider = threadMetadataProvider;
-    }
-
     /**
      * @throws TaskMigratedException if the task producer got fenced (EOS only)
      */
     private void addStreamTasks(final Collection<TopicPartition> assignment) {
-        Map<TaskId, Set<TopicPartition>> assignedTasks = threadMetadataProvider.activeTasks();
-        if (assignedTasks.isEmpty()) {
+        if (assignedActiveTasks.isEmpty()) {
             return;
         }
         final Map<TaskId, Set<TopicPartition>> newTasks = new HashMap<>();
         // collect newly assigned tasks and reopen re-assigned tasks
-        log.debug("Adding assigned tasks as active: {}", assignedTasks);
-        for (final Map.Entry<TaskId, Set<TopicPartition>> entry : assignedTasks.entrySet()) {
+        log.debug("Adding assigned tasks as active: {}", assignedActiveTasks);
+        for (final Map.Entry<TaskId, Set<TopicPartition>> entry : assignedActiveTasks.entrySet()) {
             final TaskId taskId = entry.getKey();
             final Set<TopicPartition> partitions = entry.getValue();
 
@@ -142,7 +158,7 @@ class TaskManager {
      * @throws TaskMigratedException if the task producer got fenced (EOS only)
      */
     private void addStandbyTasks() {
-        final Map<TaskId, Set<TopicPartition>> assignedStandbyTasks = threadMetadataProvider.standbyTasks();
+        final Map<TaskId, Set<TopicPartition>> assignedStandbyTasks = this.assignedStandbyTasks;
         if (assignedStandbyTasks.isEmpty()) {
             return;
         }
@@ -184,6 +200,44 @@ class TaskManager {
     }
 
     /**
+     * Returns ids of tasks whose states are kept on the local storage.
+     */
+    Set<TaskId> cachedTasksIds() {
+        // A client could contain some inactive tasks whose states are still kept on the local storage in the following scenarios:
+        // 1) the client is actively maintaining standby tasks by maintaining their states from the change log.
+        // 2) the client has just got some tasks migrated out of itself to other clients while these task states
+        //    have not been cleaned up yet (this can happen in a rolling bounce upgrade, for example).
+
+        final HashSet<TaskId> tasks = new HashSet<>();
+
+        final File[] stateDirs = taskCreator.stateDirectory().listTaskDirectories();
+        if (stateDirs != null) {
+            for (final File dir : stateDirs) {
+                try {
+                    final TaskId id = TaskId.parse(dir.getName());
+                    // if the checkpoint file exists, the state is valid.
+                    if (new File(dir, ProcessorStateManager.CHECKPOINT_FILE_NAME).exists()) {
+                        tasks.add(id);
+                    }
+                } catch (final TaskIdFormatException e) {
+                    // there may be some unknown files that sits in the same directory,
+                    // we should ignore these files instead trying to delete them as well
+                }
+            }
+        }
+
+        return tasks;
+    }
+
+    UUID processId() {
+        return processId;
+    }
+
+    InternalTopologyBuilder builder() {
+        return taskCreator.builder();
+    }
+
+    /**
      * Similar to shutdownTasksAndState, however does not close the task managers, in the hope that
      * soon the tasks will be assigned again
      * @throws TaskMigratedException if the task producer got fenced (EOS only)
@@ -216,16 +270,14 @@ class TaskManager {
             firstException.compareAndSet(null, fatalException);
         }
         standby.close(clean);
-        try {
-            threadMetadataProvider.close();
-        } catch (final Throwable e) {
-            log.error("Failed to close KafkaStreamClient due to the following error:", e);
-        }
+
         // remove the changelog partitions from restore consumer
         restoreConsumer.unsubscribe();
         taskCreator.close();
         standbyTaskCreator.close();
 
+        streamsKafkaClient.close();
+
         final RuntimeException fatalException = firstException.get();
         if (fatalException != null) {
             throw fatalException;
@@ -311,6 +363,45 @@ class TaskManager {
         }
     }
 
+    void setClusterMetadata(final Cluster cluster) {
+        this.cluster = cluster;
+    }
+
+    void setPartitionsByHostState(final Map<HostInfo, Set<TopicPartition>> partitionsByHostState) {
+        this.partitionsByHostState = partitionsByHostState;
+        this.streamsMetadataState.onChange(partitionsByHostState, cluster);
+    }
+
+    void setAssignmentMetadata(final Map<TaskId, Set<TopicPartition>> activeTasks,
+                               final Map<TaskId, Set<TopicPartition>> standbyTasks) {
+        this.assignedActiveTasks = activeTasks;
+        this.assignedStandbyTasks = standbyTasks;
+    }
+
+    void updateSubscriptionsFromAssignment(List<TopicPartition> partitions) {
+        if (builder().sourceTopicPattern() != null) {
+            final Set<String> assignedTopics = new HashSet<>();
+            for (final TopicPartition topicPartition : partitions) {
+                assignedTopics.add(topicPartition.topic());
+            }
+
+            final Collection<String> existingTopics = builder().subscriptionUpdates().getUpdates();
+            if (!existingTopics.containsAll(assignedTopics)) {
+                assignedTopics.addAll(existingTopics);
+                builder().updateSubscribedTopics(assignedTopics, logPrefix);
+            }
+        }
+    }
+
+    void updateSubscriptionsFromMetadata(Set<String> topics) {
+        if (builder().sourceTopicPattern() != null) {
+            final Collection<String> existingTopics = builder().subscriptionUpdates().getUpdates();
+            if (!existingTopics.equals(topics)) {
+                builder().updateSubscribedTopics(topics, logPrefix);
+            }
+        }
+    }
+
     /**
      * @throws TaskMigratedException if committing offsets failed (non-EOS)
      *                               or if the task producer got fenced (EOS)
@@ -350,4 +441,13 @@ class TaskManager {
         builder.append(standby.toString(indent + "\t\t"));
         return builder.toString();
     }
+
+    // the following functions are for testing only
+    Map<TaskId, Set<TopicPartition>> assignedActiveTasks() {
+        return assignedActiveTasks;
+    }
+
+    Map<TaskId, Set<TopicPartition>> assignedStandbyTasks() {
+        return assignedStandbyTasks;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java
deleted file mode 100644
index ded98f7..0000000
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadDataProvider.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.streams.processor.internals;
-
-import org.apache.kafka.streams.StreamsConfig;
-import org.apache.kafka.streams.processor.PartitionGrouper;
-import org.apache.kafka.streams.processor.TaskId;
-
-import java.util.Set;
-import java.util.UUID;
-
-// interface to get info about the StreamThread
-interface ThreadDataProvider {
-    InternalTopologyBuilder builder();
-    String name();
-    Set<TaskId> prevActiveTasks();
-    Set<TaskId> cachedTasks();
-    UUID processId();
-    StreamsConfig config();
-    PartitionGrouper partitionGrouper();
-    void setThreadMetadataProvider(final ThreadMetadataProvider provider);
-}

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java
deleted file mode 100644
index f185045..0000000
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataProvider.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.streams.processor.internals;
-
-import org.apache.kafka.common.Cluster;
-import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.state.HostInfo;
-
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Interface used by a <code>StreamThread</code> to get metadata from the <code>StreamPartitionAssignor</code>
- */
-public interface ThreadMetadataProvider {
-    Map<TaskId, Set<TopicPartition>> standbyTasks();
-    Map<TaskId, Set<TopicPartition>> activeTasks();
-    Map<HostInfo, Set<TopicPartition>> getPartitionsByHostState();
-    Cluster clusterMetadata();
-    void close();
-}

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java b/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java
index 3774a8e..1a4cfb1 100644
--- a/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java
@@ -94,7 +94,7 @@ public class StreamsConfigTest {
     public void testGetConsumerConfigs() {
         final String groupId = "example-application";
         final String clientId = "client";
-        final Map<String, Object> returnedProps = streamsConfig.getConsumerConfigs(null, groupId, clientId);
+        final Map<String, Object> returnedProps = streamsConfig.getConsumerConfigs(groupId, clientId);
         assertEquals(returnedProps.get(ConsumerConfig.CLIENT_ID_CONFIG), clientId + "-consumer");
         assertEquals(returnedProps.get(ConsumerConfig.GROUP_ID_CONFIG), groupId);
         assertEquals(returnedProps.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "1000");
@@ -147,7 +147,7 @@ public class StreamsConfigTest {
         props.put(consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "earliest");
         props.put(consumerPrefix(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG), 1);
         final StreamsConfig streamsConfig = new StreamsConfig(props);
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientId");
         assertEquals("earliest", consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG));
         assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG));
     }
@@ -166,7 +166,7 @@ public class StreamsConfigTest {
     public void shouldSupportPrefixedPropertiesThatAreNotPartOfConsumerConfig() {
         final StreamsConfig streamsConfig = new StreamsConfig(props);
         props.put(consumerPrefix("interceptor.statsd.host"), "host");
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientId");
         assertEquals("host", consumerConfigs.get("interceptor.statsd.host"));
     }
 
@@ -202,7 +202,7 @@ public class StreamsConfigTest {
         props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
         props.put(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG, 1);
         final StreamsConfig streamsConfig = new StreamsConfig(props);
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientId");
         assertEquals("earliest", consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG));
         assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG));
     }
@@ -248,7 +248,7 @@ public class StreamsConfigTest {
         props.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "latest");
         props.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "10");
         final StreamsConfig streamsConfig = new StreamsConfig(props);
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientId");
         assertEquals("latest", consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG));
         assertEquals("10", consumerConfigs.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG));
     }
@@ -275,7 +275,7 @@ public class StreamsConfigTest {
     public void shouldResetToDefaultIfConsumerAutoCommitIsOverridden() {
         props.put(StreamsConfig.consumerPrefix(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG), "true");
         final StreamsConfig streamsConfig = new StreamsConfig(props);
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "a", "b");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("a", "b");
         assertEquals("false", consumerConfigs.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG));
     }
 
@@ -290,7 +290,7 @@ public class StreamsConfigTest {
     @Test
     public void shouldSetInternalLeaveGroupOnCloseConfigToFalseInConsumer() {
         final StreamsConfig streamsConfig = new StreamsConfig(props);
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientId");
         assertThat(consumerConfigs.get("internal.leave.group.on.close"), CoreMatchers.<Object>equalTo(false));
     }
 
@@ -319,7 +319,7 @@ public class StreamsConfigTest {
         props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE);
         props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "anyValue");
         final StreamsConfig streamsConfig = new StreamsConfig(props);
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientId");
         assertThat((String) consumerConfigs.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG), equalTo(READ_COMMITTED.name().toLowerCase(Locale.ROOT)));
     }
 
@@ -327,7 +327,7 @@ public class StreamsConfigTest {
     public void shouldAllowSettingConsumerIsolationLevelIfEosDisabled() {
         props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, READ_UNCOMMITTED.name().toLowerCase(Locale.ROOT));
         final StreamsConfig streamsConfig = new StreamsConfig(props);
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientrId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientrId");
         assertThat((String) consumerConfigs.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG), equalTo(READ_UNCOMMITTED.name().toLowerCase(Locale.ROOT)));
     }
 
@@ -371,7 +371,7 @@ public class StreamsConfigTest {
         props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE);
         final StreamsConfig streamsConfig = new StreamsConfig(props);
 
-        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs(null, "groupId", "clientId");
+        final Map<String, Object> consumerConfigs = streamsConfig.getConsumerConfigs("groupId", "clientId");
         final Map<String, Object> producerConfigs = streamsConfig.getProducerConfigs("clientId");
 
         assertThat((String) consumerConfigs.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG), equalTo(READ_COMMITTED.name().toLowerCase(Locale.ROOT)));

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
index 9c8244a..e9df495 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java
@@ -262,12 +262,14 @@ public class QueryableStateIntegrationTest {
                 public boolean conditionMet() {
                     try {
                         final StreamsMetadata metadata = streams.metadataForKey(storeName, key, new StringSerializer());
+
                         if (metadata == null || metadata.equals(StreamsMetadata.NOT_AVAILABLE)) {
                             return false;
                         }
                         final int index = metadata.hostInfo().port();
                         final KafkaStreams streamsWithKey = streamRunnables[index].getStream();
                         final ReadOnlyKeyValueStore<String, Long> store = streamsWithKey.store(storeName, QueryableStoreTypes.<String, Long>keyValueStore());
+
                         return store != null && store.get(key) != null;
                     } catch (final IllegalStateException e) {
                         // Kafka Streams instance may have closed but rebalance hasn't happened

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
index 42e5ccf..d21d2e3 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
@@ -27,7 +27,7 @@ import org.apache.kafka.streams.processor.internals.InternalTopicManager;
 import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.ProcessorStateManager;
 import org.apache.kafka.streams.processor.internals.ProcessorTopology;
-import org.apache.kafka.streams.processor.internals.StreamPartitionAssignor.SubscriptionUpdates;
+import org.apache.kafka.streams.processor.internals.StreamPartitionAssignor;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.streams.state.internals.RocksDBWindowStoreSupplier;
 import org.apache.kafka.test.MockProcessorSupplier;
@@ -681,7 +681,7 @@ public class TopologyBuilderTest {
         builder.addSource("source-2", Pattern.compile("topic-[A-C]"));
         builder.addSource("source-3", Pattern.compile("topic-\\d"));
 
-        SubscriptionUpdates subscriptionUpdates = new SubscriptionUpdates();
+        StreamPartitionAssignor.SubscriptionUpdates subscriptionUpdates = new StreamPartitionAssignor.SubscriptionUpdates();
         Field updatedTopicsField  = subscriptionUpdates.getClass().getDeclaredField("updatedTopicSubscriptions");
         updatedTopicsField.setAccessible(true);
 
@@ -761,7 +761,7 @@ public class TopologyBuilderTest {
                 .addProcessor("my-processor", new MockProcessorSupplier(), "ingest")
                 .addStateStore(new MockStateStoreSupplier("testStateStore", false), "my-processor");
 
-        final SubscriptionUpdates subscriptionUpdates = new SubscriptionUpdates();
+        final StreamPartitionAssignor.SubscriptionUpdates subscriptionUpdates = new StreamPartitionAssignor.SubscriptionUpdates();
         final Field updatedTopicsField  = subscriptionUpdates.getClass().getDeclaredField("updatedTopicSubscriptions");
         updatedTopicsField.setAccessible(true);
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
index 418b0ba..2bd2d42 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
@@ -66,6 +66,7 @@ public class GlobalStreamThreadTest {
                                                     config,
                                                     mockConsumer,
                                                     new StateDirectory(config, time),
+                                                    0,
                                                     new Metrics(),
                                                     new MockTime(),
                                                     "clientId",
@@ -98,6 +99,7 @@ public class GlobalStreamThreadTest {
                                                     config,
                                                     mockConsumer,
                                                     new StateDirectory(config, time),
+                                                    0,
                                                     new Metrics(),
                                                     new MockTime(),
                                                     "clientId",

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java
index 7d032a1..e914f9e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java
@@ -161,7 +161,7 @@ public class InternalTopicManagerTest {
         Map<String, Integer> replicationFactorPerTopic = new HashMap<>();
 
         MockStreamKafkaClient(final StreamsConfig streamsConfig) {
-            super(StreamsKafkaClient.Config.fromStreamsConfig(streamsConfig),
+            super(StreamsKafkaClient.Config.fromStreamsConfig(streamsConfig.originals()),
                   new MockClient(new MockTime()),
                   Collections.<MetricsReporter>emptyList(),
                   new LogContext());

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java
index fa83a71..0fdb575 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java
@@ -29,7 +29,6 @@ import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.ProcessorSupplier;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreSupplier;
-import org.apache.kafka.streams.processor.internals.StreamPartitionAssignor.SubscriptionUpdates;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.streams.state.internals.RocksDBWindowStoreSupplier;
 import org.apache.kafka.test.MockProcessorSupplier;
@@ -633,7 +632,7 @@ public class InternalTopologyBuilderTest {
         builder.addSource(null, "source-2", null, null, null, Pattern.compile("topic-[A-C]"));
         builder.addSource(null, "source-3", null, null, null, Pattern.compile("topic-\\d"));
 
-        final SubscriptionUpdates subscriptionUpdates = new SubscriptionUpdates();
+        final InternalTopologyBuilder.SubscriptionUpdates subscriptionUpdates = new InternalTopologyBuilder.SubscriptionUpdates();
         final Field updatedTopicsField  = subscriptionUpdates.getClass().getDeclaredField("updatedTopicSubscriptions");
         updatedTopicsField.setAccessible(true);
 
@@ -721,7 +720,7 @@ public class InternalTopologyBuilderTest {
         builder.addProcessor("my-processor", new MockProcessorSupplier(), "ingest");
         builder.addStateStore(new MockStateStoreSupplier("testStateStore", false), "my-processor");
 
-        final SubscriptionUpdates subscriptionUpdates = new SubscriptionUpdates();
+        final InternalTopologyBuilder.SubscriptionUpdates subscriptionUpdates = new InternalTopologyBuilder.SubscriptionUpdates();
         final Field updatedTopicsField  = subscriptionUpdates.getClass().getDeclaredField("updatedTopicSubscriptions");
         updatedTopicsField.setAccessible(true);
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/5df1eee7/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
index cd37fab..99bb56d 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamPartitionAssignorTest.java
@@ -34,8 +34,6 @@ import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.KeyValueMapper;
 import org.apache.kafka.streams.kstream.Materialized;
 import org.apache.kafka.streams.kstream.ValueJoiner;
-import org.apache.kafka.streams.processor.DefaultPartitionGrouper;
-import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
 import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
@@ -45,7 +43,7 @@ import org.apache.kafka.test.MockClientSupplier;
 import org.apache.kafka.test.MockInternalTopicManager;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.MockStateStoreSupplier;
-import org.apache.kafka.test.MockTimestampExtractor;
+import org.easymock.Capture;
 import org.easymock.EasyMock;
 import org.junit.Assert;
 import org.junit.Test;
@@ -57,7 +55,6 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Properties;
 import java.util.Set;
 import java.util.UUID;
 
@@ -65,7 +62,6 @@ import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.not;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertThat;
 
 public class StreamPartitionAssignorTest {
@@ -110,43 +106,34 @@ public class StreamPartitionAssignorTest {
     private final MockClientSupplier mockClientSupplier = new MockClientSupplier();
     private final InternalTopologyBuilder builder = new InternalTopologyBuilder();
     private final StreamsConfig config = new StreamsConfig(configProps());
-    private final ThreadDataProvider threadDataProvider = EasyMock.createNiceMock(ThreadDataProvider.class);
-    private final Map<String, Object> configurationMap = new HashMap<>();
-    private final DefaultPartitionGrouper defaultPartitionGrouper = new DefaultPartitionGrouper();
-    private final SingleGroupPartitionGrouperStub stubPartitionGrouper = new SingleGroupPartitionGrouperStub();
     private final String userEndPoint = "localhost:8080";
+    private final String applicationId = "stream-partition-assignor-test";
 
-    private Properties configProps() {
-        return new Properties() {
-            {
-                setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "stream-partition-assignor-test");
-                setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, userEndPoint);
-                setProperty(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3");
-                setProperty(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MockTimestampExtractor.class.getName());
-            }
-        };
+    private final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class);
+
+    private Map<String, Object> configProps() {
+        Map<String, Object> configurationMap = new HashMap<>();
+        configurationMap.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId);
+        configurationMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, userEndPoint);
+        configurationMap.put(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, taskManager);
+        return configurationMap;
     }
 
-    private void configurePartitionAssignor(final int standbyReplicas, final String endPoint) {
-        configurationMap.put(StreamsConfig.InternalConfig.STREAM_THREAD_INSTANCE, threadDataProvider);
-        configurationMap.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, standbyReplicas);
-        configurationMap.put(StreamsConfig.APPLICATION_SERVER_CONFIG, endPoint);
+    private void configurePartitionAssignor(final Map<String, Object> props) {
+        Map<String, Object> configurationMap = configProps();
+        configurationMap.putAll(props);
         partitionAssignor.configure(configurationMap);
     }
 
-    private void mockThreadDataProvider(final Set<TaskId> prevTasks,
-                                        final Set<TaskId> cachedTasks,
-                                        final UUID processId,
-                                        final PartitionGrouper partitionGrouper,
-                                        final InternalTopologyBuilder builder) throws NoSuchFieldException, IllegalAccessException {
-        EasyMock.expect(threadDataProvider.name()).andReturn("name").anyTimes();
-        EasyMock.expect(threadDataProvider.prevActiveTasks()).andReturn(prevTasks).anyTimes();
-        EasyMock.expect(threadDataProvider.cachedTasks()).andReturn(cachedTasks).anyTimes();
-        EasyMock.expect(threadDataProvider.config()).andReturn(config).anyTimes();
-        EasyMock.expect(threadDataProvider.builder()).andReturn(builder).anyTimes();
-        EasyMock.expect(threadDataProvider.processId()).andReturn(processId).anyTimes();
-        EasyMock.expect(threadDataProvider.partitionGrouper()).andReturn(partitionGrouper).anyTimes();
-        EasyMock.replay(threadDataProvider);
+    private void mockTaskManager(final Set<TaskId> prevTasks,
+                                 final Set<TaskId> cachedTasks,
+                                 final UUID processId,
+                                 final InternalTopologyBuilder builder) throws NoSuchFieldException, IllegalAccessException {
+        EasyMock.expect(taskManager.builder()).andReturn(builder).anyTimes();
+        EasyMock.expect(taskManager.prevActiveTaskIds()).andReturn(prevTasks).anyTimes();
+        EasyMock.expect(taskManager.cachedTasksIds()).andReturn(cachedTasks).anyTimes();
+        EasyMock.expect(taskManager.processId()).andReturn(processId).anyTimes();
+        EasyMock.replay(taskManager);
     }
 
 
@@ -163,9 +150,9 @@ public class StreamPartitionAssignorTest {
                 new TaskId(0, 2), new TaskId(1, 2), new TaskId(2, 2));
 
         final UUID processId = UUID.randomUUID();
-        mockThreadDataProvider(prevTasks, cachedTasks, processId, stubPartitionGrouper, builder);
+        mockTaskManager(prevTasks, cachedTasks, processId, builder);
 
-        configurePartitionAssignor(0, null);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
         PartitionAssignor.Subscription subscription = partitionAssignor.subscription(Utils.mkSet("topic1", "topic2"));
 
         Collections.sort(subscription.topics());
@@ -197,8 +184,8 @@ public class StreamPartitionAssignorTest {
         UUID uuid1 = UUID.randomUUID();
         UUID uuid2 = UUID.randomUUID();
 
-        mockThreadDataProvider(prevTasks10, standbyTasks10, uuid1, stubPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
+        mockTaskManager(prevTasks10, standbyTasks10, uuid1, builder);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
 
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer));
 
@@ -245,10 +232,6 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void testAssignWithPartialTopology() throws Exception {
-        Properties props = configProps();
-        props.put(StreamsConfig.PARTITION_GROUPER_CLASS_CONFIG, SingleGroupPartitionGrouperStub.class);
-        StreamsConfig config = new StreamsConfig(props);
-
         builder.addSource(null, "source1", null, null, null, "topic1");
         builder.addProcessor("processor1", new MockProcessorSupplier(), "source1");
         builder.addStateStore(new MockStateStoreSupplier("store1", false), "processor1");
@@ -260,10 +243,10 @@ public class StreamPartitionAssignorTest {
 
         UUID uuid1 = UUID.randomUUID();
 
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, stubPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
+        mockTaskManager(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, builder);
+        configurePartitionAssignor(Collections.singletonMap(StreamsConfig.PARTITION_GROUPER_CLASS_CONFIG, (Object) SingleGroupPartitionGrouperStub.class));
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer));
+        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(new StreamsConfig(configProps()), mockClientSupplier.restoreConsumer));
         Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
         subscriptions.put("consumer10",
             new PartitionAssignor.Subscription(topics, new SubscriptionInfo(uuid1, Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), userEndPoint).encode()));
@@ -298,8 +281,8 @@ public class StreamPartitionAssignorTest {
             Collections.<String>emptySet());
         UUID uuid1 = UUID.randomUUID();
 
-        mockThreadDataProvider(prevTasks10, standbyTasks10, uuid1, stubPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
+        mockTaskManager(prevTasks10, standbyTasks10, uuid1, builder);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
 
         Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
         subscriptions.put("consumer10",
@@ -353,8 +336,9 @@ public class StreamPartitionAssignorTest {
 
         UUID uuid1 = UUID.randomUUID();
         UUID uuid2 = UUID.randomUUID();
-        mockThreadDataProvider(prevTasks10, Collections.<TaskId>emptySet(), uuid1, stubPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
+        mockTaskManager(prevTasks10, Collections.<TaskId>emptySet(), uuid1, builder);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
+
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer));
 
         Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
@@ -392,7 +376,6 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void testAssignWithStates() throws Exception {
-        String applicationId = "test";
         builder.setApplicationId(applicationId);
         builder.addSource(null, "source1", null, null, null, "topic1");
         builder.addSource(null, "source2", null, null, null, "topic2");
@@ -417,11 +400,11 @@ public class StreamPartitionAssignorTest {
         UUID uuid1 = UUID.randomUUID();
         UUID uuid2 = UUID.randomUUID();
 
-        mockThreadDataProvider(Collections.<TaskId>emptySet(),
+        mockTaskManager(Collections.<TaskId>emptySet(),
                                Collections.<TaskId>emptySet(),
                                uuid1,
-                               defaultPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
+                builder);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
 
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer));
 
@@ -481,8 +464,8 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void testAssignWithStandbyReplicas() throws Exception {
-        Properties props = configProps();
-        props.setProperty(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "1");
+        Map<String, Object> props = configProps();
+        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "1");
         StreamsConfig config = new StreamsConfig(props);
 
         builder.addSource(null, "source1", null, null, null, "topic1");
@@ -502,9 +485,10 @@ public class StreamPartitionAssignorTest {
         UUID uuid1 = UUID.randomUUID();
         UUID uuid2 = UUID.randomUUID();
 
-        mockThreadDataProvider(prevTasks00, standbyTasks01, uuid1, defaultPartitionGrouper, builder);
+        mockTaskManager(prevTasks00, standbyTasks01, uuid1, builder);
+
+        configurePartitionAssignor(Collections.<String, Object>singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1));
 
-        configurePartitionAssignor(1, null);
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer));
 
         Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
@@ -552,35 +536,41 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void testOnAssignment() throws Exception {
-        TopicPartition t2p3 = new TopicPartition("topic2", 3);
-
-        builder.addSource(null, "source1", null, null, null, "topic1");
-        builder.addSource(null, "source2", null, null, null, "topic2");
-        builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
-
-        UUID uuid = UUID.randomUUID();
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid, defaultPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
-
-        List<TaskId> activeTaskList = Utils.mkList(task0, task3);
-        Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
-        Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
-        activeTasks.put(task0, Utils.mkSet(t1p0));
-        activeTasks.put(task3, Utils.mkSet(t2p3));
-        standbyTasks.put(task1, Utils.mkSet(t1p0));
-        standbyTasks.put(task2, Utils.mkSet(t2p0));
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
+
+        final List<TaskId> activeTaskList = Utils.mkList(task0, task3);
+        final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
+        final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
+        final Map<HostInfo, Set<TopicPartition>> hostState = Collections.singletonMap(
+                new HostInfo("localhost", 9090),
+                Utils.mkSet(t3p0, t3p3));
+        activeTasks.put(task0, Utils.mkSet(t3p0));
+        activeTasks.put(task3, Utils.mkSet(t3p3));
+        standbyTasks.put(task1, Utils.mkSet(t3p1));
+        standbyTasks.put(task2, Utils.mkSet(t3p2));
+
+        final AssignmentInfo info = new AssignmentInfo(activeTaskList, standbyTasks, hostState);
+        final PartitionAssignor.Assignment assignment = new PartitionAssignor.Assignment(Utils.mkList(t3p0, t3p3), info.encode());
+
+        Capture<Cluster> capturedCluster = EasyMock.newCapture();
+        taskManager.setPartitionsByHostState(hostState);
+        EasyMock.expectLastCall();
+        taskManager.setAssignmentMetadata(activeTasks, standbyTasks);
+        EasyMock.expectLastCall();
+        taskManager.setClusterMetadata(EasyMock.capture(capturedCluster));
+        EasyMock.expectLastCall();
+        EasyMock.replay(taskManager);
 
-        AssignmentInfo info = new AssignmentInfo(activeTaskList, standbyTasks, new HashMap<HostInfo, Set<TopicPartition>>());
-        PartitionAssignor.Assignment assignment = new PartitionAssignor.Assignment(Utils.mkList(t1p0, t2p3), info.encode());
         partitionAssignor.onAssignment(assignment);
 
-        assertEquals(activeTasks, partitionAssignor.activeTasks());
-        assertEquals(standbyTasks, partitionAssignor.standbyTasks());
+        EasyMock.verify(taskManager);
+
+        assertEquals(Collections.singleton(t3p0.topic()), capturedCluster.getValue().topics());
+        assertEquals(2, capturedCluster.getValue().partitionsForTopic(t3p0.topic()).size());
     }
 
     @Test
     public void testAssignWithInternalTopics() throws Exception {
-        String applicationId = "test";
         builder.setApplicationId(applicationId);
         builder.addInternalTopic("topicX");
         builder.addSource(null, "source1", null, null, null, "topic1");
@@ -588,12 +578,12 @@ public class StreamPartitionAssignorTest {
         builder.addSink("sink1", "topicX", null, null, null, "processor1");
         builder.addSource(null, "source2", null, null, null, "topicX");
         builder.addProcessor("processor2", new MockProcessorSupplier(), "source2");
-        List<String> topics = Utils.mkList("topic1", "test-topicX");
+        List<String> topics = Utils.mkList("topic1", applicationId + "-topicX");
         Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
 
         UUID uuid1 = UUID.randomUUID();
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, defaultPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
+        mockTaskManager(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, builder);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
         MockInternalTopicManager internalTopicManager = new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer);
         partitionAssignor.setInternalTopicManager(internalTopicManager);
 
@@ -606,7 +596,7 @@ public class StreamPartitionAssignorTest {
 
         // check prepared internal topics
         assertEquals(1, internalTopicManager.readyTopics.size());
-        assertEquals(allTasks.size(), (long) internalTopicManager.readyTopics.get("test-topicX"));
+        assertEquals(allTasks.size(), (long) internalTopicManager.readyTopics.get(applicationId + "-topicX"));
     }
 
     @Test
@@ -626,9 +616,9 @@ public class StreamPartitionAssignorTest {
         Set<TaskId> allTasks = Utils.mkSet(task0, task1, task2);
 
         UUID uuid1 = UUID.randomUUID();
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, defaultPartitionGrouper, builder);
+        mockTaskManager(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, builder);
 
-        configurePartitionAssignor(0, null);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
         MockInternalTopicManager internalTopicManager = new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer);
         partitionAssignor.setInternalTopicManager(internalTopicManager);
 
@@ -646,18 +636,17 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void shouldAddUserDefinedEndPointToSubscription() throws Exception {
-        final String applicationId = "application-id";
         builder.setApplicationId(applicationId);
         builder.addSource(null, "source", null, null, null, "input");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source");
         builder.addSink("sink", "output", null, null, null, "processor");
 
         final UUID uuid1 = UUID.randomUUID();
-        mockThreadDataProvider(Collections.<TaskId>emptySet(),
+        mockTaskManager(Collections.<TaskId>emptySet(),
                                Collections.<TaskId>emptySet(),
                                uuid1,
-                               defaultPartitionGrouper, builder);
-        configurePartitionAssignor(0, userEndPoint);
+                builder);
+        configurePartitionAssignor(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, (Object) userEndPoint));
         final PartitionAssignor.Subscription subscription = partitionAssignor.subscription(Utils.mkSet("input"));
         final SubscriptionInfo subscriptionInfo = SubscriptionInfo.decode(subscription.userData());
         assertEquals("localhost:8080", subscriptionInfo.userEndPoint);
@@ -665,7 +654,6 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void shouldMapUserEndPointToTopicPartitions() throws Exception {
-        final String applicationId = "application-id";
         builder.setApplicationId(applicationId);
         builder.addSource(null, "source", null, null, null, "topic1");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source");
@@ -675,8 +663,9 @@ public class StreamPartitionAssignorTest {
 
         final UUID uuid1 = UUID.randomUUID();
 
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, defaultPartitionGrouper, builder);
-        configurePartitionAssignor(0, userEndPoint);
+        mockTaskManager(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), uuid1, builder);
+        configurePartitionAssignor(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, (Object) userEndPoint));
+
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer));
 
         final Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
@@ -695,15 +684,13 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void shouldThrowExceptionIfApplicationServerConfigIsNotHostPortPair() throws Exception {
-        final String myEndPoint = "localhost";
-        final String applicationId = "application-id";
         builder.setApplicationId(applicationId);
 
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), UUID.randomUUID(), defaultPartitionGrouper, builder);
+        mockTaskManager(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), UUID.randomUUID(), builder);
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(config, mockClientSupplier.restoreConsumer));
 
         try {
-            configurePartitionAssignor(0, myEndPoint);
+            configurePartitionAssignor(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, (Object) "localhost"));
             Assert.fail("expected to an exception due to invalid config");
         } catch (ConfigException e) {
             // pass
@@ -712,12 +699,10 @@ public class StreamPartitionAssignorTest {
 
     @Test
     public void shouldThrowExceptionIfApplicationServerConfigPortIsNotAnInteger() {
-        final String myEndPoint = "localhost:j87yhk";
-        final String applicationId = "application-id";
         builder.setApplicationId(applicationId);
 
         try {
-            configurePartitionAssignor(0, myEndPoint);
+            configurePartitionAssignor(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, (Object) "localhost:j87yhk"));
             Assert.fail("expected to an exception due to invalid config");
         } catch (ConfigException e) {
             // pass
@@ -725,53 +710,7 @@ public class StreamPartitionAssignorTest {
     }
 
     @Test
-    public void shouldExposeHostStateToTopicPartitionsOnAssignment() throws Exception {
-        List<TopicPartition> topic = Collections.singletonList(new TopicPartition("topic", 0));
-        final Map<HostInfo, Set<TopicPartition>> hostState =
-                Collections.singletonMap(new HostInfo("localhost", 80),
-                        Collections.singleton(new TopicPartition("topic", 0)));
-        AssignmentInfo assignmentInfo = new AssignmentInfo(Collections.singletonList(new TaskId(0, 0)),
-                Collections.<TaskId, Set<TopicPartition>>emptyMap(),
-                hostState);
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), UUID.randomUUID(), defaultPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
-
-        partitionAssignor.onAssignment(new PartitionAssignor.Assignment(topic, assignmentInfo.encode()));
-        assertEquals(hostState, partitionAssignor.getPartitionsByHostState());
-    }
-
-    @Test
-    public void shouldSetClusterMetadataOnAssignment() throws Exception {
-        final List<TopicPartition> topic = Collections.singletonList(new TopicPartition("topic", 0));
-        final Map<HostInfo, Set<TopicPartition>> hostState =
-                Collections.singletonMap(new HostInfo("localhost", 80),
-                        Collections.singleton(new TopicPartition("topic", 0)));
-        final AssignmentInfo assignmentInfo = new AssignmentInfo(Collections.singletonList(new TaskId(0, 0)),
-                Collections.<TaskId, Set<TopicPartition>>emptyMap(),
-                hostState);
-
-
-        mockThreadDataProvider(Collections.<TaskId>emptySet(), Collections.<TaskId>emptySet(), UUID.randomUUID(), defaultPartitionGrouper, builder);
-        configurePartitionAssignor(0, null);
-        partitionAssignor.onAssignment(new PartitionAssignor.Assignment(topic, assignmentInfo.encode()));
-        final Cluster cluster = partitionAssignor.clusterMetadata();
-        final List<PartitionInfo> partitionInfos = cluster.partitionsForTopic("topic");
-        final PartitionInfo partitionInfo = partitionInfos.get(0);
-        assertEquals(1, partitionInfos.size());
-        assertEquals("topic", partitionInfo.topic());
-        assertEquals(0, partitionInfo.partition());
-    }
-
-    @Test
-    public void shouldReturnEmptyClusterMetadataIfItHasntBeenBuilt() {
-        final Cluster cluster = partitionAssignor.clusterMetadata();
-        assertNotNull(cluster);
-    }
-
-    @Test
     public void shouldNotLoopInfinitelyOnMissingMetadataAndShouldNotCreateRelatedTasks() throws Exception {
-        final String applicationId = "application-id";
-
         final StreamsBuilder builder = new StreamsBuilder();
 
         final InternalTopologyBuilder internalTopologyBuilder = StreamsBuilderTest.internalTopologyBuilder(builder);
@@ -833,12 +772,11 @@ public class StreamPartitionAssignorTest {
         final UUID uuid = UUID.randomUUID();
         final String client = "client1";
 
-        mockThreadDataProvider(Collections.<TaskId>emptySet(),
+        mockTaskManager(Collections.<TaskId>emptySet(),
                                Collections.<TaskId>emptySet(),
                                UUID.randomUUID(),
-                               defaultPartitionGrouper,
-                               internalTopologyBuilder);
-        configurePartitionAssignor(0, null);
+                internalTopologyBuilder);
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
 
         final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager(
             config,
@@ -874,53 +812,25 @@ public class StreamPartitionAssignorTest {
     }
 
     @Test
-    public void shouldUpdatePartitionHostInfoMapOnAssignment() throws Exception {
+    public void shouldUpdateClusterMetadataAndHostInfoOnAssignment() throws Exception {
         final TopicPartition partitionOne = new TopicPartition("topic", 1);
         final TopicPartition partitionTwo = new TopicPartition("topic", 2);
-        final Map<HostInfo, Set<TopicPartition>> firstHostState = Collections.singletonMap(
+        final Map<HostInfo, Set<TopicPartition>> hostState = Collections.singletonMap(
                 new HostInfo("localhost", 9090), Utils.mkSet(partitionOne, partitionTwo));
 
-        final Map<HostInfo, Set<TopicPartition>> secondHostState = new HashMap<>();
-        secondHostState.put(new HostInfo("localhost", 9090), Utils.mkSet(partitionOne));
-        secondHostState.put(new HostInfo("other", 9090), Utils.mkSet(partitionTwo));
+        configurePartitionAssignor(Collections.<String, Object>emptyMap());
 
-        mockThreadDataProvider(Collections.<TaskId>emptySet(),
-                               Collections.<TaskId>emptySet(),
-                               UUID.randomUUID(),
-                               defaultPartitionGrouper,
-                               builder);
-        configurePartitionAssignor(0, null);
-        partitionAssignor.onAssignment(createAssignment(firstHostState));
-        assertEquals(firstHostState, partitionAssignor.getPartitionsByHostState());
-        partitionAssignor.onAssignment(createAssignment(secondHostState));
-        assertEquals(secondHostState, partitionAssignor.getPartitionsByHostState());
-    }
-
-    @Test
-    public void shouldUpdateClusterMetadataOnAssignment() throws Exception {
-        final TopicPartition topicOne = new TopicPartition("topic", 1);
-        final TopicPartition topicTwo = new TopicPartition("topic2", 2);
-        final Map<HostInfo, Set<TopicPartition>> firstHostState = Collections.singletonMap(
-                new HostInfo("localhost", 9090), Utils.mkSet(topicOne));
+        taskManager.setPartitionsByHostState(hostState);
+        EasyMock.expectLastCall();
+        EasyMock.replay(taskManager);
 
-        final Map<HostInfo, Set<TopicPartition>> secondHostState = Collections.singletonMap(
-                new HostInfo("localhost", 9090), Utils.mkSet(topicOne, topicTwo));
+        partitionAssignor.onAssignment(createAssignment(hostState));
 
-        mockThreadDataProvider(Collections.<TaskId>emptySet(),
-                               Collections.<TaskId>emptySet(),
-                               UUID.randomUUID(),
-                               defaultPartitionGrouper,
-                               builder);
-        configurePartitionAssignor(0, null);
-        partitionAssignor.onAssignment(createAssignment(firstHostState));
-        assertEquals(Utils.mkSet("topic"), partitionAssignor.clusterMetadata().topics());
-        partitionAssignor.onAssignment(createAssignment(secondHostState));
-        assertEquals(Utils.mkSet("topic", "topic2"), partitionAssignor.clusterMetadata().topics());
+        EasyMock.verify(taskManager);
     }
 
     @Test
     public void shouldNotAddStandbyTaskPartitionsToPartitionsForHost() throws Exception {
-        final String applicationId = "appId";
         final StreamsBuilder builder = new StreamsBuilder();
 
         final InternalTopologyBuilder internalTopologyBuilder = StreamsBuilderTest.internalTopologyBuilder(builder);
@@ -929,13 +839,15 @@ public class StreamPartitionAssignorTest {
         builder.stream("topic1").groupByKey().count();
 
         final UUID uuid = UUID.randomUUID();
-        mockThreadDataProvider(Collections.<TaskId>emptySet(),
+        mockTaskManager(Collections.<TaskId>emptySet(),
                                Collections.<TaskId>emptySet(),
                                uuid,
-                               defaultPartitionGrouper,
-                               internalTopologyBuilder);
+                internalTopologyBuilder);
 
-        configurePartitionAssignor(1, userEndPoint);
+        Map<String, Object> props = new HashMap<>();
+        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+        props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, userEndPoint);
+        configurePartitionAssignor(props);
         partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(
             config,
             mockClientSupplier.restoreConsumer));
@@ -979,7 +891,7 @@ public class StreamPartitionAssignorTest {
     public void shouldThrowKafkaExceptionIfStreamThreadConfigIsNotThreadDataProviderInstance() {
         final Map<String, Object> config = new HashMap<>();
         config.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
-        config.put(StreamsConfig.InternalConfig.STREAM_THREAD_INSTANCE, "i am not a stream thread");
+        config.put(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, "i am not a stream thread");
 
         partitionAssignor.configure(config);
     }
@@ -1035,5 +947,4 @@ public class StreamPartitionAssignorTest {
 
         return info;
     }
-
 }