You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/10/01 05:55:12 UTC
[02/21] ignite git commit: IGNITE-9706: [ML] Update ignite-tensorflow
to support TensorFlow standalone client mode
IGNITE-9706: [ML] Update ignite-tensorflow to support
TensorFlow standalone client mode
this closes #4847
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/5aef8813
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/5aef8813
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/5aef8813
Branch: refs/heads/ignite-gg-14206
Commit: 5aef8813269f7e7b3e3d175a4343a9fd72b68325
Parents: 66acc56
Author: Anton Dmitriev <dm...@gmail.com>
Authored: Fri Sep 28 11:49:08 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Sep 28 11:49:08 2018 +0300
----------------------------------------------------------------------
.../TensorFlowServerScriptFormatter.java | 51 ++++++++++++--------
.../util/TensorFlowUserScriptRunner.java | 15 ++----
2 files changed, 37 insertions(+), 29 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/5aef8813/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
----------------------------------------------------------------------
diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
index 7cfa1c6..18854ab 100644
--- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
+++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java
@@ -34,10 +34,16 @@ public class TensorFlowServerScriptFormatter {
public String format(TensorFlowServer srv, boolean join, Ignite ignite) {
StringBuilder builder = new StringBuilder();
+ builder.append("from __future__ import absolute_import").append("\n");
+ builder.append("from __future__ import division").append("\n");
+ builder.append("from __future__ import print_function").append("\n");
+
builder.append("from threading import Thread").append("\n");
builder.append("from time import sleep").append("\n");
builder.append("import os, signal").append("\n");
+
builder.append("\n");
+
builder.append("def check_pid(pid):").append("\n");
builder.append(" try:").append("\n");
builder.append(" os.kill(pid, 0)").append("\n");
@@ -45,24 +51,23 @@ public class TensorFlowServerScriptFormatter {
builder.append(" return False").append("\n");
builder.append(" else:").append("\n");
builder.append(" return True").append("\n");
+
builder.append("\n");
+
builder.append("def threaded_function(pid):").append("\n");
builder.append(" while check_pid(pid):").append("\n");
builder.append(" sleep(1)").append("\n");
builder.append(" os.kill(os.getpid(), signal.SIGUSR1)").append("\n");
+
builder.append("\n");
+
builder.append("Thread(target = threaded_function, args = (int(os.environ['PPID']), )).start()")
.append("\n");
builder.append("\n");
builder.append("import tensorflow as tf").append('\n');
- builder.append("from tensorflow.contrib.ignite import IgniteDataset").append("\n");
- builder.append("\n");
- builder.append("cluster = tf.train.ClusterSpec(")
- .append(srv.getClusterSpec().format(ignite))
- .append(')')
- .append('\n');
- builder.append("");
+ builder.append("fto_import_contrib_ops = tf.contrib.resampler").append("\n");
+ builder.append("import tensorflow.contrib.igfs.python.ops.igfs_ops").append("\n");
builder.append("print('job:%s task:%d' % ('")
.append(srv.getJobName())
@@ -74,22 +79,30 @@ public class TensorFlowServerScriptFormatter {
builder.append("print('IGNITE_DATASET_PORT = ', os.environ.get('IGNITE_DATASET_PORT'))").append("\n");
builder.append("print('IGNITE_DATASET_PART = ', os.environ.get('IGNITE_DATASET_PART'))").append("\n");
- builder.append("server = tf.train.Server(cluster");
-
- if (srv.getJobName() != null)
- builder.append(", job_name=\"").append(srv.getJobName()).append('"');
-
- if (srv.getTaskIdx() != null)
- builder.append(", task_index=").append(srv.getTaskIdx());
-
- if (srv.getProto() != null)
- builder.append(", protocol=\"").append(srv.getProto()).append('"');
-
- builder.append(')').append('\n');
+ builder.append("os.environ['TF_CONFIG'] = '").append(formatTfConfigVar(srv, ignite)).append("'\n");
+ builder.append("server = tf.contrib.distribute.run_standard_tensorflow_server()").append("\n");
if (join)
builder.append("server.join()").append('\n');
return builder.toString();
}
+
+ /**
+ * Formats "TF_CONFIG" variable to be passed into user script.
+ *
+ * @param srv Server description.
+ * @param ignite Ignite instance.
+ * @return Formatted "TF_CONFIG" variable to be passed into user script.
+ */
+ private String formatTfConfigVar(TensorFlowServer srv, Ignite ignite) {
+ return "{\"cluster\" : " +
+ srv.getClusterSpec().format(ignite).replace('\n', ' ') +
+ ", " +
+ "\"task\": {\"type\" : \"" +
+ srv.getJobName() +
+ "\", \"index\": " +
+ srv.getTaskIdx() +
+ "}}";
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/5aef8813/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
----------------------------------------------------------------------
diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
index 17e63bb..d9ed9b2 100644
--- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
+++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java
@@ -124,7 +124,7 @@ public class TensorFlowUserScriptRunner extends AsyncNativeProcessRunner {
Map<String, String> env = procBuilder.environment();
env.put("PYTHONPATH", workingDir.getAbsolutePath());
- env.put("TF_CONFIG", formatTfConfigVar());
+ env.put("TF_CLUSTER", formatTfClusterVar());
env.put("TF_WORKERS", formatTfWorkersVar());
env.put("TF_CHIEF_SERVER", formatTfChiefServerVar());
@@ -132,17 +132,12 @@ public class TensorFlowUserScriptRunner extends AsyncNativeProcessRunner {
}
/**
- * Formats "TF_CONFIG" variable to be passed into user script.
+ * Formats "TF_CLUSTER" variable to be passed into user script.
*
- * @return Formatted "TF_CONFIG" variable to be passed into user script.
+ * @return Formatted "TF_CLUSTER" variable to be passed into user script.
*/
- private String formatTfConfigVar() {
- return new StringBuilder()
- .append("{\"cluster\" : ")
- .append(clusterSpec.format(Ignition.ignite()))
- .append(", ")
- .append("\"task\": {\"type\" : \"" + TensorFlowClusterResolver.CHIEF_JOB_NAME + "\", \"index\": 0}}")
- .toString();
+ private String formatTfClusterVar() {
+ return clusterSpec.format(Ignition.ignite());
}
/**