You are viewing a plain text version of this content. The canonical link for it is here.
Posted to common-commits@hadoop.apache.org by zt...@apache.org on 2019/04/25 04:55:32 UTC
[hadoop] branch trunk updated: SUBMARINE-54. Add test coverage for
YarnServiceJobSubmitter and make it ready for extension for PyTorch.
Contributed by Szilard Nemeth.
This is an automated email from the ASF dual-hosted git repository.
ztang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/hadoop.git
The following commit(s) were added to refs/heads/trunk by this push:
new 0b3d41b SUBMARINE-54. Add test coverage for YarnServiceJobSubmitter and make it ready for extension for PyTorch. Contributed by Szilard Nemeth.
0b3d41b is described below
commit 0b3d41bdeef68afdde0fa4dca27fb582c83424c4
Author: Zhankun Tang <zt...@apache.org>
AuthorDate: Thu Apr 25 12:52:24 2019 +0800
SUBMARINE-54. Add test coverage for YarnServiceJobSubmitter and make it ready for extension for PyTorch. Contributed by Szilard Nemeth.
---
.../client/cli/param/RunJobParameters.java | 16 +
.../submarine/client/cli/TestRunJobCliParsing.java | 19 +
.../common/fs/MockRemoteDirectoryManager.java | 2 +
.../hadoop-submarine-yarnservice-runtime/pom.xml | 6 +
.../runtimes/yarnservice/AbstractComponent.java | 99 ++
.../runtimes/yarnservice/FileSystemOperations.java | 201 ++++
.../yarnservice/HadoopEnvironmentSetup.java | 161 +++
.../runtimes/yarnservice/ServiceSpec.java | 27 +
.../yarnservice/ServiceSpecFileGenerator.java | 51 +
.../runtimes/yarnservice/ServiceWrapper.java | 62 +
.../yarnservice/YarnServiceJobSubmitter.java | 860 +-------------
.../runtimes/yarnservice/YarnServiceUtils.java | 94 +-
.../yarnservice/command/AbstractLaunchCommand.java | 64 +
.../yarnservice/command/LaunchCommandFactory.java | 67 ++
.../yarnservice/command/LaunchScriptBuilder.java | 107 ++
.../runtimes/yarnservice/command/package-info.java | 19 +
.../yarnservice/tensorflow/TensorFlowCommons.java | 109 ++
.../tensorflow/TensorFlowServiceSpec.java | 203 ++++
.../command/TensorBoardLaunchCommand.java | 67 ++
.../command/TensorFlowLaunchCommand.java | 87 ++
.../command/TensorFlowPsLaunchCommand.java | 58 +
.../command/TensorFlowWorkerLaunchCommand.java | 59 +
.../tensorflow/command/package-info.java | 19 +
.../tensorflow/component/TensorBoardComponent.java | 96 ++
.../component/TensorFlowPsComponent.java | 73 ++
.../component/TensorFlowWorkerComponent.java | 82 ++
.../tensorflow/component/package-info.java | 20 +
.../yarnservice/tensorflow/package-info.java | 20 +
.../yarn/submarine/utils/ClassPathUtilities.java | 57 +
.../yarn/submarine/utils/DockerUtilities.java | 33 +
.../yarn/submarine/utils/EnvironmentUtilities.java | 120 ++
.../submarine/utils/KerberosPrincipalFactory.java | 95 ++
.../hadoop/yarn/submarine/utils/Localizer.java | 170 +++
.../submarine/utils/SubmarineResourceUtils.java | 51 +
.../hadoop/yarn/submarine/utils/ZipUtilities.java | 82 ++
.../hadoop/yarn/submarine/utils/package-info.java | 19 +
.../yarn/submarine/FileUtilitiesForTests.java | 146 +++
.../cli/yarnservice/ParamBuilderForTest.java | 139 +++
.../cli/yarnservice/TestYarnServiceRunJobCli.java | 1242 +++++---------------
.../TestYarnServiceRunJobCliCommons.java | 79 ++
.../TestYarnServiceRunJobCliLocalization.java | 599 ++++++++++
.../runtimes/yarnservice/TestServiceWrapper.java | 95 ++
.../yarnservice/TestTFConfigGenerator.java | 10 +-
.../command/AbstractLaunchCommandTestHelper.java | 190 +++
.../command/TestLaunchCommandFactory.java | 97 ++
.../command/TestTensorBoardLaunchCommand.java | 104 ++
.../command/TestTensorFlowLaunchCommand.java | 251 ++++
.../tensorflow/component/ComponentTestCommons.java | 90 ++
.../component/TestTensorBoardComponent.java | 125 ++
.../component/TestTensorFlowPsComponent.java | 166 +++
.../component/TestTensorFlowWorkerComponent.java | 215 ++++
.../submarine/utils/TestClassPathUtilities.java | 91 ++
.../submarine/utils/TestEnvironmentUtilities.java | 231 ++++
.../utils/TestKerberosPrincipalFactory.java | 156 +++
.../utils/TestSubmarineResourceUtils.java | 72 ++
55 files changed, 5629 insertions(+), 1844 deletions(-)
diff --git a/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java b/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java
index e7b1e2f..2d91a64 100644
--- a/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java
+++ b/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java
@@ -293,10 +293,18 @@ public class RunJobParameters extends RunParameters {
return psDockerImage;
}
+ public void setPsDockerImage(String psDockerImage) {
+ this.psDockerImage = psDockerImage;
+ }
+
public String getWorkerDockerImage() {
return workerDockerImage;
}
+ public void setWorkerDockerImage(String workerDockerImage) {
+ this.workerDockerImage = workerDockerImage;
+ }
+
public boolean isDistributed() {
return distributed;
}
@@ -313,6 +321,10 @@ public class RunJobParameters extends RunParameters {
return tensorboardDockerImage;
}
+ public void setTensorboardDockerImage(String tensorboardDockerImage) {
+ this.tensorboardDockerImage = tensorboardDockerImage;
+ }
+
public List<Quicklink> getQuicklinks() {
return quicklinks;
}
@@ -366,6 +378,10 @@ public class RunJobParameters extends RunParameters {
return this;
}
+ public void setDistributed(boolean distributed) {
+ this.distributed = distributed;
+ }
+
@VisibleForTesting
public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
@Override
diff --git a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java
index 4ad0227..d092693 100644
--- a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java
+++ b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java
@@ -178,6 +178,25 @@ public class TestRunJobCliParsing {
}
@Test
+ public void testJobWithoutName() throws Exception {
+ RunJobCli runJobCli = new RunJobCli(getMockClientContext());
+ String expectedErrorMessage =
+ "--" + CliConstants.NAME + " is absent";
+ String actualMessage = "";
+ try {
+ runJobCli.run(
+ new String[]{"--docker_image", "tf-docker:1.1.0",
+ "--num_workers", "0", "--tensorboard", "--verbose",
+ "--tensorboard_resources", "memory=2G,vcores=2",
+ "--tensorboard_docker_image", "tb_docker_image:001"});
+ } catch (ParseException e) {
+ actualMessage = e.getMessage();
+ e.printStackTrace();
+ }
+ assertEquals(expectedErrorMessage, actualMessage);
+ }
+
+ @Test
public void testLaunchCommandPatternReplace() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
diff --git a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java
index 4334293..7ef03f5 100644
--- a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java
+++ b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java
@@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path;
import java.io.File;
import java.io.IOException;
+import java.util.Objects;
public class MockRemoteDirectoryManager implements RemoteDirectoryManager {
private File jobsParentDir = null;
@@ -35,6 +36,7 @@ public class MockRemoteDirectoryManager implements RemoteDirectoryManager {
@Override
public Path getJobStagingArea(String jobName, boolean create)
throws IOException {
+ Objects.requireNonNull(jobName, "Job name must not be null!");
if (jobsParentDir == null && create) {
jobsParentDir = new File(
"target/_staging_area_" + System.currentTimeMillis());
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml
index a337c42..15dffb9 100644
--- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml
@@ -115,6 +115,12 @@
<artifactId>hadoop-yarn-services-core</artifactId>
<version>3.3.0-SNAPSHOT</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java
new file mode 100644
index 0000000..903ae09
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
+
+/**
+ * Abstract base class for Component classes.
+ * The implementations of this class are act like factories for
+ * {@link Component} instances.
+ * All dependencies are passed to the constructor so that child classes
+ * are obliged to provide matching constructors.
+ */
+public abstract class AbstractComponent {
+ private final FileSystemOperations fsOperations;
+ protected final RunJobParameters parameters;
+ protected final TaskType taskType;
+ private final RemoteDirectoryManager remoteDirectoryManager;
+ protected final Configuration yarnConfig;
+ private final LaunchCommandFactory launchCommandFactory;
+
+ /**
+ * This is only required for testing.
+ */
+ private String localScriptFile;
+
+ public AbstractComponent(FileSystemOperations fsOperations,
+ RemoteDirectoryManager remoteDirectoryManager,
+ RunJobParameters parameters, TaskType taskType,
+ Configuration yarnConfig,
+ LaunchCommandFactory launchCommandFactory) {
+ this.fsOperations = fsOperations;
+ this.remoteDirectoryManager = remoteDirectoryManager;
+ this.parameters = parameters;
+ this.taskType = taskType;
+ this.launchCommandFactory = launchCommandFactory;
+ this.yarnConfig = yarnConfig;
+ }
+
+ protected abstract Component createComponent() throws IOException;
+
+ /**
+ * Generates a command launch script on local disk,
+ * returns path to the script.
+ */
+ protected void generateLaunchCommand(Component component)
+ throws IOException {
+ AbstractLaunchCommand launchCommand =
+ launchCommandFactory.createLaunchCommand(taskType, component);
+ this.localScriptFile = launchCommand.generateLaunchScript();
+
+ String remoteLaunchCommand = uploadLaunchCommand(component);
+ component.setLaunchCommand(remoteLaunchCommand);
+ }
+
+ private String uploadLaunchCommand(Component component)
+ throws IOException {
+ Objects.requireNonNull(localScriptFile, "localScriptFile should be " +
+ "set before calling this method!");
+ Path stagingDir =
+ remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
+
+ String destScriptFileName = getScriptFileName(taskType);
+ fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
+ localScriptFile, destScriptFileName, component);
+
+ return "./" + destScriptFileName;
+ }
+
+ String getLocalScriptFile() {
+ return localScriptFile;
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java
new file mode 100644
index 0000000..edac6ed
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.ConfigFile;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.utils.ZipUtilities;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * Contains methods to perform file system operations. Almost all of the methods
+ * are regular non-static methods as the operations are performed with the help
+ * of a {@link RemoteDirectoryManager} instance passed in as a constructor
+ * dependency. Please note that some operations require to read config settings
+ * as well, so that we have Submarine and YARN config objects as dependencies as
+ * well.
+ */
+public class FileSystemOperations {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(FileSystemOperations.class);
+ private final Configuration submarineConfig;
+ private final Configuration yarnConfig;
+
+ private Set<Path> uploadedFiles = new HashSet<>();
+ private RemoteDirectoryManager remoteDirectoryManager;
+
+ public FileSystemOperations(ClientContext clientContext) {
+ this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
+ this.submarineConfig = clientContext.getSubmarineConfig();
+ this.yarnConfig = clientContext.getYarnConfig();
+ }
+
+ /**
+ * May download a remote uri(file/dir) and zip.
+ * Skip download if local dir
+ * Remote uri can be a local dir(won't download)
+ * or remote HDFS dir, s3 dir/file .etc
+ * */
+ public String downloadAndZip(String remoteDir, String zipFileName,
+ boolean doZip)
+ throws IOException {
+ //Append original modification time and size to zip file name
+ String suffix;
+ String srcDir = remoteDir;
+ String zipDirPath =
+ System.getProperty("java.io.tmpdir") + "/" + zipFileName;
+ boolean needDeleteTempDir = false;
+ if (remoteDirectoryManager.isRemote(remoteDir)) {
+ //Append original modification time and size to zip file name
+ FileStatus status =
+ remoteDirectoryManager.getRemoteFileStatus(new Path(remoteDir));
+ suffix = "_" + status.getModificationTime()
+ + "-" + remoteDirectoryManager.getRemoteFileSize(remoteDir);
+ // Download them to temp dir
+ boolean downloaded =
+ remoteDirectoryManager.copyRemoteToLocal(remoteDir, zipDirPath);
+ if (!downloaded) {
+ throw new IOException("Failed to download files from "
+ + remoteDir);
+ }
+ LOG.info("Downloaded remote: {} to local: {}", remoteDir, zipDirPath);
+ srcDir = zipDirPath;
+ needDeleteTempDir = true;
+ } else {
+ File localDir = new File(remoteDir);
+ suffix = "_" + localDir.lastModified()
+ + "-" + localDir.length();
+ }
+ if (!doZip) {
+ return srcDir;
+ }
+ // zip a local dir
+ String zipFileUri =
+ ZipUtilities.zipDir(srcDir, zipDirPath + suffix + ".zip");
+ // delete downloaded temp dir
+ if (needDeleteTempDir) {
+ deleteFiles(srcDir);
+ }
+ return zipFileUri;
+ }
+
+ public void deleteFiles(String localUri) {
+ boolean success = FileUtil.fullyDelete(new File(localUri));
+ if (!success) {
+ LOG.warn("Failed to delete {}", localUri);
+ }
+ LOG.info("Deleted {}", localUri);
+ }
+
+ @VisibleForTesting
+ public void uploadToRemoteFileAndLocalizeToContainerWorkDir(Path stagingDir,
+ String fileToUpload, String destFilename, Component comp)
+ throws IOException {
+ Path uploadedFilePath = uploadToRemoteFile(stagingDir, fileToUpload);
+ locateRemoteFileToContainerWorkDir(destFilename, comp, uploadedFilePath);
+ }
+
+ private void locateRemoteFileToContainerWorkDir(String destFilename,
+ Component comp, Path uploadedFilePath)
+ throws IOException {
+ FileSystem fs = FileSystem.get(yarnConfig);
+
+ FileStatus fileStatus = fs.getFileStatus(uploadedFilePath);
+ LOG.info("Uploaded file path = " + fileStatus.getPath());
+
+ // Set it to component's files list
+ comp.getConfiguration().getFiles().add(new ConfigFile().srcFile(
+ fileStatus.getPath().toUri().toString()).destFile(destFilename)
+ .type(ConfigFile.TypeEnum.STATIC));
+ }
+
+ public Path uploadToRemoteFile(Path stagingDir, String fileToUpload) throws
+ IOException {
+ FileSystem fs = remoteDirectoryManager.getDefaultFileSystem();
+
+ // Upload to remote FS under staging area
+ File localFile = new File(fileToUpload);
+ if (!localFile.exists()) {
+ throw new FileNotFoundException(
+ "Trying to upload file=" + localFile.getAbsolutePath()
+ + " to remote, but couldn't find local file.");
+ }
+ String filename = new File(fileToUpload).getName();
+
+ Path uploadedFilePath = new Path(stagingDir, filename);
+ if (!uploadedFiles.contains(uploadedFilePath)) {
+ if (SubmarineLogs.isVerbose()) {
+ LOG.info("Copying local file=" + fileToUpload + " to remote="
+ + uploadedFilePath);
+ }
+ fs.copyFromLocalFile(new Path(fileToUpload), uploadedFilePath);
+ uploadedFiles.add(uploadedFilePath);
+ }
+ return uploadedFilePath;
+ }
+
+ public void validFileSize(String uri) throws IOException {
+ long actualSizeByte;
+ String locationType = "Local";
+ if (remoteDirectoryManager.isRemote(uri)) {
+ actualSizeByte = remoteDirectoryManager.getRemoteFileSize(uri);
+ locationType = "Remote";
+ } else {
+ actualSizeByte = FileUtil.getDU(new File(uri));
+ }
+ long maxFileSizeMB = submarineConfig
+ .getLong(SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB,
+ SubmarineConfiguration.DEFAULT_MAX_ALLOWED_REMOTE_URI_SIZE_MB);
+ LOG.info("{} fie/dir: {}, size(Byte):{},"
+ + " Allowed max file/dir size: {}",
+ locationType, uri, actualSizeByte, maxFileSizeMB * 1024 * 1024);
+
+ if (actualSizeByte > maxFileSizeMB * 1024 * 1024) {
+ throw new IOException(uri + " size(Byte): "
+ + actualSizeByte + " exceeds configured max size:"
+ + maxFileSizeMB * 1024 * 1024);
+ }
+ }
+
+ public void setPermission(Path destPath, FsPermission permission) throws
+ IOException {
+ FileSystem fs = FileSystem.get(yarnConfig);
+ fs.setPermission(destPath, new FsPermission(permission));
+ }
+
+ public static boolean needHdfs(String content) {
+ return content != null && content.contains("hdfs://");
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java
new file mode 100644
index 0000000..461525f
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintWriter;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
+import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
+import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.getValueOfEnvironment;
+
+/**
+ * This class contains helper methods to fill HDFS and Java environment
+ * variables into scripts.
+ */
+public class HadoopEnvironmentSetup {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(HadoopEnvironmentSetup.class);
+ private static final String CORE_SITE_XML = "core-site.xml";
+ private static final String HDFS_SITE_XML = "hdfs-site.xml";
+
+ public static final String DOCKER_HADOOP_HDFS_HOME =
+ "DOCKER_HADOOP_HDFS_HOME";
+ public static final String DOCKER_JAVA_HOME = "DOCKER_JAVA_HOME";
+ private final RemoteDirectoryManager remoteDirectoryManager;
+ private final FileSystemOperations fsOperations;
+
+ public HadoopEnvironmentSetup(ClientContext clientContext,
+ FileSystemOperations fsOperations) {
+ this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
+ this.fsOperations = fsOperations;
+ }
+
+ public void addHdfsClassPath(RunJobParameters parameters,
+ PrintWriter fw, Component comp) throws IOException {
+ // Find envs to use HDFS
+ String hdfsHome = null;
+ String javaHome = null;
+
+ boolean hadoopEnv = false;
+
+ for (String envVar : parameters.getEnvars()) {
+ if (envVar.startsWith(DOCKER_HADOOP_HDFS_HOME + "=")) {
+ hdfsHome = getValueOfEnvironment(envVar);
+ hadoopEnv = true;
+ } else if (envVar.startsWith(DOCKER_JAVA_HOME + "=")) {
+ javaHome = getValueOfEnvironment(envVar);
+ }
+ }
+
+ boolean hasHdfsEnvs = hdfsHome != null && javaHome != null;
+ boolean needHdfs = doesNeedHdfs(parameters, hadoopEnv);
+ if (needHdfs) {
+ // HDFS is asked either in input or output, set LD_LIBRARY_PATH
+ // and classpath
+ if (hdfsHome != null) {
+ appendHdfsHome(fw, hdfsHome);
+ }
+
+ // hadoop confs will be uploaded to HDFS and localized to container's
+ // local folder, so here set $HADOOP_CONF_DIR to $WORK_DIR.
+ fw.append("export HADOOP_CONF_DIR=$WORK_DIR\n");
+ if (javaHome != null) {
+ appendJavaHome(fw, javaHome);
+ }
+
+ fw.append(
+ "export CLASSPATH=`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`\n");
+ }
+
+ if (needHdfs && !hasHdfsEnvs) {
+ LOG.error("When HDFS is being used to read/write models/data, " +
+ "the following environment variables are required: " +
+ "1) {}=<HDFS_HOME inside docker container> " +
+ "2) {}=<JAVA_HOME inside docker container>. " +
+ "You can use --env to pass these environment variables.",
+ DOCKER_HADOOP_HDFS_HOME, DOCKER_JAVA_HOME);
+ throw new IOException("Failed to detect HDFS-related environments.");
+ }
+
+ // Trying to upload core-site.xml and hdfs-site.xml
+ Path stagingDir =
+ remoteDirectoryManager.getJobStagingArea(
+ parameters.getName(), true);
+ File coreSite = findFileOnClassPath(CORE_SITE_XML);
+ File hdfsSite = findFileOnClassPath(HDFS_SITE_XML);
+ if (coreSite == null || hdfsSite == null) {
+ LOG.error("HDFS is being used, however we could not locate " +
+ "{} nor {} on classpath! " +
+ "Please double check your classpath setting and make sure these " +
+ "setting files are included!", CORE_SITE_XML, HDFS_SITE_XML);
+ throw new IOException(
+ "Failed to locate core-site.xml / hdfs-site.xml on classpath!");
+ }
+ fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
+ coreSite.getAbsolutePath(), CORE_SITE_XML, comp);
+ fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
+ hdfsSite.getAbsolutePath(), HDFS_SITE_XML, comp);
+
+ // DEBUG
+ if (SubmarineLogs.isVerbose()) {
+ appendEchoOfEnvVars(fw);
+ }
+ }
+
+ private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
+ return needHdfs(parameters.getInputPath()) ||
+ needHdfs(parameters.getPSLaunchCmd()) ||
+ needHdfs(parameters.getWorkerLaunchCmd()) ||
+ hadoopEnv;
+ }
+
+ private void appendHdfsHome(PrintWriter fw, String hdfsHome) {
+ // Unset HADOOP_HOME/HADOOP_YARN_HOME to make sure host machine's envs
+ // won't pollute docker's env.
+ fw.append("export HADOOP_HOME=\n");
+ fw.append("export HADOOP_YARN_HOME=\n");
+ fw.append("export HADOOP_HDFS_HOME=" + hdfsHome + "\n");
+ fw.append("export HADOOP_COMMON_HOME=" + hdfsHome + "\n");
+ }
+
+ private void appendJavaHome(PrintWriter fw, String javaHome) {
+ fw.append("export JAVA_HOME=" + javaHome + "\n");
+ fw.append("export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"
+ + "$JAVA_HOME/lib/amd64/server\n");
+ }
+
+ private void appendEchoOfEnvVars(PrintWriter fw) {
+ fw.append("echo \"CLASSPATH:$CLASSPATH\"\n");
+ fw.append("echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n");
+ fw.append(
+ "echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n");
+ fw.append("echo \"JAVA_HOME:$JAVA_HOME\"\n");
+ fw.append("echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n");
+ fw.append("echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n");
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpec.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpec.java
new file mode 100644
index 0000000..f26d610
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpec.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import java.io.IOException;
+
+/**
+ * This interface is to provide means of creating wrappers around
+ * {@link org.apache.hadoop.yarn.service.api.records.Service} instances.
+ */
+public interface ServiceSpec {
+ ServiceWrapper create() throws IOException;
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java
new file mode 100644
index 0000000..06e36d5
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.yarn.service.api.records.Service;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.io.Writer;
+import java.nio.charset.StandardCharsets;
+
+import static org.apache.hadoop.yarn.service.utils.ServiceApiUtil.jsonSerDeser;
+
+/**
+ * This class is merely responsible for creating Json representation of
+ * {@link Service} instances.
+ */
+public final class ServiceSpecFileGenerator {
+ private ServiceSpecFileGenerator() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ static String generateJson(Service service) throws IOException {
+ File serviceSpecFile = File.createTempFile(service.getName(), ".json");
+ String buffer = jsonSerDeser.toJson(service);
+ Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
+ StandardCharsets.UTF_8);
+ try (PrintWriter pw = new PrintWriter(w)) {
+ pw.append(buffer);
+ }
+ return serviceSpecFile.getAbsolutePath();
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceWrapper.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceWrapper.java
new file mode 100644
index 0000000..3891602
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceWrapper.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.Service;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * This class is only existing because we need a component name to
+ * local launch command mapping from the test code.
+ * Once this is solved in more clean or different way, we can delete this class.
+ */
+public class ServiceWrapper {
+ private final Service service;
+
+ @VisibleForTesting
+ private Map<String, String> componentToLocalLaunchCommand = Maps.newHashMap();
+
+ public ServiceWrapper(Service service) {
+ this.service = service;
+ }
+
+ public void addComponent(AbstractComponent abstractComponent)
+ throws IOException {
+ Component component = abstractComponent.createComponent();
+ service.addComponent(component);
+ storeComponentName(abstractComponent, component.getName());
+ }
+
+ private void storeComponentName(
+ AbstractComponent component, String name) {
+ componentToLocalLaunchCommand.put(name,
+ component.getLocalScriptFile());
+ }
+
+ public Service getService() {
+ return service;
+ }
+
+ public String getLocalLaunchCommandPathForComponent(String componentName) {
+ return componentToLocalLaunchCommand.get(componentName);
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java
index 58a33cf..37445a6 100644
--- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java
@@ -15,858 +15,59 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import com.google.common.annotations.VisibleForTesting;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.FileUtil;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
-import org.apache.hadoop.yarn.service.api.records.Artifact;
-import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.service.api.records.ConfigFile;
-import org.apache.hadoop.yarn.service.api.records.Resource;
-import org.apache.hadoop.yarn.service.api.records.ResourceInformation;
import org.apache.hadoop.yarn.service.api.records.Service;
-import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
-import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
-import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
-import org.apache.hadoop.yarn.submarine.common.Envs;
-import org.apache.hadoop.yarn.submarine.common.api.TaskType;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
-import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
+import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.FileOutputStream;
import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.io.Writer;
-import java.nio.charset.StandardCharsets;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.StringTokenizer;
-import java.util.zip.ZipEntry;
-import java.util.zip.ZipOutputStream;
-import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
-
-import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants
- .CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
-import static org.apache.hadoop.yarn.service.utils.ServiceApiUtil.jsonSerDeser;
/**
- * Submit a job to cluster
+ * Submit a job to cluster.
*/
public class YarnServiceJobSubmitter implements JobSubmitter {
- public static final String TENSORBOARD_QUICKLINK_LABEL = "Tensorboard";
+
private static final Logger LOG =
LoggerFactory.getLogger(YarnServiceJobSubmitter.class);
- ClientContext clientContext;
- Service serviceSpec;
- private Set<Path> uploadedFiles = new HashSet<>();
-
- // Used by testing
- private Map<String, String> componentToLocalLaunchScriptPath =
- new HashMap<>();
+ private ClientContext clientContext;
+ private ServiceWrapper serviceWrapper;
- public YarnServiceJobSubmitter(ClientContext clientContext) {
+ YarnServiceJobSubmitter(ClientContext clientContext) {
this.clientContext = clientContext;
}
- private Resource getServiceResourceFromYarnResource(
- org.apache.hadoop.yarn.api.records.Resource yarnResource) {
- Resource serviceResource = new Resource();
- serviceResource.setCpus(yarnResource.getVirtualCores());
- serviceResource.setMemory(String.valueOf(yarnResource.getMemorySize()));
-
- Map<String, ResourceInformation> riMap = new HashMap<>();
- for (org.apache.hadoop.yarn.api.records.ResourceInformation ri : yarnResource
- .getAllResourcesListCopy()) {
- ResourceInformation serviceRi =
- new ResourceInformation();
- serviceRi.setValue(ri.getValue());
- serviceRi.setUnit(ri.getUnits());
- riMap.put(ri.getName(), serviceRi);
- }
- serviceResource.setResourceInformations(riMap);
-
- return serviceResource;
- }
-
- private String getValueOfEnvironment(String envar) {
- // extract value from "key=value" form
- if (envar == null || !envar.contains("=")) {
- return "";
- } else {
- return envar.substring(envar.indexOf("=") + 1);
- }
- }
-
- private boolean needHdfs(String content) {
- return content != null && content.contains("hdfs://");
- }
-
- private void addHdfsClassPathIfNeeded(RunJobParameters parameters,
- PrintWriter fw, Component comp) throws IOException {
- // Find envs to use HDFS
- String hdfsHome = null;
- String javaHome = null;
-
- boolean hadoopEnv = false;
-
- for (String envar : parameters.getEnvars()) {
- if (envar.startsWith("DOCKER_HADOOP_HDFS_HOME=")) {
- hdfsHome = getValueOfEnvironment(envar);
- hadoopEnv = true;
- } else if (envar.startsWith("DOCKER_JAVA_HOME=")) {
- javaHome = getValueOfEnvironment(envar);
- }
- }
-
- boolean lackingEnvs = false;
-
- if (needHdfs(parameters.getInputPath()) || needHdfs(
- parameters.getPSLaunchCmd()) || needHdfs(
- parameters.getWorkerLaunchCmd()) || hadoopEnv) {
- // HDFS is asked either in input or output, set LD_LIBRARY_PATH
- // and classpath
- if (hdfsHome != null) {
- // Unset HADOOP_HOME/HADOOP_YARN_HOME to make sure host machine's envs
- // won't pollute docker's env.
- fw.append("export HADOOP_HOME=\n");
- fw.append("export HADOOP_YARN_HOME=\n");
- fw.append("export HADOOP_HDFS_HOME=" + hdfsHome + "\n");
- fw.append("export HADOOP_COMMON_HOME=" + hdfsHome + "\n");
- } else{
- lackingEnvs = true;
- }
-
- // hadoop confs will be uploaded to HDFS and localized to container's
- // local folder, so here set $HADOOP_CONF_DIR to $WORK_DIR.
- fw.append("export HADOOP_CONF_DIR=$WORK_DIR\n");
- if (javaHome != null) {
- fw.append("export JAVA_HOME=" + javaHome + "\n");
- fw.append("export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"
- + "$JAVA_HOME/lib/amd64/server\n");
- } else {
- lackingEnvs = true;
- }
- fw.append("export CLASSPATH=`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`\n");
- }
-
- if (lackingEnvs) {
- LOG.error("When hdfs is being used to read/write models/data. Following"
- + "envs are required: 1) DOCKER_HADOOP_HDFS_HOME=<HDFS_HOME inside"
- + "docker container> 2) DOCKER_JAVA_HOME=<JAVA_HOME inside docker"
- + "container>. You can use --env to pass these envars.");
- throw new IOException("Failed to detect HDFS-related environments.");
- }
-
- // Trying to upload core-site.xml and hdfs-site.xml
- Path stagingDir =
- clientContext.getRemoteDirectoryManager().getJobStagingArea(
- parameters.getName(), true);
- File coreSite = findFileOnClassPath("core-site.xml");
- File hdfsSite = findFileOnClassPath("hdfs-site.xml");
- if (coreSite == null || hdfsSite == null) {
- LOG.error("hdfs is being used, however we couldn't locate core-site.xml/"
- + "hdfs-site.xml from classpath, please double check you classpath"
- + "setting and make sure they're included.");
- throw new IOException(
- "Failed to locate core-site.xml / hdfs-site.xml from class path");
- }
- uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
- coreSite.getAbsolutePath(), "core-site.xml", comp);
- uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
- hdfsSite.getAbsolutePath(), "hdfs-site.xml", comp);
-
- // DEBUG
- if (SubmarineLogs.isVerbose()) {
- fw.append("echo \"CLASSPATH:$CLASSPATH\"\n");
- fw.append("echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n");
- fw.append("echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n");
- fw.append("echo \"JAVA_HOME:$JAVA_HOME\"\n");
- fw.append("echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n");
- fw.append("echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n");
- }
- }
-
- private void addCommonEnvironments(Component component, TaskType taskType) {
- Map<String, String> envs = component.getConfiguration().getEnv();
- envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
- envs.put(Envs.TASK_TYPE_ENV, taskType.name());
- }
-
- @VisibleForTesting
- protected String getUserName() {
- return System.getProperty("user.name");
- }
-
- private String getDNSDomain() {
- return clientContext.getYarnConfig().get("hadoop.registry.dns.domain-name");
- }
-
- /*
- * Generate a command launch script on local disk, returns patch to the script
- */
- private String generateCommandLaunchScript(RunJobParameters parameters,
- TaskType taskType, Component comp) throws IOException {
- File file = File.createTempFile(taskType.name() + "-launch-script", ".sh");
- Writer w = new OutputStreamWriter(new FileOutputStream(file),
- StandardCharsets.UTF_8);
- PrintWriter pw = new PrintWriter(w);
-
- try {
- pw.append("#!/bin/bash\n");
-
- addHdfsClassPathIfNeeded(parameters, pw, comp);
-
- if (taskType.equals(TaskType.TENSORBOARD)) {
- String tbCommand =
- "export LC_ALL=C && tensorboard --logdir=" + parameters
- .getCheckpointPath();
- pw.append(tbCommand + "\n");
- LOG.info("Tensorboard command=" + tbCommand);
- } else{
- // When distributed training is required
- if (parameters.isDistributed()) {
- // Generated TF_CONFIG
- String tfConfigEnv = YarnServiceUtils.getTFConfigEnv(
- taskType.getComponentName(), parameters.getNumWorkers(),
- parameters.getNumPS(), parameters.getName(), getUserName(),
- getDNSDomain());
- pw.append("export TF_CONFIG=\"" + tfConfigEnv + "\"\n");
- }
-
- // Print launch command
- if (taskType.equals(TaskType.WORKER) || taskType.equals(
- TaskType.PRIMARY_WORKER)) {
- pw.append(parameters.getWorkerLaunchCmd() + '\n');
-
- if (SubmarineLogs.isVerbose()) {
- LOG.info(
- "Worker command =[" + parameters.getWorkerLaunchCmd() + "]");
- }
- } else if (taskType.equals(TaskType.PS)) {
- pw.append(parameters.getPSLaunchCmd() + '\n');
-
- if (SubmarineLogs.isVerbose()) {
- LOG.info("PS command =[" + parameters.getPSLaunchCmd() + "]");
- }
- }
- }
- } finally {
- pw.close();
- }
- return file.getAbsolutePath();
- }
-
- private String getScriptFileName(TaskType taskType) {
- return "run-" + taskType.name() + ".sh";
- }
-
- private File findFileOnClassPath(final String fileName) {
- final String classpath = System.getProperty("java.class.path");
- final String pathSeparator = System.getProperty("path.separator");
- final StringTokenizer tokenizer = new StringTokenizer(classpath,
- pathSeparator);
-
- while (tokenizer.hasMoreTokens()) {
- final String pathElement = tokenizer.nextToken();
- final File directoryOrJar = new File(pathElement);
- final File absoluteDirectoryOrJar = directoryOrJar.getAbsoluteFile();
- if (absoluteDirectoryOrJar.isFile()) {
- final File target = new File(absoluteDirectoryOrJar.getParent(),
- fileName);
- if (target.exists()) {
- return target;
- }
- } else{
- final File target = new File(directoryOrJar, fileName);
- if (target.exists()) {
- return target;
- }
- }
- }
-
- return null;
- }
-
- private void uploadToRemoteFileAndLocalizeToContainerWorkDir(Path stagingDir,
- String fileToUpload, String destFilename, Component comp)
- throws IOException {
- Path uploadedFilePath = uploadToRemoteFile(stagingDir, fileToUpload);
- locateRemoteFileToContainerWorkDir(destFilename, comp, uploadedFilePath);
- }
-
- private void locateRemoteFileToContainerWorkDir(String destFilename,
- Component comp, Path uploadedFilePath)
- throws IOException {
- FileSystem fs = FileSystem.get(clientContext.getYarnConfig());
-
- FileStatus fileStatus = fs.getFileStatus(uploadedFilePath);
- LOG.info("Uploaded file path = " + fileStatus.getPath());
-
- // Set it to component's files list
- comp.getConfiguration().getFiles().add(new ConfigFile().srcFile(
- fileStatus.getPath().toUri().toString()).destFile(destFilename)
- .type(ConfigFile.TypeEnum.STATIC));
- }
-
- private Path uploadToRemoteFile(Path stagingDir, String fileToUpload) throws
- IOException {
- FileSystem fs = clientContext.getRemoteDirectoryManager()
- .getDefaultFileSystem();
-
- // Upload to remote FS under staging area
- File localFile = new File(fileToUpload);
- if (!localFile.exists()) {
- throw new FileNotFoundException(
- "Trying to upload file=" + localFile.getAbsolutePath()
- + " to remote, but couldn't find local file.");
- }
- String filename = new File(fileToUpload).getName();
-
- Path uploadedFilePath = new Path(stagingDir, filename);
- if (!uploadedFiles.contains(uploadedFilePath)) {
- if (SubmarineLogs.isVerbose()) {
- LOG.info("Copying local file=" + fileToUpload + " to remote="
- + uploadedFilePath);
- }
- fs.copyFromLocalFile(new Path(fileToUpload), uploadedFilePath);
- uploadedFiles.add(uploadedFilePath);
- }
- return uploadedFilePath;
- }
-
- private void setPermission(Path destPath, FsPermission permission) throws
- IOException {
- FileSystem fs = FileSystem.get(clientContext.getYarnConfig());
- fs.setPermission(destPath, new FsPermission(permission));
- }
-
- private void handleLaunchCommand(RunJobParameters parameters,
- TaskType taskType, Component component) throws IOException {
- // Get staging area directory
- Path stagingDir =
- clientContext.getRemoteDirectoryManager().getJobStagingArea(
- parameters.getName(), true);
-
- // Generate script file in the local disk
- String localScriptFile = generateCommandLaunchScript(parameters, taskType,
- component);
- String destScriptFileName = getScriptFileName(taskType);
- uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, localScriptFile,
- destScriptFileName, component);
-
- component.setLaunchCommand("./" + destScriptFileName);
- componentToLocalLaunchScriptPath.put(taskType.getComponentName(),
- localScriptFile);
- }
-
- private String getLastNameFromPath(String srcFileStr) {
- return new Path(srcFileStr).getName();
- }
-
- /**
- * May download a remote uri(file/dir) and zip.
- * Skip download if local dir
- * Remote uri can be a local dir(won't download)
- * or remote HDFS dir, s3 dir/file .etc
- * */
- private String mayDownloadAndZipIt(String remoteDir, String zipFileName,
- boolean doZip)
- throws IOException {
- RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
- //Append original modification time and size to zip file name
- String suffix;
- String srcDir = remoteDir;
- String zipDirPath =
- System.getProperty("java.io.tmpdir") + "/" + zipFileName;
- boolean needDeleteTempDir = false;
- if (rdm.isRemote(remoteDir)) {
- //Append original modification time and size to zip file name
- FileStatus status = rdm.getRemoteFileStatus(new Path(remoteDir));
- suffix = "_" + status.getModificationTime()
- + "-" + rdm.getRemoteFileSize(remoteDir);
- // Download them to temp dir
- boolean downloaded = rdm.copyRemoteToLocal(remoteDir, zipDirPath);
- if (!downloaded) {
- throw new IOException("Failed to download files from "
- + remoteDir);
- }
- LOG.info("Downloaded remote: {} to local: {}", remoteDir, zipDirPath);
- srcDir = zipDirPath;
- needDeleteTempDir = true;
- } else {
- File localDir = new File(remoteDir);
- suffix = "_" + localDir.lastModified()
- + "-" + localDir.length();
- }
- if (!doZip) {
- return srcDir;
- }
- // zip a local dir
- String zipFileUri = zipDir(srcDir, zipDirPath + suffix + ".zip");
- // delete downloaded temp dir
- if (needDeleteTempDir) {
- deleteFiles(srcDir);
- }
- return zipFileUri;
- }
-
- @VisibleForTesting
- public String zipDir(String srcDir, String dstFile) throws IOException {
- FileOutputStream fos = new FileOutputStream(dstFile);
- ZipOutputStream zos = new ZipOutputStream(fos);
- File srcFile = new File(srcDir);
- LOG.info("Compressing {}", srcDir);
- addDirToZip(zos, srcFile, srcFile);
- // close the ZipOutputStream
- zos.close();
- LOG.info("Compressed {} to {}", srcDir, dstFile);
- return dstFile;
- }
-
- private void deleteFiles(String localUri) {
- boolean success = FileUtil.fullyDelete(new File(localUri));
- if (!success) {
- LOG.warn("Fail to delete {}", localUri);
- }
- LOG.info("Deleted {}", localUri);
- }
-
- private void addDirToZip(ZipOutputStream zos, File srcFile, File base)
- throws IOException {
- File[] files = srcFile.listFiles();
- if (null == files) {
- return;
- }
- FileInputStream fis = null;
- for (int i = 0; i < files.length; i++) {
- // if it's directory, add recursively
- if (files[i].isDirectory()) {
- addDirToZip(zos, files[i], base);
- continue;
- }
- byte[] buffer = new byte[1024];
- try {
- fis = new FileInputStream(files[i]);
- String name = base.toURI().relativize(files[i].toURI()).getPath();
- LOG.info(" Zip adding: " + name);
- zos.putNextEntry(new ZipEntry(name));
- int length;
- while ((length = fis.read(buffer)) > 0) {
- zos.write(buffer, 0, length);
- }
- zos.flush();
- } finally {
- if (fis != null) {
- fis.close();
- }
- zos.closeEntry();
- }
- }
- }
-
- private void addWorkerComponent(Service service,
- RunJobParameters parameters, TaskType taskType) throws IOException {
- Component workerComponent = new Component();
- addCommonEnvironments(workerComponent, taskType);
-
- workerComponent.setName(taskType.getComponentName());
-
- if (taskType.equals(TaskType.PRIMARY_WORKER)) {
- workerComponent.setNumberOfContainers(1L);
- workerComponent.getConfiguration().setProperty(
- CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
- } else{
- workerComponent.setNumberOfContainers(
- (long) parameters.getNumWorkers() - 1);
- }
-
- if (parameters.getWorkerDockerImage() != null) {
- workerComponent.setArtifact(
- getDockerArtifact(parameters.getWorkerDockerImage()));
- }
-
- workerComponent.setResource(
- getServiceResourceFromYarnResource(parameters.getWorkerResource()));
- handleLaunchCommand(parameters, taskType, workerComponent);
- workerComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
- service.addComponent(workerComponent);
- }
-
- // Handle worker and primary_worker.
- private void addWorkerComponents(Service service, RunJobParameters parameters)
- throws IOException {
- addWorkerComponent(service, parameters, TaskType.PRIMARY_WORKER);
-
- if (parameters.getNumWorkers() > 1) {
- addWorkerComponent(service, parameters, TaskType.WORKER);
- }
- }
-
- private void appendToEnv(Service service, String key, String value,
- String delim) {
- Map<String, String> env = service.getConfiguration().getEnv();
- if (!env.containsKey(key)) {
- env.put(key, value);
- } else {
- if (!value.isEmpty()) {
- String existingValue = env.get(key);
- if (!existingValue.endsWith(delim)) {
- env.put(key, existingValue + delim + value);
- } else {
- env.put(key, existingValue + value);
- }
- }
- }
- }
-
- private void handleServiceEnvs(Service service, RunJobParameters parameters) {
- if (parameters.getEnvars() != null) {
- for (String envarPair : parameters.getEnvars()) {
- String key, value;
- if (envarPair.contains("=")) {
- int idx = envarPair.indexOf('=');
- key = envarPair.substring(0, idx);
- value = envarPair.substring(idx + 1);
- } else{
- // No "=" found so use the whole key
- key = envarPair;
- value = "";
- }
- appendToEnv(service, key, value, ":");
- }
- }
-
- // Append other configs like /etc/passwd, /etc/krb5.conf
- appendToEnv(service, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS",
- "/etc/passwd:/etc/passwd:ro", ",");
-
- String authenication = clientContext.getYarnConfig().get(
- HADOOP_SECURITY_AUTHENTICATION);
- if (authenication != null && authenication.equals("kerberos")) {
- appendToEnv(service, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS",
- "/etc/krb5.conf:/etc/krb5.conf:ro", ",");
- }
- }
-
- private Artifact getDockerArtifact(String dockerImageName) {
- return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName);
- }
-
- private void handleQuicklinks(RunJobParameters runJobParameters)
- throws IOException {
- List<Quicklink> quicklinks = runJobParameters.getQuicklinks();
- if (null != quicklinks && !quicklinks.isEmpty()) {
- for (Quicklink ql : quicklinks) {
- // Make sure it is a valid instance name
- String instanceName = ql.getComponentInstanceName();
- boolean found = false;
-
- for (Component comp : serviceSpec.getComponents()) {
- for (int i = 0; i < comp.getNumberOfContainers(); i++) {
- String possibleInstanceName = comp.getName() + "-" + i;
- if (possibleInstanceName.equals(instanceName)) {
- found = true;
- break;
- }
- }
- }
-
- if (!found) {
- throw new IOException(
- "Couldn't find a component instance = " + instanceName
- + " while adding quicklink");
- }
-
- String link = ql.getProtocol() + YarnServiceUtils.getDNSName(
- serviceSpec.getName(), instanceName, getUserName(), getDNSDomain(),
- ql.getPort());
- YarnServiceUtils.addQuicklink(serviceSpec, ql.getLabel(), link);
- }
- }
- }
-
- private Service createServiceByParameters(RunJobParameters parameters)
- throws IOException {
- componentToLocalLaunchScriptPath.clear();
- serviceSpec = new Service();
- serviceSpec.setName(parameters.getName());
- serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
- serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
- handleKerberosPrincipal(parameters);
-
- handleServiceEnvs(serviceSpec, parameters);
-
- handleLocalizations(parameters);
-
- if (parameters.getNumWorkers() > 0) {
- addWorkerComponents(serviceSpec, parameters);
- }
-
- if (parameters.getNumPS() > 0) {
- Component psComponent = new Component();
- psComponent.setName(TaskType.PS.getComponentName());
- addCommonEnvironments(psComponent, TaskType.PS);
- psComponent.setNumberOfContainers((long) parameters.getNumPS());
- psComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
- psComponent.setResource(
- getServiceResourceFromYarnResource(parameters.getPsResource()));
-
- // Override global docker image if needed.
- if (parameters.getPsDockerImage() != null) {
- psComponent.setArtifact(
- getDockerArtifact(parameters.getPsDockerImage()));
- }
- handleLaunchCommand(parameters, TaskType.PS, psComponent);
- serviceSpec.addComponent(psComponent);
- }
-
- if (parameters.isTensorboardEnabled()) {
- Component tbComponent = new Component();
- tbComponent.setName(TaskType.TENSORBOARD.getComponentName());
- addCommonEnvironments(tbComponent, TaskType.TENSORBOARD);
- tbComponent.setNumberOfContainers(1L);
- tbComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
- tbComponent.setResource(getServiceResourceFromYarnResource(
- parameters.getTensorboardResource()));
- if (parameters.getTensorboardDockerImage() != null) {
- tbComponent.setArtifact(
- getDockerArtifact(parameters.getTensorboardDockerImage()));
- }
-
- handleLaunchCommand(parameters, TaskType.TENSORBOARD, tbComponent);
-
- // Add tensorboard to quicklink
- String tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
- parameters.getName(),
- TaskType.TENSORBOARD.getComponentName() + "-" + 0, getUserName(),
- getDNSDomain(), 6006);
- LOG.info("Link to tensorboard:" + tensorboardLink);
- serviceSpec.addComponent(tbComponent);
-
- YarnServiceUtils.addQuicklink(serviceSpec, TENSORBOARD_QUICKLINK_LABEL,
- tensorboardLink);
- }
-
- // After all components added, handle quicklinks
- handleQuicklinks(parameters);
-
- return serviceSpec;
- }
-
- /**
- * Localize dependencies for all containers.
- * If remoteUri is a local directory,
- * we'll zip it, upload to HDFS staging dir HDFS.
- * If remoteUri is directory, we'll download it, zip it and upload
- * to HDFS.
- * If localFilePath is ".", we'll use remoteUri's file/dir name
- * */
- private void handleLocalizations(RunJobParameters parameters)
- throws IOException {
- // Handle localizations
- Path stagingDir =
- clientContext.getRemoteDirectoryManager().getJobStagingArea(
- parameters.getName(), true);
- List<Localization> locs = parameters.getLocalizations();
- String remoteUri;
- String containerLocalPath;
- RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
-
- // Check to fail fast
- for (Localization loc : locs) {
- remoteUri = loc.getRemoteUri();
- Path resourceToLocalize = new Path(remoteUri);
- // Check if remoteUri exists
- if (rdm.isRemote(remoteUri)) {
- // check if exists
- if (!rdm.existsRemoteFile(resourceToLocalize)) {
- throw new FileNotFoundException(
- "File " + remoteUri + " doesn't exists.");
- }
- } else {
- // Check if exists
- File localFile = new File(remoteUri);
- if (!localFile.exists()) {
- throw new FileNotFoundException(
- "File " + remoteUri + " doesn't exists.");
- }
- }
- // check remote file size
- validFileSize(remoteUri);
- }
- // Start download remote if needed and upload to HDFS
- for (Localization loc : locs) {
- remoteUri = loc.getRemoteUri();
- containerLocalPath = loc.getLocalPath();
- String srcFileStr = remoteUri;
- ConfigFile.TypeEnum destFileType = ConfigFile.TypeEnum.STATIC;
- Path resourceToLocalize = new Path(remoteUri);
- boolean needUploadToHDFS = true;
-
- /**
- * Special handling for remoteUri directory.
- * */
- boolean needDeleteTempFile = false;
- if (rdm.isDir(remoteUri)) {
- destFileType = ConfigFile.TypeEnum.ARCHIVE;
- srcFileStr = mayDownloadAndZipIt(
- remoteUri, getLastNameFromPath(srcFileStr), true);
- } else if (rdm.isRemote(remoteUri)) {
- if (!needHdfs(remoteUri)) {
- // Non HDFS remote uri. Non directory, no need to zip
- srcFileStr = mayDownloadAndZipIt(
- remoteUri, getLastNameFromPath(srcFileStr), false);
- needDeleteTempFile = true;
- } else {
- // HDFS file, no need to upload
- needUploadToHDFS = false;
- }
- }
-
- // Upload file to HDFS
- if (needUploadToHDFS) {
- resourceToLocalize = uploadToRemoteFile(stagingDir, srcFileStr);
- }
- if (needDeleteTempFile) {
- deleteFiles(srcFileStr);
- }
- // Remove .zip from zipped dir name
- if (destFileType == ConfigFile.TypeEnum.ARCHIVE
- && srcFileStr.endsWith(".zip")) {
- // Delete local zip file
- deleteFiles(srcFileStr);
- int suffixIndex = srcFileStr.lastIndexOf('_');
- srcFileStr = srcFileStr.substring(0, suffixIndex);
- }
- // If provided, use the name of local uri
- if (!containerLocalPath.equals(".")
- && !containerLocalPath.equals("./")) {
- // Change the YARN localized file name to what'll used in container
- srcFileStr = getLastNameFromPath(containerLocalPath);
- }
- String localizedName = getLastNameFromPath(srcFileStr);
- LOG.info("The file/dir to be localized is {}",
- resourceToLocalize.toString());
- LOG.info("Its localized file name will be {}", localizedName);
- serviceSpec.getConfiguration().getFiles().add(new ConfigFile().srcFile(
- resourceToLocalize.toUri().toString()).destFile(localizedName)
- .type(destFileType));
- // set mounts
- // if mount path is absolute, just use it.
- // if relative, no need to mount explicitly
- if (containerLocalPath.startsWith("/")) {
- String mountStr = getLastNameFromPath(srcFileStr) + ":"
- + containerLocalPath + ":" + loc.getMountPermission();
- LOG.info("Add bind-mount string {}", mountStr);
- appendToEnv(serviceSpec, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS",
- mountStr, ",");
- }
- }
- }
-
- private void validFileSize(String uri) throws IOException {
- RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
- long actualSizeByte;
- String locationType = "Local";
- if (rdm.isRemote(uri)) {
- actualSizeByte = clientContext.getRemoteDirectoryManager()
- .getRemoteFileSize(uri);
- locationType = "Remote";
- } else {
- actualSizeByte = FileUtil.getDU(new File(uri));
- }
- long maxFileSizeMB = clientContext.getSubmarineConfig()
- .getLong(SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB,
- SubmarineConfiguration.DEFAULT_MAX_ALLOWED_REMOTE_URI_SIZE_MB);
- LOG.info("{} fie/dir: {}, size(Byte):{},"
- + " Allowed max file/dir size: {}",
- locationType, uri, actualSizeByte, maxFileSizeMB * 1024 * 1024);
-
- if (actualSizeByte > maxFileSizeMB * 1024 * 1024) {
- throw new IOException(uri + " size(Byte): "
- + actualSizeByte + " exceeds configured max size:"
- + maxFileSizeMB * 1024 * 1024);
- }
- }
-
- private String generateServiceSpecFile(Service service) throws IOException {
- File serviceSpecFile = File.createTempFile(service.getName(), ".json");
- String buffer = jsonSerDeser.toJson(service);
- Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
- "UTF-8");
- PrintWriter pw = new PrintWriter(w);
- try {
- pw.append(buffer);
- } finally {
- pw.close();
- }
- return serviceSpecFile.getAbsolutePath();
- }
-
- private void handleKerberosPrincipal(RunJobParameters parameters) throws
- IOException {
- if(StringUtils.isNotBlank(parameters.getKeytab()) && StringUtils
- .isNotBlank(parameters.getPrincipal())) {
- String keytab = parameters.getKeytab();
- String principal = parameters.getPrincipal();
- if(parameters.isDistributeKeytab()) {
- Path stagingDir =
- clientContext.getRemoteDirectoryManager().getJobStagingArea(
- parameters.getName(), true);
- Path remoteKeytabPath = uploadToRemoteFile(stagingDir, keytab);
- //only the owner has read access
- setPermission(remoteKeytabPath,
- FsPermission.createImmutable((short)Integer.parseInt("400", 8)));
- serviceSpec.setKerberosPrincipal(new KerberosPrincipal().keytab(
- remoteKeytabPath.toString()).principalName(principal));
- } else {
- if(!keytab.startsWith("file")) {
- keytab = "file://" + keytab;
- }
- serviceSpec.setKerberosPrincipal(new KerberosPrincipal().keytab(
- keytab).principalName(principal));
- }
- }
- }
-
/**
* {@inheritDoc}
*/
@Override
public ApplicationId submitJob(RunJobParameters parameters)
throws IOException, YarnException {
- createServiceByParameters(parameters);
- String serviceSpecFile = generateServiceSpecFile(serviceSpec);
+ FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
+ HadoopEnvironmentSetup hadoopEnvSetup =
+ new HadoopEnvironmentSetup(clientContext, fsOperations);
- AppAdminClient appAdminClient = YarnServiceUtils.createServiceClient(
- clientContext.getYarnConfig());
+ Service serviceSpec = createTensorFlowServiceSpec(parameters,
+ fsOperations, hadoopEnvSetup);
+ String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
+
+ AppAdminClient appAdminClient =
+ YarnServiceUtils.createServiceClient(clientContext.getYarnConfig());
int code = appAdminClient.actionLaunch(serviceSpecFile,
serviceSpec.getName(), null, null);
- if(code != EXIT_SUCCESS) {
- throw new YarnException("Fail to launch application with exit code:" +
- code);
+ if (code != EXIT_SUCCESS) {
+ throw new YarnException(
+ "Fail to launch application with exit code:" + code);
}
String appStatus=appAdminClient.getStatusString(serviceSpec.getName());
@@ -896,13 +97,24 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
return appid;
}
- @VisibleForTesting
- public Service getServiceSpec() {
- return serviceSpec;
+ private Service createTensorFlowServiceSpec(RunJobParameters parameters,
+ FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
+ throws IOException {
+ LaunchCommandFactory launchCommandFactory =
+ new LaunchCommandFactory(hadoopEnvSetup, parameters,
+ clientContext.getYarnConfig());
+ Localizer localizer = new Localizer(fsOperations,
+ clientContext.getRemoteDirectoryManager(), parameters);
+ TensorFlowServiceSpec tensorFlowServiceSpec = new TensorFlowServiceSpec(
+ parameters, this.clientContext, fsOperations, launchCommandFactory,
+ localizer);
+
+ serviceWrapper = tensorFlowServiceSpec.create();
+ return serviceWrapper.getService();
}
@VisibleForTesting
- public Map<String, String> getComponentToLocalLaunchScriptPath() {
- return componentToLocalLaunchScriptPath;
+ public ServiceWrapper getServiceWrapper() {
+ return serviceWrapper;
}
}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java
index c599fc9..352fd79 100644
--- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java
@@ -17,33 +17,27 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.client.api.AppAdminClient;
-import org.apache.hadoop.yarn.service.api.records.Service;
-import org.apache.hadoop.yarn.submarine.common.Envs;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.HashMap;
-import java.util.Map;
import static org.apache.hadoop.yarn.client.api.AppAdminClient.DEFAULT_TYPE;
-public class YarnServiceUtils {
- private static final Logger LOG =
- LoggerFactory.getLogger(YarnServiceUtils.class);
+/**
+ * This class contains some static helper methods to query DNS data
+ * based on the provided parameters.
+ */
+public final class YarnServiceUtils {
+ private YarnServiceUtils() {
+ }
// This will be true only in UT.
private static AppAdminClient stubServiceClient = null;
- public static AppAdminClient createServiceClient(
+ static AppAdminClient createServiceClient(
Configuration yarnConfiguration) {
if (stubServiceClient != null) {
return stubServiceClient;
}
- AppAdminClient serviceClient = AppAdminClient.createAppAdminClient(
- DEFAULT_TYPE, yarnConfiguration);
- return serviceClient;
+ return AppAdminClient.createAppAdminClient(DEFAULT_TYPE, yarnConfiguration);
}
@VisibleForTesting
@@ -57,77 +51,9 @@ public class YarnServiceUtils {
domain, port);
}
- private static String getDNSNameCommonSuffix(String serviceName,
+ public static String getDNSNameCommonSuffix(String serviceName,
String userName, String domain, int port) {
return "." + serviceName + "." + userName + "." + domain + ":" + port;
}
- public static String getTFConfigEnv(String curCommponentName, int nWorkers,
- int nPs, String serviceName, String userName, String domain) {
- String commonEndpointSuffix = getDNSNameCommonSuffix(serviceName, userName,
- domain, 8000);
-
- String json = "{\\\"cluster\\\":{";
-
- String master = getComponentArrayJson("master", 1, commonEndpointSuffix)
- + ",";
- String worker = getComponentArrayJson("worker", nWorkers - 1,
- commonEndpointSuffix) + ",";
- String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},";
-
- StringBuilder sb = new StringBuilder();
- sb.append("\\\"task\\\":{");
- sb.append(" \\\"type\\\":\\\"");
- sb.append(curCommponentName);
- sb.append("\\\",");
- sb.append(" \\\"index\\\":");
- sb.append('$');
- sb.append(Envs.TASK_INDEX_ENV + "},");
- String task = sb.toString();
- String environment = "\\\"environment\\\":\\\"cloud\\\"}";
-
- sb = new StringBuilder();
- sb.append(json);
- sb.append(master);
- sb.append(worker);
- sb.append(ps);
- sb.append(task);
- sb.append(environment);
- return sb.toString();
- }
-
- public static void addQuicklink(Service serviceSpec, String label,
- String link) {
- Map<String, String> quicklinks = serviceSpec.getQuicklinks();
- if (null == quicklinks) {
- quicklinks = new HashMap<>();
- serviceSpec.setQuicklinks(quicklinks);
- }
-
- if (SubmarineLogs.isVerbose()) {
- LOG.info("Added quicklink, " + label + "=" + link);
- }
-
- quicklinks.put(label, link);
- }
-
- private static String getComponentArrayJson(String componentName, int count,
- String endpointSuffix) {
- String component = "\\\"" + componentName + "\\\":";
- StringBuilder array = new StringBuilder();
- array.append("[");
- for (int i = 0; i < count; i++) {
- array.append("\\\"");
- array.append(componentName);
- array.append("-");
- array.append(i);
- array.append(endpointSuffix);
- array.append("\\\"");
- if (i != count - 1) {
- array.append(",");
- }
- }
- array.append("]");
- return component + array.toString();
- }
}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java
new file mode 100644
index 0000000..cd86e40
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * Abstract base class for Launch command implementations for Services.
+ * Currently we have launch command implementations
+ * for TensorFlow PS, worker and Tensorboard instances.
+ */
+public abstract class AbstractLaunchCommand {
+ private final LaunchScriptBuilder builder;
+
+ public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
+ TaskType taskType, Component component, RunJobParameters parameters)
+ throws IOException {
+ Objects.requireNonNull(taskType, "TaskType must not be null!");
+ this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
+ parameters, component);
+ }
+
+ protected LaunchScriptBuilder getBuilder() {
+ return builder;
+ }
+
+ /**
+ * Subclasses need to defined this method and return a valid launch script.
+ * Implementors can utilize the {@link LaunchScriptBuilder} using
+ * the getBuilder method of this class.
+ * @return The contents of a script.
+ * @throws IOException If any IO issue happens.
+ */
+ public abstract String generateLaunchScript() throws IOException;
+
+ /**
+ * Subclasses need to provide a service-specific launch command
+ * of the service.
+ * Please note that this method should only return the launch command
+ * but not the whole script.
+ * @return The launch command
+ */
+ public abstract String createLaunchCommand();
+
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java
new file mode 100644
index 0000000..572e65a
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * Simple factory to create instances of {@link AbstractLaunchCommand}
+ * based on the {@link TaskType}.
+ * All dependencies are passed to this factory that could be required
+ * by any implementor of {@link AbstractLaunchCommand}.
+ */
+public class LaunchCommandFactory {
+ private final HadoopEnvironmentSetup hadoopEnvSetup;
+ private final RunJobParameters parameters;
+ private final Configuration yarnConfig;
+
+ public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
+ RunJobParameters parameters, Configuration yarnConfig) {
+ this.hadoopEnvSetup = hadoopEnvSetup;
+ this.parameters = parameters;
+ this.yarnConfig = yarnConfig;
+ }
+
+ public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
+ Component component) throws IOException {
+ Objects.requireNonNull(taskType, "TaskType must not be null!");
+
+ if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
+ return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
+ component, parameters, yarnConfig);
+
+ } else if (taskType == TaskType.PS) {
+ return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
+ parameters, yarnConfig);
+
+ } else if (taskType == TaskType.TENSORBOARD) {
+ return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
+ parameters);
+ }
+ throw new IllegalStateException("Unknown task type: " + taskType);
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java
new file mode 100644
index 0000000..d24a0a7
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+/**
+ * This class is a builder to conveniently create launch scripts.
+ * All dependencies are provided with the constructor except
+ * the launch command.
+ */
+public class LaunchScriptBuilder {
+ private static final Logger LOG = LoggerFactory.getLogger(
+ LaunchScriptBuilder.class);
+
+ private final File file;
+ private final HadoopEnvironmentSetup hadoopEnvSetup;
+ private final RunJobParameters parameters;
+ private final Component component;
+ private final OutputStreamWriter writer;
+ private final StringBuilder scriptBuffer;
+ private String launchCommand;
+
+ LaunchScriptBuilder(String namePrefix,
+ HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
+ Component component) throws IOException {
+ this.file = File.createTempFile(namePrefix + "-launch-script", ".sh");
+ this.hadoopEnvSetup = hadoopEnvSetup;
+ this.parameters = parameters;
+ this.component = component;
+ this.writer = new OutputStreamWriter(new FileOutputStream(file), UTF_8);
+ this.scriptBuffer = new StringBuilder();
+ }
+
+ public void append(String s) {
+ scriptBuffer.append(s);
+ }
+
+ public LaunchScriptBuilder withLaunchCommand(String command) {
+ this.launchCommand = command;
+ return this;
+ }
+
+ public String build() throws IOException {
+ if (launchCommand != null) {
+ append(launchCommand);
+ } else {
+ LOG.warn("LaunchScript object was null!");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("LaunchScript's Builder object: {}", this);
+ }
+ }
+
+ try (PrintWriter pw = new PrintWriter(writer)) {
+ writeBashHeader(pw);
+ hadoopEnvSetup.addHdfsClassPath(parameters, pw, component);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Appending command to launch script: {}", scriptBuffer);
+ }
+ pw.append(scriptBuffer);
+ }
+ return file.getAbsolutePath();
+ }
+
+ @Override
+ public String toString() {
+ return "LaunchScriptBuilder{" +
+ "file=" + file +
+ ", hadoopEnvSetup=" + hadoopEnvSetup +
+ ", parameters=" + parameters +
+ ", component=" + component +
+ ", writer=" + writer +
+ ", scriptBuffer=" + scriptBuffer +
+ ", launchCommand='" + launchCommand + '\'' +
+ '}';
+ }
+
+ private void writeBashHeader(PrintWriter pw) {
+ pw.append("#!/bin/bash\n");
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/package-info.java
new file mode 100644
index 0000000..a257204
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to produce launch commands and scripts.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java
new file mode 100644
index 0000000..ea735c9
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.common.Envs;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
+
+import java.util.Map;
+
+/**
+ * This class has common helper methods for TensorFlow.
+ */
+public final class TensorFlowCommons {
+ private TensorFlowCommons() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ public static void addCommonEnvironments(Component component,
+ TaskType taskType) {
+ Map<String, String> envs = component.getConfiguration().getEnv();
+ envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
+ envs.put(Envs.TASK_TYPE_ENV, taskType.name());
+ }
+
+ public static String getUserName() {
+ return System.getProperty("user.name");
+ }
+
+ public static String getDNSDomain(Configuration yarnConfig) {
+ return yarnConfig.get("hadoop.registry.dns.domain-name");
+ }
+
+ public static String getScriptFileName(TaskType taskType) {
+ return "run-" + taskType.name() + ".sh";
+ }
+
+ public static String getTFConfigEnv(String componentName, int nWorkers,
+ int nPs, String serviceName, String userName, String domain) {
+ String commonEndpointSuffix = YarnServiceUtils
+ .getDNSNameCommonSuffix(serviceName, userName, domain, 8000);
+
+ String json = "{\\\"cluster\\\":{";
+
+ String master = getComponentArrayJson("master", 1, commonEndpointSuffix)
+ + ",";
+ String worker = getComponentArrayJson("worker", nWorkers - 1,
+ commonEndpointSuffix) + ",";
+ String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},";
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("\\\"task\\\":{");
+ sb.append(" \\\"type\\\":\\\"");
+ sb.append(componentName);
+ sb.append("\\\",");
+ sb.append(" \\\"index\\\":");
+ sb.append('$');
+ sb.append(Envs.TASK_INDEX_ENV + "},");
+ String task = sb.toString();
+ String environment = "\\\"environment\\\":\\\"cloud\\\"}";
+
+ sb = new StringBuilder();
+ sb.append(json);
+ sb.append(master);
+ sb.append(worker);
+ sb.append(ps);
+ sb.append(task);
+ sb.append(environment);
+ return sb.toString();
+ }
+
+ private static String getComponentArrayJson(String componentName, int count,
+ String endpointSuffix) {
+ String component = "\\\"" + componentName + "\\\":";
+ StringBuilder array = new StringBuilder();
+ array.append("[");
+ for (int i = 0; i < count; i++) {
+ array.append("\\\"");
+ array.append(componentName);
+ array.append("-");
+ array.append(i);
+ array.append(endpointSuffix);
+ array.append("\\\"");
+ if (i != count - 1) {
+ array.append(",");
+ }
+ }
+ array.append("]");
+ return component + array.toString();
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java
new file mode 100644
index 0000000..815a41a
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.ClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
+import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
+import org.apache.hadoop.yarn.submarine.utils.Localizer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
+
+/**
+ * This class contains all the logic to create an instance
+ * of a {@link Service} object for TensorFlow.
+ * Worker,PS and Tensorboard components are added to the Service
+ * based on the value of the received {@link RunJobParameters}.
+ */
+public class TensorFlowServiceSpec implements ServiceSpec {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(TensorFlowServiceSpec.class);
+
+ private final RemoteDirectoryManager remoteDirectoryManager;
+
+ private final RunJobParameters parameters;
+ private final Configuration yarnConfig;
+ private final FileSystemOperations fsOperations;
+ private final LaunchCommandFactory launchCommandFactory;
+ private final Localizer localizer;
+
+ public TensorFlowServiceSpec(RunJobParameters parameters,
+ ClientContext clientContext, FileSystemOperations fsOperations,
+ LaunchCommandFactory launchCommandFactory, Localizer localizer) {
+ this.parameters = parameters;
+ this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
+ this.yarnConfig = clientContext.getYarnConfig();
+ this.fsOperations = fsOperations;
+ this.launchCommandFactory = launchCommandFactory;
+ this.localizer = localizer;
+ }
+
+ @Override
+ public ServiceWrapper create() throws IOException {
+ ServiceWrapper serviceWrapper = createServiceSpecWrapper();
+
+ if (parameters.getNumWorkers() > 0) {
+ addWorkerComponents(serviceWrapper);
+ }
+
+ if (parameters.getNumPS() > 0) {
+ addPsComponent(serviceWrapper);
+ }
+
+ if (parameters.isTensorboardEnabled()) {
+ createTensorBoardComponent(serviceWrapper);
+ }
+
+ // After all components added, handle quicklinks
+ handleQuicklinks(serviceWrapper.getService());
+
+ return serviceWrapper;
+ }
+
+ private ServiceWrapper createServiceSpecWrapper() throws IOException {
+ Service serviceSpec = new Service();
+ serviceSpec.setName(parameters.getName());
+ serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
+ serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
+
+ KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
+ .create(fsOperations, remoteDirectoryManager, parameters);
+ if (kerberosPrincipal != null) {
+ serviceSpec.setKerberosPrincipal(kerberosPrincipal);
+ }
+
+ handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
+ localizer.handleLocalizations(serviceSpec);
+ return new ServiceWrapper(serviceSpec);
+ }
+
+ private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
+ throws IOException {
+ TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
+ remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
+ serviceWrapper.addComponent(tbComponent);
+
+ addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
+ tbComponent.getTensorboardLink());
+ }
+
+ private static void addQuicklink(Service serviceSpec, String label,
+ String link) {
+ Map<String, String> quicklinks = serviceSpec.getQuicklinks();
+ if (quicklinks == null) {
+ quicklinks = new HashMap<>();
+ serviceSpec.setQuicklinks(quicklinks);
+ }
+
+ if (SubmarineLogs.isVerbose()) {
+ LOG.info("Added quicklink, " + label + "=" + link);
+ }
+
+ quicklinks.put(label, link);
+ }
+
+ private void handleQuicklinks(Service serviceSpec)
+ throws IOException {
+ List<Quicklink> quicklinks = parameters.getQuicklinks();
+ if (quicklinks != null && !quicklinks.isEmpty()) {
+ for (Quicklink ql : quicklinks) {
+ // Make sure it is a valid instance name
+ String instanceName = ql.getComponentInstanceName();
+ boolean found = false;
+
+ for (Component comp : serviceSpec.getComponents()) {
+ for (int i = 0; i < comp.getNumberOfContainers(); i++) {
+ String possibleInstanceName = comp.getName() + "-" + i;
+ if (possibleInstanceName.equals(instanceName)) {
+ found = true;
+ break;
+ }
+ }
+ }
+
+ if (!found) {
+ throw new IOException(
+ "Couldn't find a component instance = " + instanceName
+ + " while adding quicklink");
+ }
+
+ String link = ql.getProtocol()
+ + YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
+ getUserName(), getDNSDomain(yarnConfig), ql.getPort());
+ addQuicklink(serviceSpec, ql.getLabel(), link);
+ }
+ }
+ }
+
+ // Handle worker and primary_worker.
+
+ private void addWorkerComponents(ServiceWrapper serviceWrapper)
+ throws IOException {
+ addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
+
+ if (parameters.getNumWorkers() > 1) {
+ addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
+ }
+ }
+ private void addWorkerComponent(ServiceWrapper serviceWrapper,
+ RunJobParameters parameters, TaskType taskType) throws IOException {
+ serviceWrapper.addComponent(
+ new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
+ parameters, taskType, launchCommandFactory, yarnConfig));
+ }
+
+ private void addPsComponent(ServiceWrapper serviceWrapper)
+ throws IOException {
+ serviceWrapper.addComponent(
+ new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
+ launchCommandFactory, parameters, yarnConfig));
+ }
+
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java
new file mode 100644
index 0000000..dcd45c0
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * Launch command implementation for Tensorboard.
+ */
+public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(TensorBoardLaunchCommand.class);
+ private final String checkpointPath;
+
+ public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
+ TaskType taskType, Component component, RunJobParameters parameters)
+ throws IOException {
+ super(hadoopEnvSetup, taskType, component, parameters);
+ Objects.requireNonNull(parameters.getCheckpointPath(),
+ "CheckpointPath must not be null as it is part "
+ + "of the tensorboard command!");
+ if (StringUtils.isEmpty(parameters.getCheckpointPath())) {
+ throw new IllegalArgumentException("CheckpointPath must not be empty!");
+ }
+
+ this.checkpointPath = parameters.getCheckpointPath();
+ }
+
+ @Override
+ public String generateLaunchScript() throws IOException {
+ return getBuilder()
+ .withLaunchCommand(createLaunchCommand())
+ .build();
+ }
+
+ @Override
+ public String createLaunchCommand() {
+ String tbCommand = String.format("export LC_ALL=C && tensorboard " +
+ "--logdir=%s%n", checkpointPath);
+ LOG.info("Tensorboard command=" + tbCommand);
+ return tbCommand;
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
new file mode 100644
index 0000000..07a1811
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * Launch command implementation for
+ * TensorFlow PS and Worker Service components.
+ */
+public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(TensorFlowLaunchCommand.class);
+ private final Configuration yarnConfig;
+ private final boolean distributed;
+ private final int numberOfWorkers;
+ private final int numberOfPS;
+ private final String name;
+ private final TaskType taskType;
+
+ TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
+ TaskType taskType, Component component, RunJobParameters parameters,
+ Configuration yarnConfig) throws IOException {
+ super(hadoopEnvSetup, taskType, component, parameters);
+ this.taskType = taskType;
+ this.name = parameters.getName();
+ this.distributed = parameters.isDistributed();
+ this.numberOfWorkers = parameters.getNumWorkers();
+ this.numberOfPS = parameters.getNumPS();
+ this.yarnConfig = yarnConfig;
+ logReceivedParameters();
+ }
+
+ private void logReceivedParameters() {
+ if (this.numberOfWorkers <= 0) {
+ LOG.warn("Received number of workers: {}", this.numberOfWorkers);
+ }
+ if (this.numberOfPS <= 0) {
+ LOG.warn("Received number of PS: {}", this.numberOfPS);
+ }
+ }
+
+ @Override
+ public String generateLaunchScript() throws IOException {
+ LaunchScriptBuilder builder = getBuilder();
+
+ // When distributed training is required
+ if (distributed) {
+ String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
+ taskType.getComponentName(), numberOfWorkers,
+ numberOfPS, name,
+ TensorFlowCommons.getUserName(),
+ TensorFlowCommons.getDNSDomain(yarnConfig));
+ String tfConfig = "export TF_CONFIG=\"" + tfConfigEnvValue + "\"\n";
+ builder.append(tfConfig);
+ }
+
+ return builder
+ .withLaunchCommand(createLaunchCommand())
+ .build();
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java
new file mode 100644
index 0000000..e1aca40
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * Launch command implementation for Tensorboard's PS component.
+ */
+public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(TensorFlowPsLaunchCommand.class);
+ private final String launchCommand;
+
+ public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
+ TaskType taskType, Component component, RunJobParameters parameters,
+ Configuration yarnConfig) throws IOException {
+ super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
+ this.launchCommand = parameters.getPSLaunchCmd();
+
+ if (StringUtils.isEmpty(this.launchCommand)) {
+ throw new IllegalArgumentException("LaunchCommand must not be null " +
+ "or empty!");
+ }
+ }
+
+ @Override
+ public String createLaunchCommand() {
+ if (SubmarineLogs.isVerbose()) {
+ LOG.info("PS command =[" + launchCommand + "]");
+ }
+ return launchCommand + '\n';
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java
new file mode 100644
index 0000000..734d879
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * Launch command implementation for Tensorboard's Worker component.
+ */
+public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(TensorFlowWorkerLaunchCommand.class);
+ private final String launchCommand;
+
+ public TensorFlowWorkerLaunchCommand(
+ HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType,
+ Component component, RunJobParameters parameters,
+ Configuration yarnConfig) throws IOException {
+ super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
+ this.launchCommand = parameters.getWorkerLaunchCmd();
+
+ if (StringUtils.isEmpty(this.launchCommand)) {
+ throw new IllegalArgumentException("LaunchCommand must not be null " +
+ "or empty!");
+ }
+ }
+
+ @Override
+ public String createLaunchCommand() {
+ if (SubmarineLogs.isVerbose()) {
+ LOG.info("Worker command =[" + launchCommand + "]");
+ }
+ return launchCommand + '\n';
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/package-info.java
new file mode 100644
index 0000000..f8df3bb
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate TensorFlow launch commands.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java
new file mode 100644
index 0000000..2b9c1ca
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
+
+/**
+ * Component implementation for Tensorboard's Tensorboard.
+ */
+public class TensorBoardComponent extends AbstractComponent {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(TensorBoardComponent.class);
+
+ public static final String TENSORBOARD_QUICKLINK_LABEL = "Tensorboard";
+ private static final int DEFAULT_PORT = 6006;
+
+ //computed fields
+ private String tensorboardLink;
+
+ public TensorBoardComponent(FileSystemOperations fsOperations,
+ RemoteDirectoryManager remoteDirectoryManager,
+ RunJobParameters parameters,
+ LaunchCommandFactory launchCommandFactory,
+ Configuration yarnConfig) {
+ super(fsOperations, remoteDirectoryManager, parameters,
+ TaskType.TENSORBOARD, yarnConfig, launchCommandFactory);
+ }
+
+ @Override
+ public Component createComponent() throws IOException {
+ Objects.requireNonNull(parameters.getTensorboardResource(),
+ "TensorBoard resource must not be null!");
+
+ Component component = new Component();
+ component.setName(taskType.getComponentName());
+ component.setNumberOfContainers(1L);
+ component.setRestartPolicy(RestartPolicyEnum.NEVER);
+ component.setResource(convertYarnResourceToServiceResource(
+ parameters.getTensorboardResource()));
+
+ if (parameters.getTensorboardDockerImage() != null) {
+ component.setArtifact(
+ getDockerArtifact(parameters.getTensorboardDockerImage()));
+ }
+
+ addCommonEnvironments(component, taskType);
+ generateLaunchCommand(component);
+
+ tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
+ parameters.getName(),
+ taskType.getComponentName() + "-" + 0, getUserName(),
+ getDNSDomain(yarnConfig), DEFAULT_PORT);
+ LOG.info("Link to tensorboard:" + tensorboardLink);
+
+ return component;
+ }
+
+ public String getTensorboardLink() {
+ return tensorboardLink;
+ }
+
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowPsComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowPsComponent.java
new file mode 100644
index 0000000..c70e132
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowPsComponent.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
+
+/**
+ * Component implementation for TensorFlow's PS process.
+ */
+public class TensorFlowPsComponent extends AbstractComponent {
+ public TensorFlowPsComponent(FileSystemOperations fsOperations,
+ RemoteDirectoryManager remoteDirectoryManager,
+ LaunchCommandFactory launchCommandFactory,
+ RunJobParameters parameters,
+ Configuration yarnConfig) {
+ super(fsOperations, remoteDirectoryManager, parameters, TaskType.PS,
+ yarnConfig, launchCommandFactory);
+ }
+
+ @Override
+ public Component createComponent() throws IOException {
+ Objects.requireNonNull(parameters.getPsResource(),
+ "PS resource must not be null!");
+ if (parameters.getNumPS() < 1) {
+ throw new IllegalArgumentException("Number of PS should be at least 1!");
+ }
+
+ Component component = new Component();
+ component.setName(taskType.getComponentName());
+ component.setNumberOfContainers((long) parameters.getNumPS());
+ component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
+ component.setResource(
+ convertYarnResourceToServiceResource(parameters.getPsResource()));
+
+ // Override global docker image if needed.
+ if (parameters.getPsDockerImage() != null) {
+ component.setArtifact(
+ getDockerArtifact(parameters.getPsDockerImage()));
+ }
+ addCommonEnvironments(component, taskType);
+ generateLaunchCommand(component);
+
+ return component;
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowWorkerComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowWorkerComponent.java
new file mode 100644
index 0000000..7496040
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowWorkerComponent.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+import java.io.IOException;
+import java.util.Objects;
+import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
+import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
+import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
+
+/**
+ * Component implementation for TensorFlow's Worker process.
+ */
+public class TensorFlowWorkerComponent extends AbstractComponent {
+ public TensorFlowWorkerComponent(FileSystemOperations fsOperations,
+ RemoteDirectoryManager remoteDirectoryManager,
+ RunJobParameters parameters, TaskType taskType,
+ LaunchCommandFactory launchCommandFactory,
+ Configuration yarnConfig) {
+ super(fsOperations, remoteDirectoryManager, parameters, taskType,
+ yarnConfig, launchCommandFactory);
+ }
+
+ @Override
+ public Component createComponent() throws IOException {
+ Objects.requireNonNull(parameters.getWorkerResource(),
+ "Worker resource must not be null!");
+ if (parameters.getNumWorkers() < 1) {
+ throw new IllegalArgumentException(
+ "Number of workers should be at least 1!");
+ }
+
+ Component component = new Component();
+ component.setName(taskType.getComponentName());
+
+ if (taskType.equals(TaskType.PRIMARY_WORKER)) {
+ component.setNumberOfContainers(1L);
+ component.getConfiguration().setProperty(
+ CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
+ } else {
+ component.setNumberOfContainers(
+ (long) parameters.getNumWorkers() - 1);
+ }
+
+ if (parameters.getWorkerDockerImage() != null) {
+ component.setArtifact(
+ getDockerArtifact(parameters.getWorkerDockerImage()));
+ }
+
+ component.setResource(convertYarnResourceToServiceResource(
+ parameters.getWorkerResource()));
+ component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
+
+ addCommonEnvironments(component, taskType);
+ generateLaunchCommand(component);
+
+ return component;
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/package-info.java
new file mode 100644
index 0000000..10978b7
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate
+ * TensorFlow Native Service components.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/package-info.java
new file mode 100644
index 0000000..0c51485
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes to generate
+ * TensorFlow-related Native Service runtime artifacts.
+ */
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ClassPathUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ClassPathUtilities.java
new file mode 100644
index 0000000..fc8f6ea
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ClassPathUtilities.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import java.io.File;
+import java.util.StringTokenizer;
+
+/**
+ * Utilities for classpath operations.
+ */
+public final class ClassPathUtilities {
+ private ClassPathUtilities() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ public static File findFileOnClassPath(final String fileName) {
+ final String classpath = System.getProperty("java.class.path");
+ final String pathSeparator = System.getProperty("path.separator");
+ final StringTokenizer tokenizer = new StringTokenizer(classpath,
+ pathSeparator);
+
+ while (tokenizer.hasMoreTokens()) {
+ final String pathElement = tokenizer.nextToken();
+ final File directoryOrJar = new File(pathElement);
+ final File absoluteDirectoryOrJar = directoryOrJar.getAbsoluteFile();
+ if (absoluteDirectoryOrJar.isFile()) {
+ final File target =
+ new File(absoluteDirectoryOrJar.getParent(), fileName);
+ if (target.exists()) {
+ return target;
+ }
+ } else {
+ final File target = new File(directoryOrJar, fileName);
+ if (target.exists()) {
+ return target;
+ }
+ }
+ }
+
+ return null;
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/DockerUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/DockerUtilities.java
new file mode 100644
index 0000000..78cee33
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/DockerUtilities.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import org.apache.hadoop.yarn.service.api.records.Artifact;
+
+/**
+ * Utilities for Docker-related operations.
+ */
+public final class DockerUtilities {
+ private DockerUtilities() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ public static Artifact getDockerArtifact(String dockerImageName) {
+ return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName);
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/EnvironmentUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/EnvironmentUtilities.java
new file mode 100644
index 0000000..f4ef7b4
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/EnvironmentUtilities.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
+
+/**
+ * Utilities for environment variable related operations
+ * for {@link Service} objects.
+ */
+public final class EnvironmentUtilities {
+ private EnvironmentUtilities() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(EnvironmentUtilities.class);
+
+ static final String ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME =
+ "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS";
+ private static final String MOUNTS_DELIM = ",";
+ private static final String ENV_SEPARATOR = "=";
+ private static final String ETC_PASSWD_MOUNT_STRING =
+ "/etc/passwd:/etc/passwd:ro";
+ private static final String KERBEROS_CONF_MOUNT_STRING =
+ "/etc/krb5.conf:/etc/krb5.conf:ro";
+ private static final String ENV_VAR_DELIM = ":";
+
+ /**
+ * Extracts value from a string representation of an environment variable.
+ * @param envVar The environment variable in 'key=value' format.
+ * @return The value of the environment variable
+ */
+ public static String getValueOfEnvironment(String envVar) {
+ if (envVar == null || !envVar.contains(ENV_SEPARATOR)) {
+ return "";
+ } else {
+ return envVar.substring(envVar.indexOf(ENV_SEPARATOR) + 1);
+ }
+ }
+
+ public static void handleServiceEnvs(Service service,
+ Configuration yarnConfig, List<String> envVars) {
+ if (envVars != null) {
+ for (String envVarPair : envVars) {
+ String key, value;
+ if (envVarPair.contains(ENV_SEPARATOR)) {
+ int idx = envVarPair.indexOf(ENV_SEPARATOR);
+ key = envVarPair.substring(0, idx);
+ value = envVarPair.substring(idx + 1);
+ } else {
+ LOG.warn("Found environment variable with unusual format: '{}'",
+ envVarPair);
+ // No "=" found so use the whole key
+ key = envVarPair;
+ value = "";
+ }
+ appendToEnv(service, key, value, ENV_VAR_DELIM);
+ }
+ }
+ appendOtherConfigs(service, yarnConfig);
+ }
+
+ /**
+ * Appends other configs like /etc/passwd, /etc/krb5.conf.
+ * @param service
+ * @param yarnConfig
+ */
+ private static void appendOtherConfigs(Service service,
+ Configuration yarnConfig) {
+ appendToEnv(service, ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME,
+ ETC_PASSWD_MOUNT_STRING, MOUNTS_DELIM);
+
+ String authentication = yarnConfig.get(HADOOP_SECURITY_AUTHENTICATION);
+ if (authentication != null && authentication.equals("kerberos")) {
+ appendToEnv(service, ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME,
+ KERBEROS_CONF_MOUNT_STRING, MOUNTS_DELIM);
+ }
+ }
+
+ static void appendToEnv(Service service, String key, String value,
+ String delim) {
+ Map<String, String> env = service.getConfiguration().getEnv();
+ if (!env.containsKey(key)) {
+ env.put(key, value);
+ } else {
+ if (!value.isEmpty()) {
+ String existingValue = env.get(key);
+ if (!existingValue.endsWith(delim)) {
+ env.put(key, existingValue + delim + value);
+ } else {
+ env.put(key, existingValue + value);
+ }
+ }
+ }
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/KerberosPrincipalFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/KerberosPrincipalFactory.java
new file mode 100644
index 0000000..a37f37b
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/KerberosPrincipalFactory.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * Simple factory that creates a {@link KerberosPrincipal}.
+ */
+public final class KerberosPrincipalFactory {
+ private KerberosPrincipalFactory() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(KerberosPrincipalFactory.class);
+
+ public static KerberosPrincipal create(FileSystemOperations fsOperations,
+ RemoteDirectoryManager remoteDirectoryManager,
+ RunJobParameters parameters) throws IOException {
+ Objects.requireNonNull(fsOperations,
+ "FileSystemOperations must not be null!");
+ Objects.requireNonNull(remoteDirectoryManager,
+ "RemoteDirectoryManager must not be null!");
+ Objects.requireNonNull(parameters, "Parameters must not be null!");
+
+ if (StringUtils.isNotBlank(parameters.getKeytab()) && StringUtils
+ .isNotBlank(parameters.getPrincipal())) {
+ String keytab = parameters.getKeytab();
+ String principal = parameters.getPrincipal();
+ if (parameters.isDistributeKeytab()) {
+ return handleDistributedKeytab(fsOperations, remoteDirectoryManager,
+ parameters, keytab, principal);
+ } else {
+ return handleNormalKeytab(keytab, principal);
+ }
+ }
+ LOG.debug("Principal and keytab was null or empty, " +
+ "returning null KerberosPrincipal!");
+ return null;
+ }
+
+ private static KerberosPrincipal handleDistributedKeytab(
+ FileSystemOperations fsOperations,
+ RemoteDirectoryManager remoteDirectoryManager,
+ RunJobParameters parameters, String keytab, String principal)
+ throws IOException {
+ Path stagingDir = remoteDirectoryManager
+ .getJobStagingArea(parameters.getName(), true);
+ Path remoteKeytabPath =
+ fsOperations.uploadToRemoteFile(stagingDir, keytab);
+ // Only the owner has read access
+ fsOperations.setPermission(remoteKeytabPath,
+ FsPermission.createImmutable((short)Integer.parseInt("400", 8)));
+ return new KerberosPrincipal()
+ .keytab(remoteKeytabPath.toString())
+ .principalName(principal);
+ }
+
+ private static KerberosPrincipal handleNormalKeytab(String keytab,
+ String principal) {
+ if(!keytab.startsWith("file")) {
+ keytab = "file://" + keytab;
+ }
+ return new KerberosPrincipal()
+ .keytab(keytab)
+ .principalName(principal);
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/Localizer.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/Localizer.java
new file mode 100644
index 0000000..c86f1a2
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/Localizer.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.service.api.records.ConfigFile;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.List;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
+import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.appendToEnv;
+
+/**
+ * This class holds all dependencies in order to localize dependencies
+ * for containers.
+ */
+public class Localizer {
+ private static final Logger LOG = LoggerFactory.getLogger(Localizer.class);
+
+ private final FileSystemOperations fsOperations;
+ private final RemoteDirectoryManager remoteDirectoryManager;
+ private final RunJobParameters parameters;
+
+ public Localizer(FileSystemOperations fsOperations,
+ RemoteDirectoryManager remoteDirectoryManager,
+ RunJobParameters parameters) {
+ this.fsOperations = fsOperations;
+ this.remoteDirectoryManager = remoteDirectoryManager;
+ this.parameters = parameters;
+ }
+
+ /**
+ * Localize dependencies for all containers.
+ * If remoteUri is a local directory,
+ * we'll zip it, upload to HDFS staging dir HDFS.
+ * If remoteUri is directory, we'll download it, zip it and upload
+ * to HDFS.
+ * If localFilePath is ".", we'll use remoteUri's file/dir name
+ * */
+ public void handleLocalizations(Service service)
+ throws IOException {
+ // Handle localizations
+ Path stagingDir =
+ remoteDirectoryManager.getJobStagingArea(
+ parameters.getName(), true);
+ List<Localization> localizations = parameters.getLocalizations();
+ String remoteUri;
+ String containerLocalPath;
+
+ // Check to fail fast
+ for (Localization loc : localizations) {
+ remoteUri = loc.getRemoteUri();
+ Path resourceToLocalize = new Path(remoteUri);
+ // Check if remoteUri exists
+ if (remoteDirectoryManager.isRemote(remoteUri)) {
+ // check if exists
+ if (!remoteDirectoryManager.existsRemoteFile(resourceToLocalize)) {
+ throw new FileNotFoundException(
+ "File " + remoteUri + " doesn't exists.");
+ }
+ } else {
+ // Check if exists
+ File localFile = new File(remoteUri);
+ if (!localFile.exists()) {
+ throw new FileNotFoundException(
+ "File " + remoteUri + " doesn't exists.");
+ }
+ }
+ // check remote file size
+ fsOperations.validFileSize(remoteUri);
+ }
+ // Start download remote if needed and upload to HDFS
+ for (Localization loc : localizations) {
+ remoteUri = loc.getRemoteUri();
+ containerLocalPath = loc.getLocalPath();
+ String srcFileStr = remoteUri;
+ ConfigFile.TypeEnum destFileType = ConfigFile.TypeEnum.STATIC;
+ Path resourceToLocalize = new Path(remoteUri);
+ boolean needUploadToHDFS = true;
+
+
+ // Special handling of remoteUri directory
+ boolean needDeleteTempFile = false;
+ if (remoteDirectoryManager.isDir(remoteUri)) {
+ destFileType = ConfigFile.TypeEnum.ARCHIVE;
+ srcFileStr = fsOperations.downloadAndZip(
+ remoteUri, getLastNameFromPath(srcFileStr), true);
+ } else if (remoteDirectoryManager.isRemote(remoteUri)) {
+ if (!needHdfs(remoteUri)) {
+ // Non HDFS remote uri. Non directory, no need to zip
+ srcFileStr = fsOperations.downloadAndZip(
+ remoteUri, getLastNameFromPath(srcFileStr), false);
+ needDeleteTempFile = true;
+ } else {
+ // HDFS file, no need to upload
+ needUploadToHDFS = false;
+ }
+ }
+
+ // Upload file to HDFS
+ if (needUploadToHDFS) {
+ resourceToLocalize =
+ fsOperations.uploadToRemoteFile(stagingDir, srcFileStr);
+ }
+ if (needDeleteTempFile) {
+ fsOperations.deleteFiles(srcFileStr);
+ }
+ // Remove .zip from zipped dir name
+ if (destFileType == ConfigFile.TypeEnum.ARCHIVE
+ && srcFileStr.endsWith(".zip")) {
+ // Delete local zip file
+ fsOperations.deleteFiles(srcFileStr);
+ int suffixIndex = srcFileStr.lastIndexOf('_');
+ srcFileStr = srcFileStr.substring(0, suffixIndex);
+ }
+ // If provided, use the name of local uri
+ if (!containerLocalPath.equals(".")
+ && !containerLocalPath.equals("./")) {
+ // Change the YARN localized file name to what'll used in container
+ srcFileStr = getLastNameFromPath(containerLocalPath);
+ }
+ String localizedName = getLastNameFromPath(srcFileStr);
+ LOG.info("The file/dir to be localized is {}",
+ resourceToLocalize.toString());
+ LOG.info("Its localized file name will be {}", localizedName);
+ service.getConfiguration().getFiles().add(new ConfigFile().srcFile(
+ resourceToLocalize.toUri().toString()).destFile(localizedName)
+ .type(destFileType));
+ // set mounts
+ // if mount path is absolute, just use it.
+ // if relative, no need to mount explicitly
+ if (containerLocalPath.startsWith("/")) {
+ String mountStr = getLastNameFromPath(srcFileStr) + ":"
+ + containerLocalPath + ":" + loc.getMountPermission();
+ LOG.info("Add bind-mount string {}", mountStr);
+ appendToEnv(service,
+ EnvironmentUtilities.ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME,
+ mountStr, ",");
+ }
+ }
+ }
+
+ private String getLastNameFromPath(String srcFileStr) {
+ return new Path(srcFileStr).getName();
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/SubmarineResourceUtils.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/SubmarineResourceUtils.java
new file mode 100644
index 0000000..3d1a237
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/SubmarineResourceUtils.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import org.apache.hadoop.yarn.service.api.records.Resource;
+import org.apache.hadoop.yarn.service.api.records.ResourceInformation;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Resource utilities for Submarine.
+ */
+public final class SubmarineResourceUtils {
+ private SubmarineResourceUtils() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ public static Resource convertYarnResourceToServiceResource(
+ org.apache.hadoop.yarn.api.records.Resource yarnResource) {
+ Resource serviceResource = new Resource();
+ serviceResource.setCpus(yarnResource.getVirtualCores());
+ serviceResource.setMemory(String.valueOf(yarnResource.getMemorySize()));
+
+ Map<String, ResourceInformation> riMap = new HashMap<>();
+ for (org.apache.hadoop.yarn.api.records.ResourceInformation ri :
+ yarnResource.getAllResourcesListCopy()) {
+ ResourceInformation serviceRi = new ResourceInformation();
+ serviceRi.setValue(ri.getValue());
+ serviceRi.setUnit(ri.getUnits());
+ riMap.put(ri.getName(), serviceRi);
+ }
+ serviceResource.setResourceInformations(riMap);
+
+ return serviceResource;
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ZipUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ZipUtilities.java
new file mode 100644
index 0000000..c75f2d3
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ZipUtilities.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipOutputStream;
+
+/**
+ * Utilities for zipping directories and adding existing directories to zips.
+ */
+public final class ZipUtilities {
+ private ZipUtilities() {
+ throw new UnsupportedOperationException("This class should not be " +
+ "instantiated!");
+ }
+
+ private static final Logger LOG = LoggerFactory.getLogger(ZipUtilities.class);
+
+ @VisibleForTesting
+ public static String zipDir(String srcDir, String dstFile)
+ throws IOException {
+ FileOutputStream fos = new FileOutputStream(dstFile);
+ ZipOutputStream zos = new ZipOutputStream(fos);
+ File srcFile = new File(srcDir);
+ LOG.info("Compressing directory {}", srcDir);
+ addDirToZip(zos, srcFile, srcFile);
+ // close the ZipOutputStream
+ zos.close();
+ LOG.info("Compressed directory {} to file: {}", srcDir, dstFile);
+ return dstFile;
+ }
+
+ private static void addDirToZip(ZipOutputStream zos, File srcFile, File base)
+ throws IOException {
+ File[] files = srcFile.listFiles();
+ if (files == null) {
+ return;
+ }
+ for (File file : files) {
+ // if it's directory, add recursively
+ if (file.isDirectory()) {
+ addDirToZip(zos, file, base);
+ continue;
+ }
+ byte[] buffer = new byte[1024];
+ try(FileInputStream fis = new FileInputStream(file)) {
+ String name = base.toURI().relativize(file.toURI()).getPath();
+ LOG.info("Adding file {} to zip", name);
+ zos.putNextEntry(new ZipEntry(name));
+ int length;
+ while ((length = fis.read(buffer)) > 0) {
+ zos.write(buffer, 0, length);
+ }
+ zos.flush();
+ } finally {
+ zos.closeEntry();
+ }
+ }
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/package-info.java
new file mode 100644
index 0000000..2f60d90
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package contains classes utility classes.
+ */
+package org.apache.hadoop.yarn.submarine.utils;
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/FileUtilitiesForTests.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/FileUtilitiesForTests.java
new file mode 100644
index 0000000..a5161f5
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/FileUtilitiesForTests.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.io.FileUtils;
+import org.apache.hadoop.fs.Path;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * File utilities for tests.
+ * Provides methods that can create, delete files or directories
+ * in a temp directory, or any specified directory.
+ */
+public class FileUtilitiesForTests {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(FileUtilitiesForTests.class);
+ private String tempDir;
+ private List<File> cleanupFiles;
+
+ public void setup() {
+ cleanupFiles = Lists.newArrayList();
+ tempDir = System.getProperty("java.io.tmpdir");
+ }
+
+ public void teardown() throws IOException {
+ LOG.info("About to clean up files: " + cleanupFiles);
+ List<File> dirs = Lists.newArrayList();
+ for (File cleanupFile : cleanupFiles) {
+ if (cleanupFile.isDirectory()) {
+ dirs.add(cleanupFile);
+ } else {
+ deleteFile(cleanupFile);
+ }
+ }
+
+ for (File dir : dirs) {
+ deleteFile(dir);
+ }
+ }
+
+ public File createFileInTempDir(String filename) throws IOException {
+ File file = new File(tempDir, new Path(filename).getName());
+ createFile(file);
+ return file;
+ }
+
+ public File createDirInTempDir(String dirName) {
+ File file = new File(tempDir, new Path(dirName).getName());
+ createDirectory(file);
+ return file;
+ }
+
+ public File createFileInDir(Path dir, String filename) throws IOException {
+ File dirTmp = new File(dir.toUri().getPath());
+ if (!dirTmp.exists()) {
+ createDirectory(dirTmp);
+ }
+ File file =
+ new File(dir.toUri().getPath() + "/" + new Path(filename).getName());
+ createFile(file);
+ return file;
+ }
+
+ public File createFileInDir(File dir, String filename) throws IOException {
+ if (!dir.exists()) {
+ createDirectory(dir);
+ }
+ File file = new File(dir, filename);
+ createFile(file);
+ return file;
+ }
+
+ public File createDirectory(Path parent, String dirname) {
+ File dir =
+ new File(parent.toUri().getPath() + "/" + new Path(dirname).getName());
+ createDirectory(dir);
+ return dir;
+ }
+
+ public File createDirectory(File parent, String dirname) {
+ File dir =
+ new File(parent.getPath() + "/" + new Path(dirname).getName());
+ createDirectory(dir);
+ return dir;
+ }
+
+ private void createDirectory(File dir) {
+ boolean result = dir.mkdir();
+ assertTrue("Failed to create directory " + dir.getAbsolutePath(), result);
+ assertTrue("Directory does not exist: " + dir.getAbsolutePath(),
+ dir.exists());
+ this.cleanupFiles.add(dir);
+ }
+
+ private void createFile(File file) throws IOException {
+ boolean result = file.createNewFile();
+ assertTrue("Failed to create file " + file.getAbsolutePath(), result);
+ assertTrue("File does not exist: " + file.getAbsolutePath(), file.exists());
+ this.cleanupFiles.add(file);
+ }
+
+ private static void deleteFile(File file) throws IOException {
+ if (file.isDirectory()) {
+ LOG.info("Removing directory: " + file.getAbsolutePath());
+ FileUtils.deleteDirectory(file);
+ }
+
+ if (file.exists()) {
+ LOG.info("Removing file: " + file.getAbsolutePath());
+ boolean result = file.delete();
+ assertTrue("Deletion of file " + file.getAbsolutePath()
+ + " was not successful!", result);
+ }
+ }
+
+ public File getTempFileWithName(String filename) {
+ return new File(tempDir + "/" + new Path(filename).getName());
+ }
+
+ public static File getFilename(Path parent, String filename) {
+ return new File(
+ parent.toUri().getPath() + "/" + new Path(filename).getName());
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/ParamBuilderForTest.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/ParamBuilderForTest.java
new file mode 100644
index 0000000..8a9b7e0
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/ParamBuilderForTest.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+
+class ParamBuilderForTest {
+ private final List<String> params = Lists.newArrayList();
+
+ static ParamBuilderForTest create() {
+ return new ParamBuilderForTest();
+ }
+
+ ParamBuilderForTest withJobName(String name) {
+ params.add("--name");
+ params.add(name);
+ return this;
+ }
+
+ ParamBuilderForTest withDockerImage(String dockerImage) {
+ params.add("--docker_image");
+ params.add(dockerImage);
+ return this;
+ }
+
+ ParamBuilderForTest withInputPath(String inputPath) {
+ params.add("--input_path");
+ params.add(inputPath);
+ return this;
+ }
+
+ ParamBuilderForTest withCheckpointPath(String checkpointPath) {
+ params.add("--checkpoint_path");
+ params.add(checkpointPath);
+ return this;
+ }
+
+ ParamBuilderForTest withNumberOfWorkers(int numWorkers) {
+ params.add("--num_workers");
+ params.add(String.valueOf(numWorkers));
+ return this;
+ }
+
+ ParamBuilderForTest withNumberOfPs(int numPs) {
+ params.add("--num_ps");
+ params.add(String.valueOf(numPs));
+ return this;
+ }
+
+ ParamBuilderForTest withWorkerLaunchCommand(String launchCommand) {
+ params.add("--worker_launch_cmd");
+ params.add(launchCommand);
+ return this;
+ }
+
+ ParamBuilderForTest withPsLaunchCommand(String launchCommand) {
+ params.add("--ps_launch_cmd");
+ params.add(launchCommand);
+ return this;
+ }
+
+ ParamBuilderForTest withWorkerResources(String workerResources) {
+ params.add("--worker_resources");
+ params.add(workerResources);
+ return this;
+ }
+
+ ParamBuilderForTest withPsResources(String psResources) {
+ params.add("--ps_resources");
+ params.add(psResources);
+ return this;
+ }
+
+ ParamBuilderForTest withWorkerDockerImage(String dockerImage) {
+ params.add("--worker_docker_image");
+ params.add(dockerImage);
+ return this;
+ }
+
+ ParamBuilderForTest withPsDockerImage(String dockerImage) {
+ params.add("--ps_docker_image");
+ params.add(dockerImage);
+ return this;
+ }
+
+ ParamBuilderForTest withVerbose() {
+ params.add("--verbose");
+ return this;
+ }
+
+ ParamBuilderForTest withTensorboard() {
+ params.add("--tensorboard");
+ return this;
+ }
+
+ ParamBuilderForTest withTensorboardResources(String resources) {
+ params.add("--tensorboard_resources");
+ params.add(resources);
+ return this;
+ }
+
+ ParamBuilderForTest withTensorboardDockerImage(String dockerImage) {
+ params.add("--tensorboard_docker_image");
+ params.add(dockerImage);
+ return this;
+ }
+
+ ParamBuilderForTest withQuickLink(String quickLink) {
+ params.add("--quicklink");
+ params.add(quickLink);
+ return this;
+ }
+
+ ParamBuilderForTest withLocalization(String remoteUrl, String localUrl) {
+ params.add("--localization");
+ params.add(remoteUrl + ":" + localUrl);
+ return this;
+ }
+
+ String[] build() {
+ return params.toArray(new String[0]);
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java
index ee6b5c1..2a568cb 100644
--- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java
@@ -20,26 +20,23 @@ package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.fs.FileUtil;
-import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.api.records.Resource;
-import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.service.api.records.Component;
-import org.apache.hadoop.yarn.service.api.records.ConfigFile;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
-import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
-import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.common.StorageKeyConstants;
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceJobSubmitter;
-import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
+import org.apache.hadoop.yarn.submarine.utils.ZipUtilities;
import org.apache.hadoop.yarn.util.resource.Resources;
-import org.junit.Assert;
+import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -48,29 +45,41 @@ import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
-import java.util.List;
import java.util.Map;
-import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.reset;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_CHECKPOINT_PATH;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_DOCKER_IMAGE;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_INPUT_PATH;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_JOB_NAME;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_PS_DOCKER_IMAGE;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_PS_LAUNCH_CMD;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_PS_RESOURCES;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_TENSORBOARD_DOCKER_IMAGE;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_TENSORBOARD_RESOURCES;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_WORKER_DOCKER_IMAGE;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_WORKER_LAUNCH_CMD;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_WORKER_RESOURCES;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+/**
+ * Class to test YarnService with the Run job CLI action.
+ */
public class TestYarnServiceRunJobCli {
+ private TestYarnServiceRunJobCliCommons testCommons =
+ new TestYarnServiceRunJobCliCommons();
+
@Before
public void before() throws IOException, YarnException {
- SubmarineLogs.verboseOff();
- AppAdminClient serviceClient = mock(AppAdminClient.class);
- when(serviceClient.actionLaunch(any(String.class), any(String.class),
- any(Long.class), any(String.class))).thenReturn(EXIT_SUCCESS);
- when(serviceClient.getStatusString(any(String.class))).thenReturn(
- "{\"id\": \"application_1234_1\"}");
- YarnServiceUtils.setStubServiceClient(serviceClient);
+ testCommons.setup();
+ }
+
+ @After
+ public void cleanup() throws IOException {
+ testCommons.teardown();
}
@Test
@@ -81,53 +90,50 @@ public class TestYarnServiceRunJobCli {
runJobCli.printUsages();
}
- private Service getServiceSpecFromJobSubmitter(JobSubmitter jobSubmitter) {
- return ((YarnServiceJobSubmitter) jobSubmitter).getServiceSpec();
+ private ServiceWrapper getServiceWrapperFromJobSubmitter(
+ JobSubmitter jobSubmitter) {
+ return ((YarnServiceJobSubmitter) jobSubmitter).getServiceWrapper();
}
- private void commonVerifyDistributedTrainingSpec(Service serviceSpec)
- throws Exception {
- Assert.assertTrue(
- serviceSpec.getComponent(TaskType.WORKER.getComponentName()) != null);
- Assert.assertTrue(
- serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName())
- != null);
- Assert.assertTrue(
- serviceSpec.getComponent(TaskType.PS.getComponentName()) != null);
+ private void commonVerifyDistributedTrainingSpec(Service serviceSpec) {
+ assertNotNull(serviceSpec.getComponent(TaskType.WORKER.getComponentName()));
+ assertNotNull(
+ serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName()));
+ assertNotNull(serviceSpec.getComponent(TaskType.PS.getComponentName()));
Component primaryWorkerComp = serviceSpec.getComponent(
TaskType.PRIMARY_WORKER.getComponentName());
- Assert.assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB());
- Assert.assertEquals(2,
+ assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB());
+ assertEquals(2,
primaryWorkerComp.getResource().getCpus().intValue());
Component workerComp = serviceSpec.getComponent(
TaskType.WORKER.getComponentName());
- Assert.assertEquals(2048, workerComp.getResource().calcMemoryMB());
- Assert.assertEquals(2, workerComp.getResource().getCpus().intValue());
+ assertEquals(2048, workerComp.getResource().calcMemoryMB());
+ assertEquals(2, workerComp.getResource().getCpus().intValue());
Component psComp = serviceSpec.getComponent(TaskType.PS.getComponentName());
- Assert.assertEquals(4096, psComp.getResource().calcMemoryMB());
- Assert.assertEquals(4, psComp.getResource().getCpus().intValue());
+ assertEquals(4096, psComp.getResource().calcMemoryMB());
+ assertEquals(4, psComp.getResource().getCpus().intValue());
- Assert.assertEquals("worker.image", workerComp.getArtifact().getId());
- Assert.assertEquals("ps.image", psComp.getArtifact().getId());
+ assertEquals(DEFAULT_WORKER_DOCKER_IMAGE, workerComp.getArtifact().getId());
+ assertEquals(DEFAULT_PS_DOCKER_IMAGE, psComp.getArtifact().getId());
- Assert.assertTrue(SubmarineLogs.isVerbose());
+ assertTrue(SubmarineLogs.isVerbose());
}
private void verifyQuicklink(Service serviceSpec,
Map<String, String> expectedQuicklinks) {
Map<String, String> actualQuicklinks = serviceSpec.getQuicklinks();
if (actualQuicklinks == null || actualQuicklinks.isEmpty()) {
- Assert.assertTrue(
+ assertTrue(
expectedQuicklinks == null || expectedQuicklinks.isEmpty());
return;
}
- Assert.assertEquals(expectedQuicklinks.size(), actualQuicklinks.size());
+ assertEquals(expectedQuicklinks.size(), actualQuicklinks.size());
for (Map.Entry<String, String> expectedEntry : expectedQuicklinks
.entrySet()) {
- Assert.assertTrue(actualQuicklinks.containsKey(expectedEntry.getKey()));
+ assertTrue(actualQuicklinks.containsKey(expectedEntry.getKey()));
// $USER could be changed in different environment. so replace $USER by
// "user"
@@ -137,7 +143,7 @@ public class TestYarnServiceRunJobCli {
String userName = System.getProperty("user.name");
actualValue = actualValue.replaceAll(userName, "username");
- Assert.assertEquals(expectedValue, actualValue);
+ assertEquals(expectedValue, actualValue);
}
}
@@ -146,19 +152,27 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose"});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(3)
+ .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_WORKER_RESOURCES)
+ .withNumberOfPs(2)
+ .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE)
+ .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD)
+ .withPsResources(DEFAULT_PS_RESOURCES)
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(3, serviceSpec.getComponents().size());
+ assertEquals(3, serviceSpec.getComponents().size());
commonVerifyDistributedTrainingSpec(serviceSpec);
@@ -171,28 +185,37 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--tensorboard", "--ps_launch_cmd", "python run-ps.py",
- "--verbose"});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(3)
+ .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_WORKER_RESOURCES)
+ .withNumberOfPs(2)
+ .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE)
+ .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD)
+ .withPsResources(DEFAULT_PS_RESOURCES)
+ .withVerbose()
+ .withTensorboard()
+ .build();
+ runJobCli.run(params);
+ ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(4, serviceSpec.getComponents().size());
+ Service serviceSpec = serviceWrapper.getService();
+ assertEquals(4, serviceSpec.getComponents().size());
commonVerifyDistributedTrainingSpec(serviceSpec);
- verifyTensorboardComponent(runJobCli, serviceSpec,
+ verifyTensorboardComponent(runJobCli, serviceWrapper,
Resources.createResource(4096, 1));
verifyQuicklink(serviceSpec, ImmutableMap
- .of(YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL,
+ .of(TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL,
"http://tensorboard-0.my-job.username.null:6006"));
}
@@ -201,17 +224,23 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
- "--worker_resources", "memory=2G,vcores=2", "--verbose"});
-
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(1)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES)
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(1, serviceSpec.getComponents().size());
+ assertEquals(1, serviceSpec.getComponents().size());
commonTestSingleNodeTraining(serviceSpec);
}
@@ -221,41 +250,53 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "0", "--tensorboard", "--verbose"});
-
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(0)
+ .withTensorboard()
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+
+ ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(1, serviceSpec.getComponents().size());
+ assertEquals(1, serviceWrapper.getService().getComponents().size());
- verifyTensorboardComponent(runJobCli, serviceSpec,
+ verifyTensorboardComponent(runJobCli, serviceWrapper,
Resources.createResource(4096, 1));
}
@Test
- public void testTensorboardOnlyServiceWithCustomizedDockerImageAndResourceCkptPath()
+ public void testTensorboardOnlyServiceWithCustomDockerImageAndCheckpointPath()
throws Exception {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "0", "--tensorboard", "--verbose",
- "--tensorboard_resources", "memory=2G,vcores=2",
- "--tensorboard_docker_image", "tb_docker_image:001"});
-
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(0)
+ .withTensorboard()
+ .withTensorboardResources(DEFAULT_TENSORBOARD_RESOURCES)
+ .withTensorboardDockerImage(DEFAULT_TENSORBOARD_DOCKER_IMAGE)
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+
+ ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(1, serviceSpec.getComponents().size());
+ assertEquals(1, serviceWrapper.getService().getComponents().size());
- verifyTensorboardComponent(runJobCli, serviceSpec,
+ verifyTensorboardComponent(runJobCli, serviceWrapper,
Resources.createResource(2048, 2));
}
@@ -265,94 +306,92 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--num_workers", "0", "--tensorboard", "--verbose",
- "--tensorboard_resources", "memory=2G,vcores=2",
- "--tensorboard_docker_image", "tb_docker_image:001"});
-
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withNumberOfWorkers(0)
+ .withTensorboard()
+ .withTensorboardResources(DEFAULT_TENSORBOARD_RESOURCES)
+ .withTensorboardDockerImage(DEFAULT_TENSORBOARD_DOCKER_IMAGE)
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+
+ ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(1, serviceSpec.getComponents().size());
+ assertEquals(1, serviceWrapper.getService().getComponents().size());
- verifyTensorboardComponent(runJobCli, serviceSpec,
+ verifyTensorboardComponent(runJobCli, serviceWrapper,
Resources.createResource(2048, 2));
- verifyQuicklink(serviceSpec, ImmutableMap
- .of(YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL,
+ verifyQuicklink(serviceWrapper.getService(), ImmutableMap
+ .of(TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL,
"http://tensorboard-0.my-job.username.null:6006"));
}
- private void commonTestSingleNodeTraining(Service serviceSpec)
- throws Exception {
- Assert.assertTrue(
- serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName())
- != null);
+ private void commonTestSingleNodeTraining(Service serviceSpec) {
+ assertNotNull(
+ serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName()));
Component primaryWorkerComp = serviceSpec.getComponent(
TaskType.PRIMARY_WORKER.getComponentName());
- Assert.assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB());
- Assert.assertEquals(2,
+ assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB());
+ assertEquals(2,
primaryWorkerComp.getResource().getCpus().intValue());
- Assert.assertTrue(SubmarineLogs.isVerbose());
+ assertTrue(SubmarineLogs.isVerbose());
}
private void verifyTensorboardComponent(RunJobCli runJobCli,
- Service serviceSpec, Resource resource) throws Exception {
- Assert.assertTrue(
- serviceSpec.getComponent(TaskType.TENSORBOARD.getComponentName())
- != null);
+ ServiceWrapper serviceWrapper, Resource resource) throws Exception {
+ Service serviceSpec = serviceWrapper.getService();
+ assertNotNull(
+ serviceSpec.getComponent(TaskType.TENSORBOARD.getComponentName()));
Component tensorboardComp = serviceSpec.getComponent(
TaskType.TENSORBOARD.getComponentName());
- Assert.assertEquals(1, tensorboardComp.getNumberOfContainers().intValue());
- Assert.assertEquals(resource.getMemorySize(),
+ assertEquals(1, tensorboardComp.getNumberOfContainers().intValue());
+ assertEquals(resource.getMemorySize(),
tensorboardComp.getResource().calcMemoryMB());
- Assert.assertEquals(resource.getVirtualCores(),
+ assertEquals(resource.getVirtualCores(),
tensorboardComp.getResource().getCpus().intValue());
- Assert.assertEquals("./run-TENSORBOARD.sh",
+ assertEquals("./run-TENSORBOARD.sh",
tensorboardComp.getLaunchCommand());
// Check docker image
if (runJobCli.getRunJobParameters().getTensorboardDockerImage() != null) {
- Assert.assertEquals(
+ assertEquals(
runJobCli.getRunJobParameters().getTensorboardDockerImage(),
tensorboardComp.getArtifact().getId());
} else {
- Assert.assertNull(tensorboardComp.getArtifact());
+ assertNull(tensorboardComp.getArtifact());
}
- YarnServiceJobSubmitter yarnServiceJobSubmitter =
- (YarnServiceJobSubmitter) runJobCli.getJobSubmitter();
-
String expectedLaunchScript =
"#!/bin/bash\n" + "echo \"CLASSPATH:$CLASSPATH\"\n"
+ "echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n"
- + "echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n"
+ + "echo \"HADOOP_TOKEN_FILE_LOCATION:" +
+ "$HADOOP_TOKEN_FILE_LOCATION\"\n"
+ "echo \"JAVA_HOME:$JAVA_HOME\"\n"
+ "echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n"
+ "echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n"
+ "export LC_ALL=C && tensorboard --logdir=" + runJobCli
.getRunJobParameters().getCheckpointPath() + "\n";
- verifyLaunchScriptForComponet(yarnServiceJobSubmitter, serviceSpec,
+ verifyLaunchScriptForComponent(serviceWrapper,
TaskType.TENSORBOARD, expectedLaunchScript);
}
- private void verifyLaunchScriptForComponet(
- YarnServiceJobSubmitter yarnServiceJobSubmitter, Service serviceSpec,
+ private void verifyLaunchScriptForComponent(ServiceWrapper serviceWrapper,
TaskType taskType, String expectedLaunchScriptContent) throws Exception {
- Map<String, String> componentToLocalLaunchScriptMap =
- yarnServiceJobSubmitter.getComponentToLocalLaunchScriptPath();
- String path = componentToLocalLaunchScriptMap.get(
- taskType.getComponentName());
+ String path = serviceWrapper
+ .getLocalLaunchCommandPathForComponent(taskType.getComponentName());
byte[] encoded = Files.readAllBytes(Paths.get(path));
String scriptContent = new String(encoded, Charset.defaultCharset());
- Assert.assertEquals(expectedLaunchScriptContent, scriptContent);
+ assertEquals(expectedLaunchScriptContent, scriptContent);
}
@Test
@@ -361,21 +400,28 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
- "--worker_resources", "memory=2G,vcores=2", "--tensorboard",
- "--verbose"});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(1)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES)
+ .withTensorboard()
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+ ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter(
runJobCli.getJobSubmitter());
+ Service serviceSpec = serviceWrapper.getService();
- Assert.assertEquals(2, serviceSpec.getComponents().size());
+ assertEquals(2, serviceSpec.getComponents().size());
commonTestSingleNodeTraining(serviceSpec);
- verifyTensorboardComponent(runJobCli, serviceSpec,
+ verifyTensorboardComponent(runJobCli, serviceWrapper,
Resources.createResource(4096, 1));
}
@@ -385,20 +431,27 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--num_workers", "1",
- "--worker_launch_cmd", "python run-job.py", "--worker_resources",
- "memory=2G,vcores=2", "--tensorboard", "--verbose"});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withNumberOfWorkers(1)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES)
+ .withTensorboard()
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+ ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter(
runJobCli.getJobSubmitter());
+ Service serviceSpec = serviceWrapper.getService();
- Assert.assertEquals(2, serviceSpec.getComponents().size());
+ assertEquals(2, serviceSpec.getComponents().size());
commonTestSingleNodeTraining(serviceSpec);
- verifyTensorboardComponent(runJobCli, serviceSpec,
+ verifyTensorboardComponent(runJobCli, serviceWrapper,
Resources.createResource(4096, 1));
}
@@ -407,20 +460,26 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
- "--worker_resources", "memory=2G,vcores=2", "--tensorboard", "true",
- "--verbose"});
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(1)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES)
+ .withTensorboard()
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
SubmarineStorage storage =
mockClientContext.getRuntimeFactory().getSubmarineStorage();
- Map<String, String> jobInfo = storage.getJobInfoByName("my-job");
- Assert.assertTrue(jobInfo.size() > 0);
- Assert.assertEquals(jobInfo.get(StorageKeyConstants.INPUT_PATH),
- "s3://input");
+ Map<String, String> jobInfo = storage.getJobInfoByName(DEFAULT_JOB_NAME);
+ assertTrue(jobInfo.size() > 0);
+ assertEquals(jobInfo.get(StorageKeyConstants.INPUT_PATH),
+ DEFAULT_INPUT_PATH);
}
@Test
@@ -428,21 +487,29 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose", "--quicklink",
- "AAA=http://master-0:8321", "--quicklink",
- "BBB=http://worker-0:1234"});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(3)
+ .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_WORKER_RESOURCES)
+ .withNumberOfPs(2)
+ .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE)
+ .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD)
+ .withPsResources(DEFAULT_PS_RESOURCES)
+ .withQuickLink("AAA=http://master-0:8321")
+ .withQuickLink("BBB=http://worker-0:1234")
+ .withVerbose()
+ .build();
+ runJobCli.run(params);
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(3, serviceSpec.getComponents().size());
+ assertEquals(3, serviceSpec.getComponents().size());
commonVerifyDistributedTrainingSpec(serviceSpec);
@@ -456,765 +523,74 @@ public class TestYarnServiceRunJobCli {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose", "--quicklink",
- "AAA=http://master-0:8321", "--quicklink",
- "BBB=http://worker-0:1234", "--tensorboard"});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
+ assertFalse(SubmarineLogs.isVerbose());
+
+ String[] params = ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(3)
+ .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_WORKER_RESOURCES)
+ .withNumberOfPs(2)
+ .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE)
+ .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD)
+ .withPsResources(DEFAULT_PS_RESOURCES)
+ .withQuickLink("AAA=http://master-0:8321")
+ .withQuickLink("BBB=http://worker-0:1234")
+ .withTensorboard()
+ .withVerbose()
+ .build();
+
+ runJobCli.run(params);
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
- Assert.assertEquals(4, serviceSpec.getComponents().size());
+ assertEquals(4, serviceSpec.getComponents().size());
commonVerifyDistributedTrainingSpec(serviceSpec);
verifyQuicklink(serviceSpec, ImmutableMap
.of("AAA", "http://master-0.my-job.username.null:8321", "BBB",
"http://worker-0.my-job.username.null:1234",
- YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL,
+ TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL,
"http://tensorboard-0.my-job.username.null:6006"));
}
/**
- * Basic test.
- * In one hand, create local temp file/dir for hdfs URI in
- * local staging dir.
- * In the other hand, use MockRemoteDirectoryManager mock
- * implementation when check FileStatus or exists of HDFS file/dir
- * --localization hdfs:///user/yarn/script1.py:.
- * --localization /temp/script2.py:./
- * --localization /temp/script2.py:/opt/script.py
- */
- @Test
- public void testRunJobWithBasicLocalization() throws Exception {
- String remoteUrl = "hdfs:///user/yarn/script1.py";
- String containerLocal1 = ".";
- String localUrl = "/temp/script2.py";
- String containerLocal2 = "./";
- String containerLocal3 = "/opt/script.py";
- String fakeLocalDir = System.getProperty("java.io.tmpdir");
- // create local file, we need to put it under local temp dir
- File localFile1 = new File(fakeLocalDir,
- new Path(localUrl).getName());
- localFile1.createNewFile();
-
-
- MockClientContext mockClientContext =
- YarnServiceCliTestUtils.getMockClientContext();
- RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- RemoteDirectoryManager spyRdm =
- spy(mockClientContext.getRemoteDirectoryManager());
- mockClientContext.setRemoteDirectoryMgr(spyRdm);
-
- // create remote file in local staging dir to simulate HDFS
- Path stagingDir = mockClientContext.getRemoteDirectoryManager()
- .getJobStagingArea("my-job", true);
- File remoteFile1 = new File(stagingDir.toUri().getPath()
- + "/" + new Path(remoteUrl).getName());
- remoteFile1.createNewFile();
-
- Assert.assertTrue(localFile1.exists());
- Assert.assertTrue(remoteFile1.exists());
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- remoteUrl + ":" + containerLocal1,
- "--localization",
- localFile1.getAbsolutePath() + ":" + containerLocal2,
- "--localization",
- localFile1.getAbsolutePath() + ":" + containerLocal3});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
- runJobCli.getJobSubmitter());
- Assert.assertEquals(3, serviceSpec.getComponents().size());
-
- // No remote dir and hdfs file exists. Ensure download 0 times
- verify(spyRdm, times(0)).copyRemoteToLocal(
- anyString(), anyString());
- // Ensure local original files are not deleted
- Assert.assertTrue(localFile1.exists());
-
- List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
- Assert.assertEquals(3, files.size());
- ConfigFile file = files.get(0);
- Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType());
- String expectedSrcLocalization = remoteUrl;
- Assert.assertEquals(expectedSrcLocalization,
- file.getSrcFile());
- String expectedDstFileName = new Path(remoteUrl).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- file = files.get(1);
- Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(localUrl).getName();
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- expectedDstFileName = new Path(localUrl).getName();
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
-
- file = files.get(2);
- Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(localUrl).getName();
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- expectedDstFileName = new Path(localUrl).getName();
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
-
- // Ensure env value is correct
- String env = serviceSpec.getConfiguration().getEnv()
- .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
- String expectedMounts = new Path(containerLocal3).getName()
- + ":" + containerLocal3 + ":rw";
- Assert.assertTrue(env.contains(expectedMounts));
-
- remoteFile1.delete();
- localFile1.delete();
- }
-
- /**
- * Non HDFS remote URI test.
- * --localization https://a/b/1.patch:.
- * --localization s3a://a/dir:/opt/mys3dir
- */
- @Test
- public void testRunJobWithNonHDFSRemoteLocalization() throws Exception {
- String remoteUri1 = "https://a/b/1.patch";
- String containerLocal1 = ".";
- String remoteUri2 = "s3a://a/s3dir";
- String containerLocal2 = "/opt/mys3dir";
-
- MockClientContext mockClientContext =
- YarnServiceCliTestUtils.getMockClientContext();
- RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- RemoteDirectoryManager spyRdm =
- spy(mockClientContext.getRemoteDirectoryManager());
- mockClientContext.setRemoteDirectoryMgr(spyRdm);
-
- // create remote file in local staging dir to simulate HDFS
- Path stagingDir = mockClientContext.getRemoteDirectoryManager()
- .getJobStagingArea("my-job", true);
- File remoteFile1 = new File(stagingDir.toUri().getPath()
- + "/" + new Path(remoteUri1).getName());
- remoteFile1.createNewFile();
-
- File remoteDir1 = new File(stagingDir.toUri().getPath()
- + "/" + new Path(remoteUri2).getName());
- remoteDir1.mkdir();
- File remoteDir1File1 = new File(remoteDir1, "afile");
- remoteDir1File1.createNewFile();
-
- Assert.assertTrue(remoteFile1.exists());
- Assert.assertTrue(remoteDir1.exists());
- Assert.assertTrue(remoteDir1File1.exists());
-
- String suffix1 = "_" + remoteDir1.lastModified()
- + "-" + mockClientContext.getRemoteDirectoryManager()
- .getRemoteFileSize(remoteUri2);
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- remoteUri1 + ":" + containerLocal1,
- "--localization",
- remoteUri2 + ":" + containerLocal2});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
- runJobCli.getJobSubmitter());
- Assert.assertEquals(3, serviceSpec.getComponents().size());
-
- // Ensure download remote dir 2 times
- verify(spyRdm, times(2)).copyRemoteToLocal(
- anyString(), anyString());
-
- // Ensure downloaded temp files are deleted
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(remoteUri1).getName()).exists());
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(remoteUri2).getName()).exists());
-
- // Ensure zip file are deleted
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(remoteUri2).getName()
- + "_" + suffix1 + ".zip").exists());
-
- List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
- Assert.assertEquals(2, files.size());
- ConfigFile file = files.get(0);
- Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType());
- String expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(remoteUri1).getName();
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- String expectedDstFileName = new Path(remoteUri1).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- file = files.get(1);
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(remoteUri2).getName() + suffix1 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
-
- expectedDstFileName = new Path(containerLocal2).getName();
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
-
- // Ensure env value is correct
- String env = serviceSpec.getConfiguration().getEnv()
- .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
- String expectedMounts = new Path(remoteUri2).getName()
- + ":" + containerLocal2 + ":rw";
- Assert.assertTrue(env.contains(expectedMounts));
-
- remoteDir1File1.delete();
- remoteFile1.delete();
- remoteDir1.delete();
- }
-
- /**
- * Test HDFS dir localization.
- * --localization hdfs:///user/yarn/mydir:./mydir1
- * --localization hdfs:///user/yarn/mydir2:/opt/dir2:rw
- * --localization hdfs:///user/yarn/mydir:.
- * --localization hdfs:///user/yarn/mydir2:./
- */
- @Test
- public void testRunJobWithHdfsDirLocalization() throws Exception {
- String remoteUrl = "hdfs:///user/yarn/mydir";
- String containerPath = "./mydir1";
- String remoteUrl2 = "hdfs:///user/yarn/mydir2";
- String containPath2 = "/opt/dir2";
- String containerPath3 = ".";
- String containerPath4 = "./";
- MockClientContext mockClientContext =
- YarnServiceCliTestUtils.getMockClientContext();
- RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- RemoteDirectoryManager spyRdm =
- spy(mockClientContext.getRemoteDirectoryManager());
- mockClientContext.setRemoteDirectoryMgr(spyRdm);
- // create remote file in local staging dir to simulate HDFS
- Path stagingDir = mockClientContext.getRemoteDirectoryManager()
- .getJobStagingArea("my-job", true);
- File remoteDir1 = new File(stagingDir.toUri().getPath().toString()
- + "/" + new Path(remoteUrl).getName());
- remoteDir1.mkdir();
- File remoteFile1 = new File(remoteDir1.getAbsolutePath() + "/1.py");
- File remoteFile2 = new File(remoteDir1.getAbsolutePath() + "/2.py");
- remoteFile1.createNewFile();
- remoteFile2.createNewFile();
-
- File remoteDir2 = new File(stagingDir.toUri().getPath().toString()
- + "/" + new Path(remoteUrl2).getName());
- remoteDir2.mkdir();
- File remoteFile3 = new File(remoteDir1.getAbsolutePath() + "/3.py");
- File remoteFile4 = new File(remoteDir1.getAbsolutePath() + "/4.py");
- remoteFile3.createNewFile();
- remoteFile4.createNewFile();
-
- Assert.assertTrue(remoteDir1.exists());
- Assert.assertTrue(remoteDir2.exists());
-
- String suffix1 = "_" + remoteDir1.lastModified()
- + "-" + mockClientContext.getRemoteDirectoryManager()
- .getRemoteFileSize(remoteUrl);
- String suffix2 = "_" + remoteDir2.lastModified()
- + "-" + mockClientContext.getRemoteDirectoryManager()
- .getRemoteFileSize(remoteUrl2);
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- remoteUrl + ":" + containerPath,
- "--localization",
- remoteUrl2 + ":" + containPath2 + ":rw",
- "--localization",
- remoteUrl + ":" + containerPath3,
- "--localization",
- remoteUrl2 + ":" + containerPath4});
- Service serviceSpec = getServiceSpecFromJobSubmitter(
- runJobCli.getJobSubmitter());
- Assert.assertEquals(3, serviceSpec.getComponents().size());
-
- // Ensure download remote dir 4 times
- verify(spyRdm, times(4)).copyRemoteToLocal(
- anyString(), anyString());
-
- // Ensure downloaded temp files are deleted
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(remoteUrl).getName()).exists());
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(remoteUrl2).getName()).exists());
- // Ensure zip file are deleted
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(remoteUrl).getName()
- + suffix1 + ".zip").exists());
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(remoteUrl2).getName()
- + suffix2 + ".zip").exists());
-
- // Ensure files will be localized
- List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
- Assert.assertEquals(4, files.size());
- ConfigFile file = files.get(0);
- // The hdfs dir should be download and compress and let YARN to uncompress
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- String expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(remoteUrl).getName() + suffix1 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
-
- // Relative path in container, but not "." or "./". Use its own name
- String expectedDstFileName = new Path(containerPath).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- file = files.get(1);
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(remoteUrl2).getName() + suffix2 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
-
- expectedDstFileName = new Path(containPath2).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- file = files.get(2);
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(remoteUrl).getName() + suffix1 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- // Relative path in container ".", use remote path name
- expectedDstFileName = new Path(remoteUrl).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- file = files.get(3);
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(remoteUrl2).getName() + suffix2 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- // Relative path in container "./", use remote path name
- expectedDstFileName = new Path(remoteUrl2).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- // Ensure mounts env value is correct. Add one mount string
- String env = serviceSpec.getConfiguration().getEnv()
- .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
-
- String expectedMounts =
- new Path(containPath2).getName() + ":" + containPath2 + ":rw";
- Assert.assertTrue(env.contains(expectedMounts));
-
- remoteFile1.delete();
- remoteFile2.delete();
- remoteFile3.delete();
- remoteFile4.delete();
- remoteDir1.delete();
- remoteDir2.delete();
- }
-
- /**
- * Test if file/dir to be localized whose size exceeds limit.
- * Max 10MB in configuration, mock remote will
- * always return file size 100MB.
- * This configuration will fail the job which has remoteUri
- * But don't impact local dir/file
- *
- * --localization https://a/b/1.patch:.
- * --localization s3a://a/dir:/opt/mys3dir
- * --localization /temp/script2.py:./
- */
- @Test
- public void testRunJobRemoteUriExceedLocalizationSize() throws Exception {
- String remoteUri1 = "https://a/b/1.patch";
- String containerLocal1 = ".";
- String remoteUri2 = "s3a://a/s3dir";
- String containerLocal2 = "/opt/mys3dir";
- String localUri1 = "/temp/script2";
- String containerLocal3 = "./";
-
- MockClientContext mockClientContext =
- YarnServiceCliTestUtils.getMockClientContext();
- SubmarineConfiguration submarineConf = new SubmarineConfiguration();
- RemoteDirectoryManager spyRdm =
- spy(mockClientContext.getRemoteDirectoryManager());
- mockClientContext.setRemoteDirectoryMgr(spyRdm);
- /**
- * Max 10MB, mock remote will always return file size 100MB.
- * */
- submarineConf.set(
- SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB,
- "10");
- mockClientContext.setSubmarineConfig(submarineConf);
-
- RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- // create remote file in local staging dir to simulate
- Path stagingDir = mockClientContext.getRemoteDirectoryManager()
- .getJobStagingArea("my-job", true);
- File remoteFile1 = new File(stagingDir.toUri().getPath()
- + "/" + new Path(remoteUri1).getName());
- remoteFile1.createNewFile();
- File remoteDir1 = new File(stagingDir.toUri().getPath()
- + "/" + new Path(remoteUri2).getName());
- remoteDir1.mkdir();
-
- File remoteDir1File1 = new File(remoteDir1, "afile");
- remoteDir1File1.createNewFile();
-
- String fakeLocalDir = System.getProperty("java.io.tmpdir");
- // create local file, we need to put it under local temp dir
- File localFile1 = new File(fakeLocalDir,
- new Path(localUri1).getName());
- localFile1.createNewFile();
-
- Assert.assertTrue(remoteFile1.exists());
- Assert.assertTrue(remoteDir1.exists());
- Assert.assertTrue(remoteDir1File1.exists());
-
- String suffix1 = "_" + remoteDir1.lastModified()
- + "-" + remoteDir1.length();
- try {
- runJobCli = new RunJobCli(mockClientContext);
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources",
- "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- remoteUri1 + ":" + containerLocal1});
- } catch (IOException e) {
- // Shouldn't have exception because it's within file size limit
- Assert.assertFalse(true);
- }
- // we should download because fail fast
- verify(spyRdm, times(1)).copyRemoteToLocal(
- anyString(), anyString());
- try {
- // reset
- reset(spyRdm);
- runJobCli = new RunJobCli(mockClientContext);
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources",
- "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- remoteUri1 + ":" + containerLocal1,
- "--localization",
- remoteUri2 + ":" + containerLocal2,
- "--localization",
- localFile1.getAbsolutePath() + ":" + containerLocal3});
- } catch (IOException e) {
- Assert.assertTrue(e.getMessage()
- .contains("104857600 exceeds configured max size:10485760"));
- // we shouldn't do any download because fail fast
- verify(spyRdm, times(0)).copyRemoteToLocal(
- anyString(), anyString());
- }
-
- try {
- runJobCli = new RunJobCli(mockClientContext);
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources",
- "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- localFile1.getAbsolutePath() + ":" + containerLocal3});
- } catch (IOException e) {
- Assert.assertTrue(e.getMessage()
- .contains("104857600 exceeds configured max size:10485760"));
- // we shouldn't do any download because fail fast
- verify(spyRdm, times(0)).copyRemoteToLocal(
- anyString(), anyString());
- }
-
- localFile1.delete();
- remoteDir1File1.delete();
- remoteFile1.delete();
- remoteDir1.delete();
- }
-
- /**
- * Test remote Uri doesn't exist.
- * */
- @Test
- public void testRunJobWithNonExistRemoteUri() throws Exception {
- String remoteUri1 = "hdfs:///a/b/1.patch";
- String containerLocal1 = ".";
- String localUri1 = "/a/b/c";
- String containerLocal2 = "./";
- MockClientContext mockClientContext =
- YarnServiceCliTestUtils.getMockClientContext();
-
- RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- try {
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources",
- "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- remoteUri1 + ":" + containerLocal1});
- } catch (IOException e) {
- Assert.assertTrue(e.getMessage()
- .contains("doesn't exists"));
- }
-
- try {
- runJobCli = new RunJobCli(mockClientContext);
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources",
- "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- localUri1 + ":" + containerLocal2});
- } catch (IOException e) {
- Assert.assertTrue(e.getMessage()
- .contains("doesn't exists"));
- }
- }
-
- /**
- * Test local dir
- * --localization /user/yarn/mydir:./mydir1
- * --localization /user/yarn/mydir2:/opt/dir2:rw
- * --localization /user/yarn/mydir2:.
- */
- @Test
- public void testRunJobWithLocalDirLocalization() throws Exception {
- String fakeLocalDir = System.getProperty("java.io.tmpdir");
- String localUrl = "/user/yarn/mydir";
- String containerPath = "./mydir1";
- String localUrl2 = "/user/yarn/mydir2";
- String containPath2 = "/opt/dir2";
- String containerPath3 = ".";
-
- MockClientContext mockClientContext =
- YarnServiceCliTestUtils.getMockClientContext();
- RunJobCli runJobCli = new RunJobCli(mockClientContext);
- Assert.assertFalse(SubmarineLogs.isVerbose());
-
- RemoteDirectoryManager spyRdm =
- spy(mockClientContext.getRemoteDirectoryManager());
- mockClientContext.setRemoteDirectoryMgr(spyRdm);
- // create local file
- File localDir1 = new File(fakeLocalDir,
- localUrl);
- localDir1.mkdirs();
- File temp1 = new File(localDir1.getAbsolutePath() + "/1.py");
- File temp2 = new File(localDir1.getAbsolutePath() + "/2.py");
- temp1.createNewFile();
- temp2.createNewFile();
-
- File localDir2 = new File(fakeLocalDir,
- localUrl2);
- localDir2.mkdirs();
- File temp3 = new File(localDir1.getAbsolutePath() + "/3.py");
- File temp4 = new File(localDir1.getAbsolutePath() + "/4.py");
- temp3.createNewFile();
- temp4.createNewFile();
-
- Assert.assertTrue(localDir1.exists());
- Assert.assertTrue(localDir2.exists());
-
- String suffix1 = "_" + localDir1.lastModified()
- + "-" + localDir1.length();
- String suffix2 = "_" + localDir2.lastModified()
- + "-" + localDir2.length();
-
- runJobCli.run(
- new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
- "--input_path", "s3://input", "--checkpoint_path", "s3://output",
- "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
- "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
- "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
- "ps.image", "--worker_docker_image", "worker.image",
- "--ps_launch_cmd", "python run-ps.py", "--verbose",
- "--localization",
- fakeLocalDir + localUrl + ":" + containerPath,
- "--localization",
- fakeLocalDir + localUrl2 + ":" + containPath2 + ":rw",
- "--localization",
- fakeLocalDir + localUrl2 + ":" + containerPath3});
-
- Service serviceSpec = getServiceSpecFromJobSubmitter(
- runJobCli.getJobSubmitter());
- Assert.assertEquals(3, serviceSpec.getComponents().size());
-
- // we shouldn't do any download
- verify(spyRdm, times(0)).copyRemoteToLocal(
- anyString(), anyString());
-
- // Ensure local original files are not deleted
- Assert.assertTrue(localDir1.exists());
- Assert.assertTrue(localDir2.exists());
-
- // Ensure zip file are deleted
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(localUrl).getName()
- + suffix1 + ".zip").exists());
- Assert.assertFalse(new File(System.getProperty("java.io.tmpdir")
- + "/" + new Path(localUrl2).getName()
- + suffix2 + ".zip").exists());
-
- // Ensure dirs will be zipped and localized
- List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
- Assert.assertEquals(3, files.size());
- ConfigFile file = files.get(0);
- Path stagingDir = mockClientContext.getRemoteDirectoryManager()
- .getJobStagingArea("my-job", true);
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- String expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(localUrl).getName() + suffix1 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- String expectedDstFileName = new Path(containerPath).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- file = files.get(1);
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(localUrl2).getName() + suffix2 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- expectedDstFileName = new Path(containPath2).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- file = files.get(2);
- Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType());
- expectedSrcLocalization = stagingDir.toUri().getPath()
- + "/" + new Path(localUrl2).getName() + suffix2 + ".zip";
- Assert.assertEquals(expectedSrcLocalization,
- new Path(file.getSrcFile()).toUri().getPath());
- expectedDstFileName = new Path(localUrl2).getName();
- Assert.assertEquals(expectedDstFileName, file.getDestFile());
-
- // Ensure mounts env value is correct
- String env = serviceSpec.getConfiguration().getEnv()
- .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
- String expectedMounts = new Path(containPath2).getName()
- + ":" + containPath2 + ":rw";
-
- Assert.assertTrue(env.contains(expectedMounts));
-
- temp1.delete();
- temp2.delete();
- temp3.delete();
- temp4.delete();
- localDir2.delete();
- localDir1.delete();
- }
-
- /**
* Test zip function.
* A dir "/user/yarn/mydir" has two files and one subdir
* */
@Test
public void testYarnServiceSubmitterZipFunction()
throws Exception {
- MockClientContext mockClientContext =
- YarnServiceCliTestUtils.getMockClientContext();
- RunJobCli runJobCli = new RunJobCli(mockClientContext);
- YarnServiceJobSubmitter submitter =
- (YarnServiceJobSubmitter)mockClientContext
- .getRuntimeFactory().getJobSubmitterInstance();
- String fakeLocalDir = System.getProperty("java.io.tmpdir");
String localUrl = "/user/yarn/mydir";
String localSubDirName = "subdir1";
- // create local file
- File localDir1 = new File(fakeLocalDir,
- localUrl);
- localDir1.mkdirs();
- File temp1 = new File(localDir1.getAbsolutePath() + "/1.py");
- File temp2 = new File(localDir1.getAbsolutePath() + "/2.py");
- temp1.createNewFile();
- temp2.createNewFile();
+ // create local file
+ File localDir1 = testCommons.getFileUtils().createDirInTempDir(localUrl);
+ testCommons.getFileUtils().createFileInDir(localDir1, "1.py");
+ testCommons.getFileUtils().createFileInDir(localDir1, "2.py");
- File localSubDir = new File(localDir1.getAbsolutePath(), localSubDirName);
- localSubDir.mkdir();
- File temp3 = new File(localSubDir.getAbsolutePath(), "3.py");
- temp3.createNewFile();
-
+ File localSubDir =
+ testCommons.getFileUtils().createDirectory(localDir1, localSubDirName);
+ testCommons.getFileUtils().createFileInDir(localSubDir, "3.py");
- String zipFilePath = submitter.zipDir(localDir1.getAbsolutePath(),
- fakeLocalDir + "/user/yarn/mydir.zip");
+ String tempDir = localDir1.getParent();
+ String zipFilePath = ZipUtilities.zipDir(localDir1.getAbsolutePath(),
+ new File(tempDir, "mydir.zip").getAbsolutePath());
File zipFile = new File(zipFilePath);
- File unzipTargetDir = new File(fakeLocalDir, "unzipDir");
+ File unzipTargetDir = new File(tempDir, "unzipDir");
FileUtil.unZip(zipFile, unzipTargetDir);
- Assert.assertTrue(
- new File(fakeLocalDir + "/unzipDir/1.py").exists());
- Assert.assertTrue(
- new File(fakeLocalDir + "/unzipDir/2.py").exists());
- Assert.assertTrue(
- new File(fakeLocalDir + "/unzipDir/subdir1").exists());
- Assert.assertTrue(
- new File(fakeLocalDir + "/unzipDir/subdir1/3.py").exists());
-
- zipFile.delete();
- unzipTargetDir.delete();
- temp1.delete();
- temp2.delete();
- temp3.delete();
- localSubDir.delete();
- localDir1.delete();
+ assertTrue(
+ new File(tempDir + "/unzipDir/1.py").exists());
+ assertTrue(
+ new File(tempDir + "/unzipDir/2.py").exists());
+ assertTrue(
+ new File(tempDir + "/unzipDir/subdir1").exists());
+ assertTrue(
+ new File(tempDir + "/unzipDir/subdir1/3.py").exists());
}
}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliCommons.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliCommons.java
new file mode 100644
index 0000000..94e2c37
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliCommons.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
+
+import org.apache.hadoop.yarn.client.api.AppAdminClient;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceJobSubmitter;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
+
+import java.io.IOException;
+
+import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Common operations shared with test classes using Run job-related actions.
+ */
+public class TestYarnServiceRunJobCliCommons {
+ static final String DEFAULT_JOB_NAME = "my-job";
+ static final String DEFAULT_DOCKER_IMAGE = "tf-docker:1.1.0";
+ static final String DEFAULT_INPUT_PATH = "s3://input";
+ static final String DEFAULT_CHECKPOINT_PATH = "s3://output";
+ static final String DEFAULT_WORKER_DOCKER_IMAGE = "worker.image";
+ static final String DEFAULT_PS_DOCKER_IMAGE = "ps.image";
+ static final String DEFAULT_WORKER_LAUNCH_CMD = "python run-job.py";
+ static final String DEFAULT_PS_LAUNCH_CMD = "python run-ps.py";
+ static final String DEFAULT_TENSORBOARD_RESOURCES = "memory=2G,vcores=2";
+ static final String DEFAULT_WORKER_RESOURCES = "memory=2048M,vcores=2";
+ static final String DEFAULT_PS_RESOURCES = "memory=4096M,vcores=4";
+ static final String DEFAULT_TENSORBOARD_DOCKER_IMAGE = "tb_docker_image:001";
+
+ private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests();
+
+ void setup() throws IOException, YarnException {
+ SubmarineLogs.verboseOff();
+ AppAdminClient serviceClient = mock(AppAdminClient.class);
+ when(serviceClient.actionLaunch(any(String.class), any(String.class),
+ any(Long.class), any(String.class))).thenReturn(EXIT_SUCCESS);
+ when(serviceClient.getStatusString(any(String.class))).thenReturn(
+ "{\"id\": \"application_1234_1\"}");
+ YarnServiceUtils.setStubServiceClient(serviceClient);
+
+ fileUtils.setup();
+ }
+
+ void teardown() throws IOException {
+ fileUtils.teardown();
+ }
+
+ FileUtilitiesForTests getFileUtils() {
+ return fileUtils;
+ }
+
+ Service getServiceSpecFromJobSubmitter(JobSubmitter jobSubmitter) {
+ return ((YarnServiceJobSubmitter) jobSubmitter).getServiceWrapper()
+ .getService();
+ }
+
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliLocalization.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliLocalization.java
new file mode 100644
index 0000000..9bee302
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliLocalization.java
@@ -0,0 +1,599 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.service.api.records.ConfigFile;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration;
+import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
+import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
+import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.*;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Class to test YarnService localization feature with the Run job CLI action.
+ */
+public class TestYarnServiceRunJobCliLocalization {
+ private static final String ZIP_EXTENSION = ".zip";
+ private TestYarnServiceRunJobCliCommons testCommons =
+ new TestYarnServiceRunJobCliCommons();
+ private MockClientContext mockClientContext;
+ private RemoteDirectoryManager spyRdm;
+
+ @Before
+ public void before() throws IOException, YarnException {
+ testCommons.setup();
+ mockClientContext = YarnServiceCliTestUtils.getMockClientContext();
+ spyRdm = setupSpyRemoteDirManager();
+ }
+
+ @After
+ public void cleanup() throws IOException {
+ testCommons.teardown();
+ }
+
+ private ParamBuilderForTest createCommonParamsBuilder() {
+ return ParamBuilderForTest.create()
+ .withJobName(DEFAULT_JOB_NAME)
+ .withDockerImage(DEFAULT_DOCKER_IMAGE)
+ .withInputPath(DEFAULT_INPUT_PATH)
+ .withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
+ .withNumberOfWorkers(3)
+ .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE)
+ .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
+ .withWorkerResources(DEFAULT_WORKER_RESOURCES)
+ .withNumberOfPs(2)
+ .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE)
+ .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD)
+ .withPsResources(DEFAULT_PS_RESOURCES)
+ .withVerbose();
+ }
+
+ private void assertFilesAreDeleted(File... files) {
+ for (File file : files) {
+ assertFalse("File should be deleted: " + file.getAbsolutePath(),
+ file.exists());
+ }
+ }
+
+ private RemoteDirectoryManager setupSpyRemoteDirManager() {
+ RemoteDirectoryManager spyRdm =
+ spy(mockClientContext.getRemoteDirectoryManager());
+ mockClientContext.setRemoteDirectoryMgr(spyRdm);
+ return spyRdm;
+ }
+
+ private Path getStagingDir() throws IOException {
+ return mockClientContext.getRemoteDirectoryManager()
+ .getJobStagingArea(DEFAULT_JOB_NAME, true);
+ }
+
+ private RunJobCli createRunJobCliWithoutVerboseAssertion() {
+ return new RunJobCli(mockClientContext);
+ }
+
+ private RunJobCli createRunJobCli() {
+ RunJobCli runJobCli = new RunJobCli(mockClientContext);
+ assertFalse(SubmarineLogs.isVerbose());
+ return runJobCli;
+ }
+
+ private String getFilePath(String localUrl, Path stagingDir) {
+ return stagingDir.toUri().getPath()
+ + "/" + new Path(localUrl).getName();
+ }
+
+ private String getFilePathWithSuffix(Path stagingDir, String localUrl,
+ String suffix) {
+ return stagingDir.toUri().getPath() + "/" + new Path(localUrl).getName()
+ + suffix;
+ }
+
+ private void assertConfigFile(ConfigFile expected, ConfigFile actual) {
+ assertEquals("ConfigFile does not equal to expected!", expected, actual);
+ }
+
+ private void assertNumberOfLocalizations(List<ConfigFile> files,
+ int expected) {
+ assertEquals("Number of localizations is not the expected!", expected,
+ files.size());
+ }
+
+ private void verifyRdmCopyToRemoteLocalCalls(int expectedCalls)
+ throws IOException {
+ verify(spyRdm, times(expectedCalls)).copyRemoteToLocal(anyString(),
+ anyString());
+ }
+
+ /**
+ * Basic test.
+ * In one hand, create local temp file/dir for hdfs URI in
+ * local staging dir.
+ * In the other hand, use MockRemoteDirectoryManager mock
+ * implementation when check FileStatus or exists of HDFS file/dir
+ * --localization hdfs:///user/yarn/script1.py:.
+ * --localization /temp/script2.py:./
+ * --localization /temp/script2.py:/opt/script.py
+ */
+ @Test
+ public void testRunJobWithBasicLocalization() throws Exception {
+ String remoteUrl = "hdfs:///user/yarn/script1.py";
+ String containerLocal1 = ".";
+ String localUrl = "/temp/script2.py";
+ String containerLocal2 = "./";
+ String containerLocal3 = "/opt/script.py";
+ // Create local file, we need to put it under local temp dir
+ File localFile1 = testCommons.getFileUtils().createFileInTempDir(localUrl);
+
+ // create remote file in local staging dir to simulate HDFS
+ Path stagingDir = getStagingDir();
+ testCommons.getFileUtils().createFileInDir(stagingDir, remoteUrl);
+
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(remoteUrl, containerLocal1)
+ .withLocalization(localFile1.getAbsolutePath(), containerLocal2)
+ .withLocalization(localFile1.getAbsolutePath(), containerLocal3)
+ .build();
+ RunJobCli runJobCli = createRunJobCli();
+ runJobCli.run(params);
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
+ runJobCli.getJobSubmitter());
+ assertNumberOfServiceComponents(serviceSpec, 3);
+
+ // No remote dir and HDFS file exists.
+ // Ensure download never happened.
+ verifyRdmCopyToRemoteLocalCalls(0);
+ // Ensure local original files are not deleted
+ assertTrue(localFile1.exists());
+
+ List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
+ assertNumberOfLocalizations(files, 3);
+
+ ConfigFile expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
+ expectedConfigFile.setSrcFile(remoteUrl);
+ expectedConfigFile.setDestFile(new Path(remoteUrl).getName());
+ assertConfigFile(expectedConfigFile, files.get(0));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
+ expectedConfigFile.setSrcFile(getFilePath(localUrl, stagingDir));
+ expectedConfigFile.setDestFile(new Path(localUrl).getName());
+ assertConfigFile(expectedConfigFile, files.get(1));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
+ expectedConfigFile.setSrcFile(getFilePath(localUrl, stagingDir));
+ expectedConfigFile.setDestFile(new Path(containerLocal3).getName());
+ assertConfigFile(expectedConfigFile, files.get(2));
+
+ // Ensure env value is correct
+ String env = serviceSpec.getConfiguration().getEnv()
+ .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
+ String expectedMounts = new Path(containerLocal3).getName()
+ + ":" + containerLocal3 + ":rw";
+ assertTrue(env.contains(expectedMounts));
+ }
+
+ private void assertNumberOfServiceComponents(Service serviceSpec,
+ int expected) {
+ assertEquals(expected, serviceSpec.getComponents().size());
+ }
+
+ /**
+ * Non HDFS remote URI test.
+ * --localization https://a/b/1.patch:.
+ * --localization s3a://a/dir:/opt/mys3dir
+ */
+ @Test
+ public void testRunJobWithNonHDFSRemoteLocalization() throws Exception {
+ String remoteUri1 = "https://a/b/1.patch";
+ String containerLocal1 = ".";
+ String remoteUri2 = "s3a://a/s3dir";
+ String containerLocal2 = "/opt/mys3dir";
+
+ // create remote file in local staging dir to simulate HDFS
+ Path stagingDir = getStagingDir();
+ testCommons.getFileUtils().createFileInDir(stagingDir, remoteUri1);
+ File remoteDir1 =
+ testCommons.getFileUtils().createDirectory(stagingDir, remoteUri2);
+ testCommons.getFileUtils().createFileInDir(remoteDir1, "afile");
+
+ String suffix1 = "_" + remoteDir1.lastModified()
+ + "-" + mockClientContext.getRemoteDirectoryManager()
+ .getRemoteFileSize(remoteUri2);
+
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(remoteUri1, containerLocal1)
+ .withLocalization(remoteUri2, containerLocal2)
+ .build();
+ RunJobCli runJobCli = createRunJobCli();
+ runJobCli.run(params);
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
+ runJobCli.getJobSubmitter());
+ assertNumberOfServiceComponents(serviceSpec, 3);
+
+ // Ensure download remote dir 2 times
+ verifyRdmCopyToRemoteLocalCalls(2);
+
+ // Ensure downloaded temp files are deleted
+ assertFilesAreDeleted(
+ testCommons.getFileUtils().getTempFileWithName(remoteUri1),
+ testCommons.getFileUtils().getTempFileWithName(remoteUri2));
+
+ // Ensure zip file are deleted
+ assertFilesAreDeleted(
+ testCommons.getFileUtils()
+ .getTempFileWithName(remoteUri2 + "_" + suffix1 + ZIP_EXTENSION));
+
+ List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
+ assertNumberOfLocalizations(files, 2);
+
+ ConfigFile expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
+ expectedConfigFile.setSrcFile(getFilePath(remoteUri1, stagingDir));
+ expectedConfigFile.setDestFile(new Path(remoteUri1).getName());
+ assertConfigFile(expectedConfigFile, files.get(0));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, remoteUri2, suffix1 + ZIP_EXTENSION));
+ expectedConfigFile.setDestFile(new Path(containerLocal2).getName());
+ assertConfigFile(expectedConfigFile, files.get(1));
+
+ // Ensure env value is correct
+ String env = serviceSpec.getConfiguration().getEnv()
+ .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
+ String expectedMounts = new Path(remoteUri2).getName()
+ + ":" + containerLocal2 + ":rw";
+ assertTrue(env.contains(expectedMounts));
+ }
+
+ /**
+ * Test HDFS dir localization.
+ * --localization hdfs:///user/yarn/mydir:./mydir1
+ * --localization hdfs:///user/yarn/mydir2:/opt/dir2:rw
+ * --localization hdfs:///user/yarn/mydir:.
+ * --localization hdfs:///user/yarn/mydir2:./
+ */
+ @Test
+ public void testRunJobWithHdfsDirLocalization() throws Exception {
+ String remoteUrl = "hdfs:///user/yarn/mydir";
+ String containerPath = "./mydir1";
+ String remoteUrl2 = "hdfs:///user/yarn/mydir2";
+ String containerPath2 = "/opt/dir2";
+ String containerPath3 = ".";
+ String containerPath4 = "./";
+
+ // create remote file in local staging dir to simulate HDFS
+ Path stagingDir = getStagingDir();
+ File remoteDir1 =
+ testCommons.getFileUtils().createDirectory(stagingDir, remoteUrl);
+ testCommons.getFileUtils().createFileInDir(remoteDir1, "1.py");
+ testCommons.getFileUtils().createFileInDir(remoteDir1, "2.py");
+
+ File remoteDir2 =
+ testCommons.getFileUtils().createDirectory(stagingDir, remoteUrl2);
+ testCommons.getFileUtils().createFileInDir(remoteDir2, "3.py");
+ testCommons.getFileUtils().createFileInDir(remoteDir2, "4.py");
+
+ String suffix1 = "_" + remoteDir1.lastModified()
+ + "-" + mockClientContext.getRemoteDirectoryManager()
+ .getRemoteFileSize(remoteUrl);
+ String suffix2 = "_" + remoteDir2.lastModified()
+ + "-" + mockClientContext.getRemoteDirectoryManager()
+ .getRemoteFileSize(remoteUrl2);
+
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(remoteUrl, containerPath)
+ .withLocalization(remoteUrl2, containerPath2)
+ .withLocalization(remoteUrl, containerPath3)
+ .withLocalization(remoteUrl2, containerPath4)
+ .build();
+ RunJobCli runJobCli = createRunJobCli();
+ runJobCli.run(params);
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
+ runJobCli.getJobSubmitter());
+ assertNumberOfServiceComponents(serviceSpec, 3);
+
+ // Ensure download remote dir 4 times
+ verifyRdmCopyToRemoteLocalCalls(4);
+
+ // Ensure downloaded temp files are deleted
+ assertFilesAreDeleted(
+ testCommons.getFileUtils().getTempFileWithName(remoteUrl),
+ testCommons.getFileUtils().getTempFileWithName(remoteUrl2));
+
+ // Ensure zip file are deleted
+ assertFilesAreDeleted(
+ testCommons.getFileUtils()
+ .getTempFileWithName(remoteUrl + suffix1 + ZIP_EXTENSION),
+ testCommons.getFileUtils()
+ .getTempFileWithName(remoteUrl2 + suffix2 + ZIP_EXTENSION));
+
+ // Ensure files will be localized
+ List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
+ assertNumberOfLocalizations(files, 4);
+
+ ConfigFile expectedConfigFile = new ConfigFile();
+ // The hdfs dir should be download and compress and let YARN to uncompress
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, remoteUrl, suffix1 + ZIP_EXTENSION));
+ // Relative path in container, but not "." or "./". Use its own name
+ expectedConfigFile.setDestFile(new Path(containerPath).getName());
+ assertConfigFile(expectedConfigFile, files.get(0));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, remoteUrl2, suffix2 + ZIP_EXTENSION));
+ expectedConfigFile.setDestFile(new Path(containerPath2).getName());
+ assertConfigFile(expectedConfigFile, files.get(1));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, remoteUrl, suffix1 + ZIP_EXTENSION));
+ // Relative path in container ".", use remote path name
+ expectedConfigFile.setDestFile(new Path(remoteUrl).getName());
+ assertConfigFile(expectedConfigFile, files.get(2));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, remoteUrl2, suffix2 + ZIP_EXTENSION));
+ // Relative path in container ".", use remote path name
+ expectedConfigFile.setDestFile(new Path(remoteUrl2).getName());
+ assertConfigFile(expectedConfigFile, files.get(3));
+
+ // Ensure mounts env value is correct. Add one mount string
+ String env = serviceSpec.getConfiguration().getEnv()
+ .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
+
+ String expectedMounts =
+ new Path(containerPath2).getName() + ":" + containerPath2 + ":rw";
+ assertTrue(env.contains(expectedMounts));
+ }
+
+ /**
+ * Test if file/dir to be localized whose size exceeds limit.
+ * Max 10MB in configuration, mock remote will
+ * always return file size 100MB.
+ * This configuration will fail the job which has remoteUri
+ * But don't impact local dir/file
+ *
+ * --localization https://a/b/1.patch:.
+ * --localization s3a://a/dir:/opt/mys3dir
+ * --localization /temp/script2.py:./
+ */
+ @Test
+ public void testRunJobRemoteUriExceedLocalizationSize() throws Exception {
+ String remoteUri1 = "https://a/b/1.patch";
+ String containerLocal1 = ".";
+ String remoteUri2 = "s3a://a/s3dir";
+ String containerLocal2 = "/opt/mys3dir";
+ String localUri1 = "/temp/script2";
+ String containerLocal3 = "./";
+
+ SubmarineConfiguration submarineConf = new SubmarineConfiguration();
+
+ // Max 10MB, mock remote will always return file size 100MB.
+ submarineConf.set(
+ SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB,
+ "10");
+ mockClientContext.setSubmarineConfig(submarineConf);
+
+ assertFalse(SubmarineLogs.isVerbose());
+
+ // create remote file in local staging dir to simulate
+ Path stagingDir = getStagingDir();
+ testCommons.getFileUtils().createFileInDir(stagingDir, remoteUri1);
+ File remoteDir1 =
+ testCommons.getFileUtils().createDirectory(stagingDir, remoteUri2);
+ testCommons.getFileUtils().createFileInDir(remoteDir1, "afile");
+
+ // create local file, we need to put it under local temp dir
+ File localFile1 = testCommons.getFileUtils().createFileInTempDir(localUri1);
+
+ try {
+ RunJobCli runJobCli = createRunJobCli();
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(remoteUri1, containerLocal1)
+ .build();
+ runJobCli.run(params);
+ } catch (IOException e) {
+ // Shouldn't have exception because it's within file size limit
+ fail();
+ }
+ // we should download because fail fast
+ verifyRdmCopyToRemoteLocalCalls(1);
+ try {
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(remoteUri1, containerLocal1)
+ .withLocalization(remoteUri2, containerLocal2)
+ .withLocalization(localFile1.getAbsolutePath(), containerLocal3)
+ .build();
+
+ reset(spyRdm);
+ RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion();
+ runJobCli.run(params);
+ } catch (IOException e) {
+ assertTrue(e.getMessage()
+ .contains("104857600 exceeds configured max size:10485760"));
+ // we shouldn't do any download because fail fast
+ verifyRdmCopyToRemoteLocalCalls(0);
+ }
+
+ try {
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(localFile1.getAbsolutePath(), containerLocal3)
+ .build();
+ RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion();
+ runJobCli.run(params);
+ } catch (IOException e) {
+ assertTrue(e.getMessage()
+ .contains("104857600 exceeds configured max size:10485760"));
+ // we shouldn't do any download because fail fast
+ verifyRdmCopyToRemoteLocalCalls(0);
+ }
+ }
+
+ /**
+ * Test remote Uri doesn't exist.
+ * */
+ @Test
+ public void testRunJobWithNonExistRemoteUri() throws Exception {
+ String remoteUri1 = "hdfs:///a/b/1.patch";
+ String containerLocal1 = ".";
+ String localUri1 = "/a/b/c";
+ String containerLocal2 = "./";
+
+ try {
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(remoteUri1, containerLocal1)
+ .build();
+ RunJobCli runJobCli = createRunJobCli();
+ runJobCli.run(params);
+ } catch (IOException e) {
+ assertTrue(e.getMessage().contains("doesn't exists"));
+ }
+
+ try {
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(localUri1, containerLocal2)
+ .build();
+ RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion();
+ runJobCli.run(params);
+ } catch (IOException e) {
+ assertTrue(e.getMessage().contains("doesn't exists"));
+ }
+ }
+
+ /**
+ * Test local dir
+ * --localization /user/yarn/mydir:./mydir1
+ * --localization /user/yarn/mydir2:/opt/dir2:rw
+ * --localization /user/yarn/mydir2:.
+ */
+ @Test
+ public void testRunJobWithLocalDirLocalization() throws Exception {
+ String localUrl = "/user/yarn/mydir";
+ String containerPath = "./mydir1";
+ String localUrl2 = "/user/yarn/mydir2";
+ String containerPath2 = "/opt/dir2";
+ String containerPath3 = ".";
+
+ // create local file
+ File localDir1 = testCommons.getFileUtils().createDirInTempDir(localUrl);
+ testCommons.getFileUtils().createFileInDir(localDir1, "1.py");
+ testCommons.getFileUtils().createFileInDir(localDir1, "2.py");
+
+ File localDir2 = testCommons.getFileUtils().createDirInTempDir(localUrl2);
+ testCommons.getFileUtils().createFileInDir(localDir2, "3.py");
+ testCommons.getFileUtils().createFileInDir(localDir2, "4.py");
+
+ String suffix1 = "_" + localDir1.lastModified()
+ + "-" + localDir1.length();
+ String suffix2 = "_" + localDir2.lastModified()
+ + "-" + localDir2.length();
+
+ String[] params = createCommonParamsBuilder()
+ .withLocalization(localDir1.getAbsolutePath(), containerPath)
+ .withLocalization(localDir2.getAbsolutePath(), containerPath2)
+ .withLocalization(localDir2.getAbsolutePath(), containerPath3)
+ .build();
+ RunJobCli runJobCli = createRunJobCli();
+ runJobCli.run(params);
+
+ Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
+ runJobCli.getJobSubmitter());
+ assertNumberOfServiceComponents(serviceSpec, 3);
+
+ // we shouldn't do any download
+ verifyRdmCopyToRemoteLocalCalls(0);
+
+ // Ensure local original files are not deleted
+ assertTrue(localDir1.exists());
+ assertTrue(localDir2.exists());
+
+ // Ensure zip file are deleted
+ assertFalse(
+ testCommons.getFileUtils()
+ .getTempFileWithName(localUrl + suffix1 + ZIP_EXTENSION)
+ .exists());
+ assertFalse(
+ testCommons.getFileUtils()
+ .getTempFileWithName(localUrl2 + suffix2 + ZIP_EXTENSION)
+ .exists());
+
+ // Ensure dirs will be zipped and localized
+ List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
+ assertNumberOfLocalizations(files, 3);
+
+ Path stagingDir = getStagingDir();
+ ConfigFile expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, localUrl, suffix1 + ZIP_EXTENSION));
+ expectedConfigFile.setDestFile(new Path(containerPath).getName());
+ assertConfigFile(expectedConfigFile, files.get(0));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, localUrl2, suffix2 + ZIP_EXTENSION));
+ expectedConfigFile.setDestFile(new Path(containerPath2).getName());
+ assertConfigFile(expectedConfigFile, files.get(1));
+
+ expectedConfigFile = new ConfigFile();
+ expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
+ expectedConfigFile.setSrcFile(
+ getFilePathWithSuffix(stagingDir, localUrl2, suffix2 + ZIP_EXTENSION));
+ expectedConfigFile.setDestFile(new Path(localUrl2).getName());
+ assertConfigFile(expectedConfigFile, files.get(2));
+
+ // Ensure mounts env value is correct
+ String env = serviceSpec.getConfiguration().getEnv()
+ .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
+ String expectedMounts = new Path(containerPath2).getName()
+ + ":" + containerPath2 + ":rw";
+
+ assertTrue(env.contains(expectedMounts));
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestServiceWrapper.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestServiceWrapper.java
new file mode 100644
index 0000000..cd5c05c
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestServiceWrapper.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.*;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Class to test the {@link ServiceWrapper}.
+ */
+public class TestServiceWrapper {
+ private AbstractComponent createMockAbstractComponent(Component mockComponent,
+ String componentName, String localScriptFile) throws IOException {
+ when(mockComponent.getName()).thenReturn(componentName);
+
+ AbstractComponent mockAbstractComponent = mock(AbstractComponent.class);
+ when(mockAbstractComponent.createComponent()).thenReturn(mockComponent);
+ when(mockAbstractComponent.getLocalScriptFile())
+ .thenReturn(localScriptFile);
+ return mockAbstractComponent;
+ }
+
+ @Test
+ public void testWithSingleComponent() throws IOException {
+ Service mockService = mock(Service.class);
+ ServiceWrapper serviceWrapper = new ServiceWrapper(mockService);
+
+ Component mockComponent = mock(Component.class);
+ AbstractComponent mockAbstractComponent =
+ createMockAbstractComponent(mockComponent, "testComponent",
+ "testLocalScriptFile");
+ serviceWrapper.addComponent(mockAbstractComponent);
+
+ verify(mockService).addComponent(eq(mockComponent));
+
+ String launchCommand =
+ serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent");
+ assertEquals("testLocalScriptFile", launchCommand);
+ }
+
+ @Test
+ public void testWithMultipleComponent() throws IOException {
+ Service mockService = mock(Service.class);
+ ServiceWrapper serviceWrapper = new ServiceWrapper(mockService);
+
+ Component mockComponent1 = mock(Component.class);
+ AbstractComponent mockAbstractComponent1 =
+ createMockAbstractComponent(mockComponent1, "testComponent1",
+ "testLocalScriptFile1");
+
+ Component mockComponent2 = mock(Component.class);
+ AbstractComponent mockAbstractComponent2 =
+ createMockAbstractComponent(mockComponent2, "testComponent2",
+ "testLocalScriptFile2");
+
+ serviceWrapper.addComponent(mockAbstractComponent1);
+ serviceWrapper.addComponent(mockAbstractComponent2);
+
+ verify(mockService).addComponent(eq(mockComponent1));
+ verify(mockService).addComponent(eq(mockComponent2));
+
+ String launchCommand1 =
+ serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent1");
+ assertEquals("testLocalScriptFile1", launchCommand1);
+
+ String launchCommand2 =
+ serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent2");
+ assertEquals("testLocalScriptFile2", launchCommand2);
+ }
+
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java
index d7dc874..c8b2388 100644
--- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java
@@ -14,26 +14,30 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons;
import org.codehaus.jettison.json.JSONException;
import org.junit.Assert;
import org.junit.Test;
+/**
+ * Class to test some functionality of {@link TensorFlowCommons}.
+ */
public class TestTFConfigGenerator {
@Test
public void testSimpleDistributedTFConfigGenerator() throws JSONException {
- String json = YarnServiceUtils.getTFConfigEnv("worker", 5, 3, "wtan",
+ String json = TensorFlowCommons.getTFConfigEnv("worker", 5, 3, "wtan",
"tf-job-001", "example.com");
String expected =
"{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\",\\\"worker-1.wtan.tf-job-001.example.com:8000\\\",\\\"worker-2.wtan.tf-job-001.example.com:8000\\\",\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\",\\\"ps-1.wtan.tf-job-001.example.com:8000\\\",\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\ [...]
Assert.assertEquals(expected, json);
- json = YarnServiceUtils.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001",
+ json = TensorFlowCommons.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001",
"example.com");
expected =
"{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\",\\\"worker-1.wtan.tf-job-001.example.com:8000\\\",\\\"worker-2.wtan.tf-job-001.example.com:8000\\\",\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\",\\\"ps-1.wtan.tf-job-001.example.com:8000\\\",\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\ [...]
Assert.assertEquals(expected, json);
- json = YarnServiceUtils.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001",
+ json = TensorFlowCommons.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001",
"example.com");
expected =
"{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\\\"master\\\", \\\"index\\\":$_TASK_INDEX},\\\"environment\\\":\\\"cloud\\\"}";
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommandTestHelper.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommandTestHelper.java
new file mode 100644
index 0000000..5275603
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommandTestHelper.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
+import org.junit.Rule;
+import org.junit.rules.ExpectedException;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * This class is an abstract base class for testing Tensorboard and TensorFlow
+ * launch commands.
+ */
+public abstract class AbstractLaunchCommandTestHelper {
+ private TaskType taskType;
+ private boolean useTaskTypeOverride;
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ private void assertScriptContainsExportedEnvVar(List<String> fileContents,
+ String varName) {
+ String expected = String.format("export %s=", varName);
+ assertScriptContainsLine(fileContents, expected);
+ }
+
+ public static void assertScriptContainsExportedEnvVarWithValue(
+ List<String> fileContents, String varName, String value) {
+ String expected = String.format("export %s=%s", varName, value);
+ assertScriptContainsLine(fileContents, expected);
+ }
+
+ public static void assertScriptContainsLine(List<String> fileContents,
+ String expected) {
+ String message = String.format(
+ "File does not contain expected line '%s'!" + " File contents: %s",
+ expected, Arrays.toString(fileContents.toArray()));
+ assertTrue(message, fileContents.contains(expected));
+ }
+
+ public static void assertScriptContainsLineWithRegex(
+ List<String> fileContents,
+ String regex) {
+ String message = String.format(
+ "File does not contain expected line '%s'!" + " File contents: %s",
+ regex, Arrays.toString(fileContents.toArray()));
+
+ for (String line : fileContents) {
+ if (line.matches(regex)) {
+ return;
+ }
+ }
+ fail(message);
+ }
+
+ public static void assertScriptDoesNotContainLine(List<String> fileContents,
+ String expected) {
+ String message = String.format(
+ "File contains unexpected line '%s'!" + " File contents: %s",
+ expected, Arrays.toString(fileContents.toArray()));
+ assertFalse(message, fileContents.contains(expected));
+ }
+
+
+ private AbstractLaunchCommand createLaunchCommandByTaskType(TaskType taskType,
+ RunJobParameters params) throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ HadoopEnvironmentSetup hadoopEnvSetup =
+ new HadoopEnvironmentSetup(mockClientContext, fsOperations);
+ Component component = new Component();
+ Configuration yarnConfig = new Configuration();
+
+ return createLaunchCommandByTaskTypeInternal(taskType, params,
+ hadoopEnvSetup, component, yarnConfig);
+ }
+
+ private AbstractLaunchCommand createLaunchCommandByTaskTypeInternal(
+ TaskType taskType, RunJobParameters params,
+ HadoopEnvironmentSetup hadoopEnvSetup, Component component,
+ Configuration yarnConfig)
+ throws IOException {
+ if (taskType == TaskType.TENSORBOARD) {
+ return new TensorBoardLaunchCommand(
+ hadoopEnvSetup, getTaskType(taskType), component, params);
+ } else if (taskType == TaskType.WORKER
+ || taskType == TaskType.PRIMARY_WORKER) {
+ return new TensorFlowWorkerLaunchCommand(
+ hadoopEnvSetup, getTaskType(taskType), component, params, yarnConfig);
+ } else if (taskType == TaskType.PS) {
+ return new TensorFlowPsLaunchCommand(
+ hadoopEnvSetup, getTaskType(taskType), component, params, yarnConfig);
+ }
+ throw new IllegalStateException("Unknown taskType!");
+ }
+
+ public void overrideTaskType(TaskType taskType) {
+ this.taskType = taskType;
+ this.useTaskTypeOverride = true;
+ }
+
+ private TaskType getTaskType(TaskType taskType) {
+ if (useTaskTypeOverride) {
+ return this.taskType;
+ }
+ return taskType;
+ }
+
+ public void testHdfsRelatedEnvironmentIsUndefined(TaskType taskType,
+ RunJobParameters params) throws IOException {
+ AbstractLaunchCommand launchCommand =
+ createLaunchCommandByTaskType(taskType, params);
+
+ expectedException.expect(IOException.class);
+ expectedException
+ .expectMessage("Failed to detect HDFS-related environments.");
+ launchCommand.generateLaunchScript();
+ }
+
+ public List<String> testHdfsRelatedEnvironmentIsDefined(TaskType taskType,
+ RunJobParameters params) throws IOException {
+ AbstractLaunchCommand launchCommand =
+ createLaunchCommandByTaskType(taskType, params);
+
+ String result = launchCommand.generateLaunchScript();
+ assertNotNull(result);
+ File resultFile = new File(result);
+ assertTrue(resultFile.exists());
+
+ List<String> fileContents = Files.readAllLines(
+ Paths.get(resultFile.toURI()),
+ Charset.forName("UTF-8"));
+
+ assertEquals("#!/bin/bash", fileContents.get(0));
+ assertScriptContainsExportedEnvVar(fileContents, "HADOOP_HOME");
+ assertScriptContainsExportedEnvVar(fileContents, "HADOOP_YARN_HOME");
+ assertScriptContainsExportedEnvVarWithValue(fileContents,
+ "HADOOP_HDFS_HOME", "testHdfsHome");
+ assertScriptContainsExportedEnvVarWithValue(fileContents,
+ "HADOOP_COMMON_HOME", "testHdfsHome");
+ assertScriptContainsExportedEnvVarWithValue(fileContents, "HADOOP_CONF_DIR",
+ "$WORK_DIR");
+ assertScriptContainsExportedEnvVarWithValue(fileContents, "JAVA_HOME",
+ "testJavaHome");
+ assertScriptContainsExportedEnvVarWithValue(fileContents, "LD_LIBRARY_PATH",
+ "$LD_LIBRARY_PATH:$JAVA_HOME/lib/amd64/server");
+ assertScriptContainsExportedEnvVarWithValue(fileContents, "CLASSPATH",
+ "`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`");
+
+ return fileContents;
+ }
+
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TestLaunchCommandFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TestLaunchCommandFactory.java
new file mode 100644
index 0000000..6351f61
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TestLaunchCommandFactory.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+/**
+ * This class is to test the {@link LaunchCommandFactory}.
+ */
+public class TestLaunchCommandFactory {
+
+ private LaunchCommandFactory createLaunchCommandFactory(
+ RunJobParameters parameters) {
+ HadoopEnvironmentSetup hadoopEnvSetup = mock(HadoopEnvironmentSetup.class);
+ Configuration configuration = mock(Configuration.class);
+ return new LaunchCommandFactory(hadoopEnvSetup, parameters, configuration);
+ }
+
+ @Test
+ public void createLaunchCommandWorkerAndPrimaryWorker() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setWorkerLaunchCmd("testWorkerLaunchCommand");
+ LaunchCommandFactory launchCommandFactory = createLaunchCommandFactory(
+ parameters);
+ Component mockComponent = mock(Component.class);
+
+ AbstractLaunchCommand launchCommand =
+ launchCommandFactory.createLaunchCommand(TaskType.PRIMARY_WORKER,
+ mockComponent);
+
+ assertTrue(launchCommand instanceof TensorFlowWorkerLaunchCommand);
+
+ launchCommand =
+ launchCommandFactory.createLaunchCommand(TaskType.WORKER,
+ mockComponent);
+ assertTrue(launchCommand instanceof TensorFlowWorkerLaunchCommand);
+
+ }
+
+ @Test
+ public void createLaunchCommandPs() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setPSLaunchCmd("testPSLaunchCommand");
+ LaunchCommandFactory launchCommandFactory = createLaunchCommandFactory(
+ parameters);
+ Component mockComponent = mock(Component.class);
+
+ AbstractLaunchCommand launchCommand =
+ launchCommandFactory.createLaunchCommand(TaskType.PS,
+ mockComponent);
+
+ assertTrue(launchCommand instanceof TensorFlowPsLaunchCommand);
+ }
+
+ @Test
+ public void createLaunchCommandTensorboard() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setCheckpointPath("testCheckpointPath");
+ LaunchCommandFactory launchCommandFactory =
+ createLaunchCommandFactory(parameters);
+ Component mockComponent = mock(Component.class);
+
+ AbstractLaunchCommand launchCommand =
+ launchCommandFactory.createLaunchCommand(TaskType.TENSORBOARD,
+ mockComponent);
+
+ assertTrue(launchCommand instanceof TensorBoardLaunchCommand);
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorBoardLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorBoardLaunchCommand.java
new file mode 100644
index 0000000..b854cdf
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorBoardLaunchCommand.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommandTestHelper;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_HADOOP_HDFS_HOME;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_JAVA_HOME;
+
+/**
+ * This class is to test the {@link TensorBoardLaunchCommand}.
+ */
+public class TestTensorBoardLaunchCommand extends
+ AbstractLaunchCommandTestHelper {
+
+ @Test
+ public void testHdfsRelatedEnvironmentIsUndefined() throws IOException {
+ RunJobParameters params = new RunJobParameters();
+ params.setInputPath("hdfs://bla");
+ params.setName("testJobname");
+ params.setCheckpointPath("something");
+
+ testHdfsRelatedEnvironmentIsUndefined(TaskType.TENSORBOARD,
+ params);
+ }
+
+ @Test
+ public void testHdfsRelatedEnvironmentIsDefined() throws IOException {
+ RunJobParameters params = new RunJobParameters();
+ params.setName("testName");
+ params.setCheckpointPath("testCheckpointPath");
+ params.setInputPath("hdfs://bla");
+ params.setEnvars(ImmutableList.of(
+ DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
+ DOCKER_JAVA_HOME + "=" + "testJavaHome"));
+
+ List<String> fileContents =
+ testHdfsRelatedEnvironmentIsDefined(TaskType.TENSORBOARD,
+ params);
+ assertScriptContainsExportedEnvVarWithValue(fileContents, "LC_ALL",
+ "C && tensorboard --logdir=testCheckpointPath");
+ }
+
+ @Test
+ public void testCheckpointPathUndefined() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ HadoopEnvironmentSetup hadoopEnvSetup =
+ new HadoopEnvironmentSetup(mockClientContext, fsOperations);
+
+ Component component = new Component();
+ RunJobParameters params = new RunJobParameters();
+ params.setCheckpointPath(null);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("CheckpointPath must not be null");
+ new TensorBoardLaunchCommand(hadoopEnvSetup, TaskType.TENSORBOARD,
+ component, params);
+ }
+
+ @Test
+ public void testCheckpointPathEmptyString() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ HadoopEnvironmentSetup hadoopEnvSetup =
+ new HadoopEnvironmentSetup(mockClientContext, fsOperations);
+
+ Component component = new Component();
+ RunJobParameters params = new RunJobParameters();
+ params.setCheckpointPath("");
+
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("CheckpointPath must not be empty");
+ new TensorBoardLaunchCommand(hadoopEnvSetup, TaskType.TENSORBOARD,
+ component, params);
+ }
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorFlowLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorFlowLaunchCommand.java
new file mode 100644
index 0000000..fa584c7
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorFlowLaunchCommand.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommandTestHelper;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_HADOOP_HDFS_HOME;
+import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_JAVA_HOME;
+
+/**
+ * This class is to test the implementors of {@link TensorFlowLaunchCommand}.
+ */
+@RunWith(Parameterized.class)
+public class TestTensorFlowLaunchCommand
+ extends AbstractLaunchCommandTestHelper {
+ private TaskType taskType;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ Collection<Object[]> params = new ArrayList<>();
+ params.add(new Object[]{TaskType.WORKER });
+ params.add(new Object[]{TaskType.PS });
+ return params;
+ }
+
+ public TestTensorFlowLaunchCommand(TaskType taskType) {
+ this.taskType = taskType;
+ }
+
+
+ private void assertScriptContainsLaunchCommand(List<String> fileContents,
+ RunJobParameters params) {
+ String launchCommand = null;
+ if (taskType == TaskType.WORKER) {
+ launchCommand = params.getWorkerLaunchCmd();
+ } else if (taskType == TaskType.PS) {
+ launchCommand = params.getPSLaunchCmd();
+ }
+ assertScriptContainsLine(fileContents, launchCommand);
+ }
+
+ private void setLaunchCommandToParams(RunJobParameters params) {
+ if (taskType == TaskType.WORKER) {
+ params.setWorkerLaunchCmd("testWorkerLaunchCommand");
+ } else if (taskType == TaskType.PS) {
+ params.setPSLaunchCmd("testPsLaunchCommand");
+ }
+ }
+
+ private void setLaunchCommandToParams(RunJobParameters params, String value) {
+ if (taskType == TaskType.WORKER) {
+ params.setWorkerLaunchCmd(value);
+ } else if (taskType == TaskType.PS) {
+ params.setPSLaunchCmd(value);
+ }
+ }
+
+ private void assertTypeInJson(List<String> fileContents) {
+ String expectedType = null;
+ if (taskType == TaskType.WORKER) {
+ expectedType = "worker";
+ } else if (taskType == TaskType.PS) {
+ expectedType = "ps";
+ }
+ assertScriptContainsLineWithRegex(fileContents, String.format(".*type.*:" +
+ ".*%s.*", expectedType));
+ }
+
+ private TensorFlowLaunchCommand createTensorFlowLaunchCommandObject(
+ HadoopEnvironmentSetup hadoopEnvSetup, Configuration yarnConfig,
+ Component component, RunJobParameters params) throws IOException {
+ if (taskType == TaskType.WORKER) {
+ return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
+ component,
+ params, yarnConfig);
+ } else if (taskType == TaskType.PS) {
+ return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
+ params, yarnConfig);
+ }
+ throw new IllegalStateException("Unknown tasktype!");
+ }
+
+ @Test
+ public void testHdfsRelatedEnvironmentIsUndefined() throws IOException {
+ RunJobParameters params = new RunJobParameters();
+ params.setInputPath("hdfs://bla");
+ params.setName("testJobname");
+ setLaunchCommandToParams(params);
+
+ testHdfsRelatedEnvironmentIsUndefined(taskType, params);
+ }
+
+ @Test
+ public void testHdfsRelatedEnvironmentIsDefined() throws IOException {
+ RunJobParameters params = new RunJobParameters();
+ params.setName("testName");
+ params.setInputPath("hdfs://bla");
+ params.setEnvars(ImmutableList.of(
+ DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
+ DOCKER_JAVA_HOME + "=" + "testJavaHome"));
+ setLaunchCommandToParams(params);
+
+ List<String> fileContents =
+ testHdfsRelatedEnvironmentIsDefined(taskType,
+ params);
+ assertScriptContainsLaunchCommand(fileContents, params);
+ assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG=");
+ }
+
+ @Test
+ public void testLaunchCommandIsNull() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ HadoopEnvironmentSetup hadoopEnvSetup =
+ new HadoopEnvironmentSetup(mockClientContext, fsOperations);
+ Configuration yarnConfig = new Configuration();
+
+ Component component = new Component();
+ RunJobParameters params = new RunJobParameters();
+ params.setName("testName");
+ setLaunchCommandToParams(params, null);
+
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("LaunchCommand must not be null or empty");
+ TensorFlowLaunchCommand launchCommand =
+ createTensorFlowLaunchCommandObject(hadoopEnvSetup, yarnConfig,
+ component,
+ params);
+ launchCommand.generateLaunchScript();
+ }
+
+ @Test
+ public void testLaunchCommandIsEmpty() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ HadoopEnvironmentSetup hadoopEnvSetup =
+ new HadoopEnvironmentSetup(mockClientContext, fsOperations);
+ Configuration yarnConfig = new Configuration();
+
+ Component component = new Component();
+ RunJobParameters params = new RunJobParameters();
+ params.setName("testName");
+ setLaunchCommandToParams(params, "");
+
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("LaunchCommand must not be null or empty");
+ TensorFlowLaunchCommand launchCommand =
+ createTensorFlowLaunchCommandObject(hadoopEnvSetup, yarnConfig,
+ component, params);
+ launchCommand.generateLaunchScript();
+ }
+
+ @Test
+ public void testDistributedTrainingMissingTaskType() throws IOException {
+ overrideTaskType(null);
+
+ RunJobParameters params = new RunJobParameters();
+ params.setDistributed(true);
+ params.setName("testName");
+ params.setInputPath("hdfs://bla");
+ params.setEnvars(ImmutableList.of(
+ DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
+ DOCKER_JAVA_HOME + "=" + "testJavaHome"));
+ setLaunchCommandToParams(params);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("TaskType must not be null");
+ testHdfsRelatedEnvironmentIsDefined(taskType, params);
+ }
+
+ @Test
+ public void testDistributedTrainingNumberOfWorkersAndPsIsZero()
+ throws IOException {
+ RunJobParameters params = new RunJobParameters();
+ params.setDistributed(true);
+ params.setNumWorkers(0);
+ params.setNumPS(0);
+ params.setName("testName");
+ params.setInputPath("hdfs://bla");
+ params.setEnvars(ImmutableList.of(
+ DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
+ DOCKER_JAVA_HOME + "=" + "testJavaHome"));
+ setLaunchCommandToParams(params);
+
+ List<String> fileContents =
+ testHdfsRelatedEnvironmentIsDefined(taskType, params);
+
+ assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG=");
+ assertScriptContainsLineWithRegex(fileContents, ".*worker.*:\\[\\].*");
+ assertScriptContainsLineWithRegex(fileContents, ".*ps.*:\\[\\].*");
+ assertTypeInJson(fileContents);
+ }
+
+ @Test
+ public void testDistributedTrainingNumberOfWorkersAndPsIsNonZero()
+ throws IOException {
+ RunJobParameters params = new RunJobParameters();
+ params.setDistributed(true);
+ params.setNumWorkers(3);
+ params.setNumPS(2);
+ params.setName("testName");
+ params.setInputPath("hdfs://bla");
+ params.setEnvars(ImmutableList.of(
+ DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
+ DOCKER_JAVA_HOME + "=" + "testJavaHome"));
+ setLaunchCommandToParams(params);
+
+ List<String> fileContents =
+ testHdfsRelatedEnvironmentIsDefined(taskType, params);
+
+ //assert we have multiple PS and workers
+ assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG=");
+ assertScriptContainsLineWithRegex(fileContents, ".*worker.*:\\[.*,.*\\].*");
+ assertScriptContainsLineWithRegex(fileContents, ".*ps.*:\\[.*,.*\\].*");
+ assertTypeInJson(fileContents);
+ }
+
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/ComponentTestCommons.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/ComponentTestCommons.java
new file mode 100644
index 0000000..420fe5a
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/ComponentTestCommons.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.submarine.common.Envs;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * This class has some helper methods and fields
+ * in order to test TensorFlow-related Components easier.
+ */
+public class ComponentTestCommons {
+ String userName;
+ TaskType taskType;
+ LaunchCommandFactory mockLaunchCommandFactory;
+ FileSystemOperations fsOperations;
+ MockClientContext mockClientContext;
+ Configuration yarnConfig;
+ Resource resource;
+
+ ComponentTestCommons(TaskType taskType) {
+ this.taskType = taskType;
+ }
+
+ public void setup() throws IOException {
+ this.userName = System.getProperty("user.name");
+ this.resource = Resource.newInstance(4000, 10);
+ setupDependencies();
+ }
+
+ private void setupDependencies() throws IOException {
+ fsOperations = mock(FileSystemOperations.class);
+ mockClientContext = new MockClientContext();
+ mockLaunchCommandFactory = mock(LaunchCommandFactory.class);
+
+ AbstractLaunchCommand mockLaunchCommand = mock(AbstractLaunchCommand.class);
+ when(mockLaunchCommand.generateLaunchScript()).thenReturn("mockScript");
+ when(mockLaunchCommandFactory.createLaunchCommand(eq(taskType),
+ any(Component.class))).thenReturn(mockLaunchCommand);
+
+ yarnConfig = new Configuration();
+ }
+
+ void verifyCommonConfigEnvs(Component component) {
+ assertNotNull(component.getConfiguration().getEnv());
+ assertEquals(2, component.getConfiguration().getEnv().size());
+ assertEquals(ServiceApiConstants.COMPONENT_ID,
+ component.getConfiguration().getEnv().get(Envs.TASK_INDEX_ENV));
+ assertEquals(taskType.name(),
+ component.getConfiguration().getEnv().get(Envs.TASK_TYPE_ENV));
+ }
+
+ void verifyResources(Component component) {
+ assertNotNull(component.getResource());
+ assertEquals(10, (int) component.getResource().getCpus());
+ assertEquals(4000,
+ (int) Integer.valueOf(component.getResource().getMemory()));
+ }
+}
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorBoardComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorBoardComponent.java
new file mode 100644
index 0000000..1c81eb7
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorBoardComponent.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.service.api.records.Artifact;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+
+/**
+ * This class is to test {@link TensorBoardComponent}.
+ */
+public class TestTensorBoardComponent {
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+ private ComponentTestCommons testCommons =
+ new ComponentTestCommons(TaskType.TENSORBOARD);
+
+ @Before
+ public void setUp() throws IOException {
+ testCommons.setup();
+ }
+
+ private TensorBoardComponent createTensorBoardComponent(
+ RunJobParameters parameters) {
+ return new TensorBoardComponent(
+ testCommons.fsOperations,
+ testCommons.mockClientContext.getRemoteDirectoryManager(),
+ parameters,
+ testCommons.mockLaunchCommandFactory,
+ testCommons.yarnConfig);
+ }
+
+ @Test
+ public void testTensorBoardComponentWithNullResource() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setTensorboardResource(null);
+
+ TensorBoardComponent tensorBoardComponent =
+ createTensorBoardComponent(parameters);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("TensorBoard resource must not be null");
+ tensorBoardComponent.createComponent();
+ }
+
+ @Test
+ public void testTensorBoardComponentWithNullJobName() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setTensorboardResource(testCommons.resource);
+ parameters.setName(null);
+
+ TensorBoardComponent tensorBoardComponent =
+ createTensorBoardComponent(parameters);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("Job name must not be null");
+ tensorBoardComponent.createComponent();
+ }
+
+ @Test
+ public void testTensorBoardComponent() throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setTensorboardResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setTensorboardDockerImage("testTBDockerImage");
+
+ TensorBoardComponent tensorBoardComponent =
+ createTensorBoardComponent(parameters);
+
+ Component component = tensorBoardComponent.createComponent();
+
+ assertEquals(testCommons.taskType.getComponentName(), component.getName());
+ testCommons.verifyCommonConfigEnvs(component);
+
+ assertEquals(1L, (long) component.getNumberOfContainers());
+ assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy());
+ testCommons.verifyResources(component);
+ assertEquals(
+ new Artifact().type(Artifact.TypeEnum.DOCKER).id("testTBDockerImage"),
+ component.getArtifact());
+
+ assertEquals(String.format(
+ "http://tensorboard-0.testJobName.%s" + ".testDomain:6006",
+ testCommons.userName),
+ tensorBoardComponent.getTensorboardLink());
+
+ assertEquals("./run-TENSORBOARD.sh", component.getLaunchCommand());
+ verify(testCommons.fsOperations)
+ .uploadToRemoteFileAndLocalizeToContainerWorkDir(
+ any(Path.class), eq("mockScript"), eq("run-TENSORBOARD.sh"),
+ eq(component));
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowPsComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowPsComponent.java
new file mode 100644
index 0000000..8027365
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowPsComponent.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+import java.io.IOException;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.service.api.records.Artifact;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+/**
+ * This class is to test {@link TensorFlowPsComponent}.
+ */
+public class TestTensorFlowPsComponent {
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+ private ComponentTestCommons testCommons =
+ new ComponentTestCommons(TaskType.PS);
+
+ @Before
+ public void setUp() throws IOException {
+ testCommons.setup();
+ }
+
+ private TensorFlowPsComponent createPsComponent(RunJobParameters parameters) {
+ return new TensorFlowPsComponent(
+ testCommons.fsOperations,
+ testCommons.mockClientContext.getRemoteDirectoryManager(),
+ testCommons.mockLaunchCommandFactory,
+ parameters,
+ testCommons.yarnConfig);
+ }
+
+ private void verifyCommons(Component component) throws IOException {
+ assertEquals(testCommons.taskType.getComponentName(), component.getName());
+ testCommons.verifyCommonConfigEnvs(component);
+
+ assertTrue(component.getConfiguration().getProperties().isEmpty());
+
+ assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy());
+ testCommons.verifyResources(component);
+ assertEquals(
+ new Artifact().type(Artifact.TypeEnum.DOCKER).id("testPSDockerImage"),
+ component.getArtifact());
+
+ String taskTypeUppercase = testCommons.taskType.name().toUpperCase();
+ String expectedScriptName = String.format("run-%s.sh", taskTypeUppercase);
+ assertEquals(String.format("./%s", expectedScriptName),
+ component.getLaunchCommand());
+ verify(testCommons.fsOperations)
+ .uploadToRemoteFileAndLocalizeToContainerWorkDir(
+ any(Path.class), eq("mockScript"), eq(expectedScriptName),
+ eq(component));
+ }
+
+ @Test
+ public void testPSComponentWithNullResource() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setPsResource(null);
+
+ TensorFlowPsComponent psComponent =
+ createPsComponent(parameters);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("PS resource must not be null");
+ psComponent.createComponent();
+ }
+
+ @Test
+ public void testPSComponentWithNullJobName() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setPsResource(testCommons.resource);
+ parameters.setNumPS(1);
+ parameters.setName(null);
+
+ TensorFlowPsComponent psComponent =
+ createPsComponent(parameters);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("Job name must not be null");
+ psComponent.createComponent();
+ }
+
+ @Test
+ public void testPSComponentZeroNumberOfPS() throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setPsResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setPsDockerImage("testPSDockerImage");
+ parameters.setNumPS(0);
+
+ TensorFlowPsComponent psComponent =
+ createPsComponent(parameters);
+
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("Number of PS should be at least 1!");
+ psComponent.createComponent();
+ }
+
+ @Test
+ public void testPSComponentNumPSIsOne() throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setPsResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setNumPS(1);
+ parameters.setPsDockerImage("testPSDockerImage");
+
+ TensorFlowPsComponent psComponent =
+ createPsComponent(parameters);
+
+ Component component = psComponent.createComponent();
+
+ assertEquals(1L, (long) component.getNumberOfContainers());
+ verifyCommons(component);
+ }
+
+ @Test
+ public void testPSComponentNumPSIsTwo() throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setPsResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setNumPS(2);
+ parameters.setPsDockerImage("testPSDockerImage");
+
+ TensorFlowPsComponent psComponent =
+ createPsComponent(parameters);
+
+ Component component = psComponent.createComponent();
+
+ assertEquals(2L, (long) component.getNumberOfContainers());
+ verifyCommons(component);
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowWorkerComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowWorkerComponent.java
new file mode 100644
index 0000000..24bebc2
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowWorkerComponent.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.service.api.records.Artifact;
+import org.apache.hadoop.yarn.service.api.records.Component;
+import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.api.TaskType;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.io.IOException;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+
+/**
+ * This class is to test {@link TensorFlowWorkerComponent}.
+ */
+public class TestTensorFlowWorkerComponent {
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+ private ComponentTestCommons testCommons =
+ new ComponentTestCommons(TaskType.TENSORBOARD);
+
+ @Before
+ public void setUp() throws IOException {
+ testCommons.setup();
+ }
+
+ private TensorFlowWorkerComponent createWorkerComponent(
+ RunJobParameters parameters) {
+ return new TensorFlowWorkerComponent(
+ testCommons.fsOperations,
+ testCommons.mockClientContext.getRemoteDirectoryManager(),
+ parameters, testCommons.taskType,
+ testCommons.mockLaunchCommandFactory,
+ testCommons.yarnConfig);
+ }
+
+ private void verifyCommons(Component component) throws IOException {
+ verifyCommonsInternal(component, ImmutableMap.of());
+ }
+
+ private void verifyCommons(Component component,
+ Map<String, String> expectedProperties) throws IOException {
+ verifyCommonsInternal(component, expectedProperties);
+ }
+
+ private void verifyCommonsInternal(Component component,
+ Map<String, String> expectedProperties) throws IOException {
+ assertEquals(testCommons.taskType.getComponentName(), component.getName());
+ testCommons.verifyCommonConfigEnvs(component);
+
+ Map<String, String> actualProperties =
+ component.getConfiguration().getProperties();
+ if (!expectedProperties.isEmpty()) {
+ assertFalse(actualProperties.isEmpty());
+ expectedProperties.forEach(
+ (k, v) -> assertEquals(v, actualProperties.get(k)));
+ } else {
+ assertTrue(actualProperties.isEmpty());
+ }
+
+ assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy());
+ testCommons.verifyResources(component);
+ assertEquals(
+ new Artifact().type(Artifact.TypeEnum.DOCKER)
+ .id("testWorkerDockerImage"),
+ component.getArtifact());
+
+ String taskTypeUppercase = testCommons.taskType.name().toUpperCase();
+ String expectedScriptName = String.format("run-%s.sh", taskTypeUppercase);
+ assertEquals(String.format("./%s", expectedScriptName),
+ component.getLaunchCommand());
+ verify(testCommons.fsOperations)
+ .uploadToRemoteFileAndLocalizeToContainerWorkDir(
+ any(Path.class), eq("mockScript"), eq(expectedScriptName),
+ eq(component));
+ }
+
+ @Test
+ public void testWorkerComponentWithNullResource() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setWorkerResource(null);
+
+ TensorFlowWorkerComponent workerComponent =
+ createWorkerComponent(parameters);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("Worker resource must not be null");
+ workerComponent.createComponent();
+ }
+
+ @Test
+ public void testWorkerComponentWithNullJobName() throws IOException {
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setWorkerResource(testCommons.resource);
+ parameters.setNumWorkers(1);
+ parameters.setName(null);
+
+ TensorFlowWorkerComponent workerComponent =
+ createWorkerComponent(parameters);
+
+ expectedException.expect(NullPointerException.class);
+ expectedException.expectMessage("Job name must not be null");
+ workerComponent.createComponent();
+ }
+
+ @Test
+ public void testNormalWorkerComponentZeroNumberOfWorkers()
+ throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setWorkerResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setWorkerDockerImage("testWorkerDockerImage");
+ parameters.setNumWorkers(0);
+
+ TensorFlowWorkerComponent workerComponent =
+ createWorkerComponent(parameters);
+
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("Number of workers should be at least 1!");
+ workerComponent.createComponent();
+ }
+
+ @Test
+ public void testNormalWorkerComponentNumWorkersIsOne() throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setWorkerResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setNumWorkers(1);
+ parameters.setWorkerDockerImage("testWorkerDockerImage");
+
+ TensorFlowWorkerComponent workerComponent =
+ createWorkerComponent(parameters);
+
+ Component component = workerComponent.createComponent();
+
+ assertEquals(0L, (long) component.getNumberOfContainers());
+ verifyCommons(component);
+ }
+
+ @Test
+ public void testNormalWorkerComponentNumWorkersIsTwo() throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setWorkerResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setNumWorkers(2);
+ parameters.setWorkerDockerImage("testWorkerDockerImage");
+
+ TensorFlowWorkerComponent workerComponent =
+ createWorkerComponent(parameters);
+
+ Component component = workerComponent.createComponent();
+
+ assertEquals(1L, (long) component.getNumberOfContainers());
+ verifyCommons(component);
+ }
+
+ @Test
+ public void testPrimaryWorkerComponentNumWorkersIsTwo() throws IOException {
+ testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
+ testCommons = new ComponentTestCommons(TaskType.PRIMARY_WORKER);
+ testCommons.setup();
+
+ RunJobParameters parameters = new RunJobParameters();
+ parameters.setWorkerResource(testCommons.resource);
+ parameters.setName("testJobName");
+ parameters.setNumWorkers(2);
+ parameters.setWorkerDockerImage("testWorkerDockerImage");
+
+ TensorFlowWorkerComponent workerComponent =
+ createWorkerComponent(parameters);
+
+ Component component = workerComponent.createComponent();
+
+ assertEquals(1L, (long) component.getNumberOfContainers());
+ verifyCommons(component, ImmutableMap.of(
+ CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true"));
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestClassPathUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestClassPathUtilities.java
new file mode 100644
index 0000000..8fdb475
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestClassPathUtilities.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+
+/**
+ * This class is to test {@link ClassPathUtilities}.
+ */
+public class TestClassPathUtilities {
+
+ private static final String CLASSPATH_KEY = "java.class.path";
+ private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests();
+ private static String originalClasspath;
+
+ @BeforeClass
+ public static void setUpClass() {
+ originalClasspath = System.getProperty(CLASSPATH_KEY);
+ }
+
+ @Before
+ public void setUp() {
+ fileUtils.setup();
+ }
+
+ @After
+ public void teardown() throws IOException {
+ fileUtils.teardown();
+ System.setProperty(CLASSPATH_KEY, originalClasspath);
+ }
+
+ private static void addFileToClasspath(File file) {
+ String newClasspath = originalClasspath + ":" + file.getAbsolutePath();
+ System.setProperty(CLASSPATH_KEY, newClasspath);
+ }
+
+ @Test
+ public void findFileNotInClasspath() {
+ File resultFile = ClassPathUtilities.findFileOnClassPath("bla");
+ assertNull(resultFile);
+ }
+
+ @Test
+ public void findFileOnClasspath() throws Exception {
+ File testFile = fileUtils.createFileInTempDir("testFile");
+
+ addFileToClasspath(testFile);
+ File resultFile = ClassPathUtilities.findFileOnClassPath("testFile");
+
+ assertNotNull(resultFile);
+ assertEquals(testFile.getAbsolutePath(), resultFile.getAbsolutePath());
+ }
+
+ @Test
+ public void findDirectoryOnClasspath() throws Exception {
+ File testDir = fileUtils.createDirInTempDir("testDir");
+ File testFile = fileUtils.createFileInDir(testDir, "testFile");
+
+ addFileToClasspath(testDir);
+ File resultFile = ClassPathUtilities.findFileOnClassPath("testFile");
+
+ assertNotNull(resultFile);
+ assertEquals(testFile.getAbsolutePath(), resultFile.getAbsolutePath());
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestEnvironmentUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestEnvironmentUtilities.java
new file mode 100644
index 0000000..a52c1cf
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestEnvironmentUtilities.java
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.yarn.service.api.records.Configuration;
+import org.apache.hadoop.yarn.service.api.records.Service;
+import org.junit.Test;
+
+import java.util.Map;
+
+import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
+import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * This class is to test {@link EnvironmentUtilities}.
+ */
+public class TestEnvironmentUtilities {
+ private Service createServiceWithEmptyEnvVars() {
+ return createServiceWithEnvVars(Maps.newHashMap());
+ }
+
+ private Service createServiceWithEnvVars(Map<String, String> envVars) {
+ Service service = mock(Service.class);
+ Configuration config = mock(Configuration.class);
+ when(config.getEnv()).thenReturn(envVars);
+ when(service.getConfiguration()).thenReturn(config);
+
+ return service;
+ }
+
+ private void validateDefaultEnvVars(Map<String, String> resultEnvs) {
+ assertEquals("/etc/passwd:/etc/passwd:ro",
+ resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
+ }
+
+ private org.apache.hadoop.conf.Configuration
+ createYarnConfigWithSecurityValue(String value) {
+ org.apache.hadoop.conf.Configuration mockConfig =
+ mock(org.apache.hadoop.conf.Configuration.class);
+ when(mockConfig.get(HADOOP_SECURITY_AUTHENTICATION)).thenReturn(value);
+ return mockConfig;
+ }
+
+ @Test
+ public void testGetValueOfNullEnvVar() {
+ assertEquals("", EnvironmentUtilities.getValueOfEnvironment(null));
+ }
+
+ @Test
+ public void testGetValueOfEmptyEnvVar() {
+ assertEquals("", EnvironmentUtilities.getValueOfEnvironment(""));
+ }
+
+ @Test
+ public void testGetValueOfEnvVarJustAnEqualsSign() {
+ assertEquals("", EnvironmentUtilities.getValueOfEnvironment("="));
+ }
+
+ @Test
+ public void testGetValueOfEnvVarWithoutValue() {
+ assertEquals("", EnvironmentUtilities.getValueOfEnvironment("a="));
+ }
+
+ @Test
+ public void testGetValueOfEnvVarValidFormat() {
+ assertEquals("bbb", EnvironmentUtilities.getValueOfEnvironment("a=bbb"));
+ }
+
+ @Test
+ public void testHandleServiceEnvWithNullMap() {
+ Service service = createServiceWithEmptyEnvVars();
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ mock(org.apache.hadoop.conf.Configuration.class);
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(1, resultEnvs.size());
+ validateDefaultEnvVars(resultEnvs);
+ }
+
+ @Test
+ public void testHandleServiceEnvWithEmptyMap() {
+ Service service = createServiceWithEmptyEnvVars();
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ mock(org.apache.hadoop.conf.Configuration.class);
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(1, resultEnvs.size());
+ validateDefaultEnvVars(resultEnvs);
+ }
+
+ @Test
+ public void testHandleServiceEnvWithYarnConfigSecurityValueNonKerberos() {
+ Service service = createServiceWithEmptyEnvVars();
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ createYarnConfigWithSecurityValue("nonkerberos");
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(1, resultEnvs.size());
+ validateDefaultEnvVars(resultEnvs);
+ }
+
+ @Test
+ public void testHandleServiceEnvWithYarnConfigSecurityValueKerberos() {
+ Service service = createServiceWithEmptyEnvVars();
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ createYarnConfigWithSecurityValue("kerberos");
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(1, resultEnvs.size());
+ assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
+ resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
+ }
+
+ @Test
+ public void testHandleServiceEnvWithExistingEnvsAndValidNewEnvs() {
+ Map<String, String> existingEnvs = Maps.newHashMap(
+ ImmutableMap.<String, String>builder().
+ put("a", "1").
+ put("b", "2").
+ build());
+ ImmutableList<String> newEnvs = ImmutableList.of("c=3", "d=4");
+
+ Service service = createServiceWithEnvVars(existingEnvs);
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ createYarnConfigWithSecurityValue("kerberos");
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(5, resultEnvs.size());
+ assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
+ resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
+ assertEquals("1", resultEnvs.get("a"));
+ assertEquals("2", resultEnvs.get("b"));
+ assertEquals("3", resultEnvs.get("c"));
+ assertEquals("4", resultEnvs.get("d"));
+ }
+
+ @Test
+ public void testHandleServiceEnvWithExistingEnvsAndNewEnvsWithoutEquals() {
+ Map<String, String> existingEnvs = Maps.newHashMap(
+ ImmutableMap.<String, String>builder().
+ put("a", "1").
+ put("b", "2").
+ build());
+ ImmutableList<String> newEnvs = ImmutableList.of("c3", "d4");
+
+ Service service = createServiceWithEnvVars(existingEnvs);
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ createYarnConfigWithSecurityValue("kerberos");
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(5, resultEnvs.size());
+ assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
+ resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
+ assertEquals("1", resultEnvs.get("a"));
+ assertEquals("2", resultEnvs.get("b"));
+ assertEquals("", resultEnvs.get("c3"));
+ assertEquals("", resultEnvs.get("d4"));
+ }
+
+ @Test
+ public void testHandleServiceEnvWithExistingEnvVarKey() {
+ Map<String, String> existingEnvs = Maps.newHashMap(
+ ImmutableMap.<String, String>builder().
+ put("a", "1").
+ put("b", "2").
+ build());
+ ImmutableList<String> newEnvs = ImmutableList.of("a=33", "c=44");
+
+ Service service = createServiceWithEnvVars(existingEnvs);
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ createYarnConfigWithSecurityValue("kerberos");
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(4, resultEnvs.size());
+ assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
+ resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
+ assertEquals("1:33", resultEnvs.get("a"));
+ assertEquals("2", resultEnvs.get("b"));
+ assertEquals("44", resultEnvs.get("c"));
+ }
+
+ @Test
+ public void testHandleServiceEnvWithExistingEnvVarKeyMultipleTimes() {
+ Map<String, String> existingEnvs = Maps.newHashMap(
+ ImmutableMap.<String, String>builder().
+ put("a", "1").
+ put("b", "2").
+ build());
+ ImmutableList<String> newEnvs = ImmutableList.of("a=33", "a=44");
+
+ Service service = createServiceWithEnvVars(existingEnvs);
+ org.apache.hadoop.conf.Configuration yarnConfig =
+ createYarnConfigWithSecurityValue("kerberos");
+ EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
+
+ Map<String, String> resultEnvs = service.getConfiguration().getEnv();
+ assertEquals(3, resultEnvs.size());
+ assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
+ resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
+ assertEquals("1:33:44", resultEnvs.get("a"));
+ assertEquals("2", resultEnvs.get("b"));
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestKerberosPrincipalFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestKerberosPrincipalFactory.java
new file mode 100644
index 0000000..74cbc85
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestKerberosPrincipalFactory.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
+import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests;
+import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
+import org.apache.hadoop.yarn.submarine.common.MockClientContext;
+import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * This class is to test {@link KerberosPrincipalFactory}.
+ */
+public class TestKerberosPrincipalFactory {
+ private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests();
+
+ @Before
+ public void setUp() {
+ fileUtils.setup();
+ }
+
+ @After
+ public void teardown() throws IOException {
+ fileUtils.teardown();
+ }
+
+ private File createKeytabFile(String keytabFileName) throws IOException {
+ return fileUtils.createFileInTempDir(keytabFileName);
+ }
+
+ @Test
+ public void testCreatePrincipalEmptyPrincipalAndKeytab() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+
+ RunJobParameters parameters = mock(RunJobParameters.class);
+ when(parameters.getPrincipal()).thenReturn("");
+ when(parameters.getKeytab()).thenReturn("");
+
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ KerberosPrincipal result =
+ KerberosPrincipalFactory.create(fsOperations,
+ mockClientContext.getRemoteDirectoryManager(), parameters);
+
+ assertNull(result);
+ }
+ @Test
+ public void testCreatePrincipalEmptyPrincipalString() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+
+ RunJobParameters parameters = mock(RunJobParameters.class);
+ when(parameters.getPrincipal()).thenReturn("");
+ when(parameters.getKeytab()).thenReturn("keytab");
+
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ KerberosPrincipal result =
+ KerberosPrincipalFactory.create(fsOperations,
+ mockClientContext.getRemoteDirectoryManager(), parameters);
+
+ assertNull(result);
+ }
+
+ @Test
+ public void testCreatePrincipalEmptyKeyTabString() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+
+ RunJobParameters parameters = mock(RunJobParameters.class);
+ when(parameters.getPrincipal()).thenReturn("principal");
+ when(parameters.getKeytab()).thenReturn("");
+
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ KerberosPrincipal result =
+ KerberosPrincipalFactory.create(fsOperations,
+ mockClientContext.getRemoteDirectoryManager(), parameters);
+
+ assertNull(result);
+ }
+
+ @Test
+ public void testCreatePrincipalNonEmptyPrincipalAndKeytab()
+ throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+
+ RunJobParameters parameters = mock(RunJobParameters.class);
+ when(parameters.getPrincipal()).thenReturn("principal");
+ when(parameters.getKeytab()).thenReturn("keytab");
+
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+ KerberosPrincipal result =
+ KerberosPrincipalFactory.create(fsOperations,
+ mockClientContext.getRemoteDirectoryManager(), parameters);
+
+ assertNotNull(result);
+ assertEquals("file://keytab", result.getKeytab());
+ assertEquals("principal", result.getPrincipalName());
+ }
+
+ @Test
+ public void testCreatePrincipalDistributedKeytab() throws IOException {
+ MockClientContext mockClientContext = new MockClientContext();
+ String jobname = "testJobname";
+ String keytab = "testKeytab";
+ File keytabFile = createKeytabFile(keytab);
+
+ RunJobParameters parameters = mock(RunJobParameters.class);
+ when(parameters.getPrincipal()).thenReturn("principal");
+ when(parameters.getKeytab()).thenReturn(keytabFile.getAbsolutePath());
+ when(parameters.getName()).thenReturn(jobname);
+ when(parameters.isDistributeKeytab()).thenReturn(true);
+
+ FileSystemOperations fsOperations =
+ new FileSystemOperations(mockClientContext);
+
+ KerberosPrincipal result =
+ KerberosPrincipalFactory.create(fsOperations,
+ mockClientContext.getRemoteDirectoryManager(), parameters);
+
+ Path stagingDir = mockClientContext.getRemoteDirectoryManager()
+ .getJobStagingArea(parameters.getName(), true);
+ String expectedKeytabFilePath =
+ FileUtilitiesForTests.getFilename(stagingDir, keytab).getAbsolutePath();
+
+ assertNotNull(result);
+ assertEquals(expectedKeytabFilePath, result.getKeytab());
+ assertEquals("principal", result.getPrincipalName());
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestSubmarineResourceUtils.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestSubmarineResourceUtils.java
new file mode 100644
index 0000000..f22fbaa
--- /dev/null
+++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestSubmarineResourceUtils.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.submarine.utils;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.service.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.util.resource.CustomResourceTypesConfigurationProvider;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+import org.junit.After;
+import org.junit.Test;
+
+import java.util.Map;
+
+import static org.junit.Assert.*;
+
+/**
+ * This class is to test {@link SubmarineResourceUtils}.
+ */
+public class TestSubmarineResourceUtils {
+ private static final String CUSTOM_RESOURCE_NAME = "a-custom-resource";
+
+ private void initResourceTypes() {
+ CustomResourceTypesConfigurationProvider.initResourceTypes(
+ ImmutableMap.<String, String>builder()
+ .put(CUSTOM_RESOURCE_NAME, "G")
+ .build());
+ }
+
+ @After
+ public void cleanup() {
+ ResourceUtils.resetResourceTypes(new Configuration());
+ }
+
+ @Test
+ public void testConvertResourceWithCustomResource() {
+ initResourceTypes();
+ Resource res = Resource.newInstance(4096, 12,
+ ImmutableMap.of(CUSTOM_RESOURCE_NAME, 20L));
+
+ org.apache.hadoop.yarn.service.api.records.Resource serviceResource =
+ SubmarineResourceUtils.convertYarnResourceToServiceResource(res);
+
+ assertEquals(12, serviceResource.getCpus().intValue());
+ assertEquals(4096, (int) Integer.valueOf(serviceResource.getMemory()));
+ Map<String, ResourceInformation> additionalResources =
+ serviceResource.getAdditional();
+
+ // Additional resources also includes vcores and memory
+ assertEquals(3, additionalResources.size());
+ ResourceInformation customResourceRI =
+ additionalResources.get(CUSTOM_RESOURCE_NAME);
+ assertEquals("G", customResourceRI.getUnit());
+ assertEquals(20L, (long) customResourceRI.getValue());
+ }
+
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: common-commits-unsubscribe@hadoop.apache.org
For additional commands, e-mail: common-commits-help@hadoop.apache.org