You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ni...@apache.org on 2016/04/05 02:27:11 UTC

[1/2] samza git commit: SAMZA-906: Host Affinity - Minimize task reassignment when container count changes

Repository: samza
Updated Branches:
  refs/heads/master 3dce4935c -> 2a531b0bb


http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
new file mode 100644
index 0000000..7f83494
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.container.grouper.task;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory;
+import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory.MockCoordinatorStreamSystemConsumer;
+import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory.MockCoordinatorStreamSystemProducer;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class TestTaskAssignmentManager {
+
+  private final MockCoordinatorStreamSystemFactory mockCoordinatorStreamSystemFactory =
+      new MockCoordinatorStreamSystemFactory();
+  private final Config config = new MapConfig(
+      new HashMap<String, String>() {
+        {
+          this.put("job.name", "test-job");
+          this.put("job.coordinator.system", "test-kafka");
+        }
+      });
+
+  @Before
+  public void setup() {
+    MockCoordinatorStreamSystemFactory.enableMockConsumerCache();
+  }
+
+  @After
+  public void tearDown() {
+    MockCoordinatorStreamSystemFactory.disableMockConsumerCache();
+  }
+
+  @Test public void testTaskAssignmentManager() throws Exception {
+    MockCoordinatorStreamSystemProducer producer =
+        mockCoordinatorStreamSystemFactory.getCoordinatorStreamSystemProducer(config, null);
+    MockCoordinatorStreamSystemConsumer consumer =
+        mockCoordinatorStreamSystemFactory.getCoordinatorStreamSystemConsumer(config, null);
+    TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(producer, consumer);
+
+    taskAssignmentManager.register(new TaskName("ignoredTaskName"));
+    assertTrue(producer.isRegistered());
+    assertEquals(producer.getRegisteredSource(), "SamzaTaskAssignmentManager");
+    assertTrue(consumer.isRegistered());
+
+    taskAssignmentManager.start();
+    assertTrue(producer.isStarted());
+    assertTrue(consumer.isStarted());
+
+    Map<String, Integer> expectedMap =
+        new HashMap<String, Integer>() {
+          {
+            this.put("Task0", new Integer(0));
+            this.put("Task1", new Integer(1));
+            this.put("Task2", new Integer(2));
+            this.put("Task3", new Integer(0));
+            this.put("Task4", new Integer(1));
+          }
+        };
+
+    for (Map.Entry<String, Integer> entry : expectedMap.entrySet()) {
+      taskAssignmentManager.writeTaskContainerMapping(entry.getKey(), entry.getValue());
+    }
+
+    Map<String, Integer> localMap = taskAssignmentManager.readTaskAssignment();
+
+    assertEquals(expectedMap, localMap);
+
+    taskAssignmentManager.stop();
+    assertTrue(producer.isStopped());
+    assertTrue(consumer.isStopped());
+  }
+
+  @Test public void testTaskAssignmentManagerEmptyCoordinatorStream() throws Exception {
+    MockCoordinatorStreamSystemProducer producer =
+        mockCoordinatorStreamSystemFactory.getCoordinatorStreamSystemProducer(config, null);
+    MockCoordinatorStreamSystemConsumer consumer =
+        mockCoordinatorStreamSystemFactory.getCoordinatorStreamSystemConsumer(config, null);
+    TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(producer, consumer);
+
+    taskAssignmentManager.register(new TaskName("ignoredTaskName"));
+    assertTrue(producer.isRegistered());
+    assertEquals(producer.getRegisteredSource(), "SamzaTaskAssignmentManager");
+    assertTrue(consumer.isRegistered());
+
+    taskAssignmentManager.start();
+    assertTrue(producer.isStarted());
+    assertTrue(consumer.isStarted());
+
+    Map<String, Integer> expectedMap = new HashMap<>();
+    Map<String, Integer> localMap = taskAssignmentManager.readTaskAssignment();
+
+    assertEquals(expectedMap, localMap);
+
+    taskAssignmentManager.stop();
+    assertTrue(producer.isStopped());
+    assertTrue(consumer.isStopped());
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
new file mode 100644
index 0000000..3b184d3
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.mock;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.system.SystemStreamPartition;
+
+
+public class ContainerMocks {
+  /**
+   * Note: This is artificial. It makes assumptions about task grouping, changelog partition mapping, etc.
+   * that will likely not be true at runtime, but are not the focus of this test.
+   */
+  public static Set<ContainerModel> generateContainerModels(int numContainers, int taskCount) {
+    Set<ContainerModel> models = new HashSet<>(numContainers);
+    int[] taskCountPerContainer = calculateTaskCountPerContainer(taskCount, numContainers);
+    int j = 0;
+    for (int i = 0; i < numContainers; i++) {
+      int[] partitions = new int[taskCountPerContainer[i]];
+      for (int k = 0; k < taskCountPerContainer[i]; k++) {
+        partitions[k] = j + k;
+      }
+      j += taskCountPerContainer[i];
+
+      models.add(createContainerModel(i, partitions));
+    }
+    return models;
+  }
+
+  public static Map<String, Integer> generateTaskAssignments(int numContainers, int taskCount) {
+    Map<String, Integer> mapping = new HashMap<>(taskCount);
+    Set<ContainerModel> containers = generateContainerModels(numContainers, taskCount);
+    for (ContainerModel container : containers) {
+      for (TaskName taskName : container.getTasks().keySet()) {
+        mapping.put(taskName.getTaskName(), container.getContainerId());
+      }
+    }
+    return mapping;
+  }
+
+  public static int[] calculateTaskCountPerContainer(int taskCount, int currentContainerCount) {
+    int[] newTaskCountPerContainer = new int[currentContainerCount];
+    for (int i = 0; i < currentContainerCount; i++) {
+      newTaskCountPerContainer[i] = taskCount / currentContainerCount;
+      if (taskCount % currentContainerCount > i) {
+        newTaskCountPerContainer[i]++;
+      }
+    }
+    return newTaskCountPerContainer;
+  }
+
+  public static ContainerModel createContainerModel(int containerId, int[] partitions) {
+    Map<TaskName, TaskModel> tasks = new HashMap<>();
+    for (int partition : partitions) {
+      tasks.put(getTaskName(partition), getTaskModel(partition));
+    }
+    return new ContainerModel(containerId, tasks);
+  }
+
+  public static Set<TaskModel> generateTaskModels(int[] partitions) {
+    Set<TaskModel> models = new HashSet<>(partitions.length);
+    for (int partition : partitions) {
+      models.add(getTaskModel(partition));
+    }
+    return models;
+  }
+
+  public static Set<TaskModel> generateTaskModels(int count) {
+    Set<TaskModel> taskModels = new HashSet<>();
+    for (int i = 0; i < count; i++) {
+      taskModels.add(getTaskModel(i));
+    }
+    return taskModels;
+  }
+
+  public static TaskModel getTaskModel(int partitionId) {
+    return new TaskModel(getTaskName(partitionId),
+        new HashSet<>(
+            Arrays.asList(new SystemStreamPartition[]{new SystemStreamPartition("System", "Stream", new Partition(partitionId))})),
+        new Partition(partitionId));
+  }
+
+  public static TaskName getTaskName(int partitionId) {
+    return new TaskName("Partition " + partitionId);
+  }
+
+  // Inclusive of both indices
+  public static int[] range(int from, int to) {
+    int[] values = new int[to - from + 1];
+    for (int i = 0; i < values.length; i++) {
+      values[i] = from++;
+    }
+    return values;
+  }
+
+  public static Map<String, Integer> generateTaskContainerMapping(Set<ContainerModel> containers) {
+    Map<String, Integer> taskMapping = new HashMap<>();
+    for (ContainerModel container : containers) {
+      for (TaskName taskName : container.getTasks().keySet()) {
+        taskMapping.put(taskName.getTaskName(), container.getContainerId());
+      }
+    }
+    return taskMapping;
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamSystemFactory.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamSystemFactory.java b/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamSystemFactory.java
index e0d4aa1..662c737 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamSystemFactory.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamSystemFactory.java
@@ -137,11 +137,13 @@ public class MockCoordinatorStreamSystemFactory implements SystemFactory {
 
   public static final class MockCoordinatorStreamSystemConsumer extends CoordinatorStreamSystemConsumer {
     private final MockCoordinatorStreamWrappedConsumer consumer;
+    private final SystemStream stream;
     private boolean isRegistered = false;
     private boolean isStarted = false;
 
     public MockCoordinatorStreamSystemConsumer(SystemStream stream, SystemConsumer consumer, SystemAdmin admin) {
       super(stream, consumer, admin);
+      this.stream = stream;
       this.consumer = (MockCoordinatorStreamWrappedConsumer) consumer;
     }
 
@@ -150,6 +152,8 @@ public class MockCoordinatorStreamSystemFactory implements SystemFactory {
     }
 
     public void register() {
+      SystemStreamPartition ssp = new SystemStreamPartition(stream, new Partition(0));
+      consumer.register(ssp, "");
       isRegistered = true;
     }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamWrappedConsumer.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamWrappedConsumer.java b/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamWrappedConsumer.java
index 429573b..d7e8654 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamWrappedConsumer.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/stream/MockCoordinatorStreamWrappedConsumer.java
@@ -55,6 +55,12 @@ public class MockCoordinatorStreamWrappedConsumer extends BlockingEnvelopeMap {
     this.systemStreamPartition = systemStreamPartition;
   }
 
+  @Override
+  public void register(SystemStreamPartition systemStreamPartition, String offset) {
+    super.register(systemStreamPartition, offset);
+    setIsAtHead(systemStreamPartition, true);
+  }
+
   public void start() {
     convertConfigToCoordinatorMessage(config);
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java b/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
index 53207ad..13f4fa9 100644
--- a/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
+++ b/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
@@ -24,14 +24,12 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
-
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemAdmin;
-import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
 import org.apache.samza.system.SystemStreamPartition;
@@ -39,21 +37,21 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class TestStorageRecovery {
 
-  public static SystemConsumer systemConsumer1 = null;
-  public static SystemConsumer systemConsumer2 = null;
   public static SystemAdmin systemAdmin = null;
-  public static Config config = null;
-  public static SystemStreamMetadata systemStreamMetadata = null;
-  public static SystemStreamMetadata inputSystemStreamMetadata = null;
-  public final static String SYSTEM_STREAM_NAME = "changelog";
-  public final static String INPUT_STREAM = "input";
+  public Config config = null;
+  public SystemStreamMetadata systemStreamMetadata = null;
+  public SystemStreamMetadata inputSystemStreamMetadata = null;
+  private static final String SYSTEM_STREAM_NAME = "changelog";
+  private static final String INPUT_STREAM = "input";
+  private static final String STORE_NAME = "testStore";
   public static SystemStreamPartition ssp = new SystemStreamPartition("mockSystem", SYSTEM_STREAM_NAME, new Partition(0));
-  public static IncomingMessageEnvelope msg = new IncomingMessageEnvelope(TestStorageRecovery.ssp, "0", "test", "test");
+  public static IncomingMessageEnvelope msg = new IncomingMessageEnvelope(ssp, "0", "test", "test");
 
   @Before
   public void setup() throws InterruptedException {
@@ -86,18 +84,20 @@ public class TestStorageRecovery {
 
     // because the stream has two partitions
     assertEquals(2, MockStorageEngine.incomingMessageEnvelopes.size());
-    assertEquals(TestStorageRecovery.msg, MockStorageEngine.incomingMessageEnvelopes.get(0));
-    assertEquals(TestStorageRecovery.msg, MockStorageEngine.incomingMessageEnvelopes.get(1));
+    assertEquals(msg, MockStorageEngine.incomingMessageEnvelopes.get(0));
+    assertEquals(msg, MockStorageEngine.incomingMessageEnvelopes.get(1));
     // correct path is passed to the store engine
-    assertEquals(path + "/state/testStore/Partition_1", MockStorageEngine.storeDir.toString());
+    String expectedStoreDir = String.format("%s/state/%s/Partition_", path, STORE_NAME);
+    String actualStoreDir = MockStorageEngine.storeDir.toString();
+    assertEquals(expectedStoreDir, actualStoreDir.substring(0, actualStoreDir.length() - 1));
   }
 
   private void putConfig() {
     Map<String, String> map = new HashMap<String, String>();
     map.put("job.name", "changelogTest");
     map.put("systems.mockSystem.samza.factory", MockSystemFactory.class.getCanonicalName());
-    map.put("stores.testStore.factory", MockStorageEngineFactory.class.getCanonicalName());
-    map.put("stores.testStore.changelog", "mockSystem." + SYSTEM_STREAM_NAME);
+    map.put(String.format("stores.%s.factory", STORE_NAME), MockStorageEngineFactory.class.getCanonicalName());
+    map.put(String.format("stores.%s.changelog", STORE_NAME), "mockSystem." + SYSTEM_STREAM_NAME);
     map.put("task.inputs", "mockSystem.input");
     map.put("job.coordinator.system", "coordinator");
     map.put("systems.coordinator.samza.factory", MockCoordinatorStreamSystemFactory.class.getCanonicalName());

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/test/scala/org/apache/samza/container/grouper/task/TestGroupByContainerCount.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/grouper/task/TestGroupByContainerCount.scala b/samza-core/src/test/scala/org/apache/samza/container/grouper/task/TestGroupByContainerCount.scala
deleted file mode 100644
index 6e9c6fa..0000000
--- a/samza-core/src/test/scala/org/apache/samza/container/grouper/task/TestGroupByContainerCount.scala
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.container.grouper.task
-
-import java.util
-
-import org.apache.samza.container.TaskName
-import org.apache.samza.system.SystemStreamPartition
-import org.junit.Assert._
-import org.junit.Test
-import org.apache.samza.job.model.TaskModel
-import org.apache.samza.Partition
-import org.scalatest.Assertions.intercept
-import scala.collection.JavaConversions._
-
-class TestGroupByContainerCount {
-  @Test
-  def testEmptyTasks {
-    intercept[IllegalArgumentException] { new GroupByContainerCount(1).group(new util.HashSet()) }
-  }
-
-  @Test
-  def testFewerTasksThanContainers {
-    val taskModels = new util.HashSet[TaskModel]()
-    taskModels.add(getTaskModel("1", 1))
-    intercept[IllegalArgumentException] { new GroupByContainerCount(2).group(taskModels) }
-  }
-
-  @Test
-  def testHappyPath {
-    val taskModels = Set(
-      getTaskModel("1", 1),
-      getTaskModel("2", 2),
-      getTaskModel("3", 3),
-      getTaskModel("4", 4),
-      getTaskModel("5", 5))
-    val containers = asScalaSet(new GroupByContainerCount(2)
-      .group(setAsJavaSet(taskModels)))
-      .map(containerModel => containerModel.getContainerId -> containerModel)
-      .toMap
-    assertEquals(2, containers.size)
-    val container0 = containers(0)
-    val container1 = containers(1)
-    assertNotNull(container0)
-    assertNotNull(container1)
-    assertEquals(0, container0.getContainerId)
-    assertEquals(1, container1.getContainerId)
-    assertEquals(3, container0.getTasks.size)
-    assertEquals(2, container1.getTasks.size)
-    assertTrue(container0.getTasks.containsKey(new TaskName("1")))
-    assertTrue(container0.getTasks.containsKey(new TaskName("3")))
-    assertTrue(container0.getTasks.containsKey(new TaskName("5")))
-    assertTrue(container1.getTasks.containsKey(new TaskName("2")))
-    assertTrue(container1.getTasks.containsKey(new TaskName("4")))
-  }
-
-  private def getTaskModel(name: String, partitionId: Int) = {
-    new TaskModel(new TaskName(name), Set[SystemStreamPartition](), new Partition(partitionId))
-  }
-}


[2/2] samza git commit: SAMZA-906: Host Affinity - Minimize task reassignment when container count changes

Posted by ni...@apache.org.
SAMZA-906: Host Affinity - Minimize task reassignment when container count changes


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/2a531b0b
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/2a531b0b
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/2a531b0b

Branch: refs/heads/master
Commit: 2a531b0bbfb11bfc1409f365d2d2fc920a7825ac
Parents: 3dce493
Author: Jacob Maes <ja...@gmail.com>
Authored: Mon Apr 4 17:26:25 2016 -0700
Committer: Yi Pan (Data Infrastructure) <ni...@gmail.com>
Committed: Mon Apr 4 17:26:25 2016 -0700

----------------------------------------------------------------------
 checkstyle/import-control.xml                   |   9 +-
 .../apache/samza/container/LocalityManager.java |  11 +-
 .../grouper/task/BalancingTaskNameGrouper.java  |  58 ++
 .../grouper/task/GroupByContainerCount.java     | 355 ++++++++
 .../task/GroupByContainerCountFactory.java      |  33 +
 .../grouper/task/TaskAssignmentManager.java     | 131 +++
 .../AbstractCoordinatorStreamManager.java       |   2 +-
 .../messages/SetContainerHostMapping.java       |   2 +-
 .../messages/SetTaskContainerMapping.java       |  72 ++
 .../grouper/task/GroupByContainerCount.scala    |  55 --
 .../task/GroupByContainerCountFactory.scala     |  30 -
 .../samza/coordinator/JobCoordinator.scala      |  13 +-
 .../grouper/task/TestGroupByContainerCount.java | 837 +++++++++++++++++++
 .../grouper/task/TestTaskAssignmentManager.java | 124 +++
 .../samza/container/mock/ContainerMocks.java    | 129 +++
 .../MockCoordinatorStreamSystemFactory.java     |   4 +
 .../MockCoordinatorStreamWrappedConsumer.java   |   6 +
 .../samza/storage/TestStorageRecovery.java      |  34 +-
 .../task/TestGroupByContainerCount.scala        |  76 --
 19 files changed, 1793 insertions(+), 188 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/checkstyle/import-control.xml
----------------------------------------------------------------------
diff --git a/checkstyle/import-control.xml b/checkstyle/import-control.xml
index b5bd365..c15b8e7 100644
--- a/checkstyle/import-control.xml
+++ b/checkstyle/import-control.xml
@@ -116,13 +116,12 @@
 
     <subpackage name="container">
         <allow pkg="org.apache.samza.config" />
+        <allow pkg="org.apache.samza.container" />
         <allow pkg="org.apache.samza.coordinator.stream" />
         <allow class="org.apache.samza.coordinator.stream.AbstractCoordinatorStreamManager" />
         <subpackage name="grouper">
             <subpackage name="stream">
-                <allow pkg="org.apache.samza.container" />
                 <allow pkg="org.apache.samza.system" />
-
                 <allow class="org.apache.samza.Partition" />
             </subpackage>
 
@@ -130,6 +129,12 @@
                 <allow pkg="org.apache.samza.job" />
             </subpackage>
         </subpackage>
+        <subpackage name="mock">
+            <allow class="org.apache.samza.container.TaskName" />
+            <allow pkg="org.apache.samza.job.model" />
+            <allow class="org.apache.samza.Partition" />
+            <allow pkg="org.apache.samza.system" />
+        </subpackage>
     </subpackage>
 
     <subpackage name="coordinator">

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/container/LocalityManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/LocalityManager.java b/samza-core/src/main/java/org/apache/samza/container/LocalityManager.java
index acf9352..a3281c2 100644
--- a/samza-core/src/main/java/org/apache/samza/container/LocalityManager.java
+++ b/samza-core/src/main/java/org/apache/samza/container/LocalityManager.java
@@ -19,6 +19,7 @@
 
 package org.apache.samza.container;
 
+import org.apache.samza.container.grouper.task.TaskAssignmentManager;
 import org.apache.samza.coordinator.stream.messages.CoordinatorStreamMessage;
 import org.apache.samza.coordinator.stream.CoordinatorStreamSystemConsumer;
 import org.apache.samza.coordinator.stream.CoordinatorStreamSystemProducer;
@@ -36,7 +37,8 @@ import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
  * */
 public class LocalityManager extends AbstractCoordinatorStreamManager {
   private static final Logger log = LoggerFactory.getLogger(LocalityManager.class);
-  private Map<Integer, Map<String, String>> containerToHostMapping;
+  private Map<Integer, Map<String, String>> containerToHostMapping = new HashMap<>();
+  private final TaskAssignmentManager taskAssignmentManager;
   private final boolean writeOnly;
 
   /**
@@ -48,8 +50,8 @@ public class LocalityManager extends AbstractCoordinatorStreamManager {
   public LocalityManager(CoordinatorStreamSystemProducer coordinatorStreamProducer,
                          CoordinatorStreamSystemConsumer coordinatorStreamConsumer) {
     super(coordinatorStreamProducer, coordinatorStreamConsumer, "SamzaContainer-");
-    this.containerToHostMapping = new HashMap<>();
     this.writeOnly = coordinatorStreamConsumer == null;
+    this.taskAssignmentManager = new TaskAssignmentManager(coordinatorStreamProducer, coordinatorStreamConsumer);
   }
 
   /**
@@ -67,6 +69,7 @@ public class LocalityManager extends AbstractCoordinatorStreamManager {
    *
    * @throws UnsupportedOperationException in the case if a {@link TaskName} is passed
    */
+  @Override
   public void register(TaskName taskName) {
     throw new UnsupportedOperationException("TaskName cannot be registered with LocalityManager");
   }
@@ -136,4 +139,8 @@ public class LocalityManager extends AbstractCoordinatorStreamManager {
     mappings.put(SetContainerHostMapping.JMX_TUNNELING_URL_KEY, jmxTunnelingAddress);
     containerToHostMapping.put(containerId, mappings);
   }
+
+  public TaskAssignmentManager getTaskAssignmentManager() {
+    return taskAssignmentManager;
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java
new file mode 100644
index 0000000..f8295c8
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import java.util.Set;
+import org.apache.samza.container.LocalityManager;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
+
+
+/**
+ * Extends {@link TaskNameGrouper} with the ability to balance/redistribute the
+ * tasks from a persisted task to container mapping. The goal of balancing is typically
+ * to minimize the changes to the ContainerModels across job runs. This balance method returns
+ * an equivalent set of {@link ContainerModel} as the group method, but it derives
+ * from the persisted mapping, rather than from scratch. Thus the balance method is
+ * called in lieu of group when the mapping is available and minimal changes are desired.
+ *
+ * {@inheritDoc}
+ */
+public interface BalancingTaskNameGrouper extends TaskNameGrouper {
+
+  /**
+   * Rebalances the tasks using the provided {@link LocalityManager}. The goal is typically
+   * to minimize changes to the ContainerModels, e.g. when the container count changes.
+   * This helps maximize the consistency of task-container locality, which is useful for optimization.
+   * Each time balance() is called, locality information is read, it is used to balance the tasks,
+   * and then the new locality information is saved.
+   *
+   * If balancing cannot be applied, then {@link TaskNameGrouper#group(Set)} should be used to
+   * retrieve an appropriate set of ContainerModels. i.e. this method is a complete replacement
+   * for {@link TaskNameGrouper#group(Set)}
+   *
+   * Implementations should prefer to use the previous mapping rather than calling
+   * {@link TaskNameGrouper#group(Set)} to enable external custom task assignments.
+   *
+   * @param tasks           the tasks to group.
+   * @param localityManager provides a persisted task to container map to use as a baseline
+   * @return                the grouped tasks in the form of ContainerModels
+   */
+  Set<ContainerModel> balance(Set<TaskModel> tasks, LocalityManager localityManager);
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
new file mode 100644
index 0000000..286ea1b
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
@@ -0,0 +1,355 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Group the SSP taskNames by dividing the number of taskNames into the number
+ * of containers (n) and assigning n taskNames to each container as returned by
+ * iterating over the keys in the map of taskNames (whatever that ordering
+ * happens to be). No consideration is given towards locality, even distribution
+ * of aggregate SSPs within a container, even distribution of the number of
+ * taskNames between containers, etc.
+ */
+public class GroupByContainerCount implements BalancingTaskNameGrouper {
+  private static final Logger log = LoggerFactory.getLogger(GroupByContainerCount.class);
+  private final int containerCount;
+
+  public GroupByContainerCount(int containerCount) {
+    if (containerCount <= 0) throw new IllegalArgumentException("Must have at least one container");
+    this.containerCount = containerCount;
+  }
+
+  @Override
+  public Set<ContainerModel> group(Set<TaskModel> tasks) {
+
+    validateTasks(tasks);
+
+    // Sort tasks by taskName.
+    List<TaskModel> sortedTasks = new ArrayList<>(tasks);
+    Collections.sort(sortedTasks);
+
+    // Map every task to a container in round-robin fashion.
+    Map<TaskName, TaskModel>[] taskGroups = new Map[containerCount];
+    for (int i = 0; i < containerCount; i++) {
+      taskGroups[i] = new HashMap<>();
+    }
+    for (int i = 0; i < sortedTasks.size(); i++) {
+      TaskModel tm = sortedTasks.get(i);
+      taskGroups[i % containerCount].put(tm.getTaskName(), tm);
+    }
+
+    // Convert to a Set of ContainerModel
+    Set<ContainerModel> containerModels = new HashSet<>();
+    for (int i = 0; i < containerCount; i++) {
+      containerModels.add(new ContainerModel(i, taskGroups[i]));
+    }
+
+    return Collections.unmodifiableSet(containerModels);
+  }
+
+  @Override
+  public Set<ContainerModel> balance(Set<TaskModel> tasks, LocalityManager localityManager) {
+
+    validateTasks(tasks);
+
+    TaskAssignmentManager taskAssignmentManager = localityManager.getTaskAssignmentManager();
+    List<TaskGroup> containers = getPreviousContainers(taskAssignmentManager, tasks.size());
+    if (containers == null || containers.size() == 1 || containerCount == 1) {
+      log.info("Balancing does not apply. Invoking grouper.");
+      Set<ContainerModel> models = group(tasks);
+      saveTaskAssignments(models, taskAssignmentManager);
+      return models;
+    }
+
+    int prevContainerCount = containers.size();
+    int containerDelta = containerCount - prevContainerCount;
+    if (containerDelta == 0) {
+      log.info("Container count has not changed. Reusing previous container models.");
+      return buildContainerModels(tasks, containers);
+    }
+    log.info("Container count changed from {} to {}. Balancing tasks.", prevContainerCount, containerCount);
+
+    // Calculate the expected task count per container
+    int[] expectedTaskCountPerContainer = calculateTaskCountPerContainer(tasks.size(), prevContainerCount, containerCount);
+
+    // Collect excess tasks from over-assigned containers
+    List<String> taskNamesToReassign = new LinkedList<>();
+    for (int i = 0; i < prevContainerCount; i++) {
+      TaskGroup taskGroup = containers.get(i);
+      while (taskGroup.size() > expectedTaskCountPerContainer[i]) {
+        taskNamesToReassign.add(taskGroup.removeTask());
+      }
+    }
+
+    // Assign tasks to the under-assigned containers
+    if (containerDelta > 0) {
+      List<TaskGroup> newContainers = createContainers(prevContainerCount, containerCount);
+      containers.addAll(newContainers);
+    } else {
+      containers = containers.subList(0, containerCount);
+    }
+    assignTasksToContainers(expectedTaskCountPerContainer, taskNamesToReassign, containers);
+
+    // Transform containers to containerModel
+    Set<ContainerModel> models = buildContainerModels(tasks, containers);
+
+    // Save the results
+    saveTaskAssignments(models, taskAssignmentManager);
+
+    return models;
+  }
+
+  /**
+   * Reads the task-container mapping from the provided {@link TaskAssignmentManager} and returns a
+   * list of TaskGroups, ordered ascending by containerId.
+   *
+   * @param taskAssignmentManager the {@link TaskAssignmentManager} that will be used to retrieve the previous mapping.
+   * @param taskCount             the number of tasks, for validation against the persisted tasks.
+   * @return                      a list of TaskGroups, ordered ascending by containerId or {@code null}
+   *                              if the previous mapping doesn't exist or isn't usable.
+   */
+  private List<TaskGroup> getPreviousContainers(TaskAssignmentManager taskAssignmentManager, int taskCount) {
+    Map<String, Integer> taskToContainerId = taskAssignmentManager.readTaskAssignment();
+    if (taskToContainerId.isEmpty()) {
+      log.info("No task assignment map was saved.");
+      return null;
+    } else if (taskCount != taskToContainerId.size()) {
+      log.warn(
+          "Current task count {} does not match saved task count {}. Stateful jobs may observe misalignment of keys!",
+          taskCount, taskToContainerId.size());
+      // If the tasks changed, then the partition-task grouping is also likely changed and we can't handle that
+      // without a much more complicated mapping. Further, the partition count may have changed, which means
+      // input message keys are likely reshuffled w.r.t. partitions, so the local state may not contain necessary
+      // data associated with the incoming keys. Warn the user and default to grouper
+      // In this scenario the tasks may have been reduced, so we need to delete all the existing messages
+      taskAssignmentManager.deleteTaskContainerMappings(taskToContainerId.keySet());
+      return null;
+    }
+
+    List<TaskGroup> containers;
+    try {
+      containers = getOrderedContainers(taskToContainerId);
+    } catch (Exception e) {
+      log.error("Exception while parsing task mapping", e);
+      return null;
+    }
+    return containers;
+  }
+
+  /**
+   * Saves the task assignments specified by containers using the provided TaskAssignementManager.
+   *
+   * @param containers            the set of containers from which the task assignments will be saved.
+   * @param taskAssignmentManager the {@link TaskAssignmentManager} that will be used to save the mappings.
+   */
+  private void saveTaskAssignments(Set<ContainerModel> containers, TaskAssignmentManager taskAssignmentManager) {
+    taskAssignmentManager.register(null);
+    for (ContainerModel container : containers) {
+      for (TaskName taskName : container.getTasks().keySet()) {
+        taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName(), container.getContainerId());
+      }
+    }
+  }
+
+  /**
+   * Verifies the input tasks argument and throws {@link IllegalArgumentException} if it is invalid.
+   *
+   * @param tasks the tasks to validate.
+   */
+  private void validateTasks(Set<TaskModel> tasks) {
+    if (tasks.size() <= 0)
+      throw new IllegalArgumentException("No tasks found. Likely due to no input partitions. Can't run a job with no tasks.");
+
+    if (tasks.size() < containerCount)
+      throw new IllegalArgumentException(String.format(
+          "Your container count (%s) is larger than your task count (%s). Can't have containers with nothing to do, so aborting.",
+          containerCount,
+          tasks.size()));
+  }
+
+  /**
+   * Creates a list of empty {@link TaskGroup} instances for a range of container id's
+   * from the start(inclusive) to end(exclusive) container id.
+   *
+   * @param startContainerId  the first container id for which a TaskGroup is needed.
+   * @param endContainerId    the first container id AFTER the last TaskGroup that is needed.
+   * @return                  a set of empty TaskGroup instances corresponding to the range
+   *                          [startContainerId, endContainerId)
+   */
+  private List<TaskGroup> createContainers(int startContainerId, int endContainerId) {
+    List<TaskGroup> containers = new ArrayList<>(endContainerId - startContainerId);
+    for (int i = startContainerId; i < endContainerId; i++) {
+      TaskGroup taskGroup = new TaskGroup(i, new ArrayList<String>());
+      containers.add(taskGroup);
+    }
+    return containers;
+  }
+
+  /**
+   * Assigns tasks from the specified list to containers that have fewer containers than indicated
+   * in taskCountPerContainer.
+   *
+   * @param taskCountPerContainer the expected number of tasks for each container.
+   * @param taskNamesToAssign     the list of tasks to assign to the containers.
+   * @param containers            the containers (as {@link TaskGroup}) to which the tasks will be assigned.
+   */
+  private void assignTasksToContainers(int[] taskCountPerContainer, List<String> taskNamesToAssign,
+      List<TaskGroup> containers) {
+    for (TaskGroup taskGroup : containers) {
+      for (int j = taskGroup.size(); j < taskCountPerContainer[taskGroup.getContainerId()]; j++) {
+        String taskName = taskNamesToAssign.remove(0);
+        taskGroup.addTaskName(taskName);
+        log.info("Assigned task {} to container {}", taskName, taskGroup.getContainerId());
+      }
+    }
+  }
+
+  /**
+   * Calculates the expected number of tasks for each container. The count is generated for
+   * max(oldContainerCount, newContainerCount) s.t. if the container count has decreased,
+   * the excess containers will have a count == 0, indicating that any tasks assigned to
+   * them should be reassigned.
+   *
+   * @param taskCount             the number of tasks to divide among the containers.
+   * @param prevContainerCount    the previous number of containers.
+   * @param currentContainerCount the current number of containers.
+   * @return                      the expected number of tasks for each container.
+   */
+  private int[] calculateTaskCountPerContainer(int taskCount, int prevContainerCount, int currentContainerCount) {
+    int[] newTaskCountPerContainer = new int[Math.max(currentContainerCount, prevContainerCount)];
+    Arrays.fill(newTaskCountPerContainer, 0);
+
+    for (int i = 0; i < currentContainerCount; i++) {
+      newTaskCountPerContainer[i] = taskCount / currentContainerCount;
+      if (taskCount % currentContainerCount > i) {
+        newTaskCountPerContainer[i]++;
+      }
+    }
+    return newTaskCountPerContainer;
+  }
+
+  /**
+   * Translates the list of TaskGroup instances to a set of ContainerModel instances, using the
+   * set of TaskModel instances.
+   *
+   * @param tasks             the TaskModels to assign to the ContainerModels.
+   * @param containerTasks    the TaskGroups defining how the tasks should be grouped.
+   * @return                  a mutable set of ContainerModels.
+   */
+  private Set<ContainerModel> buildContainerModels(Set<TaskModel> tasks, List<TaskGroup> containerTasks) {
+    // Map task names to models
+    Map<String, TaskModel> taskNameToModel = new HashMap<>();
+    for (TaskModel model : tasks) {
+      taskNameToModel.put(model.getTaskName().getTaskName(), model);
+    }
+
+    // Build container models
+    Set<ContainerModel> containerModels = new HashSet<>();
+    for (TaskGroup container : containerTasks) {
+      Map<TaskName, TaskModel> containerTaskModels = new HashMap<>();
+      for (String taskName : container.taskNames) {
+        TaskModel model = taskNameToModel.get(taskName);
+        containerTaskModels.put(model.getTaskName(), model);
+      }
+      containerModels.add(new ContainerModel(container.containerId, containerTaskModels));
+    }
+    return Collections.unmodifiableSet(containerModels);
+  }
+
+  /**
+   * Converts the task->containerId map to an ordered list of {@link TaskGroup} instances.
+   *
+   * @param taskToContainerId a map from each task name to the containerId to which it is assigned.
+   * @return                  a list of TaskGroups ordered ascending by containerId.
+   */
+  private List<TaskGroup> getOrderedContainers(Map<String, Integer> taskToContainerId) {
+    log.debug("Got task to container map: {}", taskToContainerId);
+
+    // Group tasks by container Id
+    HashMap<Integer, List<String>> containerIdToTaskNames = new HashMap<>();
+    for (Map.Entry<String, Integer> entry : taskToContainerId.entrySet()) {
+      String taskName = entry.getKey();
+      Integer containerId = entry.getValue();
+      List<String> taskNames = containerIdToTaskNames.get(containerId);
+      if (taskNames == null) {
+        taskNames = new ArrayList<>();
+        containerIdToTaskNames.put(containerId, taskNames);
+      }
+      taskNames.add(taskName);
+    }
+
+    // Build container tasks
+    List<TaskGroup> containerTasks = new ArrayList<>(containerIdToTaskNames.size());
+    for (int i = 0; i < containerIdToTaskNames.size(); i++) {
+      if (containerIdToTaskNames.get(i) == null) throw new IllegalStateException("Task mapping is missing container: " + i);
+      containerTasks.add(new TaskGroup(i, containerIdToTaskNames.get(i)));
+    }
+
+    return containerTasks;
+  }
+
+  /**
+   * A mutable group of tasks and an associated container id.
+   *
+   * Used as a temporary mutable container until the final ContainerModel is known.
+   */
+  private static class TaskGroup {
+    private final List<String> taskNames = new LinkedList<>();
+    private final Integer containerId;
+
+    private TaskGroup(Integer containerId, List<String> taskNames) {
+      this.containerId = containerId;
+      Collections.sort(taskNames);        // For consistency because the taskNames came from a Map
+      this.taskNames.addAll(taskNames);
+    }
+
+    public Integer getContainerId() {
+      return containerId;
+    }
+
+    public void addTaskName(String taskName) {
+      taskNames.add(taskName);
+    }
+
+    public String removeTask() {
+      return taskNames.remove(taskNames.size() - 1);
+    }
+
+    public int size() {
+      return taskNames.size();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
new file mode 100644
index 0000000..f0e9686
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+
+
+/**
+ * Factory to build the GroupByContainerCount class.
+ */
+public class GroupByContainerCountFactory implements TaskNameGrouperFactory {
+  @Override
+  public TaskNameGrouper build(Config config) {
+    return new GroupByContainerCount(new JobConfig(config).getContainerCount());
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
new file mode 100644
index 0000000..ec5cf3d
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.coordinator.stream.AbstractCoordinatorStreamManager;
+import org.apache.samza.coordinator.stream.CoordinatorStreamSystemConsumer;
+import org.apache.samza.coordinator.stream.CoordinatorStreamSystemProducer;
+import org.apache.samza.coordinator.stream.messages.CoordinatorStreamMessage;
+import org.apache.samza.coordinator.stream.messages.Delete;
+import org.apache.samza.coordinator.stream.messages.SetTaskContainerMapping;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Task assignment Manager is used to persist and read the task-to-container
+ * assignment information from the coordinator stream
+ * */
+public class TaskAssignmentManager extends AbstractCoordinatorStreamManager {
+  private static final Logger log = LoggerFactory.getLogger(TaskAssignmentManager.class);
+  private Map<String, Integer> taskNameToContainerId = new HashMap<>();
+
+  /**
+   * Default constructor that creates a read-write manager
+   *
+   * @param coordinatorStreamProducer producer to the coordinator stream
+   * @param coordinatorStreamConsumer consumer for the coordinator stream
+   */
+  public TaskAssignmentManager(CoordinatorStreamSystemProducer coordinatorStreamProducer,
+                         CoordinatorStreamSystemConsumer coordinatorStreamConsumer) {
+    super(coordinatorStreamProducer, coordinatorStreamConsumer, "SamzaTaskAssignmentManager");
+  }
+
+  /**
+   * Special constructor that creates a write-only {@link TaskAssignmentManager} that only writes
+   * to coordinator stream in SamzaContainer
+   *
+   * @param coordinatorStreamSystemProducer producer to the coordinator stream
+   */
+  public TaskAssignmentManager(CoordinatorStreamSystemProducer coordinatorStreamSystemProducer) {
+    this(coordinatorStreamSystemProducer, null);
+  }
+
+  @Override
+  public void register(TaskName taskName) {
+    // taskName will not be used. This producer is global scope.
+    registerCoordinatorStreamConsumer();
+    registerCoordinatorStreamProducer(getSource());
+  }
+
+  /**
+   * Method to allow read container task information from coordinator stream. This method is used
+   * in {@link org.apache.samza.coordinator.JobCoordinator}.
+   *
+   * @return the map of taskName: containerId
+   */
+  public Map<String, Integer> readTaskAssignment() {
+    Map<String, Integer> allMappings = new HashMap<>();
+    for (CoordinatorStreamMessage message: getBootstrappedStream(SetTaskContainerMapping.TYPE)) {
+      if (message.isDelete()) {
+        allMappings.remove(message.getKey());
+        log.debug("Got TaskContainerMapping delete message: {}", message);
+      } else {
+        SetTaskContainerMapping mapping = new SetTaskContainerMapping(message);
+        allMappings.put(mapping.getKey(), mapping.getTaskAssignment());
+        log.debug("Got TaskContainerMapping message: {}", mapping);
+      }
+    }
+    taskNameToContainerId = allMappings;
+
+    for (Map.Entry<String, Integer> entry : taskNameToContainerId.entrySet()) {
+      log.debug("Assignment for task \"{}\": {}", entry.getKey(), entry.getValue());
+    }
+
+    return Collections.unmodifiableMap(allMappings);
+  }
+
+  /**
+   * Method to write task container info to coordinator stream.
+   *
+   * @param taskName    the task name
+   * @param containerId the SamzaContainer ID or {@code null} to delete the mapping
+   */
+  public void writeTaskContainerMapping(String taskName, Integer containerId) {
+    Integer existingContainerId = taskNameToContainerId.get(taskName);
+    if (existingContainerId != null && !existingContainerId.equals(containerId)) {
+      log.info("Task \"{}\" moved from container {} to container {}", new Object[]{taskName, existingContainerId, containerId});
+    } else {
+      log.debug("Task \"{}\" assigned to container {}", taskName, containerId);
+    }
+
+    if (containerId == null) {
+      send(new Delete(getSource(), taskName, SetTaskContainerMapping.TYPE));
+      taskNameToContainerId.remove(taskName);
+    } else {
+      send(new SetTaskContainerMapping(getSource(), taskName, String.valueOf(containerId)));
+      taskNameToContainerId.put(taskName, containerId);
+    }
+  }
+
+  /**
+   * Deletes the task container info from the coordinator stream for each of the specified task names.
+   *
+   * @param taskNames the task names for which the mapping will be deleted.
+   */
+  public void deleteTaskContainerMappings(Iterable<String> taskNames) {
+    for (String taskName : taskNames) {
+      writeTaskContainerMapping(taskName, null);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/coordinator/stream/AbstractCoordinatorStreamManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/coordinator/stream/AbstractCoordinatorStreamManager.java b/samza-core/src/main/java/org/apache/samza/coordinator/stream/AbstractCoordinatorStreamManager.java
index 211b642..813234b 100644
--- a/samza-core/src/main/java/org/apache/samza/coordinator/stream/AbstractCoordinatorStreamManager.java
+++ b/samza-core/src/main/java/org/apache/samza/coordinator/stream/AbstractCoordinatorStreamManager.java
@@ -109,7 +109,7 @@ public abstract class AbstractCoordinatorStreamManager {
   }
 
   /**
-   * Registers a consumer and a produces. Every subclass should implement it's logic for registration.<br><br>
+   * Registers a consumer and a producer. Every subclass should implement it's logic for registration.<br><br>
    * Registering a single consumer and a single producer can be done with {@link AbstractCoordinatorStreamManager#registerCoordinatorStreamConsumer()}
    * and {@link AbstractCoordinatorStreamManager#registerCoordinatorStreamProducer(String)} methods respectively.<br>
    * These methods can be used in the concrete implementation of this register method.

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetContainerHostMapping.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetContainerHostMapping.java b/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetContainerHostMapping.java
index 4d093b5..da67346 100644
--- a/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetContainerHostMapping.java
+++ b/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetContainerHostMapping.java
@@ -30,7 +30,7 @@ package org.apache.samza.coordinator.stream.messages;
  *     Source: "SamzaContainer-$ContainerId"
  *     MessageMap:
  *     {
- *         hostname: Name of the host
+ *         host: Name of the host
  *         jmx-url: jmxAddressString
  *         jmx-tunneling-url: jmxTunnelingAddressString
  *     }

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskContainerMapping.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskContainerMapping.java b/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskContainerMapping.java
new file mode 100644
index 0000000..431c05d
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskContainerMapping.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.coordinator.stream.messages;
+
+/**
+ * SetTaskContainerMapping is a {@link CoordinatorStreamMessage} used internally
+ * by the Samza framework to persist the task-to-container mappings.
+ *
+ * Structure of the message looks like:
+ *
+ * <pre>
+ * key =&gt; [1, "set-task-container-assignment", $TaskName]
+ *
+ * message =&gt; {
+ *     "host": "192.168.0.1",
+ *     "source": "SamzaTaskAssignmentManager",
+ *     "username":"app",
+ *     "timestamp": 1456177487325,
+ *     "values": {
+ *         "containerId": "139"
+ *     }
+ * }
+ * </pre>
+ * */
+public class SetTaskContainerMapping extends CoordinatorStreamMessage {
+  public static final String TYPE = "set-task-container-assignment";
+  public static final String CONTAINER_KEY = "containerId";
+
+  /**
+   * SteContainerToHostMapping is used to set the container to host mapping information.
+   * @param message which holds the container to host information.
+   */
+  public SetTaskContainerMapping(CoordinatorStreamMessage message) {
+    super(message.getKeyArray(), message.getMessageMap());
+  }
+
+  /**
+   * SteContainerToHostMapping is used to set the container to host mapping information.
+   * @param source              the source of the message
+   * @param taskName                 the taskName which is used to persist the message
+   * @param containerId            the hostname of the container
+   */
+  public SetTaskContainerMapping(String source, String taskName, String containerId) {
+    super(source);
+    setType(TYPE);
+    setKey(taskName);
+    putMessageValue(CONTAINER_KEY, containerId);
+  }
+
+  public Integer getTaskAssignment() {
+    return Integer.parseInt(getMessageValue(CONTAINER_KEY));
+  }
+
+
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCount.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCount.scala b/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCount.scala
deleted file mode 100644
index cb0a3bd..0000000
--- a/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCount.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.container.grouper.task
-
-import org.apache.samza.job.model.TaskModel
-import org.apache.samza.job.model.ContainerModel
-import scala.collection.JavaConversions._
-import java.util
-
-/**
- * Group the SSP taskNames by dividing the number of taskNames into the number
- * of containers (n) and assigning n taskNames to each container as returned by
- * iterating over the keys in the map of taskNames (whatever that ordering
- * happens to be). No consideration is given towards locality, even distribution
- * of aggregate SSPs within a container, even distribution of the number of
- * taskNames between containers, etc.
- */
-class GroupByContainerCount(numContainers: Int) extends TaskNameGrouper {
-  require(numContainers > 0, "Must have at least one container")
-
-  override def group(tasks: util.Set[TaskModel]): util.Set[ContainerModel] = {
-    require(tasks.size > 0, "No tasks found. Likely due to no input partitions. Can't run a job with no tasks.")
-    require(tasks.size >= numContainers, "Your container count (%s) is larger than your task count (%s). Can't have containers with nothing to do, so aborting." format (numContainers, tasks.size))
-    setAsJavaSet(tasks
-      .toList
-      // Sort tasks by taskName.
-      .sortWith { case (task1, task2) => task1.compareTo(task2) < 0 }
-      // Assign every task an ID.
-      .zip(0 until tasks.size)
-      // Map every task to a container using its task ID.
-      .groupBy(_._2 % numContainers)
-      // Take just TaskModel and remove task IDs.
-      .mapValues(_.map { case (task, taskId) => (task.getTaskName, task) }.toMap)
-      .map { case (containerId, taskModels) => new ContainerModel(containerId, taskModels) }
-      .toSet)
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.scala b/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.scala
deleted file mode 100644
index 8bbfd63..0000000
--- a/samza-core/src/main/scala/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.container.grouper.task
-
-import org.apache.samza.config.Config
-import org.apache.samza.config.JobConfig.Config2Job
-/**
- * Factory to build the GroupByContainerCount class.
- */
-class GroupByContainerCountFactory extends TaskNameGrouperFactory {
-  override def build(config: Config): TaskNameGrouper = {
-    new GroupByContainerCount(config.getContainerCount)
-  }
-}

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
index cd7daa2..384b2e7 100644
--- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
+++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
@@ -28,7 +28,7 @@ import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.config.{Config, StorageConfig}
 import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory
-import org.apache.samza.container.grouper.task.TaskNameGrouperFactory
+import org.apache.samza.container.grouper.task.{BalancingTaskNameGrouper, TaskNameGrouperFactory}
 import org.apache.samza.container.{LocalityManager, TaskName}
 import org.apache.samza.coordinator.server.{HttpServer, JobServlet}
 import org.apache.samza.coordinator.stream.CoordinatorStreamSystemFactory
@@ -247,13 +247,18 @@ object JobCoordinator extends Logging {
 
     // Here is where we should put in a pluggable option for the
     // SSPTaskNameGrouper for locality, load-balancing, etc.
-
     val containerGrouperFactory = Util.getObj[TaskNameGrouperFactory](config.getTaskNameGrouperFactory)
     val containerGrouper = containerGrouperFactory.build(config)
-    val containerModels = asScalaSet(containerGrouper.group(setAsJavaSet(taskModels))).map
+    val containerModels = {
+      if (containerGrouper.isInstanceOf[BalancingTaskNameGrouper])
+        containerGrouper.asInstanceOf[BalancingTaskNameGrouper].balance(taskModels, localityManager)
+      else
+        containerGrouper.group(taskModels)
+    }
+    val containerMap = asScalaSet(containerModels).map
             { case (containerModel) => Integer.valueOf(containerModel.getContainerId) -> containerModel }.toMap
 
-    new JobModel(config, containerModels, localityManager)
+    new JobModel(config, containerMap, localityManager)
   }
 
   private def createChangeLogStreams(config: StorageConfig, changeLogPartitions: Int, streamMetadataCache: StreamMetadataCache) {

http://git-wip-us.apache.org/repos/asf/samza/blob/2a531b0b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
new file mode 100644
index 0000000..3fd39d7
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
@@ -0,0 +1,837 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.container.LocalityManager;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.apache.samza.container.mock.ContainerMocks.generateTaskContainerMapping;
+import static org.apache.samza.container.mock.ContainerMocks.generateTaskModels;
+import static org.apache.samza.container.mock.ContainerMocks.getTaskModel;
+import static org.apache.samza.container.mock.ContainerMocks.getTaskName;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.*;
+
+public class TestGroupByContainerCount {
+  private TaskAssignmentManager taskAssignmentManager;
+  private LocalityManager localityManager;
+
+  @Before
+  public void setup() {
+    taskAssignmentManager = mock(TaskAssignmentManager.class);
+    localityManager = mock(LocalityManager.class);
+    when(localityManager.getTaskAssignmentManager()).thenReturn(taskAssignmentManager);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testGroupEmptyTasks() {
+    new GroupByContainerCount(1).group(new HashSet());
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testGroupFewerTasksThanContainers() {
+    Set<TaskModel> taskModels = new HashSet<>();
+    taskModels.add(getTaskModel(1));
+    new GroupByContainerCount(2).group(taskModels);
+  }
+
+  @Test(expected = UnsupportedOperationException.class)
+  public void testGrouperResultImmutable() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+    Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels);
+    containers.remove(containers.iterator().next());
+  }
+
+  @Test
+  public void testGroupHappyPath() {
+    Set<TaskModel> taskModels = generateTaskModels(5);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(2, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(3, container0.getTasks().size());
+    assertEquals(2, container1.getTasks().size());
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(3)));
+  }
+
+  @Test
+  public void testGroupManyTasks() {
+    Set<TaskModel> taskModels = generateTaskModels(21);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(2, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(11, container0.getTasks().size());
+    assertEquals(10, container1.getTasks().size());
+
+    // NOTE: tasks are sorted lexicographically, so the container assignment
+    // can seem odd, but the consistency is the key focus
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(10)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(12)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(14)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(16)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(18)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(3)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(5)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(7)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(9)));
+
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(11)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(13)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(15)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(17)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(19)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(20)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(6)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(8)));
+  }
+
+  /**
+   * Before:
+   *  C0  C1
+   * --------
+   *  T0  T1
+   *  T2  T3
+   *  T4  T5
+   *  T6  T7
+   *  T8
+   *
+   * After:
+   *  C0  C1  C2  C3
+   * ----------------
+   *  T0  T1  T6  T5
+   *  T2  T3  T8  T7
+   *  T4
+   *
+   *  NOTE for host affinity, it would help to have some additional logic to reassign tasks
+   *  from C0 and C1 to containers that were on the same respective hosts, it wasn't implemented
+   *  because the scenario is infrequent, the benefits are not guaranteed, and the code complexity
+   *  wasn't worth it. It certainly could be implemented in the future.
+   */
+  @Test
+  public void testBalancerAfterContainerIncrease() {
+    Set<TaskModel> taskModels = generateTaskModels(9);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(4).balance(taskModels, localityManager);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(4, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    ContainerModel container2 = containersMap.get(2);
+    ContainerModel container3 = containersMap.get(3);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertNotNull(container2);
+    assertNotNull(container3);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(3, container0.getTasks().size());
+    assertEquals(2, container1.getTasks().size());
+    assertEquals(2, container2.getTasks().size());
+    assertEquals(2, container3.getTasks().size());
+
+    // Tasks 0-4 should stay on the same original containers
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(3)));
+    // Tasks 5-8 should be reassigned to the new containers.
+    // Consistency is the goal with these reassignments
+    assertTrue(container2.getTasks().containsKey(getTaskName(8)));
+    assertTrue(container2.getTasks().containsKey(getTaskName(6)));
+    assertTrue(container3.getTasks().containsKey(getTaskName(5)));
+    assertTrue(container3.getTasks().containsKey(getTaskName(7)));
+
+    // Verify task mappings are saved
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), 0);
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), 1);
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), 2);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), 2);
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), 3);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), 3);
+
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  /**
+   * Before:
+   *  C0  C1  C2  C3
+   * ----------------
+   *  T0  T1  T2  T3
+   *  T4  T5  T6  T7
+   *  T8
+   *
+   * After:
+   *  C0  C1
+   * --------
+   *  T0  T1
+   *  T4  T5
+   *  T8  T7
+   *  T6  T3
+   *  T2
+   *
+   *  NOTE for host affinity, it would help to have some additional logic to reassign tasks
+   *  from C2 and C3 to containers that were on the same respective hosts, it wasn't implemented
+   *  because the scenario is infrequent, the benefits are not guaranteed, and the code complexity
+   *  wasn't worth it. It certainly could be implemented in the future.
+   */
+  @Test
+  public void testBalancerAfterContainerDecrease() {
+    Set<TaskModel> taskModels = generateTaskModels(9);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(2, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(5, container0.getTasks().size());
+    assertEquals(4, container1.getTasks().size());
+
+    // Tasks 0,4,8 and 1,5 should stay on the same original containers
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(8)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(5)));
+
+    // Tasks 2,6 and 3,7 should be reassigned to the new containers.
+    // Consistency is the goal with these reassignments
+    assertTrue(container0.getTasks().containsKey(getTaskName(6)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(7)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(3)));
+
+    // Verify task mappings are saved
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 0);
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), 1);
+
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  /**
+   * Before:
+   *  C0  C1  C2  C3
+   * ----------------
+   *  T0  T1  T2  T3
+   *  T4  T5  T6  T7
+   *  T8
+   *
+   * Intermediate:
+   *  C0  C1
+   * --------
+   *  T0  T1
+   *  T4  T5
+   *  T8  T7
+   *  T6  T3
+   *  T2
+   *
+   *  After:
+   *  C0  C1  C2
+   * ------------
+   *  T0  T1  T6
+   *  T4  T5  T2
+   *  T8  T7  T3
+   */
+  @Test
+  public void testBalancerMultipleReblances() {
+    // Before
+    Set<TaskModel> taskModels = generateTaskModels(9);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    // First balance
+    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(2, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(5, container0.getTasks().size());
+    assertEquals(4, container1.getTasks().size());
+
+    // Tasks 0,4,8 and 1,5 should stay on the same original containers
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(8)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(5)));
+
+    // Tasks 2,6 and 3,7 should be reassigned to the new containers.
+    // Consistency is the goal with these reassignments
+    assertTrue(container0.getTasks().containsKey(getTaskName(6)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(7)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(3)));
+
+    // Verify task mappings are saved
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 0);
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), 1);
+
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+
+
+    // Second balance
+    prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+
+    TaskAssignmentManager taskAssignmentManager2 = mock(TaskAssignmentManager.class);
+    when(taskAssignmentManager2.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    LocalityManager localityManager2 = mock(LocalityManager.class);
+    when(localityManager2.getTaskAssignmentManager()).thenReturn(taskAssignmentManager2);
+
+    containers = new GroupByContainerCount(3).balance(taskModels, localityManager2);
+
+    containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(3, containers.size());
+    container0 = containersMap.get(0);
+    container1 = containersMap.get(1);
+    ContainerModel container2 = containersMap.get(2);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertNotNull(container2);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(2, container2.getContainerId());
+    assertEquals(3, container0.getTasks().size());
+    assertEquals(3, container1.getTasks().size());
+    assertEquals(3, container2.getTasks().size());
+
+    // Tasks 0,4,8 and 1,5,7 should stay on the same original containers
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(8)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(5)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(7)));
+
+    // Tasks 2,6 and 3 should be reassigned to the new container.
+    // Consistency is the goal with these reassignments
+    assertTrue(container2.getTasks().containsKey(getTaskName(6)));
+    assertTrue(container2.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container2.getTasks().containsKey(getTaskName(3)));
+
+    // Verify task mappings are saved
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(4).getTaskName(), 0);
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(8).getTaskName(), 0);
+
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(1).getTaskName(), 1);
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(5).getTaskName(), 1);
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(7).getTaskName(), 1);
+
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(6).getTaskName(), 2);
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(2).getTaskName(), 2);
+    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(3).getTaskName(), 2);
+
+    verify(taskAssignmentManager2, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  /**
+   * Before:
+   *  C0  C1
+   * --------
+   *  T0  T1
+   *  T2  T3
+   *  T4  T5
+   *  T6  T7
+   *  T8
+   *
+   *  After:
+   *  C0  C1
+   * --------
+   *  T0  T1
+   *  T2  T3
+   *  T4  T5
+   *  T6  T7
+   *  T8
+   */
+  @Test
+  public void testBalancerAfterContainerSame() {
+    Set<TaskModel> taskModels = generateTaskModels(9);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(2, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(5, container0.getTasks().size());
+    assertEquals(4, container1.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(6)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(8)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(3)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(5)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(7)));
+
+    verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyInt());
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  /**
+   * Verifies the ability to have a custom task-container mapping that is *deliberately* unbalanced.
+   *
+   * Before:
+   *  C0  C1
+   * --------
+   *  T0  T6
+   *  T1  T7
+   *  T2  T8
+   *  T3
+   *  T4
+   *  T5
+   *
+   *  After:
+   *  C0  C1
+   * --------
+   *  T0  T6
+   *  T1  T7
+   *  T2  T8
+   *  T3
+   *  T4
+   *  T5
+   */
+  @Test
+  public void testBalancerAfterContainerSameCustomAssignment() {
+    Set<TaskModel> taskModels = generateTaskModels(9);
+
+    Map<String, Integer> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), 0);
+    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), 0);
+    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), 0);
+    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), 0);
+    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), 0);
+    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), 0);
+    prevTaskToContainerMapping.put(getTaskName(6).getTaskName(), 1);
+    prevTaskToContainerMapping.put(getTaskName(7).getTaskName(), 1);
+    prevTaskToContainerMapping.put(getTaskName(8).getTaskName(), 1);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(2, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(6, container0.getTasks().size());
+    assertEquals(3, container1.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(3)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(5)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(6)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(7)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(8)));
+
+    verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyInt());
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  /**
+   * Verifies the ability to have a custom task-container mapping that is *deliberately* unbalanced.
+   *
+   * Before:
+   *  C0  C1
+   * --------
+   *  T0  T1
+   *      T2
+   *      T3
+   *      T4
+   *      T5
+   *
+   *  After:
+   *  C0  C1  C2
+   * ------------
+   *  T0  T1  T4
+   *  T5  T2  T3
+   *
+   *  The key here is that C0, which is not one of the new containers was under-allocated.
+   *  This is an important case because this scenario, while impossible with GroupByContainerCount.group()
+   *  could occur when the grouper class is switched or if there is a custom mapping.
+   */
+  @Test
+  public void testBalancerAfterContainerSameCustomAssignmentAndContainerIncrease() {
+    Set<TaskModel> taskModels = generateTaskModels(6);
+
+    Map<String, Integer> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), 0);
+    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), 1);
+    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), 1);
+    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), 1);
+    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), 1);
+    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), 1);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(3).balance(taskModels, localityManager);
+
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(3, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    ContainerModel container2 = containersMap.get(2);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertNotNull(container2);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(2, container2.getContainerId());
+    assertEquals(2, container0.getTasks().size());
+    assertEquals(2, container1.getTasks().size());
+    assertEquals(2, container1.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(5)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(2)));
+    assertTrue(container2.getTasks().containsKey(getTaskName(4)));
+    assertTrue(container2.getTasks().containsKey(getTaskName(3)));
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), 2);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), 2);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), 0);
+
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  @Test
+  public void testBalancerOldContainerCountOne() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(1).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(3).balance(taskModels, localityManager);
+
+    // Results should be the same as calling group()
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+    assertEquals(3, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    ContainerModel container1 = containersMap.get(1);
+    ContainerModel container2 = containersMap.get(2);
+    assertNotNull(container0);
+    assertNotNull(container1);
+    assertNotNull(container2);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(1, container1.getContainerId());
+    assertEquals(2, container2.getContainerId());
+    assertEquals(1, container0.getTasks().size());
+    assertEquals(1, container1.getTasks().size());
+    assertEquals(1, container2.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container2.getTasks().containsKey(getTaskName(2)));
+
+    // Verify task mappings are saved
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 1);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 2);
+
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  @Test
+  public void testBalancerNewContainerCountOne() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+
+    // Results should be the same as calling group
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(1, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    assertNotNull(container0);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(3, container0.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 0);
+
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  @Test
+  public void testBalancerEmptyTaskMapping() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<String, Integer>());
+
+    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+
+    // Results should be the same as calling group
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(1, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    assertNotNull(container0);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(3, container0.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 0);
+
+    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
+  }
+
+  @Test
+  public void testGroupTaskCountIncrease() {
+    int taskCount = 3;
+    Set<TaskModel> taskModels = generateTaskModels(taskCount);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1));
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+
+    // Results should be the same as calling group
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(1, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    assertNotNull(container0);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(3, container0.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 0);
+
+    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
+  }
+
+  @Test
+  public void testGroupTaskCountDecrease() {
+    int taskCount = 3;
+    Set<TaskModel> taskModels = generateTaskModels(taskCount);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1));
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+
+    // Results should be the same as calling group
+    Map<Integer, ContainerModel> containersMap = new HashMap<>();
+    for (ContainerModel container : containers) {
+      containersMap.put(container.getContainerId(), container);
+    }
+
+    assertEquals(1, containers.size());
+    ContainerModel container0 = containersMap.get(0);
+    assertNotNull(container0);
+    assertEquals(0, container0.getContainerId());
+    assertEquals(3, container0.getTasks().size());
+
+    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
+    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), 0);
+    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), 0);
+
+    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBalancerNewContainerCountGreaterThanTasks() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    new GroupByContainerCount(5).balance(taskModels, localityManager);     // Should throw
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBalancerEmptyTasks() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    new GroupByContainerCount(5).balance(new HashSet<TaskModel>(), localityManager);     // Should throw
+  }
+
+  @Test(expected = UnsupportedOperationException.class)
+  public void testBalancerResultImmutable() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<String, Integer> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+
+    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+    containers.remove(containers.iterator().next());
+  }
+}