http://git-wip-us.apache.org/repos/asf/tez/blob/267fe737/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java index b4064a0..352ad87 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java @@ -28,17 +28,23 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.security.Credentials; import org.apache.hadoop.yarn.api.ApplicationConstants; import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.apache.hadoop.yarn.event.EventHandler; import org.apache.hadoop.yarn.util.Clock; +import org.apache.tez.dag.api.DagTypeConverters; import org.apache.tez.dag.api.TaskLocationHint; import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.api.TezConstants; +import org.apache.tez.dag.api.Vertex; +import org.apache.tez.dag.api.Vertex.VertexExecutionContext; import org.apache.tez.dag.api.VertexLocationHint; import org.apache.tez.dag.api.records.DAGProtos; +import org.apache.tez.dag.api.records.DAGProtos.VertexPlan; import org.apache.tez.dag.app.AppContext; import org.apache.tez.dag.app.ContainerContext; import org.apache.tez.dag.app.TaskAttemptListener; @@ -47,6 +53,7 @@ import org.apache.tez.dag.app.dag.DAG; import org.apache.tez.dag.app.dag.StateChangeNotifier; import org.apache.tez.dag.records.TezVertexID; import org.apache.tez.dag.utils.TaskSpecificLaunchCmdOption; +import org.apache.tez.runtime.api.ExecutionContext; import org.junit.Test; /** @@ -60,7 +67,8 @@ public class TestVertexImpl2 { Configuration conf = new TezConfiguration(); conf.set(TezConfiguration.TEZ_TASK_LOG_LEVEL, "DEBUG;org.apache.hadoop.ipc=INFO;org.apache.hadoop.server=INFO"); - LogTestInfoHolder testInfo = new LogTestInfoHolder(conf); + LogTestInfoHolder testInfo = new LogTestInfoHolder(); + VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf); List<String> expectedCommands = new LinkedList<String>(); expectedCommands.add("-Dlog4j.configuratorClass=org.apache.tez.common.TezLog4jConfigurator"); @@ -71,7 +79,8 @@ public class TestVertexImpl2 { TezConstants.TEZ_CONTAINER_LOGGER_NAME); for (int i = 0 ; i < testInfo.numTasks ; i++) { - ContainerContext containerContext = testInfo.vertex.getContainerContext(i); + ContainerContext containerContext = vertexWrapper + .vertex.getContainerContext(i); String javaOpts = containerContext.getJavaOpts(); assertTrue(javaOpts.contains(testInfo.initialJavaOpts)); for (String expectedCmd : expectedCommands) { @@ -92,7 +101,8 @@ public class TestVertexImpl2 { Configuration conf = new TezConfiguration(); conf.set(TezConfiguration.TEZ_TASK_LOG_LEVEL, "DEBUG"); - LogTestInfoHolder testInfo = new LogTestInfoHolder(conf); + LogTestInfoHolder testInfo = new LogTestInfoHolder(); + VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf); List<String> expectedCommands = new LinkedList<String>(); expectedCommands.add("-Dlog4j.configuratorClass=org.apache.tez.common.TezLog4jConfigurator"); @@ -103,7 +113,7 @@ public class TestVertexImpl2 { TezConstants.TEZ_CONTAINER_LOGGER_NAME); for (int i = 0 ; i < testInfo.numTasks ; i++) { - ContainerContext containerContext = testInfo.vertex.getContainerContext(i); + ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i); String javaOpts = containerContext.getJavaOpts(); assertTrue(javaOpts.contains(testInfo.initialJavaOpts)); for (String expectedCmd : expectedCommands) { @@ -130,7 +140,8 @@ public class TestVertexImpl2 { conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LOG_LEVEL, "DEBUG;org.apache.tez=INFO"); conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LAUNCH_CMD_OPTS, customJavaOpts); - LogTestInfoHolder testInfo = new LogTestInfoHolder(conf); + LogTestInfoHolder testInfo = new LogTestInfoHolder(); + VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf); // Expected command opts for regular tasks List<String> expectedCommands = new LinkedList<String>(); @@ -142,7 +153,7 @@ public class TestVertexImpl2 { TezConstants.TEZ_CONTAINER_LOGGER_NAME); for (int i = 3 ; i < testInfo.numTasks ; i++) { - ContainerContext containerContext = testInfo.vertex.getContainerContext(i); + ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i); String javaOpts = containerContext.getJavaOpts(); assertTrue(javaOpts.contains(testInfo.initialJavaOpts)); @@ -167,7 +178,7 @@ public class TestVertexImpl2 { TezConstants.TEZ_CONTAINER_LOGGER_NAME); for (int i = 0 ; i < 3 ; i++) { - ContainerContext containerContext = testInfo.vertex.getContainerContext(i); + ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i); String javaOpts = containerContext.getJavaOpts(); assertTrue(javaOpts.contains(testInfo.initialJavaOpts)); @@ -195,7 +206,8 @@ public class TestVertexImpl2 { conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LOG_LEVEL, "DEBUG"); conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LAUNCH_CMD_OPTS, customJavaOpts); - LogTestInfoHolder testInfo = new LogTestInfoHolder(conf); + LogTestInfoHolder testInfo = new LogTestInfoHolder(); + VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf); // Expected command opts for regular tasks List<String> expectedCommands = new LinkedList<String>(); @@ -207,7 +219,7 @@ public class TestVertexImpl2 { TezConstants.TEZ_CONTAINER_LOGGER_NAME); for (int i = 3 ; i < testInfo.numTasks ; i++) { - ContainerContext containerContext = testInfo.vertex.getContainerContext(i); + ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i); String javaOpts = containerContext.getJavaOpts(); assertTrue(javaOpts.contains(testInfo.initialJavaOpts)); @@ -232,7 +244,7 @@ public class TestVertexImpl2 { TezConstants.TEZ_CONTAINER_LOGGER_NAME); for (int i = 0 ; i < 3 ; i++) { - ContainerContext containerContext = testInfo.vertex.getContainerContext(i); + ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i); String javaOpts = containerContext.getJavaOpts(); assertTrue(javaOpts.contains(testInfo.initialJavaOpts)); @@ -248,43 +260,224 @@ public class TestVertexImpl2 { } } + @Test(timeout = 5000) + public void testNullExecutionContexts() { - private static class LogTestInfoHolder { + ExecutionContextTestInfoHolder info = new ExecutionContextTestInfoHolder(null, null); + VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info); - final AppContext mockAppContext; - final DAG mockDag; - final VertexImpl vertex; - final DAGProtos.VertexPlan vertexPlan; + assertEquals(0, vertexWrapper.vertex.taskSchedulerIdentifier); + assertEquals(0, vertexWrapper.vertex.containerLauncherIdentifier); + assertEquals(0, vertexWrapper.vertex.taskCommunicatorIdentifier); + } + + @Test(timeout = 5000) + public void testDefaultExecContextViaDag() { + VertexExecutionContext defaultExecContext = VertexExecutionContext.create( + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 0), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 2), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 2)); + ExecutionContextTestInfoHolder info = + new ExecutionContextTestInfoHolder(null, defaultExecContext, 3); + VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info); + + assertEquals(0, vertexWrapper.vertex.taskSchedulerIdentifier); + assertEquals(2, vertexWrapper.vertex.containerLauncherIdentifier); + assertEquals(2, vertexWrapper.vertex.taskCommunicatorIdentifier); + } + + @Test(timeout = 5000) + public void testVertexExecutionContextOnly() { + VertexExecutionContext vertexExecutionContext = VertexExecutionContext.create( + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 1), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 1), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 1)); + ExecutionContextTestInfoHolder info = + new ExecutionContextTestInfoHolder(vertexExecutionContext, null, 3); + VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info); + + assertEquals(1, vertexWrapper.vertex.taskSchedulerIdentifier); + assertEquals(1, vertexWrapper.vertex.containerLauncherIdentifier); + assertEquals(1, vertexWrapper.vertex.taskCommunicatorIdentifier); + } + + @Test(timeout = 5000) + public void testVertexExecutionContextOverride() { + VertexExecutionContext defaultExecContext = VertexExecutionContext.create( + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 0), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 2), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 2)); + + VertexExecutionContext vertexExecutionContext = VertexExecutionContext.create( + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 1), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 1), + ExecutionContextTestInfoHolder + .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 1)); + ExecutionContextTestInfoHolder info = + new ExecutionContextTestInfoHolder(vertexExecutionContext, defaultExecContext, 3); + VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info); + + assertEquals(1, vertexWrapper.vertex.taskSchedulerIdentifier); + assertEquals(1, vertexWrapper.vertex.containerLauncherIdentifier); + assertEquals(1, vertexWrapper.vertex.taskCommunicatorIdentifier); + } + + + private static class ExecutionContextTestInfoHolder { + + static final String TASK_SCHEDULER_NAME_BASE = "TASK_SCHEDULER"; + static final String CONTAINER_LAUNCHER_NAME_BASE = "CONTAINER_LAUNCHER"; + static final String TASK_COMM_NAME_BASE = "TASK_COMMUNICATOR"; + + static String append(String base, int index) { + return base + index; + } + + final String vertexName; + final VertexExecutionContext defaultExecutionContext; + final VertexExecutionContext vertexExecutionContext; + final BiMap<String, Integer> taskSchedulers = HashBiMap.create(); + final BiMap<String, Integer> containerLaunchers = HashBiMap.create(); + final BiMap<String, Integer> taskComms = HashBiMap.create(); + final AppContext appContext; + + public ExecutionContextTestInfoHolder(VertexExecutionContext vertexExecutionContext, + VertexExecutionContext defaultDagExecutionContext) { + this(vertexExecutionContext, defaultDagExecutionContext, 0); + } + + public ExecutionContextTestInfoHolder(VertexExecutionContext vertexExecutionContext, + VertexExecutionContext defaultDagExecitionContext, + int numPlugins) { + this.vertexName = "testvertex"; + this.vertexExecutionContext = vertexExecutionContext; + this.defaultExecutionContext = defaultDagExecitionContext; + if (numPlugins == 0) { + this.taskSchedulers.put(TezConstants.getTezYarnServicePluginName(), 0); + this.containerLaunchers.put(TezConstants.getTezYarnServicePluginName(), 0); + this.taskSchedulers.put(TezConstants.getTezYarnServicePluginName(), 0); + } else { + for (int i = 0; i < numPlugins; i++) { + this.taskSchedulers.put(append(TASK_SCHEDULER_NAME_BASE, i), i); + this.containerLaunchers.put(append(CONTAINER_LAUNCHER_NAME_BASE, i), i); + this.taskComms.put(append(TASK_COMM_NAME_BASE, i), i); + } + } + + this.appContext = createDefaultMockAppContext(); + DAG dag = appContext.getCurrentDAG(); + doReturn(defaultDagExecitionContext).when(dag).getDefaultExecutionContext(); + for (Map.Entry<String, Integer> entry : taskSchedulers.entrySet()) { + doReturn(entry.getKey()).when(appContext).getTaskSchedulerName(entry.getValue()); + doReturn(entry.getValue()).when(appContext).getTaskScheduerIdentifier(entry.getKey()); + } + for (Map.Entry<String, Integer> entry : containerLaunchers.entrySet()) { + doReturn(entry.getKey()).when(appContext).getContainerLauncherName(entry.getValue()); + doReturn(entry.getValue()).when(appContext).getContainerLauncherIdentifier(entry.getKey()); + } + for (Map.Entry<String, Integer> entry : taskComms.entrySet()) { + doReturn(entry.getKey()).when(appContext).getTaskCommunicatorName(entry.getValue()); + doReturn(entry.getValue()).when(appContext).getTaskCommunicatorIdentifier(entry.getKey()); + } + } + } + private VertexWrapper createVertexWrapperForExecutionContextTest( + ExecutionContextTestInfoHolder vertexInfo) { + VertexPlan vertexPlan = createVertexPlanForExeuctionContextTests(vertexInfo); + VertexWrapper vertexWrapper = + new VertexWrapper(vertexInfo.appContext, vertexPlan, new Configuration(false)); + return vertexWrapper; + } + + private VertexPlan createVertexPlanForExeuctionContextTests(ExecutionContextTestInfoHolder info) { + VertexPlan.Builder vertexPlanBuilder = VertexPlan.newBuilder() + .setName(info.vertexName) + .setTaskConfig(DAGProtos.PlanTaskConfiguration.newBuilder() + .setNumTasks(10) + .setJavaOpts("dontcare") + .setMemoryMb(1024) + .setVirtualCores(1) + .setTaskModule("taskmodule") + .build()) + .setType(DAGProtos.PlanVertexType.NORMAL); + if (info.vertexExecutionContext != null) { + vertexPlanBuilder + .setExecutionContext(DagTypeConverters.convertToProto(info.vertexExecutionContext)); + } + return vertexPlanBuilder.build(); + } + + private static class LogTestInfoHolder { final int numTasks = 10; final String initialJavaOpts = "initialJavaOpts"; final String envKey = "key1"; final String envVal = "val1"; + final String vertexName; + + public LogTestInfoHolder() { + this("testvertex"); + } - LogTestInfoHolder(Configuration conf) { - this(conf, "testvertex"); + public LogTestInfoHolder(String vertexName) { + this.vertexName = vertexName; } + } + + private VertexWrapper createVertexWrapperForLogTests(LogTestInfoHolder logTestInfoHolder, + Configuration conf) { + VertexPlan vertexPlan = createVertexPlanForLogTests(logTestInfoHolder); + VertexWrapper vertexWrapper = new VertexWrapper(vertexPlan, conf); + return vertexWrapper; + } + + private VertexPlan createVertexPlanForLogTests(LogTestInfoHolder logTestInfoHolder) { + VertexPlan vertexPlan = VertexPlan.newBuilder() + .setName(logTestInfoHolder.vertexName) + .setTaskConfig(DAGProtos.PlanTaskConfiguration.newBuilder() + .setJavaOpts(logTestInfoHolder.initialJavaOpts) + .setNumTasks(logTestInfoHolder.numTasks) + .setMemoryMb(1024) + .setVirtualCores(1) + .setTaskModule("taskmodule") + .addEnvironmentSetting(DAGProtos.PlanKeyValuePair.newBuilder() + .setKey(logTestInfoHolder.envKey) + .setValue(logTestInfoHolder.envVal) + .build()) + .build()) + .setType(DAGProtos.PlanVertexType.NORMAL).build(); + return vertexPlan; + } + + private static class VertexWrapper { - LogTestInfoHolder(Configuration conf, String vertexName) { - mockAppContext = mock(AppContext.class); - mockDag = mock(DAG.class); - doReturn(new Credentials()).when(mockDag).getCredentials(); - doReturn(mockDag).when(mockAppContext).getCurrentDAG(); - - vertexPlan = DAGProtos.VertexPlan.newBuilder() - .setName(vertexName) - .setTaskConfig(DAGProtos.PlanTaskConfiguration.newBuilder() - .setJavaOpts(initialJavaOpts) - .setNumTasks(numTasks) - .setMemoryMb(1024) - .setVirtualCores(1) - .setTaskModule("taskmodule") - .addEnvironmentSetting(DAGProtos.PlanKeyValuePair.newBuilder() - .setKey(envKey) - .setValue(envVal) - .build()) - .build()) - .setType(DAGProtos.PlanVertexType.NORMAL).build(); + final AppContext mockAppContext; + final VertexImpl vertex; + final VertexPlan vertexPlan; + + VertexWrapper(AppContext appContext, VertexPlan vertexPlan, Configuration conf) { + if (appContext == null) { + mockAppContext = createDefaultMockAppContext(); + DAG mockDag = mock(DAG.class); + doReturn(new Credentials()).when(mockDag).getCredentials(); + doReturn(mockDag).when(mockAppContext).getCurrentDAG(); + } else { + mockAppContext = appContext; + } + + + this.vertexPlan = vertexPlan; vertex = new VertexImpl(TezVertexID.fromString("vertex_1418197758681_0001_1_00"), vertexPlan, @@ -293,5 +486,17 @@ public class TestVertexImpl2 { VertexLocationHint.create(new LinkedList<TaskLocationHint>()), null, new TaskSpecificLaunchCmdOption(conf), mock(StateChangeNotifier.class)); } + + VertexWrapper(VertexPlan vertexPlan, Configuration conf) { + this(null, vertexPlan, conf); + } + } + + private static AppContext createDefaultMockAppContext() { + AppContext appContext = mock(AppContext.class); + DAG mockDag = mock(DAG.class); + doReturn(new Credentials()).when(mockDag).getCredentials(); + doReturn(mockDag).when(appContext).getCurrentDAG(); + return appContext; } }
http://git-wip-us.apache.org/repos/asf/tez/blob/267fe737/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java b/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java new file mode 100644 index 0000000..62a5f19 --- /dev/null +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java @@ -0,0 +1,361 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.app.launcher; + + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.UnknownHostException; +import java.nio.ByteBuffer; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.NamedEntityDescriptor; +import org.apache.tez.dag.api.TezConstants; +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.dag.app.AppContext; +import org.apache.tez.dag.app.TaskAttemptListener; +import org.apache.tez.dag.app.rm.NMCommunicatorLaunchRequestEvent; +import org.apache.tez.serviceplugins.api.ContainerLaunchRequest; +import org.apache.tez.serviceplugins.api.ContainerLauncher; +import org.apache.tez.serviceplugins.api.ContainerLauncherContext; +import org.apache.tez.serviceplugins.api.ContainerStopRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +public class TestContainerLauncherRouter { + + @Before + @After + public void reset() { + ContainerLaucherRouterForMultipleLauncherTest.reset(); + } + + @Test(timeout = 5000) + public void testNoLaunchersSpecified() throws IOException { + + AppContext appContext = mock(AppContext.class); + TaskAttemptListener tal = mock(TaskAttemptListener.class); + + try { + + new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null, null, + false); + fail("Expecting a failure without any launchers being specified"); + } catch (IllegalArgumentException e) { + + } + } + + @Test(timeout = 5000) + public void testCustomLauncherSpecified() throws IOException { + Configuration conf = new Configuration(false); + + AppContext appContext = mock(AppContext.class); + TaskAttemptListener tal = mock(TaskAttemptListener.class); + + String customLauncherName = "customLauncher"; + List<NamedEntityDescriptor> launcherDescriptors = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload customPayload = UserPayload.create(bb); + launcherDescriptors.add( + new NamedEntityDescriptor(customLauncherName, FakeContainerLauncher.class.getName()) + .setUserPayload(customPayload)); + + ContainerLaucherRouterForMultipleLauncherTest clr = + new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null, + launcherDescriptors, + true); + try { + clr.init(conf); + clr.start(); + + assertEquals(1, clr.getNumContainerLaunchers()); + assertFalse(clr.getYarnContainerLauncherCreated()); + assertFalse(clr.getUberContainerLauncherCreated()); + assertEquals(customLauncherName, clr.getContainerLauncherName(0)); + assertEquals(bb, clr.getContainerLauncherContext(0).getInitialUserPayload().getPayload()); + } finally { + clr.stop(); + } + } + + @Test(timeout = 5000) + public void testMultipleContainerLaunchers() throws IOException { + Configuration conf = new Configuration(false); + conf.set("testkey", "testvalue"); + UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf); + + AppContext appContext = mock(AppContext.class); + TaskAttemptListener tal = mock(TaskAttemptListener.class); + + String customLauncherName = "customLauncher"; + List<NamedEntityDescriptor> launcherDescriptors = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload customPayload = UserPayload.create(bb); + launcherDescriptors.add( + new NamedEntityDescriptor(customLauncherName, FakeContainerLauncher.class.getName()) + .setUserPayload(customPayload)); + launcherDescriptors + .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(userPayload)); + + ContainerLaucherRouterForMultipleLauncherTest clr = + new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null, + launcherDescriptors, + true); + try { + clr.init(conf); + clr.start(); + + assertEquals(2, clr.getNumContainerLaunchers()); + assertTrue(clr.getYarnContainerLauncherCreated()); + assertFalse(clr.getUberContainerLauncherCreated()); + assertEquals(customLauncherName, clr.getContainerLauncherName(0)); + assertEquals(bb, clr.getContainerLauncherContext(0).getInitialUserPayload().getPayload()); + + assertEquals(TezConstants.getTezYarnServicePluginName(), clr.getContainerLauncherName(1)); + Configuration confParsed = TezUtils + .createConfFromUserPayload(clr.getContainerLauncherContext(1).getInitialUserPayload()); + assertEquals("testvalue", confParsed.get("testkey")); + } finally { + clr.stop(); + } + } + + @Test(timeout = 5000) + public void testEventRouting() throws Exception { + Configuration conf = new Configuration(false); + UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf); + + AppContext appContext = mock(AppContext.class); + TaskAttemptListener tal = mock(TaskAttemptListener.class); + + String customLauncherName = "customLauncher"; + List<NamedEntityDescriptor> launcherDescriptors = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload customPayload = UserPayload.create(bb); + launcherDescriptors.add( + new NamedEntityDescriptor(customLauncherName, FakeContainerLauncher.class.getName()) + .setUserPayload(customPayload)); + launcherDescriptors + .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(userPayload)); + + ContainerLaucherRouterForMultipleLauncherTest clr = + new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null, + launcherDescriptors, + true); + try { + clr.init(conf); + clr.start(); + + assertEquals(2, clr.getNumContainerLaunchers()); + assertTrue(clr.getYarnContainerLauncherCreated()); + assertFalse(clr.getUberContainerLauncherCreated()); + assertEquals(customLauncherName, clr.getContainerLauncherName(0)); + assertEquals(TezConstants.getTezYarnServicePluginName(), clr.getContainerLauncherName(1)); + + verify(clr.getTestContainerLauncher(0)).initialize(); + verify(clr.getTestContainerLauncher(0)).start(); + verify(clr.getTestContainerLauncher(1)).initialize(); + verify(clr.getTestContainerLauncher(1)).start(); + + ContainerLaunchContext clc1 = mock(ContainerLaunchContext.class); + Container container1 = mock(Container.class); + + ContainerLaunchContext clc2 = mock(ContainerLaunchContext.class); + Container container2 = mock(Container.class); + + NMCommunicatorLaunchRequestEvent launchRequestEvent1 = + new NMCommunicatorLaunchRequestEvent(clc1, container1, 0, 0, 0); + NMCommunicatorLaunchRequestEvent launchRequestEvent2 = + new NMCommunicatorLaunchRequestEvent(clc2, container2, 1, 0, 0); + + clr.handle(launchRequestEvent1); + + + ArgumentCaptor<ContainerLaunchRequest> captor = + ArgumentCaptor.forClass(ContainerLaunchRequest.class); + verify(clr.getTestContainerLauncher(0)).launchContainer(captor.capture()); + assertEquals(1, captor.getAllValues().size()); + ContainerLaunchRequest launchRequest1 = captor.getValue(); + assertEquals(clc1, launchRequest1.getContainerLaunchContext()); + + clr.handle(launchRequestEvent2); + captor = ArgumentCaptor.forClass(ContainerLaunchRequest.class); + verify(clr.getTestContainerLauncher(1)).launchContainer(captor.capture()); + assertEquals(1, captor.getAllValues().size()); + ContainerLaunchRequest launchRequest2 = captor.getValue(); + assertEquals(clc2, launchRequest2.getContainerLaunchContext()); + + } finally { + clr.stop(); + verify(clr.getTestContainerLauncher(0)).shutdown(); + verify(clr.getTestContainerLauncher(1)).shutdown(); + } + } + + private static class ContainerLaucherRouterForMultipleLauncherTest + extends ContainerLauncherRouter { + + // All variables setup as static since methods being overridden are invoked by the ContainerLauncherRouter ctor, + // and regular variables will not be initialized at this point. + private static final AtomicInteger numContainerLaunchers = new AtomicInteger(0); + private static final Set<Integer> containerLauncherIndices = new HashSet<>(); + private static final ContainerLauncher yarnContainerLauncher = mock(ContainerLauncher.class); + private static final ContainerLauncher uberContainerlauncher = mock(ContainerLauncher.class); + private static final AtomicBoolean yarnContainerLauncherCreated = new AtomicBoolean(false); + private static final AtomicBoolean uberContainerLauncherCreated = new AtomicBoolean(false); + + private static final List<ContainerLauncherContext> containerLauncherContexts = + new LinkedList<>(); + private static final List<String> containerLauncherNames = new LinkedList<>(); + private static final List<ContainerLauncher> testContainerLaunchers = new LinkedList<>(); + + + public static void reset() { + numContainerLaunchers.set(0); + containerLauncherIndices.clear(); + yarnContainerLauncherCreated.set(false); + uberContainerLauncherCreated.set(false); + containerLauncherContexts.clear(); + containerLauncherNames.clear(); + testContainerLaunchers.clear(); + } + + public ContainerLaucherRouterForMultipleLauncherTest(AppContext context, + TaskAttemptListener taskAttemptListener, + String workingDirectory, + List<NamedEntityDescriptor> containerLauncherDescriptors, + boolean isPureLocalMode) throws + UnknownHostException { + super(context, taskAttemptListener, workingDirectory, + containerLauncherDescriptors, isPureLocalMode); + } + + @Override + ContainerLauncher createContainerLauncher(NamedEntityDescriptor containerLauncherDescriptor, + AppContext context, + ContainerLauncherContext containerLauncherContext, + TaskAttemptListener taskAttemptListener, + String workingDirectory, + int containerLauncherIndex, + boolean isPureLocalMode) throws + UnknownHostException { + numContainerLaunchers.incrementAndGet(); + boolean added = containerLauncherIndices.add(containerLauncherIndex); + assertTrue("Cannot add multiple launchers with the same index", added); + containerLauncherNames.add(containerLauncherDescriptor.getEntityName()); + containerLauncherContexts.add(containerLauncherContext); + return super + .createContainerLauncher(containerLauncherDescriptor, context, containerLauncherContext, + taskAttemptListener, workingDirectory, containerLauncherIndex, isPureLocalMode); + } + + @Override + ContainerLauncher createYarnContainerLauncher( + ContainerLauncherContext containerLauncherContext) { + yarnContainerLauncherCreated.set(true); + testContainerLaunchers.add(yarnContainerLauncher); + return yarnContainerLauncher; + } + + @Override + ContainerLauncher createUberContainerLauncher(ContainerLauncherContext containerLauncherContext, + AppContext context, + TaskAttemptListener taskAttemptListener, + String workingDirectory, + boolean isPureLocalMode) throws + UnknownHostException { + uberContainerLauncherCreated.set(true); + testContainerLaunchers.add(uberContainerlauncher); + return uberContainerlauncher; + } + + @Override + ContainerLauncher createCustomContainerLauncher( + ContainerLauncherContext containerLauncherContext, + NamedEntityDescriptor containerLauncherDescriptor) { + ContainerLauncher spyLauncher = spy(super.createCustomContainerLauncher( + containerLauncherContext, containerLauncherDescriptor)); + testContainerLaunchers.add(spyLauncher); + return spyLauncher; + } + + public int getNumContainerLaunchers() { + return numContainerLaunchers.get(); + } + + public boolean getYarnContainerLauncherCreated() { + return yarnContainerLauncherCreated.get(); + } + + public boolean getUberContainerLauncherCreated() { + return uberContainerLauncherCreated.get(); + } + + public String getContainerLauncherName(int containerLauncherIndex) { + return containerLauncherNames.get(containerLauncherIndex); + } + + public ContainerLauncher getTestContainerLauncher(int containerLauncherIndex) { + return testContainerLaunchers.get(containerLauncherIndex); + } + + public ContainerLauncherContext getContainerLauncherContext(int containerLauncherIndex) { + return containerLauncherContexts.get(containerLauncherIndex); + } + } + + private static class FakeContainerLauncher extends ContainerLauncher { + + public FakeContainerLauncher( + ContainerLauncherContext containerLauncherContext) { + super(containerLauncherContext); + } + + @Override + public void launchContainer(ContainerLaunchRequest launchRequest) { + + } + + @Override + public void stopContainer(ContainerStopRequest stopRequest) { + + } + } + +} http://git-wip-us.apache.org/repos/asf/tez/blob/267fe737/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java index f8aa1e2..3e68a4c 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java @@ -19,22 +19,30 @@ package org.apache.tez.dag.app.rm; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; import java.net.InetSocketAddress; +import java.nio.ByteBuffer; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.security.Credentials; @@ -44,6 +52,7 @@ import org.apache.hadoop.yarn.api.records.ContainerExitStatus; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.api.records.ContainerStatus; import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.NodeId; import org.apache.hadoop.yarn.api.records.Priority; import org.apache.hadoop.yarn.api.records.Resource; import org.apache.hadoop.yarn.event.Event; @@ -53,12 +62,13 @@ import org.apache.tez.common.TezUtils; import org.apache.tez.dag.api.NamedEntityDescriptor; import org.apache.tez.dag.api.TaskLocationHint; import org.apache.tez.dag.api.TezConfiguration; -import org.apache.tez.dag.api.TezUncheckedException; +import org.apache.tez.dag.api.TezConstants; import org.apache.tez.dag.api.UserPayload; import org.apache.tez.dag.api.client.DAGClientServer; import org.apache.tez.dag.app.AppContext; import org.apache.tez.dag.app.ContainerContext; import org.apache.tez.dag.app.ServicePluginLifecycleAbstractService; +import org.apache.tez.dag.app.dag.TaskAttempt; import org.apache.tez.dag.app.dag.impl.TaskAttemptImpl; import org.apache.tez.dag.app.dag.impl.TaskImpl; import org.apache.tez.dag.app.dag.impl.VertexImpl; @@ -70,8 +80,14 @@ import org.apache.tez.dag.app.rm.container.AMContainerMap; import org.apache.tez.dag.app.rm.container.AMContainerState; import org.apache.tez.dag.app.web.WebUIService; import org.apache.tez.dag.records.TaskAttemptTerminationCause; +import org.apache.tez.dag.records.TezDAGID; import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.dag.records.TezTaskID; +import org.apache.tez.dag.records.TezVertexID; +import org.apache.tez.runtime.api.impl.TaskSpec; +import org.apache.tez.serviceplugins.api.TaskAttemptEndReason; import org.apache.tez.serviceplugins.api.TaskScheduler; +import org.apache.tez.serviceplugins.api.TaskSchedulerContext; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -95,10 +111,9 @@ public class TestTaskSchedulerEventHandler { public MockTaskSchedulerEventHandler(AppContext appContext, DAGClientServer clientService, EventHandler eventHandler, - ContainerSignatureMatcher containerSignatureMatcher, WebUIService webUI, - UserPayload defaultPayload) { + ContainerSignatureMatcher containerSignatureMatcher, WebUIService webUI) { super(appContext, clientService, eventHandler, containerSignatureMatcher, webUI, - new LinkedList<NamedEntityDescriptor>(), defaultPayload, false); + Lists.newArrayList(new NamedEntityDescriptor("FakeDescriptor", null)), false); } @Override @@ -140,14 +155,8 @@ public class TestTaskSchedulerEventHandler { when(mockAppContext.getAllContainers()).thenReturn(mockAMContainerMap); when(mockClientService.getBindAddress()).thenReturn(new InetSocketAddress(10000)); Configuration conf = new Configuration(false); - UserPayload userPayload; - try { - userPayload = TezUtils.createUserPayloadFromConf(conf); - } catch (IOException e) { - throw new TezUncheckedException(e); - } schedulerHandler = new MockTaskSchedulerEventHandler( - mockAppContext, mockClientService, mockEventHandler, mockSigMatcher, mockWebUIService, userPayload); + mockAppContext, mockClientService, mockEventHandler, mockSigMatcher, mockWebUIService); } @Test(timeout = 5000) @@ -272,7 +281,7 @@ public class TestTaskSchedulerEventHandler { when(mockAmContainer.getContainerLauncherIdentifier()).thenReturn(0); when(mockAmContainer.getTaskCommunicatorIdentifier()).thenReturn(0); ContainerId mockCId = mock(ContainerId.class); - verify(mockTaskScheduler, times(0)).deallocateContainer((ContainerId)any()); + verify(mockTaskScheduler, times(0)).deallocateContainer((ContainerId) any()); when(mockAMContainerMap.get(mockCId)).thenReturn(mockAmContainer); schedulerHandler.preemptContainer(0, mockCId); verify(mockTaskScheduler, times(1)).deallocateContainer(mockCId); @@ -400,5 +409,300 @@ public class TestTaskSchedulerEventHandler { } - // TODO TEZ-2003. Add tests with multiple schedulers, and ensuring that events go out with correct IDs. + @Test(timeout = 5000) + public void testNoSchedulerSpecified() throws IOException { + try { + TSEHForMultipleSchedulersTest tseh = + new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler, + mockSigMatcher, mockWebUIService, null, false); + fail("Expecting an IllegalStateException with no schedulers specified"); + } catch (IllegalArgumentException e) { + } + } + + // Verified via statics + @Test(timeout = 5000) + public void testCustomTaskSchedulerSetup() throws IOException { + Configuration conf = new Configuration(false); + conf.set("testkey", "testval"); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + + String customSchedulerName = "fakeScheduler"; + List<NamedEntityDescriptor> taskSchedulers = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload userPayload = UserPayload.create(bb); + taskSchedulers.add( + new NamedEntityDescriptor(customSchedulerName, FakeTaskScheduler.class.getName()) + .setUserPayload(userPayload)); + taskSchedulers.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(defaultPayload)); + + TSEHForMultipleSchedulersTest tseh = + new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler, + mockSigMatcher, mockWebUIService, taskSchedulers, false); + + tseh.init(conf); + tseh.start(); + + // Verify that the YARN task scheduler is installed by default + assertTrue(tseh.getYarnSchedulerCreated()); + assertFalse(tseh.getUberSchedulerCreated()); + assertEquals(2, tseh.getNumCreateInvocations()); + + // Verify the order of the schedulers + assertEquals(customSchedulerName, tseh.getTaskSchedulerName(0)); + assertEquals(TezConstants.getTezYarnServicePluginName(), tseh.getTaskSchedulerName(1)); + + // Verify the payload setup for the custom task scheduler + assertNotNull(tseh.getTaskSchedulerContext(0)); + assertEquals(bb, tseh.getTaskSchedulerContext(0).getInitialUserPayload().getPayload()); + + // Verify the payload on the yarn scheduler + assertNotNull(tseh.getTaskSchedulerContext(1)); + Configuration parsed = TezUtils.createConfFromUserPayload(tseh.getTaskSchedulerContext(1).getInitialUserPayload()); + assertEquals("testval", parsed.get("testkey")); + } + + @Test(timeout = 5000) + public void testTaskSchedulerRouting() throws Exception { + Configuration conf = new Configuration(false); + UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf); + + String customSchedulerName = "fakeScheduler"; + List<NamedEntityDescriptor> taskSchedulers = new LinkedList<>(); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.putInt(0, 3); + UserPayload userPayload = UserPayload.create(bb); + taskSchedulers.add( + new NamedEntityDescriptor(customSchedulerName, FakeTaskScheduler.class.getName()) + .setUserPayload(userPayload)); + taskSchedulers.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null) + .setUserPayload(defaultPayload)); + + TSEHForMultipleSchedulersTest tseh = + new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler, + mockSigMatcher, mockWebUIService, taskSchedulers, false); + + tseh.init(conf); + tseh.start(); + + // Verify that the YARN task scheduler is installed by default + assertTrue(tseh.getYarnSchedulerCreated()); + assertFalse(tseh.getUberSchedulerCreated()); + assertEquals(2, tseh.getNumCreateInvocations()); + + // Verify the order of the schedulers + assertEquals(customSchedulerName, tseh.getTaskSchedulerName(0)); + assertEquals(TezConstants.getTezYarnServicePluginName(), tseh.getTaskSchedulerName(1)); + + verify(tseh.getTestTaskScheduler(0)).initialize(); + verify(tseh.getTestTaskScheduler(0)).start(); + + ApplicationId appId = ApplicationId.newInstance(1000, 1); + TezDAGID dagId = TezDAGID.getInstance(appId, 1); + TezVertexID vertexID = TezVertexID.getInstance(dagId, 1); + TezTaskID taskId1 = TezTaskID.getInstance(vertexID, 1); + TezTaskAttemptID attemptId11 = TezTaskAttemptID.getInstance(taskId1, 1); + TezTaskID taskId2 = TezTaskID.getInstance(vertexID, 2); + TezTaskAttemptID attemptId21 = TezTaskAttemptID.getInstance(taskId2, 1); + + Resource resource = Resource.newInstance(1024, 1); + + TaskAttempt mockTaskAttempt1 = mock(TaskAttempt.class); + TaskAttempt mockTaskAttempt2 = mock(TaskAttempt.class); + + AMSchedulerEventTALaunchRequest launchRequest1 = + new AMSchedulerEventTALaunchRequest(attemptId11, resource, mock(TaskSpec.class), + mockTaskAttempt1, mock(TaskLocationHint.class), 1, mock(ContainerContext.class), 0, 0, + 0); + + tseh.handle(launchRequest1); + + verify(tseh.getTestTaskScheduler(0)).allocateTask(eq(mockTaskAttempt1), eq(resource), + any(String[].class), any(String[].class), any(Priority.class), any(Object.class), + eq(launchRequest1)); + + AMSchedulerEventTALaunchRequest launchRequest2 = + new AMSchedulerEventTALaunchRequest(attemptId21, resource, mock(TaskSpec.class), + mockTaskAttempt2, mock(TaskLocationHint.class), 1, mock(ContainerContext.class), 1, 0, + 0); + tseh.handle(launchRequest2); + verify(tseh.getTestTaskScheduler(1)).allocateTask(eq(mockTaskAttempt2), eq(resource), + any(String[].class), any(String[].class), any(Priority.class), any(Object.class), + eq(launchRequest2)); + } + + private static class TSEHForMultipleSchedulersTest extends TaskSchedulerEventHandler { + + private final TaskScheduler yarnTaskScheduler; + private final TaskScheduler uberTaskScheduler; + private final AtomicBoolean uberSchedulerCreated = new AtomicBoolean(false); + private final AtomicBoolean yarnSchedulerCreated = new AtomicBoolean(false); + private final AtomicInteger numCreateInvocations = new AtomicInteger(0); + private final Set<Integer> seenSchedulers = new HashSet<>(); + private final List<TaskSchedulerContext> taskSchedulerContexts = new LinkedList<>(); + private final List<String> taskSchedulerNames = new LinkedList<>(); + private final List<TaskScheduler> testTaskSchedulers = new LinkedList<>(); + + public TSEHForMultipleSchedulersTest(AppContext appContext, + DAGClientServer clientService, + EventHandler eventHandler, + ContainerSignatureMatcher containerSignatureMatcher, + WebUIService webUI, + List<NamedEntityDescriptor> schedulerDescriptors, + boolean isPureLocalMode) { + super(appContext, clientService, eventHandler, containerSignatureMatcher, webUI, + schedulerDescriptors, isPureLocalMode); + yarnTaskScheduler = mock(TaskScheduler.class); + uberTaskScheduler = mock(TaskScheduler.class); + } + + @Override + TaskScheduler createTaskScheduler(String host, int port, String trackingUrl, + AppContext appContext, + NamedEntityDescriptor taskSchedulerDescriptor, + long customAppIdIdentifier, + int schedulerId) { + + numCreateInvocations.incrementAndGet(); + boolean added = seenSchedulers.add(schedulerId); + assertTrue("Cannot add multiple schedulers with the same schedulerId", added); + taskSchedulerNames.add(taskSchedulerDescriptor.getEntityName()); + return super.createTaskScheduler(host, port, trackingUrl, appContext, taskSchedulerDescriptor, + customAppIdIdentifier, schedulerId); + } + + @Override + TaskSchedulerContext wrapTaskSchedulerContext(TaskSchedulerContext rawContext) { + // Avoid wrapping in threads + return rawContext; + } + + @Override + TaskScheduler createYarnTaskScheduler(TaskSchedulerContext taskSchedulerContext, int schedulerId) { + taskSchedulerContexts.add(taskSchedulerContext); + testTaskSchedulers.add(yarnTaskScheduler); + yarnSchedulerCreated.set(true); + return yarnTaskScheduler; + } + + @Override + TaskScheduler createUberTaskScheduler(TaskSchedulerContext taskSchedulerContext, int schedulerId) { + taskSchedulerContexts.add(taskSchedulerContext); + uberSchedulerCreated.set(true); + testTaskSchedulers.add(yarnTaskScheduler); + return uberTaskScheduler; + } + + @Override + TaskScheduler createCustomTaskScheduler(TaskSchedulerContext taskSchedulerContext, + NamedEntityDescriptor taskSchedulerDescriptor, int schedulerId) { + taskSchedulerContexts.add(taskSchedulerContext); + TaskScheduler taskScheduler = spy(super.createCustomTaskScheduler(taskSchedulerContext, taskSchedulerDescriptor, schedulerId)); + testTaskSchedulers.add(taskScheduler); + return taskScheduler; + } + + @Override + // Inline handling of events. + public void handle(AMSchedulerEvent event) { + handleEvent(event); + } + + public boolean getUberSchedulerCreated() { + return uberSchedulerCreated.get(); + } + + public boolean getYarnSchedulerCreated() { + return yarnSchedulerCreated.get(); + } + + public int getNumCreateInvocations() { + return numCreateInvocations.get(); + } + + public TaskSchedulerContext getTaskSchedulerContext(int schedulerId) { + return taskSchedulerContexts.get(schedulerId); + } + + public String getTaskSchedulerName(int schedulerId) { + return taskSchedulerNames.get(schedulerId); + } + + public TaskScheduler getTestTaskScheduler(int schedulerId) { + return testTaskSchedulers.get(schedulerId); + } + } + + public static class FakeTaskScheduler extends TaskScheduler { + + public FakeTaskScheduler( + TaskSchedulerContext taskSchedulerContext) { + super(taskSchedulerContext); + } + + @Override + public Resource getAvailableResources() { + return null; + } + + @Override + public int getClusterNodeCount() { + return 0; + } + + @Override + public void dagComplete() { + + } + + @Override + public Resource getTotalResources() { + return null; + } + + @Override + public void blacklistNode(NodeId nodeId) { + + } + + @Override + public void unblacklistNode(NodeId nodeId) { + + } + + @Override + public void allocateTask(Object task, Resource capability, String[] hosts, String[] racks, + Priority priority, Object containerSignature, Object clientCookie) { + + } + + @Override + public void allocateTask(Object task, Resource capability, ContainerId containerId, + Priority priority, Object containerSignature, Object clientCookie) { + + } + + @Override + public boolean deallocateTask(Object task, boolean taskSucceeded, + TaskAttemptEndReason endReason) { + return false; + } + + @Override + public Object deallocateContainer(ContainerId containerId) { + return null; + } + + @Override + public void setShouldUnregister() { + + } + + @Override + public boolean hasUnregistered() { + return false; + } + } } http://git-wip-us.apache.org/repos/asf/tez/blob/267fe737/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java index 59ab00a..0746507 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java @@ -42,6 +42,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.service.AbstractService; @@ -138,7 +139,8 @@ class TestTaskSchedulerHelpers { ContainerSignatureMatcher containerSignatureMatcher, UserPayload defaultPayload) { super(appContext, null, eventHandler, containerSignatureMatcher, null, - new LinkedList<NamedEntityDescriptor>(), defaultPayload, false); + Lists.newArrayList(new NamedEntityDescriptor("FakeScheduler", null)), + false); this.amrmClientAsync = amrmClientAsync; this.containerSignatureMatcher = containerSignatureMatcher; this.defaultPayload = defaultPayload;
