You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ja...@apache.org on 2018/12/05 18:57:02 UTC

[1/3] samza git commit: SAMZA-1973: Unify the TaskNameGrouper interface for yarn and standalone.

Repository: samza
Updated Branches:
  refs/heads/master c7e5dcba4 -> 5ea72584f


http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/zk/TestZkUtils.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/zk/TestZkUtils.java b/samza-core/src/test/java/org/apache/samza/zk/TestZkUtils.java
index d0008b1..29e861b 100644
--- a/samza-core/src/test/java/org/apache/samza/zk/TestZkUtils.java
+++ b/samza-core/src/test/java/org/apache/samza/zk/TestZkUtils.java
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.zk;
 
+import com.google.common.collect.ImmutableMap;
 import java.lang.reflect.Field;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -37,15 +38,15 @@ import org.I0Itec.zkclient.exception.ZkNodeExistsException;
 import org.apache.commons.lang3.reflect.FieldUtils;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.MapConfig;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.testUtils.EmbeddedZookeeper;
 import org.apache.samza.util.NoOpMetricsRegistry;
 import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.Before;
-import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -67,15 +68,11 @@ public class TestZkUtils {
   @Rule
   public Timeout testTimeOutInMillis = new Timeout(120000);
 
-  @BeforeClass
-  public static void setup() throws InterruptedException {
-    zkServer = new EmbeddedZookeeper();
-    zkServer.setup();
-  }
-
   @Before
   public void testSetup() {
     try {
+      zkServer = new EmbeddedZookeeper();
+      zkServer.setup();
       zkClient = new ZkClient(
           new ZkConnection("127.0.0.1:" + zkServer.getPort(), SESSION_TIMEOUT_MS),
           CONNECTION_TIMEOUT_MS);
@@ -89,14 +86,17 @@ public class TestZkUtils {
     }
 
     zkUtils = getZkUtils();
-
     zkUtils.connect();
   }
 
   @After
   public void testTeardown() {
     if (zkClient != null) {
-      zkUtils.close();
+      try {
+        zkUtils.close();
+      } finally {
+        zkServer.teardown();
+      }
     }
   }
 
@@ -105,11 +105,6 @@ public class TestZkUtils {
                        SESSION_TIMEOUT_MS, new NoOpMetricsRegistry());
   }
 
-  @AfterClass
-  public static void teardown() {
-    zkServer.teardown();
-  }
-
   @Test
   public void testRegisterProcessorId() {
     String assignedPath = zkUtils.registerProcessorAndGetId(new ProcessorData("host", "1"));
@@ -135,6 +130,54 @@ public class TestZkUtils {
 
 
   @Test
+  public void testReadAfterWriteTaskLocality() {
+    zkUtils.writeTaskLocality(new TaskName("task-1"), new LocationId("LocationId-1"));
+    zkUtils.writeTaskLocality(new TaskName("task-2"), new LocationId("LocationId-2"));
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(new TaskName("task-1"), new LocationId("LocationId-1"),
+                                                             new TaskName("task-2"), new LocationId("LocationId-2"));
+
+    Assert.assertEquals(taskLocality, zkUtils.readTaskLocality());
+  }
+
+  @Test
+  public void testReadWhenTaskLocalityDoesNotExist() {
+    Map<TaskName, LocationId> taskLocality = zkUtils.readTaskLocality();
+
+    Assert.assertEquals(0, taskLocality.size());
+  }
+
+  @Test
+  public void testWriteTaskLocalityShouldUpdateTheExistingValue() {
+    zkUtils.writeTaskLocality(new TaskName("task-1"), new LocationId("LocationId-1"));
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(new TaskName("task-1"), new LocationId("LocationId-1"));
+    Assert.assertEquals(taskLocality, zkUtils.readTaskLocality());
+
+    zkUtils.writeTaskLocality(new TaskName("task-1"), new LocationId("LocationId-2"));
+
+    taskLocality = ImmutableMap.of(new TaskName("task-1"), new LocationId("LocationId-2"));
+    Assert.assertEquals(taskLocality, zkUtils.readTaskLocality());
+  }
+
+  @Test
+  public void testReadTaskLocalityShouldReturnAllTheExistingLocalityValue() {
+    zkUtils.writeTaskLocality(new TaskName("task-1"), new LocationId("LocationId-1"));
+    zkUtils.writeTaskLocality(new TaskName("task-2"), new LocationId("LocationId-2"));
+    zkUtils.writeTaskLocality(new TaskName("task-3"), new LocationId("LocationId-3"));
+    zkUtils.writeTaskLocality(new TaskName("task-4"), new LocationId("LocationId-4"));
+    zkUtils.writeTaskLocality(new TaskName("task-5"), new LocationId("LocationId-5"));
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(new TaskName("task-1"), new LocationId("LocationId-1"),
+                                                             new TaskName("task-2"), new LocationId("LocationId-2"),
+                                                             new TaskName("task-3"), new LocationId("LocationId-3"),
+                                                             new TaskName("task-4"), new LocationId("LocationId-4"),
+                                                             new TaskName("task-5"), new LocationId("LocationId-5"));
+
+    Assert.assertEquals(taskLocality, zkUtils.readTaskLocality());
+  }
+
+  @Test
   public void testGetAllProcessorNodesShouldReturnEmptyForNonExistingZookeeperNodes() {
     List<ZkUtils.ProcessorNode> processorsIDs = zkUtils.getAllProcessorNodes();
 

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index 760e358..49a4e84 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -22,25 +22,29 @@ package org.apache.samza.container
 import java.util
 import java.util.concurrent.atomic.AtomicReference
 
-import org.apache.samza.config.{Config, MapConfig}
-import org.apache.samza.context.{ApplicationContainerContext, ContainerContext}
+import org.apache.samza.config.{ClusterManagerConfig, Config, MapConfig}
+import org.apache.samza.context.{ApplicationContainerContext, ContainerContext, JobContext}
 import org.apache.samza.coordinator.JobModelManager
 import org.apache.samza.coordinator.server.{HttpServer, JobServlet}
 import org.apache.samza.job.model.{ContainerModel, JobModel, TaskModel}
-import org.apache.samza.metrics.{Gauge, Timer}
+import org.apache.samza.metrics.{Gauge, MetricsReporter, Timer}
 import org.apache.samza.storage.{ContainerStorageManager, TaskStorageManager}
 import org.apache.samza.system._
+import org.apache.samza.task.{StreamTaskFactory, TaskFactory}
 import org.apache.samza.{Partition, SamzaContainerStatus}
 import org.junit.Assert._
 import org.junit.{Before, Test}
 import org.mockito.Matchers.{any, notNull}
 import org.mockito.Mockito._
-import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.mockito.{ArgumentCaptor, Mock, Mockito, MockitoAnnotations}
 import org.scalatest.junit.AssertionsForJUnit
 import org.scalatest.mockito.MockitoSugar
 
 import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   private val TASK_NAME = new TaskName("taskName")
@@ -258,6 +262,32 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
     assertEquals(Set(), SamzaContainer.getChangelogSSPsForContainer(containerModel, Map()))
   }
 
+  @Test
+  def testStoreContainerLocality():Unit = {
+    val localityManager: LocalityManager = Mockito.mock[LocalityManager](classOf[LocalityManager])
+    val containerContext: ContainerContext = Mockito.mock[ContainerContext](classOf[ContainerContext])
+    val containerModel: ContainerModel = Mockito.mock[ContainerModel](classOf[ContainerModel])
+    val testContainerId = "1"
+    Mockito.when(containerModel.getId).thenReturn(testContainerId)
+    Mockito.when(containerContext.getContainerModel).thenReturn(containerModel)
+
+    val samzaContainer: SamzaContainer = new SamzaContainer(
+      new MapConfig(Map(ClusterManagerConfig.JOB_HOST_AFFINITY_ENABLED -> "true")),
+      Map(TASK_NAME -> this.taskInstance),
+      this.runLoop,
+      this.systemAdmins,
+      this.consumerMultiplexer,
+      this.producerMultiplexer,
+      metrics,
+      containerContext = containerContext,
+      applicationContainerContextOption = null,
+      localityManager = localityManager,
+      containerStorageManager = Mockito.mock(classOf[ContainerStorageManager]))
+
+    samzaContainer.storeContainerLocality
+    Mockito.verify(localityManager).writeContainerToHostMapping(any(), any())
+  }
+
   private def setupSamzaContainer(applicationContainerContext: Option[ApplicationContainerContext]) {
     this.samzaContainer = new SamzaContainer(
       this.config,

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
----------------------------------------------------------------------
diff --git a/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java b/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
index c80ce1b..cada93d 100644
--- a/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
+++ b/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
@@ -113,7 +113,7 @@ public class TestRunner {
         new File(System.getProperty("java.io.tmpdir"), this.inMemoryScope + "-logged").getAbsolutePath());
     addConfig(JobConfig.JOB_DEFAULT_SYSTEM(), JOB_DEFAULT_SYSTEM);
     // Disabling host affinity since it requires reading locality information from a Kafka coordinator stream
-    addConfig(ClusterManagerConfig.CLUSTER_MANAGER_HOST_AFFINITY_ENABLED, Boolean.FALSE.toString());
+    addConfig(ClusterManagerConfig.JOB_HOST_AFFINITY_ENABLED, Boolean.FALSE.toString());
     addConfig(InMemorySystemConfig.INMEMORY_SCOPE, inMemoryScope);
     addConfig(new InMemorySystemDescriptor(JOB_DEFAULT_SYSTEM).withInMemoryScope(inMemoryScope).toConfig());
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-test/src/test/java/org/apache/samza/processor/TestZkStreamProcessorBase.java
----------------------------------------------------------------------
diff --git a/samza-test/src/test/java/org/apache/samza/processor/TestZkStreamProcessorBase.java b/samza-test/src/test/java/org/apache/samza/processor/TestZkStreamProcessorBase.java
index 585af0f..7bd99bb 100644
--- a/samza-test/src/test/java/org/apache/samza/processor/TestZkStreamProcessorBase.java
+++ b/samza-test/src/test/java/org/apache/samza/processor/TestZkStreamProcessorBase.java
@@ -67,7 +67,6 @@ import org.junit.Before;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-
 public class TestZkStreamProcessorBase extends StandaloneIntegrationTestHarness {
   private static final String TASK_SHUTDOWN_MS = "2000";
   private static final String JOB_DEBOUNCE_TIME_MS = "2000";
@@ -131,7 +130,6 @@ public class TestZkStreamProcessorBase extends StandaloneIntegrationTestHarness
   protected StreamProcessor createStreamProcessor(final String pId, Map<String, String> map, final CountDownLatch waitStart,
       final CountDownLatch waitStop) {
     map.put(ApplicationConfig.PROCESSOR_ID, pId);
-
     Config config = new MapConfig(map);
     String jobCoordinatorFactoryClassName = new JobCoordinatorConfig(config).getJobCoordinatorFactoryClassName();
     JobCoordinator jobCoordinator = Util.getObj(jobCoordinatorFactoryClassName, JobCoordinatorFactory.class).getJobCoordinator(config);

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
----------------------------------------------------------------------
diff --git a/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java b/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
index 78dad0d..6faf80b 100644
--- a/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
+++ b/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
@@ -42,19 +42,20 @@ import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.ClusterManagerConfig;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.JobCoordinatorConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.TaskConfig;
 import org.apache.samza.config.TaskConfigJava;
 import org.apache.samza.config.ZkConfig;
+import org.apache.samza.SamzaException;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.job.ApplicationStatus;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.runtime.ApplicationRunners;
-import org.apache.samza.SamzaException;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.test.StandaloneIntegrationTestHarness;
 import org.apache.samza.test.StandaloneTestUtils;
@@ -209,6 +210,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
         .put(TaskConfig.DROP_PRODUCER_ERRORS(), "true")
         .put(JobConfig.JOB_DEBOUNCE_TIME_MS(), JOB_DEBOUNCE_TIME_MS)
         .put(JobConfig.MONITOR_PARTITION_CHANGE_FREQUENCY_MS(), "1000")
+        .put(ClusterManagerConfig.HOST_AFFINITY_ENABLED, "true")
         .build();
     Map<String, String> applicationConfig = Maps.newHashMap(samzaContainerConfig);
 
@@ -234,7 +236,8 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
 
     // Configuration, verification variables
     MapConfig testConfig = new MapConfig(ImmutableMap.of(JobConfig.SSP_GROUPER_FACTORY(),
-        "org.apache.samza.container.grouper.stream.GroupBySystemStreamPartitionFactory", JobConfig.JOB_DEBOUNCE_TIME_MS(), "10"));
+        "org.apache.samza.container.grouper.stream.GroupBySystemStreamPartitionFactory", JobConfig.JOB_DEBOUNCE_TIME_MS(), "10",
+            ClusterManagerConfig.JOB_HOST_AFFINITY_ENABLED, "true"));
     // Declared as final array to update it from streamApplication callback(Variable should be declared final to access in lambda block).
     final JobModel[] previousJobModel = new JobModel[1];
     final String[] previousJobModelVersion = new String[1];
@@ -697,9 +700,6 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
 
     // Validate that the input partition count is 100 in the new JobModel.
     Assert.assertEquals(100, ssps.size());
-    appRunner1.kill();
-    appRunner1.waitForFinish();
-    assertEquals(ApplicationStatus.SuccessfulFinish, appRunner1.status());
   }
 
   private static Set<SystemStreamPartition> getSystemStreamPartitions(JobModel jobModel) {

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java b/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java
index 4adb93a..a32ea81 100644
--- a/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java
+++ b/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java
@@ -40,6 +40,7 @@ import org.apache.samza.coordinator.stream.CoordinatorStreamManager;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import org.apache.samza.job.yarn.ClientHelper;
 import org.apache.samza.metrics.JmxMetricsAccessor;
+import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.metrics.MetricsValidator;
 import org.apache.samza.storage.ChangelogStreamManager;
@@ -152,12 +153,13 @@ public class YarnJobValidationTool {
   }
 
   public void validateJmxMetrics() throws Exception {
-    CoordinatorStreamManager coordinatorStreamManager = new CoordinatorStreamManager(config, new MetricsRegistryMap());
+    MetricsRegistry metricsRegistry = new MetricsRegistryMap();
+    CoordinatorStreamManager coordinatorStreamManager = new CoordinatorStreamManager(config, metricsRegistry);
     coordinatorStreamManager.register(getClass().getSimpleName());
     coordinatorStreamManager.start();
     coordinatorStreamManager.bootstrap();
     ChangelogStreamManager changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager);
-    JobModelManager jobModelManager = JobModelManager.apply(coordinatorStreamManager.getConfig(), changelogStreamManager.readPartitionMapping());
+    JobModelManager jobModelManager = JobModelManager.apply(coordinatorStreamManager.getConfig(), changelogStreamManager.readPartitionMapping(), metricsRegistry);
     validator.init(config);
     Map<String, String> jmxUrls = jobModelManager.jobModel().getAllContainerToHostValues(SetContainerHostMapping.JMX_TUNNELING_URL_KEY);
     for (Map.Entry<String, String> entry : jmxUrls.entrySet()) {

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java b/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java
index d19badc..0788b86 100644
--- a/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java
+++ b/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java
@@ -276,9 +276,7 @@ public class TestApplicationMasterRestClient {
         new TaskModel(new TaskName("task2"),
             ImmutableSet.of(new SystemStreamPartition(new SystemStream("system1", "stream1"), new Partition(1))),
             new Partition(1)));
-    Map<String, String> config = new HashMap<>();
-    config.put(JobConfig.JOB_CONTAINER_COUNT(), String.valueOf(2));
-    GroupByContainerCount grouper = new GroupByContainerCount(new MapConfig(config));
+    GroupByContainerCount grouper = new GroupByContainerCount(2);
     Set<ContainerModel> containerModels = grouper.group(taskModels);
     HashMap<String, ContainerModel> containers = new HashMap<>();
     for (ContainerModel containerModel : containerModels) {


[3/3] samza git commit: SAMZA-1973: Unify the TaskNameGrouper interface for yarn and standalone.

Posted by ja...@apache.org.
SAMZA-1973: Unify the TaskNameGrouper interface for yarn and standalone.

This patch consists of the following changes:
* Unify the different methods present in the TaskNameGrouper interface. This will enable us to have a single interface method usable for both the yarn and standalone models.
* Generate locationId aware task assignment to processors in standalone.
* Move the task assignment persistence logic from a custom `TaskNameGrouper` implementation to `JobModelManager`, so that this works for any kind of custom group.
* General code clean up in `JobModelManager`,  `TaskAssignmentManager` and in other samza internal classes.
* Read/write taskLocality of the processors in standalone.
* Updated the existing java docs and added java docs where they were missing.

Testing:
* Fixed the existing unit-tests due to the changes.
* Added new unit tests for the functionality changed added as a part of this patch.
* Tested this patch with a sample job from `hello-samza` project and verified that it works as expected.

Please refer to [SEP-11](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75957309) for more details.

Author: Shanthoosh Venkataraman <sp...@usc.edu>
Author: Shanthoosh Venkataraman <sv...@linkedin.com>
Author: svenkata <sv...@linkedin.com>

Reviewers: Prateek M<pm...@linkedin.com>

Closes #790 from shanthoosh/task_name_grouper_changes


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/5ea72584
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/5ea72584
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/5ea72584

Branch: refs/heads/master
Commit: 5ea72584f6b92937ec130f486d6f70603b7188c2
Parents: c7e5dcb
Author: Shanthoosh Venkataraman <sp...@usc.edu>
Authored: Wed Dec 5 10:56:55 2018 -0800
Committer: Jagadish <jv...@linkedin.com>
Committed: Wed Dec 5 10:56:55 2018 -0800

----------------------------------------------------------------------
 .../org/apache/samza/job/model/TaskModel.java   |   1 -
 .../samza/coordinator/AzureJobCoordinator.java  |   7 +-
 .../ClusterBasedJobCoordinator.java             |   3 +-
 .../samza/config/ClusterManagerConfig.java      |  14 +-
 .../grouper/task/BalancingTaskNameGrouper.java  |   5 +-
 .../grouper/task/GroupByContainerCount.java     | 228 ++++---------
 .../task/GroupByContainerCountFactory.java      |   3 +-
 .../grouper/task/GroupByContainerIds.java       | 171 ++++++++--
 .../container/grouper/task/GrouperMetadata.java |  58 ++++
 .../grouper/task/GrouperMetadataImpl.java       |  72 +++++
 .../grouper/task/TaskAssignmentManager.java     |   3 -
 .../samza/container/grouper/task/TaskGroup.java |  85 +++++
 .../container/grouper/task/TaskNameGrouper.java |  39 ++-
 .../grouper/task/TaskNameGrouperFactory.java    |   2 +-
 .../samza/execution/ExecutionPlanner.java       |   2 +-
 .../apache/samza/processor/StreamProcessor.java |   3 +-
 .../samza/runtime/LocalContainerRunner.java     |  18 +-
 .../standalone/PassthroughJobCoordinator.java   |  30 +-
 .../apache/samza/storage/StorageRecovery.java   |   5 +-
 .../org/apache/samza/zk/ZkJobCoordinator.java   |  74 ++++-
 .../main/java/org/apache/samza/zk/ZkUtils.java  |  27 ++
 .../apache/samza/container/SamzaContainer.scala |  12 +-
 .../samza/coordinator/JobModelManager.scala     | 231 +++++++++----
 .../samza/job/local/ProcessJobFactory.scala     |   2 +-
 .../samza/job/local/ThreadJobFactory.scala      |   2 +-
 .../grouper/task/TestGroupByContainerCount.java | 320 ++++++-------------
 .../grouper/task/TestGroupByContainerIds.java   | 292 +++++++++++++++--
 .../grouper/task/TestTaskAssignmentManager.java |   3 -
 .../samza/container/mock/ContainerMocks.java    |   6 +-
 .../coordinator/JobModelManagerTestUtil.java    |  17 +-
 .../samza/coordinator/TestJobModelManager.java  | 114 ++++++-
 .../java/org/apache/samza/zk/TestZkUtils.java   |  73 ++++-
 .../samza/container/TestSamzaContainer.scala    |  38 ++-
 .../apache/samza/test/framework/TestRunner.java |   2 +-
 .../processor/TestZkStreamProcessorBase.java    |   2 -
 .../processor/TestZkLocalApplicationRunner.java |  10 +-
 .../samza/validation/YarnJobValidationTool.java |   6 +-
 .../webapp/TestApplicationMasterRestClient.java |   4 +-
 38 files changed, 1348 insertions(+), 636 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-api/src/main/java/org/apache/samza/job/model/TaskModel.java
----------------------------------------------------------------------
diff --git a/samza-api/src/main/java/org/apache/samza/job/model/TaskModel.java b/samza-api/src/main/java/org/apache/samza/job/model/TaskModel.java
index 7ee7609..36917cf 100644
--- a/samza-api/src/main/java/org/apache/samza/job/model/TaskModel.java
+++ b/samza-api/src/main/java/org/apache/samza/job/model/TaskModel.java
@@ -99,7 +99,6 @@ public class TaskModel implements Comparable<TaskModel> {
   }
 
   @Override
-
   public String toString() {
     return "TaskModel [taskName=" + taskName + ", systemStreamPartitions=" + systemStreamPartitions + ", changeLogPartition=" + changelogPartition + "]";
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-azure/src/main/java/org/apache/samza/coordinator/AzureJobCoordinator.java
----------------------------------------------------------------------
diff --git a/samza-azure/src/main/java/org/apache/samza/coordinator/AzureJobCoordinator.java b/samza-azure/src/main/java/org/apache/samza/coordinator/AzureJobCoordinator.java
index 96f628c..076ab54 100644
--- a/samza-azure/src/main/java/org/apache/samza/coordinator/AzureJobCoordinator.java
+++ b/samza-azure/src/main/java/org/apache/samza/coordinator/AzureJobCoordinator.java
@@ -30,6 +30,8 @@ import org.apache.samza.config.TaskConfig;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouper;
 import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory;
+import org.apache.samza.container.grouper.task.GrouperMetadata;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.coordinator.data.BarrierState;
 import org.apache.samza.coordinator.data.ProcessorEntity;
 import org.apache.samza.coordinator.scheduler.HeartbeatScheduler;
@@ -54,7 +56,6 @@ import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.collection.JavaConverters;
-
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
@@ -365,8 +366,8 @@ public class AzureJobCoordinator implements JobCoordinator {
     }
 
     // Generate the new JobModel
-    JobModel newJobModel = JobModelManager.readJobModel(this.config, Collections.emptyMap(),
-        null, streamMetadataCache, currentProcessorIds);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
+    JobModel newJobModel = JobModelManager.readJobModel(this.config, Collections.emptyMap(), streamMetadataCache, grouperMetadata);
     LOG.info("pid=" + processorId + "Generated new Job Model. Version = " + nextJMVersion);
 
     // Publish the new job model

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java b/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
index 4c5a34b..0eddbf2 100644
--- a/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
+++ b/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
@@ -185,8 +185,7 @@ public class ClusterBasedJobCoordinator {
 
     // build a JobModelManager and ChangelogStreamManager and perform partition assignments.
     changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager);
-    jobModelManager =
-        JobModelManager.apply(coordinatorStreamManager.getConfig(), changelogStreamManager.readPartitionMapping());
+    jobModelManager = JobModelManager.apply(coordinatorStreamManager.getConfig(), changelogStreamManager.readPartitionMapping(), metrics);
 
     config = jobModelManager.jobModel().getConfig();
     hasDurableStores = new StorageConfig(config).hasDurableStores();

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/config/ClusterManagerConfig.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/config/ClusterManagerConfig.java b/samza-core/src/main/java/org/apache/samza/config/ClusterManagerConfig.java
index cb86a58..eda1be8 100644
--- a/samza-core/src/main/java/org/apache/samza/config/ClusterManagerConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/ClusterManagerConfig.java
@@ -52,10 +52,14 @@ public class ClusterManagerConfig extends MapConfig {
   private static final int DEFAULT_CONTAINER_REQUEST_TIMEOUT_MS = 5000;
 
   /**
-   * Flag to indicate if host-affinity is enabled for the job or not
+   * NOTE: This field is deprecated.
    */
   public static final String HOST_AFFINITY_ENABLED = "yarn.samza.host-affinity.enabled";
-  public static final String CLUSTER_MANAGER_HOST_AFFINITY_ENABLED = "job.host-affinity.enabled";
+
+  /**
+   * Flag to indicate if host-affinity is enabled for the job or not
+   */
+  public static final String JOB_HOST_AFFINITY_ENABLED = "job.host-affinity.enabled";
 
   /**
    * Number of CPU cores to request from the cluster manager per container
@@ -145,10 +149,10 @@ public class ClusterManagerConfig extends MapConfig {
   }
 
   public boolean getHostAffinityEnabled() {
-    if (containsKey(CLUSTER_MANAGER_HOST_AFFINITY_ENABLED)) {
-      return getBoolean(CLUSTER_MANAGER_HOST_AFFINITY_ENABLED);
+    if (containsKey(JOB_HOST_AFFINITY_ENABLED)) {
+      return getBoolean(JOB_HOST_AFFINITY_ENABLED);
     } else if (containsKey(HOST_AFFINITY_ENABLED)) {
-      log.info("Configuration {} is deprecated. Please use {}", HOST_AFFINITY_ENABLED, CLUSTER_MANAGER_HOST_AFFINITY_ENABLED);
+      log.warn("Configuration {} is deprecated. Please use {}", HOST_AFFINITY_ENABLED, JOB_HOST_AFFINITY_ENABLED);
       return getBoolean(HOST_AFFINITY_ENABLED);
     } else {
       return false;

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java
index f8295c8..91eab54 100644
--- a/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/BalancingTaskNameGrouper.java
@@ -54,5 +54,8 @@ public interface BalancingTaskNameGrouper extends TaskNameGrouper {
    * @param localityManager provides a persisted task to container map to use as a baseline
    * @return                the grouped tasks in the form of ContainerModels
    */
-  Set<ContainerModel> balance(Set<TaskModel> tasks, LocalityManager localityManager);
+  @Deprecated
+  default Set<ContainerModel> balance(Set<TaskModel> tasks, LocalityManager localityManager) {
+    return group(tasks);
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
index 759f82e..8a741db 100644
--- a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
@@ -27,19 +27,13 @@ import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-
-import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.container.LocalityManager;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.SamzaException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-
 /**
  * Group the SSP taskNames by dividing the number of taskNames into the number
  * of containers (n) and assigning n taskNames to each container as returned by
@@ -51,19 +45,21 @@ import org.slf4j.LoggerFactory;
  * TODO: SAMZA-1197 - need to modify balance to work with processorId strings
  */
 public class GroupByContainerCount implements BalancingTaskNameGrouper {
-  private static final Logger log = LoggerFactory.getLogger(GroupByContainerCount.class);
+  private static final Logger LOG = LoggerFactory.getLogger(GroupByContainerCount.class);
   private final int containerCount;
-  private final Config config;
 
-  public GroupByContainerCount(Config config) {
-    this.containerCount = new JobConfig(config).getContainerCount();
-    this.config = config;
-    if (containerCount <= 0) throw new IllegalArgumentException("Must have at least one container");
+  public GroupByContainerCount(int containerCount) {
+    if (containerCount <= 0) {
+      throw new IllegalArgumentException("Must have at least one container");
+    }
+    this.containerCount = containerCount;
   }
 
+  /**
+   * {@inheritDoc}
+   */
   @Override
   public Set<ContainerModel> group(Set<TaskModel> tasks) {
-
     validateTasks(tasks);
 
     // Sort tasks by taskName.
@@ -89,79 +85,63 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
     return Collections.unmodifiableSet(containerModels);
   }
 
+  /**
+   * {@inheritDoc}
+   */
   @Override
-  public Set<ContainerModel> balance(Set<TaskModel> tasks, LocalityManager localityManager) {
-
+  public Set<ContainerModel> group(Set<TaskModel> tasks, GrouperMetadata grouperMetadata) {
     validateTasks(tasks);
 
-    if (localityManager == null) {
-      log.info("Locality manager is null. Cannot read or write task assignments. Invoking grouper.");
+    List<TaskGroup> containers = getPreviousContainers(grouperMetadata, tasks.size());
+    if (containers == null || containers.size() == 1 || containerCount == 1) {
+      LOG.info("Balancing does not apply. Invoking grouper.");
       return group(tasks);
     }
 
-    TaskAssignmentManager taskAssignmentManager =  new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
-    try {
-      List<TaskGroup> containers = getPreviousContainers(taskAssignmentManager, tasks.size());
-      if (containers == null || containers.size() == 1 || containerCount == 1) {
-        log.info("Balancing does not apply. Invoking grouper.");
-        Set<ContainerModel> models = group(tasks);
-        saveTaskAssignments(models, taskAssignmentManager);
-        return models;
-      }
-
-      int prevContainerCount = containers.size();
-      int containerDelta = containerCount - prevContainerCount;
-      if (containerDelta == 0) {
-        log.info("Container count has not changed. Reusing previous container models.");
-        return buildContainerModels(tasks, containers);
-      }
-      log.info("Container count changed from {} to {}. Balancing tasks.", prevContainerCount, containerCount);
-
-      // Calculate the expected task count per container
-      int[] expectedTaskCountPerContainer = calculateTaskCountPerContainer(tasks.size(), prevContainerCount, containerCount);
+    int prevContainerCount = containers.size();
+    int containerDelta = containerCount - prevContainerCount;
+    if (containerDelta == 0) {
+      LOG.info("Container count has not changed. Reusing previous container models.");
+      return TaskGroup.buildContainerModels(tasks, containers);
+    }
+    LOG.info("Container count changed from {} to {}. Balancing tasks.", prevContainerCount, containerCount);
 
-      // Collect excess tasks from over-assigned containers
-      List<String> taskNamesToReassign = new LinkedList<>();
-      for (int i = 0; i < prevContainerCount; i++) {
-        TaskGroup taskGroup = containers.get(i);
-        while (taskGroup.size() > expectedTaskCountPerContainer[i]) {
-          taskNamesToReassign.add(taskGroup.removeTask());
-        }
-      }
+    // Calculate the expected task count per container
+    int[] expectedTaskCountPerContainer = calculateTaskCountPerContainer(tasks.size(), prevContainerCount, containerCount);
 
-      // Assign tasks to the under-assigned containers
-      if (containerDelta > 0) {
-        List<TaskGroup> newContainers = createContainers(prevContainerCount, containerCount);
-        containers.addAll(newContainers);
-      } else {
-        containers = containers.subList(0, containerCount);
+    // Collect excess tasks from over-assigned containers
+    List<String> taskNamesToReassign = new LinkedList<>();
+    for (int i = 0; i < prevContainerCount; i++) {
+      TaskGroup taskGroup = containers.get(i);
+      while (taskGroup.size() > expectedTaskCountPerContainer[i]) {
+        taskNamesToReassign.add(taskGroup.removeLastTaskName());
       }
-      assignTasksToContainers(expectedTaskCountPerContainer, taskNamesToReassign, containers);
+    }
 
-      // Transform containers to containerModel
-      Set<ContainerModel> models = buildContainerModels(tasks, containers);
+    // Assign tasks to the under-assigned containers
+    if (containerDelta > 0) {
+      List<TaskGroup> newContainers = createContainers(prevContainerCount, containerCount);
+      containers.addAll(newContainers);
+    } else {
+      containers = containers.subList(0, containerCount);
+    }
 
-      // Save the results
-      saveTaskAssignments(models, taskAssignmentManager);
+    assignTasksToContainers(expectedTaskCountPerContainer, taskNamesToReassign, containers);
 
-      return models;
-    } finally {
-      taskAssignmentManager.close();
-    }
+    return TaskGroup.buildContainerModels(tasks, containers);
   }
 
   /**
-   * Reads the task-container mapping from the provided {@link TaskAssignmentManager} and returns a
+   * Reads the task-container mapping from the provided {@link GrouperMetadata} and returns a
    * list of TaskGroups, ordered ascending by containerId.
    *
-   * @param taskAssignmentManager the {@link TaskAssignmentManager} that will be used to retrieve the previous mapping.
-   * @param taskCount             the number of tasks, for validation against the persisted tasks.
-   * @return                      a list of TaskGroups, ordered ascending by containerId or {@code null}
-   *                              if the previous mapping doesn't exist or isn't usable.
+   * @param grouperMetadata  the {@link GrouperMetadata} will be used to retrieve the previous task to container assignments.
+   * @param taskCount       the number of tasks, for validation against the persisted tasks.
+   * @return                a list of TaskGroups, ordered ascending by containerId or {@code null}
+   *                        if the previous mapping doesn't exist or isn't usable.
    */
-  private List<TaskGroup> getPreviousContainers(TaskAssignmentManager taskAssignmentManager, int taskCount) {
-    Map<String, String> taskToContainerId = taskAssignmentManager.readTaskAssignment();
+  private List<TaskGroup> getPreviousContainers(GrouperMetadata grouperMetadata, int taskCount) {
+    Map<TaskName, String> taskToContainerId = grouperMetadata.getPreviousTaskToProcessorAssignment();
     taskToContainerId.values().forEach(id -> {
         try {
           int intId = Integer.parseInt(id);
@@ -169,19 +149,11 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
           throw new SamzaException("GroupByContainerCount cannot handle non-integer processorIds!", nfe);
         }
       });
+
     if (taskToContainerId.isEmpty()) {
-      log.info("No task assignment map was saved.");
+      LOG.info("No task assignment map was saved.");
       return null;
     } else if (taskCount != taskToContainerId.size()) {
-      log.warn(
-          "Current task count {} does not match saved task count {}. Stateful jobs may observe misalignment of keys!",
-          taskCount, taskToContainerId.size());
-      // If the tasks changed, then the partition-task grouping is also likely changed and we can't handle that
-      // without a much more complicated mapping. Further, the partition count may have changed, which means
-      // input message keys are likely reshuffled w.r.t. partitions, so the local state may not contain necessary
-      // data associated with the incoming keys. Warn the user and default to grouper
-      // In this scenario the tasks may have been reduced, so we need to delete all the existing messages
-      taskAssignmentManager.deleteTaskContainerMappings(taskToContainerId.keySet());
       return null;
     }
 
@@ -189,27 +161,13 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
     try {
       containers = getOrderedContainers(taskToContainerId);
     } catch (Exception e) {
-      log.error("Exception while parsing task mapping", e);
+      LOG.error("Exception while parsing task mapping", e);
       return null;
     }
     return containers;
   }
 
   /**
-   * Saves the task assignments specified by containers using the provided TaskAssignementManager.
-   *
-   * @param containers            the set of containers from which the task assignments will be saved.
-   * @param taskAssignmentManager the {@link TaskAssignmentManager} that will be used to save the mappings.
-   */
-  private void saveTaskAssignments(Set<ContainerModel> containers, TaskAssignmentManager taskAssignmentManager) {
-    for (ContainerModel container : containers) {
-      for (TaskName taskName : container.getTasks().keySet()) {
-        taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName(), container.getId());
-      }
-    }
-  }
-
-  /**
    * Verifies the input tasks argument and throws {@link IllegalArgumentException} if it is invalid.
    *
    * @param tasks the tasks to validate.
@@ -252,13 +210,12 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
    * @param containers            the containers (as {@link TaskGroup}) to which the tasks will be assigned.
    */
   // TODO: Change logic from using int arrays to a Map<String, Integer> (id -> taskCount)
-  private void assignTasksToContainers(int[] taskCountPerContainer, List<String> taskNamesToAssign,
-      List<TaskGroup> containers) {
+  private void assignTasksToContainers(int[] taskCountPerContainer, List<String> taskNamesToAssign, List<TaskGroup> containers) {
     for (TaskGroup taskGroup : containers) {
       for (int j = taskGroup.size(); j < taskCountPerContainer[Integer.valueOf(taskGroup.getContainerId())]; j++) {
         String taskName = taskNamesToAssign.remove(0);
         taskGroup.addTaskName(taskName);
-        log.info("Assigned task {} to container {}", taskName, taskGroup.getContainerId());
+        LOG.info("Assigned task {} to container {}", taskName, taskGroup.getContainerId());
       }
     }
   }
@@ -288,53 +245,20 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
   }
 
   /**
-   * Translates the list of TaskGroup instances to a set of ContainerModel instances, using the
-   * set of TaskModel instances.
-   *
-   * @param tasks             the TaskModels to assign to the ContainerModels.
-   * @param containerTasks    the TaskGroups defining how the tasks should be grouped.
-   * @return                  a mutable set of ContainerModels.
-   */
-  private Set<ContainerModel> buildContainerModels(Set<TaskModel> tasks, List<TaskGroup> containerTasks) {
-    // Map task names to models
-    Map<String, TaskModel> taskNameToModel = new HashMap<>();
-    for (TaskModel model : tasks) {
-      taskNameToModel.put(model.getTaskName().getTaskName(), model);
-    }
-
-    // Build container models
-    Set<ContainerModel> containerModels = new HashSet<>();
-    for (TaskGroup container : containerTasks) {
-      Map<TaskName, TaskModel> containerTaskModels = new HashMap<>();
-      for (String taskName : container.taskNames) {
-        TaskModel model = taskNameToModel.get(taskName);
-        containerTaskModels.put(model.getTaskName(), model);
-      }
-      containerModels.add(
-          new ContainerModel(container.containerId, containerTaskModels));
-    }
-    return Collections.unmodifiableSet(containerModels);
-  }
-
-  /**
    * Converts the task->containerId map to an ordered list of {@link TaskGroup} instances.
    *
    * @param taskToContainerId a map from each task name to the containerId to which it is assigned.
    * @return                  a list of TaskGroups ordered ascending by containerId.
    */
-  private List<TaskGroup> getOrderedContainers(Map<String, String> taskToContainerId) {
-    log.debug("Got task to container map: {}", taskToContainerId);
+  private List<TaskGroup> getOrderedContainers(Map<TaskName, String> taskToContainerId) {
+    LOG.debug("Got task to container map: {}", taskToContainerId);
 
     // Group tasks by container Id
-    HashMap<String, List<String>> containerIdToTaskNames = new HashMap<>();
-    for (Map.Entry<String, String> entry : taskToContainerId.entrySet()) {
-      String taskName = entry.getKey();
+    Map<String, List<String>> containerIdToTaskNames = new HashMap<>();
+    for (Map.Entry<TaskName, String> entry : taskToContainerId.entrySet()) {
+      String taskName = entry.getKey().getTaskName();
       String containerId = entry.getValue();
-      List<String> taskNames = containerIdToTaskNames.get(containerId);
-      if (taskNames == null) {
-        taskNames = new ArrayList<>();
-        containerIdToTaskNames.put(containerId, taskNames);
-      }
+      List<String> taskNames = containerIdToTaskNames.computeIfAbsent(containerId, k -> new ArrayList<>());
       taskNames.add(taskName);
     }
 
@@ -347,36 +271,4 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
 
     return containerTasks;
   }
-
-  /**
-   * A mutable group of tasks and an associated container id.
-   *
-   * Used as a temporary mutable container until the final ContainerModel is known.
-   */
-  private static class TaskGroup {
-    private final List<String> taskNames = new LinkedList<>();
-    private final String containerId;
-
-    private TaskGroup(String containerId, List<String> taskNames) {
-      this.containerId = containerId;
-      Collections.sort(taskNames);        // For consistency because the taskNames came from a Map
-      this.taskNames.addAll(taskNames);
-    }
-
-    public String getContainerId() {
-      return containerId;
-    }
-
-    public void addTaskName(String taskName) {
-      taskNames.add(taskName);
-    }
-
-    public String removeTask() {
-      return taskNames.remove(taskNames.size() - 1);
-    }
-
-    public int size() {
-      return taskNames.size();
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
index 06aba33..5acf5b8 100644
--- a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
@@ -19,6 +19,7 @@
 package org.apache.samza.container.grouper.task;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 
 /**
  * Factory to build the GroupByContainerCount class.
@@ -26,6 +27,6 @@ import org.apache.samza.config.Config;
 public class GroupByContainerCountFactory implements TaskNameGrouperFactory {
   @Override
   public TaskNameGrouper build(Config config) {
-    return new GroupByContainerCount(config);
+    return new GroupByContainerCount(new JobConfig(config).getContainerCount());
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerIds.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerIds.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerIds.java
index 9dab943..7c11da4 100644
--- a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerIds.java
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerIds.java
@@ -19,27 +19,34 @@
 
 package org.apache.samza.container.grouper.task;
 
-import java.util.Arrays;
-import java.util.stream.Collectors;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.TaskModel;
-
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterators;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.runtime.LocationId;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-
 /**
- * Simple grouper.
- * It exposes two group methods - one that assumes sequential container numbers and one that gets a set of container
- * IDs as an argument. Please note - this first implementation ignores locality information.
+ * A {@link TaskNameGrouper} implementation that provides two different grouping strategies:
+ *
+ * - One that assigns the tasks to the available containerIds in a round robin fashion.
+ * - The other that generates a equidistributed and locality-aware task to container assignment.
  */
 public class GroupByContainerIds implements TaskNameGrouper {
   private static final Logger LOG = LoggerFactory.getLogger(GroupByContainerIds.class);
@@ -49,6 +56,9 @@ public class GroupByContainerIds implements TaskNameGrouper {
     this.startContainerCount = count;
   }
 
+  /**
+   * {@inheritDoc}
+   */
   @Override
   public Set<ContainerModel> group(Set<TaskModel> tasks) {
     List<String> containerIds = new ArrayList<>(startContainerCount);
@@ -58,30 +68,40 @@ public class GroupByContainerIds implements TaskNameGrouper {
     return group(tasks, containerIds);
   }
 
-  public Set<ContainerModel> group(Set<TaskModel> tasks, List<String> containersIds) {
-    if (containersIds == null)
+  /**
+   * {@inheritDoc}
+   *
+   * When number of taskModels are less than number of available containerIds,
+   * then chooses then selects the lexicographically least `x` containerIds.
+   *
+   * Otherwise, assigns the tasks to the available containerIds in a round robin fashion
+   * preserving the containerId in the final assignment.
+   */
+  @Override
+  public Set<ContainerModel> group(Set<TaskModel> tasks, List<String> containerIds) {
+    if (containerIds == null)
       return this.group(tasks);
 
-    if (containersIds.isEmpty())
+    if (containerIds.isEmpty())
       throw new IllegalArgumentException("Must have at least one container");
 
     if (tasks.isEmpty())
-      throw new IllegalArgumentException("cannot group an empty set. containersIds=" + Arrays
-          .toString(containersIds.toArray()));
+      throw new IllegalArgumentException("cannot group an empty set. containerIds=" + Arrays
+          .toString(containerIds.toArray()));
 
-    if (containersIds.size() > tasks.size()) {
-      LOG.warn("Number of containers: {} is greater than number of tasks: {}.",  containersIds.size(), tasks.size());
+    if (containerIds.size() > tasks.size()) {
+      LOG.warn("Number of containers: {} is greater than number of tasks: {}.",  containerIds.size(), tasks.size());
       /**
        * Choose lexicographically least `x` containerIds(where x = tasks.size()).
        */
-      containersIds = containersIds.stream()
+      containerIds = containerIds.stream()
                                    .sorted()
                                    .limit(tasks.size())
                                    .collect(Collectors.toList());
-      LOG.info("Generating containerModel with containers: {}.", containersIds);
+      LOG.info("Generating containerModel with containers: {}.", containerIds);
     }
 
-    int containerCount = containersIds.size();
+    int containerCount = containerIds.size();
 
     // Sort tasks by taskName.
     List<TaskModel> sortedTasks = new ArrayList<>(tasks);
@@ -100,9 +120,118 @@ public class GroupByContainerIds implements TaskNameGrouper {
     // Convert to a Set of ContainerModel
     Set<ContainerModel> containerModels = new HashSet<>();
     for (int i = 0; i < containerCount; i++) {
-      containerModels.add(new ContainerModel(containersIds.get(i), taskGroups[i]));
+      containerModels.add(new ContainerModel(containerIds.get(i), taskGroups[i]));
     }
 
     return Collections.unmodifiableSet(containerModels);
   }
+
+  /**
+   * {@inheritDoc}
+   *
+   * When the are `t` tasks and `p` processors, where t &lt;= p, a fair task distribution should ideally assign
+   * (t / p) tasks to each processor. In addition to guaranteeing a fair distribution, this {@link TaskNameGrouper}
+   * implementation generates a locationId aware task assignment to processors where it makes best efforts in assigning
+   * the tasks to processors with the same locality.
+   *
+   * Task assignment to processors is accomplished through the following two phases:
+   *
+   * 1. In the first phase, each task(T) is assigned to a processor(P) that satisfies the following constraints:
+   *    A. The processor(P) should have the same locality of the task(T).
+   *    B. Number of tasks already assigned to the processor should be less than the (number of tasks / number of processors).
+   *
+   * 2. Each unassigned task from phase 1 are then mapped to any processor with task count less than the
+   * (number of tasks / number of processors). When no such processor exists, then the unassigned
+   * task is mapped to any processor from available processors in a round robin fashion.
+   */
+  @Override
+  public Set<ContainerModel> group(Set<TaskModel> taskModels, GrouperMetadata grouperMetadata) {
+    // Validate that the task models are not empty.
+    Map<TaskName, LocationId> taskLocality = grouperMetadata.getTaskLocality();
+    Preconditions.checkArgument(!taskModels.isEmpty(), "No tasks found. Likely due to no input partitions. Can't run a job with no tasks.");
+
+    // Invoke the default grouper when the processor locality does not exist.
+    if (MapUtils.isEmpty(grouperMetadata.getProcessorLocality())) {
+      LOG.info("ProcessorLocality is empty. Generating with the default group method.");
+      return group(taskModels, new ArrayList<>());
+    }
+
+    Map<String, LocationId> processorLocality = new TreeMap<>(grouperMetadata.getProcessorLocality());
+    /**
+     * When there're more task models than processors then choose the lexicographically least `x` processors(where x = tasks.size()).
+     */
+    if (processorLocality.size() > taskModels.size()) {
+      processorLocality = processorLocality.entrySet()
+                                           .stream()
+                                           .limit(taskModels.size())
+                                           .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+    }
+
+    Map<LocationId, List<String>> locationIdToProcessors = new HashMap<>();
+    Map<String, TaskGroup> processorIdToTaskGroup = new HashMap<>();
+
+    // Generate the {@see LocationId} to processors mapping and processorId to {@see TaskGroup} mapping.
+    processorLocality.forEach((processorId, locationId) -> {
+        List<String> processorIds = locationIdToProcessors.getOrDefault(locationId, new ArrayList<>());
+        processorIds.add(processorId);
+        locationIdToProcessors.put(locationId, processorIds);
+        processorIdToTaskGroup.put(processorId, new TaskGroup(processorId, new ArrayList<>()));
+      });
+
+    int numTasksPerProcessor = taskModels.size() / processorLocality.size();
+    Set<TaskName> assignedTasks = new HashSet<>();
+
+    /**
+     * A processor is considered under-assigned when number of tasks assigned to it is less than
+     * (number of tasks / number of processors).
+     * Map the tasks to the under-assigned processors with same locality.
+     */
+    for (TaskModel taskModel : taskModels) {
+      LocationId taskLocationId = taskLocality.get(taskModel.getTaskName());
+      if (taskLocationId != null) {
+        List<String> processorIds = locationIdToProcessors.getOrDefault(taskLocationId, new ArrayList<>());
+        for (String processorId : processorIds) {
+          TaskGroup taskGroup = processorIdToTaskGroup.get(processorId);
+          if (taskGroup.size() < numTasksPerProcessor) {
+            taskGroup.addTaskName(taskModel.getTaskName().getTaskName());
+            assignedTasks.add(taskModel.getTaskName());
+            break;
+          }
+        }
+      }
+    }
+
+    /**
+     * In some scenarios, the task either might not have any previous locality or might not have any
+     * processor that maps to its previous locality. This cyclic processorId's iterator helps us in
+     * those scenarios to assign the processorIds to those kind of tasks in a round robin fashion.
+     */
+    Iterator<String> processorIdsCyclicIterator = Iterators.cycle(processorLocality.keySet());
+
+    // Order the taskGroups to choose a task group in a deterministic fashion for unassigned tasks.
+    List<TaskGroup> taskGroups = new ArrayList<>(processorIdToTaskGroup.values());
+    taskGroups.sort(Comparator.comparing(TaskGroup::getContainerId));
+
+    /**
+     * For the tasks left over from the previous stage, map them to any under-assigned processor.
+     * When a under-assigned processor doesn't exist, then map them to any processor from the
+     * available processors in a round robin manner.
+     */
+    for (TaskModel taskModel : taskModels) {
+      if (!assignedTasks.contains(taskModel.getTaskName())) {
+        Optional<TaskGroup> underAssignedTaskGroup = taskGroups.stream()
+                .filter(taskGroup -> taskGroup.size() < numTasksPerProcessor)
+                .findFirst();
+        if (underAssignedTaskGroup.isPresent()) {
+          underAssignedTaskGroup.get().addTaskName(taskModel.getTaskName().getTaskName());
+        } else {
+          TaskGroup taskGroup = processorIdToTaskGroup.get(processorIdsCyclicIterator.next());
+          taskGroup.addTaskName(taskModel.getTaskName().getTaskName());
+        }
+        assignedTasks.add(taskModel.getTaskName());
+      }
+    }
+
+    return TaskGroup.buildContainerModels(taskModels, taskGroups);
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadata.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadata.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadata.java
new file mode 100644
index 0000000..ef919d1
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadata.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import org.apache.samza.annotation.InterfaceStability;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.runtime.LocationId;
+import org.apache.samza.system.SystemStreamPartition;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Provides the historical metadata of the samza application.
+ */
+@InterfaceStability.Evolving
+public interface GrouperMetadata {
+
+  /**
+   * Gets the current processor locality of the job.
+   * @return the processorId to the {@link LocationId} assignment.
+   */
+  Map<String, LocationId> getProcessorLocality();
+
+  /**
+   * Gets the current task locality of the job.
+   * @return the current {@link TaskName} to {@link LocationId} assignment.
+   */
+  Map<TaskName, LocationId> getTaskLocality();
+
+  /**
+   * Gets the previous {@link TaskName} to {@link SystemStreamPartition} assignment of the job.
+   * @return the previous {@link TaskName} to {@link SystemStreamPartition} assignment.
+   */
+  Map<TaskName, List<SystemStreamPartition>> getPreviousTaskToSSPAssignment();
+
+
+  /**
+   * Gets the previous {@link TaskName} to processorId assignments of the job.
+   * @return the previous task to processorId assignment.
+   */
+  Map<TaskName, String> getPreviousTaskToProcessorAssignment();
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadataImpl.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadataImpl.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadataImpl.java
new file mode 100644
index 0000000..bc40bc4
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/GrouperMetadataImpl.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import org.apache.samza.container.TaskName;
+import org.apache.samza.runtime.LocationId;
+import org.apache.samza.system.SystemStreamPartition;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Implementation of {@link GrouperMetadata} that holds the necessary historical metadata of
+ * the samza job. This is used by the {@link TaskNameGrouper} to generate optimal task assignments.
+ */
+public class GrouperMetadataImpl implements GrouperMetadata {
+
+  // Map of processorId to LocationId.
+  private final Map<String, LocationId> processorLocality;
+
+  // Map of TaskName to LocationId.
+  private final Map<TaskName, LocationId> taskLocality;
+
+  // Map of TaskName to a list of the input SystemStreamPartition's assigned to it.
+  private final Map<TaskName, List<SystemStreamPartition>> previousTaskToSSPAssignment;
+
+  // Map of TaskName to ProcessorId.
+  private final Map<TaskName, String> previousTaskToProcessorAssignment;
+
+  public GrouperMetadataImpl(Map<String, LocationId> processorLocality, Map<TaskName, LocationId> taskLocality, Map<TaskName, List<SystemStreamPartition>> previousTaskToSSPAssignments, Map<TaskName, String> previousTaskToProcessorAssignment) {
+    this.processorLocality = Collections.unmodifiableMap(processorLocality);
+    this.taskLocality = Collections.unmodifiableMap(taskLocality);
+    this.previousTaskToSSPAssignment = Collections.unmodifiableMap(previousTaskToSSPAssignments);
+    this.previousTaskToProcessorAssignment = Collections.unmodifiableMap(previousTaskToProcessorAssignment);
+  }
+
+  @Override
+  public Map<String, LocationId> getProcessorLocality() {
+    return processorLocality;
+  }
+
+  @Override
+  public Map<TaskName, LocationId> getTaskLocality() {
+    return taskLocality;
+  }
+
+  @Override
+  public Map<TaskName, List<SystemStreamPartition>> getPreviousTaskToSSPAssignment() {
+    return previousTaskToSSPAssignment;
+  }
+
+  @Override
+  public Map<TaskName, String> getPreviousTaskToProcessorAssignment() {
+    return this.previousTaskToProcessorAssignment;
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
index 32bbf29..b6e946c 100644
--- a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
@@ -81,9 +81,6 @@ public class TaskAssignmentManager {
     this.valueSerde = valueSerde;
     MetadataStoreFactory metadataStoreFactory = Util.getObj(new JobConfig(config).getMetadataStoreFactory(), MetadataStoreFactory.class);
     this.metadataStore = metadataStoreFactory.getMetadataStore(SetTaskContainerMapping.TYPE, config, metricsRegistry);
-  }
-
-  public void init() {
     this.metadataStore.init();
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskGroup.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskGroup.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskGroup.java
new file mode 100644
index 0000000..1fe0f40
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskGroup.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import org.apache.samza.container.TaskName;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
+
+import java.util.*;
+
+/**
+ * A mutable group of tasks and an associated container id.
+ *
+ * Used as a temporary mutable container until the final ContainerModel is known.
+ */
+class TaskGroup {
+  private final List<String> taskNames = new LinkedList<>();
+  private final String containerId;
+
+  TaskGroup(String containerId, List<String> taskNames) {
+    this.containerId = containerId;
+    this.taskNames.addAll(taskNames);
+    Collections.sort(this.taskNames); // For consistency because the taskNames came from a Map
+  }
+
+  public String getContainerId() {
+    return containerId;
+  }
+
+  public void addTaskName(String taskName) {
+    taskNames.add(taskName);
+  }
+
+  public String removeLastTaskName() {
+    return taskNames.remove(taskNames.size() - 1);
+  }
+
+  public int size() {
+    return taskNames.size();
+  }
+
+  /**
+   * Converts the {@link TaskGroup} list to a set of ContainerModel.
+   *
+   * @param taskModels    the TaskModels to assign to the ContainerModels.
+   * @param taskGroups    the TaskGroups defining how the tasks should be grouped.
+   * @return              a set of ContainerModels.
+   */
+  public static Set<ContainerModel> buildContainerModels(Set<TaskModel> taskModels, Collection<TaskGroup> taskGroups) {
+    // Map task names to models
+    Map<String, TaskModel> taskNameToModel = new HashMap<>();
+    for (TaskModel model : taskModels) {
+      taskNameToModel.put(model.getTaskName().getTaskName(), model);
+    }
+
+    // Build container models
+    Set<ContainerModel> containerModels = new HashSet<>();
+    for (TaskGroup container : taskGroups) {
+      Map<TaskName, TaskModel> containerTaskModels = new HashMap<>();
+      for (String taskName : container.taskNames) {
+        TaskModel model = taskNameToModel.get(taskName);
+        containerTaskModels.put(model.getTaskName(), model);
+      }
+      containerModels.add(new ContainerModel(container.containerId, containerTaskModels));
+    }
+
+    return Collections.unmodifiableSet(containerModels);
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouper.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouper.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouper.java
index 71b80cc..2124dfc 100644
--- a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouper.java
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouper.java
@@ -18,10 +18,9 @@
  */
 package org.apache.samza.container.grouper.task;
 
-import java.util.List;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-
+import java.util.List;
 import java.util.Set;
 
 /**
@@ -44,15 +43,39 @@ import java.util.Set;
  * </p>
  */
 public interface TaskNameGrouper {
+
   /**
-   * Group tasks into the containers they will share.
+   * Groups the taskModels into set of {@link ContainerModel} using the metadata of
+   * the job from {@link GrouperMetadata}.
    *
-   * @param tasks Set of tasks to group into containers.
-   * @return Set of containers, which contain the tasks that were passed in.
+   * @param taskModels the set of tasks to group into containers.
+   * @param grouperMetadata provides the historical metadata of the samza job.
+   * @return the grouped {@link ContainerModel} built from the provided taskModels.
    */
-  Set<ContainerModel> group(Set<TaskModel> tasks);
+  default Set<ContainerModel> group(Set<TaskModel> taskModels, GrouperMetadata grouperMetadata) {
+    return group(taskModels);
+  }
 
-  default Set<ContainerModel> group(Set<TaskModel> tasks, List<String> containersIds) {
-    return group(tasks);
+  /**
+   * Group the taskModels into set of {@link ContainerModel}.
+   *
+   * @param taskModels the set of {@link TaskModel} to group into containers.
+   * @return the grouped {@link ContainerModel} built from the provided taskModels.
+   */
+  @Deprecated
+  default Set<ContainerModel> group(Set<TaskModel> taskModels) {
+    throw new UnsupportedOperationException();
+  }
+
+  /**
+   * Group the taskModels into set of {@link ContainerModel}.
+   *
+   * @param taskModels the set of {@link TaskModel} to group into containers.
+   * @param containersIds the list of container ids that has to be used in the {@link ContainerModel}.
+   * @return the grouped {@link ContainerModel} built from the provided taskModels.
+   */
+  @Deprecated
+  default Set<ContainerModel> group(Set<TaskModel> taskModels, List<String> containersIds) {
+    return group(taskModels);
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouperFactory.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouperFactory.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouperFactory.java
index 8b967b7..37684f4 100644
--- a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouperFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskNameGrouperFactory.java
@@ -28,7 +28,7 @@ public interface TaskNameGrouperFactory {
    * Builds a {@link TaskNameGrouper}. The config can be used to read the necessary values which are needed int the
    * process of building the {@link TaskNameGrouper}
    *
-   * @param config configuration to which values can be used to build a {@link TaskNameGrouper}
+   * @param config configuration to use for building the {@link TaskNameGrouper}
    * @return a {@link TaskNameGrouper} implementation
    */
   TaskNameGrouper build(Config config);

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java
index 0110551..0c5e368 100644
--- a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java
+++ b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java
@@ -104,7 +104,7 @@ public class ExecutionPlanner {
     // currently we don't support host-affinity in batch mode
     if (appConfig.getAppMode() == ApplicationConfig.ApplicationMode.BATCH && clusterConfig.getHostAffinityEnabled()) {
       throw new SamzaException(String.format("Host affinity is not supported in batch mode. Please configure %s=false.",
-          ClusterManagerConfig.CLUSTER_MANAGER_HOST_AFFINITY_ENABLED));
+          ClusterManagerConfig.JOB_HOST_AFFINITY_ENABLED));
     }
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
index 7328bc7..34d67cc 100644
--- a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
+++ b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
@@ -317,7 +317,8 @@ public class StreamProcessor {
     return SamzaContainer.apply(processorId, jobModel, ScalaJavaUtil.toScalaMap(this.customMetricsReporter),
         this.taskFactory, JobContextImpl.fromConfigWithDefaults(this.config),
         Option.apply(this.applicationDefinedContainerContextFactoryOptional.orElse(null)),
-        Option.apply(this.applicationDefinedTaskContextFactoryOptional.orElse(null)));
+        Option.apply(this.applicationDefinedTaskContextFactoryOptional.orElse(null)),
+        null);
   }
 
   private JobCoordinator createJobCoordinator() {

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java b/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
index c5c0d78..a5a45ba 100644
--- a/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
+++ b/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
@@ -22,6 +22,14 @@ package org.apache.samza.runtime;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Random;
+
+import org.apache.samza.container.ContainerHeartbeatClient;
+import org.apache.samza.container.ContainerHeartbeatMonitor;
+import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.SamzaContainer;
+import org.apache.samza.container.SamzaContainer$;
+import org.apache.samza.container.SamzaContainerListener;
+import org.apache.samza.metrics.MetricsRegistryMap;
 import org.slf4j.MDC;
 import org.apache.samza.SamzaException;
 import org.apache.samza.application.descriptors.ApplicationDescriptor;
@@ -31,11 +39,6 @@ import org.apache.samza.application.ApplicationUtil;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.ShellCommandConfig;
-import org.apache.samza.container.ContainerHeartbeatClient;
-import org.apache.samza.container.ContainerHeartbeatMonitor;
-import org.apache.samza.container.SamzaContainer;
-import org.apache.samza.container.SamzaContainer$;
-import org.apache.samza.container.SamzaContainerListener;
 import org.apache.samza.context.JobContextImpl;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metrics.MetricsReporter;
@@ -47,7 +50,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.Option;
 
-
 /**
  * Launches and manages the lifecycle for {@link SamzaContainer}s in YARN.
  */
@@ -93,6 +95,7 @@ public class LocalContainerRunner {
   private static void run(ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc, String containerId,
       JobModel jobModel, Config config) {
     TaskFactory taskFactory = TaskFactoryUtil.getTaskFactory(appDesc);
+    LocalityManager localityManager = new LocalityManager(config, new MetricsRegistryMap());
     SamzaContainer container = SamzaContainer$.MODULE$.apply(
         containerId,
         jobModel,
@@ -100,7 +103,8 @@ public class LocalContainerRunner {
         taskFactory,
         JobContextImpl.fromConfigWithDefaults(config),
         Option.apply(appDesc.getApplicationContainerContextFactory().orElse(null)),
-        Option.apply(appDesc.getApplicationTaskContextFactory().orElse(null)));
+        Option.apply(appDesc.getApplicationTaskContextFactory().orElse(null)),
+        localityManager);
 
     ProcessorLifecycleListener listener = appDesc.getProcessorLifecycleListenerFactory()
         .createInstance(new ProcessorContext() { }, config);

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/standalone/PassthroughJobCoordinator.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/standalone/PassthroughJobCoordinator.java b/samza-core/src/main/java/org/apache/samza/standalone/PassthroughJobCoordinator.java
index 737ac3e..44fd811 100644
--- a/samza-core/src/main/java/org/apache/samza/standalone/PassthroughJobCoordinator.java
+++ b/samza-core/src/main/java/org/apache/samza/standalone/PassthroughJobCoordinator.java
@@ -18,16 +18,22 @@
  */
 package org.apache.samza.standalone;
 
+import com.google.common.collect.ImmutableMap;
 import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.ConfigException;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.TaskConfigJava;
+import org.apache.samza.container.grouper.task.GrouperMetadata;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.coordinator.JobCoordinator;
 import org.apache.samza.coordinator.JobModelManager;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.coordinator.JobCoordinatorListener;
+import org.apache.samza.runtime.LocationId;
+import org.apache.samza.runtime.LocationIdProvider;
+import org.apache.samza.runtime.LocationIdProviderFactory;
 import org.apache.samza.runtime.ProcessorIdGenerator;
 import org.apache.samza.storage.ChangelogStreamManager;
 import org.apache.samza.system.StreamMetadataCache;
@@ -35,7 +41,6 @@ import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.util.*;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-
 import java.util.Collections;
 
 /**
@@ -65,11 +70,15 @@ public class PassthroughJobCoordinator implements JobCoordinator {
   private static final Logger LOGGER = LoggerFactory.getLogger(PassthroughJobCoordinator.class);
   private final String processorId;
   private final Config config;
+  private final LocationId locationId;
   private JobCoordinatorListener coordinatorListener = null;
 
   public PassthroughJobCoordinator(Config config) {
     this.processorId = createProcessorId(config);
     this.config = config;
+    LocationIdProviderFactory locationIdProviderFactory = Util.getObj(new JobConfig(config).getLocationIdProviderFactory(), LocationIdProviderFactory.class);
+    LocationIdProvider locationIdProvider = locationIdProviderFactory.getLocationIdProvider(config);
+    this.locationId = locationIdProvider.getLocationId();
   }
 
   @Override
@@ -119,18 +128,13 @@ public class PassthroughJobCoordinator implements JobCoordinator {
     SystemAdmins systemAdmins = new SystemAdmins(config);
     StreamMetadataCache streamMetadataCache = new StreamMetadataCache(systemAdmins, 5000, SystemClock.instance());
     systemAdmins.start();
-    String containerId = Integer.toString(config.getInt(JobConfig.PROCESSOR_ID()));
-
-    /** TODO:
-     Locality Manager seems to be required in JC for reading locality info and grouping tasks intelligently and also,
-     in SamzaContainer for writing locality info to the coordinator stream. This closely couples together
-     TaskNameGrouper with the LocalityManager! Hence, groupers should be a property of the jobcoordinator
-     (job.coordinator.task.grouper, instead of task.systemstreampartition.grouper)
-     */
-    JobModel jobModel = JobModelManager.readJobModel(this.config, Collections.emptyMap(), null, streamMetadataCache,
-        Collections.singletonList(containerId));
-    systemAdmins.stop();
-    return jobModel;
+    try {
+      String containerId = Integer.toString(config.getInt(JobConfig.PROCESSOR_ID()));
+      GrouperMetadata grouperMetadata = new GrouperMetadataImpl(ImmutableMap.of(String.valueOf(containerId), locationId), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
+      return JobModelManager.readJobModel(this.config, Collections.emptyMap(), streamMetadataCache, grouperMetadata);
+    } finally {
+      systemAdmins.stop();
+    }
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
index 64ae310..5442d6e 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
@@ -125,12 +125,13 @@ public class StorageRecovery extends CommandLine {
    * map
    */
   private void getContainerModels() {
-    CoordinatorStreamManager coordinatorStreamManager = new CoordinatorStreamManager(jobConfig, new MetricsRegistryMap());
+    MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
+    CoordinatorStreamManager coordinatorStreamManager = new CoordinatorStreamManager(jobConfig, metricsRegistryMap);
     coordinatorStreamManager.register(getClass().getSimpleName());
     coordinatorStreamManager.start();
     coordinatorStreamManager.bootstrap();
     ChangelogStreamManager changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager);
-    JobModel jobModel = JobModelManager.apply(coordinatorStreamManager.getConfig(), changelogStreamManager.readPartitionMapping()).jobModel();
+    JobModel jobModel = JobModelManager.apply(coordinatorStreamManager.getConfig(), changelogStreamManager.readPartitionMapping(), metricsRegistryMap).jobModel();
     containers = jobModel.getContainers();
     coordinatorStreamManager.stop();
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinator.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinator.java b/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinator.java
index 81c0465..8c5a3ba 100644
--- a/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinator.java
+++ b/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinator.java
@@ -20,6 +20,7 @@ package org.apache.samza.zk;
 
 import com.google.common.annotations.VisibleForTesting;
 
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.HashMap;
 import java.util.List;
@@ -40,6 +41,8 @@ import org.apache.samza.config.TaskConfigJava;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.ZkConfig;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.task.GrouperMetadata;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.coordinator.JobCoordinator;
 import org.apache.samza.coordinator.JobCoordinatorListener;
 import org.apache.samza.coordinator.JobModelManager;
@@ -47,17 +50,23 @@ import org.apache.samza.coordinator.LeaderElectorListener;
 import org.apache.samza.coordinator.StreamPartitionCountMonitor;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.MetricsReporter;
 import org.apache.samza.metrics.ReadableMetricsRegistry;
+import org.apache.samza.runtime.LocationId;
+import org.apache.samza.runtime.LocationIdProvider;
+import org.apache.samza.runtime.LocationIdProviderFactory;
 import org.apache.samza.runtime.ProcessorIdGenerator;
 import org.apache.samza.storage.ChangelogStreamManager;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.MetricsReporterLoader;
 import org.apache.samza.util.SystemClock;
 import org.apache.samza.util.Util;
+import org.apache.samza.zk.ZkUtils.ProcessorNode;
 import org.apache.zookeeper.Watcher;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -98,6 +107,7 @@ public class ZkJobCoordinator implements JobCoordinator {
   private final SystemAdmins systemAdmins;
   private final int debounceTimeMs;
   private final Map<TaskName, Integer> changeLogPartitionMap = new HashMap<>();
+  private final LocationId locationId;
 
   private JobCoordinatorListener coordinatorListener = null;
   private JobModel newJobModel;
@@ -131,14 +141,16 @@ public class ZkJobCoordinator implements JobCoordinator {
     this.barrier =  new ZkBarrierForVersionUpgrade(zkUtils.getKeyBuilder().getJobModelVersionBarrierPrefix(), zkUtils, new ZkBarrierListenerImpl(), debounceTimer);
     systemAdmins = new SystemAdmins(config);
     streamMetadataCache = new StreamMetadataCache(systemAdmins, METADATA_CACHE_TTL_MS, SystemClock.instance());
+    LocationIdProviderFactory locationIdProviderFactory = Util.getObj(new JobConfig(config).getLocationIdProviderFactory(), LocationIdProviderFactory.class);
+    LocationIdProvider locationIdProvider = locationIdProviderFactory.getLocationIdProvider(config);
+    this.locationId = locationIdProvider.getLocationId();
   }
 
   @Override
   public void start() {
     ZkKeyBuilder keyBuilder = zkUtils.getKeyBuilder();
     zkUtils.validateZkVersion();
-    zkUtils.validatePaths(new String[]{keyBuilder.getProcessorsPath(), keyBuilder.getJobModelVersionPath(), keyBuilder
-        .getJobModelPathPrefix()});
+    zkUtils.validatePaths(new String[]{keyBuilder.getProcessorsPath(), keyBuilder.getJobModelVersionPath(), keyBuilder.getJobModelPathPrefix(), keyBuilder.getTaskLocalityPath()});
 
     startMetrics();
     systemAdmins.start();
@@ -246,7 +258,13 @@ public class ZkJobCoordinator implements JobCoordinator {
   }
 
   void doOnProcessorChange() {
-    List<String> currentProcessorIds = zkUtils.getSortedActiveProcessorsIDs();
+    List<ProcessorNode> processorNodes = zkUtils.getAllProcessorNodes();
+
+    List<String> currentProcessorIds = new ArrayList<>();
+    for (ProcessorNode processorNode : processorNodes) {
+      currentProcessorIds.add(processorNode.getProcessorData().getProcessorId());
+    }
+
     Set<String> uniqueProcessorIds = new HashSet<>(currentProcessorIds);
 
     if (currentProcessorIds.size() != uniqueProcessorIds.size()) {
@@ -256,7 +274,7 @@ public class ZkJobCoordinator implements JobCoordinator {
 
     // Generate the JobModel
     LOG.info("Generating new JobModel with processors: {}.", currentProcessorIds);
-    JobModel jobModel = generateNewJobModel(currentProcessorIds);
+    JobModel jobModel = generateNewJobModel(processorNodes);
 
     // Create checkpoint and changelog streams if they don't exist
     if (!hasCreatedStreams) {
@@ -308,7 +326,7 @@ public class ZkJobCoordinator implements JobCoordinator {
   /**
    * Generate new JobModel when becoming a leader or the list of processor changed.
    */
-  private JobModel generateNewJobModel(List<String> processors) {
+  private JobModel generateNewJobModel(List<ProcessorNode> processorNodes) {
     String zkJobModelVersion = zkUtils.getJobModelVersion();
     // If JobModel exists in zookeeper && cached JobModel version is unequal to JobModel version stored in zookeeper.
     if (zkJobModelVersion != null && !Objects.equals(cachedJobModelVersion, zkJobModelVersion)) {
@@ -318,11 +336,9 @@ public class ZkJobCoordinator implements JobCoordinator {
       }
       cachedJobModelVersion = zkJobModelVersion;
     }
-    /**
-     * Host affinity is not supported in standalone. Hence, LocalityManager(which is responsible for container
-     * to host mapping) is passed in as null when building the jobModel.
-     */
-    JobModel model = JobModelManager.readJobModel(this.config, changeLogPartitionMap, null, streamMetadataCache, processors);
+
+    GrouperMetadata grouperMetadata = getGrouperMetadata(zkJobModelVersion, processorNodes);
+    JobModel model = JobModelManager.readJobModel(config, changeLogPartitionMap, streamMetadataCache, grouperMetadata);
     return new JobModel(new MapConfig(), model.getContainers());
   }
 
@@ -343,6 +359,39 @@ public class ZkJobCoordinator implements JobCoordinator {
       });
   }
 
+  /**
+   * Builds the {@link GrouperMetadataImpl} based upon provided {@param jobModelVersion}
+   * and {@param processorNodes}.
+   * @param jobModelVersion the most recent jobModelVersion available in the zookeeper.
+   * @param processorNodes the list of live processors in the zookeeper.
+   * @return the built grouper metadata.
+   */
+  private GrouperMetadataImpl getGrouperMetadata(String jobModelVersion, List<ProcessorNode> processorNodes) {
+    Map<TaskName, String> taskToProcessorId = new HashMap<>();
+    Map<TaskName, List<SystemStreamPartition>> taskToSSPs = new HashMap<>();
+    if (jobModelVersion != null) {
+      JobModel jobModel = zkUtils.getJobModel(jobModelVersion);
+      for (ContainerModel containerModel : jobModel.getContainers().values()) {
+        for (TaskModel taskModel : containerModel.getTasks().values()) {
+          taskToProcessorId.put(taskModel.getTaskName(), containerModel.getId());
+          for (SystemStreamPartition partition : taskModel.getSystemStreamPartitions()) {
+            taskToSSPs.computeIfAbsent(taskModel.getTaskName(), k -> new ArrayList<>());
+            taskToSSPs.get(taskModel.getTaskName()).add(partition);
+          }
+        }
+      }
+    }
+
+    Map<String, LocationId> processorLocality = new HashMap<>();
+    for (ProcessorNode processorNode : processorNodes) {
+      ProcessorData processorData = processorNode.getProcessorData();
+      processorLocality.put(processorData.getProcessorId(), processorData.getLocationId());
+    }
+
+    Map<TaskName, LocationId> taskLocality = zkUtils.readTaskLocality();
+    return new GrouperMetadataImpl(processorLocality, taskLocality, taskToSSPs, taskToProcessorId);
+  }
+
   class LeaderElectorListenerImpl implements LeaderElectorListener {
     @Override
     public void onBecomingLeader() {
@@ -390,6 +439,11 @@ public class ZkJobCoordinator implements JobCoordinator {
             JobModel jobModel = getJobModel();
             // start the container with the new model
             if (coordinatorListener != null) {
+              for (ContainerModel containerModel : jobModel.getContainers().values()) {
+                for (TaskName taskName : containerModel.getTasks().keySet()) {
+                  zkUtils.writeTaskLocality(taskName, locationId);
+                }
+              }
               coordinatorListener.onNewJobModel(processorId, jobModel);
             }
           });

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java b/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
index 6349432..56ea577 100644
--- a/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
+++ b/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
@@ -24,6 +24,8 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
 import java.util.Objects;
 import java.util.TreeSet;
 import java.util.concurrent.TimeUnit;
@@ -37,8 +39,10 @@ import org.I0Itec.zkclient.exception.ZkInterruptedException;
 import org.I0Itec.zkclient.exception.ZkNodeExistsException;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.SamzaException;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.serializers.model.SamzaObjectMapper;
 import org.apache.zookeeper.data.Stat;
 import org.codehaus.jackson.map.ObjectMapper;
@@ -239,6 +243,29 @@ public class ZkUtils {
     return processorNodes;
   }
 
+  public void writeTaskLocality(TaskName taskName, LocationId locationId) {
+    String taskLocalityPath = String.format("%s/%s", keyBuilder.getTaskLocalityPath(), taskName);
+    validatePaths(new String[] {taskLocalityPath});
+    writeData(taskLocalityPath, locationId.getId());
+  }
+
+  public Map<TaskName, LocationId> readTaskLocality() {
+    Map<TaskName, LocationId> taskLocality = new HashMap<>();
+    String taskLocalityPath = keyBuilder.getTaskLocalityPath();
+    List<String> tasks = new ArrayList<>();
+    if (zkClient.exists(taskLocalityPath)) {
+      tasks = zkClient.getChildren(taskLocalityPath);
+    }
+    for (String taskName : tasks) {
+      String taskPath = String.format("%s/%s", keyBuilder.getTaskLocalityPath(), taskName);
+      String locationId = zkClient.readData(taskPath, true);
+      if (locationId != null) {
+        taskLocality.put(new TaskName(taskName), new LocationId(locationId));
+      }
+    }
+    return taskLocality;
+  }
+
   /**
    * Method is used to get the <i>sorted</i> list of currently active/registered processors (znodes)
    *

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index 03effe6..865658f 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -127,8 +127,8 @@ object SamzaContainer extends Logging {
     taskFactory: TaskFactory[_],
     jobContext: JobContext,
     applicationContainerContextFactoryOption: Option[ApplicationContainerContextFactory[ApplicationContainerContext]],
-    applicationTaskContextFactoryOption: Option[ApplicationTaskContextFactory[ApplicationTaskContext]]
-  ) = {
+    applicationTaskContextFactoryOption: Option[ApplicationTaskContextFactory[ApplicationTaskContext]],
+    localityManager: LocalityManager = null) = {
     val config = jobContext.getConfig
     val containerModel = jobModel.getContainers.get(containerId)
     val containerName = "samza-container-%s" format containerId
@@ -735,6 +735,7 @@ object SamzaContainer extends Logging {
       systemAdmins = systemAdmins,
       consumerMultiplexer = consumerMultiplexer,
       producerMultiplexer = producerMultiplexer,
+      localityManager = localityManager,
       offsetManager = offsetManager,
       securityManager = securityManager,
       metrics = samzaContainerMetrics,
@@ -990,16 +991,13 @@ class SamzaContainer(
 
   def storeContainerLocality {
     val isHostAffinityEnabled: Boolean = new ClusterManagerConfig(config).getHostAffinityEnabled
-    if (isHostAffinityEnabled) {
-      val localityManager: LocalityManager = new LocalityManager(config, containerContext.getContainerMetricsRegistry)
+    if (isHostAffinityEnabled && localityManager != null) {
       val containerId = containerContext.getContainerModel.getId
       val containerName = "SamzaContainer-" + containerId
       info("Registering %s with metadata store" format containerName)
       try {
         val hostInet = Util.getLocalHost
-        val jmxUrl = if (jmxServer != null) jmxServer.getJmxUrl else ""
-        val jmxTunnelingUrl = if (jmxServer != null) jmxServer.getTunnelingJmxUrl else ""
-        info("Writing container locality and JMX address to metadata store")
+        info("Writing container locality to metadata store")
         localityManager.writeContainerToHostMapping(containerId, hostInet.getHostName)
       } catch {
         case uhe: UnknownHostException =>


[2/3] samza git commit: SAMZA-1973: Unify the TaskNameGrouper interface for yarn and standalone.

Posted by ja...@apache.org.
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
index 600b7a1..5fb71f3 100644
--- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
@@ -19,31 +19,33 @@
 
 package org.apache.samza.coordinator
 
-
 import java.util
 import java.util.concurrent.atomic.AtomicReference
-
 import org.apache.samza.config._
 import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.config.Config
 import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory
-import org.apache.samza.container.grouper.task.BalancingTaskNameGrouper
-import org.apache.samza.container.grouper.task.TaskNameGrouperFactory
+import org.apache.samza.container.grouper.task._
 import org.apache.samza.container.LocalityManager
 import org.apache.samza.container.TaskName
 import org.apache.samza.coordinator.server.HttpServer
 import org.apache.samza.coordinator.server.JobServlet
+import org.apache.samza.job.model.ContainerModel
 import org.apache.samza.job.model.JobModel
 import org.apache.samza.job.model.TaskModel
+import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system._
 import org.apache.samza.util.Logging
 import org.apache.samza.util.Util
 import org.apache.samza.Partition
+import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping
+import org.apache.samza.runtime.LocationId
 
 import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
 
 /**
  * Helper companion object that is responsible for wiring up a JobModelManager
@@ -51,66 +53,145 @@ import scala.collection.JavaConverters._
  */
 object JobModelManager extends Logging {
 
-  val SOURCE = "JobModelManager"
   /**
    * a volatile value to store the current instantiated <code>JobModelManager</code>
    */
-  @volatile var currentJobModelManager: JobModelManager = null
+  @volatile var currentJobModelManager: JobModelManager = _
   val jobModelRef: AtomicReference[JobModel] = new AtomicReference[JobModel]()
 
   /**
-   * Does the following actions for a job.
+   * Currently used only in the ApplicationMaster for yarn deployment model.
+   * Does the following:
    * a) Reads the jobModel from coordinator stream using the job's configuration.
-   * b) Recomputes changelog partition mapping based on jobModel and job's configuration.
+   * b) Recomputes the changelog partition mapping based on jobModel and job's configuration.
    * c) Builds JobModelManager using the jobModel read from coordinator stream.
-   * @param config Config from the coordinator stream.
-   * @param changelogPartitionMapping The changelog partition-to-task mapping.
-   * @return JobModelManager
+   * @param config config from the coordinator stream.
+   * @param changelogPartitionMapping changelog partition-to-task mapping of the samza job.
+   * @param metricsRegistry the registry for reporting metrics.
+   * @return the instantiated {@see JobModelManager}.
    */
-  def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer]) = {
-    val localityManager = new LocalityManager(config, new MetricsRegistryMap())
-
-    // Map the name of each system to the corresponding SystemAdmin
+  def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer], metricsRegistry: MetricsRegistry = new MetricsRegistryMap()): JobModelManager = {
+    val localityManager = new LocalityManager(config, metricsRegistry)
+    val taskAssignmentManager = new TaskAssignmentManager(config, metricsRegistry)
     val systemAdmins = new SystemAdmins(config)
-    val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+    try {
+      systemAdmins.start()
+      val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+      val grouperMetadata: GrouperMetadata = getGrouperMetadata(config, localityManager, taskAssignmentManager)
 
-    val containerCount = new JobConfig(config).getContainerCount
-    val processorList = List.range(0, containerCount).map(c => c.toString)
+      val jobModel: JobModel = readJobModel(config, changelogPartitionMapping, streamMetadataCache, grouperMetadata)
+      jobModelRef.set(new JobModel(jobModel.getConfig, jobModel.getContainers, localityManager))
 
-    systemAdmins.start()
-    val jobModelManager = getJobModelManager(config, changelogPartitionMapping, localityManager, streamMetadataCache, processorList.asJava)
-    systemAdmins.stop()
+      updateTaskAssignments(jobModel, taskAssignmentManager, grouperMetadata)
 
-    jobModelManager
+      val server = new HttpServer
+      server.addServlet("/", new JobServlet(jobModelRef))
+
+      currentJobModelManager = new JobModelManager(jobModel, server, localityManager)
+      currentJobModelManager
+    } finally {
+      taskAssignmentManager.close()
+      systemAdmins.stop()
+      // Not closing localityManager, since {@code ClusterBasedJobCoordinator} uses it to read container locality through {@code JobModel}.
+    }
   }
 
   /**
-   * Build a JobModelManager using a Samza job's configuration.
-   */
-  private def getJobModelManager(config: Config,
-                                changeLogMapping: util.Map[TaskName, Integer],
-                                localityManager: LocalityManager,
-                                streamMetadataCache: StreamMetadataCache,
-                                containerIds: java.util.List[String]) = {
-    val jobModel: JobModel = readJobModel(config, changeLogMapping, localityManager, streamMetadataCache, containerIds)
-    jobModelRef.set(jobModel)
-
-    val server = new HttpServer
-    server.addServlet("/", new JobServlet(jobModelRef))
-    currentJobModelManager = new JobModelManager(jobModel, server, localityManager)
-    currentJobModelManager
+    * Builds the {@see GrouperMetadataImpl} for the samza job.
+    * @param config represents the configurations defined by the user.
+    * @param localityManager provides the processor to host mapping persisted to the metadata store.
+    * @param taskAssignmentManager provides the processor to task assignments persisted to the metadata store.
+    * @return the instantiated {@see GrouperMetadata}.
+    */
+  def getGrouperMetadata(config: Config, localityManager: LocalityManager, taskAssignmentManager: TaskAssignmentManager) = {
+    val processorLocality: util.Map[String, LocationId] = getProcessorLocality(config, localityManager)
+    val taskAssignment: util.Map[String, String] = taskAssignmentManager.readTaskAssignment()
+    val taskNameToProcessorId: util.Map[TaskName, String] = new util.HashMap[TaskName, String]()
+    for ((taskName, processorId) <- taskAssignment) {
+      taskNameToProcessorId.put(new TaskName(taskName), processorId)
+    }
+
+    val taskLocality:util.Map[TaskName, LocationId] = new util.HashMap[TaskName, LocationId]()
+    for ((taskName, processorId) <- taskAssignment) {
+      if (processorLocality.containsKey(processorId)) {
+        taskLocality.put(new TaskName(taskName), processorLocality.get(processorId))
+      }
+    }
+    new GrouperMetadataImpl(processorLocality, taskLocality, new util.HashMap[TaskName, util.List[SystemStreamPartition]](), taskNameToProcessorId)
   }
 
   /**
-   * For each input stream specified in config, exactly determine its
-   * partitions, returning a set of SystemStreamPartitions containing them all.
-   */
-  private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache) = {
+    * Retrieves and returns the processor locality of a samza job using provided {@see Config} and {@see LocalityManager}.
+    * @param config provides the configurations defined by the user. Required to connect to the storage layer.
+    * @param localityManager provides the processor to host mapping persisted to the metadata store.
+    * @return the processor locality.
+    */
+  def getProcessorLocality(config: Config, localityManager: LocalityManager) = {
+    val containerToLocationId: util.Map[String, LocationId] = new util.HashMap[String, LocationId]()
+    val existingContainerLocality = localityManager.readContainerLocality()
+
+    for (containerId <- 0 to config.getContainerCount) {
+      val localityMapping = existingContainerLocality.get(containerId.toString)
+      // To handle the case when the container count is increased between two different runs of a samza-yarn job,
+      // set the locality of newly added containers to any_host.
+      var locationId: LocationId = new LocationId("ANY_HOST")
+      if (localityMapping != null && localityMapping.containsKey(SetContainerHostMapping.HOST_KEY)) {
+        locationId = new LocationId(localityMapping.get(SetContainerHostMapping.HOST_KEY))
+      }
+      containerToLocationId.put(containerId.toString, locationId)
+    }
+
+    containerToLocationId
+  }
+
+  /**
+    * This method does the following:
+    * 1. Deletes the existing task assignments if the partition-task grouping has changed from the previous run of the job.
+    * 2. Saves the newly generated task assignments to the storage layer through the {@param TaskAssignementManager}.
+    *
+    * @param jobModel              represents the {@see JobModel} of the samza job.
+    * @param taskAssignmentManager required to persist the processor to task assignments to the storage layer.
+    * @param grouperMetadata       provides the historical metadata of the application.
+    */
+  def updateTaskAssignments(jobModel: JobModel, taskAssignmentManager: TaskAssignmentManager, grouperMetadata: GrouperMetadata): Unit = {
+    val taskNames: util.Set[String] = new util.HashSet[String]()
+    for (container <- jobModel.getContainers.values()) {
+      for (taskModel <- container.getTasks.values()) {
+        taskNames.add(taskModel.getTaskName.getTaskName)
+      }
+    }
+    val taskToContainerId = grouperMetadata.getPreviousTaskToProcessorAssignment
+    if (taskNames.size() != taskToContainerId.size()) {
+      warn("Current task count {} does not match saved task count {}. Stateful jobs may observe misalignment of keys!",
+           taskNames.size(), taskToContainerId.size())
+      // If the tasks changed, then the partition-task grouping is also likely changed and we can't handle that
+      // without a much more complicated mapping. Further, the partition count may have changed, which means
+      // input message keys are likely reshuffled w.r.t. partitions, so the local state may not contain necessary
+      // data associated with the incoming keys. Warn the user and default to grouper
+      // In this scenario the tasks may have been reduced, so we need to delete all the existing messages
+      taskAssignmentManager.deleteTaskContainerMappings(taskNames)
+    }
+
+    for (container <- jobModel.getContainers.values()) {
+      for (taskName <- container.getTasks.keySet) {
+        taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName, container.getId)
+      }
+    }
+  }
+
+  /**
+    * Computes the input system stream partitions of a samza job using the provided {@param config}
+    * and {@param streamMetadataCache}.
+    * @param config the configuration of the job.
+    * @param streamMetadataCache to query the partition metadata of the input streams.
+    * @return the input {@see SystemStreamPartition} of the samza job.
+    */
+  private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = {
     val inputSystemStreams = config.getInputStreams
 
     // Get the set of partitions for each SystemStream from the stream metadata
     streamMetadataCache
-      .getStreamMetadata(inputSystemStreams, true)
+      .getStreamMetadata(inputSystemStreams, partitionsMetadataOnly = true)
       .flatMap {
         case (systemStream, metadata) =>
           metadata
@@ -121,55 +202,69 @@ object JobModelManager extends Logging {
       }.toSet
   }
 
+  /**
+    * Builds the input {@see SystemStreamPartition} based upon the {@param config} defined by the user.
+    * @param config configuration to fetch the metadata of the input streams.
+    * @param streamMetadataCache required to query the partition metadata of the input streams.
+    * @return the input SystemStreamPartitions of the job.
+    */
   private def getMatchedInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = {
     val allSystemStreamPartitions = getInputStreamPartitions(config, streamMetadataCache)
     config.getSSPMatcherClass match {
-      case Some(s) => {
+      case Some(s) =>
         val jfr = config.getSSPMatcherConfigJobFactoryRegex.r
         config.getStreamJobFactoryClass match {
-          case Some(jfr(_*)) => {
-            info("before match: allSystemStreamPartitions.size = %s" format (allSystemStreamPartitions.size))
+          case Some(jfr(_*)) =>
+            info("before match: allSystemStreamPartitions.size = %s" format allSystemStreamPartitions.size)
             val sspMatcher = Util.getObj(s, classOf[SystemStreamPartitionMatcher])
             val matchedPartitions = sspMatcher.filter(allSystemStreamPartitions.asJava, config).asScala.toSet
             // Usually a small set hence ok to log at info level
-            info("after match: matchedPartitions = %s" format (matchedPartitions))
+            info("after match: matchedPartitions = %s" format matchedPartitions)
             matchedPartitions
-          }
           case _ => allSystemStreamPartitions
         }
-      }
       case _ => allSystemStreamPartitions
     }
   }
 
   /**
-   * Gets a SystemStreamPartitionGrouper object from the configuration.
-   */
+    * Finds the {@see SystemStreamPartitionGrouperFactory} from the {@param config}. Instantiates the  {@see SystemStreamPartitionGrouper}
+    * object through the factory.
+    * @param config the configuration of the samza job.
+    * @return the instantiated {@see SystemStreamPartitionGrouper}.
+    */
   private def getSystemStreamPartitionGrouper(config: Config) = {
     val factoryString = config.getSystemStreamPartitionGrouperFactory
     val factory = Util.getObj(factoryString, classOf[SystemStreamPartitionGrouperFactory])
     factory.getSystemStreamPartitionGrouper(config)
   }
 
+
   /**
-   * The function reads the latest checkpoint from the underlying coordinator stream and
-   * builds a new JobModel.
-   */
+    * Does the following:
+    * 1. Fetches metadata of the input streams defined in configuration through {@param streamMetadataCache}.
+    * 2. Applies the {@see SystemStreamPartitionGrouper}, {@see TaskNameGrouper} defined in the configuration
+    * to build the {@see JobModel}.
+    * @param config the configuration of the job.
+    * @param changeLogPartitionMapping the task to changelog partition mapping of the job.
+    * @param streamMetadataCache the cache that holds the partition metadata of the input streams.
+    * @param grouperMetadata provides the historical metadata of the application.
+    * @return the built {@see JobModel}.
+    */
   def readJobModel(config: Config,
                    changeLogPartitionMapping: util.Map[TaskName, Integer],
-                   localityManager: LocalityManager,
                    streamMetadataCache: StreamMetadataCache,
-                   containerIds: java.util.List[String]): JobModel = {
+                   grouperMetadata: GrouperMetadata): JobModel = {
     // Do grouping to fetch TaskName to SSP mapping
     val allSystemStreamPartitions = getMatchedInputStreamPartitions(config, streamMetadataCache)
 
     // processor list is required by some of the groupers. So, let's pass them as part of the config.
     // Copy the config and add the processor list to the config copy.
     val configMap = new util.HashMap[String, String](config)
-    configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", containerIds))
+    configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", grouperMetadata.getProcessorLocality.keySet()))
     val grouper = getSystemStreamPartitionGrouper(new MapConfig(configMap))
 
-    val groups = grouper.group(allSystemStreamPartitions.asJava)
+    val groups = grouper.group(allSystemStreamPartitions)
     info("SystemStreamPartitionGrouper %s has grouped the SystemStreamPartitions into %d tasks with the following taskNames: %s" format(grouper, groups.size(), groups.keySet()))
 
     val isHostAffinityEnabled = new ClusterManagerConfig(config).getHostAffinityEnabled
@@ -200,22 +295,18 @@ object JobModelManager extends Logging {
     // SSPTaskNameGrouper for locality, load-balancing, etc.
     val containerGrouperFactory = Util.getObj(config.getTaskNameGrouperFactory, classOf[TaskNameGrouperFactory])
     val containerGrouper = containerGrouperFactory.build(config)
-    val containerModels = {
-      containerGrouper match {
-        case grouper: BalancingTaskNameGrouper if isHostAffinityEnabled => grouper.balance(taskModels.asJava, localityManager)
-        case _ => containerGrouper.group(taskModels.asJava, containerIds)
-      }
-    }
-    val containerMap = containerModels.asScala.map { case (containerModel) => containerModel.getId -> containerModel }.toMap
-
-    if (isHostAffinityEnabled) {
-      new JobModel(config, containerMap.asJava, localityManager)
+    var containerModels: util.Set[ContainerModel] = null
+    if(isHostAffinityEnabled) {
+      containerModels = containerGrouper.group(taskModels, grouperMetadata)
     } else {
-      new JobModel(config, containerMap.asJava)
+      containerModels = containerGrouper.group(taskModels, new util.ArrayList[String](grouperMetadata.getProcessorLocality.keySet()))
     }
+    val containerMap = containerModels.asScala.map(containerModel => containerModel.getId -> containerModel).toMap
+
+    new JobModel(config, containerMap.asJava)
   }
 
-  private def getSystemNames(config: Config) = config.getSystemNames.toSet
+  private def getSystemNames(config: Config) = config.getSystemNames().toSet
 }
 
 /**
@@ -248,7 +339,7 @@ class JobModelManager(
 
   debug("Got job model: %s." format jobModel)
 
-  def start {
+  def start() {
     if (server != null) {
       debug("Starting HTTP server.")
       server.start
@@ -256,7 +347,7 @@ class JobModelManager(
     }
   }
 
-  def stop {
+  def stop() {
     if (server != null) {
       debug("Stopping HTTP server.")
       server.stop

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
index 64f516b..d16c294 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
@@ -50,7 +50,7 @@ class ProcessJobFactory extends StreamJobFactory with Logging {
     coordinatorStreamManager.bootstrap
     val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager)
 
-    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping())
+    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry)
     val jobModel = coordinator.jobModel
 
     val taskPartitionMappings: util.Map[TaskName, Integer] = new util.HashMap[TaskName, Integer]

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index 5a8d2f8..e4a7838 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -52,7 +52,7 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
     coordinatorStreamManager.bootstrap
     val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager)
 
-    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping())
+    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry)
 
     val jobModel = coordinator.jobModel
 

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
index 0c2f2fb..9e6e8d0 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
@@ -24,54 +24,34 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
-import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
+
+import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.SamzaException;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mockito;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 
 import static org.apache.samza.container.mock.ContainerMocks.*;
 import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestGroupByContainerCount {
-  private TaskAssignmentManager taskAssignmentManager;
-  private LocalityManager localityManager;
-  @Before
-  public void setup() throws Exception {
-    taskAssignmentManager = mock(TaskAssignmentManager.class);
-    localityManager = mock(LocalityManager.class);
-    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
-    Mockito.doNothing().when(taskAssignmentManager).init();
-  }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupEmptyTasks() {
-    new GroupByContainerCount(getConfig(1)).group(new HashSet());
+    new GroupByContainerCount(1).group(new HashSet<>());
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupFewerTasksThanContainers() {
     Set<TaskModel> taskModels = new HashSet<>();
     taskModels.add(getTaskModel(1));
-    new GroupByContainerCount(getConfig(2)).group(taskModels);
+    new GroupByContainerCount(2).group(taskModels);
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testGrouperResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels);
     containers.remove(containers.iterator().next());
   }
 
@@ -79,7 +59,7 @@ public class TestGroupByContainerCount {
   public void testGroupHappyPath() {
     Set<TaskModel> taskModels = generateTaskModels(5);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -106,7 +86,7 @@ public class TestGroupByContainerCount {
   public void testGroupManyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(21);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -174,11 +154,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerIncrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(4)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(4).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -213,22 +193,6 @@ public class TestGroupByContainerCount {
     assertTrue(container2.getTasks().containsKey(getTaskName(6)));
     assertTrue(container3.getTasks().containsKey(getTaskName(5)));
     assertTrue(container3.getTasks().containsKey(getTaskName(7)));
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "2");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "3");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "3");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -256,11 +220,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerDecrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -290,20 +254,6 @@ public class TestGroupByContainerCount {
     assertTrue(container0.getTasks().containsKey(getTaskName(2)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -331,15 +281,15 @@ public class TestGroupByContainerCount {
    *  T8  T7  T3
    */
   @Test
-  public void testBalancerMultipleReblances() throws Exception {
+  public void testBalancerMultipleReblances() {
     // Before
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
     // First balance
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -370,30 +320,11 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
 
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
-
-
     // Second balance
     prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
 
-    TaskAssignmentManager taskAssignmentManager2 = mock(TaskAssignmentManager.class);
-    when(taskAssignmentManager2.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
-    LocalityManager localityManager2 = mock(LocalityManager.class);
-    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager2);
-
-    containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager2);
+    GrouperMetadataImpl grouperMetadata1 = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata1);
 
     containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -427,21 +358,6 @@ public class TestGroupByContainerCount {
     assertTrue(container2.getTasks().containsKey(getTaskName(6)));
     assertTrue(container2.getTasks().containsKey(getTaskName(2)));
     assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
-
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
-
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2");
-    verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2");
-
-    verify(taskAssignmentManager2, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -466,11 +382,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerSame() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -496,9 +412,6 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
     assertTrue(container1.getTasks().containsKey(getTaskName(5)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
-
-    verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString());
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -528,19 +441,19 @@ public class TestGroupByContainerCount {
   public void testBalancerAfterContainerSameCustomAssignment() {
     Set<TaskModel> taskModels = generateTaskModels(9);
 
-    Map<String, String> prevTaskToContainerMapping = new HashMap<>();
-    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(6).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(7).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(8).getTaskName(), "1");
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0), "0");
+    prevTaskToContainerMapping.put(getTaskName(1), "0");
+    prevTaskToContainerMapping.put(getTaskName(2), "0");
+    prevTaskToContainerMapping.put(getTaskName(3), "0");
+    prevTaskToContainerMapping.put(getTaskName(4), "0");
+    prevTaskToContainerMapping.put(getTaskName(5), "0");
+    prevTaskToContainerMapping.put(getTaskName(6), "1");
+    prevTaskToContainerMapping.put(getTaskName(7), "1");
+    prevTaskToContainerMapping.put(getTaskName(8), "1");
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -566,9 +479,6 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(6)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(8)));
-
-    verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString());
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -597,16 +507,16 @@ public class TestGroupByContainerCount {
   public void testBalancerAfterContainerSameCustomAssignmentAndContainerIncrease() {
     Set<TaskModel> taskModels = generateTaskModels(6);
 
-    Map<String, String> prevTaskToContainerMapping = new HashMap<>();
-    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "1");
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0), "0");
+    prevTaskToContainerMapping.put(getTaskName(1), "1");
+    prevTaskToContainerMapping.put(getTaskName(2), "1");
+    prevTaskToContainerMapping.put(getTaskName(3), "1");
+    prevTaskToContainerMapping.put(getTaskName(4), "1");
+    prevTaskToContainerMapping.put(getTaskName(5), "1");
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -633,146 +543,106 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(2)));
     assertTrue(container2.getTasks().containsKey(getTaskName(4)));
     assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "2");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "0");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerOldContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(1).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    // Verify task mappings are saved
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerNewContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerEmptyTaskMapping() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<>());
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testGroupTaskCountIncrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(generateTaskModels(taskCount - 1)); // Here's the key step
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1)); // Here's the key step
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testGroupTaskCountDecrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(generateTaskModels(taskCount + 1)); // Here's the key step
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1)); // Here's the key step
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
-    verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
-    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerNewContainerCountGreaterThanTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    new GroupByContainerCount(getConfig(5)).balance(taskModels, localityManager);     // Should throw
+    new GroupByContainerCount(5).group(taskModels, grouperMetadata);     // Should throw
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerEmptyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    new GroupByContainerCount(getConfig(5)).balance(new HashSet<>(), localityManager);     // Should throw
+    new GroupByContainerCount(5).group(new HashSet<>(), grouperMetadata);
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testBalancerResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
     containers.remove(containers.iterator().next());
   }
 
@@ -780,32 +650,20 @@ public class TestGroupByContainerCount {
   public void testBalancerThrowsOnNonIntegerContainerIds() {
     Set<TaskModel> taskModels = generateTaskModels(3);
     Set<ContainerModel> prevContainers = new HashSet<>();
-    taskModels.forEach(model -> {
-        prevContainers.add(
-          new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model)));
-      });
-    Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
-
-    new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager); //Should throw
-
+    taskModels.forEach(model -> prevContainers.add(new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model))));
+    Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    new GroupByContainerCount(3).group(taskModels, grouperMetadata); //Should throw
   }
 
   @Test
   public void testBalancerWithNullLocalityManager() {
     Set<TaskModel> taskModels = generateTaskModels(3);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, null);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, null);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
   }
-
-
-  Config getConfig(int containerCount) {
-    Map<String, String> config = new HashMap<>();
-    config.put(JobConfig.JOB_CONTAINER_COUNT(), String.valueOf(containerCount));
-    return new MapConfig(config);
-  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
index 5bb78e8..12b6b1e 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
@@ -20,6 +20,7 @@
 package org.apache.samza.container.grouper.task;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -29,35 +30,24 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
+
+import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.runtime.LocationId;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
-
-import static org.apache.samza.container.mock.ContainerMocks.*;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
 
+import static org.apache.samza.container.mock.ContainerMocks.generateTaskModels;
+import static org.apache.samza.container.mock.ContainerMocks.getTaskName;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerIds.class})
 public class TestGroupByContainerIds {
 
-  @Before
-  public void setup() throws Exception {
-    TaskAssignmentManager taskAssignmentManager = mock(TaskAssignmentManager.class);
-    LocalityManager localityManager = mock(LocalityManager.class);
-    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
-  }
-
   private Config buildConfigForContainerCount(int count) {
     Map<String, String> map = new HashMap<>();
     map.put("job.container.count", String.valueOf(count));
@@ -67,6 +57,7 @@ public class TestGroupByContainerIds {
   private TaskNameGrouper buildSimpleGrouper() {
     return buildSimpleGrouper(1);
   }
+
   private TaskNameGrouper buildSimpleGrouper(int containerCount) {
     return new GroupByContainerIdsFactory().build(buildConfigForContainerCount(containerCount));
   }
@@ -114,7 +105,8 @@ public class TestGroupByContainerIds {
   public void testGroupWithNullContainerIds() {
     Set<TaskModel> taskModels = generateTaskModels(5);
 
-    Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, null);
+    List<String> containerIds = null;
+    Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, containerIds);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -251,4 +243,264 @@ public class TestGroupByContainerIds {
     assertEquals(1, actualContainerModels.size());
     assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels);
   }
+
+  @Test
+  public void testShouldUseTaskLocalityWhenGeneratingContainerModels() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+                                                                testProcessorId2, testLocationId2,
+                                                                testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+                                                             testTaskName2, testLocationId2,
+                                                             testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+                                                                  new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+                                                                  new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testGenerateContainerModelForSingleContainer() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(1);
+
+    String testProcessorId1 = "testProcessorId1";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+                                                             testTaskName2, testLocationId2,
+                                                             testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1,
+                                                                                                                       testTaskName2, testTaskModel2,
+                                                                                                                       testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testShouldGenerateCorrectContainerModelWhenTaskLocalityIsEmpty() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+                                                                testProcessorId2, testLocationId2,
+                                                                testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+                                                                  new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+                                                                  new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testShouldFailWhenProcessorLocalityIsEmpty() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
+
+    taskNameGrouper.group(new HashSet<>(), grouperMetadata);
+  }
+
+  @Test
+  public void testShouldGenerateIdenticalTaskDistributionWhenNoChangeInProcessorGroup() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+            testProcessorId2, testLocationId2,
+            testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+            testTaskName2, testLocationId2,
+            testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+            new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+            new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+
+    actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testShouldMinimizeTaskShuffleWhenAvailableProcessorInGroupChanges() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+            testProcessorId2, testLocationId2,
+            testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+            testTaskName2, testLocationId2,
+            testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+            new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+            new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+
+    processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+                                        testProcessorId2, testLocationId2);
+
+    grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+    actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1, testTaskName3, testTaskModel3)),
+                                              new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)));
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testMoreTasksThanProcessors() {
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+        testProcessorId2, testLocationId2);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+        testTaskName2, testLocationId2,
+        testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+
+    Set<TaskModel> taskModels = generateTaskModels(1);
+    List<String> containerIds = ImmutableList.of(testProcessorId1, testProcessorId2);
+
+    Map<TaskName, TaskModel> expectedTasks = taskModels.stream()
+        .collect(Collectors.toMap(TaskModel::getTaskName, x -> x));
+    ContainerModel expectedContainerModel = new ContainerModel(testProcessorId1, expectedTasks);
+
+    Set<ContainerModel> actualContainerModels = buildSimpleGrouper().group(taskModels, grouperMetadata);
+
+    assertEquals(1, actualContainerModels.size());
+    assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels);
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
index fcdbf08..60164b2 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
@@ -68,7 +68,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testTaskAssignmentManager() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1", "Task2", "2", "Task3", "0", "Task4", "1");
 
@@ -86,7 +85,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testDeleteMappings() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1");
 
@@ -108,7 +106,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testTaskAssignmentManagerEmptyCoordinatorStream() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = new HashMap<>();
     Map<String, String> localMap = taskAssignmentManager.readTaskAssignment();

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
index ca9def2..be240b1 100644
--- a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
+++ b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
@@ -117,11 +117,11 @@ public class ContainerMocks {
     return values;
   }
 
-  public static Map<String, String> generateTaskContainerMapping(Set<ContainerModel> containers) {
-    Map<String, String> taskMapping = new HashMap<>();
+  public static Map<TaskName, String> generateTaskContainerMapping(Set<ContainerModel> containers) {
+    Map<TaskName, String> taskMapping = new HashMap<>();
     for (ContainerModel container : containers) {
       for (TaskName taskName : container.getTasks().keySet()) {
-        taskMapping.put(taskName.getTaskName(), container.getId());
+        taskMapping.put(taskName, container.getId());
       }
     }
     return taskMapping;

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
index ea25ec1..02aaaa7 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
@@ -19,15 +19,15 @@
 
 package org.apache.samza.coordinator;
 
-import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import org.apache.samza.config.Config;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.system.StreamMetadataCache;
 
 /**
@@ -49,15 +49,8 @@ public class JobModelManagerTestUtil {
     return new JobModelManager(jobModel, server, null);
   }
 
-  public static JobModelManager getJobModelManagerUsingReadModel(Config config, int containerCount, StreamMetadataCache streamMetadataCache,
-    LocalityManager locManager, HttpServer server) {
-    List<String> containerIds = new ArrayList<>();
-    for (int i = 0; i < containerCount; i++) {
-      containerIds.add(String.valueOf(i));
-    }
-    JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), locManager, streamMetadataCache, containerIds);
-    return new JobModelManager(jobModel, server, null);
+  public static JobModelManager getJobModelManagerUsingReadModel(Config config, StreamMetadataCache streamMetadataCache, HttpServer server, LocalityManager localityManager, Map<String, LocationId> processorLocality) {
+    JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), streamMetadataCache, new GrouperMetadataImpl(processorLocality, new HashMap<>(), new HashMap<>(), new HashMap<>()));
+    return new JobModelManager(new JobModel(jobModel.getConfig(), jobModel.getContainers(), localityManager), server, localityManager);
   }
-
-
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
index 1dbf132..6048466 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
@@ -19,20 +19,30 @@
 
 package org.apache.samza.coordinator;
 
+import com.google.common.collect.ImmutableMap;
+import java.util.HashSet;
+import java.util.Set;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.container.grouper.task.GroupByContainerCount;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.container.grouper.task.TaskAssignmentManager;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.JobModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.testUtils.MockHttpServer;
 import org.eclipse.jetty.servlet.DefaultServlet;
 import org.eclipse.jetty.servlet.ServletHolder;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -40,14 +50,15 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Collections;
 
+import static org.apache.samza.coordinator.JobModelManager.*;
 import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.anyBoolean;
 import static org.mockito.Matchers.argThat;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
 
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentMatcher;
+import org.mockito.Mockito;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
@@ -60,7 +71,6 @@ import scala.collection.JavaConversions;
 @PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestJobModelManager {
   private final TaskAssignmentManager mockTaskManager = mock(TaskAssignmentManager.class);
-  private final LocalityManager mockLocalityManager = mock(LocalityManager.class);
   private final Map<String, Map<String, String>> localityMappings = new HashMap<>();
   private final HttpServer server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class));
   private final SystemStream inputStream = new SystemStream("test-system", "test-stream");
@@ -75,7 +85,6 @@ public class TestJobModelManager {
 
   @Before
   public void setup() throws Exception {
-    when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
     when(mockStreamMetadataCache.getStreamMetadata(argThat(new ArgumentMatcher<scala.collection.immutable.Set<SystemStream>>() {
       @Override
       public boolean matches(Object argument) {
@@ -105,11 +114,15 @@ public class TestJobModelManager {
         put("job.host-affinity.enabled", "true");
       }
     });
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
 
-    this.localityMappings.put("0", new HashMap<String, String>() { {
+    localityMappings.put("0", new HashMap<String, String>() { {
         put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
       } });
-    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+    when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
+
+    Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity"));
+    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
 
     assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", "abc-affinity"); } });
   }
@@ -132,11 +145,96 @@ public class TestJobModelManager {
       }
     });
 
-    this.localityMappings.put("0", new HashMap<String, String>() { {
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+    localityMappings.put("0", new HashMap<String, String>() { {
         put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
       } });
-    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+    when(mockLocalityManager.readContainerLocality()).thenReturn(new HashMap<>());
+
+    Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity"));
+
+    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
 
     assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", null); } });
   }
+
+  @Test
+  public void testGetGrouperMetadata() {
+    // Mocking setup.
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+    TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+
+    Map<String, Map<String, String>> localityMappings = new HashMap<>();
+    localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+    Map<String, String> taskAssignment = ImmutableMap.of("task-0", "0");
+
+    // Mock the container locality assignment.
+    when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+    // Mock the container to task assignment.
+    when(mockTaskAssignmentManager.readTaskAssignment()).thenReturn(taskAssignment);
+
+    GrouperMetadataImpl grouperMetadata = JobModelManager.getGrouperMetadata(new MapConfig(), mockLocalityManager, mockTaskAssignmentManager);
+
+    Mockito.verify(mockLocalityManager).readContainerLocality();
+    Mockito.verify(mockTaskAssignmentManager).readTaskAssignment();
+
+    Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), grouperMetadata.getProcessorLocality());
+    Assert.assertEquals(ImmutableMap.of(new TaskName("task-0"), new LocationId("abc-affinity")), grouperMetadata.getTaskLocality());
+  }
+
+  @Test
+  public void testGetProcessorLocality() {
+    // Mock the dependencies.
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+    Map<String, Map<String, String>> localityMappings = new HashMap<>();
+    localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+    // Mock the container locality assignment.
+    when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+    Map<String, LocationId> processorLocality = JobModelManager.getProcessorLocality(new MapConfig(), mockLocalityManager);
+
+    Mockito.verify(mockLocalityManager).readContainerLocality();
+    Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), processorLocality);
+  }
+
+  @Test
+  public void testUpdateTaskAssignments() {
+    // Mocking setup.
+    JobModel mockJobModel = Mockito.mock(JobModel.class);
+    GrouperMetadataImpl mockGrouperMetadata = Mockito.mock(GrouperMetadataImpl.class);
+    TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+
+    Map<TaskName, TaskModel> taskModelMap = new HashMap<>();
+    taskModelMap.put(new TaskName("task-1"), new TaskModel(new TaskName("task-1"), new HashSet<>(), new Partition(0)));
+    taskModelMap.put(new TaskName("task-2"), new TaskModel(new TaskName("task-2"), new HashSet<>(), new Partition(1)));
+    taskModelMap.put(new TaskName("task-3"), new TaskModel(new TaskName("task-3"), new HashSet<>(), new Partition(2)));
+    taskModelMap.put(new TaskName("task-4"), new TaskModel(new TaskName("task-4"), new HashSet<>(), new Partition(3)));
+    ContainerModel containerModel = new ContainerModel("test-container-id", taskModelMap);
+    Map<String, ContainerModel> containerMapping = ImmutableMap.of("test-container-id", containerModel);
+
+    when(mockJobModel.getContainers()).thenReturn(containerMapping);
+    when(mockGrouperMetadata.getPreviousTaskToProcessorAssignment()).thenReturn(new HashMap<>());
+    Mockito.doNothing().when(mockTaskAssignmentManager).writeTaskContainerMapping(Mockito.any(), Mockito.any());
+
+    JobModelManager.updateTaskAssignments(mockJobModel, mockTaskAssignmentManager, mockGrouperMetadata);
+
+    Set<String> taskNames = new HashSet<String>();
+    taskNames.add("task-4");
+    taskNames.add("task-2");
+    taskNames.add("task-3");
+    taskNames.add("task-1");
+
+    // Verifications
+    Mockito.verify(mockJobModel, atLeast(1)).getContainers();
+    Mockito.verify(mockTaskAssignmentManager).deleteTaskContainerMappings((Iterable<String>) taskNames);
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-1", "test-container-id");
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-2", "test-container-id");
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-3", "test-container-id");
+    Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-4", "test-container-id");
+  }
 }