This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new ed80ff2 [MXNET-62] add test against spark integration (#10462) ed80ff2 is described below commit ed80ff2c01ff54e82215bf03e8df942ea729a15e Author: Nan Zhu <coding...@users.noreply.github.com> AuthorDate: Mon Jun 11 18:22:59 2018 -0700 [MXNET-62] add test against spark integration (#10462) * fix bug * temp * temp * temp * update * update * update * remove debugging stubs * remove unused * stylistic fix * fix typo * Pulled down update to submodule_dir * add test * retrigger it * sync 3rd party --- 3rdparty/ps-lite | 2 +- include/mxnet/kvstore.h | 1 + .../scala/org/apache/mxnet/optimizer/SGD.scala | 7 +- scala-package/pom.xml | 2 +- scala-package/spark/bin/run-mnist-example.sh | 9 +- scala-package/spark/pom.xml | 39 +++++- .../main/scala/org/apache/mxnet/spark/MXNet.scala | 7 +- .../scala/org/apache/mxnet/spark/MXNetParams.scala | 6 +- .../org/apache/mxnet/spark/ParameterServer.scala | 6 +- .../spark/example/ClassificationExample.scala | 1 + .../org/apache/mxnet/spark/MXNetGeneralSuite.scala | 69 ++++++++++ .../apache/mxnet/spark/SharedSparkContext.scala | 146 +++++++++++++++++++++ src/kvstore/kvstore_dist.h | 3 +- src/kvstore/kvstore_dist_server.h | 3 +- 14 files changed, 282 insertions(+), 19 deletions(-) diff --git a/3rdparty/ps-lite b/3rdparty/ps-lite index a6dda54..8a76389 160000 --- a/3rdparty/ps-lite +++ b/3rdparty/ps-lite @@ -1 +1 @@ -Subproject commit a6dda54604a07d1fb21b016ed1e3f4246b08222a +Subproject commit 8a763892a973afc1acd3d4b469d05bb338a83a6e diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 4e99a9c..9e92207 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -229,6 +229,7 @@ class KVStore { CHECK(updater) << "invalid updater"; updater_ = updater; } + /*! * \brief set an updater with string keys * diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala index c1b7259..e228e72 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala @@ -41,14 +41,15 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f, */ override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { // TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python package) - var lr = - (if (lrScheduler != null) { + var lr = { + if (lrScheduler != null) { val scheduledLr = lrScheduler(numUpdate) updateCount(index) scheduledLr } else { this.learningRate - }) + } + } lr = getLr(index, lr) val wd = getWd(index, this.wd) diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 9dcfa7c..cd5dba8 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -242,7 +242,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-surefire-plugin</artifactId> - <version>2.7</version> + <version>2.19</version> <configuration> <skipTests>true</skipTests> </configuration> diff --git a/scala-package/spark/bin/run-mnist-example.sh b/scala-package/spark/bin/run-mnist-example.sh index 962c337..392d6c6 100755 --- a/scala-package/spark/bin/run-mnist-example.sh +++ b/scala-package/spark/bin/run-mnist-example.sh @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. +set -x + CURR_DIR=$(cd `dirname $0`; pwd) SPARK_MODULE_DIR=$(cd $CURR_DIR/../; pwd) SCALA_PKG_DIR=$(cd $CURR_DIR/../../; pwd) @@ -35,10 +37,7 @@ SPARK_JAR=`find ${SPARK_MODULE_DIR}/target -name "*.jar" -type f -exec ls "{}" + SCALA_JAR=`find ${SCALA_PKG_DIR}/assembly/$OS/target -maxdepth 1 -name "*.jar" -type f -exec ls "{}" + | grep -v -E '(javadoc|sources)'` SPARK_OPTS+=" --name mxnet-spark-mnist" -SPARK_OPTS+=" --driver-memory 1g" -SPARK_OPTS+=" --executor-memory 1g" -SPARK_OPTS+=" --num-executors 2" -SPARK_OPTS+=" --executor-cores 1" +SPARK_OPTS+=" --driver-memory 2g" SPARK_OPTS+=" --jars ${SCALA_JAR}" # Download training and test set @@ -72,7 +71,7 @@ fi HOST=`hostname` -$SPARK_HOME/bin/spark-submit --master spark://$HOST:7077 \ +$SPARK_HOME/bin/spark-submit --master local[*] \ --class org.apache.mxnet.spark.example.ClassificationExample \ ${SPARK_OPTS} \ ${SPARK_JAR} \ diff --git a/scala-package/spark/pom.xml b/scala-package/spark/pom.xml index 281fad4..43ff1f7 100644 --- a/scala-package/spark/pom.xml +++ b/scala-package/spark/pom.xml @@ -16,7 +16,44 @@ <properties> <spark.version>1.6.3</spark.version> </properties> - + <profiles> + <profile> + <id>osx-x86_64-cpu</id> + <properties> + <platform>osx-x86_64-cpu</platform> + </properties> + </profile> + <profile> + <id>linux-x86_64-cpu</id> + <properties> + <platform>linux-x86_64-cpu</platform> + </properties> + </profile> + <profile> + <id>linux-x86_64-gpu</id> + <properties> + <platform>linux-x86_64-gpu</platform> + </properties> + </profile> + </profiles> + <build> + <plugins> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <configuration> + <argLine> + -Djava.library.path=${project.parent.basedir}/native/${platform}/target \ + -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties + </argLine> + </configuration> + </plugin> + <plugin> + <groupId>org.scalastyle</groupId> + <artifactId>scalastyle-maven-plugin</artifactId> + </plugin> + </plugins> + </build> <dependencies> <dependency> <groupId>org.apache.mxnet</groupId> diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala index 9720038..4952ca2 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala @@ -127,7 +127,8 @@ class MXNet extends Serializable { logger.info("Starting server ...") val server = new ParameterServer(params.runtimeClasspath, role = "server", - rootUri = schedulerIP, rootPort = schedulerPort, + rootUri = schedulerIP, + rootPort = schedulerPort, numServer = params.numServer, numWorker = params.numWorker, timeout = params.timeout, @@ -241,7 +242,9 @@ class MXNet extends Serializable { def fit(data: RDD[LabeledPoint]): MXNetModel = { val sc = data.context // distribute native jars - params.jars.foreach(jar => sc.addFile(jar)) + if (params.jars != null) { + params.jars.foreach(jar => sc.addFile(jar)) + } val trainData = { if (params.numWorker != data.partitions.length) { logger.info("repartitioning training set to {} partitions", params.numWorker) diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala index 47e6cd4..f72e56e 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala @@ -61,7 +61,11 @@ private[mxnet] class MXNetParams extends Serializable { // jars on executors for running mxnet application var jars: Array[String] = null def runtimeClasspath: String = { - jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":") + if (jars != null) { + jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":") + } else { + "" + } } // java binary diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala index 907d3de..45033d4 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala @@ -51,7 +51,7 @@ private[mxnet] object ParameterServer { def buildEnv(role: String, rootUri: String, rootPort: Int, numServer: Int, numWorker: Int): Map[String, String] = { - val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String] + val envs = mutable.HashMap.empty[String, String] envs.put("DMLC_ROLE", role) envs.put("DMLC_PS_ROOT_URI", rootUri) envs.put("DMLC_PS_ROOT_PORT", rootPort.toString) @@ -127,9 +127,9 @@ class ParameterServer( val inputStream = psProcess.get().getInputStream val errorStream = psProcess.get().getErrorStream logger.info(s"Starting InputStream-Redirecter Thread for $rootUri:$rootPort") - new RedirectThread(inputStream, System.out, "InputStream-Redirecter", true).start() + new RedirectThread(inputStream, System.out, "InputStream-Redirecter", false).start() logger.info(s"Starting ErrorStream-Redirecter Thread for $rootUri:$rootPort") - new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", true).start() + new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", false).start() } def startProcess(): Int = { diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala index ce49302..2026bde 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala @@ -103,6 +103,7 @@ object ClassificationExample { sc.stop() } catch { case e: Throwable => + e.printStackTrace() logger.error(e.getMessage, e) sys.exit(-1) } diff --git a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala new file mode 100644 index 0000000..74bc1db --- /dev/null +++ b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.spark + +import java.io.{BufferedReader, File, InputStreamReader} +import java.nio.file.Files + +import scala.sys.process.Process + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD + +class MXNetGeneralSuite extends SharedSparkContext { + + private var testDataDir: String = _ + + private def parseRawData(sc: SparkContext, path: String): RDD[LabeledPoint] = { + val raw = sc.textFile(path) + raw.map { s => + val parts = s.split(' ') + val label = java.lang.Double.parseDouble(parts(0)) + val features = Vectors.dense(parts(1).trim().split(',').map(java.lang.Double.parseDouble)) + LabeledPoint(label, features) + } + } + + private def downloadTestData(): Unit = { + Process("wget http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon" + + "/dataset/mxnet-spark-test/train.txt" + " -P " + testDataDir + " -q") ! + } + + override def beforeAll(): Unit = { + val tempDirFile = Files.createTempDirectory(s"mxnet-spark-test-${System.currentTimeMillis()}"). + toFile + testDataDir = tempDirFile.getPath + tempDirFile.deleteOnExit() + downloadTestData() + } + + + test("run spark with MLP") { + val trainData = parseRawData(sc, s"$testDataDir/train.txt") + val model = buildMlp().fit(trainData) + assert(model != null) + } + + test("run spark with LeNet") { + val trainData = parseRawData(sc, s"$testDataDir/train.txt") + val model = buildLeNet().fit(trainData) + assert(model != null) + } +} diff --git a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala new file mode 100644 index 0000000..2efd181 --- /dev/null +++ b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.spark + +import java.io.{File, FileFilter} + +import org.apache.mxnet.{Context, Shape, Symbol} + +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} + +trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { + + protected var sc: SparkContext = _ + + protected val numWorkers: Int = math.min(Runtime.getRuntime.availableProcessors(), 2) + + override def beforeEach() { + sc = new SparkContext(new SparkConf().setMaster("local[*]").setAppName("mxnet-spark-test")) + } + + override def afterEach(): Unit = { + if (sc != null) { + sc.stop() + } + } + + private def getMlp: Symbol = { + val data = Symbol.Variable("data") + val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128)) + val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu")) + val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64)) + val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu")) + val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10)) + val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3)) + mlp + } + + def getLenet: Symbol = { + val data = Symbol.Variable("data") + // first conv + val conv1 = Symbol.Convolution()()( + Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20)) + val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh")) + val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max", + "kernel" -> "(2, 2)", "stride" -> "(2, 2)")) + // second conv + val conv2 = Symbol.Convolution()()( + Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50)) + val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh")) + val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max", + "kernel" -> "(2, 2)", "stride" -> "(2, 2)")) + // first fullc + val flatten = Symbol.Flatten()()(Map("data" -> pool2)) + val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 500)) + val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh")) + // second fullc + val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 10)) + // loss + val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2)) + lenet + } + + private def composeWorkingDirPath: String = { + System.getProperty("user.dir") + } + + private def getJarFilePath(root: String): String = { + for (platform <- List("linux-x86_64-cpu", "linux-x86_64-gpu", "osx-x86_64-cpu")) { + val jarFiles = new File(s"$root/$platform/target/").listFiles(new FileFilter { + override def accept(pathname: File) = { + pathname.getAbsolutePath.endsWith(".jar") && + !pathname.getAbsolutePath.contains("javadoc") && + !pathname.getAbsolutePath.contains("sources") + } + }) + if (jarFiles != null && jarFiles.nonEmpty) { + return jarFiles.head.getAbsolutePath + } + } + null + } + + private def getSparkJar: String = { + val jarFiles = new File(s"$composeWorkingDirPath/target/").listFiles(new FileFilter { + override def accept(pathname: File) = { + pathname.getAbsolutePath.endsWith(".jar") && + !pathname.getAbsolutePath.contains("javadoc") && + !pathname.getAbsolutePath.contains("sources") + } + }) + if (jarFiles != null && jarFiles.nonEmpty) { + jarFiles.head.getAbsolutePath + } else { + null + } + } + + protected def buildLeNet(): MXNet = { + val workingDir = composeWorkingDirPath + val assemblyRoot = s"$workingDir/../assembly" + new MXNet() + .setBatchSize(128) + .setLabelName("softmax_label") + .setContext(Array(Context.cpu(0), Context.cpu(1))) + .setDimension(Shape(1, 28, 28)) + .setNetwork(getLenet) + .setNumEpoch(10) + .setNumServer(1) + .setNumWorker(numWorkers) + .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar") + .setJava("java") + } + + protected def buildMlp(): MXNet = { + val workingDir = composeWorkingDirPath + val assemblyRoot = s"$workingDir/../assembly" + new MXNet() + .setBatchSize(128) + .setLabelName("softmax_label") + .setContext(Array(Context.cpu(0), Context.cpu(1))) + .setDimension(Shape(784)) + .setNetwork(getMlp) + .setNumEpoch(10) + .setNumServer(1) + .setNumWorker(numWorkers) + .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar") + .setJava("java") + .setTimeout(0) + } +} diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 373081b..dd3464b 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -61,6 +61,7 @@ class KVStoreDist : public KVStoreLocal { virtual ~KVStoreDist() { Engine::Get()->WaitForAll(); + customer_id_ = 0; if (IsWorkerNode()) { if (barrier_before_exit_) { Barrier(); @@ -183,7 +184,7 @@ class KVStoreDist : public KVStoreLocal { for (size_t i = 0; i < keys.size(); ++i) { comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } - if (get_rank() == 0) { + if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 0) { Push_(keys, values, 0, false); // wait until the push is finished for (const int key : keys) { diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 421de27..a150ff4 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -103,7 +103,8 @@ class Executor { lk.unlock(); if (blk.f) { - blk.f(); blk.p->set_value(); + blk.f(); + blk.p->set_value(); } else { blk.p->set_value(); break; } -- To stop receiving notification emails like this one, please contact liuyi...@apache.org.