This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 9ed171e [scala-package][spark] Resources running PS (role = server) should be explicit to Spark (#7571) 9ed171e is described below commit 9ed171e35bf994b250e4a3ba9db5d792ee71c418 Author: Nan Zhu <coding...@users.noreply.github.com> AuthorDate: Tue Aug 29 20:29:52 2017 -0700 [scala-package][spark] Resources running PS (role = server) should be explicit to Spark (#7571) * temp * resources running PS (role = server) should be explicit to Spark * address the comments --- .../main/scala/ml/dmlc/mxnet/KVStoreServer.scala | 4 - .../src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala | 91 ++++++++++++++-------- .../ml/dmlc/mxnet/spark/ParameterServer.scala | 69 ++++++++-------- 3 files changed, 91 insertions(+), 73 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala index 22f9269..d3c8691 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStoreServer.scala @@ -20,10 +20,6 @@ package ml.dmlc.mxnet import ml.dmlc.mxnet.Base._ import org.slf4j.{Logger, LoggerFactory} -/** - * Server node for the key value store - * @author Yizhi Liu - */ private[mxnet] class KVStoreServer(private val kvStore: KVStore) { private val logger: Logger = LoggerFactory.getLogger(classOf[KVStoreServer]) private val handle: KVStoreHandle = kvStore.handle diff --git a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala index 27dd99f..cc77342 100644 --- a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala +++ b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/MXNet.scala @@ -27,14 +27,24 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -/** - * MXNet Training On Spark - * @author Yizhi Liu - */ class MXNet extends Serializable { + + class MXNetControllingThread( + schedulerIP: String, + schedulerPort: Int, + sparkContext: SparkContext, + triggerOfComponent: (String, Int, SparkContext) => Unit) extends Thread { + override def run() { + triggerOfComponent(schedulerIP, schedulerPort, sparkContext) + } + } + private val logger: Logger = LoggerFactory.getLogger(classOf[MXNet]) private val params: MXNetParams = new MXNetParams + @transient private var psServerThread: MXNetControllingThread = _ + @transient private var psSchedulerThread: MXNetControllingThread = _ + def setBatchSize(batchSize: Int): this.type = { params.batchSize = batchSize this @@ -105,30 +115,51 @@ class MXNet extends Serializable { this } - private def startParameterServers( + private def startPSServers( schedulerIP: String, schedulerPort: Int, - sc: SparkContext): ParameterServer = { - // TODO: check ip & port available - logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort) - val scheduler = new ParameterServer(params.runtimeClasspath, role = "scheduler", - rootUri = schedulerIP, rootPort = schedulerPort, - numServer = params.numServer, numWorker = params.numWorker, - timeout = params.timeout, java = params.javabin) - require(scheduler.startProcess(), "Failed to start ps scheduler process") - - sc.parallelize(1 to params.numServer, params.numServer).foreachPartition { p => - logger.info("Starting server ...") - val server = new ParameterServer(params.runtimeClasspath, - role = "server", + sc: SparkContext) = { + def startPSServersInner( + schedulerIP: String, + schedulerPort: Int, + sc: SparkContext): Unit = { + sc.parallelize(1 to params.numServer, params.numServer).foreachPartition { p => + logger.info("Starting server ...") + val server = new ParameterServer(params.runtimeClasspath, + role = "server", + rootUri = schedulerIP, rootPort = schedulerPort, + numServer = params.numServer, + numWorker = params.numWorker, + timeout = params.timeout, + java = params.javabin) + val exitCode = server.startProcess() + require(exitCode == 0, s"ps server process quit with exit code $exitCode") + } + } + psServerThread = new MXNetControllingThread(schedulerIP, schedulerPort, sc, startPSServersInner) + psServerThread.start() + } + + private def startPSScheduler( + schedulerIP: String, + schedulerPort: Int, + sc: SparkContext) = { + def startPSSchedulerInner( + schedulerIP: String, + schedulerPort: Int, + sc: SparkContext): Unit = { + // TODO: check ip & port available + logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort) + val scheduler = new ParameterServer(params.runtimeClasspath, role = "scheduler", rootUri = schedulerIP, rootPort = schedulerPort, - numServer = params.numServer, - numWorker = params.numWorker, - timeout = params.timeout, - java = params.javabin) - require(server.startProcess(), "Failed to start ps server process") + numServer = params.numServer, numWorker = params.numWorker, + timeout = params.timeout, java = params.javabin) + val exitCode = scheduler.startProcess() + require(exitCode == 0, s"Failed to start ps scheduler process with exit code $exitCode") } - scheduler + psSchedulerThread = new MXNetControllingThread(schedulerIP, schedulerPort, sc, + startPSSchedulerInner) + psSchedulerThread.start() } private def setFeedForwardModel( @@ -212,23 +243,21 @@ class MXNet extends Serializable { // distribute native jars params.jars.foreach(jar => sc.addFile(jar)) val trainData = { - if (params.numWorker > data.partitions.length) { + if (params.numWorker != data.partitions.length) { logger.info("repartitioning training set to {} partitions", params.numWorker) data.repartition(params.numWorker) - } else if (params.numWorker < data.partitions.length) { - logger.info("repartitioning training set to {} partitions", params.numWorker) - data.coalesce(params.numWorker) } else { data } } val schedulerIP = utils.Network.ipAddress val schedulerPort = utils.Network.availablePort - val scheduler = startParameterServers(schedulerIP, schedulerPort, sc) - // simply the first model + startPSScheduler(schedulerIP, schedulerPort, sc) + startPSServers(schedulerIP, schedulerPort, sc) val mxModel = trainModel(trainData, schedulerIP, schedulerPort) logger.info("Waiting for scheduler ...") - scheduler.waitFor() + psSchedulerThread.join() + psServerThread.join() mxModel } } diff --git a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala index 60e1c69..7ed4512 100644 --- a/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala +++ b/scala-package/spark/src/main/scala/ml/dmlc/mxnet/spark/ParameterServer.scala @@ -84,17 +84,19 @@ private[mxnet] object ParameterServer { } } -class ParameterServer(private val classpath: String, - private val role: String, - private val rootUri: String, - private val rootPort: Int, - private val numServer: Int = 1, - private val numWorker: Int = 1, - private val timeout: Int = 0, - private val java: String = "java", - private val jvmOpts: String = "") { +class ParameterServer( + classpath: String, + role: String, + rootUri: String, + rootPort: Int, + numServer: Int = 1, + numWorker: Int = 1, + timeout: Int = 0, + java: String = "java", + jvmOpts: String = "") { + private val logger: Logger = LoggerFactory.getLogger(classOf[ParameterServer]) - private val trackerProcess: AtomicReference[Process] = new AtomicReference[Process] + private val psProcess: AtomicReference[Process] = new AtomicReference[Process] /** * A utility class to redirect the child process's stdout or stderr. @@ -121,47 +123,38 @@ class ParameterServer(private val classpath: String, } } - def startProcess(): Boolean = { + private def startLoggingThreads(rootUri: String, rootPort: Int): Unit = { + val inputStream = psProcess.get().getInputStream + val errorStream = psProcess.get().getErrorStream + logger.info(s"Starting InputStream-Redirecter Thread for $rootUri:$rootPort") + new RedirectThread(inputStream, System.out, "InputStream-Redirecter", true).start() + logger.info(s"Starting ErrorStream-Redirecter Thread for $rootUri:$rootPort") + new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", true).start() + } + + def startProcess(): Int = { val cp = if (classpath == null) "" else s"-cp $classpath" val cmd = s"$java $jvmOpts $cp $runningClass " + s"--role=$role --root-uri=$rootUri --root-port=$rootPort " + s"--num-server=$numServer --num-worker=$numWorker --timeout=$timeout" - logger.info(s"Start process: $cmd") try { val childProcess = Runtime.getRuntime.exec(cmd) - trackerProcess.set(childProcess) - val inputStream = childProcess.getInputStream - val errorStream = childProcess.getErrorStream - logger.info("Starting InputStream-Redirecter Thread") - new RedirectThread(inputStream, System.out, "InputStream-Redirecter", true).start() - logger.info("Starting ErrorStream-Redirecter Thread") - new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", true).start() - true + logger.info(s"Started process: $cmd at $rootUri:$rootPort") + psProcess.set(childProcess) + startLoggingThreads(rootUri, rootPort) + psProcess.get().waitFor() } catch { case ioe: IOException => ioe.printStackTrace() - false + 1 + } finally { + stop() } } def stop() { - if (trackerProcess.get != null) { - trackerProcess.get.destroy() - } - } - - def waitFor(): Int = { - try { - trackerProcess.get.waitFor() - val returnVal: Int = trackerProcess.get.exitValue - logger.info("Process ends with exit code " + returnVal) - stop() - returnVal - } catch { - case e: InterruptedException => - e.printStackTrace() - logger.error("Process terminated unexpectedly") - 1 + if (psProcess.get != null && psProcess.get().isAlive) { + psProcess.get.destroy() } } -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].