You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2015/11/12 01:08:24 UTC

kafka git commit: KAFKA-2763: better stream task assignment

Repository: kafka
Updated Branches:
  refs/heads/trunk c6b8de4e6 -> 124f73b17


KAFKA-2763: better stream task assignment

guozhangwang

When the rebalance happens each consumer reports the following information to the coordinator.
* Client UUID (a unique id assigned to an instance of KafkaStreaming)
* Task ids of previously running tasks
* Task ids of valid local states on the client's state directory

TaskAssignor does the following
* Assign a task to a client which was running it previously. If there is no such client, assign a task to a client which has its valid local state.
* Try to balance the load among stream threads.
  * A client may have more than one stream threads. The assignor tries to assign tasks to a client proportionally to the number of threads.

Author: Yasuhiro Matsuda <ya...@confluent.io>

Reviewers: Guozhang Wang

Closes #497 from ymatsuda/task_assignment


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/124f73b1
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/124f73b1
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/124f73b1

Branch: refs/heads/trunk
Commit: 124f73b1747a574982e9ca491712e6758ddbacea
Parents: c6b8de4
Author: Yasuhiro Matsuda <ya...@confluent.io>
Authored: Wed Nov 11 16:14:27 2015 -0800
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Wed Nov 11 16:14:27 2015 -0800

----------------------------------------------------------------------
 .../apache/kafka/streams/KafkaStreaming.java    |   5 +-
 .../apache/kafka/streams/StreamingConfig.java   |   8 +-
 .../streams/processor/PartitionGrouper.java     |   4 +
 .../apache/kafka/streams/processor/TaskId.java  |  23 +-
 .../KafkaStreamingPartitionAssignor.java        | 187 ++++++++----
 .../processor/internals/StreamThread.java       |  59 +++-
 .../internals/assignment/AssignmentInfo.java    | 125 ++++++++
 .../internals/assignment/ClientState.java       |  72 +++++
 .../internals/assignment/SubscriptionInfo.java  | 128 ++++++++
 .../assignment/TaskAssignmentException.java     |  32 ++
 .../internals/assignment/TaskAssignor.java      | 195 +++++++++++++
 .../KafkaStreamingPartitionAssignorTest.java    | 283 ++++++++++++++++++
 .../processor/internals/StreamThreadTest.java   |  33 ++-
 .../assignment/AssginmentInfoTest.java          |  45 +++
 .../assignment/SubscriptionInfoTest.java        |  46 +++
 .../internals/assignment/TaskAssignorTest.java  | 289 +++++++++++++++++++
 16 files changed, 1464 insertions(+), 70 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
