You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2015/10/26 21:28:40 UTC

[1/2] kafka git commit: KAFKA-2652: integrate new group protocol into partition grouping

Repository: kafka
Updated Branches:
  refs/heads/trunk 939c4244e -> 71399ffe4


http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index 8fdbfff..4dfa9c2 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -47,14 +47,14 @@ import java.util.Set;
  * <p>
  * <h2>Basic usage</h2>
  * This component can be used to help test a {@link KeyValueStore}'s ability to read and write entries.
- * 
+ *
  * <pre>
  * // Create the test driver ...
  * KeyValueStoreTestDriver&lt;Integer, String> driver = KeyValueStoreTestDriver.create();
  * KeyValueStore&lt;Integer, String> store = Stores.create("my-store", driver.context())
  *                                              .withIntegerKeys().withStringKeys()
  *                                              .inMemory().build();
- * 
+ *
  * // Verify that the store reads and writes correctly ...
  * store.put(0, "zero");
  * store.put(1, "one");
@@ -69,7 +69,7 @@ import java.util.Set;
  * assertEquals("five", store.get(5));
  * assertNull(store.get(3));
  * store.delete(5);
- * 
+ *
  * // Flush the store and verify all current entries were properly flushed ...
  * store.flush();
  * assertEquals("zero", driver.flushedEntryStored(0));
@@ -77,14 +77,14 @@ import java.util.Set;
  * assertEquals("two", driver.flushedEntryStored(2));
  * assertEquals("four", driver.flushedEntryStored(4));
  * assertEquals(null, driver.flushedEntryStored(5));
- * 
+ *
  * assertEquals(false, driver.flushedEntryRemoved(0));
  * assertEquals(false, driver.flushedEntryRemoved(1));
  * assertEquals(false, driver.flushedEntryRemoved(2));
  * assertEquals(false, driver.flushedEntryRemoved(4));
  * assertEquals(true, driver.flushedEntryRemoved(5));
  * </pre>
- * 
+ *
  * <p>
  * <h2>Restoring a store</h2>
  * This component can be used to test whether a {@link KeyValueStore} implementation properly
