You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by zh...@apache.org on 2019/11/19 02:47:48 UTC

[submarine] branch master updated: SUBMARINE-66. Improve TF config env JSON generator + tests

This is an automated email from the ASF dual-hosted git repository.

zhouquan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new aae0380  SUBMARINE-66. Improve TF config env JSON generator + tests
aae0380 is described below

commit aae0380151f79c7e93f8a05ab1b979d3c4c5ef9b
Author: Adam Antal <ad...@cloudera.com>
AuthorDate: Mon Nov 18 11:05:31 2019 +0100

    SUBMARINE-66. Improve TF config env JSON generator + tests
    
    ### What is this PR for?
    * The goal is to update the TensorFlow config generation. Also some of the tests were doing some manual JSON thing which is not very maintainable - it is also changed to a more flexible one using jackson-databind
    
    ### What type of PR is it?
    * Refactoring | Test
    
    ### Todos
    * [x] - Travis checks
    * [x] - New tests should pass
    
    ### What is the Jira issue?
    * [SUBMARINE-66](https://issues.apache.org/jira/browse/SUBMARINE-66)
    
    ### How should this be tested?
    * Only the new UTs should pass.
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Does the licenses files need update? Yes
      * We should probably add jackson-module-jaxb-annotations?
    * Is there breaking changes for older versions? No
    * Does this needs documentation? No
    
    Author: Adam Antal <ad...@cloudera.com>
    
    Closes #91 from adamantal/SUBMARINE-66 and squashes the following commits:
    
    268e7e4 [Adam Antal] Renaming setUp to setup in TensorFlowConfigEnvGeneratorTest
    ed3b220 [Adam Antal] SUBMARINE-66. Improve TF config env JSON generator + tests
---
 pom.xml                                            |   1 +
 submarine-all/pom.xml                              |  22 +++
 .../server-submitter/submitter-yarn/pom.xml        |   4 +
 .../server-submitter/submitter-yarnservice/pom.xml |  43 ++++++
 .../yarnservice/tensorflow/TensorFlowCommons.java  |  55 --------
 .../tensorflow/TensorFlowConfigEnvGenerator.java   | 102 ++++++++++++++
 .../command/TensorFlowLaunchCommand.java           |   3 +-
 .../yarnservice/TestTFConfigGenerator.java         |  73 ----------
 .../TensorFlowConfigEnvGeneratorTest.java          | 149 +++++++++++++++++++++
 9 files changed, 323 insertions(+), 129 deletions(-)

diff --git a/pom.xml b/pom.xml
index 82234c7..5465cd7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -72,6 +72,7 @@
     <gson.version>2.8.1</gson.version>
     <jackson-databind.version>2.9.10</jackson-databind.version>
     <jackson-annotations.version>2.9.10</jackson-annotations.version>
+    <jackson-module-jaxb-annotations.version>2.9.10</jackson-module-jaxb-annotations.version>
     <commons-configuration.version>1.10</commons-configuration.version>
     <commons-httpclient.version>3.1</commons-httpclient.version>
 
diff --git a/submarine-all/pom.xml b/submarine-all/pom.xml
index 952709a..900a544 100644
--- a/submarine-all/pom.xml
+++ b/submarine-all/pom.xml
@@ -95,6 +95,12 @@
           <groupId>org.apache.hadoop</groupId>
           <artifactId>hadoop-hdfs-client</artifactId>
           <version>${hadoop.version}</version>
+          <exclusions>
+            <exclusion>
+              <groupId>com.fasterxml.jackson.core</groupId>
+              <artifactId>jackson-databind</artifactId>
+            </exclusion>
+          </exclusions>
         </dependency>
         <dependency>
           <groupId>org.apache.hadoop</groupId>
@@ -105,6 +111,10 @@
               <groupId>io.netty</groupId>
               <artifactId>netty</artifactId>
             </exclusion>
+            <exclusion>
+              <groupId>com.fasterxml.jackson.core</groupId>
+              <artifactId>jackson-databind</artifactId>
+            </exclusion>
           </exclusions>
         </dependency>
       </dependencies>
@@ -128,6 +138,12 @@
           <groupId>org.apache.hadoop</groupId>
           <artifactId>hadoop-hdfs-client</artifactId>
           <version>${hadoop.version}</version>
+          <exclusions>
+            <exclusion>
+              <groupId>com.fasterxml.jackson.core</groupId>
+              <artifactId>jackson-databind</artifactId>
+            </exclusion>
+          </exclusions>
         </dependency>
       </dependencies>
     </profile>
@@ -142,6 +158,12 @@
           <groupId>org.apache.hadoop</groupId>
           <artifactId>hadoop-hdfs-client</artifactId>
           <version>${hadoop.version}</version>
+          <exclusions>
+            <exclusion>
+              <groupId>com.fasterxml.jackson.core</groupId>
+              <artifactId>jackson-databind</artifactId>
+            </exclusion>
+          </exclusions>
         </dependency>
       </dependencies>
     </profile>
diff --git a/submarine-server/server-submitter/submitter-yarn/pom.xml b/submarine-server/server-submitter/submitter-yarn/pom.xml
index 77da805..31d0040 100644
--- a/submarine-server/server-submitter/submitter-yarn/pom.xml
+++ b/submarine-server/server-submitter/submitter-yarn/pom.xml
@@ -171,6 +171,10 @@
           <groupId>commons-codec</groupId>
           <artifactId>commons-codec</artifactId>
         </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.core</groupId>
+          <artifactId>jackson-databind</artifactId>
+        </exclusion>
       </exclusions>
     </dependency>
     <dependency>
diff --git a/submarine-server/server-submitter/submitter-yarnservice/pom.xml b/submarine-server/server-submitter/submitter-yarnservice/pom.xml
index 88b8625..60e0217 100644
--- a/submarine-server/server-submitter/submitter-yarnservice/pom.xml
+++ b/submarine-server/server-submitter/submitter-yarnservice/pom.xml
@@ -104,6 +104,10 @@
           <groupId>commons-io</groupId>
           <artifactId>commons-io</artifactId>
         </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.core</groupId>
+          <artifactId>jackson-databind</artifactId>
+        </exclusion>
       </exclusions>
     </dependency>
 
@@ -142,6 +146,14 @@
           <groupId>org.codehaus.jackson</groupId>
           <artifactId>jackson-mapper-asl</artifactId>
         </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.core</groupId>
+          <artifactId>jackson-databind</artifactId>
+        </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.module</groupId>
+          <artifactId>jackson-module-jaxb-annotations</artifactId>
+        </exclusion>
       </exclusions>
     </dependency>
 
@@ -225,6 +237,10 @@
           <groupId>org.codehaus.jackson</groupId>
           <artifactId>jackson-mapper-asl</artifactId>
         </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.core</groupId>
+          <artifactId>jackson-databind</artifactId>
+        </exclusion>
       </exclusions>
     </dependency>
     <dependency>
@@ -256,6 +272,10 @@
           <groupId>org.codehaus.jackson</groupId>
           <artifactId>jackson-mapper-asl</artifactId>
         </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.core</groupId>
+          <artifactId>jackson-databind</artifactId>
+        </exclusion>
       </exclusions>
     </dependency>
     <dependency>
@@ -273,8 +293,31 @@
           <groupId>commons-io</groupId>
           <artifactId>commons-io</artifactId>
         </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.core</groupId>
+          <artifactId>jackson-databind</artifactId>
+        </exclusion>
+        <exclusion>
+          <groupId>com.fasterxml.jackson.module</groupId>
+          <artifactId>jackson-module-jaxb-annotations</artifactId>
+        </exclusion>
       </exclusions>
     </dependency>
+    <dependency>
+      <groupId>com.fasterxml.jackson.core</groupId>
+      <artifactId>jackson-annotations</artifactId>
+      <version>${jackson-annotations.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>com.fasterxml.jackson.core</groupId>
+      <artifactId>jackson-databind</artifactId>
+      <version>${jackson-databind.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>com.fasterxml.jackson.module</groupId>
+      <artifactId>jackson-module-jaxb-annotations</artifactId>
+      <version>${jackson-module-jaxb-annotations.version}</version>
+    </dependency>
   </dependencies>
 
   <build>
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java
index 424fc6d..32d2640 100644
--- a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowCommons.java
@@ -24,7 +24,6 @@ import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
 import org.apache.hadoop.yarn.service.api.records.Component;
 import org.apache.submarine.commons.runtime.conf.Envs;
 import org.apache.submarine.commons.runtime.api.Role;
-import org.apache.submarine.server.submitter.yarnservice.YarnServiceUtils;
 
 import java.util.Map;
 
@@ -55,58 +54,4 @@ public final class TensorFlowCommons {
   public static String getScriptFileName(Role role) {
     return "run-" + role.getName() + ".sh";
   }
-
-  public static String getTFConfigEnv(String componentName, int nWorkers,
-      int nPs, String serviceName, String userName, String domain) {
-    String commonEndpointSuffix = YarnServiceUtils
-        .getDNSNameCommonSuffix(serviceName, userName, domain, 8000);
-
-    String json = "{\\\"cluster\\\":{";
-
-    String master = getComponentArrayJson("master", 1, commonEndpointSuffix)
-        + ",";
-    String worker = getComponentArrayJson("worker", nWorkers - 1,
-        commonEndpointSuffix) + ",";
-    String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},";
-
-    StringBuilder sb = new StringBuilder();
-    sb.append("\\\"task\\\":{");
-    sb.append(" \\\"type\\\":\\\"");
-    sb.append(componentName);
-    sb.append("\\\",");
-    sb.append(" \\\"index\\\":");
-    sb.append('$');
-    sb.append(Envs.TASK_INDEX_ENV + "},");
-    String task = sb.toString();
-    String environment = "\\\"environment\\\":\\\"cloud\\\"}";
-
-    sb = new StringBuilder();
-    sb.append(json);
-    sb.append(master);
-    sb.append(worker);
-    sb.append(ps);
-    sb.append(task);
-    sb.append(environment);
-    return sb.toString();
-  }
-
-  private static String getComponentArrayJson(String componentName, int count,
-      String endpointSuffix) {
-    String component = "\\\"" + componentName + "\\\":";
-    StringBuilder array = new StringBuilder();
-    array.append("[");
-    for (int i = 0; i < count; i++) {
-      array.append("\\\"");
-      array.append(componentName);
-      array.append("-");
-      array.append(i);
-      array.append(endpointSuffix);
-      array.append("\\\"");
-      if (i != count - 1) {
-        array.append(",");
-      }
-    }
-    array.append("]");
-    return component + array.toString();
-  }
 }
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGenerator.java b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGenerator.java
new file mode 100644
index 0000000..45de9e7
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGenerator.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.yarnservice.tensorflow;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializationFeature;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import com.fasterxml.jackson.databind.type.TypeFactory;
+import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector;
+import org.apache.submarine.commons.runtime.conf.Envs;
+import org.apache.submarine.server.submitter.yarnservice.YarnServiceUtils;
+
+public class TensorFlowConfigEnvGenerator {
+
+  private static final ObjectMapper OBJECT_MAPPER = createObjectMapper();
+
+  private static ObjectMapper createObjectMapper() {
+    ObjectMapper mapper = new ObjectMapper();
+    mapper.setAnnotationIntrospector(
+        new JaxbAnnotationIntrospector(TypeFactory.defaultInstance()));
+    mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
+    mapper.configure(SerializationFeature.FLUSH_AFTER_WRITE_VALUE, false);
+    mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+    return mapper;
+  }
+
+  public static String getTFConfigEnv(String componentName, int nWorkers,
+      int nPs, String serviceName, String userName, String domain) {
+    String commonEndpointSuffix = YarnServiceUtils
+        .getDNSNameCommonSuffix(serviceName, userName, domain, 8000);
+
+    TFConfigEnv tfConfigEnv =
+        new TFConfigEnv(nWorkers, nPs, componentName, commonEndpointSuffix);
+    return tfConfigEnv.toJson();
+  }
+
+  private static class TFConfigEnv {
+    private final int nWorkers;
+    private final int nPS;
+    private final String componentName;
+    private final String endpointSuffix;
+
+    TFConfigEnv(int nWorkers, int nPS, String componentName,
+          String endpointSuffix) {
+      this.nWorkers = nWorkers;
+      this.nPS = nPS;
+      this.componentName = componentName;
+      this.endpointSuffix = endpointSuffix;
+    }
+
+    String toJson() {
+      ObjectNode rootNode = OBJECT_MAPPER.createObjectNode();
+
+      ObjectNode cluster = rootNode.putObject("cluster");
+      createComponentArray(cluster, "master", 1);
+      createComponentArray(cluster, "worker", nWorkers - 1);
+      createComponentArray(cluster, "ps", nPS);
+
+      ObjectNode task = rootNode.putObject("task");
+      task.put("type", componentName);
+      task.put("index", "$" + Envs.TASK_INDEX_ENV);
+      task.put("environment", "cloud");
+      try {
+        return OBJECT_MAPPER.writeValueAsString(rootNode);
+      } catch (JsonProcessingException e) {
+        throw new RuntimeException("Failed to serialize TF config env JSON!",
+            e);
+      }
+    }
+
+    private void createComponentArray(ObjectNode cluster, String name,
+          int count) {
+      ArrayNode array = cluster.putArray(name);
+      for (int i = 0; i < count; i++) {
+        String componentValue = String.format("%s-%d%s", name, i,
+            endpointSuffix);
+        array.add(componentValue);
+      }
+    }
+  }
+}
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
index b88ea24..13c83cc 100644
--- a/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/main/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java
@@ -27,6 +27,7 @@ import org.apache.submarine.server.submitter.yarnservice.HadoopEnvironmentSetup;
 import org.apache.submarine.server.submitter.yarnservice.command.AbstractLaunchCommand;
 import org.apache.submarine.server.submitter.yarnservice.command.LaunchScriptBuilder;
 import org.apache.submarine.server.submitter.yarnservice.tensorflow.TensorFlowCommons;
+import org.apache.submarine.server.submitter.yarnservice.tensorflow.TensorFlowConfigEnvGenerator;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -78,7 +79,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
 
     // When distributed training is required
     if (distributed) {
-      String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
+      String tfConfigEnvValue = TensorFlowConfigEnvGenerator.getTFConfigEnv(
           role.getComponentName(), numberOfWorkers,
           numberOfPS, name,
           TensorFlowCommons.getUserName(),
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/TestTFConfigGenerator.java b/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/TestTFConfigGenerator.java
deleted file mode 100644
index 072eaec..0000000
--- a/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/TestTFConfigGenerator.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.submarine.server.submitter.yarnservice;
-
-import org.apache.submarine.server.submitter.yarnservice.tensorflow.TensorFlowCommons;
-import org.codehaus.jettison.json.JSONException;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Class to test some functionality of {@link TensorFlowCommons}.
- */
-public class TestTFConfigGenerator {
-  @Test
-  public void testSimpleDistributedTFConfigGenerator() throws JSONException {
-    String json = TensorFlowCommons.getTFConfigEnv("worker", 5, 3, "wtan",
-        "tf-job-001", "example.com");
-    String expected =
-        "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"]," +
-            "\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"worker-1.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"worker-2.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"]," +
-            "\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"ps-1.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]}," +
-            "\\\"task\\\":{ \\\"type\\\":\\\"worker\\\", \\\"index\\\":$_TASK_INDEX}," +
-            "\\\"environment\\\":\\\"cloud\\\"}";
-    Assert.assertEquals(expected, json);
-
-    json = TensorFlowCommons.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001",
-        "example.com");
-    expected =
-        "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"]," +
-            "\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"worker-1.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"worker-2.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"]," +
-            "\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"ps-1.wtan.tf-job-001.example.com:8000\\\"," +
-            "\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]}," +
-            "\\\"task\\\":{ \\\"type\\\":\\\"ps\\\", \\\"index\\\":$_TASK_INDEX}," +
-            "\\\"environment\\\":\\\"cloud\\\"}";
-    Assert.assertEquals(expected, json);
-
-    json = TensorFlowCommons.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001",
-        "example.com");
-    expected =
-        "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"]," +
-            "\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"]," +
-            "\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"]}," +
-            "\\\"task\\\":{ \\\"type\\\":\\\"master\\\", \\\"index\\\":$_TASK_INDEX}," +
-            "\\\"environment\\\":\\\"cloud\\\"}";
-    Assert.assertEquals(expected, json);
-  }
-}
diff --git a/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGeneratorTest.java b/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGeneratorTest.java
new file mode 100644
index 0000000..de1135b
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-yarnservice/src/test/java/org/apache/submarine/server/submitter/yarnservice/tensorflow/TensorFlowConfigEnvGeneratorTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.yarnservice.tensorflow;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+import com.fasterxml.jackson.databind.node.JsonNodeType;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * Class to test some functionality of {@link TensorFlowConfigEnvGenerator}.
+ */
+public class TensorFlowConfigEnvGeneratorTest {
+  private ObjectMapper objectMapper;
+
+  @Before
+  public void setup() {
+    objectMapper = new ObjectMapper();
+  }
+
+  private void verifyCommonJsonData(JsonNode node, String taskType) {
+    JsonNode task = node.get("task");
+    assertNotNull(task);
+    assertEquals(taskType, task.get("type").asText());
+    assertEquals("$_TASK_INDEX", task.get("index").asText());
+
+    JsonNode environment = task.get("environment");
+    assertNotNull(environment);
+    assertEquals("cloud", environment.asText());
+  }
+
+  private void verifyArrayElements(JsonNode node, String childName,
+        String... elements) {
+    JsonNode master = node.get(childName);
+    assertNotNull(master);
+    assertEquals(JsonNodeType.ARRAY, master.getNodeType());
+    ArrayNode masterArray = (ArrayNode) master;
+    verifyArray(masterArray, elements);
+  }
+
+  private void verifyArray(ArrayNode array, String... elements) {
+    int arraySize = array.size();
+    assertEquals(elements.length, arraySize);
+
+    for (int i = 0; i < arraySize; i++) {
+      JsonNode arrayElement = array.get(i);
+      assertEquals(elements[i], arrayElement.asText());
+    }
+  }
+
+  @Test
+  public void testSimpleDistributedTFConfigGeneratorWorker()
+      throws IOException {
+    String json = TensorFlowConfigEnvGenerator.getTFConfigEnv("worker", 5, 3,
+            "wtan", "tf-job-001", "example.com");
+
+    JsonNode jsonNode = objectMapper.readTree(json);
+    assertNotNull(jsonNode);
+    JsonNode cluster = jsonNode.get("cluster");
+    assertNotNull(cluster);
+
+    verifyArrayElements(cluster, "master",
+        "master-0.wtan.tf-job-001.example.com:8000");
+    verifyArrayElements(cluster, "worker",
+        "worker-0.wtan.tf-job-001.example.com:8000",
+        "worker-1.wtan.tf-job-001.example.com:8000",
+        "worker-2.wtan.tf-job-001.example.com:8000",
+        "worker-3.wtan.tf-job-001.example.com:8000");
+
+    verifyArrayElements(cluster, "ps",
+        "ps-0.wtan.tf-job-001.example.com:8000",
+        "ps-1.wtan.tf-job-001.example.com:8000",
+        "ps-2.wtan.tf-job-001.example.com:8000");
+
+    verifyCommonJsonData(jsonNode, "worker");
+  }
+
+  @Test
+  public void testSimpleDistributedTFConfigGeneratorMaster()
+      throws IOException {
+    String json = TensorFlowConfigEnvGenerator.getTFConfigEnv("master", 2, 1,
+        "wtan", "tf-job-001", "example.com");
+
+    JsonNode jsonNode = objectMapper.readTree(json);
+    assertNotNull(jsonNode);
+    JsonNode cluster = jsonNode.get("cluster");
+    assertNotNull(cluster);
+
+    verifyArrayElements(cluster, "master",
+        "master-0.wtan.tf-job-001.example.com:8000");
+    verifyArrayElements(cluster, "worker",
+        "worker-0.wtan.tf-job-001.example.com:8000");
+
+    verifyArrayElements(cluster, "ps",
+        "ps-0.wtan.tf-job-001.example.com:8000");
+
+    verifyCommonJsonData(jsonNode, "master");
+  }
+
+  @Test
+  public void testSimpleDistributedTFConfigGeneratorPS() throws IOException {
+    String json = TensorFlowConfigEnvGenerator.getTFConfigEnv("ps", 5, 3,
+        "wtan", "tf-job-001", "example.com");
+
+    JsonNode jsonNode = objectMapper.readTree(json);
+    assertNotNull(jsonNode);
+    JsonNode cluster = jsonNode.get("cluster");
+    assertNotNull(cluster);
+
+    verifyArrayElements(cluster, "master",
+        "master-0.wtan.tf-job-001.example.com:8000");
+    verifyArrayElements(cluster, "worker",
+        "worker-0.wtan.tf-job-001.example.com:8000",
+        "worker-1.wtan.tf-job-001.example.com:8000",
+        "worker-2.wtan.tf-job-001.example.com:8000",
+        "worker-3.wtan.tf-job-001.example.com:8000");
+
+    verifyArrayElements(cluster, "ps",
+        "ps-0.wtan.tf-job-001.example.com:8000",
+        "ps-1.wtan.tf-job-001.example.com:8000",
+        "ps-2.wtan.tf-job-001.example.com:8000");
+
+    verifyCommonJsonData(jsonNode, "ps");
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org