You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ja...@apache.org on 2018/12/05 18:57:03 UTC
[2/3] samza git commit: SAMZA-1973: Unify the TaskNameGrouper
interface for yarn and standalone.
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
index 600b7a1..5fb71f3 100644
--- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
@@ -19,31 +19,33 @@
package org.apache.samza.coordinator
-
import java.util
import java.util.concurrent.atomic.AtomicReference
-
import org.apache.samza.config._
import org.apache.samza.config.JobConfig.Config2Job
import org.apache.samza.config.SystemConfig.Config2System
import org.apache.samza.config.TaskConfig.Config2Task
import org.apache.samza.config.Config
import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory
-import org.apache.samza.container.grouper.task.BalancingTaskNameGrouper
-import org.apache.samza.container.grouper.task.TaskNameGrouperFactory
+import org.apache.samza.container.grouper.task._
import org.apache.samza.container.LocalityManager
import org.apache.samza.container.TaskName
import org.apache.samza.coordinator.server.HttpServer
import org.apache.samza.coordinator.server.JobServlet
+import org.apache.samza.job.model.ContainerModel
import org.apache.samza.job.model.JobModel
import org.apache.samza.job.model.TaskModel
+import org.apache.samza.metrics.MetricsRegistry
import org.apache.samza.metrics.MetricsRegistryMap
import org.apache.samza.system._
import org.apache.samza.util.Logging
import org.apache.samza.util.Util
import org.apache.samza.Partition
+import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping
+import org.apache.samza.runtime.LocationId
import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
/**
* Helper companion object that is responsible for wiring up a JobModelManager
@@ -51,66 +53,145 @@ import scala.collection.JavaConverters._
*/
object JobModelManager extends Logging {
- val SOURCE = "JobModelManager"
/**
* a volatile value to store the current instantiated <code>JobModelManager</code>
*/
- @volatile var currentJobModelManager: JobModelManager = null
+ @volatile var currentJobModelManager: JobModelManager = _
val jobModelRef: AtomicReference[JobModel] = new AtomicReference[JobModel]()
/**
- * Does the following actions for a job.
+ * Currently used only in the ApplicationMaster for yarn deployment model.
+ * Does the following:
* a) Reads the jobModel from coordinator stream using the job's configuration.
- * b) Recomputes changelog partition mapping based on jobModel and job's configuration.
+ * b) Recomputes the changelog partition mapping based on jobModel and job's configuration.
* c) Builds JobModelManager using the jobModel read from coordinator stream.
- * @param config Config from the coordinator stream.
- * @param changelogPartitionMapping The changelog partition-to-task mapping.
- * @return JobModelManager
+ * @param config config from the coordinator stream.
+ * @param changelogPartitionMapping changelog partition-to-task mapping of the samza job.
+ * @param metricsRegistry the registry for reporting metrics.
+ * @return the instantiated {@see JobModelManager}.
*/
- def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer]) = {
- val localityManager = new LocalityManager(config, new MetricsRegistryMap())
-
- // Map the name of each system to the corresponding SystemAdmin
+ def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer], metricsRegistry: MetricsRegistry = new MetricsRegistryMap()): JobModelManager = {
+ val localityManager = new LocalityManager(config, metricsRegistry)
+ val taskAssignmentManager = new TaskAssignmentManager(config, metricsRegistry)
val systemAdmins = new SystemAdmins(config)
- val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+ try {
+ systemAdmins.start()
+ val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+ val grouperMetadata: GrouperMetadata = getGrouperMetadata(config, localityManager, taskAssignmentManager)
- val containerCount = new JobConfig(config).getContainerCount
- val processorList = List.range(0, containerCount).map(c => c.toString)
+ val jobModel: JobModel = readJobModel(config, changelogPartitionMapping, streamMetadataCache, grouperMetadata)
+ jobModelRef.set(new JobModel(jobModel.getConfig, jobModel.getContainers, localityManager))
- systemAdmins.start()
- val jobModelManager = getJobModelManager(config, changelogPartitionMapping, localityManager, streamMetadataCache, processorList.asJava)
- systemAdmins.stop()
+ updateTaskAssignments(jobModel, taskAssignmentManager, grouperMetadata)
- jobModelManager
+ val server = new HttpServer
+ server.addServlet("/", new JobServlet(jobModelRef))
+
+ currentJobModelManager = new JobModelManager(jobModel, server, localityManager)
+ currentJobModelManager
+ } finally {
+ taskAssignmentManager.close()
+ systemAdmins.stop()
+ // Not closing localityManager, since {@code ClusterBasedJobCoordinator} uses it to read container locality through {@code JobModel}.
+ }
}
/**
- * Build a JobModelManager using a Samza job's configuration.
- */
- private def getJobModelManager(config: Config,
- changeLogMapping: util.Map[TaskName, Integer],
- localityManager: LocalityManager,
- streamMetadataCache: StreamMetadataCache,
- containerIds: java.util.List[String]) = {
- val jobModel: JobModel = readJobModel(config, changeLogMapping, localityManager, streamMetadataCache, containerIds)
- jobModelRef.set(jobModel)
-
- val server = new HttpServer
- server.addServlet("/", new JobServlet(jobModelRef))
- currentJobModelManager = new JobModelManager(jobModel, server, localityManager)
- currentJobModelManager
+ * Builds the {@see GrouperMetadataImpl} for the samza job.
+ * @param config represents the configurations defined by the user.
+ * @param localityManager provides the processor to host mapping persisted to the metadata store.
+ * @param taskAssignmentManager provides the processor to task assignments persisted to the metadata store.
+ * @return the instantiated {@see GrouperMetadata}.
+ */
+ def getGrouperMetadata(config: Config, localityManager: LocalityManager, taskAssignmentManager: TaskAssignmentManager) = {
+ val processorLocality: util.Map[String, LocationId] = getProcessorLocality(config, localityManager)
+ val taskAssignment: util.Map[String, String] = taskAssignmentManager.readTaskAssignment()
+ val taskNameToProcessorId: util.Map[TaskName, String] = new util.HashMap[TaskName, String]()
+ for ((taskName, processorId) <- taskAssignment) {
+ taskNameToProcessorId.put(new TaskName(taskName), processorId)
+ }
+
+ val taskLocality:util.Map[TaskName, LocationId] = new util.HashMap[TaskName, LocationId]()
+ for ((taskName, processorId) <- taskAssignment) {
+ if (processorLocality.containsKey(processorId)) {
+ taskLocality.put(new TaskName(taskName), processorLocality.get(processorId))
+ }
+ }
+ new GrouperMetadataImpl(processorLocality, taskLocality, new util.HashMap[TaskName, util.List[SystemStreamPartition]](), taskNameToProcessorId)
}
/**
- * For each input stream specified in config, exactly determine its
- * partitions, returning a set of SystemStreamPartitions containing them all.
- */
- private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache) = {
+ * Retrieves and returns the processor locality of a samza job using provided {@see Config} and {@see LocalityManager}.
+ * @param config provides the configurations defined by the user. Required to connect to the storage layer.
+ * @param localityManager provides the processor to host mapping persisted to the metadata store.
+ * @return the processor locality.
+ */
+ def getProcessorLocality(config: Config, localityManager: LocalityManager) = {
+ val containerToLocationId: util.Map[String, LocationId] = new util.HashMap[String, LocationId]()
+ val existingContainerLocality = localityManager.readContainerLocality()
+
+ for (containerId <- 0 to config.getContainerCount) {
+ val localityMapping = existingContainerLocality.get(containerId.toString)
+ // To handle the case when the container count is increased between two different runs of a samza-yarn job,
+ // set the locality of newly added containers to any_host.
+ var locationId: LocationId = new LocationId("ANY_HOST")
+ if (localityMapping != null && localityMapping.containsKey(SetContainerHostMapping.HOST_KEY)) {
+ locationId = new LocationId(localityMapping.get(SetContainerHostMapping.HOST_KEY))
+ }
+ containerToLocationId.put(containerId.toString, locationId)
+ }
+
+ containerToLocationId
+ }
+
+ /**
+ * This method does the following:
+ * 1. Deletes the existing task assignments if the partition-task grouping has changed from the previous run of the job.
+ * 2. Saves the newly generated task assignments to the storage layer through the {@param TaskAssignementManager}.
+ *
+ * @param jobModel represents the {@see JobModel} of the samza job.
+ * @param taskAssignmentManager required to persist the processor to task assignments to the storage layer.
+ * @param grouperMetadata provides the historical metadata of the application.
+ */
+ def updateTaskAssignments(jobModel: JobModel, taskAssignmentManager: TaskAssignmentManager, grouperMetadata: GrouperMetadata): Unit = {
+ val taskNames: util.Set[String] = new util.HashSet[String]()
+ for (container <- jobModel.getContainers.values()) {
+ for (taskModel <- container.getTasks.values()) {
+ taskNames.add(taskModel.getTaskName.getTaskName)
+ }
+ }
+ val taskToContainerId = grouperMetadata.getPreviousTaskToProcessorAssignment
+ if (taskNames.size() != taskToContainerId.size()) {
+ warn("Current task count {} does not match saved task count {}. Stateful jobs may observe misalignment of keys!",
+ taskNames.size(), taskToContainerId.size())
+ // If the tasks changed, then the partition-task grouping is also likely changed and we can't handle that
+ // without a much more complicated mapping. Further, the partition count may have changed, which means
+ // input message keys are likely reshuffled w.r.t. partitions, so the local state may not contain necessary
+ // data associated with the incoming keys. Warn the user and default to grouper
+ // In this scenario the tasks may have been reduced, so we need to delete all the existing messages
+ taskAssignmentManager.deleteTaskContainerMappings(taskNames)
+ }
+
+ for (container <- jobModel.getContainers.values()) {
+ for (taskName <- container.getTasks.keySet) {
+ taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName, container.getId)
+ }
+ }
+ }
+
+ /**
+ * Computes the input system stream partitions of a samza job using the provided {@param config}
+ * and {@param streamMetadataCache}.
+ * @param config the configuration of the job.
+ * @param streamMetadataCache to query the partition metadata of the input streams.
+ * @return the input {@see SystemStreamPartition} of the samza job.
+ */
+ private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = {
val inputSystemStreams = config.getInputStreams
// Get the set of partitions for each SystemStream from the stream metadata
streamMetadataCache
- .getStreamMetadata(inputSystemStreams, true)
+ .getStreamMetadata(inputSystemStreams, partitionsMetadataOnly = true)
.flatMap {
case (systemStream, metadata) =>
metadata
@@ -121,55 +202,69 @@ object JobModelManager extends Logging {
}.toSet
}
+ /**
+ * Builds the input {@see SystemStreamPartition} based upon the {@param config} defined by the user.
+ * @param config configuration to fetch the metadata of the input streams.
+ * @param streamMetadataCache required to query the partition metadata of the input streams.
+ * @return the input SystemStreamPartitions of the job.
+ */
private def getMatchedInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = {
val allSystemStreamPartitions = getInputStreamPartitions(config, streamMetadataCache)
config.getSSPMatcherClass match {
- case Some(s) => {
+ case Some(s) =>
val jfr = config.getSSPMatcherConfigJobFactoryRegex.r
config.getStreamJobFactoryClass match {
- case Some(jfr(_*)) => {
- info("before match: allSystemStreamPartitions.size = %s" format (allSystemStreamPartitions.size))
+ case Some(jfr(_*)) =>
+ info("before match: allSystemStreamPartitions.size = %s" format allSystemStreamPartitions.size)
val sspMatcher = Util.getObj(s, classOf[SystemStreamPartitionMatcher])
val matchedPartitions = sspMatcher.filter(allSystemStreamPartitions.asJava, config).asScala.toSet
// Usually a small set hence ok to log at info level
- info("after match: matchedPartitions = %s" format (matchedPartitions))
+ info("after match: matchedPartitions = %s" format matchedPartitions)
matchedPartitions
- }
case _ => allSystemStreamPartitions
}
- }
case _ => allSystemStreamPartitions
}
}
/**
- * Gets a SystemStreamPartitionGrouper object from the configuration.
- */
+ * Finds the {@see SystemStreamPartitionGrouperFactory} from the {@param config}. Instantiates the {@see SystemStreamPartitionGrouper}
+ * object through the factory.
+ * @param config the configuration of the samza job.
+ * @return the instantiated {@see SystemStreamPartitionGrouper}.
+ */
private def getSystemStreamPartitionGrouper(config: Config) = {
val factoryString = config.getSystemStreamPartitionGrouperFactory
val factory = Util.getObj(factoryString, classOf[SystemStreamPartitionGrouperFactory])
factory.getSystemStreamPartitionGrouper(config)
}
+
/**
- * The function reads the latest checkpoint from the underlying coordinator stream and
- * builds a new JobModel.
- */
+ * Does the following:
+ * 1. Fetches metadata of the input streams defined in configuration through {@param streamMetadataCache}.
+ * 2. Applies the {@see SystemStreamPartitionGrouper}, {@see TaskNameGrouper} defined in the configuration
+ * to build the {@see JobModel}.
+ * @param config the configuration of the job.
+ * @param changeLogPartitionMapping the task to changelog partition mapping of the job.
+ * @param streamMetadataCache the cache that holds the partition metadata of the input streams.
+ * @param grouperMetadata provides the historical metadata of the application.
+ * @return the built {@see JobModel}.
+ */
def readJobModel(config: Config,
changeLogPartitionMapping: util.Map[TaskName, Integer],
- localityManager: LocalityManager,
streamMetadataCache: StreamMetadataCache,
- containerIds: java.util.List[String]): JobModel = {
+ grouperMetadata: GrouperMetadata): JobModel = {
// Do grouping to fetch TaskName to SSP mapping
val allSystemStreamPartitions = getMatchedInputStreamPartitions(config, streamMetadataCache)
// processor list is required by some of the groupers. So, let's pass them as part of the config.
// Copy the config and add the processor list to the config copy.
val configMap = new util.HashMap[String, String](config)
- configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", containerIds))
+ configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", grouperMetadata.getProcessorLocality.keySet()))
val grouper = getSystemStreamPartitionGrouper(new MapConfig(configMap))
- val groups = grouper.group(allSystemStreamPartitions.asJava)
+ val groups = grouper.group(allSystemStreamPartitions)
info("SystemStreamPartitionGrouper %s has grouped the SystemStreamPartitions into %d tasks with the following taskNames: %s" format(grouper, groups.size(), groups.keySet()))
val isHostAffinityEnabled = new ClusterManagerConfig(config).getHostAffinityEnabled
@@ -200,22 +295,18 @@ object JobModelManager extends Logging {
// SSPTaskNameGrouper for locality, load-balancing, etc.
val containerGrouperFactory = Util.getObj(config.getTaskNameGrouperFactory, classOf[TaskNameGrouperFactory])
val containerGrouper = containerGrouperFactory.build(config)
- val containerModels = {
- containerGrouper match {
- case grouper: BalancingTaskNameGrouper if isHostAffinityEnabled => grouper.balance(taskModels.asJava, localityManager)
- case _ => containerGrouper.group(taskModels.asJava, containerIds)
- }
- }
- val containerMap = containerModels.asScala.map { case (containerModel) => containerModel.getId -> containerModel }.toMap
-
- if (isHostAffinityEnabled) {
- new JobModel(config, containerMap.asJava, localityManager)
+ var containerModels: util.Set[ContainerModel] = null
+ if(isHostAffinityEnabled) {
+ containerModels = containerGrouper.group(taskModels, grouperMetadata)
} else {
- new JobModel(config, containerMap.asJava)
+ containerModels = containerGrouper.group(taskModels, new util.ArrayList[String](grouperMetadata.getProcessorLocality.keySet()))
}
+ val containerMap = containerModels.asScala.map(containerModel => containerModel.getId -> containerModel).toMap
+
+ new JobModel(config, containerMap.asJava)
}
- private def getSystemNames(config: Config) = config.getSystemNames.toSet
+ private def getSystemNames(config: Config) = config.getSystemNames().toSet
}
/**
@@ -248,7 +339,7 @@ class JobModelManager(
debug("Got job model: %s." format jobModel)
- def start {
+ def start() {
if (server != null) {
debug("Starting HTTP server.")
server.start
@@ -256,7 +347,7 @@ class JobModelManager(
}
}
- def stop {
+ def stop() {
if (server != null) {
debug("Stopping HTTP server.")
server.stop
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
index 64f516b..d16c294 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
@@ -50,7 +50,7 @@ class ProcessJobFactory extends StreamJobFactory with Logging {
coordinatorStreamManager.bootstrap
val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager)
- val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping())
+ val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry)
val jobModel = coordinator.jobModel
val taskPartitionMappings: util.Map[TaskName, Integer] = new util.HashMap[TaskName, Integer]
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index 5a8d2f8..e4a7838 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -52,7 +52,7 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
coordinatorStreamManager.bootstrap
val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager)
- val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping())
+ val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry)
val jobModel = coordinator.jobModel
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
index 0c2f2fb..9e6e8d0 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
@@ -24,54 +24,34 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
-import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
+
+import org.apache.samza.container.TaskName;
import org.apache.samza.job.model.ContainerModel;
import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.SamzaException;
import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mockito;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
import static org.apache.samza.container.mock.ContainerMocks.*;
import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
public class TestGroupByContainerCount {
- private TaskAssignmentManager taskAssignmentManager;
- private LocalityManager localityManager;
- @Before
- public void setup() throws Exception {
- taskAssignmentManager = mock(TaskAssignmentManager.class);
- localityManager = mock(LocalityManager.class);
- PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
- Mockito.doNothing().when(taskAssignmentManager).init();
- }
@Test(expected = IllegalArgumentException.class)
public void testGroupEmptyTasks() {
- new GroupByContainerCount(getConfig(1)).group(new HashSet());
+ new GroupByContainerCount(1).group(new HashSet<>());
}
@Test(expected = IllegalArgumentException.class)
public void testGroupFewerTasksThanContainers() {
Set<TaskModel> taskModels = new HashSet<>();
taskModels.add(getTaskModel(1));
- new GroupByContainerCount(getConfig(2)).group(taskModels);
+ new GroupByContainerCount(2).group(taskModels);
}
@Test(expected = UnsupportedOperationException.class)
public void testGrouperResultImmutable() {
Set<TaskModel> taskModels = generateTaskModels(3);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).group(taskModels);
+ Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels);
containers.remove(containers.iterator().next());
}
@@ -79,7 +59,7 @@ public class TestGroupByContainerCount {
public void testGroupHappyPath() {
Set<TaskModel> taskModels = generateTaskModels(5);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
+ Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -106,7 +86,7 @@ public class TestGroupByContainerCount {
public void testGroupManyTasks() {
Set<TaskModel> taskModels = generateTaskModels(21);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
+ Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -174,11 +154,11 @@ public class TestGroupByContainerCount {
@Test
public void testBalancerAfterContainerIncrease() {
Set<TaskModel> taskModels = generateTaskModels(9);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(4)).balance(taskModels, localityManager);
+ Set<ContainerModel> containers = new GroupByContainerCount(4).group(taskModels, grouperMetadata);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -213,22 +193,6 @@ public class TestGroupByContainerCount {
assertTrue(container2.getTasks().containsKey(getTaskName(6)));
assertTrue(container3.getTasks().containsKey(getTaskName(5)));
assertTrue(container3.getTasks().containsKey(getTaskName(7)));
-
- // Verify task mappings are saved
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "2");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2");
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "3");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "3");
-
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
/**
@@ -256,11 +220,11 @@ public class TestGroupByContainerCount {
@Test
public void testBalancerAfterContainerDecrease() {
Set<TaskModel> taskModels = generateTaskModels(9);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+ Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -290,20 +254,6 @@ public class TestGroupByContainerCount {
assertTrue(container0.getTasks().containsKey(getTaskName(2)));
assertTrue(container1.getTasks().containsKey(getTaskName(7)));
assertTrue(container1.getTasks().containsKey(getTaskName(3)));
-
- // Verify task mappings are saved
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
/**
@@ -331,15 +281,15 @@ public class TestGroupByContainerCount {
* T8 T7 T3
*/
@Test
- public void testBalancerMultipleReblances() throws Exception {
+ public void testBalancerMultipleReblances() {
// Before
Set<TaskModel> taskModels = generateTaskModels(9);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
// First balance
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+ Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -370,30 +320,11 @@ public class TestGroupByContainerCount {
assertTrue(container1.getTasks().containsKey(getTaskName(7)));
assertTrue(container1.getTasks().containsKey(getTaskName(3)));
- // Verify task mappings are saved
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1");
-
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
-
-
// Second balance
prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- TaskAssignmentManager taskAssignmentManager2 = mock(TaskAssignmentManager.class);
- when(taskAssignmentManager2.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
- LocalityManager localityManager2 = mock(LocalityManager.class);
- PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager2);
-
- containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager2);
+ GrouperMetadataImpl grouperMetadata1 = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+ containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata1);
containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -427,21 +358,6 @@ public class TestGroupByContainerCount {
assertTrue(container2.getTasks().containsKey(getTaskName(6)));
assertTrue(container2.getTasks().containsKey(getTaskName(2)));
assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
- // Verify task mappings are saved
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0");
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0");
-
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1");
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1");
-
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2");
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2");
- verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2");
-
- verify(taskAssignmentManager2, never()).deleteTaskContainerMappings(anyCollection());
}
/**
@@ -466,11 +382,11 @@ public class TestGroupByContainerCount {
@Test
public void testBalancerAfterContainerSame() {
Set<TaskModel> taskModels = generateTaskModels(9);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+ Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -496,9 +412,6 @@ public class TestGroupByContainerCount {
assertTrue(container1.getTasks().containsKey(getTaskName(3)));
assertTrue(container1.getTasks().containsKey(getTaskName(5)));
assertTrue(container1.getTasks().containsKey(getTaskName(7)));
-
- verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString());
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
/**
@@ -528,19 +441,19 @@ public class TestGroupByContainerCount {
public void testBalancerAfterContainerSameCustomAssignment() {
Set<TaskModel> taskModels = generateTaskModels(9);
- Map<String, String> prevTaskToContainerMapping = new HashMap<>();
- prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
- prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "0");
- prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "0");
- prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "0");
- prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "0");
- prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "0");
- prevTaskToContainerMapping.put(getTaskName(6).getTaskName(), "1");
- prevTaskToContainerMapping.put(getTaskName(7).getTaskName(), "1");
- prevTaskToContainerMapping.put(getTaskName(8).getTaskName(), "1");
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+ prevTaskToContainerMapping.put(getTaskName(0), "0");
+ prevTaskToContainerMapping.put(getTaskName(1), "0");
+ prevTaskToContainerMapping.put(getTaskName(2), "0");
+ prevTaskToContainerMapping.put(getTaskName(3), "0");
+ prevTaskToContainerMapping.put(getTaskName(4), "0");
+ prevTaskToContainerMapping.put(getTaskName(5), "0");
+ prevTaskToContainerMapping.put(getTaskName(6), "1");
+ prevTaskToContainerMapping.put(getTaskName(7), "1");
+ prevTaskToContainerMapping.put(getTaskName(8), "1");
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+ Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -566,9 +479,6 @@ public class TestGroupByContainerCount {
assertTrue(container1.getTasks().containsKey(getTaskName(6)));
assertTrue(container1.getTasks().containsKey(getTaskName(7)));
assertTrue(container1.getTasks().containsKey(getTaskName(8)));
-
- verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString());
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
/**
@@ -597,16 +507,16 @@ public class TestGroupByContainerCount {
public void testBalancerAfterContainerSameCustomAssignmentAndContainerIncrease() {
Set<TaskModel> taskModels = generateTaskModels(6);
- Map<String, String> prevTaskToContainerMapping = new HashMap<>();
- prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
- prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "1");
- prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "1");
- prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "1");
- prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "1");
- prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "1");
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+ prevTaskToContainerMapping.put(getTaskName(0), "0");
+ prevTaskToContainerMapping.put(getTaskName(1), "1");
+ prevTaskToContainerMapping.put(getTaskName(2), "1");
+ prevTaskToContainerMapping.put(getTaskName(3), "1");
+ prevTaskToContainerMapping.put(getTaskName(4), "1");
+ prevTaskToContainerMapping.put(getTaskName(5), "1");
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+ Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -633,146 +543,106 @@ public class TestGroupByContainerCount {
assertTrue(container1.getTasks().containsKey(getTaskName(2)));
assertTrue(container2.getTasks().containsKey(getTaskName(4)));
assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "2");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "0");
-
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
@Test
public void testBalancerOldContainerCountOne() {
Set<TaskModel> taskModels = generateTaskModels(3);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(1).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
- Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+ Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+ Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).group(taskModels, grouperMetadata);
// Results should be the same as calling group()
assertEquals(groupContainers, balanceContainers);
-
- // Verify task mappings are saved
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2");
-
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
@Test
public void testBalancerNewContainerCountOne() {
Set<TaskModel> taskModels = generateTaskModels(3);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
- Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+ Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+ Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
// Results should be the same as calling group()
assertEquals(groupContainers, balanceContainers);
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
@Test
public void testBalancerEmptyTaskMapping() {
Set<TaskModel> taskModels = generateTaskModels(3);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<>());
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
- Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
- Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+ Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+ Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
// Results should be the same as calling group()
assertEquals(groupContainers, balanceContainers);
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
- verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection());
}
@Test
public void testGroupTaskCountIncrease() {
int taskCount = 3;
Set<TaskModel> taskModels = generateTaskModels(taskCount);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(generateTaskModels(taskCount - 1)); // Here's the key step
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1)); // Here's the key step
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
- Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+ Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+ Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
// Results should be the same as calling group()
assertEquals(groupContainers, balanceContainers);
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
- verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
}
@Test
public void testGroupTaskCountDecrease() {
int taskCount = 3;
Set<TaskModel> taskModels = generateTaskModels(taskCount);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(generateTaskModels(taskCount + 1)); // Here's the key step
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1)); // Here's the key step
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
- Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+ Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+ Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata);
// Results should be the same as calling group()
assertEquals(groupContainers, balanceContainers);
-
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
- verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0");
-
- verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
}
@Test(expected = IllegalArgumentException.class)
public void testBalancerNewContainerCountGreaterThanTasks() {
Set<TaskModel> taskModels = generateTaskModels(3);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- new GroupByContainerCount(getConfig(5)).balance(taskModels, localityManager); // Should throw
+ new GroupByContainerCount(5).group(taskModels, grouperMetadata); // Should throw
}
@Test(expected = IllegalArgumentException.class)
public void testBalancerEmptyTasks() {
Set<TaskModel> taskModels = generateTaskModels(3);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- new GroupByContainerCount(getConfig(5)).balance(new HashSet<>(), localityManager); // Should throw
+ new GroupByContainerCount(5).group(new HashSet<>(), grouperMetadata);
}
@Test(expected = UnsupportedOperationException.class)
public void testBalancerResultImmutable() {
Set<TaskModel> taskModels = generateTaskModels(3);
- Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+ Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
- Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+ Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata);
containers.remove(containers.iterator().next());
}
@@ -780,32 +650,20 @@ public class TestGroupByContainerCount {
public void testBalancerThrowsOnNonIntegerContainerIds() {
Set<TaskModel> taskModels = generateTaskModels(3);
Set<ContainerModel> prevContainers = new HashSet<>();
- taskModels.forEach(model -> {
- prevContainers.add(
- new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model)));
- });
- Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
- when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
-
- new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager); //Should throw
-
+ taskModels.forEach(model -> prevContainers.add(new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model))));
+ Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+ new GroupByContainerCount(3).group(taskModels, grouperMetadata); //Should throw
}
@Test
public void testBalancerWithNullLocalityManager() {
Set<TaskModel> taskModels = generateTaskModels(3);
- Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
- Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, null);
+ Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+ Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, null);
// Results should be the same as calling group()
assertEquals(groupContainers, balanceContainers);
}
-
-
- Config getConfig(int containerCount) {
- Map<String, String> config = new HashMap<>();
- config.put(JobConfig.JOB_CONTAINER_COUNT(), String.valueOf(containerCount));
- return new MapConfig(config);
- }
}
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
index 5bb78e8..12b6b1e 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
@@ -20,6 +20,7 @@
package org.apache.samza.container.grouper.task;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collections;
@@ -29,35 +30,24 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
+
+import org.apache.samza.Partition;
import org.apache.samza.config.Config;
import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
import org.apache.samza.container.TaskName;
import org.apache.samza.job.model.ContainerModel;
import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.runtime.LocationId;
import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
-
-import static org.apache.samza.container.mock.ContainerMocks.*;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.apache.samza.container.mock.ContainerMocks.generateTaskModels;
+import static org.apache.samza.container.mock.ContainerMocks.getTaskName;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerIds.class})
public class TestGroupByContainerIds {
- @Before
- public void setup() throws Exception {
- TaskAssignmentManager taskAssignmentManager = mock(TaskAssignmentManager.class);
- LocalityManager localityManager = mock(LocalityManager.class);
- PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
- }
-
private Config buildConfigForContainerCount(int count) {
Map<String, String> map = new HashMap<>();
map.put("job.container.count", String.valueOf(count));
@@ -67,6 +57,7 @@ public class TestGroupByContainerIds {
private TaskNameGrouper buildSimpleGrouper() {
return buildSimpleGrouper(1);
}
+
private TaskNameGrouper buildSimpleGrouper(int containerCount) {
return new GroupByContainerIdsFactory().build(buildConfigForContainerCount(containerCount));
}
@@ -114,7 +105,8 @@ public class TestGroupByContainerIds {
public void testGroupWithNullContainerIds() {
Set<TaskModel> taskModels = generateTaskModels(5);
- Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, null);
+ List<String> containerIds = null;
+ Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, containerIds);
Map<String, ContainerModel> containersMap = new HashMap<>();
for (ContainerModel container : containers) {
@@ -251,4 +243,264 @@ public class TestGroupByContainerIds {
assertEquals(1, actualContainerModels.size());
assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels);
}
+
+ @Test
+ public void testShouldUseTaskLocalityWhenGeneratingContainerModels() {
+ TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+ String testProcessorId1 = "testProcessorId1";
+ String testProcessorId2 = "testProcessorId2";
+ String testProcessorId3 = "testProcessorId3";
+
+ LocationId testLocationId1 = new LocationId("testLocationId1");
+ LocationId testLocationId2 = new LocationId("testLocationId2");
+ LocationId testLocationId3 = new LocationId("testLocationId3");
+
+ TaskName testTaskName1 = new TaskName("testTasKId1");
+ TaskName testTaskName2 = new TaskName("testTaskId2");
+ TaskName testTaskName3 = new TaskName("testTaskId3");
+
+ TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+ TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+ TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+ Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+ testProcessorId2, testLocationId2,
+ testProcessorId3, testLocationId3);
+
+ Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+ testTaskName2, testLocationId2,
+ testTaskName3, testLocationId3);
+
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+ Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+ Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+ new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+ new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+ Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+ assertEquals(expectedContainerModels, actualContainerModels);
+ }
+
+ @Test
+ public void testGenerateContainerModelForSingleContainer() {
+ TaskNameGrouper taskNameGrouper = buildSimpleGrouper(1);
+
+ String testProcessorId1 = "testProcessorId1";
+
+ LocationId testLocationId1 = new LocationId("testLocationId1");
+ LocationId testLocationId2 = new LocationId("testLocationId2");
+ LocationId testLocationId3 = new LocationId("testLocationId3");
+
+ TaskName testTaskName1 = new TaskName("testTasKId1");
+ TaskName testTaskName2 = new TaskName("testTaskId2");
+ TaskName testTaskName3 = new TaskName("testTaskId3");
+
+ TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+ TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+ TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+ Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1);
+
+ Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+ testTaskName2, testLocationId2,
+ testTaskName3, testLocationId3);
+
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+ Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+ Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1,
+ testTaskName2, testTaskModel2,
+ testTaskName3, testTaskModel3)));
+
+ Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+ assertEquals(expectedContainerModels, actualContainerModels);
+ }
+
+ @Test
+ public void testShouldGenerateCorrectContainerModelWhenTaskLocalityIsEmpty() {
+ TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+ String testProcessorId1 = "testProcessorId1";
+ String testProcessorId2 = "testProcessorId2";
+ String testProcessorId3 = "testProcessorId3";
+
+ LocationId testLocationId1 = new LocationId("testLocationId1");
+ LocationId testLocationId2 = new LocationId("testLocationId2");
+ LocationId testLocationId3 = new LocationId("testLocationId3");
+
+ TaskName testTaskName1 = new TaskName("testTasKId1");
+ TaskName testTaskName2 = new TaskName("testTaskId2");
+ TaskName testTaskName3 = new TaskName("testTaskId3");
+
+ TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+ TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+ TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+ Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+ testProcessorId2, testLocationId2,
+ testProcessorId3, testLocationId3);
+
+ Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1);
+
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+ Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+ Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+ new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+ new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+ Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+ assertEquals(expectedContainerModels, actualContainerModels);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testShouldFailWhenProcessorLocalityIsEmpty() {
+ TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
+
+ taskNameGrouper.group(new HashSet<>(), grouperMetadata);
+ }
+
+ @Test
+ public void testShouldGenerateIdenticalTaskDistributionWhenNoChangeInProcessorGroup() {
+ TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+ String testProcessorId1 = "testProcessorId1";
+ String testProcessorId2 = "testProcessorId2";
+ String testProcessorId3 = "testProcessorId3";
+
+ LocationId testLocationId1 = new LocationId("testLocationId1");
+ LocationId testLocationId2 = new LocationId("testLocationId2");
+ LocationId testLocationId3 = new LocationId("testLocationId3");
+
+ TaskName testTaskName1 = new TaskName("testTasKId1");
+ TaskName testTaskName2 = new TaskName("testTaskId2");
+ TaskName testTaskName3 = new TaskName("testTaskId3");
+
+ TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+ TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+ TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+ Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+ testProcessorId2, testLocationId2,
+ testProcessorId3, testLocationId3);
+
+ Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+ testTaskName2, testLocationId2,
+ testTaskName3, testLocationId3);
+
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+ Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+ Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+ new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+ new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+ Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+ assertEquals(expectedContainerModels, actualContainerModels);
+
+ actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+ assertEquals(expectedContainerModels, actualContainerModels);
+ }
+
+ @Test
+ public void testShouldMinimizeTaskShuffleWhenAvailableProcessorInGroupChanges() {
+ TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+ String testProcessorId1 = "testProcessorId1";
+ String testProcessorId2 = "testProcessorId2";
+ String testProcessorId3 = "testProcessorId3";
+
+ LocationId testLocationId1 = new LocationId("testLocationId1");
+ LocationId testLocationId2 = new LocationId("testLocationId2");
+ LocationId testLocationId3 = new LocationId("testLocationId3");
+
+ TaskName testTaskName1 = new TaskName("testTasKId1");
+ TaskName testTaskName2 = new TaskName("testTaskId2");
+ TaskName testTaskName3 = new TaskName("testTaskId3");
+
+ TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0));
+ TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1));
+ TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2));
+
+ Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+ testProcessorId2, testLocationId2,
+ testProcessorId3, testLocationId3);
+
+ Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+ testTaskName2, testLocationId2,
+ testTaskName3, testLocationId3);
+
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+ Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3);
+
+ Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)),
+ new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)),
+ new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+ Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+ assertEquals(expectedContainerModels, actualContainerModels);
+
+ processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+ testProcessorId2, testLocationId2);
+
+ grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+ actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+ expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1, testTaskName3, testTaskModel3)),
+ new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)));
+
+ assertEquals(expectedContainerModels, actualContainerModels);
+ }
+
+ @Test
+ public void testMoreTasksThanProcessors() {
+ String testProcessorId1 = "testProcessorId1";
+ String testProcessorId2 = "testProcessorId2";
+
+ LocationId testLocationId1 = new LocationId("testLocationId1");
+ LocationId testLocationId2 = new LocationId("testLocationId2");
+ LocationId testLocationId3 = new LocationId("testLocationId3");
+
+ TaskName testTaskName1 = new TaskName("testTasKId1");
+ TaskName testTaskName2 = new TaskName("testTaskId2");
+ TaskName testTaskName3 = new TaskName("testTaskId3");
+
+ Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+ testProcessorId2, testLocationId2);
+
+ Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1,
+ testTaskName2, testLocationId2,
+ testTaskName3, testLocationId3);
+
+ GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>());
+
+
+ Set<TaskModel> taskModels = generateTaskModels(1);
+ List<String> containerIds = ImmutableList.of(testProcessorId1, testProcessorId2);
+
+ Map<TaskName, TaskModel> expectedTasks = taskModels.stream()
+ .collect(Collectors.toMap(TaskModel::getTaskName, x -> x));
+ ContainerModel expectedContainerModel = new ContainerModel(testProcessorId1, expectedTasks);
+
+ Set<ContainerModel> actualContainerModels = buildSimpleGrouper().group(taskModels, grouperMetadata);
+
+ assertEquals(1, actualContainerModels.size());
+ assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels);
+ }
}
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
index fcdbf08..60164b2 100644
--- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
+++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
@@ -68,7 +68,6 @@ public class TestTaskAssignmentManager {
@Test
public void testTaskAssignmentManager() {
TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
- taskAssignmentManager.init();
Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1", "Task2", "2", "Task3", "0", "Task4", "1");
@@ -86,7 +85,6 @@ public class TestTaskAssignmentManager {
@Test
public void testDeleteMappings() {
TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
- taskAssignmentManager.init();
Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1");
@@ -108,7 +106,6 @@ public class TestTaskAssignmentManager {
@Test
public void testTaskAssignmentManagerEmptyCoordinatorStream() {
TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
- taskAssignmentManager.init();
Map<String, String> expectedMap = new HashMap<>();
Map<String, String> localMap = taskAssignmentManager.readTaskAssignment();
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
index ca9def2..be240b1 100644
--- a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
+++ b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
@@ -117,11 +117,11 @@ public class ContainerMocks {
return values;
}
- public static Map<String, String> generateTaskContainerMapping(Set<ContainerModel> containers) {
- Map<String, String> taskMapping = new HashMap<>();
+ public static Map<TaskName, String> generateTaskContainerMapping(Set<ContainerModel> containers) {
+ Map<TaskName, String> taskMapping = new HashMap<>();
for (ContainerModel container : containers) {
for (TaskName taskName : container.getTasks().keySet()) {
- taskMapping.put(taskName.getTaskName(), container.getId());
+ taskMapping.put(taskName, container.getId());
}
}
return taskMapping;
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
index ea25ec1..02aaaa7 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
@@ -19,15 +19,15 @@
package org.apache.samza.coordinator;
-import java.util.ArrayList;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import org.apache.samza.config.Config;
import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
import org.apache.samza.coordinator.server.HttpServer;
import org.apache.samza.job.model.ContainerModel;
import org.apache.samza.job.model.JobModel;
+import org.apache.samza.runtime.LocationId;
import org.apache.samza.system.StreamMetadataCache;
/**
@@ -49,15 +49,8 @@ public class JobModelManagerTestUtil {
return new JobModelManager(jobModel, server, null);
}
- public static JobModelManager getJobModelManagerUsingReadModel(Config config, int containerCount, StreamMetadataCache streamMetadataCache,
- LocalityManager locManager, HttpServer server) {
- List<String> containerIds = new ArrayList<>();
- for (int i = 0; i < containerCount; i++) {
- containerIds.add(String.valueOf(i));
- }
- JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), locManager, streamMetadataCache, containerIds);
- return new JobModelManager(jobModel, server, null);
+ public static JobModelManager getJobModelManagerUsingReadModel(Config config, StreamMetadataCache streamMetadataCache, HttpServer server, LocalityManager localityManager, Map<String, LocationId> processorLocality) {
+ JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), streamMetadataCache, new GrouperMetadataImpl(processorLocality, new HashMap<>(), new HashMap<>(), new HashMap<>()));
+ return new JobModelManager(new JobModel(jobModel.getConfig(), jobModel.getContainers(), localityManager), server, localityManager);
}
-
-
}
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
index 1dbf132..6048466 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
@@ -19,20 +19,30 @@
package org.apache.samza.coordinator;
+import com.google.common.collect.ImmutableMap;
+import java.util.HashSet;
+import java.util.Set;
import org.apache.samza.Partition;
import org.apache.samza.config.Config;
import org.apache.samza.config.MapConfig;
import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.TaskName;
import org.apache.samza.container.grouper.task.GroupByContainerCount;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
import org.apache.samza.container.grouper.task.TaskAssignmentManager;
import org.apache.samza.coordinator.server.HttpServer;
import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.JobModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.runtime.LocationId;
import org.apache.samza.system.StreamMetadataCache;
import org.apache.samza.system.SystemStream;
import org.apache.samza.system.SystemStreamMetadata;
import org.apache.samza.testUtils.MockHttpServer;
import org.eclipse.jetty.servlet.DefaultServlet;
import org.eclipse.jetty.servlet.ServletHolder;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -40,14 +50,15 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Collections;
+import static org.apache.samza.coordinator.JobModelManager.*;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.argThat;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
import org.junit.runner.RunWith;
import org.mockito.ArgumentMatcher;
+import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
@@ -60,7 +71,6 @@ import scala.collection.JavaConversions;
@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
public class TestJobModelManager {
private final TaskAssignmentManager mockTaskManager = mock(TaskAssignmentManager.class);
- private final LocalityManager mockLocalityManager = mock(LocalityManager.class);
private final Map<String, Map<String, String>> localityMappings = new HashMap<>();
private final HttpServer server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class));
private final SystemStream inputStream = new SystemStream("test-system", "test-stream");
@@ -75,7 +85,6 @@ public class TestJobModelManager {
@Before
public void setup() throws Exception {
- when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
when(mockStreamMetadataCache.getStreamMetadata(argThat(new ArgumentMatcher<scala.collection.immutable.Set<SystemStream>>() {
@Override
public boolean matches(Object argument) {
@@ -105,11 +114,15 @@ public class TestJobModelManager {
put("job.host-affinity.enabled", "true");
}
});
+ LocalityManager mockLocalityManager = mock(LocalityManager.class);
- this.localityMappings.put("0", new HashMap<String, String>() { {
+ localityMappings.put("0", new HashMap<String, String>() { {
put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
} });
- this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+ when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
+
+ Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity"));
+ this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", "abc-affinity"); } });
}
@@ -132,11 +145,96 @@ public class TestJobModelManager {
}
});
- this.localityMappings.put("0", new HashMap<String, String>() { {
+ LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+ localityMappings.put("0", new HashMap<String, String>() { {
put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
} });
- this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+ when(mockLocalityManager.readContainerLocality()).thenReturn(new HashMap<>());
+
+ Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity"));
+
+ this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", null); } });
}
+
+ @Test
+ public void testGetGrouperMetadata() {
+ // Mocking setup.
+ LocalityManager mockLocalityManager = mock(LocalityManager.class);
+ TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+
+ Map<String, Map<String, String>> localityMappings = new HashMap<>();
+ localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+ Map<String, String> taskAssignment = ImmutableMap.of("task-0", "0");
+
+ // Mock the container locality assignment.
+ when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+ // Mock the container to task assignment.
+ when(mockTaskAssignmentManager.readTaskAssignment()).thenReturn(taskAssignment);
+
+ GrouperMetadataImpl grouperMetadata = JobModelManager.getGrouperMetadata(new MapConfig(), mockLocalityManager, mockTaskAssignmentManager);
+
+ Mockito.verify(mockLocalityManager).readContainerLocality();
+ Mockito.verify(mockTaskAssignmentManager).readTaskAssignment();
+
+ Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), grouperMetadata.getProcessorLocality());
+ Assert.assertEquals(ImmutableMap.of(new TaskName("task-0"), new LocationId("abc-affinity")), grouperMetadata.getTaskLocality());
+ }
+
+ @Test
+ public void testGetProcessorLocality() {
+ // Mock the dependencies.
+ LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+ Map<String, Map<String, String>> localityMappings = new HashMap<>();
+ localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+ // Mock the container locality assignment.
+ when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+ Map<String, LocationId> processorLocality = JobModelManager.getProcessorLocality(new MapConfig(), mockLocalityManager);
+
+ Mockito.verify(mockLocalityManager).readContainerLocality();
+ Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), processorLocality);
+ }
+
+ @Test
+ public void testUpdateTaskAssignments() {
+ // Mocking setup.
+ JobModel mockJobModel = Mockito.mock(JobModel.class);
+ GrouperMetadataImpl mockGrouperMetadata = Mockito.mock(GrouperMetadataImpl.class);
+ TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+
+ Map<TaskName, TaskModel> taskModelMap = new HashMap<>();
+ taskModelMap.put(new TaskName("task-1"), new TaskModel(new TaskName("task-1"), new HashSet<>(), new Partition(0)));
+ taskModelMap.put(new TaskName("task-2"), new TaskModel(new TaskName("task-2"), new HashSet<>(), new Partition(1)));
+ taskModelMap.put(new TaskName("task-3"), new TaskModel(new TaskName("task-3"), new HashSet<>(), new Partition(2)));
+ taskModelMap.put(new TaskName("task-4"), new TaskModel(new TaskName("task-4"), new HashSet<>(), new Partition(3)));
+ ContainerModel containerModel = new ContainerModel("test-container-id", taskModelMap);
+ Map<String, ContainerModel> containerMapping = ImmutableMap.of("test-container-id", containerModel);
+
+ when(mockJobModel.getContainers()).thenReturn(containerMapping);
+ when(mockGrouperMetadata.getPreviousTaskToProcessorAssignment()).thenReturn(new HashMap<>());
+ Mockito.doNothing().when(mockTaskAssignmentManager).writeTaskContainerMapping(Mockito.any(), Mockito.any());
+
+ JobModelManager.updateTaskAssignments(mockJobModel, mockTaskAssignmentManager, mockGrouperMetadata);
+
+ Set<String> taskNames = new HashSet<String>();
+ taskNames.add("task-4");
+ taskNames.add("task-2");
+ taskNames.add("task-3");
+ taskNames.add("task-1");
+
+ // Verifications
+ Mockito.verify(mockJobModel, atLeast(1)).getContainers();
+ Mockito.verify(mockTaskAssignmentManager).deleteTaskContainerMappings((Iterable<String>) taskNames);
+ Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-1", "test-container-id");
+ Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-2", "test-container-id");
+ Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-3", "test-container-id");
+ Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-4", "test-container-id");
+ }
}