This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ed171e  [scala-package][spark] Resources running PS (role = server) 
should be explicit to Spark (#7571)
9ed171e is described below

commit 9ed171e35bf994b250e4a3ba9db5d792ee71c418
Author: Nan Zhu <coding...@users.noreply.github.com>
AuthorDate: Tue Aug 29 20:29:52 2017 -0700

    [scala-package][spark] Resources running PS (role = server) should be 
explicit to Spark (#7571)
    
    * temp
    
    * resources running PS (role = server) should be explicit to Spark
    
    * address the comments
---
 .../main/scala/ml/dmlc/mxnet/KVStoreServer.scala   |  4 -
 .../src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala | 91 ++++++++++++++--------
 .../ml/dmlc/mxnet/spark/ParameterServer.scala      | 69 ++++++++--------
 3 files changed, 91 insertions(+), 73 deletions(-)

diff --git 
a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala 
b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala
index 22f9269..d3c8691 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala
@@ -20,10 +20,6 @@ package ml.dmlc.mxnet
 import ml.dmlc.mxnet.Base._
 import org.slf4j.{Logger, LoggerFactory}
 
-/**
- * Server node for the key value store
- * @author Yizhi Liu
- */
 private[mxnet] class KVStoreServer(private val kvStore: KVStore) {
   private val logger: Logger = LoggerFactory.getLogger(classOf[KVStoreServer])
   private val handle: KVStoreHandle = kvStore.handle
diff --git a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala 
b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala
index 27dd99f..cc77342 100644
--- a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala
+++ b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala
@@ -27,14 +27,24 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext
 
-/**
- * MXNet Training On Spark
- * @author Yizhi Liu
- */
 class MXNet extends Serializable {
+
+  class MXNetControllingThread(
+      schedulerIP: String,
+      schedulerPort: Int,
+      sparkContext: SparkContext,
+      triggerOfComponent: (String, Int, SparkContext) => Unit) extends Thread {
+    override def run() {
+      triggerOfComponent(schedulerIP, schedulerPort, sparkContext)
+    }
+  }
+
   private val logger: Logger = LoggerFactory.getLogger(classOf[MXNet])
   private val params: MXNetParams = new MXNetParams
 
+  @transient private var psServerThread: MXNetControllingThread = _
+  @transient private var psSchedulerThread: MXNetControllingThread = _
+
   def setBatchSize(batchSize: Int): this.type = {
     params.batchSize = batchSize
     this
@@ -105,30 +115,51 @@ class MXNet extends Serializable {
     this
   }
 
-  private def startParameterServers(
+  private def startPSServers(
       schedulerIP: String,
       schedulerPort: Int,
-      sc: SparkContext): ParameterServer = {
-    // TODO: check ip & port available
-    logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort)
-    val scheduler = new ParameterServer(params.runtimeClasspath, role = 
"scheduler",
-      rootUri = schedulerIP, rootPort = schedulerPort,
-      numServer = params.numServer, numWorker = params.numWorker,
-      timeout = params.timeout, java = params.javabin)
-    require(scheduler.startProcess(), "Failed to start ps scheduler process")
-
-    sc.parallelize(1 to params.numServer, params.numServer).foreachPartition { 
p =>
-      logger.info("Starting server ...")
-      val server = new ParameterServer(params.runtimeClasspath,
-        role = "server",
+      sc: SparkContext) = {
+    def startPSServersInner(
+        schedulerIP: String,
+        schedulerPort: Int,
+        sc: SparkContext): Unit = {
+      sc.parallelize(1 to params.numServer, params.numServer).foreachPartition 
{ p =>
+          logger.info("Starting server ...")
+          val server = new ParameterServer(params.runtimeClasspath,
+            role = "server",
+            rootUri = schedulerIP, rootPort = schedulerPort,
+            numServer = params.numServer,
+            numWorker = params.numWorker,
+            timeout = params.timeout,
+            java = params.javabin)
+          val exitCode = server.startProcess()
+          require(exitCode == 0, s"ps server process quit with exit code 
$exitCode")
+        }
+    }
+    psServerThread = new MXNetControllingThread(schedulerIP, schedulerPort, 
sc, startPSServersInner)
+    psServerThread.start()
+  }
+
+  private def startPSScheduler(
+      schedulerIP: String,
+      schedulerPort: Int,
+      sc: SparkContext) = {
+    def startPSSchedulerInner(
+        schedulerIP: String,
+        schedulerPort: Int,
+        sc: SparkContext): Unit = {
+      // TODO: check ip & port available
+      logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort)
+      val scheduler = new ParameterServer(params.runtimeClasspath, role = 
"scheduler",
         rootUri = schedulerIP, rootPort = schedulerPort,
-        numServer = params.numServer,
-        numWorker = params.numWorker,
-        timeout = params.timeout,
-        java = params.javabin)
-      require(server.startProcess(), "Failed to start ps server process")
+        numServer = params.numServer, numWorker = params.numWorker,
+        timeout = params.timeout, java = params.javabin)
+      val exitCode = scheduler.startProcess()
+      require(exitCode == 0, s"Failed to start ps scheduler process with exit 
code $exitCode")
     }
-    scheduler
+    psSchedulerThread = new MXNetControllingThread(schedulerIP, schedulerPort, 
sc,
+      startPSSchedulerInner)
+    psSchedulerThread.start()
   }
 
   private def setFeedForwardModel(
@@ -212,23 +243,21 @@ class MXNet extends Serializable {
     // distribute native jars
     params.jars.foreach(jar => sc.addFile(jar))
     val trainData = {
-      if (params.numWorker > data.partitions.length) {
+      if (params.numWorker != data.partitions.length) {
         logger.info("repartitioning training set to {} partitions", 
params.numWorker)
         data.repartition(params.numWorker)
-      } else if (params.numWorker < data.partitions.length) {
-        logger.info("repartitioning training set to {} partitions", 
params.numWorker)
-        data.coalesce(params.numWorker)
       } else {
         data
       }
     }
     val schedulerIP = utils.Network.ipAddress
     val schedulerPort = utils.Network.availablePort
-    val scheduler = startParameterServers(schedulerIP, schedulerPort, sc)
-    // simply the first model
+    startPSScheduler(schedulerIP, schedulerPort, sc)
+    startPSServers(schedulerIP, schedulerPort, sc)
     val mxModel = trainModel(trainData, schedulerIP, schedulerPort)
     logger.info("Waiting for scheduler ...")
-    scheduler.waitFor()
+    psSchedulerThread.join()
+    psServerThread.join()
     mxModel
   }
 }
diff --git 
a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala 
b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala
index 60e1c69..7ed4512 100644
--- 
a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala
+++ 
b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala
@@ -84,17 +84,19 @@ private[mxnet] object ParameterServer {
   }
 }
 
-class ParameterServer(private val classpath: String,
-                      private val role: String,
-                      private val rootUri: String,
-                      private val rootPort: Int,
-                      private val numServer: Int = 1,
-                      private val numWorker: Int = 1,
-                      private val timeout: Int = 0,
-                      private val java: String = "java",
-                      private val jvmOpts: String = "") {
+class ParameterServer(
+    classpath: String,
+    role: String,
+    rootUri: String,
+    rootPort: Int,
+    numServer: Int = 1,
+    numWorker: Int = 1,
+    timeout: Int = 0,
+    java: String = "java",
+    jvmOpts: String = "") {
+
   private val logger: Logger = 
LoggerFactory.getLogger(classOf[ParameterServer])
-  private val trackerProcess: AtomicReference[Process] = new 
AtomicReference[Process]
+  private val psProcess: AtomicReference[Process] = new 
AtomicReference[Process]
 
   /**
    * A utility class to redirect the child process's stdout or stderr.
@@ -121,47 +123,38 @@ class ParameterServer(private val classpath: String,
     }
   }
 
-  def startProcess(): Boolean = {
+  private def startLoggingThreads(rootUri: String, rootPort: Int): Unit = {
+    val inputStream = psProcess.get().getInputStream
+    val errorStream = psProcess.get().getErrorStream
+    logger.info(s"Starting InputStream-Redirecter Thread for 
$rootUri:$rootPort")
+    new RedirectThread(inputStream, System.out, "InputStream-Redirecter", 
true).start()
+    logger.info(s"Starting ErrorStream-Redirecter Thread for 
$rootUri:$rootPort")
+    new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", 
true).start()
+  }
+
+  def startProcess(): Int = {
     val cp = if (classpath == null) "" else s"-cp $classpath"
     val cmd = s"$java $jvmOpts $cp $runningClass " +
       s"--role=$role --root-uri=$rootUri --root-port=$rootPort " +
       s"--num-server=$numServer --num-worker=$numWorker --timeout=$timeout"
-    logger.info(s"Start process: $cmd")
     try {
       val childProcess = Runtime.getRuntime.exec(cmd)
-      trackerProcess.set(childProcess)
-      val inputStream = childProcess.getInputStream
-      val errorStream = childProcess.getErrorStream
-      logger.info("Starting InputStream-Redirecter Thread")
-      new RedirectThread(inputStream, System.out, "InputStream-Redirecter", 
true).start()
-      logger.info("Starting ErrorStream-Redirecter Thread")
-      new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", 
true).start()
-      true
+      logger.info(s"Started process: $cmd at $rootUri:$rootPort")
+      psProcess.set(childProcess)
+      startLoggingThreads(rootUri, rootPort)
+      psProcess.get().waitFor()
     } catch {
       case ioe: IOException =>
         ioe.printStackTrace()
-        false
+        1
+    } finally {
+      stop()
     }
   }
 
   def stop() {
-    if (trackerProcess.get != null) {
-      trackerProcess.get.destroy()
-    }
-  }
-
-  def waitFor(): Int = {
-    try {
-      trackerProcess.get.waitFor()
-      val returnVal: Int = trackerProcess.get.exitValue
-      logger.info("Process ends with exit code " + returnVal)
-      stop()
-      returnVal
-    } catch {
-      case e: InterruptedException =>
-        e.printStackTrace()
-        logger.error("Process terminated unexpectedly")
-        1
+    if (psProcess.get != null && psProcess.get().isAlive) {
+      psProcess.get.destroy()
     }
   }
 

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to