You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by ss...@apache.org on 2015/08/06 10:04:52 UTC
[1/2] tez git commit: TEZ-2126. Add unit tests for verifying multiple
schedulers, launchers, communicators. (sseth)
Repository: tez
Updated Branches:
refs/heads/TEZ-2003 0026ebecd -> bd6fcf95d
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java
index b4064a0..352ad87 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl2.java
@@ -28,17 +28,23 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import com.google.common.collect.BiMap;
+import com.google.common.collect.HashBiMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.util.Clock;
+import org.apache.tez.dag.api.DagTypeConverters;
import org.apache.tez.dag.api.TaskLocationHint;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.Vertex;
+import org.apache.tez.dag.api.Vertex.VertexExecutionContext;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.records.DAGProtos;
+import org.apache.tez.dag.api.records.DAGProtos.VertexPlan;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.ContainerContext;
import org.apache.tez.dag.app.TaskAttemptListener;
@@ -47,6 +53,7 @@ import org.apache.tez.dag.app.dag.DAG;
import org.apache.tez.dag.app.dag.StateChangeNotifier;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.dag.utils.TaskSpecificLaunchCmdOption;
+import org.apache.tez.runtime.api.ExecutionContext;
import org.junit.Test;
/**
@@ -60,7 +67,8 @@ public class TestVertexImpl2 {
Configuration conf = new TezConfiguration();
conf.set(TezConfiguration.TEZ_TASK_LOG_LEVEL, "DEBUG;org.apache.hadoop.ipc=INFO;org.apache.hadoop.server=INFO");
- LogTestInfoHolder testInfo = new LogTestInfoHolder(conf);
+ LogTestInfoHolder testInfo = new LogTestInfoHolder();
+ VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf);
List<String> expectedCommands = new LinkedList<String>();
expectedCommands.add("-Dlog4j.configuratorClass=org.apache.tez.common.TezLog4jConfigurator");
@@ -71,7 +79,8 @@ public class TestVertexImpl2 {
TezConstants.TEZ_CONTAINER_LOGGER_NAME);
for (int i = 0 ; i < testInfo.numTasks ; i++) {
- ContainerContext containerContext = testInfo.vertex.getContainerContext(i);
+ ContainerContext containerContext = vertexWrapper
+ .vertex.getContainerContext(i);
String javaOpts = containerContext.getJavaOpts();
assertTrue(javaOpts.contains(testInfo.initialJavaOpts));
for (String expectedCmd : expectedCommands) {
@@ -92,7 +101,8 @@ public class TestVertexImpl2 {
Configuration conf = new TezConfiguration();
conf.set(TezConfiguration.TEZ_TASK_LOG_LEVEL, "DEBUG");
- LogTestInfoHolder testInfo = new LogTestInfoHolder(conf);
+ LogTestInfoHolder testInfo = new LogTestInfoHolder();
+ VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf);
List<String> expectedCommands = new LinkedList<String>();
expectedCommands.add("-Dlog4j.configuratorClass=org.apache.tez.common.TezLog4jConfigurator");
@@ -103,7 +113,7 @@ public class TestVertexImpl2 {
TezConstants.TEZ_CONTAINER_LOGGER_NAME);
for (int i = 0 ; i < testInfo.numTasks ; i++) {
- ContainerContext containerContext = testInfo.vertex.getContainerContext(i);
+ ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i);
String javaOpts = containerContext.getJavaOpts();
assertTrue(javaOpts.contains(testInfo.initialJavaOpts));
for (String expectedCmd : expectedCommands) {
@@ -130,7 +140,8 @@ public class TestVertexImpl2 {
conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LOG_LEVEL, "DEBUG;org.apache.tez=INFO");
conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LAUNCH_CMD_OPTS, customJavaOpts);
- LogTestInfoHolder testInfo = new LogTestInfoHolder(conf);
+ LogTestInfoHolder testInfo = new LogTestInfoHolder();
+ VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf);
// Expected command opts for regular tasks
List<String> expectedCommands = new LinkedList<String>();
@@ -142,7 +153,7 @@ public class TestVertexImpl2 {
TezConstants.TEZ_CONTAINER_LOGGER_NAME);
for (int i = 3 ; i < testInfo.numTasks ; i++) {
- ContainerContext containerContext = testInfo.vertex.getContainerContext(i);
+ ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i);
String javaOpts = containerContext.getJavaOpts();
assertTrue(javaOpts.contains(testInfo.initialJavaOpts));
@@ -167,7 +178,7 @@ public class TestVertexImpl2 {
TezConstants.TEZ_CONTAINER_LOGGER_NAME);
for (int i = 0 ; i < 3 ; i++) {
- ContainerContext containerContext = testInfo.vertex.getContainerContext(i);
+ ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i);
String javaOpts = containerContext.getJavaOpts();
assertTrue(javaOpts.contains(testInfo.initialJavaOpts));
@@ -195,7 +206,8 @@ public class TestVertexImpl2 {
conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LOG_LEVEL, "DEBUG");
conf.set(TezConfiguration.TEZ_TASK_SPECIFIC_LAUNCH_CMD_OPTS, customJavaOpts);
- LogTestInfoHolder testInfo = new LogTestInfoHolder(conf);
+ LogTestInfoHolder testInfo = new LogTestInfoHolder();
+ VertexWrapper vertexWrapper = createVertexWrapperForLogTests(testInfo, conf);
// Expected command opts for regular tasks
List<String> expectedCommands = new LinkedList<String>();
@@ -207,7 +219,7 @@ public class TestVertexImpl2 {
TezConstants.TEZ_CONTAINER_LOGGER_NAME);
for (int i = 3 ; i < testInfo.numTasks ; i++) {
- ContainerContext containerContext = testInfo.vertex.getContainerContext(i);
+ ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i);
String javaOpts = containerContext.getJavaOpts();
assertTrue(javaOpts.contains(testInfo.initialJavaOpts));
@@ -232,7 +244,7 @@ public class TestVertexImpl2 {
TezConstants.TEZ_CONTAINER_LOGGER_NAME);
for (int i = 0 ; i < 3 ; i++) {
- ContainerContext containerContext = testInfo.vertex.getContainerContext(i);
+ ContainerContext containerContext = vertexWrapper.vertex.getContainerContext(i);
String javaOpts = containerContext.getJavaOpts();
assertTrue(javaOpts.contains(testInfo.initialJavaOpts));
@@ -248,43 +260,224 @@ public class TestVertexImpl2 {
}
}
+ @Test(timeout = 5000)
+ public void testNullExecutionContexts() {
- private static class LogTestInfoHolder {
+ ExecutionContextTestInfoHolder info = new ExecutionContextTestInfoHolder(null, null);
+ VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info);
- final AppContext mockAppContext;
- final DAG mockDag;
- final VertexImpl vertex;
- final DAGProtos.VertexPlan vertexPlan;
+ assertEquals(0, vertexWrapper.vertex.taskSchedulerIdentifier);
+ assertEquals(0, vertexWrapper.vertex.containerLauncherIdentifier);
+ assertEquals(0, vertexWrapper.vertex.taskCommunicatorIdentifier);
+ }
+
+ @Test(timeout = 5000)
+ public void testDefaultExecContextViaDag() {
+ VertexExecutionContext defaultExecContext = VertexExecutionContext.create(
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 0),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 2),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 2));
+ ExecutionContextTestInfoHolder info =
+ new ExecutionContextTestInfoHolder(null, defaultExecContext, 3);
+ VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info);
+
+ assertEquals(0, vertexWrapper.vertex.taskSchedulerIdentifier);
+ assertEquals(2, vertexWrapper.vertex.containerLauncherIdentifier);
+ assertEquals(2, vertexWrapper.vertex.taskCommunicatorIdentifier);
+ }
+
+ @Test(timeout = 5000)
+ public void testVertexExecutionContextOnly() {
+ VertexExecutionContext vertexExecutionContext = VertexExecutionContext.create(
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 1),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 1),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 1));
+ ExecutionContextTestInfoHolder info =
+ new ExecutionContextTestInfoHolder(vertexExecutionContext, null, 3);
+ VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info);
+
+ assertEquals(1, vertexWrapper.vertex.taskSchedulerIdentifier);
+ assertEquals(1, vertexWrapper.vertex.containerLauncherIdentifier);
+ assertEquals(1, vertexWrapper.vertex.taskCommunicatorIdentifier);
+ }
+
+ @Test(timeout = 5000)
+ public void testVertexExecutionContextOverride() {
+ VertexExecutionContext defaultExecContext = VertexExecutionContext.create(
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 0),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 2),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 2));
+
+ VertexExecutionContext vertexExecutionContext = VertexExecutionContext.create(
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_SCHEDULER_NAME_BASE, 1),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.CONTAINER_LAUNCHER_NAME_BASE, 1),
+ ExecutionContextTestInfoHolder
+ .append(ExecutionContextTestInfoHolder.TASK_COMM_NAME_BASE, 1));
+ ExecutionContextTestInfoHolder info =
+ new ExecutionContextTestInfoHolder(vertexExecutionContext, defaultExecContext, 3);
+ VertexWrapper vertexWrapper = createVertexWrapperForExecutionContextTest(info);
+
+ assertEquals(1, vertexWrapper.vertex.taskSchedulerIdentifier);
+ assertEquals(1, vertexWrapper.vertex.containerLauncherIdentifier);
+ assertEquals(1, vertexWrapper.vertex.taskCommunicatorIdentifier);
+ }
+
+
+ private static class ExecutionContextTestInfoHolder {
+
+ static final String TASK_SCHEDULER_NAME_BASE = "TASK_SCHEDULER";
+ static final String CONTAINER_LAUNCHER_NAME_BASE = "CONTAINER_LAUNCHER";
+ static final String TASK_COMM_NAME_BASE = "TASK_COMMUNICATOR";
+
+ static String append(String base, int index) {
+ return base + index;
+ }
+
+ final String vertexName;
+ final VertexExecutionContext defaultExecutionContext;
+ final VertexExecutionContext vertexExecutionContext;
+ final BiMap<String, Integer> taskSchedulers = HashBiMap.create();
+ final BiMap<String, Integer> containerLaunchers = HashBiMap.create();
+ final BiMap<String, Integer> taskComms = HashBiMap.create();
+ final AppContext appContext;
+
+ public ExecutionContextTestInfoHolder(VertexExecutionContext vertexExecutionContext,
+ VertexExecutionContext defaultDagExecutionContext) {
+ this(vertexExecutionContext, defaultDagExecutionContext, 0);
+ }
+
+ public ExecutionContextTestInfoHolder(VertexExecutionContext vertexExecutionContext,
+ VertexExecutionContext defaultDagExecitionContext,
+ int numPlugins) {
+ this.vertexName = "testvertex";
+ this.vertexExecutionContext = vertexExecutionContext;
+ this.defaultExecutionContext = defaultDagExecitionContext;
+ if (numPlugins == 0) {
+ this.taskSchedulers.put(TezConstants.getTezYarnServicePluginName(), 0);
+ this.containerLaunchers.put(TezConstants.getTezYarnServicePluginName(), 0);
+ this.taskSchedulers.put(TezConstants.getTezYarnServicePluginName(), 0);
+ } else {
+ for (int i = 0; i < numPlugins; i++) {
+ this.taskSchedulers.put(append(TASK_SCHEDULER_NAME_BASE, i), i);
+ this.containerLaunchers.put(append(CONTAINER_LAUNCHER_NAME_BASE, i), i);
+ this.taskComms.put(append(TASK_COMM_NAME_BASE, i), i);
+ }
+ }
+
+ this.appContext = createDefaultMockAppContext();
+ DAG dag = appContext.getCurrentDAG();
+ doReturn(defaultDagExecitionContext).when(dag).getDefaultExecutionContext();
+ for (Map.Entry<String, Integer> entry : taskSchedulers.entrySet()) {
+ doReturn(entry.getKey()).when(appContext).getTaskSchedulerName(entry.getValue());
+ doReturn(entry.getValue()).when(appContext).getTaskScheduerIdentifier(entry.getKey());
+ }
+ for (Map.Entry<String, Integer> entry : containerLaunchers.entrySet()) {
+ doReturn(entry.getKey()).when(appContext).getContainerLauncherName(entry.getValue());
+ doReturn(entry.getValue()).when(appContext).getContainerLauncherIdentifier(entry.getKey());
+ }
+ for (Map.Entry<String, Integer> entry : taskComms.entrySet()) {
+ doReturn(entry.getKey()).when(appContext).getTaskCommunicatorName(entry.getValue());
+ doReturn(entry.getValue()).when(appContext).getTaskCommunicatorIdentifier(entry.getKey());
+ }
+ }
+ }
+ private VertexWrapper createVertexWrapperForExecutionContextTest(
+ ExecutionContextTestInfoHolder vertexInfo) {
+ VertexPlan vertexPlan = createVertexPlanForExeuctionContextTests(vertexInfo);
+ VertexWrapper vertexWrapper =
+ new VertexWrapper(vertexInfo.appContext, vertexPlan, new Configuration(false));
+ return vertexWrapper;
+ }
+
+ private VertexPlan createVertexPlanForExeuctionContextTests(ExecutionContextTestInfoHolder info) {
+ VertexPlan.Builder vertexPlanBuilder = VertexPlan.newBuilder()
+ .setName(info.vertexName)
+ .setTaskConfig(DAGProtos.PlanTaskConfiguration.newBuilder()
+ .setNumTasks(10)
+ .setJavaOpts("dontcare")
+ .setMemoryMb(1024)
+ .setVirtualCores(1)
+ .setTaskModule("taskmodule")
+ .build())
+ .setType(DAGProtos.PlanVertexType.NORMAL);
+ if (info.vertexExecutionContext != null) {
+ vertexPlanBuilder
+ .setExecutionContext(DagTypeConverters.convertToProto(info.vertexExecutionContext));
+ }
+ return vertexPlanBuilder.build();
+ }
+
+ private static class LogTestInfoHolder {
final int numTasks = 10;
final String initialJavaOpts = "initialJavaOpts";
final String envKey = "key1";
final String envVal = "val1";
+ final String vertexName;
+
+ public LogTestInfoHolder() {
+ this("testvertex");
+ }
- LogTestInfoHolder(Configuration conf) {
- this(conf, "testvertex");
+ public LogTestInfoHolder(String vertexName) {
+ this.vertexName = vertexName;
}
+ }
+
+ private VertexWrapper createVertexWrapperForLogTests(LogTestInfoHolder logTestInfoHolder,
+ Configuration conf) {
+ VertexPlan vertexPlan = createVertexPlanForLogTests(logTestInfoHolder);
+ VertexWrapper vertexWrapper = new VertexWrapper(vertexPlan, conf);
+ return vertexWrapper;
+ }
+
+ private VertexPlan createVertexPlanForLogTests(LogTestInfoHolder logTestInfoHolder) {
+ VertexPlan vertexPlan = VertexPlan.newBuilder()
+ .setName(logTestInfoHolder.vertexName)
+ .setTaskConfig(DAGProtos.PlanTaskConfiguration.newBuilder()
+ .setJavaOpts(logTestInfoHolder.initialJavaOpts)
+ .setNumTasks(logTestInfoHolder.numTasks)
+ .setMemoryMb(1024)
+ .setVirtualCores(1)
+ .setTaskModule("taskmodule")
+ .addEnvironmentSetting(DAGProtos.PlanKeyValuePair.newBuilder()
+ .setKey(logTestInfoHolder.envKey)
+ .setValue(logTestInfoHolder.envVal)
+ .build())
+ .build())
+ .setType(DAGProtos.PlanVertexType.NORMAL).build();
+ return vertexPlan;
+ }
+
+ private static class VertexWrapper {
- LogTestInfoHolder(Configuration conf, String vertexName) {
- mockAppContext = mock(AppContext.class);
- mockDag = mock(DAG.class);
- doReturn(new Credentials()).when(mockDag).getCredentials();
- doReturn(mockDag).when(mockAppContext).getCurrentDAG();
-
- vertexPlan = DAGProtos.VertexPlan.newBuilder()
- .setName(vertexName)
- .setTaskConfig(DAGProtos.PlanTaskConfiguration.newBuilder()
- .setJavaOpts(initialJavaOpts)
- .setNumTasks(numTasks)
- .setMemoryMb(1024)
- .setVirtualCores(1)
- .setTaskModule("taskmodule")
- .addEnvironmentSetting(DAGProtos.PlanKeyValuePair.newBuilder()
- .setKey(envKey)
- .setValue(envVal)
- .build())
- .build())
- .setType(DAGProtos.PlanVertexType.NORMAL).build();
+ final AppContext mockAppContext;
+ final VertexImpl vertex;
+ final VertexPlan vertexPlan;
+
+ VertexWrapper(AppContext appContext, VertexPlan vertexPlan, Configuration conf) {
+ if (appContext == null) {
+ mockAppContext = createDefaultMockAppContext();
+ DAG mockDag = mock(DAG.class);
+ doReturn(new Credentials()).when(mockDag).getCredentials();
+ doReturn(mockDag).when(mockAppContext).getCurrentDAG();
+ } else {
+ mockAppContext = appContext;
+ }
+
+
+ this.vertexPlan = vertexPlan;
vertex =
new VertexImpl(TezVertexID.fromString("vertex_1418197758681_0001_1_00"), vertexPlan,
@@ -293,5 +486,17 @@ public class TestVertexImpl2 {
VertexLocationHint.create(new LinkedList<TaskLocationHint>()), null,
new TaskSpecificLaunchCmdOption(conf), mock(StateChangeNotifier.class));
}
+
+ VertexWrapper(VertexPlan vertexPlan, Configuration conf) {
+ this(null, vertexPlan, conf);
+ }
+ }
+
+ private static AppContext createDefaultMockAppContext() {
+ AppContext appContext = mock(AppContext.class);
+ DAG mockDag = mock(DAG.class);
+ doReturn(new Credentials()).when(mockDag).getCredentials();
+ doReturn(mockDag).when(appContext).getCurrentDAG();
+ return appContext;
}
}
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java b/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java
new file mode 100644
index 0000000..62a5f19
--- /dev/null
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherRouter.java
@@ -0,0 +1,361 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.launcher;
+
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.net.UnknownHostException;
+import java.nio.ByteBuffer;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.NamedEntityDescriptor;
+import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.app.AppContext;
+import org.apache.tez.dag.app.TaskAttemptListener;
+import org.apache.tez.dag.app.rm.NMCommunicatorLaunchRequestEvent;
+import org.apache.tez.serviceplugins.api.ContainerLaunchRequest;
+import org.apache.tez.serviceplugins.api.ContainerLauncher;
+import org.apache.tez.serviceplugins.api.ContainerLauncherContext;
+import org.apache.tez.serviceplugins.api.ContainerStopRequest;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+public class TestContainerLauncherRouter {
+
+ @Before
+ @After
+ public void reset() {
+ ContainerLaucherRouterForMultipleLauncherTest.reset();
+ }
+
+ @Test(timeout = 5000)
+ public void testNoLaunchersSpecified() throws IOException {
+
+ AppContext appContext = mock(AppContext.class);
+ TaskAttemptListener tal = mock(TaskAttemptListener.class);
+
+ try {
+
+ new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null, null,
+ false);
+ fail("Expecting a failure without any launchers being specified");
+ } catch (IllegalArgumentException e) {
+
+ }
+ }
+
+ @Test(timeout = 5000)
+ public void testCustomLauncherSpecified() throws IOException {
+ Configuration conf = new Configuration(false);
+
+ AppContext appContext = mock(AppContext.class);
+ TaskAttemptListener tal = mock(TaskAttemptListener.class);
+
+ String customLauncherName = "customLauncher";
+ List<NamedEntityDescriptor> launcherDescriptors = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload customPayload = UserPayload.create(bb);
+ launcherDescriptors.add(
+ new NamedEntityDescriptor(customLauncherName, FakeContainerLauncher.class.getName())
+ .setUserPayload(customPayload));
+
+ ContainerLaucherRouterForMultipleLauncherTest clr =
+ new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null,
+ launcherDescriptors,
+ true);
+ try {
+ clr.init(conf);
+ clr.start();
+
+ assertEquals(1, clr.getNumContainerLaunchers());
+ assertFalse(clr.getYarnContainerLauncherCreated());
+ assertFalse(clr.getUberContainerLauncherCreated());
+ assertEquals(customLauncherName, clr.getContainerLauncherName(0));
+ assertEquals(bb, clr.getContainerLauncherContext(0).getInitialUserPayload().getPayload());
+ } finally {
+ clr.stop();
+ }
+ }
+
+ @Test(timeout = 5000)
+ public void testMultipleContainerLaunchers() throws IOException {
+ Configuration conf = new Configuration(false);
+ conf.set("testkey", "testvalue");
+ UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ AppContext appContext = mock(AppContext.class);
+ TaskAttemptListener tal = mock(TaskAttemptListener.class);
+
+ String customLauncherName = "customLauncher";
+ List<NamedEntityDescriptor> launcherDescriptors = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload customPayload = UserPayload.create(bb);
+ launcherDescriptors.add(
+ new NamedEntityDescriptor(customLauncherName, FakeContainerLauncher.class.getName())
+ .setUserPayload(customPayload));
+ launcherDescriptors
+ .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(userPayload));
+
+ ContainerLaucherRouterForMultipleLauncherTest clr =
+ new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null,
+ launcherDescriptors,
+ true);
+ try {
+ clr.init(conf);
+ clr.start();
+
+ assertEquals(2, clr.getNumContainerLaunchers());
+ assertTrue(clr.getYarnContainerLauncherCreated());
+ assertFalse(clr.getUberContainerLauncherCreated());
+ assertEquals(customLauncherName, clr.getContainerLauncherName(0));
+ assertEquals(bb, clr.getContainerLauncherContext(0).getInitialUserPayload().getPayload());
+
+ assertEquals(TezConstants.getTezYarnServicePluginName(), clr.getContainerLauncherName(1));
+ Configuration confParsed = TezUtils
+ .createConfFromUserPayload(clr.getContainerLauncherContext(1).getInitialUserPayload());
+ assertEquals("testvalue", confParsed.get("testkey"));
+ } finally {
+ clr.stop();
+ }
+ }
+
+ @Test(timeout = 5000)
+ public void testEventRouting() throws Exception {
+ Configuration conf = new Configuration(false);
+ UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ AppContext appContext = mock(AppContext.class);
+ TaskAttemptListener tal = mock(TaskAttemptListener.class);
+
+ String customLauncherName = "customLauncher";
+ List<NamedEntityDescriptor> launcherDescriptors = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload customPayload = UserPayload.create(bb);
+ launcherDescriptors.add(
+ new NamedEntityDescriptor(customLauncherName, FakeContainerLauncher.class.getName())
+ .setUserPayload(customPayload));
+ launcherDescriptors
+ .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(userPayload));
+
+ ContainerLaucherRouterForMultipleLauncherTest clr =
+ new ContainerLaucherRouterForMultipleLauncherTest(appContext, tal, null,
+ launcherDescriptors,
+ true);
+ try {
+ clr.init(conf);
+ clr.start();
+
+ assertEquals(2, clr.getNumContainerLaunchers());
+ assertTrue(clr.getYarnContainerLauncherCreated());
+ assertFalse(clr.getUberContainerLauncherCreated());
+ assertEquals(customLauncherName, clr.getContainerLauncherName(0));
+ assertEquals(TezConstants.getTezYarnServicePluginName(), clr.getContainerLauncherName(1));
+
+ verify(clr.getTestContainerLauncher(0)).initialize();
+ verify(clr.getTestContainerLauncher(0)).start();
+ verify(clr.getTestContainerLauncher(1)).initialize();
+ verify(clr.getTestContainerLauncher(1)).start();
+
+ ContainerLaunchContext clc1 = mock(ContainerLaunchContext.class);
+ Container container1 = mock(Container.class);
+
+ ContainerLaunchContext clc2 = mock(ContainerLaunchContext.class);
+ Container container2 = mock(Container.class);
+
+ NMCommunicatorLaunchRequestEvent launchRequestEvent1 =
+ new NMCommunicatorLaunchRequestEvent(clc1, container1, 0, 0, 0);
+ NMCommunicatorLaunchRequestEvent launchRequestEvent2 =
+ new NMCommunicatorLaunchRequestEvent(clc2, container2, 1, 0, 0);
+
+ clr.handle(launchRequestEvent1);
+
+
+ ArgumentCaptor<ContainerLaunchRequest> captor =
+ ArgumentCaptor.forClass(ContainerLaunchRequest.class);
+ verify(clr.getTestContainerLauncher(0)).launchContainer(captor.capture());
+ assertEquals(1, captor.getAllValues().size());
+ ContainerLaunchRequest launchRequest1 = captor.getValue();
+ assertEquals(clc1, launchRequest1.getContainerLaunchContext());
+
+ clr.handle(launchRequestEvent2);
+ captor = ArgumentCaptor.forClass(ContainerLaunchRequest.class);
+ verify(clr.getTestContainerLauncher(1)).launchContainer(captor.capture());
+ assertEquals(1, captor.getAllValues().size());
+ ContainerLaunchRequest launchRequest2 = captor.getValue();
+ assertEquals(clc2, launchRequest2.getContainerLaunchContext());
+
+ } finally {
+ clr.stop();
+ verify(clr.getTestContainerLauncher(0)).shutdown();
+ verify(clr.getTestContainerLauncher(1)).shutdown();
+ }
+ }
+
+ private static class ContainerLaucherRouterForMultipleLauncherTest
+ extends ContainerLauncherRouter {
+
+ // All variables setup as static since methods being overridden are invoked by the ContainerLauncherRouter ctor,
+ // and regular variables will not be initialized at this point.
+ private static final AtomicInteger numContainerLaunchers = new AtomicInteger(0);
+ private static final Set<Integer> containerLauncherIndices = new HashSet<>();
+ private static final ContainerLauncher yarnContainerLauncher = mock(ContainerLauncher.class);
+ private static final ContainerLauncher uberContainerlauncher = mock(ContainerLauncher.class);
+ private static final AtomicBoolean yarnContainerLauncherCreated = new AtomicBoolean(false);
+ private static final AtomicBoolean uberContainerLauncherCreated = new AtomicBoolean(false);
+
+ private static final List<ContainerLauncherContext> containerLauncherContexts =
+ new LinkedList<>();
+ private static final List<String> containerLauncherNames = new LinkedList<>();
+ private static final List<ContainerLauncher> testContainerLaunchers = new LinkedList<>();
+
+
+ public static void reset() {
+ numContainerLaunchers.set(0);
+ containerLauncherIndices.clear();
+ yarnContainerLauncherCreated.set(false);
+ uberContainerLauncherCreated.set(false);
+ containerLauncherContexts.clear();
+ containerLauncherNames.clear();
+ testContainerLaunchers.clear();
+ }
+
+ public ContainerLaucherRouterForMultipleLauncherTest(AppContext context,
+ TaskAttemptListener taskAttemptListener,
+ String workingDirectory,
+ List<NamedEntityDescriptor> containerLauncherDescriptors,
+ boolean isPureLocalMode) throws
+ UnknownHostException {
+ super(context, taskAttemptListener, workingDirectory,
+ containerLauncherDescriptors, isPureLocalMode);
+ }
+
+ @Override
+ ContainerLauncher createContainerLauncher(NamedEntityDescriptor containerLauncherDescriptor,
+ AppContext context,
+ ContainerLauncherContext containerLauncherContext,
+ TaskAttemptListener taskAttemptListener,
+ String workingDirectory,
+ int containerLauncherIndex,
+ boolean isPureLocalMode) throws
+ UnknownHostException {
+ numContainerLaunchers.incrementAndGet();
+ boolean added = containerLauncherIndices.add(containerLauncherIndex);
+ assertTrue("Cannot add multiple launchers with the same index", added);
+ containerLauncherNames.add(containerLauncherDescriptor.getEntityName());
+ containerLauncherContexts.add(containerLauncherContext);
+ return super
+ .createContainerLauncher(containerLauncherDescriptor, context, containerLauncherContext,
+ taskAttemptListener, workingDirectory, containerLauncherIndex, isPureLocalMode);
+ }
+
+ @Override
+ ContainerLauncher createYarnContainerLauncher(
+ ContainerLauncherContext containerLauncherContext) {
+ yarnContainerLauncherCreated.set(true);
+ testContainerLaunchers.add(yarnContainerLauncher);
+ return yarnContainerLauncher;
+ }
+
+ @Override
+ ContainerLauncher createUberContainerLauncher(ContainerLauncherContext containerLauncherContext,
+ AppContext context,
+ TaskAttemptListener taskAttemptListener,
+ String workingDirectory,
+ boolean isPureLocalMode) throws
+ UnknownHostException {
+ uberContainerLauncherCreated.set(true);
+ testContainerLaunchers.add(uberContainerlauncher);
+ return uberContainerlauncher;
+ }
+
+ @Override
+ ContainerLauncher createCustomContainerLauncher(
+ ContainerLauncherContext containerLauncherContext,
+ NamedEntityDescriptor containerLauncherDescriptor) {
+ ContainerLauncher spyLauncher = spy(super.createCustomContainerLauncher(
+ containerLauncherContext, containerLauncherDescriptor));
+ testContainerLaunchers.add(spyLauncher);
+ return spyLauncher;
+ }
+
+ public int getNumContainerLaunchers() {
+ return numContainerLaunchers.get();
+ }
+
+ public boolean getYarnContainerLauncherCreated() {
+ return yarnContainerLauncherCreated.get();
+ }
+
+ public boolean getUberContainerLauncherCreated() {
+ return uberContainerLauncherCreated.get();
+ }
+
+ public String getContainerLauncherName(int containerLauncherIndex) {
+ return containerLauncherNames.get(containerLauncherIndex);
+ }
+
+ public ContainerLauncher getTestContainerLauncher(int containerLauncherIndex) {
+ return testContainerLaunchers.get(containerLauncherIndex);
+ }
+
+ public ContainerLauncherContext getContainerLauncherContext(int containerLauncherIndex) {
+ return containerLauncherContexts.get(containerLauncherIndex);
+ }
+ }
+
+ private static class FakeContainerLauncher extends ContainerLauncher {
+
+ public FakeContainerLauncher(
+ ContainerLauncherContext containerLauncherContext) {
+ super(containerLauncherContext);
+ }
+
+ @Override
+ public void launchContainer(ContainerLaunchRequest launchRequest) {
+
+ }
+
+ @Override
+ public void stopContainer(ContainerStopRequest stopRequest) {
+
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java
index f8aa1e2..3e68a4c 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerEventHandler.java
@@ -19,22 +19,30 @@
package org.apache.tez.dag.app.rm;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
+import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Credentials;
@@ -44,6 +52,7 @@ import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.event.Event;
@@ -53,12 +62,13 @@ import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.NamedEntityDescriptor;
import org.apache.tez.dag.api.TaskLocationHint;
import org.apache.tez.dag.api.TezConfiguration;
-import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.api.TezConstants;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.client.DAGClientServer;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.ContainerContext;
import org.apache.tez.dag.app.ServicePluginLifecycleAbstractService;
+import org.apache.tez.dag.app.dag.TaskAttempt;
import org.apache.tez.dag.app.dag.impl.TaskAttemptImpl;
import org.apache.tez.dag.app.dag.impl.TaskImpl;
import org.apache.tez.dag.app.dag.impl.VertexImpl;
@@ -70,8 +80,14 @@ import org.apache.tez.dag.app.rm.container.AMContainerMap;
import org.apache.tez.dag.app.rm.container.AMContainerState;
import org.apache.tez.dag.app.web.WebUIService;
import org.apache.tez.dag.records.TaskAttemptTerminationCause;
+import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.impl.TaskSpec;
+import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskScheduler;
+import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -95,10 +111,9 @@ public class TestTaskSchedulerEventHandler {
public MockTaskSchedulerEventHandler(AppContext appContext,
DAGClientServer clientService, EventHandler eventHandler,
- ContainerSignatureMatcher containerSignatureMatcher, WebUIService webUI,
- UserPayload defaultPayload) {
+ ContainerSignatureMatcher containerSignatureMatcher, WebUIService webUI) {
super(appContext, clientService, eventHandler, containerSignatureMatcher, webUI,
- new LinkedList<NamedEntityDescriptor>(), defaultPayload, false);
+ Lists.newArrayList(new NamedEntityDescriptor("FakeDescriptor", null)), false);
}
@Override
@@ -140,14 +155,8 @@ public class TestTaskSchedulerEventHandler {
when(mockAppContext.getAllContainers()).thenReturn(mockAMContainerMap);
when(mockClientService.getBindAddress()).thenReturn(new InetSocketAddress(10000));
Configuration conf = new Configuration(false);
- UserPayload userPayload;
- try {
- userPayload = TezUtils.createUserPayloadFromConf(conf);
- } catch (IOException e) {
- throw new TezUncheckedException(e);
- }
schedulerHandler = new MockTaskSchedulerEventHandler(
- mockAppContext, mockClientService, mockEventHandler, mockSigMatcher, mockWebUIService, userPayload);
+ mockAppContext, mockClientService, mockEventHandler, mockSigMatcher, mockWebUIService);
}
@Test(timeout = 5000)
@@ -272,7 +281,7 @@ public class TestTaskSchedulerEventHandler {
when(mockAmContainer.getContainerLauncherIdentifier()).thenReturn(0);
when(mockAmContainer.getTaskCommunicatorIdentifier()).thenReturn(0);
ContainerId mockCId = mock(ContainerId.class);
- verify(mockTaskScheduler, times(0)).deallocateContainer((ContainerId)any());
+ verify(mockTaskScheduler, times(0)).deallocateContainer((ContainerId) any());
when(mockAMContainerMap.get(mockCId)).thenReturn(mockAmContainer);
schedulerHandler.preemptContainer(0, mockCId);
verify(mockTaskScheduler, times(1)).deallocateContainer(mockCId);
@@ -400,5 +409,300 @@ public class TestTaskSchedulerEventHandler {
}
- // TODO TEZ-2003. Add tests with multiple schedulers, and ensuring that events go out with correct IDs.
+ @Test(timeout = 5000)
+ public void testNoSchedulerSpecified() throws IOException {
+ try {
+ TSEHForMultipleSchedulersTest tseh =
+ new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler,
+ mockSigMatcher, mockWebUIService, null, false);
+ fail("Expecting an IllegalStateException with no schedulers specified");
+ } catch (IllegalArgumentException e) {
+ }
+ }
+
+ // Verified via statics
+ @Test(timeout = 5000)
+ public void testCustomTaskSchedulerSetup() throws IOException {
+ Configuration conf = new Configuration(false);
+ conf.set("testkey", "testval");
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ String customSchedulerName = "fakeScheduler";
+ List<NamedEntityDescriptor> taskSchedulers = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload userPayload = UserPayload.create(bb);
+ taskSchedulers.add(
+ new NamedEntityDescriptor(customSchedulerName, FakeTaskScheduler.class.getName())
+ .setUserPayload(userPayload));
+ taskSchedulers.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(defaultPayload));
+
+ TSEHForMultipleSchedulersTest tseh =
+ new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler,
+ mockSigMatcher, mockWebUIService, taskSchedulers, false);
+
+ tseh.init(conf);
+ tseh.start();
+
+ // Verify that the YARN task scheduler is installed by default
+ assertTrue(tseh.getYarnSchedulerCreated());
+ assertFalse(tseh.getUberSchedulerCreated());
+ assertEquals(2, tseh.getNumCreateInvocations());
+
+ // Verify the order of the schedulers
+ assertEquals(customSchedulerName, tseh.getTaskSchedulerName(0));
+ assertEquals(TezConstants.getTezYarnServicePluginName(), tseh.getTaskSchedulerName(1));
+
+ // Verify the payload setup for the custom task scheduler
+ assertNotNull(tseh.getTaskSchedulerContext(0));
+ assertEquals(bb, tseh.getTaskSchedulerContext(0).getInitialUserPayload().getPayload());
+
+ // Verify the payload on the yarn scheduler
+ assertNotNull(tseh.getTaskSchedulerContext(1));
+ Configuration parsed = TezUtils.createConfFromUserPayload(tseh.getTaskSchedulerContext(1).getInitialUserPayload());
+ assertEquals("testval", parsed.get("testkey"));
+ }
+
+ @Test(timeout = 5000)
+ public void testTaskSchedulerRouting() throws Exception {
+ Configuration conf = new Configuration(false);
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ String customSchedulerName = "fakeScheduler";
+ List<NamedEntityDescriptor> taskSchedulers = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload userPayload = UserPayload.create(bb);
+ taskSchedulers.add(
+ new NamedEntityDescriptor(customSchedulerName, FakeTaskScheduler.class.getName())
+ .setUserPayload(userPayload));
+ taskSchedulers.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(defaultPayload));
+
+ TSEHForMultipleSchedulersTest tseh =
+ new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler,
+ mockSigMatcher, mockWebUIService, taskSchedulers, false);
+
+ tseh.init(conf);
+ tseh.start();
+
+ // Verify that the YARN task scheduler is installed by default
+ assertTrue(tseh.getYarnSchedulerCreated());
+ assertFalse(tseh.getUberSchedulerCreated());
+ assertEquals(2, tseh.getNumCreateInvocations());
+
+ // Verify the order of the schedulers
+ assertEquals(customSchedulerName, tseh.getTaskSchedulerName(0));
+ assertEquals(TezConstants.getTezYarnServicePluginName(), tseh.getTaskSchedulerName(1));
+
+ verify(tseh.getTestTaskScheduler(0)).initialize();
+ verify(tseh.getTestTaskScheduler(0)).start();
+
+ ApplicationId appId = ApplicationId.newInstance(1000, 1);
+ TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+ TezVertexID vertexID = TezVertexID.getInstance(dagId, 1);
+ TezTaskID taskId1 = TezTaskID.getInstance(vertexID, 1);
+ TezTaskAttemptID attemptId11 = TezTaskAttemptID.getInstance(taskId1, 1);
+ TezTaskID taskId2 = TezTaskID.getInstance(vertexID, 2);
+ TezTaskAttemptID attemptId21 = TezTaskAttemptID.getInstance(taskId2, 1);
+
+ Resource resource = Resource.newInstance(1024, 1);
+
+ TaskAttempt mockTaskAttempt1 = mock(TaskAttempt.class);
+ TaskAttempt mockTaskAttempt2 = mock(TaskAttempt.class);
+
+ AMSchedulerEventTALaunchRequest launchRequest1 =
+ new AMSchedulerEventTALaunchRequest(attemptId11, resource, mock(TaskSpec.class),
+ mockTaskAttempt1, mock(TaskLocationHint.class), 1, mock(ContainerContext.class), 0, 0,
+ 0);
+
+ tseh.handle(launchRequest1);
+
+ verify(tseh.getTestTaskScheduler(0)).allocateTask(eq(mockTaskAttempt1), eq(resource),
+ any(String[].class), any(String[].class), any(Priority.class), any(Object.class),
+ eq(launchRequest1));
+
+ AMSchedulerEventTALaunchRequest launchRequest2 =
+ new AMSchedulerEventTALaunchRequest(attemptId21, resource, mock(TaskSpec.class),
+ mockTaskAttempt2, mock(TaskLocationHint.class), 1, mock(ContainerContext.class), 1, 0,
+ 0);
+ tseh.handle(launchRequest2);
+ verify(tseh.getTestTaskScheduler(1)).allocateTask(eq(mockTaskAttempt2), eq(resource),
+ any(String[].class), any(String[].class), any(Priority.class), any(Object.class),
+ eq(launchRequest2));
+ }
+
+ private static class TSEHForMultipleSchedulersTest extends TaskSchedulerEventHandler {
+
+ private final TaskScheduler yarnTaskScheduler;
+ private final TaskScheduler uberTaskScheduler;
+ private final AtomicBoolean uberSchedulerCreated = new AtomicBoolean(false);
+ private final AtomicBoolean yarnSchedulerCreated = new AtomicBoolean(false);
+ private final AtomicInteger numCreateInvocations = new AtomicInteger(0);
+ private final Set<Integer> seenSchedulers = new HashSet<>();
+ private final List<TaskSchedulerContext> taskSchedulerContexts = new LinkedList<>();
+ private final List<String> taskSchedulerNames = new LinkedList<>();
+ private final List<TaskScheduler> testTaskSchedulers = new LinkedList<>();
+
+ public TSEHForMultipleSchedulersTest(AppContext appContext,
+ DAGClientServer clientService,
+ EventHandler eventHandler,
+ ContainerSignatureMatcher containerSignatureMatcher,
+ WebUIService webUI,
+ List<NamedEntityDescriptor> schedulerDescriptors,
+ boolean isPureLocalMode) {
+ super(appContext, clientService, eventHandler, containerSignatureMatcher, webUI,
+ schedulerDescriptors, isPureLocalMode);
+ yarnTaskScheduler = mock(TaskScheduler.class);
+ uberTaskScheduler = mock(TaskScheduler.class);
+ }
+
+ @Override
+ TaskScheduler createTaskScheduler(String host, int port, String trackingUrl,
+ AppContext appContext,
+ NamedEntityDescriptor taskSchedulerDescriptor,
+ long customAppIdIdentifier,
+ int schedulerId) {
+
+ numCreateInvocations.incrementAndGet();
+ boolean added = seenSchedulers.add(schedulerId);
+ assertTrue("Cannot add multiple schedulers with the same schedulerId", added);
+ taskSchedulerNames.add(taskSchedulerDescriptor.getEntityName());
+ return super.createTaskScheduler(host, port, trackingUrl, appContext, taskSchedulerDescriptor,
+ customAppIdIdentifier, schedulerId);
+ }
+
+ @Override
+ TaskSchedulerContext wrapTaskSchedulerContext(TaskSchedulerContext rawContext) {
+ // Avoid wrapping in threads
+ return rawContext;
+ }
+
+ @Override
+ TaskScheduler createYarnTaskScheduler(TaskSchedulerContext taskSchedulerContext, int schedulerId) {
+ taskSchedulerContexts.add(taskSchedulerContext);
+ testTaskSchedulers.add(yarnTaskScheduler);
+ yarnSchedulerCreated.set(true);
+ return yarnTaskScheduler;
+ }
+
+ @Override
+ TaskScheduler createUberTaskScheduler(TaskSchedulerContext taskSchedulerContext, int schedulerId) {
+ taskSchedulerContexts.add(taskSchedulerContext);
+ uberSchedulerCreated.set(true);
+ testTaskSchedulers.add(yarnTaskScheduler);
+ return uberTaskScheduler;
+ }
+
+ @Override
+ TaskScheduler createCustomTaskScheduler(TaskSchedulerContext taskSchedulerContext,
+ NamedEntityDescriptor taskSchedulerDescriptor, int schedulerId) {
+ taskSchedulerContexts.add(taskSchedulerContext);
+ TaskScheduler taskScheduler = spy(super.createCustomTaskScheduler(taskSchedulerContext, taskSchedulerDescriptor, schedulerId));
+ testTaskSchedulers.add(taskScheduler);
+ return taskScheduler;
+ }
+
+ @Override
+ // Inline handling of events.
+ public void handle(AMSchedulerEvent event) {
+ handleEvent(event);
+ }
+
+ public boolean getUberSchedulerCreated() {
+ return uberSchedulerCreated.get();
+ }
+
+ public boolean getYarnSchedulerCreated() {
+ return yarnSchedulerCreated.get();
+ }
+
+ public int getNumCreateInvocations() {
+ return numCreateInvocations.get();
+ }
+
+ public TaskSchedulerContext getTaskSchedulerContext(int schedulerId) {
+ return taskSchedulerContexts.get(schedulerId);
+ }
+
+ public String getTaskSchedulerName(int schedulerId) {
+ return taskSchedulerNames.get(schedulerId);
+ }
+
+ public TaskScheduler getTestTaskScheduler(int schedulerId) {
+ return testTaskSchedulers.get(schedulerId);
+ }
+ }
+
+ public static class FakeTaskScheduler extends TaskScheduler {
+
+ public FakeTaskScheduler(
+ TaskSchedulerContext taskSchedulerContext) {
+ super(taskSchedulerContext);
+ }
+
+ @Override
+ public Resource getAvailableResources() {
+ return null;
+ }
+
+ @Override
+ public int getClusterNodeCount() {
+ return 0;
+ }
+
+ @Override
+ public void dagComplete() {
+
+ }
+
+ @Override
+ public Resource getTotalResources() {
+ return null;
+ }
+
+ @Override
+ public void blacklistNode(NodeId nodeId) {
+
+ }
+
+ @Override
+ public void unblacklistNode(NodeId nodeId) {
+
+ }
+
+ @Override
+ public void allocateTask(Object task, Resource capability, String[] hosts, String[] racks,
+ Priority priority, Object containerSignature, Object clientCookie) {
+
+ }
+
+ @Override
+ public void allocateTask(Object task, Resource capability, ContainerId containerId,
+ Priority priority, Object containerSignature, Object clientCookie) {
+
+ }
+
+ @Override
+ public boolean deallocateTask(Object task, boolean taskSucceeded,
+ TaskAttemptEndReason endReason) {
+ return false;
+ }
+
+ @Override
+ public Object deallocateContainer(ContainerId containerId) {
+ return null;
+ }
+
+ @Override
+ public void setShouldUnregister() {
+
+ }
+
+ @Override
+ public boolean hasUnregistered() {
+ return false;
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java
index 59ab00a..0746507 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerHelpers.java
@@ -42,6 +42,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.service.AbstractService;
@@ -138,7 +139,8 @@ class TestTaskSchedulerHelpers {
ContainerSignatureMatcher containerSignatureMatcher,
UserPayload defaultPayload) {
super(appContext, null, eventHandler, containerSignatureMatcher, null,
- new LinkedList<NamedEntityDescriptor>(), defaultPayload, false);
+ Lists.newArrayList(new NamedEntityDescriptor("FakeScheduler", null)),
+ false);
this.amrmClientAsync = amrmClientAsync;
this.containerSignatureMatcher = containerSignatureMatcher;
this.defaultPayload = defaultPayload;
[2/2] tez git commit: TEZ-2126. Add unit tests for verifying multiple
schedulers, launchers, communicators. (sseth)
Posted by ss...@apache.org.
TEZ-2126. Add unit tests for verifying multiple schedulers, launchers,
communicators. (sseth)
Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/bd6fcf95
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/bd6fcf95
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/bd6fcf95
Branch: refs/heads/TEZ-2003
Commit: bd6fcf95d74fa3e9da8fd1d1f5c994afaf37919d
Parents: 0026ebe
Author: Siddharth Seth <ss...@apache.org>
Authored: Thu Aug 6 01:04:31 2015 -0700
Committer: Siddharth Seth <ss...@apache.org>
Committed: Thu Aug 6 01:04:31 2015 -0700
----------------------------------------------------------------------
TEZ-2003-CHANGES.txt | 1 +
.../tez/dag/api/NamedEntityDescriptor.java | 7 +
.../org/apache/tez/dag/app/DAGAppMaster.java | 163 ++++----
.../dag/app/TaskAttemptListenerImpTezDag.java | 94 ++---
.../apache/tez/dag/app/dag/impl/VertexImpl.java | 9 +-
.../app/launcher/ContainerLauncherRouter.java | 126 ++++---
.../dag/app/rm/TaskSchedulerEventHandler.java | 137 +++----
.../apache/tez/dag/app/MockDAGAppMaster.java | 3 +-
.../apache/tez/dag/app/TestDAGAppMaster.java | 300 +++++++++++++++
.../app/TestTaskAttemptListenerImplTezDag.java | 44 ++-
.../app/TestTaskAttemptListenerImplTezDag2.java | 6 +-
.../dag/app/TestTaskCommunicatorManager.java | 369 +++++++++++++++++++
.../tez/dag/app/dag/impl/TestVertexImpl2.java | 279 ++++++++++++--
.../launcher/TestContainerLauncherRouter.java | 361 ++++++++++++++++++
.../app/rm/TestTaskSchedulerEventHandler.java | 330 ++++++++++++++++-
.../dag/app/rm/TestTaskSchedulerHelpers.java | 4 +-
16 files changed, 1907 insertions(+), 326 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/TEZ-2003-CHANGES.txt
----------------------------------------------------------------------
diff --git a/TEZ-2003-CHANGES.txt b/TEZ-2003-CHANGES.txt
index c7a3dcc..f921739 100644
--- a/TEZ-2003-CHANGES.txt
+++ b/TEZ-2003-CHANGES.txt
@@ -42,5 +42,6 @@ ALL CHANGES:
TEZ-2441. Add tests for TezTaskRunner2.
TEZ-2657. Add tests for client side changes - specifying plugins, etc.
TEZ-2626. Fix log lines with DEBUG in messages, consolidate TEZ-2003 TODOs.
+ TEZ-2126. Add unit tests for verifying multiple schedulers, launchers, communicators.
INCOMPATIBLE CHANGES:
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java b/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java
index 723d43f..17c8c6c 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/NamedEntityDescriptor.java
@@ -35,4 +35,11 @@ public class NamedEntityDescriptor<T extends NamedEntityDescriptor<T>> extends E
super.setUserPayload(userPayload);
return (T) this;
}
+
+ @Override
+ public String toString() {
+ boolean hasPayload =
+ getUserPayload() == null ? false : getUserPayload().getPayload() == null ? false : true;
+ return "EntityName=" + entityName + ", ClassName=" + getClassName() + ", hasPayload=" + hasPayload;
+ }
}
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
index 9ed14d7..767c55c 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
@@ -59,6 +59,7 @@ import java.util.regex.Pattern;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
+import com.google.common.collect.Lists;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.Options;
@@ -388,42 +389,16 @@ public class DAGAppMaster extends AbstractService {
this.isLocal = conf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE,
TezConfiguration.TEZ_LOCAL_MODE_DEFAULT);
- List<NamedEntityDescriptor> taskSchedulerDescriptors;
- List<NamedEntityDescriptor> containerLauncherDescriptors;
- List<NamedEntityDescriptor> taskCommunicatorDescriptors;
- boolean tezYarnEnabled = true;
- boolean uberEnabled = false;
-
- if (!isLocal) {
- if (amPluginDescriptorProto == null) {
- tezYarnEnabled = true;
- uberEnabled = false;
- } else {
- tezYarnEnabled = amPluginDescriptorProto.getContainersEnabled();
- uberEnabled = amPluginDescriptorProto.getUberEnabled();
- }
- } else {
- tezYarnEnabled = false;
- uberEnabled = true;
- }
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(amConf);
- taskSchedulerDescriptors = parsePlugin(taskSchedulers,
- (amPluginDescriptorProto == null || amPluginDescriptorProto.getTaskSchedulersCount() == 0 ?
- null :
- amPluginDescriptorProto.getTaskSchedulersList()),
- tezYarnEnabled, uberEnabled);
+ List<NamedEntityDescriptor> taskSchedulerDescriptors = Lists.newLinkedList();
+ List<NamedEntityDescriptor> containerLauncherDescriptors = Lists.newLinkedList();
+ List<NamedEntityDescriptor> taskCommunicatorDescriptors = Lists.newLinkedList();
- containerLauncherDescriptors = parsePlugin(containerLaunchers,
- (amPluginDescriptorProto == null ||
- amPluginDescriptorProto.getContainerLaunchersCount() == 0 ? null :
- amPluginDescriptorProto.getContainerLaunchersList()),
- tezYarnEnabled, uberEnabled);
+ parseAllPlugins(taskSchedulerDescriptors, taskSchedulers, containerLauncherDescriptors,
+ containerLaunchers, taskCommunicatorDescriptors, taskCommunicators, amPluginDescriptorProto,
+ isLocal, defaultPayload);
- taskCommunicatorDescriptors = parsePlugin(taskCommunicators,
- (amPluginDescriptorProto == null ||
- amPluginDescriptorProto.getTaskCommunicatorsCount() == 0 ? null :
- amPluginDescriptorProto.getTaskCommunicatorsList()),
- tezYarnEnabled, uberEnabled);
LOG.info(buildPluginComponentLog(taskSchedulerDescriptors, taskSchedulers, "TaskSchedulers"));
@@ -493,12 +468,11 @@ public class DAGAppMaster extends AbstractService {
jobTokenSecretManager.addTokenForJob(
appAttemptID.getApplicationId().toString(), sessionToken);
- UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(amConf);
+
//service to handle requests to TaskUmbilicalProtocol
taskAttemptListener = createTaskAttemptListener(context,
- taskHeartbeatHandler, containerHeartbeatHandler, taskCommunicatorDescriptors,
- defaultPayload, isLocal);
+ taskHeartbeatHandler, containerHeartbeatHandler, taskCommunicatorDescriptors);
addIfService(taskAttemptListener, true);
containerSignatureMatcher = createContainerSignatureMatcher();
@@ -548,7 +522,7 @@ public class DAGAppMaster extends AbstractService {
this.taskSchedulerEventHandler = new TaskSchedulerEventHandler(context,
clientRpcServer, dispatcher.getEventHandler(), containerSignatureMatcher, webUIService,
- taskSchedulerDescriptors, defaultPayload, isLocal);
+ taskSchedulerDescriptors, isLocal);
addIfService(taskSchedulerEventHandler, true);
if (enableWebUIService()) {
@@ -566,7 +540,7 @@ public class DAGAppMaster extends AbstractService {
taskSchedulerEventHandler);
addIfServiceDependency(taskSchedulerEventHandler, clientRpcServer);
- this.containerLauncherRouter = createContainerLauncherRouter(defaultPayload, containerLauncherDescriptors, isLocal);
+ this.containerLauncherRouter = createContainerLauncherRouter(containerLauncherDescriptors, isLocal);
addIfService(containerLauncherRouter, true);
dispatcher.register(NMCommunicatorEventType.class, containerLauncherRouter);
@@ -1076,12 +1050,9 @@ public class DAGAppMaster extends AbstractService {
protected TaskAttemptListener createTaskAttemptListener(AppContext context,
TaskHeartbeatHandler thh,
ContainerHeartbeatHandler chh,
- List<NamedEntityDescriptor> entityDescriptors,
- UserPayload defaultUserPayload,
- boolean isLocal) {
+ List<NamedEntityDescriptor> entityDescriptors) {
TaskAttemptListener lis =
- new TaskAttemptListenerImpTezDag(context, thh, chh,
- entityDescriptors, defaultUserPayload, isLocal);
+ new TaskAttemptListenerImpTezDag(context, thh, chh, entityDescriptors);
return lis;
}
@@ -1102,11 +1073,10 @@ public class DAGAppMaster extends AbstractService {
return chh;
}
- protected ContainerLauncherRouter createContainerLauncherRouter(UserPayload defaultPayload,
- List<NamedEntityDescriptor> containerLauncherDescriptors,
+ protected ContainerLauncherRouter createContainerLauncherRouter(List<NamedEntityDescriptor> containerLauncherDescriptors,
boolean isLocal) throws
UnknownHostException {
- return new ContainerLauncherRouter(defaultPayload, context, taskAttemptListener, workingDirectory,
+ return new ContainerLauncherRouter(context, taskAttemptListener, workingDirectory,
containerLauncherDescriptors, isLocal);
}
@@ -2373,41 +2343,106 @@ public class DAGAppMaster extends AbstractService {
TezConfiguration.TEZ_AM_WEBSERVICE_ENABLE_DEFAULT);
}
- private static List<NamedEntityDescriptor> parsePlugin(
- BiMap<String, Integer> pluginMap, List<TezNamedEntityDescriptorProto> namedEntityDescriptorProtos,
- boolean tezYarnEnabled, boolean uberEnabled) {
- int index = 0;
+ @VisibleForTesting
+ static void parseAllPlugins(
+ List<NamedEntityDescriptor> taskSchedulerDescriptors, BiMap<String, Integer> taskSchedulerPluginMap,
+ List<NamedEntityDescriptor> containerLauncherDescriptors, BiMap<String, Integer> containerLauncherPluginMap,
+ List<NamedEntityDescriptor> taskCommDescriptors, BiMap<String, Integer> taskCommPluginMap,
+ AMPluginDescriptorProto amPluginDescriptorProto, boolean isLocal, UserPayload defaultPayload) {
+
+ boolean tezYarnEnabled;
+ boolean uberEnabled;
+ if (!isLocal) {
+ if (amPluginDescriptorProto == null) {
+ tezYarnEnabled = true;
+ uberEnabled = false;
+ } else {
+ tezYarnEnabled = amPluginDescriptorProto.getContainersEnabled();
+ uberEnabled = amPluginDescriptorProto.getUberEnabled();
+ }
+ } else {
+ tezYarnEnabled = false;
+ uberEnabled = true;
+ }
+
+ parsePlugin(taskSchedulerDescriptors, taskSchedulerPluginMap,
+ (amPluginDescriptorProto == null || amPluginDescriptorProto.getTaskSchedulersCount() == 0 ?
+ null :
+ amPluginDescriptorProto.getTaskSchedulersList()),
+ tezYarnEnabled, uberEnabled, defaultPayload);
+ processSchedulerDescriptors(taskSchedulerDescriptors, isLocal, defaultPayload, taskSchedulerPluginMap);
- List<NamedEntityDescriptor> resultList = new LinkedList<>();
+ parsePlugin(containerLauncherDescriptors, containerLauncherPluginMap,
+ (amPluginDescriptorProto == null ||
+ amPluginDescriptorProto.getContainerLaunchersCount() == 0 ? null :
+ amPluginDescriptorProto.getContainerLaunchersList()),
+ tezYarnEnabled, uberEnabled, defaultPayload);
+
+ parsePlugin(taskCommDescriptors, taskCommPluginMap,
+ (amPluginDescriptorProto == null ||
+ amPluginDescriptorProto.getTaskCommunicatorsCount() == 0 ? null :
+ amPluginDescriptorProto.getTaskCommunicatorsList()),
+ tezYarnEnabled, uberEnabled, defaultPayload);
+ }
+
+
+ @VisibleForTesting
+ static void parsePlugin(List<NamedEntityDescriptor> resultList,
+ BiMap<String, Integer> pluginMap, List<TezNamedEntityDescriptorProto> namedEntityDescriptorProtos,
+ boolean tezYarnEnabled, boolean uberEnabled, UserPayload defaultPayload) {
if (tezYarnEnabled) {
// Default classnames will be populated by individual components
NamedEntityDescriptor r = new NamedEntityDescriptor(
- TezConstants.getTezYarnServicePluginName(), null);
- resultList.add(r);
- pluginMap.put(TezConstants.getTezYarnServicePluginName(), index);
- index++;
+ TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload);
+ addDescriptor(resultList, pluginMap, r);
}
if (uberEnabled) {
// Default classnames will be populated by individual components
NamedEntityDescriptor r = new NamedEntityDescriptor(
- TezConstants.getTezUberServicePluginName(), null);
- resultList.add(r);
- pluginMap.put(TezConstants.getTezUberServicePluginName(), index);
- index++;
+ TezConstants.getTezUberServicePluginName(), null).setUserPayload(defaultPayload);
+ addDescriptor(resultList, pluginMap, r);
}
if (namedEntityDescriptorProtos != null) {
for (TezNamedEntityDescriptorProto namedEntityDescriptorProto : namedEntityDescriptorProtos) {
- resultList.add(DagTypeConverters
- .convertNamedDescriptorFromProto(namedEntityDescriptorProto));
- pluginMap.put(resultList.get(index).getEntityName(), index);
- index++;
+ NamedEntityDescriptor namedEntityDescriptor = DagTypeConverters
+ .convertNamedDescriptorFromProto(namedEntityDescriptorProto);
+ addDescriptor(resultList, pluginMap, namedEntityDescriptor);
+ }
+ }
+ }
+
+ @VisibleForTesting
+ static void addDescriptor(List<NamedEntityDescriptor> list, BiMap<String, Integer> pluginMap,
+ NamedEntityDescriptor namedEntityDescriptor) {
+ list.add(namedEntityDescriptor);
+ pluginMap.put(list.get(list.size() - 1).getEntityName(), list.size() - 1);
+ }
+
+ @VisibleForTesting
+ static void processSchedulerDescriptors(List<NamedEntityDescriptor> descriptors, boolean isLocal,
+ UserPayload defaultPayload,
+ BiMap<String, Integer> schedulerPluginMap) {
+ if (isLocal) {
+ Preconditions.checkState(descriptors.size() == 1 &&
+ descriptors.get(0).getEntityName().equals(TezConstants.getTezUberServicePluginName()));
+ } else {
+ boolean foundYarn = false;
+ for (int i = 0; i < descriptors.size(); i++) {
+ if (descriptors.get(i).getEntityName().equals(TezConstants.getTezYarnServicePluginName())) {
+ foundYarn = true;
+ }
+ }
+ if (!foundYarn) {
+ NamedEntityDescriptor yarnDescriptor =
+ new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(defaultPayload);
+ addDescriptor(descriptors, schedulerPluginMap, yarnDescriptor);
}
}
- return resultList;
}
String buildPluginComponentLog(List<NamedEntityDescriptor> namedEntityDescriptors, BiMap<String, Integer> map,
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
index 462befe..7d92988 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
@@ -27,7 +27,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Lists;
+import com.google.common.base.Preconditions;
import org.apache.commons.collections4.ListUtils;
import org.apache.tez.dag.api.NamedEntityDescriptor;
import org.apache.tez.dag.api.TezConstants;
@@ -103,35 +103,19 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
public TaskAttemptListenerImpTezDag(AppContext context,
TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh,
- List<NamedEntityDescriptor> taskCommunicatorDescriptors,
- UserPayload defaultUserPayload,
- boolean isPureLocalMode) {
+ List<NamedEntityDescriptor> taskCommunicatorDescriptors) {
super(TaskAttemptListenerImpTezDag.class.getName());
this.context = context;
this.taskHeartbeatHandler = thh;
this.containerHeartbeatHandler = chh;
- if (taskCommunicatorDescriptors == null || taskCommunicatorDescriptors.isEmpty()) {
- if (isPureLocalMode) {
- taskCommunicatorDescriptors = Lists.newArrayList(new NamedEntityDescriptor(
- TezConstants.getTezUberServicePluginName(), null).setUserPayload(defaultUserPayload));
- } else {
- taskCommunicatorDescriptors = Lists.newArrayList(new NamedEntityDescriptor(
- TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultUserPayload));
- }
- }
+ Preconditions.checkArgument(
+ taskCommunicatorDescriptors != null && !taskCommunicatorDescriptors.isEmpty(),
+ "TaskCommunicators must be specified");
this.taskCommunicators = new TaskCommunicator[taskCommunicatorDescriptors.size()];
this.taskCommunicatorContexts = new TaskCommunicatorContext[taskCommunicatorDescriptors.size()];
this.taskCommunicatorServiceWrappers = new ServicePluginLifecycleAbstractService[taskCommunicatorDescriptors.size()];
for (int i = 0 ; i < taskCommunicatorDescriptors.size() ; i++) {
- UserPayload userPayload;
- if (taskCommunicatorDescriptors.get(i).getEntityName()
- .equals(TezConstants.getTezYarnServicePluginName()) ||
- taskCommunicatorDescriptors.get(i).getEntityName()
- .equals(TezConstants.getTezUberServicePluginName())) {
- userPayload = defaultUserPayload;
- } else {
- userPayload = taskCommunicatorDescriptors.get(i).getUserPayload();
- }
+ UserPayload userPayload = taskCommunicatorDescriptors.get(i).getUserPayload();
taskCommunicatorContexts[i] = new TaskCommunicatorContextImpl(context, this, userPayload, i);
taskCommunicators[i] = createTaskCommunicator(taskCommunicatorDescriptors.get(i), i);
taskCommunicatorServiceWrappers[i] = new ServicePluginLifecycleAbstractService(taskCommunicators[i]);
@@ -155,36 +139,54 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
}
}
- private TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor, int taskCommIndex) {
+ @VisibleForTesting
+ TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor,
+ int taskCommIndex) {
if (taskCommDescriptor.getEntityName().equals(TezConstants.getTezYarnServicePluginName())) {
- LOG.info("Using Default Task Communicator");
- return createTezTaskCommunicator(taskCommunicatorContexts[taskCommIndex]);
- } else if (taskCommDescriptor.getEntityName().equals(TezConstants.getTezUberServicePluginName())) {
- LOG.info("Using Default Local Task Communicator");
- return new TezLocalTaskCommunicatorImpl(taskCommunicatorContexts[taskCommIndex]);
+ return createDefaultTaskCommunicator(taskCommunicatorContexts[taskCommIndex]);
+ } else if (taskCommDescriptor.getEntityName()
+ .equals(TezConstants.getTezUberServicePluginName())) {
+ return createUberTaskCommunicator(taskCommunicatorContexts[taskCommIndex]);
} else {
- LOG.info("Using TaskCommunicator {}:{} " + taskCommDescriptor.getEntityName(), taskCommDescriptor.getClassName());
- Class<? extends TaskCommunicator> taskCommClazz = (Class<? extends TaskCommunicator>) ReflectionUtils
- .getClazz(taskCommDescriptor.getClassName());
- try {
- Constructor<? extends TaskCommunicator> ctor = taskCommClazz.getConstructor(TaskCommunicatorContext.class);
- ctor.setAccessible(true);
- return ctor.newInstance(taskCommunicatorContexts[taskCommIndex]);
- } catch (NoSuchMethodException e) {
- throw new TezUncheckedException(e);
- } catch (InvocationTargetException e) {
- throw new TezUncheckedException(e);
- } catch (InstantiationException e) {
- throw new TezUncheckedException(e);
- } catch (IllegalAccessException e) {
- throw new TezUncheckedException(e);
- }
+ return createCustomTaskCommunicator(taskCommunicatorContexts[taskCommIndex],
+ taskCommDescriptor);
}
}
@VisibleForTesting
- protected TezTaskCommunicatorImpl createTezTaskCommunicator(TaskCommunicatorContext context) {
- return new TezTaskCommunicatorImpl(context);
+ TaskCommunicator createDefaultTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) {
+ LOG.info("Using Default Task Communicator");
+ return new TezTaskCommunicatorImpl(taskCommunicatorContext);
+ }
+
+ @VisibleForTesting
+ TaskCommunicator createUberTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) {
+ LOG.info("Using Default Local Task Communicator");
+ return new TezLocalTaskCommunicatorImpl(taskCommunicatorContext);
+ }
+
+ @VisibleForTesting
+ TaskCommunicator createCustomTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext,
+ NamedEntityDescriptor taskCommDescriptor) {
+ LOG.info("Using TaskCommunicator {}:{} " + taskCommDescriptor.getEntityName(),
+ taskCommDescriptor.getClassName());
+ Class<? extends TaskCommunicator> taskCommClazz =
+ (Class<? extends TaskCommunicator>) ReflectionUtils
+ .getClazz(taskCommDescriptor.getClassName());
+ try {
+ Constructor<? extends TaskCommunicator> ctor =
+ taskCommClazz.getConstructor(TaskCommunicatorContext.class);
+ ctor.setAccessible(true);
+ return ctor.newInstance(taskCommunicatorContext);
+ } catch (NoSuchMethodException e) {
+ throw new TezUncheckedException(e);
+ } catch (InvocationTargetException e) {
+ throw new TezUncheckedException(e);
+ } catch (InstantiationException e) {
+ throw new TezUncheckedException(e);
+ } catch (IllegalAccessException e) {
+ throw new TezUncheckedException(e);
+ }
}
public TaskHeartbeatResponse heartbeat(TaskHeartbeatRequest request)
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index e7c209d..50a5377 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -231,9 +231,12 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex, EventHandl
private final boolean isSpeculationEnabled;
- private final int taskSchedulerIdentifier;
- private final int containerLauncherIdentifier;
- private final int taskCommunicatorIdentifier;
+ @VisibleForTesting
+ final int taskSchedulerIdentifier;
+ @VisibleForTesting
+ final int containerLauncherIdentifier;
+ @VisibleForTesting
+ final int taskCommunicatorIdentifier;
//fields initialized in init
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java
index 2d56bfe..57b4aee 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherRouter.java
@@ -20,7 +20,7 @@ import java.net.UnknownHostException;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Lists;
+import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.service.AbstractService;
import org.apache.hadoop.yarn.event.EventHandler;
@@ -48,8 +48,10 @@ public class ContainerLauncherRouter extends AbstractService
static final Logger LOG = LoggerFactory.getLogger(ContainerLauncherImpl.class);
- private final ContainerLauncher containerLaunchers[];
- private final ContainerLauncherContext containerLauncherContexts[];
+ @VisibleForTesting
+ final ContainerLauncher containerLaunchers[];
+ @VisibleForTesting
+ final ContainerLauncherContext containerLauncherContexts[];
protected final ServicePluginLifecycleAbstractService[] containerLauncherServiceWrappers;
private final AppContext appContext;
@@ -64,7 +66,7 @@ public class ContainerLauncherRouter extends AbstractService
}
// Accepting conf to setup final parameters, if required.
- public ContainerLauncherRouter(UserPayload defaultUserPayload, AppContext context,
+ public ContainerLauncherRouter(AppContext context,
TaskAttemptListener taskAttemptListener,
String workingDirectory,
List<NamedEntityDescriptor> containerLauncherDescriptors,
@@ -72,79 +74,91 @@ public class ContainerLauncherRouter extends AbstractService
super(ContainerLauncherRouter.class.getName());
this.appContext = context;
- if (containerLauncherDescriptors == null || containerLauncherDescriptors.isEmpty()) {
- if (isPureLocalMode) {
- containerLauncherDescriptors = Lists.newArrayList(new NamedEntityDescriptor(
- TezConstants.getTezUberServicePluginName(), null).setUserPayload(defaultUserPayload));
- } else {
- containerLauncherDescriptors = Lists.newArrayList(new NamedEntityDescriptor(
- TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultUserPayload));
- }
- }
+ Preconditions.checkArgument(
+ containerLauncherDescriptors != null && !containerLauncherDescriptors.isEmpty(),
+ "ContainerLauncherDescriptors must be specified");
containerLauncherContexts = new ContainerLauncherContext[containerLauncherDescriptors.size()];
containerLaunchers = new ContainerLauncher[containerLauncherDescriptors.size()];
containerLauncherServiceWrappers = new ServicePluginLifecycleAbstractService[containerLauncherDescriptors.size()];
for (int i = 0; i < containerLauncherDescriptors.size(); i++) {
- UserPayload userPayload;
- if (containerLauncherDescriptors.get(i).getEntityName()
- .equals(TezConstants.getTezYarnServicePluginName()) ||
- containerLauncherDescriptors.get(i).getEntityName()
- .equals(TezConstants.getTezUberServicePluginName())) {
- userPayload = defaultUserPayload;
- } else {
- userPayload = containerLauncherDescriptors.get(i).getUserPayload();
- }
+ UserPayload userPayload = containerLauncherDescriptors.get(i).getUserPayload();
ContainerLauncherContext containerLauncherContext =
new ContainerLauncherContextImpl(context, taskAttemptListener, userPayload);
containerLauncherContexts[i] = containerLauncherContext;
containerLaunchers[i] = createContainerLauncher(containerLauncherDescriptors.get(i), context,
- containerLauncherContext, taskAttemptListener, workingDirectory, isPureLocalMode);
+ containerLauncherContext, taskAttemptListener, workingDirectory, i, isPureLocalMode);
containerLauncherServiceWrappers[i] = new ServicePluginLifecycleAbstractService(containerLaunchers[i]);
}
}
- private ContainerLauncher createContainerLauncher(NamedEntityDescriptor containerLauncherDescriptor,
- AppContext context,
- ContainerLauncherContext containerLauncherContext,
- TaskAttemptListener taskAttemptListener,
- String workingDirectory,
- boolean isPureLocalMode) throws
+ @VisibleForTesting
+ ContainerLauncher createContainerLauncher(
+ NamedEntityDescriptor containerLauncherDescriptor,
+ AppContext context,
+ ContainerLauncherContext containerLauncherContext,
+ TaskAttemptListener taskAttemptListener,
+ String workingDirectory,
+ int containerLauncherIndex,
+ boolean isPureLocalMode) throws
UnknownHostException {
if (containerLauncherDescriptor.getEntityName().equals(
TezConstants.getTezYarnServicePluginName())) {
- LOG.info("Creating DefaultContainerLauncher");
- return new ContainerLauncherImpl(containerLauncherContext);
+ return createYarnContainerLauncher(containerLauncherContext);
} else if (containerLauncherDescriptor.getEntityName()
.equals(TezConstants.getTezUberServicePluginName())) {
- LOG.info("Creating LocalContainerLauncher");
- // TODO Post TEZ-2003. LocalContainerLauncher is special cased, since it makes use of
- // extensive internals which are only available at runtime. Will likely require
- // some kind of runtime binding of parameters in the payload to work correctly.
- return
- new LocalContainerLauncher(containerLauncherContext, context, taskAttemptListener, workingDirectory, isPureLocalMode);
+ return createUberContainerLauncher(containerLauncherContext, context, taskAttemptListener,
+ workingDirectory, isPureLocalMode);
} else {
- LOG.info("Creating container launcher {}:{} ", containerLauncherDescriptor.getEntityName(), containerLauncherDescriptor.getClassName());
- Class<? extends ContainerLauncher> containerLauncherClazz =
- (Class<? extends ContainerLauncher>) ReflectionUtils.getClazz(
- containerLauncherDescriptor.getClassName());
- try {
- Constructor<? extends ContainerLauncher> ctor = containerLauncherClazz
- .getConstructor(ContainerLauncherContext.class);
- ctor.setAccessible(true);
- return ctor.newInstance(containerLauncherContext);
- } catch (NoSuchMethodException e) {
- throw new TezUncheckedException(e);
- } catch (InvocationTargetException e) {
- throw new TezUncheckedException(e);
- } catch (InstantiationException e) {
- throw new TezUncheckedException(e);
- } catch (IllegalAccessException e) {
- throw new TezUncheckedException(e);
- }
+ return createCustomContainerLauncher(containerLauncherContext, containerLauncherDescriptor);
+ }
+ }
+
+ @VisibleForTesting
+ ContainerLauncher createYarnContainerLauncher(ContainerLauncherContext containerLauncherContext) {
+ LOG.info("Creating DefaultContainerLauncher");
+ return new ContainerLauncherImpl(containerLauncherContext);
+ }
+
+ @VisibleForTesting
+ ContainerLauncher createUberContainerLauncher(ContainerLauncherContext containerLauncherContext,
+ AppContext context,
+ TaskAttemptListener taskAttemptListener,
+ String workingDirectory,
+ boolean isPureLocalMode) throws
+ UnknownHostException {
+ LOG.info("Creating LocalContainerLauncher");
+ // TODO Post TEZ-2003. LocalContainerLauncher is special cased, since it makes use of
+ // extensive internals which are only available at runtime. Will likely require
+ // some kind of runtime binding of parameters in the payload to work correctly.
+ return
+ new LocalContainerLauncher(containerLauncherContext, context, taskAttemptListener,
+ workingDirectory, isPureLocalMode);
+ }
+
+ @VisibleForTesting
+ ContainerLauncher createCustomContainerLauncher(ContainerLauncherContext containerLauncherContext,
+ NamedEntityDescriptor containerLauncherDescriptor) {
+ LOG.info("Creating container launcher {}:{} ", containerLauncherDescriptor.getEntityName(),
+ containerLauncherDescriptor.getClassName());
+ Class<? extends ContainerLauncher> containerLauncherClazz =
+ (Class<? extends ContainerLauncher>) ReflectionUtils.getClazz(
+ containerLauncherDescriptor.getClassName());
+ try {
+ Constructor<? extends ContainerLauncher> ctor = containerLauncherClazz
+ .getConstructor(ContainerLauncherContext.class);
+ ctor.setAccessible(true);
+ return ctor.newInstance(containerLauncherContext);
+ } catch (NoSuchMethodException e) {
+ throw new TezUncheckedException(e);
+ } catch (InvocationTargetException e) {
+ throw new TezUncheckedException(e);
+ } catch (InstantiationException e) {
+ throw new TezUncheckedException(e);
+ } catch (IllegalAccessException e) {
+ throw new TezUncheckedException(e);
}
- // TODO TEZ-2118 Handle routing to multiple launchers
}
@Override
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java
index c86f638..7c36232 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerEventHandler.java
@@ -22,7 +22,6 @@ import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
-import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
@@ -34,10 +33,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
-import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.tez.dag.api.NamedEntityDescriptor;
import org.apache.tez.dag.api.TezConstants;
-import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.app.ServicePluginLifecycleAbstractService;
import org.apache.tez.serviceplugins.api.TaskScheduler;
import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
@@ -126,9 +123,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements
private final boolean isPureLocalMode;
// If running in non local-only mode, the YARN task scheduler will always run to take care of
// registration with YARN and heartbeats to YARN.
- // Splitting registration and heartbeats is not straigh-forward due to the taskScheduler being
+ // Splitting registration and heartbeats is not straight-forward due to the taskScheduler being
// tied to a ContainerRequestType.
- private final int yarnTaskSchedulerIndex;
// Custom AppIds to avoid container conflicts if there's multiple sources
private final long SCHEDULER_APP_ID_BASE = 111101111;
private final long SCHEDULER_APP_ID_INCREMENT = 111111111;
@@ -153,9 +149,10 @@ public class TaskSchedulerEventHandler extends AbstractService implements
public TaskSchedulerEventHandler(AppContext appContext,
DAGClientServer clientService, EventHandler eventHandler,
ContainerSignatureMatcher containerSignatureMatcher, WebUIService webUI,
- List<NamedEntityDescriptor> schedulerDescriptors, UserPayload defaultPayload,
- boolean isPureLocalMode) {
+ List<NamedEntityDescriptor> schedulerDescriptors, boolean isPureLocalMode) {
super(TaskSchedulerEventHandler.class.getName());
+ Preconditions.checkArgument(schedulerDescriptors != null && !schedulerDescriptors.isEmpty(),
+ "TaskSchedulerDescriptors must be specified");
this.appContext = appContext;
this.eventHandler = eventHandler;
this.clientService = clientService;
@@ -168,50 +165,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements
this.webUI.setHistoryUrl(this.historyUrl);
}
- // Override everything for pure local mode
- if (isPureLocalMode) {
- this.taskSchedulerDescriptors = new NamedEntityDescriptor[]{
- new NamedEntityDescriptor(TezConstants.getTezUberServicePluginName(), null)
- .setUserPayload(defaultPayload)};
- this.yarnTaskSchedulerIndex = -1;
- } else {
- if (schedulerDescriptors == null || schedulerDescriptors.isEmpty()) {
- this.taskSchedulerDescriptors = new NamedEntityDescriptor[]{
- new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
- .setUserPayload(defaultPayload)};
- this.yarnTaskSchedulerIndex = 0;
- } else {
- // Ensure the YarnScheduler will be setup and note it's index. This will be responsible for heartbeats and YARN registration.
- int foundYarnTaskSchedulerIndex = -1;
-
- List<NamedEntityDescriptor> schedulerDescriptorList = new LinkedList<>();
- for (int i = 0 ; i < schedulerDescriptors.size() ; i++) {
- if (schedulerDescriptors.get(i).getEntityName().equals(
- TezConstants.getTezYarnServicePluginName())) {
- schedulerDescriptorList.add(
- new NamedEntityDescriptor(schedulerDescriptors.get(i).getEntityName(), null)
- .setUserPayload(
- defaultPayload));
- foundYarnTaskSchedulerIndex = i;
- } else if (schedulerDescriptors.get(i).getEntityName().equals(
- TezConstants.getTezUberServicePluginName())) {
- schedulerDescriptorList.add(
- new NamedEntityDescriptor(schedulerDescriptors.get(i).getEntityName(), null)
- .setUserPayload(
- defaultPayload));
- } else {
- schedulerDescriptorList.add(schedulerDescriptors.get(i));
- }
- }
- if (foundYarnTaskSchedulerIndex == -1) {
- schedulerDescriptorList.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(
- defaultPayload));
- foundYarnTaskSchedulerIndex = schedulerDescriptorList.size() -1;
- }
- this.taskSchedulerDescriptors = schedulerDescriptorList.toArray(new NamedEntityDescriptor[schedulerDescriptorList.size()]);
- this.yarnTaskSchedulerIndex = foundYarnTaskSchedulerIndex;
- }
- }
+ this.taskSchedulerDescriptors = schedulerDescriptors.toArray(new NamedEntityDescriptor[schedulerDescriptors.size()]);
+
taskSchedulers = new TaskScheduler[this.taskSchedulerDescriptors.length];
taskSchedulerServiceWrappers = new ServicePluginLifecycleAbstractService[this.taskSchedulerDescriptors.length];
}
@@ -239,7 +194,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements
private ExecutorService createAppCallbackExecutorService() {
return Executors.newSingleThreadExecutor(
- new ThreadFactoryBuilder().setNameFormat("TaskSchedulerAppCallbackExecutor #%d").setDaemon(true)
+ new ThreadFactoryBuilder().setNameFormat("TaskSchedulerAppCallbackExecutor #%d")
+ .setDaemon(true)
.build());
}
@@ -428,7 +384,8 @@ public class TaskSchedulerEventHandler extends AbstractService implements
event);
}
- private TaskScheduler createTaskScheduler(String host, int port, String trackingUrl,
+ @VisibleForTesting
+ TaskScheduler createTaskScheduler(String host, int port, String trackingUrl,
AppContext appContext,
NamedEntityDescriptor taskSchedulerDescriptor,
long customAppIdIdentifier,
@@ -436,32 +393,57 @@ public class TaskSchedulerEventHandler extends AbstractService implements
TaskSchedulerContext rawContext =
new TaskSchedulerContextImpl(this, appContext, schedulerId, trackingUrl,
customAppIdIdentifier, host, port, taskSchedulerDescriptor.getUserPayload());
- TaskSchedulerContext wrappedContext = new TaskSchedulerContextImplWrapper(rawContext, appCallbackExecutor);
+ TaskSchedulerContext wrappedContext = wrapTaskSchedulerContext(rawContext);
String schedulerName = taskSchedulerDescriptor.getEntityName();
if (schedulerName.equals(TezConstants.getTezYarnServicePluginName())) {
- LOG.info("Creating TaskScheduler: YarnTaskSchedulerService");
- return new YarnTaskSchedulerService(wrappedContext);
+ return createYarnTaskScheduler(wrappedContext, schedulerId);
} else if (schedulerName.equals(TezConstants.getTezUberServicePluginName())) {
- LOG.info("Creating TaskScheduler: Local TaskScheduler");
- return new LocalTaskSchedulerService(wrappedContext);
+ return createUberTaskScheduler(wrappedContext, schedulerId);
} else {
- LOG.info("Creating custom TaskScheduler {}:{}", taskSchedulerDescriptor.getEntityName(), taskSchedulerDescriptor.getClassName());
- Class<? extends TaskScheduler> taskSchedulerClazz =
- (Class<? extends TaskScheduler>) ReflectionUtils.getClazz(taskSchedulerDescriptor.getClassName());
- try {
- Constructor<? extends TaskScheduler> ctor = taskSchedulerClazz
- .getConstructor(TaskSchedulerContext.class);
- ctor.setAccessible(true);
- return ctor.newInstance(wrappedContext);
- } catch (NoSuchMethodException e) {
- throw new TezUncheckedException(e);
- } catch (InvocationTargetException e) {
- throw new TezUncheckedException(e);
- } catch (InstantiationException e) {
- throw new TezUncheckedException(e);
- } catch (IllegalAccessException e) {
- throw new TezUncheckedException(e);
- }
+ return createCustomTaskScheduler(wrappedContext, taskSchedulerDescriptor, schedulerId);
+ }
+ }
+
+ @VisibleForTesting
+ TaskSchedulerContext wrapTaskSchedulerContext(TaskSchedulerContext rawContext) {
+ return new TaskSchedulerContextImplWrapper(rawContext, appCallbackExecutor);
+ }
+
+ @VisibleForTesting
+ TaskScheduler createYarnTaskScheduler(TaskSchedulerContext taskSchedulerContext,
+ int schedulerId) {
+ LOG.info("Creating TaskScheduler: YarnTaskSchedulerService");
+ return new YarnTaskSchedulerService(taskSchedulerContext);
+ }
+
+ @VisibleForTesting
+ TaskScheduler createUberTaskScheduler(TaskSchedulerContext taskSchedulerContext,
+ int schedulerId) {
+ LOG.info("Creating TaskScheduler: Local TaskScheduler");
+ return new LocalTaskSchedulerService(taskSchedulerContext);
+ }
+
+ TaskScheduler createCustomTaskScheduler(TaskSchedulerContext taskSchedulerContext,
+ NamedEntityDescriptor taskSchedulerDescriptor,
+ int schedulerId) {
+ LOG.info("Creating custom TaskScheduler {}:{}", taskSchedulerDescriptor.getEntityName(),
+ taskSchedulerDescriptor.getClassName());
+ Class<? extends TaskScheduler> taskSchedulerClazz =
+ (Class<? extends TaskScheduler>) ReflectionUtils
+ .getClazz(taskSchedulerDescriptor.getClassName());
+ try {
+ Constructor<? extends TaskScheduler> ctor = taskSchedulerClazz
+ .getConstructor(TaskSchedulerContext.class);
+ ctor.setAccessible(true);
+ return ctor.newInstance(taskSchedulerContext);
+ } catch (NoSuchMethodException e) {
+ throw new TezUncheckedException(e);
+ } catch (InvocationTargetException e) {
+ throw new TezUncheckedException(e);
+ } catch (InstantiationException e) {
+ throw new TezUncheckedException(e);
+ } catch (IllegalAccessException e) {
+ throw new TezUncheckedException(e);
}
}
@@ -797,9 +779,4 @@ public class TaskSchedulerEventHandler extends AbstractService implements
return historyUrl;
}
- @VisibleForTesting
- @InterfaceAudience.Private
- ExecutorService getContextExecutorService() {
- return appCallbackExecutor;
- }
}
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
index 99406dd..2770182 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
@@ -509,8 +509,7 @@ public class MockDAGAppMaster extends DAGAppMaster {
// use mock container launcher for tests
@Override
- protected ContainerLauncherRouter createContainerLauncherRouter(final UserPayload defaultUserPayload,
- List<NamedEntityDescriptor> containerLauncherDescirptors,
+ protected ContainerLauncherRouter createContainerLauncherRouter(List<NamedEntityDescriptor> containerLauncherDescirptors,
boolean isLocal)
throws UnknownHostException {
return new ContainerLauncherRouter(containerLauncher, getContext());
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java
new file mode 100644
index 0000000..fa5d87c
--- /dev/null
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.LinkedList;
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.BiMap;
+import com.google.common.collect.HashBiMap;
+import com.google.common.collect.Lists;
+import com.google.protobuf.ByteString;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.NamedEntityDescriptor;
+import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.records.DAGProtos;
+import org.apache.tez.dag.api.records.DAGProtos.AMPluginDescriptorProto;
+import org.apache.tez.dag.api.records.DAGProtos.TezNamedEntityDescriptorProto;
+import org.apache.tez.dag.api.records.DAGProtos.TezUserPayloadProto;
+import org.junit.Test;
+
+public class TestDAGAppMaster {
+
+ private static final String TEST_KEY = "TEST_KEY";
+ private static final String TEST_VAL = "TEST_VAL";
+ private static final String TS_NAME = "TS";
+ private static final String CL_NAME = "CL";
+ private static final String TC_NAME = "TC";
+ private static final String CLASS_SUFFIX = "_CLASS";
+
+ @Test(timeout = 5000)
+ public void testPluginParsing() throws IOException {
+ BiMap<String, Integer> pluginMap = HashBiMap.create();
+ Configuration conf = new Configuration(false);
+ conf.set("testkey", "testval");
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ List<TezNamedEntityDescriptorProto> entityDescriptors = new LinkedList<>();
+ List<NamedEntityDescriptor> entities;
+
+ // Test empty descriptor list, yarn enabled
+ pluginMap.clear();
+ entities = new LinkedList<>();
+ DAGAppMaster.parsePlugin(entities, pluginMap, null, true, false, defaultPayload);
+ assertEquals(1, pluginMap.size());
+ assertEquals(1, entities.size());
+ assertTrue(pluginMap.containsKey(TezConstants.getTezYarnServicePluginName()));
+ assertTrue(0 == pluginMap.get(TezConstants.getTezYarnServicePluginName()));
+ assertEquals("testval",
+ TezUtils.createConfFromUserPayload(entities.get(0).getUserPayload()).get("testkey"));
+
+ // Test empty descriptor list, uber enabled
+ pluginMap.clear();
+ entities = new LinkedList<>();
+ DAGAppMaster.parsePlugin(entities, pluginMap, null, false, true, defaultPayload);
+ assertEquals(1, pluginMap.size());
+ assertEquals(1, entities.size());
+ assertTrue(pluginMap.containsKey(TezConstants.getTezUberServicePluginName()));
+ assertTrue(0 == pluginMap.get(TezConstants.getTezUberServicePluginName()));
+ assertEquals("testval",
+ TezUtils.createConfFromUserPayload(entities.get(0).getUserPayload()).get("testkey"));
+
+ // Test empty descriptor list, yarn enabled, uber enabled
+ pluginMap.clear();
+ entities = new LinkedList<>();
+ DAGAppMaster.parsePlugin(entities, pluginMap, null, true, true, defaultPayload);
+ assertEquals(2, pluginMap.size());
+ assertEquals(2, entities.size());
+ assertTrue(pluginMap.containsKey(TezConstants.getTezYarnServicePluginName()));
+ assertTrue(0 == pluginMap.get(TezConstants.getTezYarnServicePluginName()));
+ assertTrue(pluginMap.containsKey(TezConstants.getTezUberServicePluginName()));
+ assertTrue(1 == pluginMap.get(TezConstants.getTezUberServicePluginName()));
+
+
+ String pluginName = "d1";
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ TezNamedEntityDescriptorProto d1 =
+ TezNamedEntityDescriptorProto.newBuilder().setName(pluginName).setEntityDescriptor(
+ DAGProtos.TezEntityDescriptorProto.newBuilder().setClassName("d1Class")
+ .setTezUserPayload(
+ TezUserPayloadProto.newBuilder()
+ .setUserPayload(ByteString.copyFrom(bb)))).build();
+ entityDescriptors.add(d1);
+
+ // Test descriptor, no yarn, no uber
+ pluginMap.clear();
+ entities = new LinkedList<>();
+ DAGAppMaster.parsePlugin(entities, pluginMap, entityDescriptors, false, false, defaultPayload);
+ assertEquals(1, pluginMap.size());
+ assertEquals(1, entities.size());
+ assertTrue(pluginMap.containsKey(pluginName));
+ assertTrue(0 == pluginMap.get(pluginName));
+
+ // Test descriptor, yarn and uber
+ pluginMap.clear();
+ entities = new LinkedList<>();
+ DAGAppMaster.parsePlugin(entities, pluginMap, entityDescriptors, true, true, defaultPayload);
+ assertEquals(3, pluginMap.size());
+ assertEquals(3, entities.size());
+ assertTrue(pluginMap.containsKey(TezConstants.getTezYarnServicePluginName()));
+ assertTrue(0 == pluginMap.get(TezConstants.getTezYarnServicePluginName()));
+ assertTrue(pluginMap.containsKey(TezConstants.getTezUberServicePluginName()));
+ assertTrue(1 == pluginMap.get(TezConstants.getTezUberServicePluginName()));
+ assertTrue(pluginMap.containsKey(pluginName));
+ assertTrue(2 == pluginMap.get(pluginName));
+ entityDescriptors.clear();
+ }
+
+
+ @Test(timeout = 5000)
+ public void testParseAllPluginsNoneSpecified() throws IOException {
+ Configuration conf = new Configuration(false);
+ conf.set(TEST_KEY, TEST_VAL);
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ List<NamedEntityDescriptor> tsDescriptors;
+ BiMap<String, Integer> tsMap;
+ List<NamedEntityDescriptor> clDescriptors;
+ BiMap<String, Integer> clMap;
+ List<NamedEntityDescriptor> tcDescriptors;
+ BiMap<String, Integer> tcMap;
+
+
+ // No plugins. Non local
+ tsDescriptors = Lists.newLinkedList();
+ tsMap = HashBiMap.create();
+ clDescriptors = Lists.newLinkedList();
+ clMap = HashBiMap.create();
+ tcDescriptors = Lists.newLinkedList();
+ tcMap = HashBiMap.create();
+ DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap,
+ null, false, defaultPayload);
+ verifyDescAndMap(tsDescriptors, tsMap, 1, true, TezConstants.getTezYarnServicePluginName());
+ verifyDescAndMap(clDescriptors, clMap, 1, true, TezConstants.getTezYarnServicePluginName());
+ verifyDescAndMap(tcDescriptors, tcMap, 1, true, TezConstants.getTezYarnServicePluginName());
+
+ // No plugins. Local
+ tsDescriptors = Lists.newLinkedList();
+ tsMap = HashBiMap.create();
+ clDescriptors = Lists.newLinkedList();
+ clMap = HashBiMap.create();
+ tcDescriptors = Lists.newLinkedList();
+ tcMap = HashBiMap.create();
+ DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap,
+ null, true, defaultPayload);
+ verifyDescAndMap(tsDescriptors, tsMap, 1, true, TezConstants.getTezUberServicePluginName());
+ verifyDescAndMap(clDescriptors, clMap, 1, true, TezConstants.getTezUberServicePluginName());
+ verifyDescAndMap(tcDescriptors, tcMap, 1, true, TezConstants.getTezUberServicePluginName());
+ }
+
+ @Test(timeout = 5000)
+ public void testParseAllPluginsOnlyCustomSpecified() throws IOException {
+ Configuration conf = new Configuration(false);
+ conf.set(TEST_KEY, TEST_VAL);
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+ TezUserPayloadProto payloadProto = TezUserPayloadProto.newBuilder()
+ .setUserPayload(ByteString.copyFrom(defaultPayload.getPayload())).build();
+
+ AMPluginDescriptorProto proto = createAmPluginDescriptor(false, false, true, payloadProto);
+
+ List<NamedEntityDescriptor> tsDescriptors;
+ BiMap<String, Integer> tsMap;
+ List<NamedEntityDescriptor> clDescriptors;
+ BiMap<String, Integer> clMap;
+ List<NamedEntityDescriptor> tcDescriptors;
+ BiMap<String, Integer> tcMap;
+
+
+ // Only plugin, Yarn.
+ tsDescriptors = Lists.newLinkedList();
+ tsMap = HashBiMap.create();
+ clDescriptors = Lists.newLinkedList();
+ clMap = HashBiMap.create();
+ tcDescriptors = Lists.newLinkedList();
+ tcMap = HashBiMap.create();
+ DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap,
+ proto, false, defaultPayload);
+ verifyDescAndMap(tsDescriptors, tsMap, 2, true, TS_NAME,
+ TezConstants.getTezYarnServicePluginName());
+ verifyDescAndMap(clDescriptors, clMap, 1, true, CL_NAME);
+ verifyDescAndMap(tcDescriptors, tcMap, 1, true, TC_NAME);
+ assertEquals(TS_NAME + CLASS_SUFFIX, tsDescriptors.get(0).getClassName());
+ assertEquals(CL_NAME + CLASS_SUFFIX, clDescriptors.get(0).getClassName());
+ assertEquals(TC_NAME + CLASS_SUFFIX, tcDescriptors.get(0).getClassName());
+ }
+
+ @Test(timeout = 5000)
+ public void testParseAllPluginsCustomAndYarnSpecified() throws IOException {
+ Configuration conf = new Configuration(false);
+ conf.set(TEST_KEY, TEST_VAL);
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+ TezUserPayloadProto payloadProto = TezUserPayloadProto.newBuilder()
+ .setUserPayload(ByteString.copyFrom(defaultPayload.getPayload())).build();
+
+ AMPluginDescriptorProto proto = createAmPluginDescriptor(true, false, true, payloadProto);
+
+ List<NamedEntityDescriptor> tsDescriptors;
+ BiMap<String, Integer> tsMap;
+ List<NamedEntityDescriptor> clDescriptors;
+ BiMap<String, Integer> clMap;
+ List<NamedEntityDescriptor> tcDescriptors;
+ BiMap<String, Integer> tcMap;
+
+
+ // Only plugin, Yarn.
+ tsDescriptors = Lists.newLinkedList();
+ tsMap = HashBiMap.create();
+ clDescriptors = Lists.newLinkedList();
+ clMap = HashBiMap.create();
+ tcDescriptors = Lists.newLinkedList();
+ tcMap = HashBiMap.create();
+ DAGAppMaster.parseAllPlugins(tsDescriptors, tsMap, clDescriptors, clMap, tcDescriptors, tcMap,
+ proto, false, defaultPayload);
+ verifyDescAndMap(tsDescriptors, tsMap, 2, true, TezConstants.getTezYarnServicePluginName(),
+ TS_NAME);
+ verifyDescAndMap(clDescriptors, clMap, 2, true, TezConstants.getTezYarnServicePluginName(),
+ CL_NAME);
+ verifyDescAndMap(tcDescriptors, tcMap, 2, true, TezConstants.getTezYarnServicePluginName(),
+ TC_NAME);
+ assertNull(tsDescriptors.get(0).getClassName());
+ assertNull(clDescriptors.get(0).getClassName());
+ assertNull(tcDescriptors.get(0).getClassName());
+ assertEquals(TS_NAME + CLASS_SUFFIX, tsDescriptors.get(1).getClassName());
+ assertEquals(CL_NAME + CLASS_SUFFIX, clDescriptors.get(1).getClassName());
+ assertEquals(TC_NAME + CLASS_SUFFIX, tcDescriptors.get(1).getClassName());
+ }
+
+ private void verifyDescAndMap(List<NamedEntityDescriptor> descriptors, BiMap<String, Integer> map,
+ int numExpected, boolean verifyPayload,
+ String... expectedNames) throws
+ IOException {
+ Preconditions.checkArgument(expectedNames.length == numExpected);
+ assertEquals(numExpected, descriptors.size());
+ assertEquals(numExpected, map.size());
+ for (int i = 0; i < numExpected; i++) {
+ assertEquals(expectedNames[i], descriptors.get(i).getEntityName());
+ if (verifyPayload) {
+ assertEquals(TEST_VAL,
+ TezUtils.createConfFromUserPayload(descriptors.get(0).getUserPayload()).get(TEST_KEY));
+ }
+ assertTrue(map.get(expectedNames[i]) == i);
+ assertTrue(map.inverse().get(i) == expectedNames[i]);
+ }
+ }
+
+ private AMPluginDescriptorProto createAmPluginDescriptor(boolean enableYarn, boolean enableUber,
+ boolean addCustom,
+ TezUserPayloadProto payloadProto) {
+ AMPluginDescriptorProto.Builder builder = AMPluginDescriptorProto.newBuilder()
+ .setUberEnabled(enableUber)
+ .setContainersEnabled(enableYarn);
+ if (addCustom) {
+ builder.addTaskSchedulers(
+ TezNamedEntityDescriptorProto.newBuilder()
+ .setName(TS_NAME)
+ .setEntityDescriptor(
+ DAGProtos.TezEntityDescriptorProto.newBuilder()
+ .setClassName(TS_NAME + CLASS_SUFFIX)
+ .setTezUserPayload(payloadProto)))
+ .addContainerLaunchers(
+ TezNamedEntityDescriptorProto.newBuilder()
+ .setName(CL_NAME)
+ .setEntityDescriptor(
+ DAGProtos.TezEntityDescriptorProto.newBuilder()
+ .setClassName(CL_NAME + CLASS_SUFFIX)
+ .setTezUserPayload(payloadProto)))
+ .addTaskCommunicators(
+ TezNamedEntityDescriptorProto.newBuilder()
+ .setName(TC_NAME)
+ .setEntityDescriptor(
+ DAGProtos.TezEntityDescriptorProto.newBuilder()
+ .setClassName(TC_NAME + CLASS_SUFFIX)
+ .setTezUserPayload(payloadProto)));
+ }
+ return builder.build();
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
index 59e486d..982790b 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
@@ -34,6 +34,7 @@ import java.util.List;
import java.util.Map;
import java.util.Random;
+import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.Credentials;
@@ -52,7 +53,9 @@ import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.NamedEntityDescriptor;
+import org.apache.tez.dag.api.TaskCommunicator;
import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezConstants;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.serviceplugins.api.ContainerEndReason;
@@ -143,7 +146,10 @@ public class TestTaskAttemptListenerImplTezDag {
throw new TezUncheckedException(e);
}
taskAttemptListener = new TaskAttemptListenerImplForTest(appContext,
- mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null, defaultPayload, false);
+ mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class),
+ Lists.newArrayList(
+ new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(defaultPayload)));
taskSpec = mock(TaskSpec.class);
doReturn(taskAttemptID).when(taskSpec).getTaskAttemptID();
@@ -296,7 +302,7 @@ public class TestTaskAttemptListenerImplTezDag {
// TODO TEZ-2003 Move this into TestTezTaskCommunicator. Potentially other tests as well.
@Test (timeout= 5000)
- public void testPortRange_NotSpecified() {
+ public void testPortRange_NotSpecified() throws IOException {
Configuration conf = new Configuration();
JobTokenIdentifier identifier = new JobTokenIdentifier(new Text(
"fakeIdentifier"));
@@ -304,14 +310,11 @@ public class TestTaskAttemptListenerImplTezDag {
new JobTokenSecretManager());
sessionToken.setService(identifier.getJobId());
TokenCache.setSessionToken(sessionToken, credentials);
- UserPayload userPayload = null;
- try {
- userPayload = TezUtils.createUserPayloadFromConf(conf);
- } catch (IOException e) {
- throw new TezUncheckedException(e);
- }
+ UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf);
taskAttemptListener = new TaskAttemptListenerImpTezDag(appContext,
- mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null, userPayload, false);
+ mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), Lists.newArrayList(
+ new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(userPayload)));
// no exception happen, should started properly
taskAttemptListener.init(conf);
taskAttemptListener.start();
@@ -330,14 +333,12 @@ public class TestTaskAttemptListenerImplTezDag {
TokenCache.setSessionToken(sessionToken, credentials);
conf.set(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE, port + "-" + port);
- UserPayload userPayload = null;
- try {
- userPayload = TezUtils.createUserPayloadFromConf(conf);
- } catch (IOException e) {
- throw new TezUncheckedException(e);
- }
+ UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf);
+
taskAttemptListener = new TaskAttemptListenerImpTezDag(appContext,
- mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null, userPayload, false);
+ mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), Lists
+ .newArrayList(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+ .setUserPayload(userPayload)));
taskAttemptListener.init(conf);
taskAttemptListener.start();
int resultedPort = taskAttemptListener.getTaskCommunicator(0).getAddress().getPort();
@@ -393,16 +394,13 @@ public class TestTaskAttemptListenerImplTezDag {
public TaskAttemptListenerImplForTest(AppContext context,
TaskHeartbeatHandler thh,
ContainerHeartbeatHandler chh,
- List<NamedEntityDescriptor> taskCommDescriptors,
- UserPayload userPayload,
- boolean isPureLocalMode) {
- super(context, thh, chh, taskCommDescriptors, userPayload,
- isPureLocalMode);
+ List<NamedEntityDescriptor> taskCommDescriptors) {
+ super(context, thh, chh, taskCommDescriptors);
}
@Override
- protected TezTaskCommunicatorImpl createTezTaskCommunicator(TaskCommunicatorContext context) {
- return new TezTaskCommunicatorImplForTest(context);
+ TaskCommunicator createDefaultTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) {
+ return new TezTaskCommunicatorImplForTest(taskCommunicatorContext);
}
}
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java
index 1c82bd8..abb5e42 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag2.java
@@ -26,6 +26,7 @@ import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
+import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
@@ -37,7 +38,9 @@ import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.NamedEntityDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezConstants;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
@@ -83,7 +86,8 @@ public class TestTaskAttemptListenerImplTezDag2 {
UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf);
TaskAttemptListenerImpTezDag taskAttemptListener =
new TaskAttemptListenerImpTezDag(appContext, mock(TaskHeartbeatHandler.class),
- mock(ContainerHeartbeatHandler.class), null, userPayload, false);
+ mock(ContainerHeartbeatHandler.class), Lists.newArrayList(new NamedEntityDescriptor(
+ TezConstants.getTezYarnServicePluginName(), null).setUserPayload(userPayload)));
TaskSpec taskSpec1 = mock(TaskSpec.class);
TezTaskAttemptID taskAttemptId1 = mock(TezTaskAttemptID.class);
http://git-wip-us.apache.org/repos/asf/tez/blob/bd6fcf95/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java
new file mode 100644
index 0000000..c76aa50
--- /dev/null
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskCommunicatorManager.java
@@ -0,0 +1,369 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.NamedEntityDescriptor;
+import org.apache.tez.dag.api.TaskCommunicator;
+import org.apache.tez.dag.api.TaskCommunicatorContext;
+import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.impl.TaskSpec;
+import org.apache.tez.serviceplugins.api.ContainerEndReason;
+import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestTaskCommunicatorManager {
+
+ @Before
+ @After
+ public void reset() {
+ TaskCommManagerForMultipleCommTest.reset();
+ }
+
+ @Test(timeout = 5000)
+ public void testNoTaskCommSpecified() throws IOException {
+
+ AppContext appContext = mock(AppContext.class);
+ TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
+ ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
+
+ try {
+ new TaskCommManagerForMultipleCommTest(appContext, thh, chh, null);
+ fail("Initialization should have failed without a TaskComm specified");
+ } catch (IllegalArgumentException e) {
+
+ }
+
+
+ }
+
+ @Test(timeout = 5000)
+ public void testCustomTaskCommSpecified() throws IOException {
+
+ AppContext appContext = mock(AppContext.class);
+ TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
+ ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
+
+ String customTaskCommName = "customTaskComm";
+ List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload customPayload = UserPayload.create(bb);
+ taskCommDescriptors.add(
+ new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName())
+ .setUserPayload(customPayload));
+
+ TaskCommManagerForMultipleCommTest tcm =
+ new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
+
+ try {
+ tcm.init(new Configuration(false));
+ tcm.start();
+
+ assertEquals(1, tcm.getNumTaskComms());
+ assertFalse(tcm.getYarnTaskCommCreated());
+ assertFalse(tcm.getUberTaskCommCreated());
+
+ assertEquals(customTaskCommName, tcm.getTaskCommName(0));
+ assertEquals(bb, tcm.getTaskCommContext(0).getInitialUserPayload().getPayload());
+
+ } finally {
+ tcm.stop();
+ }
+ }
+
+ @Test(timeout = 5000)
+ public void testMultipleTaskComms() throws IOException {
+
+ AppContext appContext = mock(AppContext.class);
+ TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
+ ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
+ Configuration conf = new Configuration(false);
+ conf.set("testkey", "testvalue");
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ String customTaskCommName = "customTaskComm";
+ List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload customPayload = UserPayload.create(bb);
+ taskCommDescriptors.add(
+ new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName())
+ .setUserPayload(customPayload));
+ taskCommDescriptors
+ .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload));
+
+ TaskCommManagerForMultipleCommTest tcm =
+ new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
+
+ try {
+ tcm.init(new Configuration(false));
+ tcm.start();
+
+ assertEquals(2, tcm.getNumTaskComms());
+ assertTrue(tcm.getYarnTaskCommCreated());
+ assertFalse(tcm.getUberTaskCommCreated());
+
+ assertEquals(customTaskCommName, tcm.getTaskCommName(0));
+ assertEquals(bb, tcm.getTaskCommContext(0).getInitialUserPayload().getPayload());
+
+ assertEquals(TezConstants.getTezYarnServicePluginName(), tcm.getTaskCommName(1));
+ Configuration confParsed = TezUtils
+ .createConfFromUserPayload(tcm.getTaskCommContext(1).getInitialUserPayload());
+ assertEquals("testvalue", confParsed.get("testkey"));
+ } finally {
+ tcm.stop();
+ }
+ }
+
+ @Test(timeout = 5000)
+ public void testEventRouting() throws Exception {
+
+ AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
+ NodeId nodeId = NodeId.newInstance("host1", 3131);
+ when(appContext.getAllContainers().get(any(ContainerId.class)).getContainer().getNodeId())
+ .thenReturn(nodeId);
+ TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
+ ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
+ Configuration conf = new Configuration(false);
+ conf.set("testkey", "testvalue");
+ UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+ String customTaskCommName = "customTaskComm";
+ List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>();
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.putInt(0, 3);
+ UserPayload customPayload = UserPayload.create(bb);
+ taskCommDescriptors.add(
+ new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName())
+ .setUserPayload(customPayload));
+ taskCommDescriptors
+ .add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload));
+
+ TaskCommManagerForMultipleCommTest tcm =
+ new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
+
+ try {
+ tcm.init(new Configuration(false));
+ tcm.start();
+
+ assertEquals(2, tcm.getNumTaskComms());
+ assertTrue(tcm.getYarnTaskCommCreated());
+ assertFalse(tcm.getUberTaskCommCreated());
+
+ verify(tcm.getTestTaskComm(0)).initialize();
+ verify(tcm.getTestTaskComm(0)).start();
+ verify(tcm.getTestTaskComm(1)).initialize();
+ verify(tcm.getTestTaskComm(1)).start();
+
+
+ ContainerId containerId1 = mock(ContainerId.class);
+ tcm.registerRunningContainer(containerId1, 0);
+ verify(tcm.getTestTaskComm(0)).registerRunningContainer(eq(containerId1), eq("host1"),
+ eq(3131));
+
+ ContainerId containerId2 = mock(ContainerId.class);
+ tcm.registerRunningContainer(containerId2, 1);
+ verify(tcm.getTestTaskComm(1)).registerRunningContainer(eq(containerId2), eq("host1"),
+ eq(3131));
+
+ } finally {
+ tcm.stop();
+ verify(tcm.getTaskCommunicator(0)).shutdown();
+ verify(tcm.getTaskCommunicator(1)).shutdown();
+ }
+ }
+
+
+ static class TaskCommManagerForMultipleCommTest extends TaskAttemptListenerImpTezDag {
+
+ // All variables setup as static since methods being overridden are invoked by the ContainerLauncherRouter ctor,
+ // and regular variables will not be initialized at this point.
+ private static final AtomicInteger numTaskComms = new AtomicInteger(0);
+ private static final Set<Integer> taskCommIndices = new HashSet<>();
+ private static final TaskCommunicator yarnTaskComm = mock(TaskCommunicator.class);
+ private static final TaskCommunicator uberTaskComm = mock(TaskCommunicator.class);
+ private static final AtomicBoolean yarnTaskCommCreated = new AtomicBoolean(false);
+ private static final AtomicBoolean uberTaskCommCreated = new AtomicBoolean(false);
+
+ private static final List<TaskCommunicatorContext> taskCommContexts =
+ new LinkedList<>();
+ private static final List<String> taskCommNames = new LinkedList<>();
+ private static final List<TaskCommunicator> testTaskComms = new LinkedList<>();
+
+
+ public static void reset() {
+ numTaskComms.set(0);
+ taskCommIndices.clear();
+ yarnTaskCommCreated.set(false);
+ uberTaskCommCreated.set(false);
+ taskCommContexts.clear();
+ taskCommNames.clear();
+ testTaskComms.clear();
+ }
+
+ public TaskCommManagerForMultipleCommTest(AppContext context,
+ TaskHeartbeatHandler thh,
+ ContainerHeartbeatHandler chh,
+ List<NamedEntityDescriptor> taskCommunicatorDescriptors) {
+ super(context, thh, chh, taskCommunicatorDescriptors);
+ }
+
+ @Override
+ TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor,
+ int taskCommIndex) {
+ numTaskComms.incrementAndGet();
+ boolean added = taskCommIndices.add(taskCommIndex);
+ assertTrue("Cannot add multiple taskComms with the same index", added);
+ taskCommNames.add(taskCommDescriptor.getEntityName());
+ return super.createTaskCommunicator(taskCommDescriptor, taskCommIndex);
+ }
+
+ @Override
+ TaskCommunicator createDefaultTaskCommunicator(
+ TaskCommunicatorContext taskCommunicatorContext) {
+ taskCommContexts.add(taskCommunicatorContext);
+ yarnTaskCommCreated.set(true);
+ testTaskComms.add(yarnTaskComm);
+ return yarnTaskComm;
+ }
+
+ @Override
+ TaskCommunicator createUberTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) {
+ taskCommContexts.add(taskCommunicatorContext);
+ uberTaskCommCreated.set(true);
+ testTaskComms.add(uberTaskComm);
+ return uberTaskComm;
+ }
+
+ @Override
+ TaskCommunicator createCustomTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext,
+ NamedEntityDescriptor taskCommDescriptor) {
+ taskCommContexts.add(taskCommunicatorContext);
+ TaskCommunicator spyComm =
+ spy(super.createCustomTaskCommunicator(taskCommunicatorContext, taskCommDescriptor));
+ testTaskComms.add(spyComm);
+ return spyComm;
+ }
+
+ public static int getNumTaskComms() {
+ return numTaskComms.get();
+ }
+
+ public static boolean getYarnTaskCommCreated() {
+ return yarnTaskCommCreated.get();
+ }
+
+ public static boolean getUberTaskCommCreated() {
+ return uberTaskCommCreated.get();
+ }
+
+ public static TaskCommunicatorContext getTaskCommContext(int taskCommIndex) {
+ return taskCommContexts.get(taskCommIndex);
+ }
+
+ public static String getTaskCommName(int taskCommIndex) {
+ return taskCommNames.get(taskCommIndex);
+ }
+
+ public static TaskCommunicator getTestTaskComm(int taskCommIndex) {
+ return testTaskComms.get(taskCommIndex);
+ }
+ }
+
+ public static class FakeTaskComm extends TaskCommunicator {
+
+ public FakeTaskComm(TaskCommunicatorContext taskCommunicatorContext) {
+ super(taskCommunicatorContext);
+ }
+
+ @Override
+ public void registerRunningContainer(ContainerId containerId, String hostname, int port) {
+
+ }
+
+ @Override
+ public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason) {
+
+ }
+
+ @Override
+ public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec,
+ Map<String, LocalResource> additionalResources,
+ Credentials credentials, boolean credentialsChanged,
+ int priority) {
+
+ }
+
+ @Override
+ public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID,
+ TaskAttemptEndReason endReason) {
+
+ }
+
+ @Override
+ public InetSocketAddress getAddress() {
+ return null;
+ }
+
+ @Override
+ public void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws Exception {
+
+ }
+
+ @Override
+ public void dagComplete(String dagName) {
+
+ }
+
+ @Override
+ public Object getMetaInfo() {
+ return null;
+ }
+ }
+}