You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ja...@apache.org on 2018/12/05 18:57:03 UTC

[2/3] samza git commit: SAMZA-1973: Unify the TaskNameGrouper interface for yarn and standalone.

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
index 600b7a1..5fb71f3 100644
--- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
@@ -19,31 +19,33 @@
 
 package org.apache.samza.coordinator
 
-
 import java.util
 import java.util.concurrent.atomic.AtomicReference
-
 import org.apache.samza.config._
 import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.config.Config
 import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory
-import org.apache.samza.container.grouper.task.BalancingTaskNameGrouper
-import org.apache.samza.container.grouper.task.TaskNameGrouperFactory
+import org.apache.samza.container.grouper.task._
 import org.apache.samza.container.LocalityManager
 import org.apache.samza.container.TaskName
 import org.apache.samza.coordinator.server.HttpServer
 import org.apache.samza.coordinator.server.JobServlet
+import org.apache.samza.job.model.ContainerModel
 import org.apache.samza.job.model.JobModel
 import org.apache.samza.job.model.TaskModel
+import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system._
 import org.apache.samza.util.Logging
 import org.apache.samza.util.Util
 import org.apache.samza.Partition
+import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping
+import org.apache.samza.runtime.LocationId
 
 import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
 
 /**
  * Helper companion object that is responsible for wiring up a JobModelManager
@@ -51,66 +53,145 @@ import scala.collection.JavaConverters._
  */
 object JobModelManager extends Logging {
 
-  val SOURCE = "JobModelManager"
   /**
    * a volatile value to store the current instantiated <code>JobModelManager</code>
    */
-  @volatile var currentJobModelManager: JobModelManager = null
+  @volatile var currentJobModelManager: JobModelManager = _
   val jobModelRef: AtomicReference[JobModel] = new AtomicReference[JobModel]()
 
   /**
-   * Does the following actions for a job.
+   * Currently used only in the ApplicationMaster for yarn deployment model.
+   * Does the following:
    * a) Reads the jobModel from coordinator stream using the job's configuration.
-   * b) Recomputes changelog partition mapping based on jobModel and job's configuration.
+   * b) Recomputes the changelog partition mapping based on jobModel and job's configuration.
    * c) Builds JobModelManager using the jobModel read from coordinator stream.
-   * @param config Config from the coordinator stream.
-   * @param changelogPartitionMapping The changelog partition-to-task mapping.
-   * @return JobModelManager
+   * @param config config from the coordinator stream.
+   * @param changelogPartitionMapping changelog partition-to-task mapping of the samza job.
+   * @param metricsRegistry the registry for reporting metrics.
+   * @return the instantiated {@see JobModelManager}.
    */
-  def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer]) = {
-    val localityManager = new LocalityManager(config, new MetricsRegistryMap())
-
-    // Map the name of each system to the corresponding SystemAdmin
+  def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer], metricsRegistry: MetricsRegistry = new MetricsRegistryMap()): JobModelManager = {
+    val localityManager = new LocalityManager(config, metricsRegistry)
+    val taskAssignmentManager = new TaskAssignmentManager(config, metricsRegistry)
     val systemAdmins = new SystemAdmins(config)
-    val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+    try {
+      systemAdmins.start()
+      val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+      val grouperMetadata: GrouperMetadata = getGrouperMetadata(config, localityManager, taskAssignmentManager)
 
-    val containerCount = new JobConfig(config).getContainerCount
-    val processorList = List.range(0, containerCount).map(c => c.toString)
+      val jobModel: JobModel = readJobModel(config, changelogPartitionMapping, streamMetadataCache, grouperMetadata)
+      jobModelRef.set(new JobModel(jobModel.getConfig, jobModel.getContainers, localityManager))
 
-    systemAdmins.start()
-    val jobModelManager = getJobModelManager(config, changelogPartitionMapping, localityManager, streamMetadataCache, processorList.asJava)
-    systemAdmins.stop()
+      updateTaskAssignments(jobModel, taskAssignmentManager, grouperMetadata)
 
-    jobModelManager
+      val server = new HttpServer
+      server.addServlet("/", new JobServlet(jobModelRef))
+
+      currentJobModelManager = new JobModelManager(jobModel, server, localityManager)
+      currentJobModelManager
+    } finally {
+      taskAssignmentManager.close()
+      systemAdmins.stop()
+      // Not closing localityManager, since {@code ClusterBasedJobCoordinator} uses it to read container locality through {@code JobModel}.
+    }
   }
 
   /**
-   * Build a JobModelManager using a Samza job's configuration.
-   */
-  private def getJobModelManager(config: Config,
-                                changeLogMapping: util.Map[TaskName, Integer],
-                                localityManager: LocalityManager,
-                                streamMetadataCache: StreamMetadataCache,
-                                containerIds: java.util.List[String]) = {
-    val jobModel: JobModel = readJobModel(config, changeLogMapping, localityManager, streamMetadataCache, containerIds)
-    jobModelRef.set(jobModel)
-
-    val server = new HttpServer
-    server.addServlet("/", new JobServlet(jobModelRef))
-    currentJobModelManager = new JobModelManager(jobModel, server, localityManager)
-    currentJobModelManager
+    * Builds the {@see GrouperMetadataImpl} for the samza job.
+    * @param config represents the configurations defined by the user.
+    * @param localityManager provides the processor to host mapping persisted to the metadata store.
+    * @param taskAssignmentManager provides the processor to task assignments persisted to the metadata store.
+    * @return the instantiated {@see GrouperMetadata}.
+    */
+  def getGrouperMetadata(config: Config, localityManager: LocalityManager, taskAssignmentManager: TaskAssignmentManager) = {
+    val processorLocality: util.Map[String, LocationId] = getProcessorLocality(config, localityManager)
+    val taskAssignment: util.Map[String, String] = taskAssignmentManager.readTaskAssignment()
+    val taskNameToProcessorId: util.Map[TaskName, String] = new util.HashMap[TaskName, String]()
+    for ((taskName, processorId) <- taskAssignment) {
+      taskNameToProcessorId.put(new TaskName(taskName), processorId)
+    }
+
+    val taskLocality:util.Map[TaskName, LocationId] = new util.HashMap[TaskName, LocationId]()
+    for ((taskName, processorId) <- taskAssignment) {
+      if (processorLocality.containsKey(processorId)) {
+        taskLocality.put(new TaskName(taskName), processorLocality.get(processorId))
+      }
+    }
+    new GrouperMetadataImpl(processorLocality, taskLocality, new util.HashMap[TaskName, util.List[SystemStreamPartition]](), taskNameToProcessorId)
   }
 
   /**
-   * For each input stream specified in config, exactly determine its
-   * partitions, returning a set of SystemStreamPartitions containing them all.
-   */
-  private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache) = {
+    * Retrieves and returns the processor locality of a samza job using provided {@see Config} and {@see LocalityManager}.
+    * @param config provides the configurations defined by the user. Required to connect to the storage layer.
+    * @param localityManager provides the processor to host mapping persisted to the metadata store.
+    * @return the processor locality.
+    */
+  def getProcessorLocality(config: Config, localityManager: LocalityManager) = {
+    val containerToLocationId: util.Map[String, LocationId] = new util.HashMap[String, LocationId]()
+    val existingContainerLocality = localityManager.readContainerLocality()
+
+    for (containerId <- 0 to config.getContainerCount) {
+      val localityMapping = existingContainerLocality.get(containerId.toString)
+      // To handle the case when the container count is increased between two different runs of a samza-yarn job,
+      // set the locality of newly added containers to any_host.
+      var locationId: LocationId = new LocationId("ANY_HOST")
+      if (localityMapping != null && localityMapping.containsKey(SetContainerHostMapping.HOST_KEY)) {
+        locationId = new LocationId(localityMapping.get(SetContainerHostMapping.HOST_KEY))
+      }
+      containerToLocationId.put(containerId.toString, locationId)
+    }
+
+    containerToLocationId
+  }
+
+  /**
+    * This method does the following:
+    * 1. Deletes the existing task assignments if the partition-task grouping has changed from the previous run of the job.
+    * 2. Saves the newly generated task assignments to the storage layer through the {@param TaskAssignementManager}.
+    *
+    * @param jobModel              represents the {@see JobModel} of the samza job.
+    * @param taskAssignmentManager required to persist the processor to task assignments to the storage layer.
+    * @param grouperMetadata       provides the historical metadata of the application.
+    */
+  def updateTaskAssignments(jobModel: JobModel, taskAssignmentManager: TaskAssignmentManager, grouperMetadata: GrouperMetadata): Unit = {
+    val taskNames: util.Set[String] = new util.HashSet[String]()
+    for (container <- jobModel.getContainers.values()) {
+      for (taskModel <- container.getTasks.values()) {
+        taskNames.add(taskModel.getTaskName.getTaskName)
+      }
+    }
+    val taskToContainerId = grouperMetadata.getPreviousTaskToProcessorAssignment
+    if (taskNames.size() != taskToContainerId.size()) {
+      warn("Current task count {} does not match saved task count {}. Stateful jobs may observe misalignment of keys!",
+           taskNames.size(), taskToContainerId.size())
+      // If the tasks changed, then the partition-task grouping is also likely changed and we can't handle that
+      // without a much more complicated mapping. Further, the partition count may have changed, which means
+      // input message keys are likely reshuffled w.r.t. partitions, so the local state may not contain necessary
+      // data associated with the incoming keys. Warn the user and default to grouper
+      // In this scenario the tasks may have been reduced, so we need to delete all the existing messages
+      taskAssignmentManager.deleteTaskContainerMappings(taskNames)
+    }
+
+    for (container <- jobModel.getContainers.values()) {
+      for (taskName <- container.getTasks.keySet) {
+        taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName, container.getId)
+      }
+    }
+  }
+
+  /**
+    * Computes the input system stream partitions of a samza job using the provided {@param config}
+    * and {@param streamMetadataCache}.
+    * @param config the configuration of the job.
+    * @param streamMetadataCache to query the partition metadata of the input streams.
+    * @return the input {@see SystemStreamPartition} of the samza job.
+    */
+  private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = {
     val inputSystemStreams = config.getInputStreams
 
     // Get the set of partitions for each SystemStream from the stream metadata
     streamMetadataCache
-      .getStreamMetadata(inputSystemStreams, true)
+      .getStreamMetadata(inputSystemStreams, partitionsMetadataOnly = true)
       .flatMap {
         case (systemStream, metadata) =>
           metadata
@@ -121,55 +202,69 @@ object JobModelManager extends Logging {
       }.toSet
   }
 
+  /**
+    * Builds the input {@see SystemStreamPartition} based upon the {@param config} defined by the user.
+    * @param config configuration to fetch the metadata of the input streams.
+    * @param streamMetadataCache required to query the partition metadata of the input streams.
+    * @return the input SystemStreamPartitions of the job.
+    */
   private def getMatchedInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = {
     val allSystemStreamPartitions = getInputStreamPartitions(config, streamMetadataCache)
     config.getSSPMatcherClass match {
-      case Some(s) => {
+      case Some(s) =>
         val jfr = config.getSSPMatcherConfigJobFactoryRegex.r
         config.getStreamJobFactoryClass match {
-          case Some(jfr(_*)) => {
-            info("before match: allSystemStreamPartitions.size = %s" format (allSystemStreamPartitions.size))
+          case Some(jfr(_*)) =>
+            info("before match: allSystemStreamPartitions.size = %s" format allSystemStreamPartitions.size)
             val sspMatcher = Util.getObj(s, classOf[SystemStreamPartitionMatcher])
             val matchedPartitions = sspMatcher.filter(allSystemStreamPartitions.asJava, config).asScala.toSet
             // Usually a small set hence ok to log at info level
-            info("after match: matchedPartitions = %s" format (matchedPartitions))
+            info("after match: matchedPartitions = %s" format matchedPartitions)
             matchedPartitions
-          }
           case _ => allSystemStreamPartitions
         }
-      }
       case _ => allSystemStreamPartitions
     }
   }
 
   /**
-   * Gets a SystemStreamPartitionGrouper object from the configuration.
-   */
+    * Finds the {@see SystemStreamPartitionGrouperFactory} from the {@param config}. Instantiates the  {@see SystemStreamPartitionGrouper}
+    * object through the factory.
+    * @param config the configuration of the samza job.
+    * @return the instantiated {@see SystemStreamPartitionGrouper}.
+    */
   private def getSystemStreamPartitionGrouper(config: Config) = {
     val factoryString = config.getSystemStreamPartitionGrouperFactory
     val factory = Util.getObj(factoryString, classOf[SystemStreamPartitionGrouperFactory])
     factory.getSystemStreamPartitionGrouper(config)
   }
 
+
   /**
-   * The function reads the latest checkpoint from the underlying coordinator stream and
-   * builds a new JobModel.
-   */
+    * Does the following:
+    * 1. Fetches metadata of the input streams defined in configuration through {@param streamMetadataCache}.
+    * 2. Applies the {@see SystemStreamPartitionGrouper}, {@see TaskNameGrouper} defined in the configuration
+    * to build the {@see JobModel}.
+    * @param config the configuration of the job.
+    * @param changeLogPartitionMapping the task to changelog partition mapping of the job.
+    * @param streamMetadataCache the cache that holds the partition metadata of the input streams.
+    * @param grouperMetadata provides the historical metadata of the application.
+    * @return the built {@see JobModel}.
+    */
   def readJobModel(config: Config,
                    changeLogPartitionMapping: util.Map[TaskName, Integer],
-                   localityManager: LocalityManager,
                    streamMetadataCache: StreamMetadataCache,
-                   containerIds: java.util.List[String]): JobModel = {
+                   grouperMetadata: GrouperMetadata): JobModel = {
     // Do grouping to fetch TaskName to SSP mapping
     val allSystemStreamPartitions = getMatchedInputStreamPartitions(config, streamMetadataCache)
 
     // processor list is required by some of the groupers. So, let's pass them as part of the config.
     // Copy the config and add the processor list to the config copy.
     val configMap = new util.HashMap[String, String](config)
-    configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", containerIds))
+    configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", grouperMetadata.getProcessorLocality.keySet()))
     val grouper = getSystemStreamPartitionGrouper(new MapConfig(configMap))
 
-    val groups = grouper.group(allSystemStreamPartitions.asJava)
+    val groups = grouper.group(allSystemStreamPartitions)
     info("SystemStreamPartitionGrouper %s has grouped the SystemStreamPartitions into %d tasks with the following taskNames: %s" format(grouper, groups.size(), groups.keySet()))
 
     val isHostAffinityEnabled = new ClusterManagerConfig(config).getHostAffinityEnabled
@@ -200,22 +295,18 @@ object JobModelManager extends Logging {
     // SSPTaskNameGrouper for locality, load-balancing, etc.
     val containerGrouperFactory = Util.getObj(config.getTaskNameGrouperFactory, classOf[TaskNameGrouperFactory])
     val containerGrouper = containerGrouperFactory.build(config)
-    val containerModels = {
-      containerGrouper match {
-        case grouper: BalancingTaskNameGrouper if isHostAffinityEnabled => grouper.balance(taskModels.asJava, localityManager)
-        case _ => containerGrouper.group(taskModels.asJava, containerIds)
-      }
-    }
-    val containerMap = containerModels.asScala.map { case (containerModel) => containerModel.getId -> containerModel }.toMap
-
-    if (isHostAffinityEnabled) {
-      new JobModel(config, containerMap.asJava, localityManager)
+    var containerModels: util.Set[ContainerModel] = null
+    if(isHostAffinityEnabled) {
+      containerModels = containerGrouper.group(taskModels, grouperMetadata)
     } else {
-      new JobModel(config, containerMap.asJava)
+      containerModels = containerGrouper.group(taskModels, new util.ArrayList[String](grouperMetadata.getProcessorLocality.keySet()))
     }
+    val containerMap = containerModels.asScala.map(containerModel => containerModel.getId -> containerModel).toMap
+
+    new JobModel(config, containerMap.asJava)
   }
 
-  private def getSystemNames(config: Config) = config.getSystemNames.toSet
+  private def getSystemNames(config: Config) = config.getSystemNames().toSet
 }
 
 /**
@@ -248,7 +339,7 @@ class JobModelManager(
 
   debug("Got job model: %s." format jobModel)
 
-  def start {
+  def start() {
     if (server != null) {
       debug("Starting HTTP server.")
       server.start
@@ -256,7 +347,7 @@ class JobModelManager(
     }
   }
 
-  def stop {
+  def stop() {
     if (server != null) {
       debug("Stopping HTTP server.")
       server.stop

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
index 64f516b..d16c294 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
@@ -50,7 +50,7 @@ class ProcessJobFactory extends StreamJobFactory with Logging {
     coordinatorStreamManager.bootstrap
     val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager)
 
-    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping())
+    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry)
     val jobModel = coordinator.jobModel
 
     val taskPartitionMappings: util.Map[TaskName, Integer] = new util.HashMap[TaskName, Integer]

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index 5a8d2f8..e4a7838 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -52,7 +52,7 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
     coordinatorStreamManager.bootstrap
     val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager)
 
-    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping())
+    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry)
 
     val jobModel = coordinator.jobModel
 

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
index 0c2f2fb..9e6e8d0 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
@@ -24,54 +24,34 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
-import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
+
+import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.SamzaException;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mockito;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 
 import static org.apache.samza.container.mock.ContainerMocks.*;
 import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestGroupByContainerCount {
-  private TaskAssignmentManager taskAssignmentManager;
-  private LocalityManager localityManager;
-  @Before
-  public void setup() throws Exception {
-    taskAssignmentManager = mock(TaskAssignmentManager.class);
-    localityManager = mock(LocalityManager.class);
-    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
-    Mockito.doNothing().when(taskAssignmentManager).init();
-  }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupEmptyTasks() {
-    new GroupByContainerCount(getConfig(1)).group(new HashSet());
+    new GroupByContainerCount(1).group(new HashSet<>());
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupFewerTasksThanContainers() {
     Set<TaskModel> taskModels = new HashSet<>();
     taskModels.add(getTaskModel(1));
-    new GroupByContainerCount(getConfig(2)).group(taskModels);
+    new GroupByContainerCount(2).group(taskModels);
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testGrouperResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels);
     containers.remove(containers.iterator().next());
   }
 
@@ -79,7 +59,7 @@ public class TestGroupByContainerCount {
   public void testGroupHappyPath() {
     Set<TaskModel> taskModels = generateTaskModels(5);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -106,7 +86,7 @@ public class TestGroupByContainerCount {
   public void testGroupManyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(21);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -174,11 +154,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerIncrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(4)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(4).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -213,22 +193,6 @@ public class TestGroupByContainerCount {
     assertTrue(container2.getTasks().containsKey(getTaskName(6)));
     assertTrue(container3.getTasks().containsKey(getTaskName(5)));
     assertTrue(container3.getTasks().containsKey(getTaskName(7)));
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "2");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "3");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "3");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -256,11 +220,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerDecrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -290,20 +254,6 @@ public class TestGroupByContainerCount {
     assertTrue(container0.getTasks().containsKey(getTaskName(2)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -331,15 +281,15 @@ public class TestGroupByContainerCount {
    *  T8  T7  T3
    */
   @Test
-  public void testBalancerMultipleReblances() throws Exception {
+  public void testBalancerMultipleReblances() {
     // Before
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
     // First balance
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -370,30 +320,11 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
 
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
-
-
     // Second balance
     prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
 
-    TaskAssignmentManager taskAssignmentManager2 = mock(TaskAssignmentManager.class);
-    when(taskAssignmentManager2.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
-    LocalityManager localityManager2 = mock(LocalityManager.class);
-    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager2);
-
-    containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager2);
+    GrouperMetadataImpl grouperMetadata1 = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata1);
 
     containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -427,21 +358,6 @@ public class TestGroupByContainerCount {
     assertTrue(container2.getTasks().containsKey(getTaskName(6)));
     assertTrue(container2.getTasks().containsKey(getTaskName(2)));
     assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
-
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
-
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2");
-
-    verify(taskAssignmentManager2, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -466,11 +382,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerSame() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -496,9 +412,6 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
     assertTrue(container1.getTasks().containsKey(getTaskName(5)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
-
-    verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString());
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -528,19 +441,19 @@ public class TestGroupByContainerCount {
   public void testBalancerAfterContainerSameCustomAssignment() {
     Set<TaskModel> taskModels = generateTaskModels(9);
 
-    Map<String, String> prevTaskToContainerMapping = new HashMap<>();
-    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(6).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(7).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(8).getTaskName(), "1");
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0), "0");
+    prevTaskToContainerMapping.put(getTaskName(1), "0");
+    prevTaskToContainerMapping.put(getTaskName(2), "0");
+    prevTaskToContainerMapping.put(getTaskName(3), "0");
+    prevTaskToContainerMapping.put(getTaskName(4), "0");
+    prevTaskToContainerMapping.put(getTaskName(5), "0");
+    prevTaskToContainerMapping.put(getTaskName(6), "1");
+    prevTaskToContainerMapping.put(getTaskName(7), "1");
+    prevTaskToContainerMapping.put(getTaskName(8), "1");
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -566,9 +479,6 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(6)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(8)));
-
-    verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString());
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -597,16 +507,16 @@ public class TestGroupByContainerCount {
   public void testBalancerAfterContainerSameCustomAssignmentAndContainerIncrease() {
     Set<TaskModel> taskModels = generateTaskModels(6);
 
-    Map<String, String> prevTaskToContainerMapping = new HashMap<>();
-    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "1");
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0), "0");
+    prevTaskToContainerMapping.put(getTaskName(1), "1");
+    prevTaskToContainerMapping.put(getTaskName(2), "1");
+    prevTaskToContainerMapping.put(getTaskName(3), "1");
+    prevTaskToContainerMapping.put(getTaskName(4), "1");
+    prevTaskToContainerMapping.put(getTaskName(5), "1");
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -633,146 +543,106 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(2)));
     assertTrue(container2.getTasks().containsKey(getTaskName(4)));
     assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "2");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "0");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerOldContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(1).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerNewContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerEmptyTaskMapping() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<>());
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testGroupTaskCountIncrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(generateTaskModels(taskCount - 1)); // Here's the key step
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1)); // Here's the key step
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testGroupTaskCountDecrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(generateTaskModels(taskCount + 1)); // Here's the key step
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1)); // Here's the key step
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerNewContainerCountGreaterThanTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    new GroupByContainerCount(getConfig(5)).balance(taskModels, localityManager);     // Should throw
+    new GroupByContainerCount(5).group(taskModels, grouperMetadata);     // Should throw
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerEmptyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    new GroupByContainerCount(getConfig(5)).balance(new HashSet<>(), localityManager);     // Should throw
+    new GroupByContainerCount(5).group(new HashSet<>(), grouperMetadata);
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testBalancerResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
     containers.remove(containers.iterator().next());
   }
 