@@ -94,30 +94,30 @@ import java.util.Set;
  * To do this, create an instance of this driver component, {@link #addEntryToRestoreLog(Object, Object) add entries} that will be
  * passed to the store upon creation (simulating the entries that were previously flushed to the topic), and then create the store
  * using this driver's {@link #context() ProcessorContext}:
- * 
+ *
  * <pre>
  * // Create the test driver ...
  * KeyValueStoreTestDriver&lt;Integer, String> driver = KeyValueStoreTestDriver.create(Integer.class, String.class);
- * 
+ *
  * // Add any entries that will be restored to any store that uses the driver's context ...
  * driver.addRestoreEntry(0, "zero");
  * driver.addRestoreEntry(1, "one");
  * driver.addRestoreEntry(2, "two");
  * driver.addRestoreEntry(4, "four");
- * 
+ *
  * // Create the store, which should register with the context and automatically
  * // receive the restore entries ...
  * KeyValueStore&lt;Integer, String> store = Stores.create("my-store", driver.context())
  *                                              .withIntegerKeys().withStringKeys()
  *                                              .inMemory().build();
- * 
+ *
  * // Verify that the store's contents were properly restored ...
  * assertEquals(0, driver.checkForRestoredEntries(store));
- * 
+ *
  * // and there are no other entries ...
  * assertEquals(4, driver.sizeOf(store));
  * </pre>
- * 
+ *
  * @param <K> the type of keys placed in the store
  * @param <V> the type of values placed in the store
  */
@@ -163,7 +163,7 @@ public class KeyValueStoreTestDriver<K, V> {
      * value serializers and deserializers. This can be used when the actual serializers and deserializers are supplied to the
      * store during creation, which should eliminate the need for a store to depend on the ProcessorContext's default key and
      * value serializers and deserializers.
-     * 
+     *
      * @return the test driver; never null
      */
     public static <K, V> KeyValueStoreTestDriver<K, V> create() {
@@ -181,7 +181,7 @@ public class KeyValueStoreTestDriver<K, V> {
      * deserializers for the given built-in key and value types (e.g., {@code String.class}, {@code Integer.class},
      * {@code Long.class}, and {@code byte[].class}). This can be used when store is created to rely upon the
      * ProcessorContext's default key and value serializers and deserializers.
-     * 
+     *
      * @param keyClass the class for the keys; must be one of {@code String.class}, {@code Integer.class},
      *            {@code Long.class}, or {@code byte[].class}
      * @param valueClass the class for the values; must be one of {@code String.class}, {@code Integer.class},
@@ -198,7 +198,7 @@ public class KeyValueStoreTestDriver<K, V> {
      * {@link ProcessorContext#forward(Object, Object) forwarded} by the store and that provides the specified serializers and
      * deserializers. This can be used when store is created to rely upon the ProcessorContext's default key and value serializers
      * and deserializers.
-     * 
+     *
      * @param keySerializer the key serializer for the {@link ProcessorContext}; may not be null
      * @param keyDeserializer the key deserializer for the {@link ProcessorContext}; may not be null
      * @param valueSerializer the value serializer for the {@link ProcessorContext}; may not be null
@@ -283,7 +283,7 @@ public class KeyValueStoreTestDriver<K, V> {
 
     /**
      * Set the directory that should be used by the store for local disk storage.
-     * 
+     *
      * @param dir the directory; may be null if no local storage is allowed
      */
     public void useStateDir(File dir) {
@@ -320,25 +320,25 @@ public class KeyValueStoreTestDriver<K, V> {
      * <p>
      * To create such a test, create the test driver, call this method one or more times, and then create the
      * {@link KeyValueStore}. Your tests can then check whether the store contains the entries from the log.
-     * 
+     *
      * <pre>
      * // Set up the driver and pre-populate the log ...
      * KeyValueStoreTestDriver&lt;Integer, String> driver = KeyValueStoreTestDriver.create();
      * driver.addRestoreEntry(1,"value1");
      * driver.addRestoreEntry(2,"value2");
      * driver.addRestoreEntry(3,"value3");
-     * 
+     *
      * // Create the store using the driver's context ...
      * ProcessorContext context = driver.context();
      * KeyValueStore&lt;Integer, String> store = ...
-     * 
+     *
      * // Verify that the store's contents were properly restored from the log ...
      * assertEquals(0, driver.checkForRestoredEntries(store));
-     * 
+     *
      * // and there are no other entries ...
      * assertEquals(3, driver.sizeOf(store));
      * </pre>
-     * 
+     *
      * @param key the key for the entry
      * @param value the value for the entry
      * @see #checkForRestoredEntries(KeyValueStore)
@@ -354,7 +354,7 @@ public class KeyValueStoreTestDriver<K, V> {
      * <p>
      * If the {@link KeyValueStore}'s are to be restored upon its startup, be sure to {@link #addEntryToRestoreLog(Object, Object)
      * add the restore entries} before creating the store with the {@link ProcessorContext} returned by this method.
-     * 
+     *
      * @return the processing context; never null
      * @see #addEntryToRestoreLog(Object, Object)
      */
@@ -365,7 +365,7 @@ public class KeyValueStoreTestDriver<K, V> {
     /**
      * Get the entries that are restored to a KeyValueStore when it is constructed with this driver's {@link #context()
      * ProcessorContext}.
-     * 
+     *
      * @return the restore entries; never null but possibly a null iterator
      */
     public Iterable<Entry<K, V>> restoredEntries() {
@@ -375,7 +375,7 @@ public class KeyValueStoreTestDriver<K, V> {
     /**
      * Utility method that will count the number of {@link #addEntryToRestoreLog(Object, Object) restore entries} missing from the
      * supplied store.
-     * 
+     *
      * @param store the store that is to have all of the {@link #restoredEntries() restore entries}
      * @return the number of restore entries missing from the store, or 0 if all restore entries were found
      * @see #addEntryToRestoreLog(Object, Object)
@@ -395,7 +395,7 @@ public class KeyValueStoreTestDriver<K, V> {
 
     /**
      * Utility method to compute the number of entries within the store.
-     * 
+     *
      * @param store the key value store using this {@link #context()}.
      * @return the number of entries
      */
@@ -410,7 +410,7 @@ public class KeyValueStoreTestDriver<K, V> {
 
     /**
      * Retrieve the value that the store {@link KeyValueStore#flush() flushed} with the given key.
-     * 
+     *
      * @param key the key
      * @return the value that was flushed with the key, or {@code null} if no such key was flushed or if the entry with this
      *         key was {@link #flushedEntryStored(Object) removed} upon flush
@@ -421,7 +421,7 @@ public class KeyValueStoreTestDriver<K, V> {
 
     /**
      * Determine whether the store {@link KeyValueStore#flush() flushed} the removal of the given key.
-     * 
+     *
      * @param key the key
      * @return {@code true} if the entry with the given key was removed when flushed, or {@code false} if the entry was not
      *         removed when last flushed
@@ -438,4 +438,4 @@ public class KeyValueStoreTestDriver<K, V> {
         flushedEntries.clear();
         flushedRemovals.clear();
     }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
index 761f5ce..16df9c5 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
@@ -83,12 +83,7 @@ public class MockProcessorContext implements ProcessorContext, RecordCollector.S
 
     @Override
     public int id() {
-        return -1;
-    }
-
-    @Override
-    public boolean joinable() {
-        return true;
+        return 0;
     }
 
     @Override
@@ -174,4 +169,4 @@ public class MockProcessorContext implements ProcessorContext, RecordCollector.S
         return this.timestamp;
     }
 
-}
\ No newline at end of file
+}


[2/2] kafka git commit: KAFKA-2652: integrate new group protocol into partition grouping

Posted by gu...@apache.org.
KAFKA-2652: integrate new group protocol into partition grouping

guozhangwang

* added ```PartitionGrouper``` (abstract class)
 * This class is responsible for grouping partitions. Each group forms a task.
 * Users may implement this class for custom grouping.
* added ```DefaultPartitionGrouper```
 * our default implementation of ```PartitionGrouper```
* added ```KafkaStreamingPartitionAssignor```
 * We always use this as ```PartitionAssignor``` of stream consumers.
 * Actual grouping is delegated to ```PartitionGrouper```.
* ```TopologyBuilder```
 * added ```topicGroups()```
   * This returns groups of related topics according to the topology
 * added ```copartitionSources(sourceNodes...)```
   * This is used by DSL layer. It asserts the specified source nodes must be copartitioned.
 * added ```copartitionGroups()```
   * This returns groups of copartitioned topics
* KStream layer
 * keep track of source nodes to determine copartition sources when steams are joined
 * source nodes are set to null when partitioning property is not preserved (ex. ```map()```, ```transform()```), and this indicates the stream is no longer joinable

Author: Yasuhiro Matsuda <ya...@confluent.io>

Reviewers: Guozhang Wang

Closes #353 from ymatsuda/grouping


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/71399ffe
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/71399ffe
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/71399ffe

Branch: refs/heads/trunk
Commit: 71399ffe4c52e2539a5794a17852c8c5b3d5fe72
Parents: 939c424
Author: Yasuhiro Matsuda <ya...@confluent.io>
Authored: Mon Oct 26 13:33:51 2015 -0700
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Mon Oct 26 13:33:51 2015 -0700

----------------------------------------------------------------------
 .../org/apache/kafka/common/utils/Utils.java    |  35 ++++-
 .../apache/kafka/streams/StreamingConfig.java   |  40 ++++--
 .../kafka/streams/kstream/KStreamBuilder.java   |  10 +-
 .../streams/kstream/SlidingWindowSupplier.java  |   4 +-
 .../streams/kstream/internals/KStreamImpl.java  |  28 ++--
 .../streams/kstream/internals/KStreamJoin.java  |   4 -
 .../kstream/internals/KStreamWindowedImpl.java  |  19 ++-
 .../processor/DefaultPartitionGrouper.java      |  97 ++++++++++++++
 .../streams/processor/PartitionGrouper.java     |  55 ++++++++
 .../streams/processor/ProcessorContext.java     |   9 +-
 .../streams/processor/TopologyBuilder.java      |  85 +++++++++++-
 .../KafkaStreamingPartitionAssignor.java        | 133 +++++++++++++++++++
 .../internals/ProcessorContextImpl.java         |  35 -----
 .../internals/ProcessorStateManager.java        |  14 +-
 .../streams/processor/internals/QuickUnion.java |  67 ++++++++++
 .../streams/processor/internals/StreamTask.java |   2 +-
 .../processor/internals/StreamThread.java       |  84 +++++++++---
 .../streams/state/MeteredKeyValueStore.java     |   4 +-
 .../streams/state/RocksDBKeyValueStore.java     |   6 +-
 .../kstream/internals/KStreamJoinTest.java      |  33 ++++-
 .../processor/DefaultPartitionGrouperTest.java  |  76 +++++++++++
 .../streams/processor/TopologyBuilderTest.java  |  41 +++++-
 .../processor/internals/QuickUnionTest.java     |  97 ++++++++++++++
 .../processor/internals/StreamThreadTest.java   | 110 ++++++++++++---
 .../streams/state/KeyValueStoreTestDriver.java  |  54 ++++----
 .../apache/kafka/test/MockProcessorContext.java |   9 +-
 26 files changed, 980 insertions(+), 171 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index bc0e645..974cf1e 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -32,6 +32,7 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import java.util.Properties;
@@ -574,10 +575,40 @@ public class Utils {
      * @param <T> the type of element
      * @return Set
      */
-    public static <T> HashSet<T> mkSet(T... elems) {
+    public static <T> Set<T> mkSet(T... elems) {
         return new HashSet<>(Arrays.asList(elems));
     }
-    
+
+    /*
+     * Creates a list
+     * @param elems the elements
+     * @param <T> the type of element
+     * @return List
+     */
+    public static <T> List<T> mkList(T... elems) {
+        return Arrays.asList(elems);
+    }
+
+
+    /*
+     * Create a string from a collection
+     * @param coll the collection
+     * @param separator the separator
+     */
+    public static <T> CharSequence mkString(Collection<T> coll, String separator) {
+        StringBuilder sb = new StringBuilder();
+        Iterator<T> iter = coll.iterator();
+        if (iter.hasNext()) {
+            sb.append(iter.next().toString());
+
+            while (iter.hasNext()) {
+                sb.append(separator);
+                sb.append(iter.next().toString());
+            }
+        }
+        return sb;
+    }
+
     /**
      * Recursively delete the given file/directory and any subfiles (if any exist)
      *

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
index 93df4c2..a0aef48 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
@@ -24,6 +24,9 @@ import org.apache.kafka.common.config.AbstractConfig;
 import org.apache.kafka.common.config.ConfigDef;
 import org.apache.kafka.common.config.ConfigDef.Importance;
 import org.apache.kafka.common.config.ConfigDef.Type;
+import org.apache.kafka.streams.processor.DefaultPartitionGrouper;
+import org.apache.kafka.streams.processor.PartitionGrouper;
+import org.apache.kafka.streams.processor.internals.KafkaStreamingPartitionAssignor;
 
 import java.util.Map;
 
@@ -70,6 +73,10 @@ public class StreamingConfig extends AbstractConfig {
     public static final String TIMESTAMP_EXTRACTOR_CLASS_CONFIG = "timestamp.extractor";
     private static final String TIMESTAMP_EXTRACTOR_CLASS_DOC = "Timestamp extractor class that implements the <code>TimestampExtractor</code> interface.";
 
+    /** <code>partition.grouper</code> */
+    public static final String PARTITION_GROUPER_CLASS_CONFIG = "partition.grouper";
+    private static final String PARTITION_GROUPER_CLASS_DOC = "Partition grouper class that implements the <code>PartitionGrouper</code> interface.";
+
     /** <code>client.id</code> */
     public static final String CLIENT_ID_CONFIG = CommonClientConfigs.CLIENT_ID_CONFIG;
 
@@ -108,15 +115,15 @@ public class StreamingConfig extends AbstractConfig {
                                         Importance.MEDIUM,
                                         CommonClientConfigs.CLIENT_ID_DOC)
                                 .define(STATE_DIR_CONFIG,
-                                    Type.STRING,
-                                    SYSTEM_TEMP_DIRECTORY,
-                                    Importance.MEDIUM,
-                                    STATE_DIR_DOC)
+                                        Type.STRING,
+                                        SYSTEM_TEMP_DIRECTORY,
+                                        Importance.MEDIUM,
+                                        STATE_DIR_DOC)
                                 .define(COMMIT_INTERVAL_MS_CONFIG,
-                                    Type.LONG,
-                                    30000,
-                                    Importance.HIGH,
-                                    COMMIT_INTERVAL_MS_DOC)
+                                        Type.LONG,
+                                        30000,
+                                        Importance.HIGH,
+                                        COMMIT_INTERVAL_MS_DOC)
                                 .define(POLL_MS_CONFIG,
                                         Type.LONG,
                                         100,
@@ -167,6 +174,11 @@ public class StreamingConfig extends AbstractConfig {
                                         Type.CLASS,
                                         Importance.HIGH,
                                         TIMESTAMP_EXTRACTOR_CLASS_DOC)
+                                .define(PARTITION_GROUPER_CLASS_CONFIG,
+                                        Type.CLASS,
+                                        DefaultPartitionGrouper.class,
+                                        Importance.HIGH,
+                                        PARTITION_GROUPER_CLASS_DOC)
                                 .define(BOOTSTRAP_SERVERS_CONFIG,
                                         Type.STRING,
                                         Importance.HIGH,
@@ -190,16 +202,26 @@ public class StreamingConfig extends AbstractConfig {
                                         CommonClientConfigs.METRICS_NUM_SAMPLES_DOC);
     }
 
+    public static class InternalConfig {
+        public static final String PARTITION_GROUPER_INSTANCE = "__partition.grouper.instance__";
+    }
+
     public StreamingConfig(Map<?, ?> props) {
         super(CONFIG, props);
     }
 
+    public Map<String, Object> getConsumerConfigs(PartitionGrouper partitionGrouper) {
+        Map<String, Object> props = getConsumerConfigs();
+        props.put(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper);
+        props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, KafkaStreamingPartitionAssignor.class.getName());
+        return props;
+    }
+
     public Map<String, Object> getConsumerConfigs() {
         Map<String, Object> props = this.originals();
 
         // set consumer default property values
         props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false");
-        props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, "range");
 
         // remove properties that are not required for consumers
         props.remove(StreamingConfig.KEY_SERIALIZER_CLASS_CONFIG);

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/kstream/KStreamBuilder.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/KStreamBuilder.java b/streams/src/main/java/org/apache/kafka/streams/kstream/KStreamBuilder.java
index 2d4dcc7..5b3feb6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/KStreamBuilder.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/KStreamBuilder.java
@@ -21,6 +21,8 @@ import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.streams.kstream.internals.KStreamImpl;
 import org.apache.kafka.streams.processor.TopologyBuilder;
 
+import java.util.Collections;
+
 /**
  * KStreamBuilder is the class to create KStream instances.
  */
@@ -31,7 +33,7 @@ public class KStreamBuilder extends TopologyBuilder {
     }
 
     /**
-     * Creates a KStream instance for the specified topic. The stream is added to the default synchronization group.
+     * Creates a KStream instance for the specified topic.
      * The default deserializers specified in the config are used.
      *
      * @param topics          the topic names, if empty default to all the topics in the config
@@ -42,11 +44,11 @@ public class KStreamBuilder extends TopologyBuilder {
 
         addSource(name, topics);
 
-        return new KStreamImpl<>(this, name);
+        return new KStreamImpl<>(this, name, Collections.singleton(name));
     }
 
     /**
-     * Creates a KStream instance for the specified topic. The stream is added to the default synchronization group.
+     * Creates a KStream instance for the specified topic.
      *
      * @param keyDeserializer key deserializer used to read this source KStream,
      *                        if not specified the default deserializer defined in the configs will be used
@@ -60,6 +62,6 @@ public class KStreamBuilder extends TopologyBuilder {
 
         addSource(name, keyDeserializer, valDeserializer, topics);
 
-        return new KStreamImpl<>(this, name);
+        return new KStreamImpl<>(this, name, Collections.singleton(name));
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindowSupplier.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindowSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindowSupplier.java
index bf6b4dc..1d53123 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindowSupplier.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindowSupplier.java
@@ -75,6 +75,7 @@ public class SlidingWindowSupplier<K, V> implements WindowSupplier<K, V> {
     public class SlidingWindow extends WindowSupport implements Window<K, V> {
         private final Object lock = new Object();
         private ProcessorContext context;
+        private int partition;
         private int slotNum; // used as a key for Kafka log compaction
         private LinkedList<K> list = new LinkedList<K>();
         private HashMap<K, ValueList<V>> map = new HashMap<>();
@@ -82,6 +83,7 @@ public class SlidingWindowSupplier<K, V> implements WindowSupplier<K, V> {
         @Override
         public void init(ProcessorContext context) {
             this.context = context;
+            this.partition = context.id();
             SlidingWindowRegistryCallback restoreFunc = new SlidingWindowRegistryCallback();
             context.register(this, restoreFunc);
 
@@ -210,7 +212,7 @@ public class SlidingWindowSupplier<K, V> implements WindowSupplier<K, V> {
                         if (offset != combined.length)
                             throw new IllegalStateException("serialized length does not match");
 
-                        collector.send(new ProducerRecord<>(name, context.id(), slot, combined), byteArraySerializer, byteArraySerializer);
+                        collector.send(new ProducerRecord<>(name, partition, slot, combined), byteArraySerializer, byteArraySerializer);
                     }
                     values.clearDirtyValues();
                 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java
index 8f56e09..404193a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java
@@ -32,6 +32,8 @@ import org.apache.kafka.streams.processor.ProcessorSupplier;
 import org.apache.kafka.streams.processor.TopologyBuilder;
 
 import java.lang.reflect.Array;
+import java.util.Collections;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 
 public class KStreamImpl<K, V> implements KStream<K, V> {
@@ -72,10 +74,12 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
     protected final TopologyBuilder topology;
     protected final String name;
+    protected final Set<String> sourceNodes;
 
-    public KStreamImpl(TopologyBuilder topology, String name) {
+    public KStreamImpl(TopologyBuilder topology, String name, Set<String> sourceNodes) {
         this.topology = topology;
         this.name = name;
+        this.sourceNodes = sourceNodes;
     }
 
     @Override
@@ -84,7 +88,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamFilter<>(predicate, false), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, sourceNodes);
     }
 
     @Override
@@ -93,7 +97,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamFilter<>(predicate, true), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, sourceNodes);
     }
 
     @Override
@@ -102,7 +106,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamMap<>(mapper), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, null);
     }
 
     @Override
@@ -111,7 +115,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamMapValues<>(mapper), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, sourceNodes);
     }
 
     @Override
@@ -120,7 +124,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamFlatMap<>(mapper), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, null);
     }
 
     @Override
@@ -129,7 +133,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamFlatMapValues<>(mapper), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, sourceNodes);
     }
 
     @Override
@@ -138,7 +142,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamWindow<>(windowSupplier), this.name);
 
-        return new KStreamWindowedImpl<>(topology, name, windowSupplier);
+        return new KStreamWindowedImpl<>(topology, name, sourceNodes, windowSupplier);
     }
 
     @Override
@@ -154,7 +158,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
             topology.addProcessor(childName, new KStreamPassThrough<K, V>(), branchName);
 
-            branchChildren[i] = new KStreamImpl<>(topology, childName);
+            branchChildren[i] = new KStreamImpl<>(topology, childName, sourceNodes);
         }
 
         return branchChildren;
@@ -174,7 +178,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addSource(sourceName, keyDeserializer, valDeserializer, topic);
 
-        return new KStreamImpl<>(topology, sourceName);
+        return new KStreamImpl<>(topology, sourceName, Collections.<String>emptySet());
     }
 
     @Override
@@ -202,7 +206,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamTransform<>(transformerSupplier), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, null);
     }
 
     @Override
@@ -211,7 +215,7 @@ public class KStreamImpl<K, V> implements KStream<K, V> {
 
         topology.addProcessor(name, new KStreamTransformValues<>(valueTransformerSupplier), this.name);
 
-        return new KStreamImpl<>(topology, name);
+        return new KStreamImpl<>(topology, name, sourceNodes);
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoin.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoin.java
index 997953f..5e8186e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoin.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoin.java
@@ -59,10 +59,6 @@ class KStreamJoin<K, V, V1, V2> implements ProcessorSupplier<K, V1> {
         public void init(ProcessorContext context) {
             super.init(context);
 
-            // check if these two streams are joinable
-            if (!context.joinable())
-                throw new IllegalStateException("Streams are not joinable.");
-
             final Window<K, V2> window = (Window<K, V2>) context.getStateStore(windowName);
 
             this.finder = new Finder<K, V2>() {

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowedImpl.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowedImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowedImpl.java
index 9316012..4e9f4c6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowedImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowedImpl.java
@@ -17,18 +17,22 @@
 
 package org.apache.kafka.streams.kstream.internals;
 
+import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.KStreamWindowed;
 import org.apache.kafka.streams.kstream.ValueJoiner;
 import org.apache.kafka.streams.kstream.WindowSupplier;
 import org.apache.kafka.streams.processor.TopologyBuilder;
 
+import java.util.HashSet;
+import java.util.Set;
+
 public final class KStreamWindowedImpl<K, V> extends KStreamImpl<K, V> implements KStreamWindowed<K, V> {
 
     private final WindowSupplier<K, V> windowSupplier;
 
-    public KStreamWindowedImpl(TopologyBuilder topology, String name, WindowSupplier<K, V> windowSupplier) {
-        super(topology, name);
+    public KStreamWindowedImpl(TopologyBuilder topology, String name, Set<String> sourceNodes, WindowSupplier<K, V> windowSupplier) {
+        super(topology, name, sourceNodes);
         this.windowSupplier = windowSupplier;
     }
 
@@ -36,6 +40,14 @@ public final class KStreamWindowedImpl<K, V> extends KStreamImpl<K, V> implement
     public <V1, V2> KStream<K, V2> join(KStreamWindowed<K, V1> other, ValueJoiner<V, V1, V2> valueJoiner) {
         String thisWindowName = this.windowSupplier.name();
         String otherWindowName = ((KStreamWindowedImpl<K, V1>) other).windowSupplier.name();
+        Set<String> thisSourceNodes = this.sourceNodes;
+        Set<String> otherSourceNodes = ((KStreamWindowedImpl<K, V1>) other).sourceNodes;
+
+        if (thisSourceNodes == null || otherSourceNodes == null)
+            throw new KafkaException("not joinable");
+
+        Set<String> allSourceNodes = new HashSet<>(sourceNodes);
+        allSourceNodes.addAll(((KStreamWindowedImpl<K, V1>) other).sourceNodes);
 
         KStreamJoin<K, V2, V, V1> joinThis = new KStreamJoin<>(otherWindowName, valueJoiner);
         KStreamJoin<K, V2, V1, V> joinOther = new KStreamJoin<>(thisWindowName, KStreamJoin.reverseJoiner(valueJoiner));
@@ -48,7 +60,8 @@ public final class KStreamWindowedImpl<K, V> extends KStreamImpl<K, V> implement
         topology.addProcessor(joinThisName, joinThis, this.name);
         topology.addProcessor(joinOtherName, joinOther, ((KStreamImpl) other).name);
         topology.addProcessor(joinMergeName, joinMerge, joinThisName, joinOtherName);
+        topology.copartitionSources(allSourceNodes);
 
-        return new KStreamImpl<>(topology, joinMergeName);
+        return new KStreamImpl<>(topology, joinMergeName, allSourceNodes);
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/DefaultPartitionGrouper.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/DefaultPartitionGrouper.java b/streams/src/main/java/org/apache/kafka/streams/processor/DefaultPartitionGrouper.java
new file mode 100644
index 0000000..f87cfa8
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/DefaultPartitionGrouper.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor;
+
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+public class DefaultPartitionGrouper extends PartitionGrouper {
+
+    public Map<Integer, List<TopicPartition>> partitionGroups(Cluster metadata) {
+        Map<Integer, List<TopicPartition>> groups = new HashMap<>();
+        List<List<String>> sortedTopicGroups = sort(topicGroups);
+
+        int taskId = 0;
+        for (List<String> topicGroup : sortedTopicGroups) {
+            int maxNumPartitions = maxNumPartitions(metadata, topicGroup);
+
+            for (int partitionId = 0; partitionId < maxNumPartitions; partitionId++) {
+                List<TopicPartition> group = new ArrayList<>(topicGroup.size());
+
+                for (String topic : topicGroup) {
+                    if (partitionId < metadata.partitionsForTopic(topic).size()) {
+                        group.add(new TopicPartition(topic, partitionId));
+                    }
+                }
+                groups.put(taskId++, group);
+            }
+        }
+
+        // make the data unmodifiable, then return
+        Map<Integer, List<TopicPartition>> unmodifiableGroups = new HashMap<>();
+        for (Map.Entry<Integer, List<TopicPartition>> entry : groups.entrySet()) {
+            unmodifiableGroups.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
+        }
+        return Collections.unmodifiableMap(unmodifiableGroups);
+    }
+
+    protected int maxNumPartitions(Cluster metadata, List<String> topics) {
+        int maxNumPartitions = 0;
+        for (String topic : topics) {
+            List<PartitionInfo> infos = metadata.partitionsForTopic(topic);
+
+            if (infos == null)
+                throw new KafkaException("topic not found :" + topic);
+
+            int numPartitions = infos.size();
+            if (numPartitions > maxNumPartitions)
+                maxNumPartitions = numPartitions;
+        }
+        return maxNumPartitions;
+    }
+
+    protected List<List<String>> sort(Collection<Set<String>> topicGroups) {
+        TreeMap<String, String[]> sortedMap = new TreeMap<>();
+
+        for (Set<String> group : topicGroups) {
+            String[] arr = group.toArray(new String[group.size()]);
+            Arrays.sort(arr);
+            sortedMap.put(arr[0], arr);
+        }
+
+        ArrayList<List<String>> list = new ArrayList(sortedMap.size());
+        for (String[] arr : sortedMap.values()) {
+            list.add(Arrays.asList(arr));
+        }
+
+        return list;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
new file mode 100644
index 0000000..82bb36a
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor;
+
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.processor.internals.KafkaStreamingPartitionAssignor;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public abstract class PartitionGrouper {
+
+    protected Collection<Set<String>> topicGroups;
+
+    private KafkaStreamingPartitionAssignor partitionAssignor = null;
+
+    /**
+     * Returns a map of task ids to groups of partitions.
+     *
+     * @param metadata
+     * @return a map of task ids to groups of partitions
+     */
+    public abstract Map<Integer, List<TopicPartition>> partitionGroups(Cluster metadata);
+
+    public void topicGroups(Collection<Set<String>> topicGroups) {
+        this.topicGroups = topicGroups;
+    }
+
+    public void partitionAssignor(KafkaStreamingPartitionAssignor partitionAssignor) {
+        this.partitionAssignor = partitionAssignor;
+    }
+
+    public Set<Integer> taskIds(TopicPartition partition) {
+        return partitionAssignor.taskIds(partition);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java
index adffe0e..e7cf257 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java
@@ -26,9 +26,9 @@ import java.io.File;
 public interface ProcessorContext {
 
     /**
-     * Returns the partition group id
+     * Returns the task id
      *
-     * @return partition group id
+     * @return the task id
      */
     int id();
 
@@ -75,11 +75,6 @@ public interface ProcessorContext {
     StreamingMetrics metrics();
 
     /**
-     * Check if this process's incoming streams are joinable
-     */
-    boolean joinable();
-
-    /**
      * Registers and possibly restores the specified storage engine.
      *
      * @param store the storage engine

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java b/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
index 833e29b..a475e1e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/TopologyBuilder.java
@@ -22,10 +22,13 @@ import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.ProcessorTopology;
+import org.apache.kafka.streams.processor.internals.QuickUnion;
 import org.apache.kafka.streams.processor.internals.SinkNode;
 import org.apache.kafka.streams.processor.internals.SourceNode;
 
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -45,10 +48,14 @@ import java.util.Set;
 public class TopologyBuilder {
 
     // list of node factories in a topological order
-    private ArrayList<NodeFactory> nodeFactories = new ArrayList<>();
+    private final ArrayList<NodeFactory> nodeFactories = new ArrayList<>();
 
-    private Set<String> nodeNames = new HashSet<>();
-    private Set<String> sourceTopicNames = new HashSet<>();
+    private final Set<String> nodeNames = new HashSet<>();
+    private final Set<String> sourceTopicNames = new HashSet<>();
+
+    private final QuickUnion<String> nodeGroups = new QuickUnion<>();
+    private final List<Set<String>> copartitionSourceGroups = new ArrayList<>();
+    private final HashMap<String, String[]> nodeToTopics = new HashMap<>();
 
     private interface NodeFactory {
         ProcessorNode build();
@@ -158,6 +165,9 @@ public class TopologyBuilder {
 
         nodeNames.add(name);
         nodeFactories.add(new SourceNodeFactory(name, topics, keyDeserializer, valDeserializer));
+        nodeToTopics.put(name, topics.clone());
+        nodeGroups.add(name);
+
         return this;
     }
 
@@ -237,10 +247,79 @@ public class TopologyBuilder {
 
         nodeNames.add(name);
         nodeFactories.add(new ProcessorNodeFactory(name, parentNames, supplier));
+        nodeGroups.add(name);
+        nodeGroups.unite(name, parentNames);
         return this;
     }
 
     /**
+     * Returns the topic groups.
+     * A topic group is a group of topics in the same task.
+     *
+     * @return groups of topic names
+     */
+    public Collection<Set<String>> topicGroups() {
+        List<Set<String>> topicGroups = new ArrayList<>();
+
+        for (Set<String> nodeGroup : generateNodeGroups(nodeGroups)) {
+            Set<String> topicGroup = new HashSet<>();
+            for (String node : nodeGroup) {
+                String[] topics = nodeToTopics.get(node);
+                if (topics != null)
+                    topicGroup.addAll(Arrays.asList(topics));
+            }
+            topicGroups.add(Collections.unmodifiableSet(topicGroup));
+        }
+
+        return Collections.unmodifiableList(topicGroups);
+    }
+
+    private Collection<Set<String>> generateNodeGroups(QuickUnion<String> grouping) {
+        HashMap<String, Set<String>> nodeGroupMap = new HashMap<>();
+
+        for (String nodeName : nodeNames) {
+            String root = grouping.root(nodeName);
+            Set<String> nodeGroup = nodeGroupMap.get(root);
+            if (nodeGroup == null) {
+                nodeGroup = new HashSet<>();
+                nodeGroupMap.put(root, nodeGroup);
+            }
+            nodeGroup.add(nodeName);
+        }
+
+        return nodeGroupMap.values();
+    }
+
+    /**
+     * Asserts that the streams of the specified source nodes must be copartitioned.
+     *
+     * @param sourceNodes a set of source node names
+     */
+    public void copartitionSources(Collection<String> sourceNodes) {
+        copartitionSourceGroups.add(Collections.unmodifiableSet(new HashSet<>(sourceNodes)));
+    }
+
+    /**
+     * Returns the copartition groups.
+     * A copartition group is a group of topics that are required to be copartitioned.
+     *
+     * @return groups of topic names
+     */
+    public Collection<Set<String>> copartitionGroups() {
+        List<Set<String>> list = new ArrayList<>(copartitionSourceGroups.size());
+        for (Set<String> nodeNames : copartitionSourceGroups) {
+            Set<String> copartitionGroup = new HashSet<>();
+            for (String node : nodeNames) {
+                String[] topics = nodeToTopics.get(node);
+                if (topics != null)
+                    copartitionGroup.addAll(Arrays.asList(topics));
+            }
+            list.add(Collections.unmodifiableSet(copartitionGroup));
+        }
+        return Collections.unmodifiableList(list);
+    }
+
+    /**
      * Build the topology. This is typically called automatically when passing this builder into the
      * {@link KafkaStreaming#KafkaStreaming(TopologyBuilder, StreamingConfig)} constructor.
      *

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
new file mode 100644
index 0000000..ee5bb93
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
@@ -0,0 +1,133 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.internals.PartitionAssignor;
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Configurable;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.StreamingConfig;
+import org.apache.kafka.streams.processor.PartitionGrouper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Configurable {
+
+    private static final Logger log = LoggerFactory.getLogger(KafkaStreamingPartitionAssignor.class);
+
+    private PartitionGrouper partitionGrouper;
+    private Map<TopicPartition, Set<Integer>> partitionToTaskIds;
+
+    @Override
+    public void configure(Map<String, ?> configs) {
+        Object o = configs.get(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE);
+        if (o == null)
+            throw new KafkaException("PartitionGrouper is not specified");
+
+        if (!PartitionGrouper.class.isInstance(o))
+            throw new KafkaException(o.getClass().getName() + " is not an instance of " + PartitionGrouper.class.getName());
+
+        partitionGrouper = (PartitionGrouper) o;
+        partitionGrouper.partitionAssignor(this);
+    }
+
+    @Override
+    public String name() {
+        return "streaming";
+    }
+
+    @Override
+    public Subscription subscription(Set<String> topics) {
+        return new Subscription(new ArrayList<>(topics));
+    }
+
+    @Override
+    public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription> subscriptions) {
+        Map<Integer, List<TopicPartition>> partitionGroups = partitionGrouper.partitionGroups(metadata);
+
+        String[] clientIds = subscriptions.keySet().toArray(new String[subscriptions.size()]);
+        Integer[] taskIds = partitionGroups.keySet().toArray(new Integer[partitionGroups.size()]);
+
+        Map<String, Assignment> assignment = new HashMap<>();
+
+        for (int i = 0; i < clientIds.length; i++) {
+            List<TopicPartition> partitions = new ArrayList<>();
+            List<Integer> ids = new ArrayList<>();
+            for (int j = i; j < taskIds.length; j += clientIds.length) {
+                Integer taskId = taskIds[j];
+                for (TopicPartition partition : partitionGroups.get(taskId)) {
+                    partitions.add(partition);
+                    ids.add(taskId);
+                }
+            }
+            ByteBuffer buf = ByteBuffer.allocate(4 + ids.size() * 4);
+            //version
+            buf.putInt(1);
+            // encode task ids
+            for (Integer id : ids) {
+                buf.putInt(id);
+            }
+            buf.rewind();
+            assignment.put(clientIds[i], new Assignment(partitions, buf));
+        }
+
+        return assignment;
+    }
+
+    @Override
+    public void onAssignment(Assignment assignment) {
+        List<TopicPartition> partitions = assignment.partitions();
+        ByteBuffer data = assignment.userData();
+        data.rewind();
+
+        Map<TopicPartition, Set<Integer>> partitionToTaskIds = new HashMap<>();
+
+        // check version
+        int version = data.getInt();
+        if (version == 1) {
+            for (TopicPartition partition : partitions) {
+                Set<Integer> taskIds = partitionToTaskIds.get(partition);
+                if (taskIds == null) {
+                    taskIds = new HashSet<>();
+                    partitionToTaskIds.put(partition, taskIds);
+                }
+                // decode a task id
+                taskIds.add(data.getInt());
+            }
+        } else {
+            KafkaException ex = new KafkaException("unknown assignment data version: " + version);
+            log.error(ex.getMessage(), ex);
+            throw ex;
+        }
+        this.partitionToTaskIds = partitionToTaskIds;
+    }
+
+    public Set<Integer> taskIds(TopicPartition partition) {
+        return partitionToTaskIds.get(partition);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
index 5cb53a4..dfc838c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
@@ -18,7 +18,6 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.KafkaException;
-import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.streams.StreamingConfig;
@@ -31,11 +30,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.File;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
 
 public class ProcessorContextImpl implements ProcessorContext, RecordCollector.Supplier {
 
@@ -85,35 +79,6 @@ public class ProcessorContextImpl implements ProcessorContext, RecordCollector.S
     }
 
     @Override
-    public boolean joinable() {
-        Set<TopicPartition> partitions = this.task.partitions();
-        Map<Integer, List<String>> partitionsById = new HashMap<>();
-        int firstId = -1;
-        for (TopicPartition partition : partitions) {
-            if (!partitionsById.containsKey(partition.partition())) {
-                partitionsById.put(partition.partition(), new ArrayList<String>());
-            }
-            partitionsById.get(partition.partition()).add(partition.topic());
-
-            if (firstId < 0)
-                firstId = partition.partition();
-        }
-
-        List<String> topics = partitionsById.get(firstId);
-        for (List<String> topicsPerPartition : partitionsById.values()) {
-            if (topics.size() != topicsPerPartition.size())
-                return false;
-
-            for (String topic : topicsPerPartition) {
-                if (!topics.contains(topic))
-                    return false;
-            }
-        }
-
-        return true;
-    }
-
-    @Override
     public int id() {
         return id;
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index 59a6394..3cb9cea 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -44,7 +44,7 @@ public class ProcessorStateManager {
     public static final String CHECKPOINT_FILE_NAME = ".checkpoint";
     public static final String LOCK_FILE_NAME = ".lock";
 
-    private final int id;
+    private final int partition;
     private final File baseDir;
     private final FileLock directoryLock;
     private final Map<String, StateStore> stores;
@@ -52,8 +52,8 @@ public class ProcessorStateManager {
     private final Map<TopicPartition, Long> restoredOffsets;
     private final Map<TopicPartition, Long> checkpointedOffsets;
 
-    public ProcessorStateManager(int id, File baseDir, Consumer<byte[], byte[]> restoreConsumer) throws IOException {
-        this.id = id;
+    public ProcessorStateManager(int partition, File baseDir, Consumer<byte[], byte[]> restoreConsumer) throws IOException {
+        this.partition = partition;
         this.baseDir = baseDir;
         this.stores = new HashMap<>();
         this.restoreConsumer = restoreConsumer;
@@ -109,14 +109,14 @@ public class ProcessorStateManager {
         if (restoreConsumer.listTopics().containsKey(store.name())) {
             boolean partitionNotFound = true;
             for (PartitionInfo partitionInfo : restoreConsumer.partitionsFor(store.name())) {
-                if (partitionInfo.partition() == id) {
+                if (partitionInfo.partition() == partition) {
                     partitionNotFound = false;
                     break;
                 }
             }
 
             if (partitionNotFound)
-                throw new IllegalStateException("Store " + store.name() + "'s change log does not contain the partition for group " + id);
+                throw new IllegalStateException("Store " + store.name() + "'s change log does not contain the partition " + partition);
 
         } else {
             throw new IllegalStateException("Change log topic for store " + store.name() + " does not exist yet");
@@ -127,7 +127,7 @@ public class ProcessorStateManager {
         // ---- try to restore the state from change-log ---- //
 
         // subscribe to the store's partition
-        TopicPartition storePartition = new TopicPartition(store.name(), id);
+        TopicPartition storePartition = new TopicPartition(store.name(), partition);
         if (!restoreConsumer.subscription().isEmpty()) {
             throw new IllegalStateException("Restore consumer should have not subscribed to any partitions beforehand");
         }
@@ -201,7 +201,7 @@ public class ProcessorStateManager {
 
             Map<TopicPartition, Long> checkpointOffsets = new HashMap<>();
             for (String storeName : stores.keySet()) {
-                TopicPartition part = new TopicPartition(storeName, id);
+                TopicPartition part = new TopicPartition(storeName, partition);
 
                 // only checkpoint the offset to the offsets file if it is persistent;
                 if (stores.get(storeName).persistent()) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/internals/QuickUnion.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/QuickUnion.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/QuickUnion.java
new file mode 100644
index 0000000..087cbd2
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/QuickUnion.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import java.util.HashMap;
+import java.util.NoSuchElementException;
+
+public class QuickUnion<T> {
+
+    private HashMap<T, T> ids = new HashMap<>();
+
+    public void add(T id) {
+        ids.put(id, id);
+    }
+
+    public boolean exists(T id) {
+        return ids.containsKey(id);
+    }
+
+    public T root(T id) {
+        T current = id;
+        T parent = ids.get(current);
+
+        if (parent == null)
+            throw new NoSuchElementException("id: " + id.toString());
+
+        while (!parent.equals(current)) {
+            // do the path compression
+            T grandparent = ids.get(parent);
+            ids.put(current, grandparent);
+
+            current = parent;
+            parent = grandparent;
+        }
+        return current;
+    }
+
+    public void unite(T id1, T... idList) {
+        for (T id2 : idList) {
+            unitePair(id1, id2);
+        }
+    }
+
+    private void unitePair(T id1, T id2) {
+        T root1 = root(id1);
+        T root2 = root(id2);
+
+        if (!root1.equals(root2))
+            ids.put(root1, root2);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 0ceec52..1de6f9b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -112,7 +112,7 @@ public class StreamTask implements Punctuator {
         // create the record recordCollector that maintains the produced offsets
         this.recordCollector = new RecordCollector(producer);
 
-        log.info("Creating restoration consumer client for stream task [" + id + "]");
+        log.info("Creating restoration consumer client for stream task #" + id());
 
         // create the processor state manager
         try {

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 7d935eb..e3803a1 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -25,6 +25,7 @@ import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.metrics.MeasurableStat;
 import org.apache.kafka.common.metrics.Metrics;
@@ -39,6 +40,7 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamingConfig;
 import org.apache.kafka.streams.StreamingMetrics;
+import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.TopologyBuilder;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -47,12 +49,15 @@ import java.io.File;
 import java.io.IOException;
 import java.nio.channels.FileLock;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -65,6 +70,7 @@ public class StreamThread extends Thread {
 
     protected final StreamingConfig config;
     protected final TopologyBuilder builder;
+    protected final PartitionGrouper partitionGrouper;
     protected final Producer<byte[], byte[]> producer;
     protected final Consumer<byte[], byte[]> consumer;
     protected final Consumer<byte[], byte[]> restoreConsumer;
@@ -119,6 +125,8 @@ public class StreamThread extends Thread {
         this.config = config;
         this.builder = builder;
         this.clientId = clientId;
+        this.partitionGrouper = config.getConfiguredInstance(StreamingConfig.PARTITION_GROUPER_CLASS_CONFIG, PartitionGrouper.class);
+        this.partitionGrouper.topicGroups(builder.topicGroups());
 
         // set the producer and consumer clients
         this.producer = (producer != null) ? producer : createProducer();
@@ -155,7 +163,7 @@ public class StreamThread extends Thread {
 
     private Consumer<byte[], byte[]> createConsumer() {
         log.info("Creating consumer client for stream thread [" + this.getName() + "]");
-        return new KafkaConsumer<>(config.getConsumerConfigs(),
+        return new KafkaConsumer<>(config.getConsumerConfigs(partitionGrouper),
                 new ByteArrayDeserializer(),
                 new ByteArrayDeserializer());
     }
@@ -233,6 +241,8 @@ public class StreamThread extends Thread {
             int totalNumBuffered = 0;
             boolean requiresPoll = true;
 
+            ensureCopartitioning(builder.copartitionGroups());
+
             consumer.subscribe(new ArrayList<>(builder.sourceTopics()), rebalanceListener);
 
             while (stillRunning()) {
@@ -365,7 +375,7 @@ public class StreamThread extends Thread {
             if (stateDirs != null) {
                 for (File dir : stateDirs) {
                     try {
-                        Integer id = Integer.parseInt(dir.getName());
+                        int id = Integer.parseInt(dir.getName());
 
                         // try to acquire the exclusive lock on the state directory
                         FileLock directoryLock = null;
@@ -404,27 +414,28 @@ public class StreamThread extends Thread {
     }
 
     private void addPartitions(Collection<TopicPartition> assignment) {
-        HashSet<TopicPartition> partitions = new HashSet<>(assignment);
-
-        // TODO: change this hard-coded co-partitioning behavior
-        for (TopicPartition partition : partitions) {
-            final Integer id = partition.partition();
-            StreamTask task = tasks.get(id);
-            if (task == null) {
-                // get the partitions for the task
-                HashSet<TopicPartition> partitionsForTask = new HashSet<>();
-                for (TopicPartition part : partitions)
-                    if (part.partition() == id)
-                        partitionsForTask.add(part);
-
-                // create the task
-                try {
-                    task = createStreamTask(id, partitionsForTask);
-                } catch (Exception e) {
-                    log.error("Failed to create a task #" + id + " in thread [" + this.getName() + "]: ", e);
-                    throw e;
+
+        HashMap<Integer, Set<TopicPartition>> partitionsForTask = new HashMap<>();
+
+        for (TopicPartition partition : assignment) {
+            Set<Integer> taskIds = partitionGrouper.taskIds(partition);
+            for (Integer taskId : taskIds) {
+                Set<TopicPartition> partitions = partitionsForTask.get(taskId);
+                if (partitions == null) {
+                    partitions = new HashSet<>();
+                    partitionsForTask.put(taskId, partitions);
                 }
-                tasks.put(id, task);
+                partitions.add(partition);
+            }
+        }
+
+        // create the tasks
+        for (Integer taskId : partitionsForTask.keySet()) {
+            try {
+                tasks.put(taskId, createStreamTask(taskId, partitionsForTask.get(taskId)));
+            } catch (Exception e) {
+                log.error("Failed to create a task #" + taskId + " in thread [" + this.getName() + "]: ", e);
+                throw e;
             }
         }
 
@@ -447,6 +458,35 @@ public class StreamThread extends Thread {
         tasks.clear();
     }
 
+    public PartitionGrouper partitionGrouper() {
+        return partitionGrouper;
+    }
+
+    private void ensureCopartitioning(Collection<Set<String>> copartitionGroups) {
+        for (Set<String> copartitionGroup : copartitionGroups) {
+            ensureCopartitioning(copartitionGroup);
+        }
+    }
+
+    private void ensureCopartitioning(Set<String> copartitionGroup) {
+        int numPartitions = -1;
+
+        for (String topic : copartitionGroup) {
+            List<PartitionInfo> infos = consumer.partitionsFor(topic);
+
+            if (infos == null)
+                throw new KafkaException("topic not found: " + topic);
+
+            if (numPartitions == -1) {
+                numPartitions = infos.size();
+            } else if (numPartitions != infos.size()) {
+                String[] topics = copartitionGroup.toArray(new String[copartitionGroup.size()]);
+                Arrays.sort(topics);
+                throw new KafkaException("topics not copartitioned: [" + Utils.mkString(Arrays.asList(topics), ",") + "]");
+            }
+        }
+    }
+
     private class StreamingMetricsImpl implements StreamingMetrics {
         final Metrics metrics;
         final String metricGrpName;

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/state/MeteredKeyValueStore.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/MeteredKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/MeteredKeyValueStore.java
index 9a652ac..779bc75 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/MeteredKeyValueStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/MeteredKeyValueStore.java
@@ -171,7 +171,7 @@ public class MeteredKeyValueStore<K, V> implements KeyValueStore<K, V> {
     /**
      * Called when the underlying {@link #inner} {@link KeyValueStore} removes an entry in response to a call from this
      * store other than {@link #delete(Object)}.
-     * 
+     *
      * @param key the key for the entry that the inner store removed
      */
     protected void removed(K key) {
@@ -267,4 +267,4 @@ public class MeteredKeyValueStore<K, V> implements KeyValueStore<K, V> {
 
     }
 
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/main/java/org/apache/kafka/streams/state/RocksDBKeyValueStore.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/RocksDBKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/RocksDBKeyValueStore.java
index 32897ea..7393bb1 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/RocksDBKeyValueStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/RocksDBKeyValueStore.java
@@ -41,7 +41,7 @@ import java.util.NoSuchElementException;
  *
  * @param <K> the type of keys
  * @param <V> the type of values
- * 
+ *
  * @see Stores#create(String, ProcessorContext)
  */
 public class RocksDBKeyValueStore<K, V> extends MeteredKeyValueStore<K, V> {
@@ -166,7 +166,7 @@ public class RocksDBKeyValueStore<K, V> extends MeteredKeyValueStore<K, V> {
             for (Entry<K, V> entry : entries)
                 put(entry.key(), entry.value());
         }
-        
+
         @Override
         public V delete(K key) {
             V value = get(key);
@@ -281,4 +281,4 @@ public class RocksDBKeyValueStore<K, V> extends MeteredKeyValueStore<K, V> {
         }
 
     }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamJoinTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamJoinTest.java
index 58899fa..12bed17 100644
--- a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamJoinTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamJoinTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.kafka.streams.kstream.internals;
 
+import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.serialization.IntegerDeserializer;
 import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.utils.Utils;
@@ -32,12 +33,18 @@ import org.apache.kafka.test.MockProcessorSupplier;
 import org.apache.kafka.test.UnlimitedWindowDef;
 import org.junit.Test;
 
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
+
 import static org.junit.Assert.assertEquals;
 
 public class KStreamJoinTest {
 
     private String topic1 = "topic1";
     private String topic2 = "topic2";
+    private String dummyTopic = "dummyTopic";
 
     private IntegerDeserializer keyDeserializer = new IntegerDeserializer();
     private StringDeserializer valDeserializer = new StringDeserializer();
@@ -88,6 +95,7 @@ public class KStreamJoinTest {
 
         KStream<Integer, String> stream1;
         KStream<Integer, String> stream2;
+        KStream<Integer, String> dummyStream;
         KStreamWindowed<Integer, String> windowed1;
         KStreamWindowed<Integer, String> windowed2;
         MockProcessorSupplier<Integer, String> processor;
@@ -96,11 +104,17 @@ public class KStreamJoinTest {
         processor = new MockProcessorSupplier<>();
         stream1 = builder.from(keyDeserializer, valDeserializer, topic1);
         stream2 = builder.from(keyDeserializer, valDeserializer, topic2);
+        dummyStream = builder.from(keyDeserializer, valDeserializer, dummyTopic);
         windowed1 = stream1.with(new UnlimitedWindowDef<Integer, String>("window1"));
         windowed2 = stream2.with(new UnlimitedWindowDef<Integer, String>("window2"));
 
         windowed1.join(windowed2, joiner).process(processor);
 
+        Collection<Set<String>> copartitionGroups = builder.copartitionGroups();
+
+        assertEquals(1, copartitionGroups.size());
+        assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next());
+
         KStreamTestDriver driver = new KStreamTestDriver(builder);
         driver.setTime(0L);
 
@@ -160,5 +174,22 @@ public class KStreamJoinTest {
         }
     }
 
-    // TODO: test for joinability
+    @Test(expected = KafkaException.class)
+    public void testNotJoinable() {
+        KStreamBuilder builder = new KStreamBuilder();
+
+        KStream<Integer, String> stream1;
+        KStream<Integer, String> stream2;
+        KStreamWindowed<Integer, String> windowed1;
+        KStreamWindowed<Integer, String> windowed2;
+        MockProcessorSupplier<Integer, String> processor;
+
+        processor = new MockProcessorSupplier<>();
+        stream1 = builder.from(keyDeserializer, valDeserializer, topic1).map(keyValueMapper);
+        stream2 = builder.from(keyDeserializer, valDeserializer, topic2);
+        windowed1 = stream1.with(new UnlimitedWindowDef<Integer, String>("window1"));
+        windowed2 = stream2.with(new UnlimitedWindowDef<Integer, String>("window2"));
+
+        windowed1.join(windowed2, joiner).process(processor);
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/test/java/org/apache/kafka/streams/processor/DefaultPartitionGrouperTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/DefaultPartitionGrouperTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/DefaultPartitionGrouperTest.java
new file mode 100644
index 0000000..388955e
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/DefaultPartitionGrouperTest.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor;
+
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import static org.apache.kafka.common.utils.Utils.mkList;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+public class DefaultPartitionGrouperTest {
+
+    private List<PartitionInfo> infos = Arrays.asList(
+            new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0])
+    );
+
+    private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos);
+
+    @Test
+    public void testGrouping() {
+        PartitionGrouper grouper = new DefaultPartitionGrouper();
+        int taskId;
+        Map<Integer, List<TopicPartition>> expected;
+
+        grouper.topicGroups(mkList(mkSet("topic1"), mkSet("topic2")));
+
+        expected = new HashMap<>();
+        taskId = 0;
+        expected.put(taskId++, mkList(new TopicPartition("topic1", 0)));
+        expected.put(taskId++, mkList(new TopicPartition("topic1", 1)));
+        expected.put(taskId++, mkList(new TopicPartition("topic1", 2)));
+        expected.put(taskId++, mkList(new TopicPartition("topic2", 0)));
+        expected.put(taskId,   mkList(new TopicPartition("topic2", 1)));
+
+        assertEquals(expected, grouper.partitionGroups(metadata));
+
+        grouper.topicGroups(mkList(mkSet("topic1", "topic2")));
+
+        expected = new HashMap<>();
+        taskId = 0;
+        expected.put(taskId++, mkList(new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)));
+        expected.put(taskId++, mkList(new TopicPartition("topic1", 1), new TopicPartition("topic2", 1)));
+        expected.put(taskId,   mkList(new TopicPartition("topic1", 2)));
+
+        assertEquals(expected, grouper.partitionGroups(metadata));
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
index 00522d5..05d24d3 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/TopologyBuilderTest.java
@@ -19,9 +19,16 @@ package org.apache.kafka.streams.processor;
 
 import static org.junit.Assert.assertEquals;
 
+import static org.apache.kafka.common.utils.Utils.mkSet;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.junit.Test;
 
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
 public class TopologyBuilderTest {
 
     @Test(expected = TopologyException.class)
@@ -94,6 +101,38 @@ public class TopologyBuilderTest {
         builder.addSource("source-2", "topic-2");
         builder.addSource("source-3", "topic-3");
 
-        assertEquals(builder.sourceTopics().size(), 3);
+        assertEquals(3, builder.sourceTopics().size());
+    }
+
+    @Test
+    public void testTopicGroups() {
+        final TopologyBuilder builder = new TopologyBuilder();
+
+        builder.addSource("source-1", "topic-1", "topic-1x");
+        builder.addSource("source-2", "topic-2");
+        builder.addSource("source-3", "topic-3");
+        builder.addSource("source-4", "topic-4");
+        builder.addSource("source-5", "topic-5");
+
+        builder.addProcessor("processor-1", new MockProcessorSupplier(), "source-1");
+
+        builder.addProcessor("processor-2", new MockProcessorSupplier(), "source-2", "processor-1");
+        builder.copartitionSources(list("source-1", "source-2"));
+
+        builder.addProcessor("processor-3", new MockProcessorSupplier(), "source-3", "source-4");
+
+        Collection<Set<String>> topicGroups = builder.topicGroups();
+
+        assertEquals(3, topicGroups.size());
+        assertEquals(mkSet(mkSet("topic-1", "topic-1x", "topic-2"), mkSet("topic-3", "topic-4"), mkSet("topic-5")), new HashSet<>(topicGroups));
+
+        Collection<Set<String>> copartitionGroups = builder.copartitionGroups();
+
+        assertEquals(mkSet(mkSet("topic-1", "topic-1x", "topic-2")), new HashSet<>(copartitionGroups));
+    }
+
+    private <T> List<T> list(T... elems) {
+        return Arrays.asList(elems);
     }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/test/java/org/apache/kafka/streams/processor/internals/QuickUnionTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/QuickUnionTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/QuickUnionTest.java
new file mode 100644
index 0000000..c40e881
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/QuickUnionTest.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.junit.Test;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+
+public class QuickUnionTest {
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testUnite() {
+        QuickUnion<Long> qu = new QuickUnion<>();
+
+        long[] ids = {
+            1L, 2L, 3L, 4L, 5L
+        };
+
+        for (long id : ids) {
+            qu.add(id);
+        }
+
+        assertEquals(5, roots(qu, ids).size());
+
+        qu.unite(1L, 2L);
+        assertEquals(4, roots(qu, ids).size());
+        assertEquals(qu.root(1L), qu.root(2L));
+
+        qu.unite(3L, 4L);
+        assertEquals(3, roots(qu, ids).size());
+        assertEquals(qu.root(1L), qu.root(2L));
+        assertEquals(qu.root(3L), qu.root(4L));
+
+        qu.unite(1L, 5L);
+        assertEquals(2, roots(qu, ids).size());
+        assertEquals(qu.root(1L), qu.root(2L));
+        assertEquals(qu.root(2L), qu.root(5L));
+        assertEquals(qu.root(3L), qu.root(4L));
+
+        qu.unite(3L, 5L);
+        assertEquals(1, roots(qu, ids).size());
+        assertEquals(qu.root(1L), qu.root(2L));
+        assertEquals(qu.root(2L), qu.root(3L));
+        assertEquals(qu.root(3L), qu.root(4L));
+        assertEquals(qu.root(4L), qu.root(5L));
+    }
+
+    @Test
+    public void testUniteMany() {
+        QuickUnion<Long> qu = new QuickUnion<>();
+
+        long[] ids = {
+            1L, 2L, 3L, 4L, 5L
+        };
+
+        for (long id : ids) {
+            qu.add(id);
+        }
+
+        assertEquals(5, roots(qu, ids).size());
+
+        qu.unite(1L, 2L, 3L, 4L);
+        assertEquals(2, roots(qu, ids).size());
+        assertEquals(qu.root(1L), qu.root(2L));
+        assertEquals(qu.root(2L), qu.root(3L));
+        assertEquals(qu.root(3L), qu.root(4L));
+        assertNotEquals(qu.root(1L), qu.root(5L));
+    }
+
+    private Set<Long> roots(QuickUnion<Long> qu, long... ids) {
+        HashSet<Long> roots = new HashSet<>();
+        for (long id : ids) {
+            roots.add(qu.root(id));
+        }
+        return roots;
+    }
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/71399ffe/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index a7e707e..cbb2558 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -25,8 +25,12 @@ import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.MockConsumer;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.clients.consumer.internals.PartitionAssignor;
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
@@ -34,7 +38,9 @@ import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.SystemTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamingConfig;
+import org.apache.kafka.streams.processor.PartitionGrouper;
 import org.apache.kafka.streams.processor.TopologyBuilder;
+import org.apache.kafka.test.MockProcessorSupplier;
 import org.junit.Test;
 
 import java.io.File;
@@ -55,6 +61,31 @@ public class StreamThreadTest {
     private TopicPartition t1p2 = new TopicPartition("topic1", 2);
     private TopicPartition t2p1 = new TopicPartition("topic2", 1);
     private TopicPartition t2p2 = new TopicPartition("topic2", 2);
+    private TopicPartition t3p1 = new TopicPartition("topic3", 1);
+    private TopicPartition t3p2 = new TopicPartition("topic3", 2);
+
+    private List<PartitionInfo> infos = Arrays.asList(
+            new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic3", 0, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic3", 1, Node.noNode(), new Node[0], new Node[0]),
+            new PartitionInfo("topic3", 2, Node.noNode(), new Node[0], new Node[0])
+    );
+
+    private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos);
+
+    PartitionAssignor.Subscription subscription = new PartitionAssignor.Subscription(Arrays.asList("topic1", "topic2", "topic3"));
+
+    // task0 is unused
+    private final int task1 = 1;
+    private final int task2 = 2;
+    // task3 is unused
+    private final int task4 = 4;
+    private final int task5 = 5;
 
     private Properties configProps() {
         return new Properties() {
@@ -104,6 +135,8 @@ public class StreamThreadTest {
         TopologyBuilder builder = new TopologyBuilder();
         builder.addSource("source1", "topic1");
         builder.addSource("source2", "topic2");
+        builder.addSource("source3", "topic3");
+        builder.addProcessor("processor", new MockProcessorSupplier(), "source2", "source3");
 
         StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), new SystemTime()) {
             @Override
@@ -112,6 +145,8 @@ public class StreamThreadTest {
             }
         };
 
+        initPartitionGrouper(thread);
+
         ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
 
         assertTrue(thread.tasks().isEmpty());
@@ -128,8 +163,8 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
-        assertTrue(thread.tasks().containsKey(1));
-        assertEquals(expectedGroup1, thread.tasks().get(1).partitions());
+        assertTrue(thread.tasks().containsKey(task1));
+        assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
         assertEquals(1, thread.tasks().size());
 
         revokedPartitions = assignedPartitions;
@@ -139,8 +174,8 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
-        assertTrue(thread.tasks().containsKey(2));
-        assertEquals(expectedGroup2, thread.tasks().get(2).partitions());
+        assertTrue(thread.tasks().containsKey(task2));
+        assertEquals(expectedGroup2, thread.tasks().get(task2).partitions());
         assertEquals(1, thread.tasks().size());
 
         revokedPartitions = assignedPartitions;
@@ -151,24 +186,38 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
-        assertTrue(thread.tasks().containsKey(1));
-        assertTrue(thread.tasks().containsKey(2));
-        assertEquals(expectedGroup1, thread.tasks().get(1).partitions());
-        assertEquals(expectedGroup2, thread.tasks().get(2).partitions());
+        assertTrue(thread.tasks().containsKey(task1));
+        assertTrue(thread.tasks().containsKey(task2));
+        assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
+        assertEquals(expectedGroup2, thread.tasks().get(task2).partitions());
+        assertEquals(2, thread.tasks().size());
+
+        revokedPartitions = assignedPartitions;
+        assignedPartitions = Arrays.asList(t2p1, t2p2, t3p1, t3p2);
+        expectedGroup1 = new HashSet<>(Arrays.asList(t2p1, t3p1));
+        expectedGroup2 = new HashSet<>(Arrays.asList(t2p2, t3p2));
+
+        rebalanceListener.onPartitionsRevoked(revokedPartitions);
+        rebalanceListener.onPartitionsAssigned(assignedPartitions);
+
+        assertTrue(thread.tasks().containsKey(task4));
+        assertTrue(thread.tasks().containsKey(task5));
+        assertEquals(expectedGroup1, thread.tasks().get(task4).partitions());
+        assertEquals(expectedGroup2, thread.tasks().get(task5).partitions());
         assertEquals(2, thread.tasks().size());
 
         revokedPartitions = assignedPartitions;
-        assignedPartitions = Arrays.asList(t1p1, t1p2, t2p1, t2p2);
-        expectedGroup1 = new HashSet<>(Arrays.asList(t1p1, t2p1));
-        expectedGroup2 = new HashSet<>(Arrays.asList(t1p2, t2p2));
+        assignedPartitions = Arrays.asList(t1p1, t2p1, t3p1);
+        expectedGroup1 = new HashSet<>(Arrays.asList(t1p1));
+        expectedGroup2 = new HashSet<>(Arrays.asList(t2p1, t3p1));
 
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
-        assertTrue(thread.tasks().containsKey(1));
-        assertTrue(thread.tasks().containsKey(2));
-        assertEquals(expectedGroup1, thread.tasks().get(1).partitions());
-        assertEquals(expectedGroup2, thread.tasks().get(2).partitions());
+        assertTrue(thread.tasks().containsKey(task1));
+        assertTrue(thread.tasks().containsKey(task4));
+        assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
+        assertEquals(expectedGroup2, thread.tasks().get(task4).partitions());
         assertEquals(2, thread.tasks().size());
 
         revokedPartitions = assignedPartitions;
@@ -213,12 +262,15 @@ public class StreamThreadTest {
                 public void maybeClean() {
                     super.maybeClean();
                 }
+
                 @Override
                 protected StreamTask createStreamTask(int id, Collection<TopicPartition> partitionsForTask) {
                     return new TestStreamTask(id, consumer, producer, mockRestoreConsumer, partitionsForTask, builder.build(), config);
                 }
             };
 
+            initPartitionGrouper(thread);
+
             ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
 
             assertTrue(thread.tasks().isEmpty());
@@ -235,7 +287,7 @@ public class StreamThreadTest {
             Map<Integer, StreamTask> prevTasks;
 
             //
-            // Assign t1p1 and t1p2. This should create Task 1 & 2
+            // Assign t1p1 and t1p2. This should create task1 & task2
             //
             revokedPartitions = Collections.emptyList();
             assignedPartitions = Arrays.asList(t1p1, t1p2);
@@ -258,7 +310,7 @@ public class StreamThreadTest {
             assertTrue(stateDir3.exists());
             assertTrue(extraDir.exists());
 
-            // all state directories except for task 1 & 2 will be removed. the extra directory should still exists
+            // all state directories except for task task2 & task3 will be removed. the extra directory should still exists
             mockTime.sleep(11L);
             thread.maybeClean();
             assertTrue(stateDir1.exists());
@@ -267,7 +319,7 @@ public class StreamThreadTest {
             assertTrue(extraDir.exists());
 
             //
-            // Revoke t1p1 and t1p2. This should remove Task 1 & 2
+            // Revoke t1p1 and t1p2. This should remove task1 & task2
             //
             revokedPartitions = assignedPartitions;
             assignedPartitions = Collections.emptyList();
@@ -286,7 +338,7 @@ public class StreamThreadTest {
             // no task
             assertTrue(thread.tasks().isEmpty());
 
-            // all state directories for task 1 & 2 still exist before the cleanup delay time
+            // all state directories for task task1 & task2 still exist before the cleanup delay time
             mockTime.sleep(cleanupDelay - 10L);
             thread.maybeClean();
             assertTrue(stateDir1.exists());
@@ -294,7 +346,7 @@ public class StreamThreadTest {
             assertFalse(stateDir3.exists());
             assertTrue(extraDir.exists());
 
-            // all state directories for task 1 & 2 are removed
+            // all state directories for task task1 & task2 are removed
             mockTime.sleep(11L);
             thread.maybeClean();
             assertFalse(stateDir1.exists());
@@ -331,12 +383,15 @@ public class StreamThreadTest {
                 public void maybeCommit() {
                     super.maybeCommit();
                 }
+
                 @Override
                 protected StreamTask createStreamTask(int id, Collection<TopicPartition> partitionsForTask) {
                     return new TestStreamTask(id, consumer, producer, mockRestoreConsumer, partitionsForTask, builder.build(), config);
                 }
             };
 
+            initPartitionGrouper(thread);
+
             ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
 
             List<TopicPartition> revokedPartitions;
@@ -387,4 +442,19 @@ public class StreamThreadTest {
             Utils.delete(baseDir);
         }
     }
+
+    private void initPartitionGrouper(StreamThread thread) {
+        PartitionGrouper partitionGrouper = thread.partitionGrouper();
+
+        KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
+
+        partitionAssignor.configure(
+                Collections.singletonMap(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper)
+        );
+
+        Map<String, PartitionAssignor.Assignment> assignments =
+                partitionAssignor.assign(metadata, Collections.singletonMap("client", subscription));
+
+        partitionAssignor.onAssignment(assignments.get("client"));
+    }
 }