You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/10/01 05:55:12 UTC

[02/21] ignite git commit: IGNITE-9706: [ML] Update ignite-tensorflow to support TensorFlow standalone client mode

IGNITE-9706: [ML] Update ignite-tensorflow to support
TensorFlow standalone client mode

this closes #4847


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/5aef8813
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/5aef8813
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/5aef8813

Branch: refs/heads/ignite-gg-14206
Commit: 5aef8813269f7e7b3e3d175a4343a9fd72b68325
Parents: 66acc56
Author: Anton Dmitriev <dm...@gmail.com>
Authored: Fri Sep 28 11:49:08 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Sep 28 11:49:08 2018 +0300

----------------------------------------------------------------------
 .../TensorFlowServerScriptFormatter.java        | 51 ++++++++++++--------
 .../util/TensorFlowUserScriptRunner.java        | 15 ++----
 2 files changed, 37 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/5aef8813/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
----------------------------------------------------------------------
diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
index 7cfa1c6..18854ab 100644
--- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
+++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
@@ -34,10 +34,16 @@ public class TensorFlowServerScriptFormatter {
     public String format(TensorFlowServer srv, boolean join, Ignite ignite) {
         StringBuilder builder = new StringBuilder();
 
+        builder.append("from __future__ import absolute_import").append("\n");
+        builder.append("from __future__ import division").append("\n");
+        builder.append("from __future__ import print_function").append("\n");
+
         builder.append("from threading import Thread").append("\n");
         builder.append("from time import sleep").append("\n");
         builder.append("import os, signal").append("\n");
+
         builder.append("\n");
+
         builder.append("def check_pid(pid):").append("\n");
         builder.append("    try:").append("\n");
         builder.append("        os.kill(pid, 0)").append("\n");
@@ -45,24 +51,23 @@ public class TensorFlowServerScriptFormatter {
         builder.append("        return False").append("\n");
         builder.append("    else:").append("\n");
         builder.append("        return True").append("\n");
+
         builder.append("\n");
+
         builder.append("def threaded_function(pid):").append("\n");
         builder.append("    while check_pid(pid):").append("\n");
         builder.append("        sleep(1)").append("\n");
         builder.append("    os.kill(os.getpid(), signal.SIGUSR1)").append("\n");
+
         builder.append("\n");
+
         builder.append("Thread(target = threaded_function, args = (int(os.environ['PPID']), )).start()")
             .append("\n");
         builder.append("\n");
 
         builder.append("import tensorflow as tf").append('\n');
-        builder.append("from tensorflow.contrib.ignite import IgniteDataset").append("\n");
-        builder.append("\n");
-        builder.append("cluster = tf.train.ClusterSpec(")
-            .append(srv.getClusterSpec().format(ignite))
-            .append(')')
-            .append('\n');
-        builder.append("");
+        builder.append("fto_import_contrib_ops = tf.contrib.resampler").append("\n");
+        builder.append("import tensorflow.contrib.igfs.python.ops.igfs_ops").append("\n");
 
         builder.append("print('job:%s task:%d' % ('")
             .append(srv.getJobName())
@@ -74,22 +79,30 @@ public class TensorFlowServerScriptFormatter {
         builder.append("print('IGNITE_DATASET_PORT = ', os.environ.get('IGNITE_DATASET_PORT'))").append("\n");
         builder.append("print('IGNITE_DATASET_PART = ', os.environ.get('IGNITE_DATASET_PART'))").append("\n");
 
-        builder.append("server = tf.train.Server(cluster");
-
-        if (srv.getJobName() != null)
-            builder.append(", job_name=\"").append(srv.getJobName()).append('"');
-
-        if (srv.getTaskIdx() != null)
-            builder.append(", task_index=").append(srv.getTaskIdx());
-
-        if (srv.getProto() != null)
-            builder.append(", protocol=\"").append(srv.getProto()).append('"');
-
-        builder.append(')').append('\n');
+        builder.append("os.environ['TF_CONFIG'] = '").append(formatTfConfigVar(srv, ignite)).append("'\n");
+        builder.append("server = tf.contrib.distribute.run_standard_tensorflow_server()").append("\n");
 
         if (join)
             builder.append("server.join()").append('\n');
 
         return builder.toString();
     }
+
+    /**
+     * Formats "TF_CONFIG" variable to be passed into user script.
+     *
+     * @param srv Server description.
+     * @param ignite Ignite instance.
+     * @return Formatted "TF_CONFIG" variable to be passed into user script.
+     */
+    private String formatTfConfigVar(TensorFlowServer srv, Ignite ignite) {
+        return "{\"cluster\" : " +
+            srv.getClusterSpec().format(ignite).replace('\n', ' ') +
+            ", " +
+            "\"task\": {\"type\" : \"" +
+            srv.getJobName() +
+            "\", \"index\": " +
+            srv.getTaskIdx() +
+            "}}";
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/5aef8813/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
----------------------------------------------------------------------
diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
index 17e63bb..d9ed9b2 100644
--- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
+++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
@@ -124,7 +124,7 @@ public class TensorFlowUserScriptRunner extends AsyncNativeProcessRunner {
 
         Map<String, String> env = procBuilder.environment();
         env.put("PYTHONPATH", workingDir.getAbsolutePath());
-        env.put("TF_CONFIG", formatTfConfigVar());
+        env.put("TF_CLUSTER", formatTfClusterVar());
         env.put("TF_WORKERS", formatTfWorkersVar());
         env.put("TF_CHIEF_SERVER", formatTfChiefServerVar());
 
@@ -132,17 +132,12 @@ public class TensorFlowUserScriptRunner extends AsyncNativeProcessRunner {
     }
 
     /**
-     * Formats "TF_CONFIG" variable to be passed into user script.
+     * Formats "TF_CLUSTER" variable to be passed into user script.
      *
-     * @return Formatted "TF_CONFIG" variable to be passed into user script.
+     * @return Formatted "TF_CLUSTER" variable to be passed into user script.
      */
-    private String formatTfConfigVar() {
-        return new StringBuilder()
-            .append("{\"cluster\" : ")
-            .append(clusterSpec.format(Ignition.ignite()))
-            .append(", ")
-            .append("\"task\": {\"type\" : \"" + TensorFlowClusterResolver.CHIEF_JOB_NAME + "\", \"index\": 0}}")
-            .toString();
+    private String formatTfClusterVar() {
+        return clusterSpec.format(Ignition.ignite());
     }
 
     /**