index d274fb9..fc1fdae 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
@@ -29,6 +29,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.List;
+import java.util.UUID;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -85,11 +86,13 @@ public class KafkaStreaming {
     private final StreamThread[] threads;
 
     private String clientId;
+    private final UUID uuid;
     private final Metrics metrics;
 
     public KafkaStreaming(TopologyBuilder builder, StreamingConfig config) throws Exception {
         // create the metrics
         this.time = new SystemTime();
+        this.uuid = UUID.randomUUID();
 
         MetricConfig metricConfig = new MetricConfig().samples(config.getInt(StreamingConfig.METRICS_NUM_SAMPLES_CONFIG))
             .timeWindow(config.getLong(StreamingConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG),
@@ -104,7 +107,7 @@ public class KafkaStreaming {
 
         this.threads = new StreamThread[config.getInt(StreamingConfig.NUM_STREAM_THREADS_CONFIG)];
         for (int i = 0; i < this.threads.length; i++) {
-            this.threads[i] = new StreamThread(builder, config, this.clientId, this.metrics, this.time);
+            this.threads[i] = new StreamThread(builder, config, this.clientId, this.uuid, this.metrics, this.time);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
index 88bd844..693cb0c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
@@ -27,8 +27,8 @@ import org.apache.kafka.common.config.ConfigDef.Type;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.streams.processor.DefaultPartitionGrouper;
-import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.internals.KafkaStreamingPartitionAssignor;
+import org.apache.kafka.streams.processor.internals.StreamThread;
 
 import java.util.Map;
 
@@ -205,16 +205,16 @@ public class StreamingConfig extends AbstractConfig {
     }
 
     public static class InternalConfig {
-        public static final String PARTITION_GROUPER_INSTANCE = "__partition.grouper.instance__";
+        public static final String STREAM_THREAD_INSTANCE = "__stream.thread.instance__";
     }
 
     public StreamingConfig(Map<?, ?> props) {
         super(CONFIG, props);
     }
 
-    public Map<String, Object> getConsumerConfigs(PartitionGrouper partitionGrouper) {
+    public Map<String, Object> getConsumerConfigs(StreamThread streamThread) {
         Map<String, Object> props = getConsumerConfigs();
-        props.put(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper);
+        props.put(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, streamThread);
         props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, KafkaStreamingPartitionAssignor.class.getName());
         return props;
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
index 026ec89..00b56b3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
@@ -50,4 +50,8 @@ public abstract class PartitionGrouper {
         return partitionAssignor.taskIds(partition);
     }
 
+    public Set<TaskId> standbyTasks() {
+        return partitionAssignor.standbyTasks();
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
index 3d474fe..5344f6c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
@@ -17,7 +17,9 @@
 
 package org.apache.kafka.streams.processor;
 
-public class TaskId {
+import java.nio.ByteBuffer;
+
+public class TaskId implements Comparable<TaskId> {
 
     public final int topicGroupId;
     public final int partition;
@@ -45,6 +47,15 @@ public class TaskId {
         }
     }
 
+    public void writeTo(ByteBuffer buf) {
+        buf.putInt(topicGroupId);
+        buf.putInt(partition);
+    }
+
+    public static TaskId readFrom(ByteBuffer buf) {
+        return new TaskId(buf.getInt(), buf.getInt());
+    }
+
     @Override
     public boolean equals(Object o) {
         if (o instanceof TaskId) {
@@ -61,6 +72,16 @@ public class TaskId {
         return (int) (n % 0xFFFFFFFFL);
     }
 
+    @Override
+    public int compareTo(TaskId other) {
+        return
+            this.topicGroupId < other.topicGroupId ? -1 :
+                (this.topicGroupId > other.topicGroupId ? 1 :
+                    (this.partition < other.partition ? -1 :
+                        (this.partition > other.partition ? 1 :
+                            0)));
+    }
+
     public static class TaskIdFormatException extends RuntimeException {
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
index f7b14ad..35ba0ec 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
@@ -23,37 +23,49 @@ import org.apache.kafka.common.Configurable;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.StreamingConfig;
-import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import org.apache.kafka.streams.processor.internals.assignment.ClientState;
+import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
+import org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentException;
+import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.UUID;
 
 public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Configurable {
 
     private static final Logger log = LoggerFactory.getLogger(KafkaStreamingPartitionAssignor.class);
 
-    private PartitionGrouper partitionGrouper;
+    private StreamThread streamThread;
     private Map<TopicPartition, Set<TaskId>> partitionToTaskIds;
+    private Set<TaskId> standbyTasks;
 
     @Override
     public void configure(Map<String, ?> configs) {
-        Object o = configs.get(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE);
-        if (o == null)
-            throw new KafkaException("PartitionGrouper is not specified");
+        Object o = configs.get(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE);
+        if (o == null) {
+            KafkaException ex = new KafkaException("StreamThread is not specified");
+            log.error(ex.getMessage(), ex);
+            throw ex;
+        }
 
-        if (!PartitionGrouper.class.isInstance(o))
-            throw new KafkaException(o.getClass().getName() + " is not an instance of " + PartitionGrouper.class.getName());
+        if (!(o instanceof StreamThread)) {
+            KafkaException ex = new KafkaException(o.getClass().getName() + " is not an instance of " + StreamThread.class.getName());
+            log.error(ex.getMessage(), ex);
+            throw ex;
+        }
 
-        partitionGrouper = (PartitionGrouper) o;
-        partitionGrouper.partitionAssignor(this);
+        streamThread = (StreamThread) o;
+        streamThread.partitionGrouper.partitionAssignor(this);
     }
 
     @Override
@@ -63,38 +75,110 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi
 
     @Override
     public Subscription subscription(Set<String> topics) {
-        return new Subscription(new ArrayList<>(topics));
+        // Adds the following information to subscription
+        // 1. Client UUID (a unique id assigned to an instance of KafkaStreaming)
+        // 2. Task ids of previously running tasks
+        // 3. Task ids of valid local states on the client's state directory.
+
+        Set<TaskId> prevTasks = streamThread.prevTasks();
+        Set<TaskId> standbyTasks = streamThread.cachedTasks();
+        standbyTasks.removeAll(prevTasks);
+        SubscriptionInfo data = new SubscriptionInfo(streamThread.clientUUID, prevTasks, standbyTasks);
+
+        return new Subscription(new ArrayList<>(topics), data.encode());
     }
 
     @Override
     public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription> subscriptions) {
-        Map<TaskId, Set<TopicPartition>> partitionGroups = partitionGrouper.partitionGroups(metadata);
+        // This assigns tasks to consumer clients in two steps.
+        // 1. using TaskAssignor tasks are assigned to streaming clients.
+        //    - Assign a task to a client which was running it previously.
+        //      If there is no such client, assign a task to a client which has its valid local state.
+        //    - A client may have more than one stream threads.
+        //      The assignor tries to assign tasks to a client proportionally to the number of threads.
+        //    - We try not to assign the same set of tasks to two different clients
+        //    We do the assignment in one-pass. The result may not satisfy above all.
+        // 2. within each client, tasks are assigned to consumer clients in round-robin manner.
+
+        Map<UUID, Set<String>> consumersByClient = new HashMap<>();
+        Map<UUID, ClientState<TaskId>> states = new HashMap<>();
+
+        // Decode subscription info
+        for (Map.Entry<String, Subscription> entry : subscriptions.entrySet()) {
+            String consumerId = entry.getKey();
+            Subscription subscription = entry.getValue();
+
+            SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData());
+
+            Set<String> consumers = consumersByClient.get(info.clientUUID);
+            if (consumers == null) {
+                consumers = new HashSet<>();
+                consumersByClient.put(info.clientUUID, consumers);
+            }
+            consumers.add(consumerId);
+
+            ClientState<TaskId> state = states.get(info.clientUUID);
+            if (state == null) {
+                state = new ClientState<>();
+                states.put(info.clientUUID, state);
+            }
+
+            state.prevActiveTasks.addAll(info.prevTasks);
+            state.prevAssignedTasks.addAll(info.prevTasks);
+            state.prevAssignedTasks.addAll(info.standbyTasks);
+            state.capacity = state.capacity + 1d;
+        }
 
-        String[] clientIds = subscriptions.keySet().toArray(new String[subscriptions.size()]);
-        TaskId[] taskIds = partitionGroups.keySet().toArray(new TaskId[partitionGroups.size()]);
+        // Get partition groups from the partition grouper
+        Map<TaskId, Set<TopicPartition>> partitionGroups = streamThread.partitionGrouper.partitionGroups(metadata);
 
+        states = TaskAssignor.assign(states, partitionGroups.keySet(), 0); // TODO: enable standby tasks
         Map<String, Assignment> assignment = new HashMap<>();
 
-        for (int i = 0; i < clientIds.length; i++) {
-            List<TopicPartition> partitions = new ArrayList<>();
-            List<TaskId> ids = new ArrayList<>();
-            for (int j = i; j < taskIds.length; j += clientIds.length) {
-                TaskId taskId = taskIds[j];
-                for (TopicPartition partition : partitionGroups.get(taskId)) {
-                    partitions.add(partition);
-                    ids.add(taskId);
-                }
+        for (Map.Entry<UUID, Set<String>> entry : consumersByClient.entrySet()) {
+            UUID uuid = entry.getKey();
+            Set<String> consumers = entry.getValue();
+            ClientState<TaskId> state = states.get(uuid);
+
+            ArrayList<TaskId> taskIds = new ArrayList<>(state.assignedTasks.size());
+            final int numActiveTasks = state.activeTasks.size();
+            for (TaskId id : state.activeTasks) {
+                taskIds.add(id);
             }
-            ByteBuffer buf = ByteBuffer.allocate(4 + ids.size() * 8);
-            //version
-            buf.putInt(1);
-            // encode task ids
-            for (TaskId id : ids) {
-                buf.putInt(id.topicGroupId);
-                buf.putInt(id.partition);
+            for (TaskId id : state.assignedTasks) {
+                if (!state.activeTasks.contains(id))
+                    taskIds.add(id);
+            }
+
+            final int numConsumers = consumers.size();
+            List<TaskId> active = new ArrayList<>();
+            Set<TaskId> standby = new HashSet<>();
+
+            int i = 0;
+            for (String consumer : consumers) {
+                List<TopicPartition> partitions = new ArrayList<>();
+
+                final int numTaskIds = taskIds.size();
+                for (int j = i; j < numTaskIds; j += numConsumers) {
+                    TaskId taskId = taskIds.get(j);
+                    if (j < numActiveTasks) {
+                        for (TopicPartition partition : partitionGroups.get(taskId)) {
+                            partitions.add(partition);
+                            active.add(taskId);
+                        }
+                    } else {
+                        // no partition to a standby task
+                        standby.add(taskId);
+                    }
+                }
+
+                AssignmentInfo data = new AssignmentInfo(active, standby);
+                assignment.put(consumer, new Assignment(partitions, data.encode()));
+                i++;
+
+                active.clear();
+                standby.clear();
             }
-            buf.rewind();
-            assignment.put(clientIds[i], new Assignment(partitions, buf));
         }
 
         return assignment;
@@ -103,27 +187,29 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi
     @Override
     public void onAssignment(Assignment assignment) {
         List<TopicPartition> partitions = assignment.partitions();
-        ByteBuffer data = assignment.userData();
-        data.rewind();
+
+        AssignmentInfo info = AssignmentInfo.decode(assignment.userData());
+        this.standbyTasks = info.standbyTasks;
 
         Map<TopicPartition, Set<TaskId>> partitionToTaskIds = new HashMap<>();
+        Iterator<TaskId> iter = info.activeTasks.iterator();
+        for (TopicPartition partition : partitions) {
+            Set<TaskId> taskIds = partitionToTaskIds.get(partition);
+            if (taskIds == null) {
+                taskIds = new HashSet<>();
+                partitionToTaskIds.put(partition, taskIds);
+            }
 
-        // check version
-        int version = data.getInt();
-        if (version == 1) {
-            for (TopicPartition partition : partitions) {
-                Set<TaskId> taskIds = partitionToTaskIds.get(partition);
-                if (taskIds == null) {
-                    taskIds = new HashSet<>();
-                    partitionToTaskIds.put(partition, taskIds);
-                }
-                // decode a task id
-                taskIds.add(new TaskId(data.getInt(), data.getInt()));
+            if (iter.hasNext()) {
+                taskIds.add(iter.next());
+            } else {
+                TaskAssignmentException ex = new TaskAssignmentException(
+                        "failed to find a task id for the partition=" + partition.toString() +
+                        ", partitions=" + partitions.size() + ", assignmentInfo=" + info.toString()
+                );
+                log.error(ex.getMessage(), ex);
+                throw ex;
             }
-        } else {
-            KafkaException ex = new KafkaException("unknown assignment data version: " + version);
-            log.error(ex.getMessage(), ex);
-            throw ex;
         }
         this.partitionToTaskIds = partitionToTaskIds;
     }
@@ -132,4 +218,7 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi
         return partitionToTaskIds.get(partition);
     }
 
+    public Set<TaskId> standbyTasks() {
+        return standbyTasks;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index ba81421..06e5951 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -59,6 +59,7 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.UUID;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -67,16 +68,18 @@ public class StreamThread extends Thread {
     private static final Logger log = LoggerFactory.getLogger(StreamThread.class);
     private static final AtomicInteger STREAMING_THREAD_ID_SEQUENCE = new AtomicInteger(1);
 
-    private final AtomicBoolean running;
+    public final PartitionGrouper partitionGrouper;
+    public final UUID clientUUID;
 
     protected final StreamingConfig config;
     protected final TopologyBuilder builder;
-    protected final PartitionGrouper partitionGrouper;
     protected final Producer<byte[], byte[]> producer;
     protected final Consumer<byte[], byte[]> consumer;
     protected final Consumer<byte[], byte[]> restoreConsumer;
 
+    private final AtomicBoolean running;
     private final Map<TaskId, StreamTask> tasks;
+    private final Set<TaskId> prevTasks;
     private final String clientId;
     private final Time time;
     private final File stateDir;
@@ -108,9 +111,10 @@ public class StreamThread extends Thread {
     public StreamThread(TopologyBuilder builder,
                         StreamingConfig config,
                         String clientId,
+                        UUID clientUUID,
                         Metrics metrics,
                         Time time) throws Exception {
-        this(builder, config, null , null, null, clientId, metrics, time);
+        this(builder, config, null , null, null, clientId, clientUUID, metrics, time);
     }
 
     StreamThread(TopologyBuilder builder,
@@ -119,6 +123,7 @@ public class StreamThread extends Thread {
                  Consumer<byte[], byte[]> consumer,
                  Consumer<byte[], byte[]> restoreConsumer,
                  String clientId,
+                 UUID clientUUID,
                  Metrics metrics,
                  Time time) throws Exception {
         super("StreamThread-" + STREAMING_THREAD_ID_SEQUENCE.getAndIncrement());
@@ -126,6 +131,7 @@ public class StreamThread extends Thread {
         this.config = config;
         this.builder = builder;
         this.clientId = clientId;
+        this.clientUUID = clientUUID;
         this.partitionGrouper = config.getConfiguredInstance(StreamingConfig.PARTITION_GROUPER_CLASS_CONFIG, PartitionGrouper.class);
         this.partitionGrouper.topicGroups(builder.topicGroups());
 
@@ -136,6 +142,7 @@ public class StreamThread extends Thread {
 
         // initialize the task list
         this.tasks = new HashMap<>();
+        this.prevTasks = new HashSet<>();
 
         // read in task specific config values
         this.stateDir = new File(this.config.getString(StreamingConfig.STATE_DIR_CONFIG));
@@ -164,7 +171,7 @@ public class StreamThread extends Thread {
 
     private Consumer<byte[], byte[]> createConsumer() {
         log.info("Creating consumer client for stream thread [" + this.getName() + "]");
-        return new KafkaConsumer<>(config.getConsumerConfigs(partitionGrouper),
+        return new KafkaConsumer<>(config.getConsumerConfigs(this),
                 new ByteArrayDeserializer(),
                 new ByteArrayDeserializer());
     }
@@ -415,6 +422,43 @@ public class StreamThread extends Thread {
         }
     }
 
+    /**
+     * Returns ids of tasks that were being executed before the rebalance.
+     */
+    public Set<TaskId> prevTasks() {
+        return prevTasks;
+    }
+
+    /**
+     * Returns ids of tasks whose states are kept on the local storage.
+     */
+    public Set<TaskId> cachedTasks() {
+        // A client could contain some inactive tasks whose states are still kept on the local storage in the following scenarios:
+        // 1) the client is actively maintaining standby tasks by maintaining their states from the change log.
+        // 2) the client has just got some tasks migrated out of itself to other clients while these task states
+        //    have not been cleaned up yet (this can happen in a rolling bounce upgrade, for example).
+
+        HashSet<TaskId> tasks = new HashSet<>();
+
+        File[] stateDirs = stateDir.listFiles();
+        if (stateDirs != null) {
+            for (File dir : stateDirs) {
+                try {
+                    TaskId id = TaskId.parse(dir.getName());
+                    // if the checkpoint file exists, the state is valid.
+                    if (new File(dir, ProcessorStateManager.CHECKPOINT_FILE_NAME).exists())
+                        tasks.add(id);
+
+                } catch (TaskId.TaskIdFormatException e) {
+                    // there may be some unknown files that sits in the same directory,
+                    // we should ignore these files instead trying to delete them as well
+                }
+            }
+        }
+
+        return tasks;
+    }
+
     protected StreamTask createStreamTask(TaskId id, Collection<TopicPartition> partitionsForTask) {
         sensors.taskCreationSensor.record();
 
@@ -465,11 +509,10 @@ public class StreamThread extends Thread {
             }
             sensors.taskDestructionSensor.record();
         }
-        tasks.clear();
-    }
+        prevTasks.clear();
+        prevTasks.addAll(tasks.keySet());
 
-    public PartitionGrouper partitionGrouper() {
-        return partitionGrouper;
+        tasks.clear();
     }
 
     private void ensureCopartitioning(Collection<Set<String>> copartitionGroups) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
new file mode 100644
index 0000000..d82dd7d
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
@@ -0,0 +1,125 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class AssignmentInfo {
+
+    private static final Logger log = LoggerFactory.getLogger(AssignmentInfo.class);
+
+    public final int version;
+    public final List<TaskId> activeTasks; // each element corresponds to a partition
+    public final Set<TaskId> standbyTasks;
+
+    public AssignmentInfo(List<TaskId> activeTasks, Set<TaskId> standbyTasks) {
+        this(1, activeTasks, standbyTasks);
+    }
+
+    protected AssignmentInfo(int version, List<TaskId> activeTasks, Set<TaskId> standbyTasks) {
+        this.version = version;
+        this.activeTasks = activeTasks;
+        this.standbyTasks = standbyTasks;
+    }
+
+    public ByteBuffer encode() {
+        if (version == 1) {
+            ByteBuffer buf = ByteBuffer.allocate(4 + 4 + activeTasks.size() * 8 + 4 + standbyTasks.size() * 8);
+            // Encode version
+            buf.putInt(1);
+            // Encode active tasks
+            buf.putInt(activeTasks.size());
+            for (TaskId id : activeTasks) {
+                id.writeTo(buf);
+            }
+            // Encode standby tasks
+            buf.putInt(standbyTasks.size());
+            for (TaskId id : standbyTasks) {
+                id.writeTo(buf);
+            }
+            buf.rewind();
+
+            return buf;
+
+        } else {
+            TaskAssignmentException ex = new TaskAssignmentException("unable to encode assignment data: version=" + version);
+            log.error(ex.getMessage(), ex);
+            throw ex;
+        }
+    }
+
+    public static AssignmentInfo decode(ByteBuffer data) {
+        // ensure we are at the beginning of the ByteBuffer
+        data.rewind();
+
+        // Decode version
+        int version = data.getInt();
+        if (version == 1) {
+           // Decode active tasks
+            int count = data.getInt();
+            List<TaskId> activeTasks = new ArrayList<>(count);
+            for (int i = 0; i < count; i++) {
+                activeTasks.add(TaskId.readFrom(data));
+            }
+            // Decode standby tasks
+            count = data.getInt();
+            Set<TaskId> standbyTasks = new HashSet<>(count);
+            for (int i = 0; i < count; i++) {
+                standbyTasks.add(TaskId.readFrom(data));
+            }
+
+            return new AssignmentInfo(activeTasks, standbyTasks);
+
+        } else {
+            TaskAssignmentException ex = new TaskAssignmentException("unknown assignment data version: " + version);
+            log.error(ex.getMessage(), ex);
+            throw ex;
+        }
+    }
+
+    @Override
+    public int hashCode() {
+        return version ^ activeTasks.hashCode() ^ standbyTasks.hashCode();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o instanceof AssignmentInfo) {
+            AssignmentInfo other = (AssignmentInfo) o;
+            return this.version == other.version &&
+                    this.activeTasks.equals(other.activeTasks) &&
+                    this.standbyTasks.equals(other.standbyTasks);
+        } else {
+            return false;
+        }
+    }
+
+    @Override
+    public String toString() {
+        return "[version=" + version + ", active tasks=" + activeTasks.size() + ", standby tasks=" + standbyTasks.size() + "]";
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
new file mode 100644
index 0000000..a0f6179
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashSet;
+import java.util.Set;
+
+public class ClientState<T> {
+
+    public final static double COST_ACTIVE = 0.1;
+    public final static double COST_STANDBY  = 0.2;
+    public final static double COST_LOAD = 0.5;
+
+    public final Set<T> activeTasks;
+    public final Set<T> assignedTasks;
+    public final Set<T> prevActiveTasks;
+    public final Set<T> prevAssignedTasks;
+
+    public double capacity;
+    public double cost;
+
+    public ClientState() {
+        this(0d);
+    }
+
+    public ClientState(double capacity) {
+        this(new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), capacity);
+    }
+
+    private ClientState(Set<T> activeTasks, Set<T> assignedTasks, Set<T> prevActiveTasks, Set<T> prevAssignedTasks, double capacity) {
+        this.activeTasks = activeTasks;
+        this.assignedTasks = assignedTasks;
+        this.prevActiveTasks = prevActiveTasks;
+        this.prevAssignedTasks = prevAssignedTasks;
+        this.capacity = capacity;
+        this.cost = 0d;
+    }
+
+    public ClientState<T> copy() {
+        return new ClientState<>(new HashSet<>(activeTasks), new HashSet<>(assignedTasks),
+                new HashSet<>(prevActiveTasks), new HashSet<>(prevAssignedTasks), capacity);
+    }
+
+    public void assign(T taskId, boolean active) {
+        if (active)
+            activeTasks.add(taskId);
+
+        assignedTasks.add(taskId);
+
+        double cost = COST_LOAD;
+        cost = prevAssignedTasks.remove(taskId) ? COST_STANDBY : cost;
+        cost = prevActiveTasks.remove(taskId) ? COST_ACTIVE : cost;
+
+        this.cost += cost;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
new file mode 100644
index 0000000..54042b9
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.nio.ByteBuffer;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.UUID;
+
+public class SubscriptionInfo {
+
+    private static final Logger log = LoggerFactory.getLogger(SubscriptionInfo.class);
+
+    public final int version;
+    public final UUID clientUUID;
+    public final Set<TaskId> prevTasks;
+    public final Set<TaskId> standbyTasks;
+
+    public SubscriptionInfo(UUID clientUUID, Set<TaskId> prevTasks, Set<TaskId> standbyTasks) {
+        this(1, clientUUID, prevTasks, standbyTasks);
+    }
+
+    private SubscriptionInfo(int version, UUID clientUUID, Set<TaskId> prevTasks, Set<TaskId> standbyTasks) {
+        this.version = version;
+        this.clientUUID = clientUUID;
+        this.prevTasks = prevTasks;
+        this.standbyTasks = standbyTasks;
+    }
+
+    public ByteBuffer encode() {
+        if (version == 1) {
+            ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + prevTasks.size() * 8 + 4 + standbyTasks.size() * 8);
+            // version
+            buf.putInt(1);
+            // encode client UUID
+            buf.putLong(clientUUID.getMostSignificantBits());
+            buf.putLong(clientUUID.getLeastSignificantBits());
+            // encode ids of previously running tasks
+            buf.putInt(prevTasks.size());
+            for (TaskId id : prevTasks) {
+                id.writeTo(buf);
+            }
+            // encode ids of cached tasks
+            buf.putInt(standbyTasks.size());
+            for (TaskId id : standbyTasks) {
+                id.writeTo(buf);
+            }
+            buf.rewind();
+
+            return buf;
+
+        } else {
+            TaskAssignmentException ex = new TaskAssignmentException("unable to encode subscription data: version=" + version);
+            log.error(ex.getMessage(), ex);
+            throw ex;
+        }
+    }
+
+    public static SubscriptionInfo decode(ByteBuffer data) {
+        // ensure we are at the beginning of the ByteBuffer
+        data.rewind();
+
+        // Decode version
+        int version = data.getInt();
+        if (version == 1) {
+            // Decode client UUID
+            UUID clientUUID = new UUID(data.getLong(), data.getLong());
+            // Decode previously active tasks
+            Set<TaskId> prevTasks = new HashSet<>();
+            int numPrevs = data.getInt();
+            for (int i = 0; i < numPrevs; i++) {
+                TaskId id = TaskId.readFrom(data);
+                prevTasks.add(id);
+            }
+            // Decode previously cached tasks
+            Set<TaskId> standbyTasks = new HashSet<>();
+            int numCached = data.getInt();
+            for (int i = 0; i < numCached; i++) {
+                standbyTasks.add(TaskId.readFrom(data));
+            }
+
+            return new SubscriptionInfo(version, clientUUID, prevTasks, standbyTasks);
+
+        } else {
+            TaskAssignmentException ex = new TaskAssignmentException("unable to decode subscription data: version=" + version);
+            log.error(ex.getMessage(), ex);
+            throw ex;
+        }
+    }
+
+    @Override
+    public int hashCode() {
+        return version ^ clientUUID.hashCode() ^ prevTasks.hashCode() ^ standbyTasks.hashCode();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o instanceof SubscriptionInfo) {
+            SubscriptionInfo other = (SubscriptionInfo) o;
+            return this.version == other.version &&
+                    this.clientUUID.equals(other.clientUUID) &&
+                    this.prevTasks.equals(other.prevTasks) &&
+                    this.standbyTasks.equals(other.standbyTasks);
+        } else {
+            return false;
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java
new file mode 100644
index 0000000..839a6c2
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.common.KafkaException;
+
+/**
+ * The run time exception class for stream task assignments
+ */
+public class TaskAssignmentException extends KafkaException {
+
+    private final static long serialVersionUID = 1L;
+
+    public TaskAssignmentException(String message) {
+        super(message);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
new file mode 100644
index 0000000..d1e0782
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
@@ -0,0 +1,195 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+public class TaskAssignor<C, T extends Comparable<T>> {
+
+    private static final Logger log = LoggerFactory.getLogger(TaskAssignor.class);
+
+    public static <C, T extends Comparable<T>> Map<C, ClientState<T>> assign(Map<C, ClientState<T>> states, Set<T> tasks, int numStandbyReplicas) {
+        long seed = 0L;
+        for (C client : states.keySet()) {
+            seed += client.hashCode();
+        }
+
+        TaskAssignor<C, T> assignor = new TaskAssignor<>(states, tasks, seed);
+        assignor.assignTasks();
+        if (numStandbyReplicas > 0)
+            assignor.assignStandbyTasks(numStandbyReplicas);
+
+        return assignor.states;
+    }
+
+    private final Random rand;
+    private final Map<C, ClientState<T>> states;
+    private final Set<TaskPair<T>> taskPairs;
+    private final int maxNumTaskPairs;
+    private final ArrayList<T> tasks;
+
+    private TaskAssignor(Map<C, ClientState<T>> states, Set<T> tasks, long randomSeed) {
+        this.rand = new Random(randomSeed);
+        this.states = new HashMap<>();
+        for (Map.Entry<C, ClientState<T>> entry : states.entrySet()) {
+            this.states.put(entry.getKey(), entry.getValue().copy());
+        }
+        this.tasks = new ArrayList<>(tasks);
+
+        int numTasks = tasks.size();
+        this.maxNumTaskPairs = numTasks * (numTasks - 1) / 2;
+        this.taskPairs = new HashSet<>(this.maxNumTaskPairs);
+    }
+
+    public void assignTasks() {
+        assignTasks(true);
+    }
+
+    public void assignStandbyTasks(int numStandbyReplicas) {
+        int numReplicas = Math.min(numStandbyReplicas, states.size() - 1);
+        for (int i = 0; i < numReplicas; i++) {
+            assignTasks(false);
+        }
+    }
+
+    private void assignTasks(boolean active) {
+        Collections.shuffle(this.tasks, rand);
+
+        for (T task : tasks) {
+            ClientState<T> state = findClientFor(task);
+
+            if (state != null) {
+                state.assign(task, active);
+            } else {
+                TaskAssignmentException ex = new TaskAssignmentException("failed to find an assignable client");
+                log.error(ex.getMessage(), ex);
+                throw ex;
+            }
+        }
+    }
+
+    private ClientState<T> findClientFor(T task) {
+        boolean checkTaskPairs = taskPairs.size() < maxNumTaskPairs;
+
+        ClientState<T> state = findClientByAdditionCost(task, checkTaskPairs);
+
+        if (state == null && checkTaskPairs)
+            state = findClientByAdditionCost(task, false);
+
+        if (state != null)
+            addTaskPairs(task, state);
+
+        return state;
+    }
+
+    private ClientState<T> findClientByAdditionCost(T task, boolean checkTaskPairs) {
+        ClientState<T> candidate = null;
+        double candidateAdditionCost = 0d;
+
+        for (ClientState<T> state : states.values()) {
+            if (!state.assignedTasks.contains(task)) {
+                // if checkTaskPairs flag is on, skip this client if this task doesn't introduce a new task combination
+                if (checkTaskPairs && !state.assignedTasks.isEmpty() && !hasNewTaskPair(task, state))
+                    continue;
+
+                double additionCost = computeAdditionCost(task, state);
+                if (candidate == null ||
+                        (additionCost < candidateAdditionCost ||
+                            (additionCost == candidateAdditionCost && state.cost < candidate.cost))) {
+                    candidate = state;
+                    candidateAdditionCost = additionCost;
+                }
+            }
+        }
+
+        return candidate;
+    }
+
+    private void addTaskPairs(T task, ClientState<T> state) {
+        for (T other : state.assignedTasks) {
+            taskPairs.add(pair(task, other));
+        }
+    }
+
+    private boolean hasNewTaskPair(T task, ClientState<T> state) {
+        for (T other : state.assignedTasks) {
+            if (!taskPairs.contains(pair(task, other)))
+                return true;
+        }
+        return false;
+    }
+
+    private double computeAdditionCost(T task, ClientState<T> state) {
+        double cost = Math.floor((double) state.assignedTasks.size() / state.capacity);
+
+        if (state.prevAssignedTasks.contains(task)) {
+            if (state.prevActiveTasks.contains(task)) {
+                cost += ClientState.COST_ACTIVE;
+            } else {
+                cost += ClientState.COST_STANDBY;
+            }
+        } else {
+            cost += ClientState.COST_LOAD;
+        }
+
+        return cost;
+    }
+
+    private TaskPair<T> pair(T task1, T task2) {
+        if (task1.compareTo(task2) < 0) {
+            return new TaskPair<>(task1, task2);
+        } else {
+            return new TaskPair<>(task2, task1);
+        }
+    }
+
+    private static class TaskPair<T> {
+        public final T task1;
+        public final T task2;
+
+        public TaskPair(T task1, T task2) {
+            this.task1 = task1;
+            this.task2 = task2;
+        }
+
+        @Override
+        public int hashCode() {
+            return task1.hashCode() ^ task2.hashCode();
+        }
+
+        @SuppressWarnings("unchecked")
+        @Override
+        public boolean equals(Object o) {
+            if (o instanceof TaskPair) {
+                TaskPair<T> other = (TaskPair<T>) o;
+                return this.task1.equals(other.task1) && this.task2.equals(other.task2);
+            }
+            return false;
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java
new file mode 100644
index 0000000..86434fb
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.MockConsumer;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.clients.consumer.internals.PartitionAssignor;
+import org.apache.kafka.clients.producer.MockProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.serialization.ByteArraySerializer;
+import org.apache.kafka.common.utils.SystemTime;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.StreamingConfig;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.TopologyBuilder;
+import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
+import org.apache.kafka.test.MockProcessorSupplier;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+import java.util.UUID;
+
+import static org.junit.Assert.assertEquals;
+
+public class KafkaStreamingPartitionAssignorTest {
+
+    private TopicPartition t1p0 = new TopicPartition("topic1", 0);
+    private TopicPartition t1p1 = new TopicPartition("topic1", 1);
+    private TopicPartition t1p2 = new TopicPartition("topic1", 2);
+    private TopicPartition t2p0 = new TopicPartition("topic2", 0);
+    private TopicPartition t2p1 = new TopicPartition("topic2", 1);
+    private TopicPartition t2p2 = new TopicPartition("topic2", 2);
+    private TopicPartition t2p3 = new TopicPartition("topic2", 3);
+
+    private List<PartitionInfo> infos = Arrays.asList(
+            new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0])
+    );
+
+    private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos, Collections.<String>emptySet());
+
+    private ByteBuffer subscriptionUserData() {
+        UUID uuid = UUID.randomUUID();
+        ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + 4);
+        // version
+        buf.putInt(1);
+        // encode client clientUUID
+        buf.putLong(uuid.getMostSignificantBits());
+        buf.putLong(uuid.getLeastSignificantBits());
+        // previously running tasks
+        buf.putInt(0);
+        // cached tasks
+        buf.putInt(0);
+        buf.rewind();
+        return buf;
+    }
+
+    private final TaskId task0 = new TaskId(0, 0);
+    private final TaskId task1 = new TaskId(0, 1);
+    private final TaskId task2 = new TaskId(0, 2);
+    private final TaskId task3 = new TaskId(0, 3);
+
+    private Properties configProps() {
+        return new Properties() {
+            {
+                setProperty(StreamingConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer");
+                setProperty(StreamingConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer");
+                setProperty(StreamingConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer");
+                setProperty(StreamingConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer");
+                setProperty(StreamingConfig.TIMESTAMP_EXTRACTOR_CLASS_CONFIG, "org.apache.kafka.test.MockTimestampExtractor");
+                setProperty(StreamingConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171");
+                setProperty(StreamingConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3");
+            }
+        };
+    }
+
+    private static class TestStreamTask extends StreamTask {
+        public boolean committed = false;
+
+        public TestStreamTask(TaskId id,
+                              Consumer<byte[], byte[]> consumer,
+                              Producer<byte[], byte[]> producer,
+                              Consumer<byte[], byte[]> restoreConsumer,
+                              Collection<TopicPartition> partitions,
+                              ProcessorTopology topology,
+                              StreamingConfig config) {
+            super(id, consumer, producer, restoreConsumer, partitions, topology, config, null);
+        }
+
+        @Override
+        public void commit() {
+            super.commit();
+            committed = true;
+        }
+    }
+
+    private ByteArraySerializer serializer = new ByteArraySerializer();
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testSubscription() throws Exception {
+        StreamingConfig config = new StreamingConfig(configProps());
+
+        MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer);
+        MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
+        MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST);
+
+        TopologyBuilder builder = new TopologyBuilder();
+        builder.addSource("source1", "topic1");
+        builder.addSource("source2", "topic2");
+        builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
+
+        final Set<TaskId> prevTasks = Utils.mkSet(
+                new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1));
+        final Set<TaskId> cachedTasks = Utils.mkSet(
+                new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1),
+                new TaskId(0, 2), new TaskId(1, 2), new TaskId(2, 2));
+
+        UUID uuid = UUID.randomUUID();
+        StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime()) {
+            @Override
+            public Set<TaskId> prevTasks() {
+                return prevTasks;
+            }
+            @Override
+            public Set<TaskId> cachedTasks() {
+                return cachedTasks;
+            }
+        };
+
+        KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
+        partitionAssignor.configure(
+                Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread)
+        );
+
+        PartitionAssignor.Subscription subscription = partitionAssignor.subscription(Utils.mkSet("topic1", "topic2"));
+
+        assertEquals(Utils.mkList("topic1", "topic2"), subscription.topics());
+
+        Set<TaskId> standbyTasks = new HashSet<>(cachedTasks);
+        standbyTasks.removeAll(prevTasks);
+
+        SubscriptionInfo info = new SubscriptionInfo(uuid, prevTasks, standbyTasks);
+        assertEquals(info.encode(), subscription.userData());
+    }
+
+    @Test
+    public void testAssign() throws Exception {
+        StreamingConfig config = new StreamingConfig(configProps());
+
+        MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer);
+        MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
+        MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST);
+
+        TopologyBuilder builder = new TopologyBuilder();
+        builder.addSource("source1", "topic1");
+        builder.addSource("source2", "topic2");
+        builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
+
+        final Set<TaskId> prevTasks10 = Utils.mkSet(task0);
+        final Set<TaskId> prevTasks11 = Utils.mkSet(task1);
+        final Set<TaskId> prevTasks20 = Utils.mkSet(task2);
+        final Set<TaskId> standbyTasks10 = Utils.mkSet(task1);
+        final Set<TaskId> standbyTasks11 = Utils.mkSet(task2);
+        final Set<TaskId> standbyTasks20 = Utils.mkSet(task0);
+
+        UUID uuid1 = UUID.randomUUID();
+        UUID uuid2 = UUID.randomUUID();
+
+        StreamThread thread10 = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid1, new Metrics(), new SystemTime());
+
+        KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
+        partitionAssignor.configure(
+                Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread10)
+        );
+
+        Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
+        subscriptions.put("consumer10",
+                new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid1, prevTasks10, standbyTasks10).encode()));
+        subscriptions.put("consumer11",
+                new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid1, prevTasks11, standbyTasks11).encode()));
+        subscriptions.put("consumer20",
+                new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid2, prevTasks20, standbyTasks20).encode()));
+
+        Map<String, PartitionAssignor.Assignment> assignments = partitionAssignor.assign(metadata, subscriptions);
+
+        // check assigned partitions
+
+        assertEquals(Utils.mkSet(Utils.mkSet(t1p0, t2p0), Utils.mkSet(t1p1, t2p1)),
+                Utils.mkSet(new HashSet<>(assignments.get("consumer10").partitions()), new HashSet<>(assignments.get("consumer11").partitions())));
+        assertEquals(Utils.mkSet(t1p2, t2p2), new HashSet<>(assignments.get("consumer20").partitions()));
+
+        // check assignment info
+
+        List<TaskId> activeTasks = new ArrayList<>();
+        for (TopicPartition partition : assignments.get("consumer10").partitions()) {
+            activeTasks.add(new TaskId(0, partition.partition()));
+        }
+        assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer10").userData()).activeTasks);
+
+        activeTasks.clear();
+        for (TopicPartition partition : assignments.get("consumer11").partitions()) {
+            activeTasks.add(new TaskId(0, partition.partition()));
+        }
+        assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer11").userData()).activeTasks);
+
+        activeTasks.clear();
+        for (TopicPartition partition : assignments.get("consumer20").partitions()) {
+            activeTasks.add(new TaskId(0, partition.partition()));
+        }
+        assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer20").userData()).activeTasks);
+    }
+
+    @Test
+    public void testOnAssignment() throws Exception {
+        StreamingConfig config = new StreamingConfig(configProps());
+
+        MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer);
+        MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
+        MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST);
+
+        TopologyBuilder builder = new TopologyBuilder();
+        builder.addSource("source1", "topic1");
+        builder.addSource("source2", "topic2");
+        builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
+
+        UUID uuid = UUID.randomUUID();
+
+        StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime());
+
+        KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
+        partitionAssignor.configure(
+                Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread)
+        );
+
+        List<TaskId> activeTaskList = Utils.mkList(task0, task3);
+        Set<TaskId> standbyTasks = Utils.mkSet(task1, task2);
+        AssignmentInfo info = new AssignmentInfo(activeTaskList, standbyTasks);
+        PartitionAssignor.Assignment assignment = new PartitionAssignor.Assignment(Utils.mkList(t1p0, t2p3), info.encode());
+        partitionAssignor.onAssignment(assignment);
+
+        assertEquals(Utils.mkSet(task0), partitionAssignor.taskIds(t1p0));
+        assertEquals(Utils.mkSet(task3), partitionAssignor.taskIds(t2p3));
+        assertEquals(standbyTasks, partitionAssignor.standbyTasks());
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 909df13..54d0a18 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -38,13 +38,13 @@ import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.SystemTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamingConfig;
-import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.TopologyBuilder;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.junit.Test;
 
 import java.io.File;
+import java.nio.ByteBuffer;
 import java.nio.file.Files;
 import java.util.Arrays;
 import java.util.Collection;
@@ -55,9 +55,12 @@ import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
+import java.util.UUID;
 
 public class StreamThreadTest {
 
+    private UUID uuid = UUID.randomUUID();
+
     private TopicPartition t1p1 = new TopicPartition("topic1", 1);
     private TopicPartition t1p2 = new TopicPartition("topic1", 2);
     private TopicPartition t2p1 = new TopicPartition("topic2", 1);
@@ -79,7 +82,24 @@ public class StreamThreadTest {
 
     private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos, Collections.<String>emptySet());
 
-    PartitionAssignor.Subscription subscription = new PartitionAssignor.Subscription(Arrays.asList("topic1", "topic2", "topic3"));
+    private final PartitionAssignor.Subscription subscription =
+            new PartitionAssignor.Subscription(Arrays.asList("topic1", "topic2", "topic3"), subscriptionUserData());
+
+    private ByteBuffer subscriptionUserData() {
+        UUID uuid = UUID.randomUUID();
+        ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + 4);
+        // version
+        buf.putInt(1);
+        // encode client clientUUID
+        buf.putLong(uuid.getMostSignificantBits());
+        buf.putLong(uuid.getLeastSignificantBits());
+        // previously running tasks
+        buf.putInt(0);
+        // cached tasks
+        buf.putInt(0);
+        buf.rewind();
+        return buf;
+    }
 
     // task0 is unused
     private final TaskId task1 = new TaskId(0, 1);
@@ -139,7 +159,7 @@ public class StreamThreadTest {
         builder.addSource("source3", "topic3");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source2", "source3");
 
-        StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), new SystemTime()) {
+        StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime()) {
             @Override
             protected StreamTask createStreamTask(TaskId id, Collection<TopicPartition> partitionsForTask) {
                 ProcessorTopology topology = builder.build(id.topicGroupId);
@@ -259,7 +279,7 @@ public class StreamThreadTest {
             TopologyBuilder builder = new TopologyBuilder();
             builder.addSource("source1", "topic1");
 
-            StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), mockTime) {
+            StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), mockTime) {
                 @Override
                 public void maybeClean() {
                     super.maybeClean();
@@ -381,7 +401,7 @@ public class StreamThreadTest {
             TopologyBuilder builder = new TopologyBuilder();
             builder.addSource("source1", "topic1");
 
-            StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), mockTime) {
+            StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), mockTime) {
                 @Override
                 public void maybeCommit() {
                     super.maybeCommit();
@@ -448,12 +468,11 @@ public class StreamThreadTest {
     }
 
     private void initPartitionGrouper(StreamThread thread) {
-        PartitionGrouper partitionGrouper = thread.partitionGrouper();
 
         KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
 
         partitionAssignor.configure(
-                Collections.singletonMap(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper)
+                Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread)
         );
 
         Map<String, PartitionAssignor.Assignment> assignments =

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java
new file mode 100644
index 0000000..58e0af9
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+
+public class AssginmentInfoTest {
+
+    @Test
+    public void testEncodeDecode() {
+        List<TaskId> activeTasks =
+                Arrays.asList(new TaskId(0, 0), new TaskId(0, 0), new TaskId(0, 1), new TaskId(1, 0));
+        Set<TaskId> standbyTasks =
+                new HashSet<>(Arrays.asList(new TaskId(1, 1), new TaskId(2, 0)));
+
+        AssignmentInfo info = new AssignmentInfo(activeTasks, standbyTasks);
+        AssignmentInfo decoded = AssignmentInfo.decode(info.encode());
+
+        assertEquals(info, decoded);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
new file mode 100644
index 0000000..acc9a9d
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.UUID;
+
+import static org.junit.Assert.assertEquals;
+
+public class SubscriptionInfoTest {
+
+    @Test
+    public void testEncodeDecode() {
+        UUID clientUUID = UUID.randomUUID();
+        Set<TaskId> activeTasks =
+                new HashSet<>(Arrays.asList(new TaskId(0, 0), new TaskId(0, 1), new TaskId(1, 0)));
+        Set<TaskId> standbyTasks =
+                new HashSet<>(Arrays.asList(new TaskId(1, 1), new TaskId(2, 0)));
+
+        SubscriptionInfo info = new SubscriptionInfo(clientUUID, activeTasks, standbyTasks);
+        SubscriptionInfo decoded = SubscriptionInfo.decode(info.encode());
+
+        assertEquals(info, decoded);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
new file mode 100644
index 0000000..28364ab
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
@@ -0,0 +1,289 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import static org.apache.kafka.common.utils.Utils.mkList;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class TaskAssignorTest {
+
+    @Test
+    public void testAssignWithoutStandby() {
+        HashMap<Integer, ClientState<Integer>> states = new HashMap<>();
+        for (int i = 0; i < 6; i++) {
+            states.put(i, new ClientState<Integer>(1d));
+        }
+        Set<Integer> tasks;
+        Map<Integer, ClientState<Integer>> assignments;
+        int numActiveTasks;
+        int numAssignedTasks;
+
+        // # of clients and # of tasks are equal.
+        tasks = mkSet(0, 1, 2, 3, 4, 5);
+        assignments = TaskAssignor.assign(states, tasks, 0);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertEquals(1, assignment.activeTasks.size());
+            assertEquals(1, assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size(), numAssignedTasks);
+
+        // # of clients < # of tasks
+        tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7);
+        assignments = TaskAssignor.assign(states, tasks, 0);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertTrue(1 <= assignment.activeTasks.size());
+            assertTrue(2 >= assignment.activeTasks.size());
+            assertTrue(1 <= assignment.assignedTasks.size());
+            assertTrue(2 >= assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size(), numAssignedTasks);
+
+        // # of clients > # of tasks
+        tasks = mkSet(0, 1, 2, 3);
+        assignments = TaskAssignor.assign(states, tasks, 0);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertTrue(0 <= assignment.activeTasks.size());
+            assertTrue(1 >= assignment.activeTasks.size());
+            assertTrue(0 <= assignment.assignedTasks.size());
+            assertTrue(1 >= assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size(), numAssignedTasks);
+    }
+
+    @Test
+    public void testAssignWithStandby() {
+        HashMap<Integer, ClientState<Integer>> states = new HashMap<>();
+        for (int i = 0; i < 6; i++) {
+            states.put(i, new ClientState<Integer>(1d));
+        }
+        Set<Integer> tasks;
+        Map<Integer, ClientState<Integer>> assignments;
+        int numActiveTasks;
+        int numAssignedTasks;
+
+        // # of clients and # of tasks are equal.
+        tasks = mkSet(0, 1, 2, 3, 4, 5);
+
+        // 1 standby replicas.
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        assignments = TaskAssignor.assign(states, tasks, 1);
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertEquals(1, assignment.activeTasks.size());
+            assertEquals(2, assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size() * 2, numAssignedTasks);
+
+        // # of clients < # of tasks
+        tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7);
+
+        // 1 standby replicas.
+        assignments = TaskAssignor.assign(states, tasks, 1);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertTrue(1 <= assignment.activeTasks.size());
+            assertTrue(2 >= assignment.activeTasks.size());
+            assertTrue(2 <= assignment.assignedTasks.size());
+            assertTrue(3 >= assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size() * 2, numAssignedTasks);
+
+        // # of clients > # of tasks
+        tasks = mkSet(0, 1, 2, 3);
+
+        // 1 standby replicas.
+        assignments = TaskAssignor.assign(states, tasks, 1);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertTrue(0 <= assignment.activeTasks.size());
+            assertTrue(1 >= assignment.activeTasks.size());
+            assertTrue(1 <= assignment.assignedTasks.size());
+            assertTrue(2 >= assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size() * 2, numAssignedTasks);
+
+        // # of clients >> # of tasks
+        tasks = mkSet(0, 1);
+
+        // 1 standby replicas.
+        assignments = TaskAssignor.assign(states, tasks, 1);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertTrue(0 <= assignment.activeTasks.size());
+            assertTrue(1 >= assignment.activeTasks.size());
+            assertTrue(0 <= assignment.assignedTasks.size());
+            assertTrue(1 >= assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size() * 2, numAssignedTasks);
+
+        // 2 standby replicas.
+        assignments = TaskAssignor.assign(states, tasks, 2);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertTrue(0 <= assignment.activeTasks.size());
+            assertTrue(1 >= assignment.activeTasks.size());
+            assertTrue(1 == assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size() * 3, numAssignedTasks);
+
+        // 3 standby replicas.
+        assignments = TaskAssignor.assign(states, tasks, 3);
+        numActiveTasks = 0;
+        numAssignedTasks = 0;
+        for (ClientState<Integer> assignment : assignments.values()) {
+            numActiveTasks += assignment.activeTasks.size();
+            numAssignedTasks += assignment.assignedTasks.size();
+            assertTrue(0 <= assignment.activeTasks.size());
+            assertTrue(1 >= assignment.activeTasks.size());
+            assertTrue(1 <= assignment.assignedTasks.size());
+            assertTrue(2 >= assignment.assignedTasks.size());
+        }
+        assertEquals(tasks.size(), numActiveTasks);
+        assertEquals(tasks.size() * 4, numAssignedTasks);
+    }
+
+    @Test
+    public void testStickiness() {
+        List<Integer> tasks;
+        Map<Integer, ClientState<Integer>> states;
+        Map<Integer, ClientState<Integer>> assignments;
+        int i;
+
+        // # of clients and # of tasks are equal.
+        tasks = mkList(0, 1, 2, 3, 4, 5);
+        Collections.shuffle(tasks);
+        states = new HashMap<>();
+        i = 0;
+        for (int task : tasks) {
+            ClientState<Integer> state = new ClientState<>(1d);
+            state.prevActiveTasks.add(task);
+            state.prevAssignedTasks.add(task);
+            states.put(i++, state);
+        }
+        assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5), 0);
+        for (int client : states.keySet()) {
+            Set<Integer> oldActive = states.get(client).prevActiveTasks;
+            Set<Integer> oldAssigned = states.get(client).prevAssignedTasks;
+            Set<Integer> newActive = assignments.get(client).activeTasks;
+            Set<Integer> newAssigned = assignments.get(client).assignedTasks;
+
+            assertEquals(oldActive, newActive);
+            assertEquals(oldAssigned, newAssigned);
+        }
+
+        // # of clients > # of tasks
+        tasks = mkList(0, 1, 2, 3, -1, -1);
+        Collections.shuffle(tasks);
+        states = new HashMap<>();
+        i = 0;
+        for (int task : tasks) {
+            ClientState<Integer> state = new ClientState<>(1d);
+            if (task >= 0) {
+                state.prevActiveTasks.add(task);
+                state.prevAssignedTasks.add(task);
+            }
+            states.put(i++, state);
+        }
+        assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3), 0);
+        for (int client : states.keySet()) {
+            Set<Integer> oldActive = states.get(client).prevActiveTasks;
+            Set<Integer> oldAssigned = states.get(client).prevAssignedTasks;
+            Set<Integer> newActive = assignments.get(client).activeTasks;
+            Set<Integer> newAssigned = assignments.get(client).assignedTasks;
+
+            assertEquals(oldActive, newActive);
+            assertEquals(oldAssigned, newAssigned);
+        }
+
+        // # of clients < # of tasks
+        List<Set<Integer>> taskSets = mkList(mkSet(0, 1), mkSet(2, 3), mkSet(4, 5), mkSet(6, 7), mkSet(8, 9), mkSet(10, 11));
+        Collections.shuffle(taskSets);
+        states = new HashMap<>();
+        i = 0;
+        for (Set<Integer> taskSet : taskSets) {
+            ClientState<Integer> state = new ClientState<>(1d);
+            state.prevActiveTasks.addAll(taskSet);
+            state.prevAssignedTasks.addAll(taskSet);
+            states.put(i++, state);
+        }
+        assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), 0);
+        for (int client : states.keySet()) {
+            Set<Integer> oldActive = states.get(client).prevActiveTasks;
+            Set<Integer> oldAssigned = states.get(client).prevAssignedTasks;
+            Set<Integer> newActive = assignments.get(client).activeTasks;
+            Set<Integer> newAssigned = assignments.get(client).assignedTasks;
+
+            Set<Integer> intersection = new HashSet<>();
+
+            intersection.addAll(oldActive);
+            intersection.retainAll(newActive);
+            assertTrue(intersection.size() > 0);
+
+            intersection.clear();
+            intersection.addAll(oldAssigned);
+            intersection.retainAll(newAssigned);
+            assertTrue(intersection.size() > 0);
+        }
+    }
+
+}