IGNITE-9706: [ML] Update ignite-tensorflow to support TensorFlow standalone client mode
this closes #4847 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/5aef8813 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/5aef8813 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/5aef8813 Branch: refs/heads/ignite-5797 Commit: 5aef8813269f7e7b3e3d175a4343a9fd72b68325 Parents: 66acc56 Author: Anton Dmitriev <[email protected]> Authored: Fri Sep 28 11:49:08 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Sep 28 11:49:08 2018 +0300 ---------------------------------------------------------------------- .../TensorFlowServerScriptFormatter.java | 51 ++++++++++++-------- .../util/TensorFlowUserScriptRunner.java | 15 ++---- 2 files changed, 37 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/5aef8813/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java index 7cfa1c6..18854ab 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java @@ -34,10 +34,16 @@ public class TensorFlowServerScriptFormatter { public String format(TensorFlowServer srv, boolean join, Ignite ignite) { StringBuilder builder = new StringBuilder(); + builder.append("from __future__ import absolute_import").append("\n"); + builder.append("from __future__ import division").append("\n"); + builder.append("from __future__ import print_function").append("\n"); + builder.append("from threading import Thread").append("\n"); builder.append("from time import sleep").append("\n"); builder.append("import os, signal").append("\n"); + builder.append("\n"); + builder.append("def check_pid(pid):").append("\n"); builder.append(" try:").append("\n"); builder.append(" os.kill(pid, 0)").append("\n"); @@ -45,24 +51,23 @@ public class TensorFlowServerScriptFormatter { builder.append(" return False").append("\n"); builder.append(" else:").append("\n"); builder.append(" return True").append("\n"); + builder.append("\n"); + builder.append("def threaded_function(pid):").append("\n"); builder.append(" while check_pid(pid):").append("\n"); builder.append(" sleep(1)").append("\n"); builder.append(" os.kill(os.getpid(), signal.SIGUSR1)").append("\n"); + builder.append("\n"); + builder.append("Thread(target = threaded_function, args = (int(os.environ['PPID']), )).start()") .append("\n"); builder.append("\n"); builder.append("import tensorflow as tf").append('\n'); - builder.append("from tensorflow.contrib.ignite import IgniteDataset").append("\n"); - builder.append("\n"); - builder.append("cluster = tf.train.ClusterSpec(") - .append(srv.getClusterSpec().format(ignite)) - .append(')') - .append('\n'); - builder.append(""); + builder.append("fto_import_contrib_ops = tf.contrib.resampler").append("\n"); + builder.append("import tensorflow.contrib.igfs.python.ops.igfs_ops").append("\n"); builder.append("print('job:%s task:%d' % ('") .append(srv.getJobName()) @@ -74,22 +79,30 @@ public class TensorFlowServerScriptFormatter { builder.append("print('IGNITE_DATASET_PORT = ', os.environ.get('IGNITE_DATASET_PORT'))").append("\n"); builder.append("print('IGNITE_DATASET_PART = ', os.environ.get('IGNITE_DATASET_PART'))").append("\n"); - builder.append("server = tf.train.Server(cluster"); - - if (srv.getJobName() != null) - builder.append(", job_name=\"").append(srv.getJobName()).append('"'); - - if (srv.getTaskIdx() != null) - builder.append(", task_index=").append(srv.getTaskIdx()); - - if (srv.getProto() != null) - builder.append(", protocol=\"").append(srv.getProto()).append('"'); - - builder.append(')').append('\n'); + builder.append("os.environ['TF_CONFIG'] = '").append(formatTfConfigVar(srv, ignite)).append("'\n"); + builder.append("server = tf.contrib.distribute.run_standard_tensorflow_server()").append("\n"); if (join) builder.append("server.join()").append('\n'); return builder.toString(); } + + /** + * Formats "TF_CONFIG" variable to be passed into user script. + * + * @param srv Server description. + * @param ignite Ignite instance. + * @return Formatted "TF_CONFIG" variable to be passed into user script. + */ + private String formatTfConfigVar(TensorFlowServer srv, Ignite ignite) { + return "{\"cluster\" : " + + srv.getClusterSpec().format(ignite).replace('\n', ' ') + + ", " + + "\"task\": {\"type\" : \"" + + srv.getJobName() + + "\", \"index\": " + + srv.getTaskIdx() + + "}}"; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/5aef8813/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java index 17e63bb..d9ed9b2 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowUserScriptRunner.java @@ -124,7 +124,7 @@ public class TensorFlowUserScriptRunner extends AsyncNativeProcessRunner { Map<String, String> env = procBuilder.environment(); env.put("PYTHONPATH", workingDir.getAbsolutePath()); - env.put("TF_CONFIG", formatTfConfigVar()); + env.put("TF_CLUSTER", formatTfClusterVar()); env.put("TF_WORKERS", formatTfWorkersVar()); env.put("TF_CHIEF_SERVER", formatTfChiefServerVar()); @@ -132,17 +132,12 @@ public class TensorFlowUserScriptRunner extends AsyncNativeProcessRunner { } /** - * Formats "TF_CONFIG" variable to be passed into user script. + * Formats "TF_CLUSTER" variable to be passed into user script. * - * @return Formatted "TF_CONFIG" variable to be passed into user script. + * @return Formatted "TF_CLUSTER" variable to be passed into user script. */ - private String formatTfConfigVar() { - return new StringBuilder() - .append("{\"cluster\" : ") - .append(clusterSpec.format(Ignition.ignite())) - .append(", ") - .append("\"task\": {\"type\" : \"" + TensorFlowClusterResolver.CHIEF_JOB_NAME + "\", \"index\": 0}}") - .toString(); + private String formatTfClusterVar() { + return clusterSpec.format(Ignition.ignite()); } /**