@@ -780,32 +650,20 @@ public class TestGroupByContainerCount {
   public void testBalancerThrowsOnNonIntegerContainerIds() {
     Set<TaskModel> taskModels = generateTaskModels(3);
     Set<ContainerModel> prevContainers = new HashSet<>();
-    taskModels.forEach(model -> {
-        prevContainers.add(
-          new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model)));
-      });
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
-
-    new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager); //Should throw
-
+    taskModels.forEach(model -> prevContainers.add(new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model))));
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    new GroupByContainerCount(3).group(taskModels, grouperMetadata); //Should throw
   }
 
   @Test
   public void testBalancerWithNullLocalityManager() {
     Set<TaskModel> taskModels = generateTaskModels(3);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, null);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, null);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
   }
-
-
-  Config getConfig(int containerCount) {
-    Map<String, String> config = new HashMap<>();
-    config.put(JobConfig.JOB_CONTAINER_COUNT(), String.valueOf(containerCount));
-    return new MapConfig(config);
-  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
index 5bb78e8..12b6b1e 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
@@ -20,6 +20,7 @@
 package org.apache.samza.container.grouper.task;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -29,35 +30,24 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
+
+import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.runtime.LocationId;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
-
-import static org.apache.samza.container.mock.ContainerMocks.*;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
 
+import static org.apache.samza.container.mock.ContainerMocks.generateTaskModels;
+import static org.apache.samza.container.mock.ContainerMocks.getTaskName;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerIds.class})
 public class TestGroupByContainerIds {
 
-  @Before
-  public void setup() throws Exception {
-    TaskAssignmentManager taskAssignmentManager = mock(TaskAssignmentManager.class);
-    LocalityManager localityManager = mock(LocalityManager.class);
-    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
-  }
-
   private Config buildConfigForContainerCount(int count) {
     Map<String, String> map = new HashMap<>();
     map.put("job.container.count", String.valueOf(count));
@@ -67,6 +57,7 @@ public class TestGroupByContainerIds {
   private TaskNameGrouper buildSimpleGrouper() {
     return buildSimpleGrouper(1);
   }
+
   private TaskNameGrouper buildSimpleGrouper(int containerCount) {
     return new GroupByContainerIdsFactory().build(buildConfigForContainerCount(containerCount));
   }
@@ -114,7 +105,8 @@ public class TestGroupByContainerIds {
   public void testGroupWithNullContainerIds() {
     Set<TaskModel> taskModels = generateTaskModels(5);
 
-    Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, null);
+    List<String> containerIds = null;
+    Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, containerIds);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -251,4 +243,264 @@ public class TestGroupByContainerIds {
     assertEquals(1, actualContainerModels.size());
     assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels);
   }
+
+  @Test
+  public void testShouldUseTaskLocalityWhenGeneratingContainerModels() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+                                                                testProcessorId2, testLocationId2,
+                                                                testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+                                                             testTaskName2, testLocationId2,
+                                                             testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+                                                                  new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+                                                                  new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testGenerateContainerModelForSingleContainer() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(1);
+
+    String testProcessorId1 = "testProcessorId1";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+                                                             testTaskName2, testLocationId2,
+                                                             testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1,
+                                                                                                                       testTaskName2, testTaskModel2,
+                                                                                                                       testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testShouldGenerateCorrectContainerModelWhenTaskLocalityIsEmpty() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+                                                                testProcessorId2, testLocationId2,
+                                                                testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+                                                                  new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+                                                                  new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testShouldFailWhenProcessorLocalityIsEmpty() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
+
+    taskNameGrouper.group(new HashSet<>(), grouperMetadata);
+  }
+
+  @Test
+  public void testShouldGenerateIdenticalTaskDistributionWhenNoChangeInProcessorGroup() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+            testProcessorId2, testLocationId2,
+            testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+            testTaskName2, testLocationId2,
+            testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+            new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+            new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+
+    actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testShouldMinimizeTaskShuffleWhenAvailableProcessorInGroupChanges() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+            testProcessorId2, testLocationId2,
+            testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+            testTaskName2, testLocationId2,
+            testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+            new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+            new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+
+    processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+                                        testProcessorId2, testLocationId2);
+
+    grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1, testTaskName3, testTaskModel3)),
+                                              new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)));
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testMoreTasksThanProcessors() {
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+        testProcessorId2, testLocationId2);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+        testTaskName2, testLocationId2,
+        testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+
+    Set<TaskModel> taskModels = generateTaskModels(1);
+    List<String> containerIds = ImmutableList.of(testProcessorId1, testProcessorId2);
+
+    Map<TaskName, TaskModel> expectedTasks = taskModels.stream()
+        .collect(Collectors.toMap(TaskModel::getTaskName, x -> x));
+    ContainerModel expectedContainerModel = new ContainerModel(testProcessorId1, expectedTasks);
+
+    Set<ContainerModel> actualContainerModels = buildSimpleGrouper().group(taskModels, grouperMetadata);
+
+    assertEquals(1, actualContainerModels.size());
+    assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels);
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
index fcdbf08..60164b2 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
@@ -68,7 +68,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testTaskAssignmentManager() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1", "Task2", "2", "Task3", "0", "Task4", "1");
 
