http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala index 600b7a1..5fb71f3 100644 --- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala +++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala @@ -19,31 +19,33 @@ package org.apache.samza.coordinator - import java.util import java.util.concurrent.atomic.AtomicReference - import org.apache.samza.config._ import org.apache.samza.config.JobConfig.Config2Job import org.apache.samza.config.SystemConfig.Config2System import org.apache.samza.config.TaskConfig.Config2Task import org.apache.samza.config.Config import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory -import org.apache.samza.container.grouper.task.BalancingTaskNameGrouper -import org.apache.samza.container.grouper.task.TaskNameGrouperFactory +import org.apache.samza.container.grouper.task._ import org.apache.samza.container.LocalityManager import org.apache.samza.container.TaskName import org.apache.samza.coordinator.server.HttpServer import org.apache.samza.coordinator.server.JobServlet +import org.apache.samza.job.model.ContainerModel import org.apache.samza.job.model.JobModel import org.apache.samza.job.model.TaskModel +import org.apache.samza.metrics.MetricsRegistry import org.apache.samza.metrics.MetricsRegistryMap import org.apache.samza.system._ import org.apache.samza.util.Logging import org.apache.samza.util.Util import org.apache.samza.Partition +import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping +import org.apache.samza.runtime.LocationId import scala.collection.JavaConverters._ +import scala.collection.JavaConversions._ /** * Helper companion object that is responsible for wiring up a JobModelManager @@ -51,66 +53,145 @@ import scala.collection.JavaConverters._ */ object JobModelManager extends Logging { - val SOURCE = "JobModelManager" /** * a volatile value to store the current instantiated <code>JobModelManager</code> */ - @volatile var currentJobModelManager: JobModelManager = null + @volatile var currentJobModelManager: JobModelManager = _ val jobModelRef: AtomicReference[JobModel] = new AtomicReference[JobModel]() /** - * Does the following actions for a job. + * Currently used only in the ApplicationMaster for yarn deployment model. + * Does the following: * a) Reads the jobModel from coordinator stream using the job's configuration. - * b) Recomputes changelog partition mapping based on jobModel and job's configuration. + * b) Recomputes the changelog partition mapping based on jobModel and job's configuration. * c) Builds JobModelManager using the jobModel read from coordinator stream. - * @param config Config from the coordinator stream. - * @param changelogPartitionMapping The changelog partition-to-task mapping. - * @return JobModelManager + * @param config config from the coordinator stream. + * @param changelogPartitionMapping changelog partition-to-task mapping of the samza job. + * @param metricsRegistry the registry for reporting metrics. + * @return the instantiated {@see JobModelManager}. */ - def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer]) = { - val localityManager = new LocalityManager(config, new MetricsRegistryMap()) - - // Map the name of each system to the corresponding SystemAdmin + def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer], metricsRegistry: MetricsRegistry = new MetricsRegistryMap()): JobModelManager = { + val localityManager = new LocalityManager(config, metricsRegistry) + val taskAssignmentManager = new TaskAssignmentManager(config, metricsRegistry) val systemAdmins = new SystemAdmins(config) - val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0) + try { + systemAdmins.start() + val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0) + val grouperMetadata: GrouperMetadata = getGrouperMetadata(config, localityManager, taskAssignmentManager) - val containerCount = new JobConfig(config).getContainerCount - val processorList = List.range(0, containerCount).map(c => c.toString) + val jobModel: JobModel = readJobModel(config, changelogPartitionMapping, streamMetadataCache, grouperMetadata) + jobModelRef.set(new JobModel(jobModel.getConfig, jobModel.getContainers, localityManager)) - systemAdmins.start() - val jobModelManager = getJobModelManager(config, changelogPartitionMapping, localityManager, streamMetadataCache, processorList.asJava) - systemAdmins.stop() + updateTaskAssignments(jobModel, taskAssignmentManager, grouperMetadata) - jobModelManager + val server = new HttpServer + server.addServlet("/", new JobServlet(jobModelRef)) + + currentJobModelManager = new JobModelManager(jobModel, server, localityManager) + currentJobModelManager + } finally { + taskAssignmentManager.close() + systemAdmins.stop() + // Not closing localityManager, since {@code ClusterBasedJobCoordinator} uses it to read container locality through {@code JobModel}. + } } /** - * Build a JobModelManager using a Samza job's configuration. - */ - private def getJobModelManager(config: Config, - changeLogMapping: util.Map[TaskName, Integer], - localityManager: LocalityManager, - streamMetadataCache: StreamMetadataCache, - containerIds: java.util.List[String]) = { - val jobModel: JobModel = readJobModel(config, changeLogMapping, localityManager, streamMetadataCache, containerIds) - jobModelRef.set(jobModel) - - val server = new HttpServer - server.addServlet("/", new JobServlet(jobModelRef)) - currentJobModelManager = new JobModelManager(jobModel, server, localityManager) - currentJobModelManager + * Builds the {@see GrouperMetadataImpl} for the samza job. + * @param config represents the configurations defined by the user. + * @param localityManager provides the processor to host mapping persisted to the metadata store. + * @param taskAssignmentManager provides the processor to task assignments persisted to the metadata store. + * @return the instantiated {@see GrouperMetadata}. + */ + def getGrouperMetadata(config: Config, localityManager: LocalityManager, taskAssignmentManager: TaskAssignmentManager) = { + val processorLocality: util.Map[String, LocationId] = getProcessorLocality(config, localityManager) + val taskAssignment: util.Map[String, String] = taskAssignmentManager.readTaskAssignment() + val taskNameToProcessorId: util.Map[TaskName, String] = new util.HashMap[TaskName, String]() + for ((taskName, processorId) <- taskAssignment) { + taskNameToProcessorId.put(new TaskName(taskName), processorId) + } + + val taskLocality:util.Map[TaskName, LocationId] = new util.HashMap[TaskName, LocationId]() + for ((taskName, processorId) <- taskAssignment) { + if (processorLocality.containsKey(processorId)) { + taskLocality.put(new TaskName(taskName), processorLocality.get(processorId)) + } + } + new GrouperMetadataImpl(processorLocality, taskLocality, new util.HashMap[TaskName, util.List[SystemStreamPartition]](), taskNameToProcessorId) } /** - * For each input stream specified in config, exactly determine its - * partitions, returning a set of SystemStreamPartitions containing them all. - */ - private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache) = { + * Retrieves and returns the processor locality of a samza job using provided {@see Config} and {@see LocalityManager}. + * @param config provides the configurations defined by the user. Required to connect to the storage layer. + * @param localityManager provides the processor to host mapping persisted to the metadata store. + * @return the processor locality. + */ + def getProcessorLocality(config: Config, localityManager: LocalityManager) = { + val containerToLocationId: util.Map[String, LocationId] = new util.HashMap[String, LocationId]() + val existingContainerLocality = localityManager.readContainerLocality() + + for (containerId <- 0 to config.getContainerCount) { + val localityMapping = existingContainerLocality.get(containerId.toString) + // To handle the case when the container count is increased between two different runs of a samza-yarn job, + // set the locality of newly added containers to any_host. + var locationId: LocationId = new LocationId("ANY_HOST") + if (localityMapping != null && localityMapping.containsKey(SetContainerHostMapping.HOST_KEY)) { + locationId = new LocationId(localityMapping.get(SetContainerHostMapping.HOST_KEY)) + } + containerToLocationId.put(containerId.toString, locationId) + } + + containerToLocationId + } + + /** + * This method does the following: + * 1. Deletes the existing task assignments if the partition-task grouping has changed from the previous run of the job. + * 2. Saves the newly generated task assignments to the storage layer through the {@param TaskAssignementManager}. + * + * @param jobModel represents the {@see JobModel} of the samza job. + * @param taskAssignmentManager required to persist the processor to task assignments to the storage layer. + * @param grouperMetadata provides the historical metadata of the application. + */ + def updateTaskAssignments(jobModel: JobModel, taskAssignmentManager: TaskAssignmentManager, grouperMetadata: GrouperMetadata): Unit = { + val taskNames: util.Set[String] = new util.HashSet[String]() + for (container <- jobModel.getContainers.values()) { + for (taskModel <- container.getTasks.values()) { + taskNames.add(taskModel.getTaskName.getTaskName) + } + } + val taskToContainerId = grouperMetadata.getPreviousTaskToProcessorAssignment + if (taskNames.size() != taskToContainerId.size()) { + warn("Current task count {} does not match saved task count {}. Stateful jobs may observe misalignment of keys!", + taskNames.size(), taskToContainerId.size()) + // If the tasks changed, then the partition-task grouping is also likely changed and we can't handle that + // without a much more complicated mapping. Further, the partition count may have changed, which means + // input message keys are likely reshuffled w.r.t. partitions, so the local state may not contain necessary + // data associated with the incoming keys. Warn the user and default to grouper + // In this scenario the tasks may have been reduced, so we need to delete all the existing messages + taskAssignmentManager.deleteTaskContainerMappings(taskNames) + } + + for (container <- jobModel.getContainers.values()) { + for (taskName <- container.getTasks.keySet) { + taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName, container.getId) + } + } + } + + /** + * Computes the input system stream partitions of a samza job using the provided {@param config} + * and {@param streamMetadataCache}. + * @param config the configuration of the job. + * @param streamMetadataCache to query the partition metadata of the input streams. + * @return the input {@see SystemStreamPartition} of the samza job. + */ + private def getInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = { val inputSystemStreams = config.getInputStreams // Get the set of partitions for each SystemStream from the stream metadata streamMetadataCache - .getStreamMetadata(inputSystemStreams, true) + .getStreamMetadata(inputSystemStreams, partitionsMetadataOnly = true) .flatMap { case (systemStream, metadata) => metadata @@ -121,55 +202,69 @@ object JobModelManager extends Logging { }.toSet } + /** + * Builds the input {@see SystemStreamPartition} based upon the {@param config} defined by the user. + * @param config configuration to fetch the metadata of the input streams. + * @param streamMetadataCache required to query the partition metadata of the input streams. + * @return the input SystemStreamPartitions of the job. + */ private def getMatchedInputStreamPartitions(config: Config, streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = { val allSystemStreamPartitions = getInputStreamPartitions(config, streamMetadataCache) config.getSSPMatcherClass match { - case Some(s) => { + case Some(s) => val jfr = config.getSSPMatcherConfigJobFactoryRegex.r config.getStreamJobFactoryClass match { - case Some(jfr(_*)) => { - info("before match: allSystemStreamPartitions.size = %s" format (allSystemStreamPartitions.size)) + case Some(jfr(_*)) => + info("before match: allSystemStreamPartitions.size = %s" format allSystemStreamPartitions.size) val sspMatcher = Util.getObj(s, classOf[SystemStreamPartitionMatcher]) val matchedPartitions = sspMatcher.filter(allSystemStreamPartitions.asJava, config).asScala.toSet // Usually a small set hence ok to log at info level - info("after match: matchedPartitions = %s" format (matchedPartitions)) + info("after match: matchedPartitions = %s" format matchedPartitions) matchedPartitions - } case _ => allSystemStreamPartitions } - } case _ => allSystemStreamPartitions } } /** - * Gets a SystemStreamPartitionGrouper object from the configuration. - */ + * Finds the {@see SystemStreamPartitionGrouperFactory} from the {@param config}. Instantiates the {@see SystemStreamPartitionGrouper} + * object through the factory. + * @param config the configuration of the samza job. + * @return the instantiated {@see SystemStreamPartitionGrouper}. + */ private def getSystemStreamPartitionGrouper(config: Config) = { val factoryString = config.getSystemStreamPartitionGrouperFactory val factory = Util.getObj(factoryString, classOf[SystemStreamPartitionGrouperFactory]) factory.getSystemStreamPartitionGrouper(config) } + /** - * The function reads the latest checkpoint from the underlying coordinator stream and - * builds a new JobModel. - */ + * Does the following: + * 1. Fetches metadata of the input streams defined in configuration through {@param streamMetadataCache}. + * 2. Applies the {@see SystemStreamPartitionGrouper}, {@see TaskNameGrouper} defined in the configuration + * to build the {@see JobModel}. + * @param config the configuration of the job. + * @param changeLogPartitionMapping the task to changelog partition mapping of the job. + * @param streamMetadataCache the cache that holds the partition metadata of the input streams. + * @param grouperMetadata provides the historical metadata of the application. + * @return the built {@see JobModel}. + */ def readJobModel(config: Config, changeLogPartitionMapping: util.Map[TaskName, Integer], - localityManager: LocalityManager, streamMetadataCache: StreamMetadataCache, - containerIds: java.util.List[String]): JobModel = { + grouperMetadata: GrouperMetadata): JobModel = { // Do grouping to fetch TaskName to SSP mapping val allSystemStreamPartitions = getMatchedInputStreamPartitions(config, streamMetadataCache) // processor list is required by some of the groupers. So, let's pass them as part of the config. // Copy the config and add the processor list to the config copy. val configMap = new util.HashMap[String, String](config) - configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", containerIds)) + configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", grouperMetadata.getProcessorLocality.keySet())) val grouper = getSystemStreamPartitionGrouper(new MapConfig(configMap)) - val groups = grouper.group(allSystemStreamPartitions.asJava) + val groups = grouper.group(allSystemStreamPartitions) info("SystemStreamPartitionGrouper %s has grouped the SystemStreamPartitions into %d tasks with the following taskNames: %s" format(grouper, groups.size(), groups.keySet())) val isHostAffinityEnabled = new ClusterManagerConfig(config).getHostAffinityEnabled @@ -200,22 +295,18 @@ object JobModelManager extends Logging { // SSPTaskNameGrouper for locality, load-balancing, etc. val containerGrouperFactory = Util.getObj(config.getTaskNameGrouperFactory, classOf[TaskNameGrouperFactory]) val containerGrouper = containerGrouperFactory.build(config) - val containerModels = { - containerGrouper match { - case grouper: BalancingTaskNameGrouper if isHostAffinityEnabled => grouper.balance(taskModels.asJava, localityManager) - case _ => containerGrouper.group(taskModels.asJava, containerIds) - } - } - val containerMap = containerModels.asScala.map { case (containerModel) => containerModel.getId -> containerModel }.toMap - - if (isHostAffinityEnabled) { - new JobModel(config, containerMap.asJava, localityManager) + var containerModels: util.Set[ContainerModel] = null + if(isHostAffinityEnabled) { + containerModels = containerGrouper.group(taskModels, grouperMetadata) } else { - new JobModel(config, containerMap.asJava) + containerModels = containerGrouper.group(taskModels, new util.ArrayList[String](grouperMetadata.getProcessorLocality.keySet())) } + val containerMap = containerModels.asScala.map(containerModel => containerModel.getId -> containerModel).toMap + + new JobModel(config, containerMap.asJava) } - private def getSystemNames(config: Config) = config.getSystemNames.toSet + private def getSystemNames(config: Config) = config.getSystemNames().toSet } /** @@ -248,7 +339,7 @@ class JobModelManager( debug("Got job model: %s." format jobModel) - def start { + def start() { if (server != null) { debug("Starting HTTP server.") server.start @@ -256,7 +347,7 @@ class JobModelManager( } } - def stop { + def stop() { if (server != null) { debug("Stopping HTTP server.") server.stop
http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala index 64f516b..d16c294 100644 --- a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala +++ b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala @@ -50,7 +50,7 @@ class ProcessJobFactory extends StreamJobFactory with Logging { coordinatorStreamManager.bootstrap val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager) - val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping()) + val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry) val jobModel = coordinator.jobModel val taskPartitionMappings: util.Map[TaskName, Integer] = new util.HashMap[TaskName, Integer] http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala index 5a8d2f8..e4a7838 100644 --- a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala +++ b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala @@ -52,7 +52,7 @@ class ThreadJobFactory extends StreamJobFactory with Logging { coordinatorStreamManager.bootstrap val changelogStreamManager = new ChangelogStreamManager(coordinatorStreamManager) - val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping()) + val coordinator = JobModelManager(coordinatorStreamManager.getConfig, changelogStreamManager.readPartitionMapping(), metricsRegistry) val jobModel = coordinator.jobModel http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java index 0c2f2fb..9e6e8d0 100644 --- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java +++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java @@ -24,54 +24,34 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.UUID; -import org.apache.samza.SamzaException; -import org.apache.samza.config.Config; -import org.apache.samza.config.JobConfig; -import org.apache.samza.config.MapConfig; -import org.apache.samza.container.LocalityManager; + +import org.apache.samza.container.TaskName; import org.apache.samza.job.model.ContainerModel; import org.apache.samza.job.model.TaskModel; -import org.junit.Before; +import org.apache.samza.SamzaException; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mockito; -import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; import static org.apache.samza.container.mock.ContainerMocks.*; import static org.junit.Assert.*; -import static org.mockito.Mockito.*; -@RunWith(PowerMockRunner.class) -@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class}) public class TestGroupByContainerCount { - private TaskAssignmentManager taskAssignmentManager; - private LocalityManager localityManager; - @Before - public void setup() throws Exception { - taskAssignmentManager = mock(TaskAssignmentManager.class); - localityManager = mock(LocalityManager.class); - PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager); - Mockito.doNothing().when(taskAssignmentManager).init(); - } @Test(expected = IllegalArgumentException.class) public void testGroupEmptyTasks() { - new GroupByContainerCount(getConfig(1)).group(new HashSet()); + new GroupByContainerCount(1).group(new HashSet<>()); } @Test(expected = IllegalArgumentException.class) public void testGroupFewerTasksThanContainers() { Set<TaskModel> taskModels = new HashSet<>(); taskModels.add(getTaskModel(1)); - new GroupByContainerCount(getConfig(2)).group(taskModels); + new GroupByContainerCount(2).group(taskModels); } @Test(expected = UnsupportedOperationException.class) public void testGrouperResultImmutable() { Set<TaskModel> taskModels = generateTaskModels(3); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).group(taskModels); + Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels); containers.remove(containers.iterator().next()); } @@ -79,7 +59,7 @@ public class TestGroupByContainerCount { public void testGroupHappyPath() { Set<TaskModel> taskModels = generateTaskModels(5); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels); + Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -106,7 +86,7 @@ public class TestGroupByContainerCount { public void testGroupManyTasks() { Set<TaskModel> taskModels = generateTaskModels(21); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels); + Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -174,11 +154,11 @@ public class TestGroupByContainerCount { @Test public void testBalancerAfterContainerIncrease() { Set<TaskModel> taskModels = generateTaskModels(9); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(4)).balance(taskModels, localityManager); + Set<ContainerModel> containers = new GroupByContainerCount(4).group(taskModels, grouperMetadata); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -213,22 +193,6 @@ public class TestGroupByContainerCount { assertTrue(container2.getTasks().containsKey(getTaskName(6))); assertTrue(container3.getTasks().containsKey(getTaskName(5))); assertTrue(container3.getTasks().containsKey(getTaskName(7))); - - // Verify task mappings are saved - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0"); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1"); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "2"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2"); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "3"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "3"); - - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } /** @@ -256,11 +220,11 @@ public class TestGroupByContainerCount { @Test public void testBalancerAfterContainerDecrease() { Set<TaskModel> taskModels = generateTaskModels(9); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager); + Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -290,20 +254,6 @@ public class TestGroupByContainerCount { assertTrue(container0.getTasks().containsKey(getTaskName(2))); assertTrue(container1.getTasks().containsKey(getTaskName(7))); assertTrue(container1.getTasks().containsKey(getTaskName(3))); - - // Verify task mappings are saved - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0"); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1"); - - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } /** @@ -331,15 +281,15 @@ public class TestGroupByContainerCount { * T8 T7 T3 */ @Test - public void testBalancerMultipleReblances() throws Exception { + public void testBalancerMultipleReblances() { // Before Set<TaskModel> taskModels = generateTaskModels(9); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); // First balance - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager); + Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -370,30 +320,11 @@ public class TestGroupByContainerCount { assertTrue(container1.getTasks().containsKey(getTaskName(7))); assertTrue(container1.getTasks().containsKey(getTaskName(3))); - // Verify task mappings are saved - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0"); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "1"); - - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); - - // Second balance prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - TaskAssignmentManager taskAssignmentManager2 = mock(TaskAssignmentManager.class); - when(taskAssignmentManager2.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); - LocalityManager localityManager2 = mock(LocalityManager.class); - PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager2); - - containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager2); + GrouperMetadataImpl grouperMetadata1 = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); + containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata1); containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -427,21 +358,6 @@ public class TestGroupByContainerCount { assertTrue(container2.getTasks().containsKey(getTaskName(6))); assertTrue(container2.getTasks().containsKey(getTaskName(2))); assertTrue(container2.getTasks().containsKey(getTaskName(3))); - - // Verify task mappings are saved - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(4).getTaskName(), "0"); - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(8).getTaskName(), "0"); - - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1"); - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(5).getTaskName(), "1"); - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(7).getTaskName(), "1"); - - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(6).getTaskName(), "2"); - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2"); - verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2"); - - verify(taskAssignmentManager2, never()).deleteTaskContainerMappings(anyCollection()); } /** @@ -466,11 +382,11 @@ public class TestGroupByContainerCount { @Test public void testBalancerAfterContainerSame() { Set<TaskModel> taskModels = generateTaskModels(9); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); + Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -496,9 +412,6 @@ public class TestGroupByContainerCount { assertTrue(container1.getTasks().containsKey(getTaskName(3))); assertTrue(container1.getTasks().containsKey(getTaskName(5))); assertTrue(container1.getTasks().containsKey(getTaskName(7))); - - verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString()); - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } /** @@ -528,19 +441,19 @@ public class TestGroupByContainerCount { public void testBalancerAfterContainerSameCustomAssignment() { Set<TaskModel> taskModels = generateTaskModels(9); - Map<String, String> prevTaskToContainerMapping = new HashMap<>(); - prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0"); - prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "0"); - prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "0"); - prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "0"); - prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "0"); - prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "0"); - prevTaskToContainerMapping.put(getTaskName(6).getTaskName(), "1"); - prevTaskToContainerMapping.put(getTaskName(7).getTaskName(), "1"); - prevTaskToContainerMapping.put(getTaskName(8).getTaskName(), "1"); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>(); + prevTaskToContainerMapping.put(getTaskName(0), "0"); + prevTaskToContainerMapping.put(getTaskName(1), "0"); + prevTaskToContainerMapping.put(getTaskName(2), "0"); + prevTaskToContainerMapping.put(getTaskName(3), "0"); + prevTaskToContainerMapping.put(getTaskName(4), "0"); + prevTaskToContainerMapping.put(getTaskName(5), "0"); + prevTaskToContainerMapping.put(getTaskName(6), "1"); + prevTaskToContainerMapping.put(getTaskName(7), "1"); + prevTaskToContainerMapping.put(getTaskName(8), "1"); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); + Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -566,9 +479,6 @@ public class TestGroupByContainerCount { assertTrue(container1.getTasks().containsKey(getTaskName(6))); assertTrue(container1.getTasks().containsKey(getTaskName(7))); assertTrue(container1.getTasks().containsKey(getTaskName(8))); - - verify(taskAssignmentManager, never()).writeTaskContainerMapping(anyString(), anyString()); - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } /** @@ -597,16 +507,16 @@ public class TestGroupByContainerCount { public void testBalancerAfterContainerSameCustomAssignmentAndContainerIncrease() { Set<TaskModel> taskModels = generateTaskModels(6); - Map<String, String> prevTaskToContainerMapping = new HashMap<>(); - prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0"); - prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "1"); - prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "1"); - prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "1"); - prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "1"); - prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "1"); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>(); + prevTaskToContainerMapping.put(getTaskName(0), "0"); + prevTaskToContainerMapping.put(getTaskName(1), "1"); + prevTaskToContainerMapping.put(getTaskName(2), "1"); + prevTaskToContainerMapping.put(getTaskName(3), "1"); + prevTaskToContainerMapping.put(getTaskName(4), "1"); + prevTaskToContainerMapping.put(getTaskName(5), "1"); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager); + Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels, grouperMetadata); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -633,146 +543,106 @@ public class TestGroupByContainerCount { assertTrue(container1.getTasks().containsKey(getTaskName(2))); assertTrue(container2.getTasks().containsKey(getTaskName(4))); assertTrue(container2.getTasks().containsKey(getTaskName(3))); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(), "2"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(), "2"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(), "0"); - - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } @Test public void testBalancerOldContainerCountOne() { Set<TaskModel> taskModels = generateTaskModels(3); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(1)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(1).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels); - Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager); + Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels); + Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).group(taskModels, grouperMetadata); // Results should be the same as calling group() assertEquals(groupContainers, balanceContainers); - - // Verify task mappings are saved - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "1"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "2"); - - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } @Test public void testBalancerNewContainerCountOne() { Set<TaskModel> taskModels = generateTaskModels(3); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels); - Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager); + Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels); + Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata); // Results should be the same as calling group() assertEquals(groupContainers, balanceContainers); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0"); - - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } @Test public void testBalancerEmptyTaskMapping() { Set<TaskModel> taskModels = generateTaskModels(3); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<>()); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>()); - Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels); - Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager); + Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels); + Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata); // Results should be the same as calling group() assertEquals(groupContainers, balanceContainers); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0"); - - verify(taskAssignmentManager, never()).deleteTaskContainerMappings(anyCollection()); } @Test public void testGroupTaskCountIncrease() { int taskCount = 3; Set<TaskModel> taskModels = generateTaskModels(taskCount); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(generateTaskModels(taskCount - 1)); // Here's the key step - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1)); // Here's the key step + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels); - Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager); + Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels); + Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata); // Results should be the same as calling group() assertEquals(groupContainers, balanceContainers); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0"); - - verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection()); } @Test public void testGroupTaskCountDecrease() { int taskCount = 3; Set<TaskModel> taskModels = generateTaskModels(taskCount); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(generateTaskModels(taskCount + 1)); // Here's the key step - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1)); // Here's the key step + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels); - Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager); + Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels); + Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).group(taskModels, grouperMetadata); // Results should be the same as calling group() assertEquals(groupContainers, balanceContainers); - - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0"); - verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(), "0"); - - verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection()); } @Test(expected = IllegalArgumentException.class) public void testBalancerNewContainerCountGreaterThanTasks() { Set<TaskModel> taskModels = generateTaskModels(3); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - new GroupByContainerCount(getConfig(5)).balance(taskModels, localityManager); // Should throw + new GroupByContainerCount(5).group(taskModels, grouperMetadata); // Should throw } @Test(expected = IllegalArgumentException.class) public void testBalancerEmptyTasks() { Set<TaskModel> taskModels = generateTaskModels(3); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - new GroupByContainerCount(getConfig(5)).balance(new HashSet<>(), localityManager); // Should throw + new GroupByContainerCount(5).group(new HashSet<>(), grouperMetadata); } @Test(expected = UnsupportedOperationException.class) public void testBalancerResultImmutable() { Set<TaskModel> taskModels = generateTaskModels(3); - Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); + Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); - Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager); + Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels, grouperMetadata); containers.remove(containers.iterator().next()); } @@ -780,32 +650,20 @@ public class TestGroupByContainerCount { public void testBalancerThrowsOnNonIntegerContainerIds() { Set<TaskModel> taskModels = generateTaskModels(3); Set<ContainerModel> prevContainers = new HashSet<>(); - taskModels.forEach(model -> { - prevContainers.add( - new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model))); - }); - Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); - when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping); - - new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager); //Should throw - + taskModels.forEach(model -> prevContainers.add(new ContainerModel(UUID.randomUUID().toString(), Collections.singletonMap(model.getTaskName(), model)))); + Map<TaskName, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers); + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping); + new GroupByContainerCount(3).group(taskModels, grouperMetadata); //Should throw } @Test public void testBalancerWithNullLocalityManager() { Set<TaskModel> taskModels = generateTaskModels(3); - Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels); - Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, null); + Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels); + Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, null); // Results should be the same as calling group() assertEquals(groupContainers, balanceContainers); } - - - Config getConfig(int containerCount) { - Map<String, String> config = new HashMap<>(); - config.put(JobConfig.JOB_CONTAINER_COUNT(), String.valueOf(containerCount)); - return new MapConfig(config); - } } http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java index 5bb78e8..12b6b1e 100644 --- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java +++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java @@ -20,6 +20,7 @@ package org.apache.samza.container.grouper.task; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.Collections; @@ -29,35 +30,24 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + +import org.apache.samza.Partition; import org.apache.samza.config.Config; import org.apache.samza.config.MapConfig; -import org.apache.samza.container.LocalityManager; import org.apache.samza.container.TaskName; import org.apache.samza.job.model.ContainerModel; import org.apache.samza.job.model.TaskModel; -import org.junit.Before; +import org.apache.samza.runtime.LocationId; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import static org.apache.samza.container.mock.ContainerMocks.*; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; +import static org.apache.samza.container.mock.ContainerMocks.generateTaskModels; +import static org.apache.samza.container.mock.ContainerMocks.getTaskName; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; -@RunWith(PowerMockRunner.class) -@PrepareForTest({TaskAssignmentManager.class, GroupByContainerIds.class}) public class TestGroupByContainerIds { - @Before - public void setup() throws Exception { - TaskAssignmentManager taskAssignmentManager = mock(TaskAssignmentManager.class); - LocalityManager localityManager = mock(LocalityManager.class); - PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager); - } - private Config buildConfigForContainerCount(int count) { Map<String, String> map = new HashMap<>(); map.put("job.container.count", String.valueOf(count)); @@ -67,6 +57,7 @@ public class TestGroupByContainerIds { private TaskNameGrouper buildSimpleGrouper() { return buildSimpleGrouper(1); } + private TaskNameGrouper buildSimpleGrouper(int containerCount) { return new GroupByContainerIdsFactory().build(buildConfigForContainerCount(containerCount)); } @@ -114,7 +105,8 @@ public class TestGroupByContainerIds { public void testGroupWithNullContainerIds() { Set<TaskModel> taskModels = generateTaskModels(5); - Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, null); + List<String> containerIds = null; + Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, containerIds); Map<String, ContainerModel> containersMap = new HashMap<>(); for (ContainerModel container : containers) { @@ -251,4 +243,264 @@ public class TestGroupByContainerIds { assertEquals(1, actualContainerModels.size()); assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels); } + + @Test + public void testShouldUseTaskLocalityWhenGeneratingContainerModels() { + TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3); + + String testProcessorId1 = "testProcessorId1"; + String testProcessorId2 = "testProcessorId2"; + String testProcessorId3 = "testProcessorId3"; + + LocationId testLocationId1 = new LocationId("testLocationId1"); + LocationId testLocationId2 = new LocationId("testLocationId2"); + LocationId testLocationId3 = new LocationId("testLocationId3"); + + TaskName testTaskName1 = new TaskName("testTasKId1"); + TaskName testTaskName2 = new TaskName("testTaskId2"); + TaskName testTaskName3 = new TaskName("testTaskId3"); + + TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0)); + TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1)); + TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2)); + + Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1, + testProcessorId2, testLocationId2, + testProcessorId3, testLocationId3); + + Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1, + testTaskName2, testLocationId2, + testTaskName3, testLocationId3); + + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>()); + + Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3); + + Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)), + new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)), + new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3))); + + Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata); + + assertEquals(expectedContainerModels, actualContainerModels); + } + + @Test + public void testGenerateContainerModelForSingleContainer() { + TaskNameGrouper taskNameGrouper = buildSimpleGrouper(1); + + String testProcessorId1 = "testProcessorId1"; + + LocationId testLocationId1 = new LocationId("testLocationId1"); + LocationId testLocationId2 = new LocationId("testLocationId2"); + LocationId testLocationId3 = new LocationId("testLocationId3"); + + TaskName testTaskName1 = new TaskName("testTasKId1"); + TaskName testTaskName2 = new TaskName("testTaskId2"); + TaskName testTaskName3 = new TaskName("testTaskId3"); + + TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0)); + TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1)); + TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2)); + + Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1); + + Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1, + testTaskName2, testLocationId2, + testTaskName3, testLocationId3); + + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>()); + + Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3); + + Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1, + testTaskName2, testTaskModel2, + testTaskName3, testTaskModel3))); + + Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata); + + assertEquals(expectedContainerModels, actualContainerModels); + } + + @Test + public void testShouldGenerateCorrectContainerModelWhenTaskLocalityIsEmpty() { + TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3); + + String testProcessorId1 = "testProcessorId1"; + String testProcessorId2 = "testProcessorId2"; + String testProcessorId3 = "testProcessorId3"; + + LocationId testLocationId1 = new LocationId("testLocationId1"); + LocationId testLocationId2 = new LocationId("testLocationId2"); + LocationId testLocationId3 = new LocationId("testLocationId3"); + + TaskName testTaskName1 = new TaskName("testTasKId1"); + TaskName testTaskName2 = new TaskName("testTaskId2"); + TaskName testTaskName3 = new TaskName("testTaskId3"); + + TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0)); + TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1)); + TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2)); + + Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1, + testProcessorId2, testLocationId2, + testProcessorId3, testLocationId3); + + Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1); + + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>()); + + Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3); + + Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)), + new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)), + new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3))); + + Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata); + + assertEquals(expectedContainerModels, actualContainerModels); + } + + @Test(expected = IllegalArgumentException.class) + public void testShouldFailWhenProcessorLocalityIsEmpty() { + TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3); + + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>()); + + taskNameGrouper.group(new HashSet<>(), grouperMetadata); + } + + @Test + public void testShouldGenerateIdenticalTaskDistributionWhenNoChangeInProcessorGroup() { + TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3); + + String testProcessorId1 = "testProcessorId1"; + String testProcessorId2 = "testProcessorId2"; + String testProcessorId3 = "testProcessorId3"; + + LocationId testLocationId1 = new LocationId("testLocationId1"); + LocationId testLocationId2 = new LocationId("testLocationId2"); + LocationId testLocationId3 = new LocationId("testLocationId3"); + + TaskName testTaskName1 = new TaskName("testTasKId1"); + TaskName testTaskName2 = new TaskName("testTaskId2"); + TaskName testTaskName3 = new TaskName("testTaskId3"); + + TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0)); + TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1)); + TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2)); + + Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1, + testProcessorId2, testLocationId2, + testProcessorId3, testLocationId3); + + Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1, + testTaskName2, testLocationId2, + testTaskName3, testLocationId3); + + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>()); + + Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3); + + Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)), + new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)), + new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3))); + + Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata); + + assertEquals(expectedContainerModels, actualContainerModels); + + actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata); + + assertEquals(expectedContainerModels, actualContainerModels); + } + + @Test + public void testShouldMinimizeTaskShuffleWhenAvailableProcessorInGroupChanges() { + TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3); + + String testProcessorId1 = "testProcessorId1"; + String testProcessorId2 = "testProcessorId2"; + String testProcessorId3 = "testProcessorId3"; + + LocationId testLocationId1 = new LocationId("testLocationId1"); + LocationId testLocationId2 = new LocationId("testLocationId2"); + LocationId testLocationId3 = new LocationId("testLocationId3"); + + TaskName testTaskName1 = new TaskName("testTasKId1"); + TaskName testTaskName2 = new TaskName("testTaskId2"); + TaskName testTaskName3 = new TaskName("testTaskId3"); + + TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), new Partition(0)); + TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), new Partition(1)); + TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), new Partition(2)); + + Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1, + testProcessorId2, testLocationId2, + testProcessorId3, testLocationId3); + + Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1, + testTaskName2, testLocationId2, + testTaskName3, testLocationId3); + + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>()); + + Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, testTaskModel2, testTaskModel3); + + Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1)), + new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2)), + new ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, testTaskModel3))); + + Set<ContainerModel> actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata); + + assertEquals(expectedContainerModels, actualContainerModels); + + processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1, + testProcessorId2, testLocationId2); + + grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>()); + + actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata); + + expectedContainerModels = ImmutableSet.of(new ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1, testTaskName3, testTaskModel3)), + new ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, testTaskModel2))); + + assertEquals(expectedContainerModels, actualContainerModels); + } + + @Test + public void testMoreTasksThanProcessors() { + String testProcessorId1 = "testProcessorId1"; + String testProcessorId2 = "testProcessorId2"; + + LocationId testLocationId1 = new LocationId("testLocationId1"); + LocationId testLocationId2 = new LocationId("testLocationId2"); + LocationId testLocationId3 = new LocationId("testLocationId3"); + + TaskName testTaskName1 = new TaskName("testTasKId1"); + TaskName testTaskName2 = new TaskName("testTaskId2"); + TaskName testTaskName3 = new TaskName("testTaskId3"); + + Map<String, LocationId> processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1, + testProcessorId2, testLocationId2); + + Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, testLocationId1, + testTaskName2, testLocationId2, + testTaskName3, testLocationId3); + + GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new HashMap<>()); + + + Set<TaskModel> taskModels = generateTaskModels(1); + List<String> containerIds = ImmutableList.of(testProcessorId1, testProcessorId2); + + Map<TaskName, TaskModel> expectedTasks = taskModels.stream() + .collect(Collectors.toMap(TaskModel::getTaskName, x -> x)); + ContainerModel expectedContainerModel = new ContainerModel(testProcessorId1, expectedTasks); + + Set<ContainerModel> actualContainerModels = buildSimpleGrouper().group(taskModels, grouperMetadata); + + assertEquals(1, actualContainerModels.size()); + assertEquals(ImmutableSet.of(expectedContainerModel), actualContainerModels); + } } http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java index fcdbf08..60164b2 100644 --- a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java +++ b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java @@ -68,7 +68,6 @@ public class TestTaskAssignmentManager { @Test public void testTaskAssignmentManager() { TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap()); - taskAssignmentManager.init(); Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1", "Task2", "2", "Task3", "0", "Task4", "1"); @@ -86,7 +85,6 @@ public class TestTaskAssignmentManager { @Test public void testDeleteMappings() { TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap()); - taskAssignmentManager.init(); Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1"); @@ -108,7 +106,6 @@ public class TestTaskAssignmentManager { @Test public void testTaskAssignmentManagerEmptyCoordinatorStream() { TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap()); - taskAssignmentManager.init(); Map<String, String> expectedMap = new HashMap<>(); Map<String, String> localMap = taskAssignmentManager.readTaskAssignment(); http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java index ca9def2..be240b1 100644 --- a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java +++ b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java @@ -117,11 +117,11 @@ public class ContainerMocks { return values; } - public static Map<String, String> generateTaskContainerMapping(Set<ContainerModel> containers) { - Map<String, String> taskMapping = new HashMap<>(); + public static Map<TaskName, String> generateTaskContainerMapping(Set<ContainerModel> containers) { + Map<TaskName, String> taskMapping = new HashMap<>(); for (ContainerModel container : containers) { for (TaskName taskName : container.getTasks().keySet()) { - taskMapping.put(taskName.getTaskName(), container.getId()); + taskMapping.put(taskName, container.getId()); } } return taskMapping; http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java index ea25ec1..02aaaa7 100644 --- a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java +++ b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java @@ -19,15 +19,15 @@ package org.apache.samza.coordinator; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.apache.samza.config.Config; import org.apache.samza.container.LocalityManager; +import org.apache.samza.container.grouper.task.GrouperMetadataImpl; import org.apache.samza.coordinator.server.HttpServer; import org.apache.samza.job.model.ContainerModel; import org.apache.samza.job.model.JobModel; +import org.apache.samza.runtime.LocationId; import org.apache.samza.system.StreamMetadataCache; /** @@ -49,15 +49,8 @@ public class JobModelManagerTestUtil { return new JobModelManager(jobModel, server, null); } - public static JobModelManager getJobModelManagerUsingReadModel(Config config, int containerCount, StreamMetadataCache streamMetadataCache, - LocalityManager locManager, HttpServer server) { - List<String> containerIds = new ArrayList<>(); - for (int i = 0; i < containerCount; i++) { - containerIds.add(String.valueOf(i)); - } - JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), locManager, streamMetadataCache, containerIds); - return new JobModelManager(jobModel, server, null); + public static JobModelManager getJobModelManagerUsingReadModel(Config config, StreamMetadataCache streamMetadataCache, HttpServer server, LocalityManager localityManager, Map<String, LocationId> processorLocality) { + JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), streamMetadataCache, new GrouperMetadataImpl(processorLocality, new HashMap<>(), new HashMap<>(), new HashMap<>())); + return new JobModelManager(new JobModel(jobModel.getConfig(), jobModel.getContainers(), localityManager), server, localityManager); } - - } http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java index 1dbf132..6048466 100644 --- a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java +++ b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java @@ -19,20 +19,30 @@ package org.apache.samza.coordinator; +import com.google.common.collect.ImmutableMap; +import java.util.HashSet; +import java.util.Set; import org.apache.samza.Partition; import org.apache.samza.config.Config; import org.apache.samza.config.MapConfig; import org.apache.samza.container.LocalityManager; +import org.apache.samza.container.TaskName; import org.apache.samza.container.grouper.task.GroupByContainerCount; +import org.apache.samza.container.grouper.task.GrouperMetadataImpl; import org.apache.samza.container.grouper.task.TaskAssignmentManager; import org.apache.samza.coordinator.server.HttpServer; import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping; +import org.apache.samza.job.model.ContainerModel; +import org.apache.samza.job.model.JobModel; +import org.apache.samza.job.model.TaskModel; +import org.apache.samza.runtime.LocationId; import org.apache.samza.system.StreamMetadataCache; import org.apache.samza.system.SystemStream; import org.apache.samza.system.SystemStreamMetadata; import org.apache.samza.testUtils.MockHttpServer; import org.eclipse.jetty.servlet.DefaultServlet; import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -40,14 +50,15 @@ import java.util.HashMap; import java.util.Map; import java.util.Collections; +import static org.apache.samza.coordinator.JobModelManager.*; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.argThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import org.junit.runner.RunWith; import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -60,7 +71,6 @@ import scala.collection.JavaConversions; @PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class}) public class TestJobModelManager { private final TaskAssignmentManager mockTaskManager = mock(TaskAssignmentManager.class); - private final LocalityManager mockLocalityManager = mock(LocalityManager.class); private final Map<String, Map<String, String>> localityMappings = new HashMap<>(); private final HttpServer server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class)); private final SystemStream inputStream = new SystemStream("test-system", "test-stream"); @@ -75,7 +85,6 @@ public class TestJobModelManager { @Before public void setup() throws Exception { - when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings); when(mockStreamMetadataCache.getStreamMetadata(argThat(new ArgumentMatcher<scala.collection.immutable.Set<SystemStream>>() { @Override public boolean matches(Object argument) { @@ -105,11 +114,15 @@ public class TestJobModelManager { put("job.host-affinity.enabled", "true"); } }); + LocalityManager mockLocalityManager = mock(LocalityManager.class); - this.localityMappings.put("0", new HashMap<String, String>() { { + localityMappings.put("0", new HashMap<String, String>() { { put(SetContainerHostMapping.HOST_KEY, "abc-affinity"); } }); - this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server); + when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings); + + Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity")); + this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality); assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", "abc-affinity"); } }); } @@ -132,11 +145,96 @@ public class TestJobModelManager { } }); - this.localityMappings.put("0", new HashMap<String, String>() { { + LocalityManager mockLocalityManager = mock(LocalityManager.class); + + localityMappings.put("0", new HashMap<String, String>() { { put(SetContainerHostMapping.HOST_KEY, "abc-affinity"); } }); - this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server); + when(mockLocalityManager.readContainerLocality()).thenReturn(new HashMap<>()); + + Map<String, LocationId> containerLocality = ImmutableMap.of("0", new LocationId("abc-affinity")); + + this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, mockStreamMetadataCache, server, mockLocalityManager, containerLocality); assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", null); } }); } + + @Test + public void testGetGrouperMetadata() { + // Mocking setup. + LocalityManager mockLocalityManager = mock(LocalityManager.class); + TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class); + + Map<String, Map<String, String>> localityMappings = new HashMap<>(); + localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity")); + + Map<String, String> taskAssignment = ImmutableMap.of("task-0", "0"); + + // Mock the container locality assignment. + when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings); + + // Mock the container to task assignment. + when(mockTaskAssignmentManager.readTaskAssignment()).thenReturn(taskAssignment); + + GrouperMetadataImpl grouperMetadata = JobModelManager.getGrouperMetadata(new MapConfig(), mockLocalityManager, mockTaskAssignmentManager); + + Mockito.verify(mockLocalityManager).readContainerLocality(); + Mockito.verify(mockTaskAssignmentManager).readTaskAssignment(); + + Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), grouperMetadata.getProcessorLocality()); + Assert.assertEquals(ImmutableMap.of(new TaskName("task-0"), new LocationId("abc-affinity")), grouperMetadata.getTaskLocality()); + } + + @Test + public void testGetProcessorLocality() { + // Mock the dependencies. + LocalityManager mockLocalityManager = mock(LocalityManager.class); + + Map<String, Map<String, String>> localityMappings = new HashMap<>(); + localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity")); + + // Mock the container locality assignment. + when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings); + + Map<String, LocationId> processorLocality = JobModelManager.getProcessorLocality(new MapConfig(), mockLocalityManager); + + Mockito.verify(mockLocalityManager).readContainerLocality(); + Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), processorLocality); + } + + @Test + public void testUpdateTaskAssignments() { + // Mocking setup. + JobModel mockJobModel = Mockito.mock(JobModel.class); + GrouperMetadataImpl mockGrouperMetadata = Mockito.mock(GrouperMetadataImpl.class); + TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class); + + Map<TaskName, TaskModel> taskModelMap = new HashMap<>(); + taskModelMap.put(new TaskName("task-1"), new TaskModel(new TaskName("task-1"), new HashSet<>(), new Partition(0))); + taskModelMap.put(new TaskName("task-2"), new TaskModel(new TaskName("task-2"), new HashSet<>(), new Partition(1))); + taskModelMap.put(new TaskName("task-3"), new TaskModel(new TaskName("task-3"), new HashSet<>(), new Partition(2))); + taskModelMap.put(new TaskName("task-4"), new TaskModel(new TaskName("task-4"), new HashSet<>(), new Partition(3))); + ContainerModel containerModel = new ContainerModel("test-container-id", taskModelMap); + Map<String, ContainerModel> containerMapping = ImmutableMap.of("test-container-id", containerModel); + + when(mockJobModel.getContainers()).thenReturn(containerMapping); + when(mockGrouperMetadata.getPreviousTaskToProcessorAssignment()).thenReturn(new HashMap<>()); + Mockito.doNothing().when(mockTaskAssignmentManager).writeTaskContainerMapping(Mockito.any(), Mockito.any()); + + JobModelManager.updateTaskAssignments(mockJobModel, mockTaskAssignmentManager, mockGrouperMetadata); + + Set<String> taskNames = new HashSet<String>(); + taskNames.add("task-4"); + taskNames.add("task-2"); + taskNames.add("task-3"); + taskNames.add("task-1"); + + // Verifications + Mockito.verify(mockJobModel, atLeast(1)).getContainers(); + Mockito.verify(mockTaskAssignmentManager).deleteTaskContainerMappings((Iterable<String>) taskNames); + Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-1", "test-container-id"); + Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-2", "test-container-id"); + Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-3", "test-container-id"); + Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-4", "test-container-id"); + } }
