IGNITE-9034: [ML] Add Estimator API support to TensorFlow cluster on top of Apache Ignite.
this closes #4402 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/9e884e5a Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/9e884e5a Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/9e884e5a Branch: refs/heads/ignite-8446 Commit: 9e884e5aaa3fba3a68967c51280f3a8f41514f51 Parents: c5e723c Author: Anton Dmitriev <[email protected]> Authored: Tue Jul 31 16:54:26 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Jul 31 16:54:26 2018 +0300 ---------------------------------------------------------------------- examples/pom.xml | 21 -- .../cluster/TensorFlowClusterExample.java | 118 -------- modules/tensorflow/README.txt | 55 +++- modules/tensorflow/pom.xml | 56 +++- modules/tensorflow/src/main/assembly/zip.xml | 56 ++++ .../cluster/TensorFlowClusterGateway.java | 34 ++- .../TensorFlowClusterGatewayManager.java | 95 +++++-- .../cluster/TensorFlowClusterMaintainer.java | 105 +++++--- .../cluster/TensorFlowClusterManager.java | 268 +++++++++++++------ .../cluster/TensorFlowJobArchive.java | 65 +++++ .../cluster/spec/TensorFlowClusterSpec.java | 39 +++ .../spec/TensorFlowServerAddressSpec.java | 14 + .../tfrunning/TensorFlowServerManager.java | 112 +------- .../TensorFlowServerScriptFormatter.java | 62 +++++ .../cluster/util/ClusterPortManager.java | 168 +++++++++--- .../cluster/util/TensorFlowChiefRunner.java | 80 ++++++ .../cluster/util/TensorFlowClusterResolver.java | 92 ++++--- .../util/TensorFlowUserScriptRunner.java | 236 ++++++++++++++++ .../ignite/tensorflow/core/ProcessManager.java | 5 +- .../tensorflow/core/ProcessManagerWrapper.java | 5 +- .../longrunning/LongRunningProcessManager.java | 19 +- .../task/LongRunningProcessStartTask.java | 4 +- .../core/nativerunning/NativeProcess.java | 6 +- .../nativerunning/NativeProcessManager.java | 12 +- .../task/NativeProcessStartTask.java | 81 ++---- .../PythonProcessBuilderSupplier.java | 57 ++++ .../pythonrunning/PythonProcessManager.java | 35 +-- .../core/util/AsyncNativeProcessRunner.java | 107 ++++++++ .../core/util/NativeProcessRunner.java | 133 +++++++++ .../tensorflow/submitter/JobSubmitter.java | 35 +++ .../submitter/command/AbstractCommand.java | 55 ++++ .../submitter/command/AttachCommand.java | 51 ++++ .../tensorflow/submitter/command/PsCommand.java | 47 ++++ .../submitter/command/RootCommand.java | 42 +++ .../submitter/command/StartCommand.java | 205 ++++++++++++++ .../submitter/command/StopCommand.java | 50 ++++ .../submitter/command/package-info.java | 22 ++ .../tensorflow/submitter/package-info.java | 23 ++ .../tensorflow/util/SerializableConsumer.java | 29 ++ .../tensorflow/util/SerializableSupplier.java | 29 ++ .../ignite/tensorflow/util/package-info.java | 22 ++ modules/tensorflow/src/main/sh/ignite-tf.sh | 19 ++ modules/tensorflow/src/main/sh/logback.xml | 36 +++ .../LongRunningProcessManagerTest.java | 10 +- 44 files changed, 2222 insertions(+), 593 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/examples/pom.xml ---------------------------------------------------------------------- diff --git a/examples/pom.xml b/examples/pom.xml index 3f6e6a8..e745beb 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -124,8 +124,6 @@ <lgpl.test.folder>src/test/java</lgpl.test.folder> <spark.folder>src/main/java</spark.folder> <spark.test.folder>src/test/java</spark.test.folder> - <tensorflow.folder>src/main/java</tensorflow.folder> - <tensorflow.test.folder>src/test/java</tensorflow.test.folder> </properties> <profiles> @@ -202,23 +200,6 @@ </dependency> </dependencies> </profile> - - <profile> - <id>tensorflow</id> - - <properties> - <tensorflow.folder>src/main/tensorflow</tensorflow.folder> - <tensorflow.test.folder>src/test/tensorflow</tensorflow.test.folder> - </properties> - - <dependencies> - <dependency> - <groupId>org.apache.ignite</groupId> - <artifactId>ignite-tensorflow</artifactId> - <version>${project.version}</version> - </dependency> - </dependencies> - </profile> </profiles> <build> @@ -249,7 +230,6 @@ <sources> <source>${lgpl.folder}</source> <source>${spark.folder}</source> - <source>${tensorflow.folder}</source> </sources> </configuration> </execution> @@ -264,7 +244,6 @@ <sources> <source>${lgpl.test.folder}</source> <source>${spark.test.folder}</source> - <source>${tensorflow.test.folder}</source> </sources> </configuration> </execution> http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/examples/src/main/tensorflow/org/apache/ignite/tensorflow/cluster/TensorFlowClusterExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/tensorflow/org/apache/ignite/tensorflow/cluster/TensorFlowClusterExample.java b/examples/src/main/tensorflow/org/apache/ignite/tensorflow/cluster/TensorFlowClusterExample.java deleted file mode 100644 index 3a956c9..0000000 --- a/examples/src/main/tensorflow/org/apache/ignite/tensorflow/cluster/TensorFlowClusterExample.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.tensorflow.cluster; - -import java.io.Serializable; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.CountDownLatch; -import java.util.function.Supplier; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.configuration.IgniteConfiguration; -import org.apache.ignite.tensorflow.cluster.tfrunning.TensorFlowServerManager; -import org.apache.ignite.tensorflow.core.longrunning.task.util.LongRunningProcessStatus; - -/** - * Prerequisites: be aware that to successfully run this example you need to have Python and TensorFlow installed on - * your machine. To find out how to install Python please take a look this page: https://www.python.org/downloads/. To - * install TensorFlow see this web site: https://www.tensorflow.org/install/. - * - * Example that shows how to use {@link TensorFlowClusterGatewayManager} and start, maintain and stop TensorFlow - * cluster. - */ -public class TensorFlowClusterExample { - /** Run example. */ - public static void main(String... args) throws InterruptedException { - IgniteConfiguration configuration = new IgniteConfiguration(); - configuration.setClientMode(false); - - try (Ignite ignite = Ignition.start(configuration)) { - System.out.println(">>> TensorFlow cluster example started."); - - CacheConfiguration<Integer, Integer> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); - cacheConfiguration.setName("TEST_CACHE"); - - IgniteCache<Integer, Integer> cache = ignite.getOrCreateCache(cacheConfiguration); - for (int i = 0; i < 1000; i++) - cache.put(i, i); - - System.out.println(">>> Cache created."); - - TensorFlowClusterGatewayManager mgr = new TensorFlowClusterGatewayManager(ignite); - TensorFlowClusterGateway gateway = mgr.getOrCreateCluster("TEST_CACHE"); - - System.out.println(">>> TensorFlow cluster gateway started."); - - CountDownLatch latch = new CountDownLatch(1); - - gateway.subscribe(cluster -> { - StringBuilder builder = new StringBuilder(); - builder.append("------------------- TensorFlow Cluster Service Info -------------------").append('\n'); - - builder.append("Specification : ").append('\n'); - - TensorFlowServerManager srvMgr = new TensorFlowServerManager( - (Supplier<Ignite> & Serializable)() -> ignite - ); - - String clusterSpec = srvMgr.formatClusterSpec(cluster.getSpec()); - builder.append(clusterSpec).append('\n'); - - Map<UUID, List<LongRunningProcessStatus>> statuses = srvMgr.ping(cluster.getProcesses()); - - builder.append("State : ").append('\n'); - - for (UUID nodeId : cluster.getProcesses().keySet()) { - List<UUID> pr = cluster.getProcesses().get(nodeId); - List<LongRunningProcessStatus> st = statuses.get(nodeId); - - builder.append("Node ").append(nodeId.toString().substring(0, 8)).append(" -> ").append('\n'); - for (int i = 0; i < pr.size(); i++) { - builder.append("\tProcess ") - .append(pr.get(i).toString().substring(0, 8)) - .append(" with status ") - .append(st.get(i).getState()); - - if (st.get(i).getException() != null) - builder.append(" (").append(st.get(i).getException()).append(")"); - - builder.append('\n'); - } - } - - builder.append("-----------------------------------------------------------------------").append('\n'); - - System.out.println(builder); - - latch.countDown(); - }); - - latch.await(); - - mgr.stopClusterIfExists("TEST_CACHE"); - - System.out.println(">>> TensorFlow cluster example completed."); - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/README.txt ---------------------------------------------------------------------- diff --git a/modules/tensorflow/README.txt b/modules/tensorflow/README.txt index 45fc3cf..21ea88d 100644 --- a/modules/tensorflow/README.txt +++ b/modules/tensorflow/README.txt @@ -7,7 +7,7 @@ will be a datasource for any TensorFlow model training. Import Apache Ignite TensorFlow Integration Module In Maven Project ------------------------------------- -If you are using Maven to manage dependencies of your project, you can add Cassandra Store module +If you are using Maven to manage dependencies of your project, you can add TensorFlow module dependency like this (replace '${ignite.version}' with actual Ignite version you are interested in): @@ -26,4 +26,55 @@ interested in): ... </dependencies> ... -</project> \ No newline at end of file +</project> +------------------------------------- + +TensorFlow integration module provides command line tool that allows to start, maintain and stop distributed deep +learning utilizing Apache Ignite infrastructure and data. This tool provides several commands that are shown here: + +Usage: ignite-tf [-hV] [-c=<cfg>] [COMMAND] +Apache Ignite and TensorFlow integration command line utility that allows to +start, maintain and stop distributed deep learning utilizing Apache Ignite +infrastructure and data. + -c, --config=<cfg> Apache Ignite client configuration. + -h, --help Show this help message and exit. + -V, --version Print version information and exit. +Commands: + start Starts a new TensorFlow cluster and attaches to user script process. + stop Stops a running TensorFlow cluster. + attach Attaches to running TensorFlow cluster (user script process). + ps Prints identifiers of all running TensorFlow clusters. + +To start TensorFlow cluster you need to specify upstream cache that will be used as data source for training, folder +that contains code that actually performs training and command that should be called on this code to start training +correctly. Command "start" have the following help output: + +Usage: ignite-tf start [-hV] [-c=<cfg>] CACHE_NAME JOB_DIR JOB_CMD [JOB_ARGS...] +Starts a new TensorFlow cluster and attaches to user script process. + CACHE_NAME Upstream cache name. + JOB_DIR Job folder (or zip archive). + JOB_CMD Job command. + [JOB_ARGS...] Job arguments. + -c, --config=<cfg> Apache Ignite client configuration. + -h, --help Show this help message and exit. + -V, --version Print version information and exit. + +To attach to running TensorFlow cluster or stop it you can use commands "attach" and "stop" correspondingly. These +commands accepts cluster identifier as a parameter: + +Usage: ignite-tf attach [-hV] [-c=<cfg>] CLUSTER_ID +Attaches to running TensorFlow cluster (user script process). + CLUSTER_ID Cluster identifier. + -c, --config=<cfg> Apache Ignite client configuration. + -h, --help Show this help message and exit. + -V, --version Print version information and exit. + +Usage: ignite-tf stop [-hV] [-c=<cfg>] CLUSTER_ID +Stops a running TensorFlow cluster. + CLUSTER_ID Cluster identifier. + -c, --config=<cfg> Apache Ignite client configuration. + -h, --help Show this help message and exit. + -V, --version Print version information and exit. + +To find out what TensorFlow clusters are currently running on top of Apache Ignite you can use "ps" command that doesn't +require arguments. \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/pom.xml ---------------------------------------------------------------------- diff --git a/modules/tensorflow/pom.xml b/modules/tensorflow/pom.xml index 72e7e69..b1cfe21 100644 --- a/modules/tensorflow/pom.xml +++ b/modules/tensorflow/pom.xml @@ -34,6 +34,30 @@ <version>2.7.0-SNAPSHOT</version> <url>http://ignite.apache.org</url> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-assembly-plugin</artifactId> + <executions> + <execution> + <id>dist</id> + <phase>package</phase> + <goals> + <goal>single</goal> + </goals> + <configuration> + <appendAssemblyId>false</appendAssemblyId> + <descriptors> + <descriptor>src/main/assembly/zip.xml</descriptor> + </descriptors> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + <dependencies> <dependency> <groupId>org.apache.ignite</groupId> @@ -43,22 +67,39 @@ <dependency> <groupId>org.apache.ignite</groupId> - <artifactId>ignite-core</artifactId> + <artifactId>ignite-spring</artifactId> <version>${project.version}</version> - <type>test-jar</type> - <scope>test</scope> </dependency> <dependency> <groupId>org.apache.ignite</groupId> - <artifactId>ignite-spring</artifactId> + <artifactId>ignite-slf4j</artifactId> <version>${project.version}</version> - <scope>test</scope> </dependency> <dependency> - <groupId>log4j</groupId> - <artifactId>log4j</artifactId> + <groupId>ch.qos.logback</groupId> + <artifactId>logback-classic</artifactId> + <version>1.2.3</version> + </dependency> + + <dependency> + <groupId>commons-io</groupId> + <artifactId>commons-io</artifactId> + <version>2.6</version> + </dependency> + + <dependency> + <groupId>info.picocli</groupId> + <artifactId>picocli</artifactId> + <version>3.3.0</version> + </dependency> + + <dependency> + <groupId>org.apache.ignite</groupId> + <artifactId>ignite-core</artifactId> + <version>${project.version}</version> + <type>test-jar</type> <scope>test</scope> </dependency> @@ -86,5 +127,4 @@ </build> </profile> </profiles> - </project> \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/assembly/zip.xml ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/assembly/zip.xml b/modules/tensorflow/src/main/assembly/zip.xml new file mode 100644 index 0000000..0c90e98 --- /dev/null +++ b/modules/tensorflow/src/main/assembly/zip.xml @@ -0,0 +1,56 @@ +<?xml version="1.0" encoding="UTF-8"?> + +<!-- + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<!-- + Assembly file. +--> +<assembly xmlns="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.2" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.2 + http://maven.apache.org/xsd/assembly-1.1.2.xsd"> + <id>zip</id> + <includeBaseDirectory>true</includeBaseDirectory> + + <formats> + <format>zip</format> + </formats> + + <files> + <file> + <source>${project.basedir}/src/main/sh/ignite-tf.sh</source> + <outputDirectory>/</outputDirectory> + </file> + <file> + <source>${project.basedir}/src/main/sh/logback.xml</source> + <outputDirectory>/</outputDirectory> + </file> + <file> + <source>${project.build.directory}/${project.artifactId}-${project.version}.jar</source> + <outputDirectory>/lib</outputDirectory> + </file> + </files> + + <dependencySets> + <dependencySet> + <outputDirectory>lib</outputDirectory> + <useStrictFiltering>true</useStrictFiltering> + <useProjectArtifact>false</useProjectArtifact> + <scope>runtime</scope> + </dependencySet> + </dependencySets> +</assembly> \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGateway.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGateway.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGateway.java index 5eee155..092dfcb 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGateway.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGateway.java @@ -18,29 +18,43 @@ package org.apache.ignite.tensorflow.cluster; import java.util.HashSet; +import java.util.Optional; import java.util.UUID; import java.util.function.Consumer; import org.apache.ignite.lang.IgniteBiPredicate; +import org.apache.ignite.tensorflow.util.SerializableConsumer; /** * TensorFlow cluster gateway that allows to subscribe on changes in cluster configuration. */ -public class TensorFlowClusterGateway implements IgniteBiPredicate<UUID, TensorFlowCluster> { +public class TensorFlowClusterGateway implements IgniteBiPredicate<UUID, Optional<TensorFlowCluster>>, AutoCloseable { /** */ private static final long serialVersionUID = -540323262800791340L; + /** Callback that will be called on unsubscribe. */ + private final SerializableConsumer<TensorFlowClusterGateway> unsubscribeCb; + /** Subscribers. */ - private final HashSet<Consumer<TensorFlowCluster>> subscribers = new HashSet<>(); + private final HashSet<Consumer<Optional<TensorFlowCluster>>> subscribers = new HashSet<>(); /** Last value received from the upstream. */ - private TensorFlowCluster last; + private Optional<TensorFlowCluster> last; + + /** + * Constructs a new instance of TensorFlow cluster gateway. + * + * @param unsubscribeCb Callback that will be called on unsubscribe. + */ + public TensorFlowClusterGateway(SerializableConsumer<TensorFlowClusterGateway> unsubscribeCb) { + this.unsubscribeCb = unsubscribeCb; + } /** * Subscribers the specified subscriber on the upstream events. * * @param subscriber Subscriber. */ - public synchronized void subscribe(Consumer<TensorFlowCluster> subscriber) { + public synchronized void subscribe(Consumer<Optional<TensorFlowCluster>> subscriber) { subscribers.add(subscriber); if (last != null) @@ -52,17 +66,23 @@ public class TensorFlowClusterGateway implements IgniteBiPredicate<UUID, TensorF * * @param subscriber Subscriber. */ - public synchronized void unsubscribe(Consumer<TensorFlowCluster> subscriber) { + public synchronized void unsubscribe(Consumer<Optional<TensorFlowCluster>> subscriber) { subscribers.remove(subscriber); } /** {@inheritDoc} */ - @Override public synchronized boolean apply(UUID uuid, TensorFlowCluster cluster) { - for (Consumer<TensorFlowCluster> subscriber : subscribers) + @Override public synchronized boolean apply(UUID uuid, Optional<TensorFlowCluster> cluster) { + for (Consumer<Optional<TensorFlowCluster>> subscriber : subscribers) subscriber.accept(cluster); last = cluster; return true; } + + /** {@inheritDoc} */ + @Override public void close() { + subscribers.clear(); + unsubscribeCb.accept(this); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGatewayManager.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGatewayManager.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGatewayManager.java index f4b8187..a315aaf 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGatewayManager.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterGatewayManager.java @@ -17,8 +17,12 @@ package org.apache.ignite.tensorflow.cluster; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.function.Consumer; import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteServices; +import org.apache.ignite.IgniteLogger; /** * TensorFlow cluster manager that allows to start, maintain and stop TensorFlow cluster using @@ -34,6 +38,9 @@ public class TensorFlowClusterGatewayManager { /** Ignite instance. */ private final Ignite ignite; + /** Logger. */ + private final IgniteLogger log; + /** * Constructs a new instance of TensorFlow cluster manager with maintenance. * @@ -43,37 +50,89 @@ public class TensorFlowClusterGatewayManager { assert ignite != null : "Ignite should not be null"; this.ignite = ignite; + this.log = ignite.log().getLogger(TensorFlowClusterGatewayManager.class); } /** - * Creates and starts a new TensorFlow cluster for the specified cache if it doesn't exist, otherwise returns - * existing one. + * Subscribes on changes of the specified cluster. * - * @param upstreamCacheName Upstream cache name. + * @param clusterId Cluster identifier. * @return TensorFlow cluster gateway that allows to subscribe on cluster changes. */ - public TensorFlowClusterGateway getOrCreateCluster(String upstreamCacheName) { - String svcName = String.format(SERVICE_NAME_TEMPLATE, upstreamCacheName); - String topicName = String.format(SERVICE_TOPIC_NAME_TEMPLATE, upstreamCacheName); + public TensorFlowClusterGateway getCluster(UUID clusterId) { + String topicName = String.format(SERVICE_TOPIC_NAME_TEMPLATE, clusterId); - TensorFlowClusterGateway gateway = createTensorFlowClusterGateway(topicName); + return createTensorFlowClusterGateway(topicName); + } - IgniteServices services = ignite.services(); + /** + * Creates and starts a new TensorFlow cluster for the specified cache. + * + * @param clusterId Cluster identifier. + * @param jobArchive Job archive. + * @return TensorFlow cluster gateway that allows to subscribe on cluster changes. + */ + public TensorFlowClusterGateway createCluster(UUID clusterId, TensorFlowJobArchive jobArchive) { + String svcName = String.format(SERVICE_NAME_TEMPLATE, clusterId); + String topicName = String.format(SERVICE_TOPIC_NAME_TEMPLATE, clusterId); - services.deployClusterSingleton(svcName, new TensorFlowClusterMaintainer(upstreamCacheName, topicName)); + TensorFlowClusterGateway gateway = createTensorFlowClusterGateway(topicName); + + ignite.services().deployClusterSingleton( + svcName, + new TensorFlowClusterMaintainer(clusterId, jobArchive, topicName) + ); + log.info("Cluster maintainer deployed as a service [clusterId=" + clusterId + "]"); return gateway; } /** - * Stops TensorFlow cluster. + * Listens to TensorFlow cluster user script. * - * @param upstreamCacheName Upstream cache name. + * @param clusterId Cluster identifier. + * @param out Output stream consumer. + * @param err Error stream consumer. */ - public void stopClusterIfExists(String upstreamCacheName) { - IgniteServices services = ignite.services(); + public void listenToClusterUserScript(UUID clusterId, Consumer<String> out, Consumer<String> err) { + TensorFlowClusterGateway gateway = getCluster(clusterId); + + ignite.message().localListen("us_out_" + clusterId, (node, msg) -> { + out.accept(msg.toString()); + return true; + }); + + ignite.message().localListen("us_err_" + clusterId, (node, msg) -> { + err.accept(msg.toString()); + return true; + }); + + CountDownLatch latch = new CountDownLatch(1); + + Consumer<Optional<TensorFlowCluster>> subscriber = cluster -> { + if (!cluster.isPresent()) + latch.countDown(); + }; + + gateway.subscribe(subscriber); + + try { + latch.await(); + gateway.unsubscribe(subscriber); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + } - services.cancel(String.format(SERVICE_NAME_TEMPLATE, upstreamCacheName)); + /** + * Stops TensorFlow cluster. + * + * @param clusterId Cluster identifier. + */ + public void stopClusterIfExists(UUID clusterId) { + ignite.services().cancel(String.format(SERVICE_NAME_TEMPLATE, clusterId)); + log.info("Cluster maintained cancelled as a service [clusterId=" + clusterId + "]"); } /** @@ -83,9 +142,13 @@ public class TensorFlowClusterGatewayManager { * @return TensorFlow cluster gateway. */ private TensorFlowClusterGateway createTensorFlowClusterGateway(String topicName) { - TensorFlowClusterGateway gateway = new TensorFlowClusterGateway(); + TensorFlowClusterGateway gateway = new TensorFlowClusterGateway(subscriber -> { + ignite.message().stopLocalListen(topicName, subscriber); + log.info("Stop listen to cluster gateway [topicName=" + topicName + "]"); + }); ignite.message().localListen(topicName, gateway); + log.info("Start listen to cluster gateway [topicName=" + topicName + "]"); return gateway; } http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterMaintainer.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterMaintainer.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterMaintainer.java index 21b9c2b..e6ca33d 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterMaintainer.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterMaintainer.java @@ -17,17 +17,18 @@ package org.apache.ignite.tensorflow.cluster; -import java.io.Serializable; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.UUID; -import java.util.function.Supplier; +import java.util.concurrent.locks.LockSupport; import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteMessaging; -import org.apache.ignite.Ignition; +import org.apache.ignite.IgniteLogger; import org.apache.ignite.cache.affinity.Affinity; import org.apache.ignite.cluster.ClusterNode; +import org.apache.ignite.resources.IgniteInstanceResource; +import org.apache.ignite.resources.LoggerResource; import org.apache.ignite.services.Service; import org.apache.ignite.services.ServiceContext; import org.apache.ignite.tensorflow.core.longrunning.task.util.LongRunningProcessState; @@ -40,74 +41,116 @@ public class TensorFlowClusterMaintainer implements Service { /** */ private static final long serialVersionUID = -3220563310643566419L; - /** Upstream cache name. */ - private final String cacheName; + /** Ignite instance. */ + @IgniteInstanceResource + private transient Ignite ignite; + + /** Logger. */ + @LoggerResource + private transient IgniteLogger log; + + /** TensorFlow cluster identifier. */ + private final UUID clusterId; + + /** Job archive. */ + private final TensorFlowJobArchive jobArchive; /** Topic name. */ private final String topicName; /** TensorFlow cluster manager. */ - private final TensorFlowClusterManager clusterMgr; + private transient TensorFlowClusterManager clusterMgr; /** Previous partition mapping. */ - private UUID[] prev; + private transient UUID[] prev; /** * Constructs a new instance of TensorFlow cluster service. * - * @param cacheName Upstream cache name. + * @param clusterId Cluster identifier. + * @param jobArchive Job archive. * @param topicName Topic name. */ - public TensorFlowClusterMaintainer(String cacheName, String topicName) { - assert cacheName != null : "Cache name should not be null"; + public TensorFlowClusterMaintainer(UUID clusterId, TensorFlowJobArchive jobArchive, String topicName) { + assert clusterId != null : "Cluster identifier should not be null"; + assert jobArchive != null : "Job archive should not be null"; assert topicName != null : "Topic name should not be null"; - this.clusterMgr = new TensorFlowClusterManager((Supplier<Ignite> & Serializable)Ignition::ignite); - this.cacheName = cacheName; + this.clusterId = clusterId; + this.jobArchive = jobArchive; this.topicName = topicName; } /** {@inheritDoc} */ @Override public void cancel(ServiceContext ctx) { - clusterMgr.stopClusterIfExists(cacheName); + clusterMgr.stopClusterIfExists(clusterId); + log.debug("Cluster maintainer canceled [clusterId=" + clusterId + "]"); } /** {@inheritDoc} */ @Override public void init(ServiceContext ctx) { - clusterMgr.init(); + clusterMgr = new TensorFlowClusterManager(ignite); + log.debug("Cluster maintainer initialized [clusterId=" + clusterId + "]"); } /** {@inheritDoc} */ - @Override public void execute(ServiceContext ctx) throws Exception { + @Override public void execute(ServiceContext ctx) { while (!ctx.isCancelled()) { - Thread.sleep(1000); + LockSupport.parkNanos(1_000_000); + + boolean completed = clusterMgr.isUserScriptCompleted(clusterId); + if (completed) + break; boolean restartRequired = hasAffinityChanged(); + if (restartRequired) + log.debug("Affinity mapping changed, cluster will be restarted [clusterId=" + clusterId + "]"); + if (!restartRequired) { - TensorFlowCluster cluster = clusterMgr.getCache().get(cacheName); - Map<UUID, List<LongRunningProcessStatus>> statuses = clusterMgr.getSrvProcMgr() - .ping(cluster.getProcesses()); - - for (UUID nodeId : statuses.keySet()) { - for (LongRunningProcessStatus status : statuses.get(nodeId)) { - if (status.getState().equals(LongRunningProcessState.DONE)) { - restartRequired = true; - break; + try { + TensorFlowCluster cluster = clusterMgr.getCluster(clusterId); + Map<UUID, List<LongRunningProcessStatus>> statuses = clusterMgr.getSrvProcMgr() + .ping(cluster.getProcesses()); + + for (UUID nodeId : statuses.keySet()) { + for (LongRunningProcessStatus status : statuses.get(nodeId)) { + if (status.getState().equals(LongRunningProcessState.DONE)) { + restartRequired = true; + break; + } } } + } + catch (Exception e) { + log.error("Failed to check process statuses", e); + restartRequired = true; + } + + if (restartRequired) + log.debug("Fail detected, cluster will be restarted [clusterId=" + clusterId + "]"); } if (restartRequired) { - clusterMgr.stopClusterIfExists(cacheName); + clusterMgr.stopClusterIfExists(clusterId); - TensorFlowCluster cluster = clusterMgr.getOrCreateCluster(cacheName); + TensorFlowCluster cluster = clusterMgr.createCluster( + clusterId, + jobArchive, + str -> ignite.message().sendOrdered("us_out_" + clusterId, str, 60 * 1000), + str -> ignite.message().sendOrdered("us_err_" + clusterId, str, 60 * 1000) + ); - IgniteMessaging messaging = Ignition.ignite().message(); - messaging.send(topicName, cluster); + ignite.message().send(topicName, Optional.of(cluster)); } } + + clusterMgr.stopClusterIfExists(clusterId); + + ignite.message().send(topicName, Optional.empty()); + + log.debug("Cluster maintainer completed [clusterId=" + clusterId + "]"); } /** @@ -116,7 +159,7 @@ public class TensorFlowClusterMaintainer implements Service { * @return True if mapping has been changed, otherwise false. */ private boolean hasAffinityChanged() { - Affinity<?> affinity = Ignition.ignite().affinity(cacheName); + Affinity<?> affinity = ignite.affinity(jobArchive.getUpstreamCacheName()); int parts = affinity.partitions(); http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterManager.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterManager.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterManager.java index 2d63195..cdbd774 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterManager.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowClusterManager.java @@ -17,36 +17,42 @@ package org.apache.ignite.tensorflow.cluster; -import java.io.Serializable; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; import java.util.concurrent.locks.Lock; -import java.util.function.Supplier; +import java.util.function.Consumer; +import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.CacheAtomicityMode; import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.tensorflow.cluster.spec.TensorFlowClusterSpec; import org.apache.ignite.tensorflow.cluster.spec.TensorFlowServerAddressSpec; import org.apache.ignite.tensorflow.cluster.tfrunning.TensorFlowServer; import org.apache.ignite.tensorflow.cluster.tfrunning.TensorFlowServerManager; +import org.apache.ignite.tensorflow.cluster.util.TensorFlowChiefRunner; import org.apache.ignite.tensorflow.cluster.util.TensorFlowClusterResolver; +import org.apache.ignite.tensorflow.cluster.util.TensorFlowUserScriptRunner; +import org.apache.ignite.tensorflow.core.util.CustomizableThreadFactory; /** * TensorFlow cluster manager that allows to start, maintain and stop TensorFlow cluster. */ -public class TensorFlowClusterManager implements Serializable { - /** */ - private static final long serialVersionUID = -4847155592164802806L; - +public class TensorFlowClusterManager { /** TensorFlow cluster metadata cache name. */ private static final String TF_CLUSTER_METADATA_CACHE_NAME = "TF_CLUSTER_METADATA_CACHE"; - /** Ignite instance supplier. */ - private final Supplier<Ignite> igniteSupplier; + /** Ignite instance. */ + private final Ignite ignite; /** TensorFlow server manager. */ private final TensorFlowServerManager srvProcMgr; @@ -55,74 +61,61 @@ public class TensorFlowClusterManager implements Serializable { private final TensorFlowClusterResolver clusterRslvr; /** TensorFlow cluster metadata cache. */ - private transient IgniteCache<String, TensorFlowCluster> cache; + private IgniteCache<UUID, TensorFlowCluster> cache; + + /** TensorFlow chief runners. */ + private ConcurrentMap<UUID, TensorFlowChiefRunner> chiefRunners; + + /** TensorFlow user script runners. */ + private ConcurrentMap<UUID, TensorFlowUserScriptRunner> userScriptRunners; /** * Constructs a new instance of TensorFlow cluster manager. * - * @param igniteSupplier Ignite instance supplier. - * @param <T> Type of serializable supplier. + * @param ignite Ignite instance. */ - public <T extends Supplier<Ignite> & Serializable> TensorFlowClusterManager(T igniteSupplier) { - this( - igniteSupplier, - new TensorFlowServerManager(igniteSupplier), - new TensorFlowClusterResolver(igniteSupplier) - ); + public TensorFlowClusterManager(Ignite ignite) { + assert ignite != null : "Ignite instance should not be null"; + + this.ignite = ignite; + this.srvProcMgr = new TensorFlowServerManager(ignite); + this.clusterRslvr = new TensorFlowClusterResolver(ignite, "TF", 10000, 1000); + this.cache = getOrCreateCache(); + this.chiefRunners = new ConcurrentHashMap<>(); + this.userScriptRunners = new ConcurrentHashMap<>(); } /** - * Constructs a new instance of TensorFlow cluster manager. + * Returns cluster by identifier. * - * @param igniteSupplier Ignite instance supplier. - * @param srvProcMgr TensorFlow server manager. - * @param clusterRslvr TensorFlow cluster resolver. + * @param clusterId Cluster identifier. + * @return TensorFlow cluster. */ - public <T extends Supplier<Ignite> & Serializable> TensorFlowClusterManager(T igniteSupplier, - TensorFlowServerManager srvProcMgr, TensorFlowClusterResolver clusterRslvr) { - assert igniteSupplier != null : "Ignite supplier should not be null"; - assert srvProcMgr != null : "TensorFlow server manager should not be null"; - assert clusterRslvr != null : "TensorFlow cluster resolver should not be null"; - - this.igniteSupplier = igniteSupplier; - this.srvProcMgr = srvProcMgr; - this.clusterRslvr = clusterRslvr; - } - - /** Initializes TensorFlow cluster manager and gets or creates correspondent caches. */ - public void init() { - clusterRslvr.init(); - - CacheConfiguration<String, TensorFlowCluster> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName(TF_CLUSTER_METADATA_CACHE_NAME); - cacheConfiguration.setCacheMode(CacheMode.REPLICATED); - cacheConfiguration.setAtomicityMode(CacheAtomicityMode.TRANSACTIONAL); - - Ignite ignite = igniteSupplier.get(); - cache = ignite.getOrCreateCache(cacheConfiguration); + public TensorFlowCluster getCluster(UUID clusterId) { + return cache.get(clusterId); } /** - * Creates and starts a new TensorFlow cluster for the specified cache if it doesn't exist, otherwise returns - * existing one. + * Creates and starts a new TensorFlow cluster for the specified cache. * - * @param upstreamCacheName Upstream cache name. + * @param clusterId Cluster identifier. + * @param jobArchive Job archive. * @return TensorFlow cluster metadata. */ - public TensorFlowCluster getOrCreateCluster(String upstreamCacheName) { - checkThatInitialized(); - - Lock clusterMgrCacheLock = cache.lock(upstreamCacheName); + public TensorFlowCluster createCluster(UUID clusterId, TensorFlowJobArchive jobArchive, + Consumer<String> userScriptOut, Consumer<String> userScriptErr) { + Lock clusterMgrCacheLock = cache.lock(clusterId); clusterMgrCacheLock.lock(); try { - TensorFlowCluster cluster = cache.get(upstreamCacheName); + TensorFlowCluster cluster = cache.get(clusterId); - if (cluster == null) { - TensorFlowClusterSpec clusterSpec = clusterRslvr.resolveAndAcquirePorts(upstreamCacheName); - cluster = startCluster(clusterSpec); - cache.put(upstreamCacheName, cluster); - } + if (cluster != null) + throw new IllegalStateException("Cluster is already created [clusterId=" + clusterId + "]"); + + TensorFlowClusterSpec clusterSpec = clusterRslvr.resolveAndAcquirePorts(jobArchive.getUpstreamCacheName()); + cluster = startCluster(clusterId, clusterSpec, jobArchive, userScriptOut, userScriptErr); + cache.put(clusterId, cluster); return cluster; } @@ -134,21 +127,21 @@ public class TensorFlowClusterManager implements Serializable { /** * Stops TensorFlow cluster. * - * @param upstreamCacheName Upstream cache name. + * @param clusterId TensorFlow cluster identifier. */ - public void stopClusterIfExists(String upstreamCacheName) { - checkThatInitialized(); - - Lock clusterMgrCacheLock = cache.lock(upstreamCacheName); + public void stopClusterIfExists(UUID clusterId) { + Lock clusterMgrCacheLock = cache.lock(clusterId); clusterMgrCacheLock.lock(); try { - TensorFlowCluster cluster = cache.get(upstreamCacheName); + TensorFlowCluster cluster = cache.get(clusterId); if (cluster != null) { + stopChief(clusterId); + stopUserScript(clusterId); srvProcMgr.stop(cluster.getProcesses(), true); - clusterRslvr.freePorts(cluster.getSpec()); - cache.remove(upstreamCacheName); + clusterRslvr.releasePorts(cluster.getSpec()); + cache.remove(clusterId); } } finally { @@ -159,44 +152,146 @@ public class TensorFlowClusterManager implements Serializable { /** Destroys TensorFlow cluster manager and related caches. */ public void destroy() { clusterRslvr.destroy(); - - Ignite ignite = igniteSupplier.get(); ignite.destroyCache(TF_CLUSTER_METADATA_CACHE_NAME); } /** * Starts TensorFlow cluster using the specified specification and returns metadata of the started cluster. * + * @param clusterId Cluster identifier. * @param spec TensorFlow cluster specification. * @return TensorFlow cluster metadata. */ - private TensorFlowCluster startCluster(TensorFlowClusterSpec spec) { - checkThatInitialized(); + private TensorFlowCluster startCluster(UUID clusterId, TensorFlowClusterSpec spec, TensorFlowJobArchive jobArchive, + Consumer<String> userScriptOut, Consumer<String> userScriptErr) { + Map<String, List<TensorFlowServerAddressSpec>> jobs = spec.getJobs(); - List<TensorFlowServer> srvs = new ArrayList<>(); + Map<UUID, List<UUID>> workerProcesses = startWorkers(spec, jobs.get(TensorFlowClusterResolver.WORKER_JOB_NAME)); - Map<String, List<TensorFlowServerAddressSpec>> jobs = spec.getJobs(); + startChief(clusterId, spec); + startUserScript(clusterId, jobArchive, spec, userScriptOut, userScriptErr); - for (String jobName : jobs.keySet()) { - List<TensorFlowServerAddressSpec> tasks = jobs.get(jobName); + return new TensorFlowCluster(spec, workerProcesses); + } + + /** + * Starts TensorFlow worker processes using the specified specification and returns identifiers of the started + * processes. + * + * @param spec TensorFlow cluster specification. + * @param tasks Worker tasks. + * @return Identifiers of the started processes. + */ + private Map<UUID, List<UUID>> startWorkers(TensorFlowClusterSpec spec, List<TensorFlowServerAddressSpec> tasks) { + List<TensorFlowServer> srvs = new ArrayList<>(); + if (tasks != null) { for (int i = 0; i < tasks.size(); i++) { - TensorFlowServer srvSpec = new TensorFlowServer(spec, jobName, i); + TensorFlowServer srvSpec = new TensorFlowServer(spec, TensorFlowClusterResolver.WORKER_JOB_NAME, i); srvs.add(srvSpec); } } - Map<UUID, List<UUID>> processes = srvProcMgr.start(srvs); + return srvProcMgr.start(srvs); + } - return new TensorFlowCluster(spec, processes); + /** + * Starts chief process using the specified cluster specification. + * + * @param clusterId Cluster identifier. + * @param spec TensorFlow cluster specification. + */ + private void startChief(UUID clusterId, TensorFlowClusterSpec spec) { + TensorFlowChiefRunner chiefRunner = new TensorFlowChiefRunner( + ignite, + Executors.newSingleThreadExecutor( + new CustomizableThreadFactory("tf-ch", true) + ), + spec, + System.out::println, + System.err::println + ); + + chiefRunner.start(); + + chiefRunners.put(clusterId, chiefRunner); } /** - * Checks that the component has been initialized. + * Stops chief process. + * + * @param clusterId Cluster identifier. */ - private void checkThatInitialized() { - if (cache == null) - throw new IllegalStateException("TensorFlow Cluster Manager is not initialized"); + private void stopChief(UUID clusterId) { + TensorFlowChiefRunner runner = chiefRunners.remove(clusterId); + + if (runner != null) + runner.stop(); + } + + /** + * Starts user script processes using the specified job archive. + * + * @param clusterId Cluster identifier. + * @param jobArchive Job archive. + * @param clusterSpec Cluster specification. + */ + private void startUserScript(UUID clusterId, TensorFlowJobArchive jobArchive, TensorFlowClusterSpec clusterSpec, + Consumer<String> out, Consumer<String> err) { + TensorFlowUserScriptRunner userScriptRunner = new TensorFlowUserScriptRunner( + ignite, + Executors.newSingleThreadExecutor( + new CustomizableThreadFactory("tf-us", true) + ), + jobArchive, + clusterSpec, + out, + err + ); + + userScriptRunner.start(); + + userScriptRunners.put(clusterId, userScriptRunner); + } + + /** + * Stops user script process. + * + * @param clusterId Cluster identifier. + */ + private void stopUserScript(UUID clusterId) { + TensorFlowUserScriptRunner runner = userScriptRunners.remove(clusterId); + + if (runner != null) + runner.stop(); + } + + /** + * Checks if user script completed and returns result. + * + * @param clusterId Cluster identifier. + * @return {@code true} if user script completed, otherwise {@code false}. + */ + public boolean isUserScriptCompleted(UUID clusterId) { + TensorFlowUserScriptRunner runner = userScriptRunners.get(clusterId); + + return runner != null && runner.isCompleted(); + } + + /** + * Returns list of maintained TensorFlow clusters. + * + * @return List of maintained TensorFlow clusters. + */ + public Map<UUID, TensorFlowCluster> getAllClusters() { + Map<UUID, TensorFlowCluster> res = new HashMap<>(); + + ScanQuery<UUID, TensorFlowCluster> qry = new ScanQuery<>(); + QueryCursor<Cache.Entry<UUID, TensorFlowCluster>> cursor = cache.query(qry); + for (Cache.Entry<UUID, TensorFlowCluster> e : cursor) + res.put(e.getKey(), e.getValue()); + + return res; } /** */ @@ -204,8 +299,17 @@ public class TensorFlowClusterManager implements Serializable { return srvProcMgr; } - /** */ - public IgniteCache<String, TensorFlowCluster> getCache() { - return cache; + /** + * Returns existing cluster manager cache or creates a new one. + * + * @return Cluster manager cache. + */ + private IgniteCache<UUID, TensorFlowCluster> getOrCreateCache() { + CacheConfiguration<UUID, TensorFlowCluster> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName(TF_CLUSTER_METADATA_CACHE_NAME); + cacheConfiguration.setCacheMode(CacheMode.REPLICATED); + cacheConfiguration.setAtomicityMode(CacheAtomicityMode.TRANSACTIONAL); + + return ignite.getOrCreateCache(cacheConfiguration); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowJobArchive.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowJobArchive.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowJobArchive.java new file mode 100644 index 0000000..953cf76 --- /dev/null +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/TensorFlowJobArchive.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.tensorflow.cluster; + +import java.io.Serializable; + +/** + * TensorFlow job archive that keeps archived working directory and command to be executed. + */ +public class TensorFlowJobArchive implements Serializable { + /** */ + private static final long serialVersionUID = -5977231383594482459L; + + /** Upstream cache name. */ + private final String upstreamCacheName; + + /** Archived working directory. */ + private final byte[] data; + + /** Command to be executed with arguments. */ + private final String[] commands; + + /** + * Constructs a new instance of TensorFlow job archive. + * + * @param upstreamCacheName Upstream cache name. + * @param data Archived working directory. + * @param commands Command to be executed with arguments. + */ + public TensorFlowJobArchive(String upstreamCacheName, byte[] data, String[] commands) { + this.upstreamCacheName = upstreamCacheName; + this.data = data; + this.commands = commands; + } + + /** */ + public String getUpstreamCacheName() { + return upstreamCacheName; + } + + /** */ + public byte[] getData() { + return data; + } + + /** */ + public String[] getCommands() { + return commands; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowClusterSpec.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowClusterSpec.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowClusterSpec.java index a053b8e..ea813ea 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowClusterSpec.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowClusterSpec.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import org.apache.ignite.Ignite; /** * TensorFlow cluster specification. @@ -52,6 +53,44 @@ public class TensorFlowClusterSpec implements Serializable { return this; } + /** + * Formats cluster specification so that TensorFlow accepts it. + * + * @param ignite Ignite instance. + * @return Formatted cluster specification. + */ + public String format(Ignite ignite) { + StringBuilder builder = new StringBuilder(); + + builder.append("{\n"); + + for (Map.Entry<String, List<TensorFlowServerAddressSpec>> entry : jobs.entrySet()) { + builder + .append("\t\"") + .append(entry.getKey()) + .append("\" : [ "); + + for (TensorFlowServerAddressSpec address : entry.getValue()) { + builder + .append("\n\t\t\"") + .append(address.format(ignite)) + .append("\", "); + } + + if (!entry.getValue().isEmpty()) + builder.delete(builder.length() - 2, builder.length()); + + builder.append("\n\t],\n"); + } + + if (!jobs.isEmpty()) + builder.delete(builder.length() - 2, builder.length() - 1); + + builder.append('}'); + + return builder.toString(); + } + /** */ public Map<String, List<TensorFlowServerAddressSpec>> getJobs() { return jobs; http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowServerAddressSpec.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowServerAddressSpec.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowServerAddressSpec.java index 196b166..727a030 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowServerAddressSpec.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/spec/TensorFlowServerAddressSpec.java @@ -18,7 +18,9 @@ package org.apache.ignite.tensorflow.cluster.spec; import java.io.Serializable; +import java.util.Collection; import java.util.UUID; +import org.apache.ignite.Ignite; /** * TensorFlow server address specification. @@ -47,6 +49,18 @@ public class TensorFlowServerAddressSpec implements Serializable { this.port = port; } + /** + * Formats Server Address specification so that TensorFlow accepts it. + * + * @param ignite Ignite instance. + * @return Formatted server address specification. + */ + public String format(Ignite ignite) { + Collection<String> names = ignite.cluster().forNodeId(nodeId).hostNames(); + + return names.iterator().next() + ":" + port; + } + /** */ public UUID getNodeId() { return nodeId; http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerManager.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerManager.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerManager.java index 192f619..ee87089 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerManager.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerManager.java @@ -17,36 +17,32 @@ package org.apache.ignite.tensorflow.cluster.tfrunning; -import java.io.Serializable; -import java.util.function.Supplier; -import org.apache.ignite.tensorflow.core.ProcessManager; -import org.apache.ignite.tensorflow.core.ProcessManagerWrapper; -import org.apache.ignite.tensorflow.core.pythonrunning.PythonProcess; -import org.apache.ignite.tensorflow.core.pythonrunning.PythonProcessManager; -import org.apache.ignite.tensorflow.cluster.spec.TensorFlowClusterSpec; -import org.apache.ignite.tensorflow.cluster.spec.TensorFlowServerAddressSpec; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.UUID; import org.apache.ignite.Ignite; import org.apache.ignite.Ignition; +import org.apache.ignite.tensorflow.cluster.spec.TensorFlowClusterSpec; +import org.apache.ignite.tensorflow.cluster.spec.TensorFlowServerAddressSpec; +import org.apache.ignite.tensorflow.core.ProcessManager; +import org.apache.ignite.tensorflow.core.ProcessManagerWrapper; +import org.apache.ignite.tensorflow.core.pythonrunning.PythonProcess; +import org.apache.ignite.tensorflow.core.pythonrunning.PythonProcessManager; /** * TensorFlow server manager that allows to start, stop and make other actions with TensorFlow servers. */ public class TensorFlowServerManager extends ProcessManagerWrapper<PythonProcess, TensorFlowServer> { - /** */ - private static final long serialVersionUID = 8355019934723445973L; + /** TensorFlow server script formatter. */ + private static final TensorFlowServerScriptFormatter scriptFormatter = new TensorFlowServerScriptFormatter(); /** * Constructs a new instance of TensorFlow server manager. * - * @param igniteSupplier Ignite instance supplier. - * @param <T> Type of serializable supplier. + * @param ignite Ignite instance. */ - public <T extends Supplier<Ignite> & Serializable> TensorFlowServerManager(T igniteSupplier) { - this(new PythonProcessManager(igniteSupplier)); + public TensorFlowServerManager(Ignite ignite) { + this(new PythonProcessManager(ignite)); } /** @@ -61,7 +57,7 @@ public class TensorFlowServerManager extends ProcessManagerWrapper<PythonProcess /** {@inheritDoc} */ @Override protected PythonProcess transformSpecification(TensorFlowServer spec) { return new PythonProcess( - formatPythonScript(spec), + scriptFormatter.format(spec, true, Ignition.ignite()), getNode(spec) ); } @@ -80,88 +76,4 @@ public class TensorFlowServerManager extends ProcessManagerWrapper<PythonProcess return addr.getNodeId(); } - - /** - * Formats TensorFlow server specification so that it's available to be passed into а python script. - * - * @param spec TensorFlow server specification. - * @return Formatted TensorFlow server specification. - */ - private String formatPythonScript(TensorFlowServer spec) { - StringBuilder builder = new StringBuilder(); - - builder.append("import tensorflow as tf").append('\n'); - builder.append("cluster = tf.train.ClusterSpec(") - .append(formatClusterSpec(spec.getClusterSpec())) - .append(')') - .append('\n'); - builder.append("server = tf.train.Server(cluster"); - - if (spec.getJobName() != null) - builder.append(", job_name=\"").append(spec.getJobName()).append('"'); - - if (spec.getTaskIdx() != null) - builder.append(", task_index=").append(spec.getTaskIdx()); - - if (spec.getProto() != null) - builder.append(", protocol=\"").append(spec.getProto()).append('"'); - - builder.append(')').append('\n'); - builder.append("server.join()").append('\n'); - - return builder.toString(); - } - - /** - * Formats TensorFlow cluster specification so that it's available to be passed into а python script. - * - * @param spec TensorFlow cluster specification. - * @return Formatted TensorFlow cluster specification. - */ - public String formatClusterSpec(TensorFlowClusterSpec spec) { - StringBuilder builder = new StringBuilder(); - - builder.append("{\n"); - - for (Map.Entry<String, List<TensorFlowServerAddressSpec>> entry : spec.getJobs().entrySet()) { - builder - .append("\t\"") - .append(entry.getKey()) - .append("\" : [ "); - - for (TensorFlowServerAddressSpec address : entry.getValue()) { - builder - .append("\n\t\t\"") - .append(formatAddressSpec(address)) - .append("\", "); - } - - if (!entry.getValue().isEmpty()) - builder.delete(builder.length() - 2, builder.length()); - - builder.append("\n\t],\n"); - } - - if (!spec.getJobs().isEmpty()) - builder.delete(builder.length() - 2, builder.length() - 1); - - builder.append('}'); - - return builder.toString(); - } - - /** - * Formats TensorFlow server address specification so that it's available to be passed into а python script. - * - * @param spec TensorFlow server address specification. - * @return Formatted TensorFlow server address specification. - */ - private String formatAddressSpec(TensorFlowServerAddressSpec spec) { - UUID nodeId = spec.getNodeId(); - - Ignite ignite = Ignition.localIgnite(); - Collection<String> names = ignite.cluster().forNodeId(nodeId).hostNames(); - - return names.iterator().next() + ":" + spec.getPort(); - } } http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java new file mode 100644 index 0000000..0645964 --- /dev/null +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/tfrunning/TensorFlowServerScriptFormatter.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.tensorflow.cluster.tfrunning; + +import org.apache.ignite.Ignite; + +/** + * Utils class that helps to format Python script that starts TensorFlow server. + */ +public class TensorFlowServerScriptFormatter { + /** + * Formats TensorFlow server specification so that it's available to be passed into а python script. + * + * @param srv Server specification. + * @param join Joins server by default or not. + * @param ignite Ignite instance. + * @return Formatted TensorFlow server script. + */ + public String format(TensorFlowServer srv, boolean join, Ignite ignite) { + StringBuilder builder = new StringBuilder(); + + builder.append("import tensorflow as tf").append('\n'); + builder.append("from tensorflow.contrib.ignite import IgniteDataset").append("\n"); + builder.append("cluster = tf.train.ClusterSpec(") + .append(srv.getClusterSpec().format(ignite)) + .append(')') + .append('\n'); + + builder.append("server = tf.train.Server(cluster"); + + if (srv.getJobName() != null) + builder.append(", job_name=\"").append(srv.getJobName()).append('"'); + + if (srv.getTaskIdx() != null) + builder.append(", task_index=").append(srv.getTaskIdx()); + + if (srv.getProto() != null) + builder.append(", protocol=\"").append(srv.getProto()).append('"'); + + builder.append(')').append('\n'); + + if (join) + builder.append("server.join()").append('\n'); + + return builder.toString(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/ClusterPortManager.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/ClusterPortManager.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/ClusterPortManager.java index 78087ab..462752c 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/ClusterPortManager.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/ClusterPortManager.java @@ -18,23 +18,33 @@ package org.apache.ignite.tensorflow.cluster.util; import java.io.Serializable; +import java.net.NetworkInterface; +import java.util.ArrayList; +import java.util.Arrays; import java.util.BitSet; +import java.util.Enumeration; +import java.util.List; import java.util.UUID; import java.util.concurrent.locks.Lock; -import java.util.function.Supplier; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; +import org.apache.ignite.IgniteLogger; import org.apache.ignite.cache.CacheAtomicityMode; import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cluster.ClusterGroup; +import org.apache.ignite.cluster.ClusterGroupEmptyException; import org.apache.ignite.configuration.CacheConfiguration; /** - * Cluster port manager that allows to reliably {@link #acquirePort(UUID)} and {@link #freePort(UUID, int)} on the + * Cluster port manager that allows to reliably {@link #acquirePort(UUID)} and {@link #releasePort(UUID, int)} on the * cluster nodes. */ -public class ClusterPortManager implements Serializable { - /** */ - private static final long serialVersionUID = -5116593574559007292L; +public class ClusterPortManager { + /** Ignite instance. */ + private final Ignite ignite; + + /** Ignite logger. */ + private final IgniteLogger log; /** Port manager cache name. */ private final String portMgrCacheName; @@ -45,11 +55,8 @@ public class ClusterPortManager implements Serializable { /** Port range size. */ private final int cnt; - /** Ignite instance supplier. */ - private final Supplier<Ignite> igniteSupplier; - /** Port manager cache */ - private transient IgniteCache<UUID, BitSet> cache; + private final IgniteCache<HostIdentifier, BitSet> cache; /** * Constructs a new instance of cluster port manager. @@ -58,28 +65,19 @@ public class ClusterPortManager implements Serializable { * @param from Port range from point. * @param cnt Port range size. */ - public <T extends Supplier<Ignite> & Serializable> ClusterPortManager(String poolName, int from, int cnt, - T igniteSupplier) { + public ClusterPortManager(Ignite ignite, String poolName, int from, int cnt) { + assert ignite != null : "Ignite instance should not be null"; assert poolName != null : "Pool name should not be null"; assert cnt >= 0 : "Count should not be negative"; assert from >= 0 && cnt + from <= 0xFFFF : "Port range should be between 0 and 65535"; - assert igniteSupplier != null : "Ignite supplier should not be null"; - this.portMgrCacheName = String.format("PORT_MANAGER_CACHE_%s", poolName); + this.ignite = ignite; + this.log = ignite.log().getLogger(ClusterPortManager.class); + + this.portMgrCacheName = String.format("PORT_MANAGER_%s_CACHE", poolName); this.from = from; this.cnt = cnt; - this.igniteSupplier = igniteSupplier; - } - - /** Initializes port manager and creates or gets correspondent caches. */ - public void init() { - CacheConfiguration<UUID, BitSet> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName(portMgrCacheName); - cacheConfiguration.setCacheMode(CacheMode.REPLICATED); - cacheConfiguration.setAtomicityMode(CacheAtomicityMode.TRANSACTIONAL); - - Ignite ignite = igniteSupplier.get(); - cache = ignite.getOrCreateCache(cacheConfiguration); + this.cache = getOrCreateCache(); } /** @@ -89,13 +87,16 @@ public class ClusterPortManager implements Serializable { * @return Port to be acquired. */ public int acquirePort(UUID nodeId) { - checkThatInitialized(); + HostIdentifier hostId = getHostIdentifier(nodeId); - Lock lock = cache.lock(nodeId); + if (hostId == null) + throw new IllegalStateException("Can't find node [nodeId=" + nodeId + "]"); + + Lock lock = cache.lock(hostId); lock.lock(); try { - BitSet ports = cache.get(nodeId); + BitSet ports = cache.get(hostId); if (ports == null) ports = new BitSet(cnt); @@ -106,8 +107,9 @@ public class ClusterPortManager implements Serializable { throw new IllegalStateException("No free ports in range [from=" + from + ", cnt=" + cnt + "]"); ports.set(free); + log.debug("Port acquired [nodeId=" + nodeId + ", port=" + (from + free) + "]"); - cache.put(nodeId, ports); + cache.put(hostId, ports); return from + free; } @@ -117,27 +119,31 @@ public class ClusterPortManager implements Serializable { } /** - * Frees acquired port on the specified node. + * Releases acquired port on the specified node. * * @param nodeId Node identifier. * @param port Acquired port to be free. */ - public void freePort(UUID nodeId, int port) { + public void releasePort(UUID nodeId, int port) { assert port - from >= 0 && port - from < cnt : "Port not in the range"; - checkThatInitialized(); + HostIdentifier hostId = getHostIdentifier(nodeId); + + if (hostId == null) + return; - Lock lock = cache.lock(nodeId); + Lock lock = cache.lock(hostId); lock.lock(); try { - BitSet ports = cache.get(nodeId); + BitSet ports = cache.get(hostId); if (ports != null) { ports.clear(port - from); + log.debug("Port released [nodeId=" + nodeId + ", port=" + port + "]"); if (ports.isEmpty()) - cache.remove(nodeId); + cache.remove(hostId); } } finally { @@ -147,15 +153,97 @@ public class ClusterPortManager implements Serializable { /** Destroys port manager and related caches. */ public void destroy() { - Ignite ignite = igniteSupplier.get(); ignite.destroyCache(portMgrCacheName); } /** - * Checks that the component has been initialized. + * Returns existed port pool cache or creates a new one. + * + * @return Port pool cache. + */ + private IgniteCache<HostIdentifier, BitSet> getOrCreateCache() { + CacheConfiguration<HostIdentifier, BitSet> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName(portMgrCacheName); + cacheConfiguration.setCacheMode(CacheMode.REPLICATED); + cacheConfiguration.setAtomicityMode(CacheAtomicityMode.TRANSACTIONAL); + + return ignite.getOrCreateCache(cacheConfiguration); + } + + /** + * Returns host identifier by node identifier. + * + * @param nodeId Node identifier. + * @return Host identifier. */ - private void checkThatInitialized() { - if (cache == null) - throw new IllegalStateException("Cluster Port Manager is not initialized"); + private HostIdentifier getHostIdentifier(UUID nodeId) { + try { + ClusterGroup grp = ignite.cluster().forNodeId(nodeId); + + return ignite.compute(grp).call(() -> { + Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces(); + + List<byte[]> macAddrs = new ArrayList<>(); + + while (interfaces.hasMoreElements()) { + NetworkInterface netItf = interfaces.nextElement(); + byte[] macAddr = netItf.getHardwareAddress(); + macAddrs.add(macAddr); + } + + return new HostIdentifier(macAddrs.toArray(new byte[macAddrs.size()][])); + }); + } + catch (ClusterGroupEmptyException e) { + return null; + } + } + + /** + * Host identifier based on arrays of mac addresses of the host machine. + */ + private static class HostIdentifier implements Serializable { + /** */ + private static final long serialVersionUID = -7060231325908935162L; + + /** Mac addresses. */ + private final byte[][] macAddrs; + + /** + * Constructs a new instance of host identifier. + * + * @param macAddrs Mac addresses. + */ + public HostIdentifier(byte[][] macAddrs) { + this.macAddrs = macAddrs; + } + + /** */ + public byte[][] getMacAddrs() { + return macAddrs; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + HostIdentifier that = (HostIdentifier)o; + if (macAddrs.length != that.macAddrs.length) + return false; + + for (int i = 0; i < macAddrs.length; i++) + if (!Arrays.equals(macAddrs[i], that.macAddrs[i])) + return false; + + return true; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return Arrays.hashCode(macAddrs); + } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowChiefRunner.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowChiefRunner.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowChiefRunner.java new file mode 100644 index 0000000..6998681 --- /dev/null +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowChiefRunner.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.tensorflow.cluster.util; + +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; +import org.apache.ignite.Ignite; +import org.apache.ignite.tensorflow.cluster.spec.TensorFlowClusterSpec; +import org.apache.ignite.tensorflow.cluster.tfrunning.TensorFlowServer; +import org.apache.ignite.tensorflow.cluster.tfrunning.TensorFlowServerScriptFormatter; +import org.apache.ignite.tensorflow.core.pythonrunning.PythonProcessBuilderSupplier; +import org.apache.ignite.tensorflow.core.util.AsyncNativeProcessRunner; +import org.apache.ignite.tensorflow.core.util.NativeProcessRunner; + +/** + * Utils class that helps to start and stop chief process. + */ +public class TensorFlowChiefRunner extends AsyncNativeProcessRunner { + /** Ignite instance. */ + private final Ignite ignite; + + /** TensorFlow cluster specification. */ + private final TensorFlowClusterSpec spec; + + /** Output stream data consumer. */ + private final Consumer<String> out; + + /** Error stream data consumer. */ + private final Consumer<String> err; + + /** + * Constructs a new instance of TensorFlow chief runner. + * + * @param ignite Ignite instance. + * @param executor Executor to be used in {@link AsyncNativeProcessRunner}. + * @param spec TensorFlow cluster specification. + * @param out Output stream data consumer. + * @param err Error stream data consumer. + */ + public TensorFlowChiefRunner(Ignite ignite, ExecutorService executor, TensorFlowClusterSpec spec, + Consumer<String> out, Consumer<String> err) { + super(ignite, executor); + this.ignite = ignite; + this.spec = spec; + this.out = out; + this.err = err; + } + + /** {@inheritDoc} */ + @Override public NativeProcessRunner doBefore() { + TensorFlowServer srv = new TensorFlowServer(spec, TensorFlowClusterResolver.CHIEF_JOB_NAME, 0); + + return new NativeProcessRunner( + new PythonProcessBuilderSupplier(true).get(), + new TensorFlowServerScriptFormatter().format(srv, true, ignite), + out, + err + ); + } + + /** {@inheritDoc} */ + @Override public void doAfter() { + // Do nothing. + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9e884e5a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowClusterResolver.java ---------------------------------------------------------------------- diff --git a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowClusterResolver.java b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowClusterResolver.java index e1d7d57..846af71 100644 --- a/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowClusterResolver.java +++ b/modules/tensorflow/src/main/java/org/apache/ignite/tensorflow/cluster/util/TensorFlowClusterResolver.java @@ -17,9 +17,7 @@ package org.apache.ignite.tensorflow.cluster.util; -import java.io.Serializable; import java.util.UUID; -import java.util.function.Supplier; import org.apache.ignite.Ignite; import org.apache.ignite.cache.affinity.Affinity; import org.apache.ignite.cluster.ClusterNode; @@ -29,32 +27,32 @@ import org.apache.ignite.tensorflow.cluster.spec.TensorFlowServerAddressSpec; /** * TensorFlow cluster resolver based on Ignite Cache affinity. */ -public class TensorFlowClusterResolver implements Serializable { - /** */ - private static final long serialVersionUID = 631456775167710173L; +public class TensorFlowClusterResolver { + /** TensorFlow worker job name. */ + public static final String WORKER_JOB_NAME = "worker"; + + /** TensorFlow chief job name. */ + public static final String CHIEF_JOB_NAME = "chief"; + + /** Ignite instance. */ + private final Ignite ignite; /** Cluster port manager. */ private final ClusterPortManager portMgr; - /** Ignite instance supplier. */ - private final Supplier<Ignite> igniteSupplier; - /** * Constructs a new instance of TensorFlow cluster resolver. * - * @param igniteSupplier Ignite instance supplier. - * @param <T> Type of serializable supplier. + * @param ignite Ignite instance. */ - public <T extends Supplier<Ignite> & Serializable> TensorFlowClusterResolver(T igniteSupplier) { - assert igniteSupplier != null : "Ignite supplier should not be null"; - - this.igniteSupplier = igniteSupplier; - this.portMgr = new ClusterPortManager("TF_POOL", 10000, 100, igniteSupplier); - } - - /** Initializes TensorFlow cluster resolver. */ - public void init() { - portMgr.init(); + public TensorFlowClusterResolver(Ignite ignite, String portPoolName, int portFrom, int portCnt) { + assert ignite != null : "Ignite instance should not be null"; + assert portPoolName != null : "Port pool name should not be null"; + assert portFrom >= 0 : "Port count should not be negative"; + assert portCnt >= 0 && portCnt + portFrom <= 0xFFFF : "Port range should be between 0 and 65535"; + + this.ignite = ignite; + this.portMgr = new ClusterPortManager(ignite, portPoolName, portFrom, portCnt); } /** @@ -64,38 +62,60 @@ public class TensorFlowClusterResolver implements Serializable { * @return TensorFlow cluster specification. */ public TensorFlowClusterSpec resolveAndAcquirePorts(String upstreamCacheName) { - Ignite ignite = igniteSupplier.get(); - Affinity<?> affinity = ignite.affinity(upstreamCacheName); - - int parts = affinity.partitions(); - TensorFlowClusterSpec spec = new TensorFlowClusterSpec(); - for (int part = 0; part < parts; part++) { - ClusterNode node = affinity.mapPartitionToNode(part); - UUID nodeId = node.id(); - - int port = portMgr.acquirePort(nodeId); - - spec.addTask("WORKER", nodeId, port); - } + resolveAndAcquirePortsForWorkers(spec, upstreamCacheName); + resolveAndAcquirePortsForChief(spec); return spec; } /** - * Frees ports acquired for the given cluster specification. + * Releases ports acquired for the given cluster specification. * * @param spec TensorFlow cluster specification. */ - public void freePorts(TensorFlowClusterSpec spec) { + public void releasePorts(TensorFlowClusterSpec spec) { for (String jobName : spec.getJobs().keySet()) for (TensorFlowServerAddressSpec address : spec.getJobs().get(jobName)) - portMgr.freePort(address.getNodeId(), address.getPort()); + portMgr.releasePort(address.getNodeId(), address.getPort()); } /** Destroys TensorFlow cluster resolver. */ public void destroy() { portMgr.destroy(); } + + /** + * Resolves TensorFlow cluster worker jobs and acquires ports. + * + * @param spec TensorFlow cluster specification. + * @param upstreamCacheName Upstream cache name. + */ + private void resolveAndAcquirePortsForWorkers(TensorFlowClusterSpec spec, String upstreamCacheName) { + Affinity<?> affinity = ignite.affinity(upstreamCacheName); + int parts = affinity.partitions(); + + for (int part = 0; part < parts; part++) { + ClusterNode node = affinity.mapPartitionToNode(part); + UUID nodeId = node.id(); + + int port = portMgr.acquirePort(nodeId); + + spec.addTask(WORKER_JOB_NAME, nodeId, port); + } + } + + /** + * Resolves TensorFlow cluster chief job and acquires ports. + * + * @param spec TensorFlow cluster specification. + */ + private void resolveAndAcquirePortsForChief(TensorFlowClusterSpec spec) { + ClusterNode chiefNode = ignite.cluster().localNode(); + UUID chiefNodeId = chiefNode.id(); + int chiefPort = portMgr.acquirePort(chiefNodeId); + + spec.addTask(CHIEF_JOB_NAME, chiefNodeId, chiefPort); + } }
