yzhliu closed pull request #10462: [MXNET-62] add test against spark integration
URL: https://github.com/apache/incubator-mxnet/pull/10462
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/3rdparty/ps-lite b/3rdparty/ps-lite
index a6dda54604a..8a763892a97 160000
--- a/3rdparty/ps-lite
+++ b/3rdparty/ps-lite
@@ -1 +1 @@
-Subproject commit a6dda54604a07d1fb21b016ed1e3f4246b08222a
+Subproject commit 8a763892a973afc1acd3d4b469d05bb338a83a6e
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 4e99a9c861f..9e92207fb8d 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -229,6 +229,7 @@ class KVStore {
CHECK(updater) << "invalid updater";
updater_ = updater;
}
+
/*!
* \brief set an updater with string keys
*
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
index c1b72591952..e228e7273d8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
@@ -41,14 +41,15 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float
= 0.0f,
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state:
AnyRef): Unit = {
// TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python
package)
- var lr =
- (if (lrScheduler != null) {
+ var lr = {
+ if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
- })
+ }
+ }
lr = getLr(index, lr)
val wd = getWd(index, this.wd)
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 9dcfa7ca27e..cd5dba85dfd 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -242,7 +242,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
- <version>2.7</version>
+ <version>2.19</version>
<configuration>
<skipTests>true</skipTests>
</configuration>
diff --git a/scala-package/spark/bin/run-mnist-example.sh
b/scala-package/spark/bin/run-mnist-example.sh
index 962c3375a9d..392d6c6a7cf 100755
--- a/scala-package/spark/bin/run-mnist-example.sh
+++ b/scala-package/spark/bin/run-mnist-example.sh
@@ -17,6 +17,8 @@
# specific language governing permissions and limitations
# under the License.
+set -x
+
CURR_DIR=$(cd `dirname $0`; pwd)
SPARK_MODULE_DIR=$(cd $CURR_DIR/../; pwd)
SCALA_PKG_DIR=$(cd $CURR_DIR/../../; pwd)
@@ -35,10 +37,7 @@ SPARK_JAR=`find ${SPARK_MODULE_DIR}/target -name "*.jar"
-type f -exec ls "{}" +
SCALA_JAR=`find ${SCALA_PKG_DIR}/assembly/$OS/target -maxdepth 1 -name "*.jar"
-type f -exec ls "{}" + | grep -v -E '(javadoc|sources)'`
SPARK_OPTS+=" --name mxnet-spark-mnist"
-SPARK_OPTS+=" --driver-memory 1g"
-SPARK_OPTS+=" --executor-memory 1g"
-SPARK_OPTS+=" --num-executors 2"
-SPARK_OPTS+=" --executor-cores 1"
+SPARK_OPTS+=" --driver-memory 2g"
SPARK_OPTS+=" --jars ${SCALA_JAR}"
# Download training and test set
@@ -72,7 +71,7 @@ fi
HOST=`hostname`
-$SPARK_HOME/bin/spark-submit --master spark://$HOST:7077 \
+$SPARK_HOME/bin/spark-submit --master local[*] \
--class org.apache.mxnet.spark.example.ClassificationExample \
${SPARK_OPTS} \
${SPARK_JAR} \
diff --git a/scala-package/spark/pom.xml b/scala-package/spark/pom.xml
index 281fad4056f..43ff1f78fe1 100644
--- a/scala-package/spark/pom.xml
+++ b/scala-package/spark/pom.xml
@@ -16,7 +16,44 @@
<properties>
<spark.version>1.6.3</spark.version>
</properties>
-
+ <profiles>
+ <profile>
+ <id>osx-x86_64-cpu</id>
+ <properties>
+ <platform>osx-x86_64-cpu</platform>
+ </properties>
+ </profile>
+ <profile>
+ <id>linux-x86_64-cpu</id>
+ <properties>
+ <platform>linux-x86_64-cpu</platform>
+ </properties>
+ </profile>
+ <profile>
+ <id>linux-x86_64-gpu</id>
+ <properties>
+ <platform>linux-x86_64-gpu</platform>
+ </properties>
+ </profile>
+ </profiles>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <configuration>
+ <argLine>
+
-Djava.library.path=${project.parent.basedir}/native/${platform}/target \
+
-Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
+ </argLine>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.scalastyle</groupId>
+ <artifactId>scalastyle-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
<dependencies>
<dependency>
<groupId>org.apache.mxnet</groupId>
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
index 9720038afac..4952ca2626d 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
@@ -127,7 +127,8 @@ class MXNet extends Serializable {
logger.info("Starting server ...")
val server = new ParameterServer(params.runtimeClasspath,
role = "server",
- rootUri = schedulerIP, rootPort = schedulerPort,
+ rootUri = schedulerIP,
+ rootPort = schedulerPort,
numServer = params.numServer,
numWorker = params.numWorker,
timeout = params.timeout,
@@ -241,7 +242,9 @@ class MXNet extends Serializable {
def fit(data: RDD[LabeledPoint]): MXNetModel = {
val sc = data.context
// distribute native jars
- params.jars.foreach(jar => sc.addFile(jar))
+ if (params.jars != null) {
+ params.jars.foreach(jar => sc.addFile(jar))
+ }
val trainData = {
if (params.numWorker != data.partitions.length) {
logger.info("repartitioning training set to {} partitions",
params.numWorker)
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
index 47e6cd49113..f72e56e9efb 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
@@ -61,7 +61,11 @@ private[mxnet] class MXNetParams extends Serializable {
// jars on executors for running mxnet application
var jars: Array[String] = null
def runtimeClasspath: String = {
- jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":")
+ if (jars != null) {
+ jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":")
+ } else {
+ ""
+ }
}
// java binary
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
index 907d3decde5..45033d48c6a 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
@@ -51,7 +51,7 @@ private[mxnet] object ParameterServer {
def buildEnv(role: String, rootUri: String, rootPort: Int,
numServer: Int, numWorker: Int): Map[String, String] = {
- val envs: mutable.Map[String, String] = mutable.HashMap.empty[String,
String]
+ val envs = mutable.HashMap.empty[String, String]
envs.put("DMLC_ROLE", role)
envs.put("DMLC_PS_ROOT_URI", rootUri)
envs.put("DMLC_PS_ROOT_PORT", rootPort.toString)
@@ -127,9 +127,9 @@ class ParameterServer(
val inputStream = psProcess.get().getInputStream
val errorStream = psProcess.get().getErrorStream
logger.info(s"Starting InputStream-Redirecter Thread for
$rootUri:$rootPort")
- new RedirectThread(inputStream, System.out, "InputStream-Redirecter",
true).start()
+ new RedirectThread(inputStream, System.out, "InputStream-Redirecter",
false).start()
logger.info(s"Starting ErrorStream-Redirecter Thread for
$rootUri:$rootPort")
- new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter",
true).start()
+ new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter",
false).start()
}
def startProcess(): Int = {
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
index ce49302fd88..2026bdee9fe 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
@@ -103,6 +103,7 @@ object ClassificationExample {
sc.stop()
} catch {
case e: Throwable =>
+ e.printStackTrace()
logger.error(e.getMessage, e)
sys.exit(-1)
}
diff --git
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
new file mode 100644
index 00000000000..74bc1dbb71f
--- /dev/null
+++
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.spark
+
+import java.io.{BufferedReader, File, InputStreamReader}
+import java.nio.file.Files
+
+import scala.sys.process.Process
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+
+class MXNetGeneralSuite extends SharedSparkContext {
+
+ private var testDataDir: String = _
+
+ private def parseRawData(sc: SparkContext, path: String): RDD[LabeledPoint]
= {
+ val raw = sc.textFile(path)
+ raw.map { s =>
+ val parts = s.split(' ')
+ val label = java.lang.Double.parseDouble(parts(0))
+ val features =
Vectors.dense(parts(1).trim().split(',').map(java.lang.Double.parseDouble))
+ LabeledPoint(label, features)
+ }
+ }
+
+ private def downloadTestData(): Unit = {
+ Process("wget
http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon" +
+ "/dataset/mxnet-spark-test/train.txt" + " -P " + testDataDir + " -q") !
+ }
+
+ override def beforeAll(): Unit = {
+ val tempDirFile =
Files.createTempDirectory(s"mxnet-spark-test-${System.currentTimeMillis()}").
+ toFile
+ testDataDir = tempDirFile.getPath
+ tempDirFile.deleteOnExit()
+ downloadTestData()
+ }
+
+
+ test("run spark with MLP") {
+ val trainData = parseRawData(sc, s"$testDataDir/train.txt")
+ val model = buildMlp().fit(trainData)
+ assert(model != null)
+ }
+
+ test("run spark with LeNet") {
+ val trainData = parseRawData(sc, s"$testDataDir/train.txt")
+ val model = buildLeNet().fit(trainData)
+ assert(model != null)
+ }
+}
diff --git
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
new file mode 100644
index 00000000000..2efd1814bc9
--- /dev/null
+++
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.spark
+
+import java.io.{File, FileFilter}
+
+import org.apache.mxnet.{Context, Shape, Symbol}
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+
+trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with
BeforeAndAfterAll {
+
+ protected var sc: SparkContext = _
+
+ protected val numWorkers: Int =
math.min(Runtime.getRuntime.availableProcessors(), 2)
+
+ override def beforeEach() {
+ sc = new SparkContext(new
SparkConf().setMaster("local[*]").setAppName("mxnet-spark-test"))
+ }
+
+ override def afterEach(): Unit = {
+ if (sc != null) {
+ sc.stop()
+ }
+ }
+
+ private def getMlp: Symbol = {
+ val data = Symbol.Variable("data")
+ val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data,
"num_hidden" -> 128))
+ val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1,
"act_type" -> "relu"))
+ val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1,
"num_hidden" -> 64))
+ val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2,
"act_type" -> "relu"))
+ val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2,
"num_hidden" -> 10))
+ val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3))
+ mlp
+ }
+
+ def getLenet: Symbol = {
+ val data = Symbol.Variable("data")
+ // first conv
+ val conv1 = Symbol.Convolution()()(
+ Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20))
+ val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" ->
"tanh"))
+ val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max",
+ "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+ // second conv
+ val conv2 = Symbol.Convolution()()(
+ Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50))
+ val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" ->
"tanh"))
+ val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
+ "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+ // first fullc
+ val flatten = Symbol.Flatten()()(Map("data" -> pool2))
+ val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" ->
500))
+ val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh"))
+ // second fullc
+ val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" ->
10))
+ // loss
+ val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))
+ lenet
+ }
+
+ private def composeWorkingDirPath: String = {
+ System.getProperty("user.dir")
+ }
+
+ private def getJarFilePath(root: String): String = {
+ for (platform <- List("linux-x86_64-cpu", "linux-x86_64-gpu",
"osx-x86_64-cpu")) {
+ val jarFiles = new File(s"$root/$platform/target/").listFiles(new
FileFilter {
+ override def accept(pathname: File) = {
+ pathname.getAbsolutePath.endsWith(".jar") &&
+ !pathname.getAbsolutePath.contains("javadoc") &&
+ !pathname.getAbsolutePath.contains("sources")
+ }
+ })
+ if (jarFiles != null && jarFiles.nonEmpty) {
+ return jarFiles.head.getAbsolutePath
+ }
+ }
+ null
+ }
+
+ private def getSparkJar: String = {
+ val jarFiles = new File(s"$composeWorkingDirPath/target/").listFiles(new
FileFilter {
+ override def accept(pathname: File) = {
+ pathname.getAbsolutePath.endsWith(".jar") &&
+ !pathname.getAbsolutePath.contains("javadoc") &&
+ !pathname.getAbsolutePath.contains("sources")
+ }
+ })
+ if (jarFiles != null && jarFiles.nonEmpty) {
+ jarFiles.head.getAbsolutePath
+ } else {
+ null
+ }
+ }
+
+ protected def buildLeNet(): MXNet = {
+ val workingDir = composeWorkingDirPath
+ val assemblyRoot = s"$workingDir/../assembly"
+ new MXNet()
+ .setBatchSize(128)
+ .setLabelName("softmax_label")
+ .setContext(Array(Context.cpu(0), Context.cpu(1)))
+ .setDimension(Shape(1, 28, 28))
+ .setNetwork(getLenet)
+ .setNumEpoch(10)
+ .setNumServer(1)
+ .setNumWorker(numWorkers)
+ .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar")
+ .setJava("java")
+ }
+
+ protected def buildMlp(): MXNet = {
+ val workingDir = composeWorkingDirPath
+ val assemblyRoot = s"$workingDir/../assembly"
+ new MXNet()
+ .setBatchSize(128)
+ .setLabelName("softmax_label")
+ .setContext(Array(Context.cpu(0), Context.cpu(1)))
+ .setDimension(Shape(784))
+ .setNetwork(getMlp)
+ .setNumEpoch(10)
+ .setNumServer(1)
+ .setNumWorker(numWorkers)
+ .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar")
+ .setJava("java")
+ .setTimeout(0)
+ }
+}
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 373081bc7b1..dd3464bf6db 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -61,6 +61,7 @@ class KVStoreDist : public KVStoreLocal {
virtual ~KVStoreDist() {
Engine::Get()->WaitForAll();
+ customer_id_ = 0;
if (IsWorkerNode()) {
if (barrier_before_exit_) {
Barrier();
@@ -183,7 +184,7 @@ class KVStoreDist : public KVStoreLocal {
for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(),
values[i].dtype());
}
- if (get_rank() == 0) {
+ if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() ==
0) {
Push_(keys, values, 0, false);
// wait until the push is finished
for (const int key : keys) {
diff --git a/src/kvstore/kvstore_dist_server.h
b/src/kvstore/kvstore_dist_server.h
index 421de27b39d..a150ff42f57 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -103,7 +103,8 @@ class Executor {
lk.unlock();
if (blk.f) {
- blk.f(); blk.p->set_value();
+ blk.f();
+ blk.p->set_value();
} else {
blk.p->set_value(); break;
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services