@@ -86,7 +85,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testDeleteMappings() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1");
 
@@ -108,7 +106,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testTaskAssignmentManagerEmptyCoordinatorStream() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = new HashMap<>();
     Map<String, String> localMap = taskAssignmentManager.readTaskAssignment();

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
index ca9def2..be240b1 100644
--- a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
+++ b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
@@ -117,11 +117,11 @@ public class ContainerMocks {
     return values;
   }
 
-  public static Map<String, String> generateTaskContainerMapping(Set<ContainerModel> containers) {
-    Map<String, String> taskMapping = new HashMap<>();
+  public static Map<TaskName, String> generateTaskContainerMapping(Set<ContainerModel> containers) {
+    Map<TaskName, String> taskMapping = new HashMap<>();
     for (ContainerModel container : containers) {
       for (TaskName taskName : container.getTasks().keySet()) {
-        taskMapping.put(taskName.getTaskName(), container.getId());
+        taskMapping.put(taskName, container.getId());
       }
     }
     return taskMapping;

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
index ea25ec1..02aaaa7 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
@@ -19,15 +19,15 @@
 
 package org.apache.samza.coordinator;
 
-import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import org.apache.samza.config.Config;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.system.StreamMetadataCache;
 
 /**
@@ -49,15 +49,8 @@ public class JobModelManagerTestUtil {
     return new JobModelManager(jobModel, server, null);
   }
 
-  public static JobModelManager getJobModelManagerUsingReadModel(Config config, int containerCount, StreamMetadataCache streamMetadataCache,
-    LocalityManager locManager, HttpServer server) {
-    List<String> containerIds = new ArrayList<>();
-    for (int i = 0; i < containerCount; i++) {
-      containerIds.add(String.valueOf(i));
-    }
-    JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), locManager, streamMetadataCache, containerIds);
-    return new JobModelManager(jobModel, server, null);
+  public static JobModelManager getJobModelManagerUsingReadModel(Config config, StreamMetadataCache streamMetadataCache, HttpServer server, LocalityManager localityManager, Map<String, LocationId> processorLocality) {
+    JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), streamMetadataCache, new GrouperMetadataImpl(processorLocality, new HashMap<>(), new HashMap<>(), new HashMap<>()));
+    return new JobModelManager(new JobModel(jobModel.getConfig(), jobModel.getContainers(), localityManager), server, localityManager);
   }
-
-
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
index 1dbf132..6048466 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
@@ -19,20 +19,30 @@
 
 package org.apache.samza.coordinator;
 
+import com.google.common.collect.ImmutableMap;
+import java.util.HashSet;
+import java.util.Set;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.container.grouper.task.GroupByContainerCount;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.container.grouper.task.TaskAssignmentManager;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.JobModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.testUtils.MockHttpServer;
 import org.eclipse.jetty.servlet.DefaultServlet;
 import org.eclipse.jetty.servlet.ServletHolder;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -40,14 +50,15 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Collections;
 
+import static org.apache.samza.coordinator.JobModelManager.*;
 import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.anyBoolean;
 import static org.mockito.Matchers.argThat;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
 
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentMatcher;
+import org.mockito.Mockito;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
@@ -60,7 +71,6 @@ import scala.collection.JavaConversions;
 @PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestJobModelManager {
   private final TaskAssignmentManager mockTaskManager = mock(TaskAssignmentManager.class);
-  private final LocalityManager mockLocalityManager = mock(LocalityManager.class);
   private final Map<String, Map<String, String>> localityMappings = new HashMap<>();
   private final HttpServer server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class));
   private final SystemStream inputStream = new SystemStream("test-system", "test-stream");
@@ -75,7 +85,6 @@ public class TestJobModelManager {
 
   @Before
   public void setup() throws Exception {
-    when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
     when(mockStreamMetadataCache.getStreamMetadata(argThat(new ArgumentMatcher<scala.collection.immutable.Set<SystemStream>>() {
       @Override
       public boolean matches(Object argument) {
@@ -105,11 +114,15 @@ public class TestJobModelManager {
         put("job.host-affinity.enabled", "true");
       }
     });
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
 
-    this.localityMappings.put("0", new HashMap<String, String>() { {
+    localityMappings.put("0", new HashMap<String, String>() { {
         put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
       } });
-    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+    when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
+
+    Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity"));
+    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
 
     assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", "abc-affinity"); } });
   }
@@ -132,11 +145,96 @@ public class TestJobModelManager {
       }
     });
 
-    this.localityMappings.put("0", new HashMap<String, String>() { {
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+    localityMappings.put("0", new HashMap<String, String>() { {
         put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
       } });
-    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+    when(mockLocalityManager.readContainerLocality()).thenReturn(new HashMap<>());
+
+    Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity"));
+
+    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
 
     assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", null); } });
   }
+
+  @Test
+  public void testGetGrouperMetadata() {
+    // Mocking setup.
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+    TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+
+    Map<String, Map<String, String>> localityMappings = new HashMap<>();
+    localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+    Map<String, String> taskAssignment = ImmutableMap.of("task-0", "0");
+
+    // Mock the container locality assignment.
+    when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+    // Mock the container to task assignment.
+    when(mockTaskAssignmentManager.readTaskAssignment()).thenReturn(taskAssignment);
+
+    GrouperMetadataImpl grouperMetadata = JobModelManager.getGrouperMetadata(new MapConfig(), mockLocalityManager, mockTaskAssignmentManager);
+
+    Mockito.verify(mockLocalityManager).readContainerLocality();
+    Mockito.verify(mockTaskAssignmentManager).readTaskAssignment();
+
+    Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), grouperMetadata.getProcessorLocality());
+    Assert.assertEquals(ImmutableMap.of(new TaskName("task-0"), new LocationId("abc-affinity")), grouperMetadata.getTaskLocality());
+  }
+
+  @Test
+  public void testGetProcessorLocality() {
+    // Mock the dependencies.
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+    Map<String, Map<String, String>> localityMappings = new HashMap<>();
+    localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+    // Mock the container locality assignment.
+    when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+    Map<String, LocationId> processorLocality = JobModelManager.getProcessorLocality(new MapConfig(), mockLocalityManager);
+
+    Mockito.verify(mockLocalityManager).readContainerLocality();
+    Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), processorLocality);
+  }
+
+  @Test
+  public void testUpdateTaskAssignments() {
+    // Mocking setup.
+    JobModel mockJobModel = Mockito.mock(JobModel.class);
+    GrouperMetadataImpl mockGrouperMetadata = Mockito.mock(GrouperMetadataImpl.class);
+    TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+
+    Map<TaskName, TaskModel> taskModelMap = new HashMap<>();
+    taskModelMap.put(new TaskName("task-1"), new TaskModel(new TaskName("task-1"), new HashSet<>(), new Partition(0)));
+    taskModelMap.put(new TaskName("task-2"), new TaskModel(new TaskName("task-2"), new HashSet<>(), new Partition(1)));
+    taskModelMap.put(new TaskName("task-3"), new TaskModel(new TaskName("task-3"), new HashSet<>(), new Partition(2)));
+    taskModelMap.put(new TaskName("task-4"), new TaskModel(new TaskName("task-4"), new HashSet<>(), new Partition(3)));
+    ContainerModel containerModel = new ContainerModel("test-container-id", taskModelMap);
+    Map<String, ContainerModel> containerMapping = ImmutableMap.of("test-container-id", containerModel);
+
+    when(mockJobModel.getContainers()).thenReturn(containerMapping);
+    when(mockGrouperMetadata.getPreviousTaskToProcessorAssignment()).thenReturn(new HashMap<>());
+    Mockito.doNothing().when(mockTaskAssignmentManager).writeTaskContainerMapping(Mockito.any(), Mockito.any());
+
+    JobModelManager.updateTaskAssignments(mockJobModel, mockTaskAssignmentManager, mockGrouperMetadata);
+
+    Set<String> taskNames = new HashSet<String>();
+    taskNames.add("task-4");
+    taskNames.add("task-2");
+    taskNames.add("task-3");
+    taskNames.add("task-1");
+
+    // Verifications
+    Mockito.verify(mockJobModel, atLeast(1)).getContainers();
+    Mockito.verify(mockTaskAssignmentManager).deleteTaskContainerMappings((Iterable<String>) taskNames);
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-1", "test-container-id");
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-2", "test-container-id");
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-3", "test-container-id");
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-4", "test-container-id");
+  }
 }