This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ed80ff2 [MXNET-62] add test against spark integration (#10462)
ed80ff2 is described below
commit ed80ff2c01ff54e82215bf03e8df942ea729a15e
Author: Nan Zhu <[email protected]>
AuthorDate: Mon Jun 11 18:22:59 2018 -0700
[MXNET-62] add test against spark integration (#10462)
* fix bug
* temp
* temp
* temp
* update
* update
* update
* remove debugging stubs
* remove unused
* stylistic fix
* fix typo
* Pulled down update to submodule_dir
* add test
* retrigger it
* sync 3rd party
---
3rdparty/ps-lite | 2 +-
include/mxnet/kvstore.h | 1 +
.../scala/org/apache/mxnet/optimizer/SGD.scala | 7 +-
scala-package/pom.xml | 2 +-
scala-package/spark/bin/run-mnist-example.sh | 9 +-
scala-package/spark/pom.xml | 39 +++++-
.../main/scala/org/apache/mxnet/spark/MXNet.scala | 7 +-
.../scala/org/apache/mxnet/spark/MXNetParams.scala | 6 +-
.../org/apache/mxnet/spark/ParameterServer.scala | 6 +-
.../spark/example/ClassificationExample.scala | 1 +
.../org/apache/mxnet/spark/MXNetGeneralSuite.scala | 69 ++++++++++
.../apache/mxnet/spark/SharedSparkContext.scala | 146 +++++++++++++++++++++
src/kvstore/kvstore_dist.h | 3 +-
src/kvstore/kvstore_dist_server.h | 3 +-
14 files changed, 282 insertions(+), 19 deletions(-)
diff --git a/3rdparty/ps-lite b/3rdparty/ps-lite
index a6dda54..8a76389 160000
--- a/3rdparty/ps-lite
+++ b/3rdparty/ps-lite
@@ -1 +1 @@
-Subproject commit a6dda54604a07d1fb21b016ed1e3f4246b08222a
+Subproject commit 8a763892a973afc1acd3d4b469d05bb338a83a6e
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 4e99a9c..9e92207 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -229,6 +229,7 @@ class KVStore {
CHECK(updater) << "invalid updater";
updater_ = updater;
}
+
/*!
* \brief set an updater with string keys
*
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
index c1b7259..e228e72 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
@@ -41,14 +41,15 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float
= 0.0f,
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state:
AnyRef): Unit = {
// TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python
package)
- var lr =
- (if (lrScheduler != null) {
+ var lr = {
+ if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
- })
+ }
+ }
lr = getLr(index, lr)
val wd = getWd(index, this.wd)
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 9dcfa7c..cd5dba8 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -242,7 +242,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
- <version>2.7</version>
+ <version>2.19</version>
<configuration>
<skipTests>true</skipTests>
</configuration>
diff --git a/scala-package/spark/bin/run-mnist-example.sh
b/scala-package/spark/bin/run-mnist-example.sh
index 962c337..392d6c6 100755
--- a/scala-package/spark/bin/run-mnist-example.sh
+++ b/scala-package/spark/bin/run-mnist-example.sh
@@ -17,6 +17,8 @@
# specific language governing permissions and limitations
# under the License.
+set -x
+
CURR_DIR=$(cd `dirname $0`; pwd)
SPARK_MODULE_DIR=$(cd $CURR_DIR/../; pwd)
SCALA_PKG_DIR=$(cd $CURR_DIR/../../; pwd)
@@ -35,10 +37,7 @@ SPARK_JAR=`find ${SPARK_MODULE_DIR}/target -name "*.jar"
-type f -exec ls "{}" +
SCALA_JAR=`find ${SCALA_PKG_DIR}/assembly/$OS/target -maxdepth 1 -name "*.jar"
-type f -exec ls "{}" + | grep -v -E '(javadoc|sources)'`
SPARK_OPTS+=" --name mxnet-spark-mnist"
-SPARK_OPTS+=" --driver-memory 1g"
-SPARK_OPTS+=" --executor-memory 1g"
-SPARK_OPTS+=" --num-executors 2"
-SPARK_OPTS+=" --executor-cores 1"
+SPARK_OPTS+=" --driver-memory 2g"
SPARK_OPTS+=" --jars ${SCALA_JAR}"
# Download training and test set
@@ -72,7 +71,7 @@ fi
HOST=`hostname`
-$SPARK_HOME/bin/spark-submit --master spark://$HOST:7077 \
+$SPARK_HOME/bin/spark-submit --master local[*] \
--class org.apache.mxnet.spark.example.ClassificationExample \
${SPARK_OPTS} \
${SPARK_JAR} \
diff --git a/scala-package/spark/pom.xml b/scala-package/spark/pom.xml
index 281fad4..43ff1f7 100644
--- a/scala-package/spark/pom.xml
+++ b/scala-package/spark/pom.xml
@@ -16,7 +16,44 @@
<properties>
<spark.version>1.6.3</spark.version>
</properties>
-
+ <profiles>
+ <profile>
+ <id>osx-x86_64-cpu</id>
+ <properties>
+ <platform>osx-x86_64-cpu</platform>
+ </properties>
+ </profile>
+ <profile>
+ <id>linux-x86_64-cpu</id>
+ <properties>
+ <platform>linux-x86_64-cpu</platform>
+ </properties>
+ </profile>
+ <profile>
+ <id>linux-x86_64-gpu</id>
+ <properties>
+ <platform>linux-x86_64-gpu</platform>
+ </properties>
+ </profile>
+ </profiles>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <configuration>
+ <argLine>
+
-Djava.library.path=${project.parent.basedir}/native/${platform}/target \
+
-Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
+ </argLine>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.scalastyle</groupId>
+ <artifactId>scalastyle-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
<dependencies>
<dependency>
<groupId>org.apache.mxnet</groupId>
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
index 9720038..4952ca2 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
@@ -127,7 +127,8 @@ class MXNet extends Serializable {
logger.info("Starting server ...")
val server = new ParameterServer(params.runtimeClasspath,
role = "server",
- rootUri = schedulerIP, rootPort = schedulerPort,
+ rootUri = schedulerIP,
+ rootPort = schedulerPort,
numServer = params.numServer,
numWorker = params.numWorker,
timeout = params.timeout,
@@ -241,7 +242,9 @@ class MXNet extends Serializable {
def fit(data: RDD[LabeledPoint]): MXNetModel = {
val sc = data.context
// distribute native jars
- params.jars.foreach(jar => sc.addFile(jar))
+ if (params.jars != null) {
+ params.jars.foreach(jar => sc.addFile(jar))
+ }
val trainData = {
if (params.numWorker != data.partitions.length) {
logger.info("repartitioning training set to {} partitions",
params.numWorker)
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
index 47e6cd4..f72e56e 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
@@ -61,7 +61,11 @@ private[mxnet] class MXNetParams extends Serializable {
// jars on executors for running mxnet application
var jars: Array[String] = null
def runtimeClasspath: String = {
- jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":")
+ if (jars != null) {
+ jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":")
+ } else {
+ ""
+ }
}
// java binary
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
index 907d3de..45033d4 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
@@ -51,7 +51,7 @@ private[mxnet] object ParameterServer {
def buildEnv(role: String, rootUri: String, rootPort: Int,
numServer: Int, numWorker: Int): Map[String, String] = {
- val envs: mutable.Map[String, String] = mutable.HashMap.empty[String,
String]
+ val envs = mutable.HashMap.empty[String, String]
envs.put("DMLC_ROLE", role)
envs.put("DMLC_PS_ROOT_URI", rootUri)
envs.put("DMLC_PS_ROOT_PORT", rootPort.toString)
@@ -127,9 +127,9 @@ class ParameterServer(
val inputStream = psProcess.get().getInputStream
val errorStream = psProcess.get().getErrorStream
logger.info(s"Starting InputStream-Redirecter Thread for
$rootUri:$rootPort")
- new RedirectThread(inputStream, System.out, "InputStream-Redirecter",
true).start()
+ new RedirectThread(inputStream, System.out, "InputStream-Redirecter",
false).start()
logger.info(s"Starting ErrorStream-Redirecter Thread for
$rootUri:$rootPort")
- new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter",
true).start()
+ new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter",
false).start()
}
def startProcess(): Int = {
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
index ce49302..2026bde 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
@@ -103,6 +103,7 @@ object ClassificationExample {
sc.stop()
} catch {
case e: Throwable =>
+ e.printStackTrace()
logger.error(e.getMessage, e)
sys.exit(-1)
}
diff --git
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
new file mode 100644
index 0000000..74bc1db
--- /dev/null
+++
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.spark
+
+import java.io.{BufferedReader, File, InputStreamReader}
+import java.nio.file.Files
+
+import scala.sys.process.Process
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+
+class MXNetGeneralSuite extends SharedSparkContext {
+
+ private var testDataDir: String = _
+
+ private def parseRawData(sc: SparkContext, path: String): RDD[LabeledPoint]
= {
+ val raw = sc.textFile(path)
+ raw.map { s =>
+ val parts = s.split(' ')
+ val label = java.lang.Double.parseDouble(parts(0))
+ val features =
Vectors.dense(parts(1).trim().split(',').map(java.lang.Double.parseDouble))
+ LabeledPoint(label, features)
+ }
+ }
+
+ private def downloadTestData(): Unit = {
+ Process("wget
http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon" +
+ "/dataset/mxnet-spark-test/train.txt" + " -P " + testDataDir + " -q") !
+ }
+
+ override def beforeAll(): Unit = {
+ val tempDirFile =
Files.createTempDirectory(s"mxnet-spark-test-${System.currentTimeMillis()}").
+ toFile
+ testDataDir = tempDirFile.getPath
+ tempDirFile.deleteOnExit()
+ downloadTestData()
+ }
+
+
+ test("run spark with MLP") {
+ val trainData = parseRawData(sc, s"$testDataDir/train.txt")
+ val model = buildMlp().fit(trainData)
+ assert(model != null)
+ }
+
+ test("run spark with LeNet") {
+ val trainData = parseRawData(sc, s"$testDataDir/train.txt")
+ val model = buildLeNet().fit(trainData)
+ assert(model != null)
+ }
+}
diff --git
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
new file mode 100644
index 0000000..2efd181
--- /dev/null
+++
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.spark
+
+import java.io.{File, FileFilter}
+
+import org.apache.mxnet.{Context, Shape, Symbol}
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+
+trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with
BeforeAndAfterAll {
+
+ protected var sc: SparkContext = _
+
+ protected val numWorkers: Int =
math.min(Runtime.getRuntime.availableProcessors(), 2)
+
+ override def beforeEach() {
+ sc = new SparkContext(new
SparkConf().setMaster("local[*]").setAppName("mxnet-spark-test"))
+ }
+
+ override def afterEach(): Unit = {
+ if (sc != null) {
+ sc.stop()
+ }
+ }
+
+ private def getMlp: Symbol = {
+ val data = Symbol.Variable("data")
+ val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data,
"num_hidden" -> 128))
+ val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1,
"act_type" -> "relu"))
+ val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1,
"num_hidden" -> 64))
+ val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2,
"act_type" -> "relu"))
+ val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2,
"num_hidden" -> 10))
+ val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3))
+ mlp
+ }
+
+ def getLenet: Symbol = {
+ val data = Symbol.Variable("data")
+ // first conv
+ val conv1 = Symbol.Convolution()()(
+ Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20))
+ val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" ->
"tanh"))
+ val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max",
+ "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+ // second conv
+ val conv2 = Symbol.Convolution()()(
+ Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50))
+ val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" ->
"tanh"))
+ val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
+ "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+ // first fullc
+ val flatten = Symbol.Flatten()()(Map("data" -> pool2))
+ val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" ->
500))
+ val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh"))
+ // second fullc
+ val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" ->
10))
+ // loss
+ val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))
+ lenet
+ }
+
+ private def composeWorkingDirPath: String = {
+ System.getProperty("user.dir")
+ }
+
+ private def getJarFilePath(root: String): String = {
+ for (platform <- List("linux-x86_64-cpu", "linux-x86_64-gpu",
"osx-x86_64-cpu")) {
+ val jarFiles = new File(s"$root/$platform/target/").listFiles(new
FileFilter {
+ override def accept(pathname: File) = {
+ pathname.getAbsolutePath.endsWith(".jar") &&
+ !pathname.getAbsolutePath.contains("javadoc") &&
+ !pathname.getAbsolutePath.contains("sources")
+ }
+ })
+ if (jarFiles != null && jarFiles.nonEmpty) {
+ return jarFiles.head.getAbsolutePath
+ }
+ }
+ null
+ }
+
+ private def getSparkJar: String = {
+ val jarFiles = new File(s"$composeWorkingDirPath/target/").listFiles(new
FileFilter {
+ override def accept(pathname: File) = {
+ pathname.getAbsolutePath.endsWith(".jar") &&
+ !pathname.getAbsolutePath.contains("javadoc") &&
+ !pathname.getAbsolutePath.contains("sources")
+ }
+ })
+ if (jarFiles != null && jarFiles.nonEmpty) {
+ jarFiles.head.getAbsolutePath
+ } else {
+ null
+ }
+ }
+
+ protected def buildLeNet(): MXNet = {
+ val workingDir = composeWorkingDirPath
+ val assemblyRoot = s"$workingDir/../assembly"
+ new MXNet()
+ .setBatchSize(128)
+ .setLabelName("softmax_label")
+ .setContext(Array(Context.cpu(0), Context.cpu(1)))
+ .setDimension(Shape(1, 28, 28))
+ .setNetwork(getLenet)
+ .setNumEpoch(10)
+ .setNumServer(1)
+ .setNumWorker(numWorkers)
+ .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar")
+ .setJava("java")
+ }
+
+ protected def buildMlp(): MXNet = {
+ val workingDir = composeWorkingDirPath
+ val assemblyRoot = s"$workingDir/../assembly"
+ new MXNet()
+ .setBatchSize(128)
+ .setLabelName("softmax_label")
+ .setContext(Array(Context.cpu(0), Context.cpu(1)))
+ .setDimension(Shape(784))
+ .setNetwork(getMlp)
+ .setNumEpoch(10)
+ .setNumServer(1)
+ .setNumWorker(numWorkers)
+ .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar")
+ .setJava("java")
+ .setTimeout(0)
+ }
+}
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 373081b..dd3464b 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -61,6 +61,7 @@ class KVStoreDist : public KVStoreLocal {
virtual ~KVStoreDist() {
Engine::Get()->WaitForAll();
+ customer_id_ = 0;
if (IsWorkerNode()) {
if (barrier_before_exit_) {
Barrier();
@@ -183,7 +184,7 @@ class KVStoreDist : public KVStoreLocal {
for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(),
values[i].dtype());
}
- if (get_rank() == 0) {
+ if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() ==
0) {
Push_(keys, values, 0, false);
// wait until the push is finished
for (const int key : keys) {
diff --git a/src/kvstore/kvstore_dist_server.h
b/src/kvstore/kvstore_dist_server.h
index 421de27..a150ff4 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -103,7 +103,8 @@ class Executor {
lk.unlock();
if (blk.f) {
- blk.f(); blk.p->set_value();
+ blk.f();
+ blk.p->set_value();
} else {
blk.p->set_value(); break;
}
--
To stop receiving notification emails like this one, please contact
[email protected].