You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by zh...@apache.org on 2019/11/19 02:47:48 UTC
[submarine] branch master updated: SUBMARINE-66. Improve TF config
env JSON generator + tests
This is an automated email from the ASF dual-hosted git repository.
zhouquan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new aae0380 SUBMARINE-66. Improve TF config env JSON generator + tests
aae0380 is described below
commit aae0380151f79c7e93f8a05ab1b979d3c4c5ef9b
Author: Adam Antal <ad...@cloudera.com>
AuthorDate: Mon Nov 18 11:05:31 2019 +0100
SUBMARINE-66. Improve TF config env JSON generator + tests
### What is this PR for?
* The goal is to update the TensorFlow config generation. Also some of the tests were doing some manual JSON thing which is not very maintainable - it is also changed to a more flexible one using jackson-databind
### What type of PR is it?
* Refactoring | Test
### Todos
* [x] - Travis checks
* [x] - New tests should pass
### What is the Jira issue?
* [SUBMARINE-66](https://issues.apache.org/jira/browse/SUBMARINE-66)
### How should this be tested?
* Only the new UTs should pass.
### Screenshots (if appropriate)
### Questions:
* Does the licenses files need update? Yes
* We should probably add jackson-module-jaxb-annotations?
* Is there breaking changes for older versions? No
* Does this needs documentation? No
Author: Adam Antal <ad...@cloudera.com>
Closes #91 from adamantal/SUBMARINE-66 and squashes the following commits:
268e7e4 [Adam Antal] Renaming setUp to setup in TensorFlowConfigEnvGeneratorTest
ed3b220 [Adam Antal] SUBMARINE-66. Improve TF config env JSON generator + tests
---
pom.xml | 1 +
submarine-all/pom.xml | 22 +++
.../server-submitter/submitter-yarn/pom.xml | 4 +
.../server-submitter/submitter-yarnservice/pom.xml | 43 ++++++
.../yarnservice/tensorflow/TensorFlowCommons.java | 55 --------
.../tensorflow/TensorFlowConfigEnvGenerator.java | 102 ++++++++++++++
.../command/TensorFlowLaunchCommand.java | 3 +-
.../yarnservice/TestTFConfigGenerator.java | 73 ----------
.../TensorFlowConfigEnvGeneratorTest.java | 149 +++++++++++++++++++++
9 files changed, 323 insertions(+), 129 deletions(-)
diff --git a/pom.xml b/pom.xml
index 82234c7..5465cd7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -72,6 +72,7 @@
<gson.version>2.8.1</gson.version>
<jackson-databind.version>2.9.10</jackson-databind.version>
<jackson-annotations.version>2.9.10</jackson-annotations.version>
+ <jackson-module-jaxb-annotations.version>2.9.10</jackson-module-jaxb-annotations.version>
<commons-configuration.version>1.10</commons-configuration.version>
<commons-httpclient.version>3.1</commons-httpclient.version>
diff --git a/submarine-all/pom.xml b/submarine-all/pom.xml
index 952709a..900a544 100644
--- a/submarine-all/pom.xml
+++ b/submarine-all/pom.xml
@@ -95,6 +95,12 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs-client</artifactId>
<version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -105,6 +111,10 @@
<groupId>io.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
</exclusions>
</dependency>
</dependencies>
@@ -128,6 +138,12 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs-client</artifactId>
<version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
</dependencies>
</profile>
@@ -142,6 +158,12 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs-client</artifactId>
<version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
</dependencies>
</profile>
diff --git a/submarine-server/server-submitter/submitter-yarn/pom.xml b/submarine-server/server-submitter/submitter-yarn/pom.xml
index 77da805..31d0040 100644
--- a/submarine-server/server-submitter/submitter-yarn/pom.xml
+++ b/submarine-server/server-submitter/submitter-yarn/pom.xml
@@ -171,6 +171,10 @@
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/submarine-server/server-submitter/submitter-yarnservice/pom.xml b/submarine-server/server-submitter/submitter-yarnservice/pom.xml
index 88b8625..60e0217 100644
--- a/submarine-server/server-submitter/submitter-yarnservice/pom.xml
+++ b/submarine-server/server-submitter/submitter-yarnservice/pom.xml
@@ -104,6 +104,10 @@
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
</exclusions>
</dependency>
@@ -142,6 +146,14 @@
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.module</groupId>
+ <artifactId>jackson-module-jaxb-annotations</artifactId>
+ </exclusion>
</exclusions>
</dependency>
@@ -225,6 +237,10 @@
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -256,6 +272,10 @@
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -273,8 +293,31 @@
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.module</groupId>
+ <artifactId>jackson-module-jaxb-annotations</artifactId>
+ </exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ <version>${jackson-annotations.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <version>${jackson-databind.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.module</groupId>
+ <artifactId>jackson-module-jaxb-annotations</artifactId>
+ <version>${jackson-module-jaxb-annotations.version}</version>
+ </dependency>
</dependencies>
<build>
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java
index 424fc6d..32d2640 100644
--- a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java
@@ -24,7 +24,6 @@ import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.submarine.commons.runtime.conf.Envs;
import org.apache.submarine.commons.runtime.api.Role;
-import org.apache.submarine.server.submitter.yarnservice.YarnServiceUtils;
import java.util.Map;
@@ -55,58 +54,4 @@ public final class TensorFlowCommons {
public static String getScriptFileName(Role role) {
return "run-" + role.getName() + ".sh";
}
-
- public static String getTFConfigEnv(String componentName, int nWorkers,
- int nPs, String serviceName, String userName, String domain) {
- String commonEndpointSuffix = YarnServiceUtils
- .getDNSNameCommonSuffix(serviceName, userName, domain, 8000);
-
- String json = "{\\\"cluster\\\":{";
-
- String master = getComponentArrayJson("master", 1, commonEndpointSuffix)
- + ",";
- String worker = getComponentArrayJson("worker", nWorkers - 1,
- commonEndpointSuffix) + ",";
- String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},";
-
- StringBuilder sb = new StringBuilder();
- sb.append("\\\"task\\\":{");
- sb.append(" \\\"type\\\":\\\"");
- sb.append(componentName);
- sb.append("\\\",");
- sb.append(" \\\"index\\\":");
- sb.append('$');
- sb.append(Envs.TASK_INDEX_ENV + "},");
- String task = sb.toString();
- String environment = "\\\"environment\\\":\\\"cloud\\\"}";
-
- sb = new StringBuilder();
- sb.append(json);
- sb.append(master);
- sb.append(worker);
- sb.append(ps);
- sb.append(task);
- sb.append(environment);
- return sb.toString();
- }
-
- private static String getComponentArrayJson(String componentName, int count,
- String endpointSuffix) {
- String component = "\\\"" + componentName + "\\\":";
- StringBuilder array = new StringBuilder();
- array.append("[");
- for (int i = 0; i < count; i++) {
- array.append("\\\"");
- array.append(componentName);
- array.append("-");
- array.append(i);
- array.append(endpointSuffix);
- array.append("\\\"");
- if (i != count - 1) {
- array.append(",");
- }
- }
- array.append("]");
- return component + array.toString();
- }
}
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGenerator.java b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGenerator.java
new file mode 100644
index 0000000..45de9e7
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGenerator.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.yarnservice.tensorflow;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializationFeature;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import com.fasterxml.jackson.databind.type.TypeFactory;
+import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector;
+import org.apache.submarine.commons.runtime.conf.Envs;
+import org.apache.submarine.server.submitter.yarnservice.YarnServiceUtils;
+
+public class TensorFlowConfigEnvGenerator {
+
+ private static final ObjectMapper OBJECT_MAPPER = createObjectMapper();
+
+ private static ObjectMapper createObjectMapper() {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.setAnnotationIntrospector(
+ new JaxbAnnotationIntrospector(TypeFactory.defaultInstance()));
+ mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
+ mapper.configure(SerializationFeature.FLUSH_AFTER_WRITE_VALUE, false);
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+ return mapper;
+ }
+
+ public static String getTFConfigEnv(String componentName, int nWorkers,
+ int nPs, String serviceName, String userName, String domain) {
+ String commonEndpointSuffix = YarnServiceUtils
+ .getDNSNameCommonSuffix(serviceName, userName, domain, 8000);
+
+ TFConfigEnv tfConfigEnv =
+ new TFConfigEnv(nWorkers, nPs, componentName, commonEndpointSuffix);
+ return tfConfigEnv.toJson();
+ }
+
+ private static class TFConfigEnv {
+ private final int nWorkers;
+ private final int nPS;
+ private final String componentName;
+ private final String endpointSuffix;
+
+ TFConfigEnv(int nWorkers, int nPS, String componentName,
+ String endpointSuffix) {
+ this.nWorkers = nWorkers;
+ this.nPS = nPS;
+ this.componentName = componentName;
+ this.endpointSuffix = endpointSuffix;
+ }
+
+ String toJson() {
+ ObjectNode rootNode = OBJECT_MAPPER.createObjectNode();
+
+ ObjectNode cluster = rootNode.putObject("cluster");
+ createComponentArray(cluster, "master", 1);
+ createComponentArray(cluster, "worker", nWorkers - 1);
+ createComponentArray(cluster, "ps", nPS);
+
+ ObjectNode task = rootNode.putObject("task");
+ task.put("type", componentName);
+ task.put("index", "$" + Envs.TASK_INDEX_ENV);
+ task.put("environment", "cloud");
+ try {
+ return OBJECT_MAPPER.writeValueAsString(rootNode);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize TF config env JSON!",
+ e);
+ }
+ }
+
+ private void createComponentArray(ObjectNode cluster, String name,
+ int count) {
+ ArrayNode array = cluster.putArray(name);
+ for (int i = 0; i < count; i++) {
+ String componentValue = String.format("%s-%d%s", name, i,
+ endpointSuffix);
+ array.add(componentValue);
+ }
+ }
+ }
+}
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
index b88ea24..13c83cc 100644
--- a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
@@ -27,6 +27,7 @@ import org.apache.submarine.server.submitter.yarnservice.HadoopEnvironmentSetup;
import org.apache.submarine.server.submitter.yarnservice.command.AbstractLaunchCommand;
import org.apache.submarine.server.submitter.yarnservice.command.LaunchScriptBuilder;
import org.apache.submarine.server.submitter.yarnservice.tensorflow.TensorFlowCommons;
+import org.apache.submarine.server.submitter.yarnservice.tensorflow.TensorFlowConfigEnvGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -78,7 +79,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
// When distributed training is required
if (distributed) {
- String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
+ String tfConfigEnvValue = TensorFlowConfigEnvGenerator.getTFConfigEnv(
role.getComponentName(), numberOfWorkers,
numberOfPS, name,
TensorFlowCommons.getUserName(),
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/TestTFConfigGenerator.java b/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/TestTFConfigGenerator.java
deleted file mode 100644
index 072eaec..0000000
--- a/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/TestTFConfigGenerator.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.submarine.server.submitter.yarnservice;
-
-import org.apache.submarine.server.submitter.yarnservice.tensorflow.TensorFlowCommons;
-import org.codehaus.jettison.json.JSONException;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Class to test some functionality of {@link TensorFlowCommons}.
- */
-public class TestTFConfigGenerator {
- @Test
- public void testSimpleDistributedTFConfigGenerator() throws JSONException {
- String json = TensorFlowCommons.getTFConfigEnv("worker", 5, 3, "wtan",
- "tf-job-001", "example.com");
- String expected =
- "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"]," +
- "\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"worker-1.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"worker-2.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"]," +
- "\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"ps-1.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]}," +
- "\\\"task\\\":{ \\\"type\\\":\\\"worker\\\", \\\"index\\\":$_TASK_INDEX}," +
- "\\\"environment\\\":\\\"cloud\\\"}";
- Assert.assertEquals(expected, json);
-
- json = TensorFlowCommons.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001",
- "example.com");
- expected =
- "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"]," +
- "\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"worker-1.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"worker-2.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"]," +
- "\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"ps-1.wtan.tf-job-001.example.com:8000\\\"," +
- "\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]}," +
- "\\\"task\\\":{ \\\"type\\\":\\\"ps\\\", \\\"index\\\":$_TASK_INDEX}," +
- "\\\"environment\\\":\\\"cloud\\\"}";
- Assert.assertEquals(expected, json);
-
- json = TensorFlowCommons.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001",
- "example.com");
- expected =
- "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"]," +
- "\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"]," +
- "\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"]}," +
- "\\\"task\\\":{ \\\"type\\\":\\\"master\\\", \\\"index\\\":$_TASK_INDEX}," +
- "\\\"environment\\\":\\\"cloud\\\"}";
- Assert.assertEquals(expected, json);
- }
-}
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGeneratorTest.java b/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGeneratorTest.java
new file mode 100644
index 0000000..de1135b
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGeneratorTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.yarnservice.tensorflow;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+import com.fasterxml.jackson.databind.node.JsonNodeType;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * Class to test some functionality of {@link TensorFlowConfigEnvGenerator}.
+ */
+public class TensorFlowConfigEnvGeneratorTest {
+ private ObjectMapper objectMapper;
+
+ @Before
+ public void setup() {
+ objectMapper = new ObjectMapper();
+ }
+
+ private void verifyCommonJsonData(JsonNode node, String taskType) {
+ JsonNode task = node.get("task");
+ assertNotNull(task);
+ assertEquals(taskType, task.get("type").asText());
+ assertEquals("$_TASK_INDEX", task.get("index").asText());
+
+ JsonNode environment = task.get("environment");
+ assertNotNull(environment);
+ assertEquals("cloud", environment.asText());
+ }
+
+ private void verifyArrayElements(JsonNode node, String childName,
+ String... elements) {
+ JsonNode master = node.get(childName);
+ assertNotNull(master);
+ assertEquals(JsonNodeType.ARRAY, master.getNodeType());
+ ArrayNode masterArray = (ArrayNode) master;
+ verifyArray(masterArray, elements);
+ }
+
+ private void verifyArray(ArrayNode array, String... elements) {
+ int arraySize = array.size();
+ assertEquals(elements.length, arraySize);
+
+ for (int i = 0; i < arraySize; i++) {
+ JsonNode arrayElement = array.get(i);
+ assertEquals(elements[i], arrayElement.asText());
+ }
+ }
+
+ @Test
+ public void testSimpleDistributedTFConfigGeneratorWorker()
+ throws IOException {
+ String json = TensorFlowConfigEnvGenerator.getTFConfigEnv("worker", 5, 3,
+ "wtan", "tf-job-001", "example.com");
+
+ JsonNode jsonNode = objectMapper.readTree(json);
+ assertNotNull(jsonNode);
+ JsonNode cluster = jsonNode.get("cluster");
+ assertNotNull(cluster);
+
+ verifyArrayElements(cluster, "master",
+ "master-0.wtan.tf-job-001.example.com:8000");
+ verifyArrayElements(cluster, "worker",
+ "worker-0.wtan.tf-job-001.example.com:8000",
+ "worker-1.wtan.tf-job-001.example.com:8000",
+ "worker-2.wtan.tf-job-001.example.com:8000",
+ "worker-3.wtan.tf-job-001.example.com:8000");
+
+ verifyArrayElements(cluster, "ps",
+ "ps-0.wtan.tf-job-001.example.com:8000",
+ "ps-1.wtan.tf-job-001.example.com:8000",
+ "ps-2.wtan.tf-job-001.example.com:8000");
+
+ verifyCommonJsonData(jsonNode, "worker");
+ }
+
+ @Test
+ public void testSimpleDistributedTFConfigGeneratorMaster()
+ throws IOException {
+ String json = TensorFlowConfigEnvGenerator.getTFConfigEnv("master", 2, 1,
+ "wtan", "tf-job-001", "example.com");
+
+ JsonNode jsonNode = objectMapper.readTree(json);
+ assertNotNull(jsonNode);
+ JsonNode cluster = jsonNode.get("cluster");
+ assertNotNull(cluster);
+
+ verifyArrayElements(cluster, "master",
+ "master-0.wtan.tf-job-001.example.com:8000");
+ verifyArrayElements(cluster, "worker",
+ "worker-0.wtan.tf-job-001.example.com:8000");
+
+ verifyArrayElements(cluster, "ps",
+ "ps-0.wtan.tf-job-001.example.com:8000");
+
+ verifyCommonJsonData(jsonNode, "master");
+ }
+
+ @Test
+ public void testSimpleDistributedTFConfigGeneratorPS() throws IOException {
+ String json = TensorFlowConfigEnvGenerator.getTFConfigEnv("ps", 5, 3,
+ "wtan", "tf-job-001", "example.com");
+
+ JsonNode jsonNode = objectMapper.readTree(json);
+ assertNotNull(jsonNode);
+ JsonNode cluster = jsonNode.get("cluster");
+ assertNotNull(cluster);
+
+ verifyArrayElements(cluster, "master",
+ "master-0.wtan.tf-job-001.example.com:8000");
+ verifyArrayElements(cluster, "worker",
+ "worker-0.wtan.tf-job-001.example.com:8000",
+ "worker-1.wtan.tf-job-001.example.com:8000",
+ "worker-2.wtan.tf-job-001.example.com:8000",
+ "worker-3.wtan.tf-job-001.example.com:8000");
+
+ verifyArrayElements(cluster, "ps",
+ "ps-0.wtan.tf-job-001.example.com:8000",
+ "ps-1.wtan.tf-job-001.example.com:8000",
+ "ps-2.wtan.tf-job-001.example.com:8000");
+
+ verifyCommonJsonData(jsonNode, "ps");
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org