This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ed80ff2  [MXNET-62] add test against spark integration (#10462)
ed80ff2 is described below

commit ed80ff2c01ff54e82215bf03e8df942ea729a15e
Author: Nan Zhu <coding...@users.noreply.github.com>
AuthorDate: Mon Jun 11 18:22:59 2018 -0700

    [MXNET-62] add test against spark integration (#10462)
    
    * fix bug
    
    * temp
    
    * temp
    
    * temp
    
    * update
    
    * update
    
    * update
    
    * remove debugging stubs
    
    * remove unused
    
    * stylistic fix
    
    * fix typo
    
    * Pulled down update to submodule_dir
    
    * add test
    
    * retrigger it
    
    * sync 3rd party
---
 3rdparty/ps-lite                                   |   2 +-
 include/mxnet/kvstore.h                            |   1 +
 .../scala/org/apache/mxnet/optimizer/SGD.scala     |   7 +-
 scala-package/pom.xml                              |   2 +-
 scala-package/spark/bin/run-mnist-example.sh       |   9 +-
 scala-package/spark/pom.xml                        |  39 +++++-
 .../main/scala/org/apache/mxnet/spark/MXNet.scala  |   7 +-
 .../scala/org/apache/mxnet/spark/MXNetParams.scala |   6 +-
 .../org/apache/mxnet/spark/ParameterServer.scala   |   6 +-
 .../spark/example/ClassificationExample.scala      |   1 +
 .../org/apache/mxnet/spark/MXNetGeneralSuite.scala |  69 ++++++++++
 .../apache/mxnet/spark/SharedSparkContext.scala    | 146 +++++++++++++++++++++
 src/kvstore/kvstore_dist.h                         |   3 +-
 src/kvstore/kvstore_dist_server.h                  |   3 +-
 14 files changed, 282 insertions(+), 19 deletions(-)

diff --git a/3rdparty/ps-lite b/3rdparty/ps-lite
index a6dda54..8a76389 160000
--- a/3rdparty/ps-lite
+++ b/3rdparty/ps-lite
@@ -1 +1 @@
-Subproject commit a6dda54604a07d1fb21b016ed1e3f4246b08222a
+Subproject commit 8a763892a973afc1acd3d4b469d05bb338a83a6e
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 4e99a9c..9e92207 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -229,6 +229,7 @@ class KVStore {
     CHECK(updater) << "invalid updater";
     updater_ = updater;
   }
+
   /*!
    * \brief set an updater with string keys
    *
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
index c1b7259..e228e72 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
@@ -41,14 +41,15 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float 
= 0.0f,
    */
   override def update(index: Int, weight: NDArray, grad: NDArray, state: 
AnyRef): Unit = {
     // TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python 
package)
-    var lr =
-      (if (lrScheduler != null) {
+    var lr = {
+      if (lrScheduler != null) {
         val scheduledLr = lrScheduler(numUpdate)
         updateCount(index)
         scheduledLr
       } else {
         this.learningRate
-      })
+      }
+    }
     lr = getLr(index, lr)
 
     val wd = getWd(index, this.wd)
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 9dcfa7c..cd5dba8 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -242,7 +242,7 @@
       <plugin>
         <groupId>org.apache.maven.plugins</groupId>
         <artifactId>maven-surefire-plugin</artifactId>
-        <version>2.7</version>
+        <version>2.19</version>
         <configuration>
           <skipTests>true</skipTests>
         </configuration>
diff --git a/scala-package/spark/bin/run-mnist-example.sh 
b/scala-package/spark/bin/run-mnist-example.sh
index 962c337..392d6c6 100755
--- a/scala-package/spark/bin/run-mnist-example.sh
+++ b/scala-package/spark/bin/run-mnist-example.sh
@@ -17,6 +17,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+set -x
+
 CURR_DIR=$(cd `dirname $0`; pwd)
 SPARK_MODULE_DIR=$(cd $CURR_DIR/../; pwd)
 SCALA_PKG_DIR=$(cd $CURR_DIR/../../; pwd)
@@ -35,10 +37,7 @@ SPARK_JAR=`find ${SPARK_MODULE_DIR}/target -name "*.jar" 
-type f -exec ls "{}" +
 SCALA_JAR=`find ${SCALA_PKG_DIR}/assembly/$OS/target -maxdepth 1 -name "*.jar" 
-type f -exec ls "{}" + | grep -v -E '(javadoc|sources)'`
 
 SPARK_OPTS+=" --name mxnet-spark-mnist"
-SPARK_OPTS+=" --driver-memory 1g"
-SPARK_OPTS+=" --executor-memory 1g"
-SPARK_OPTS+=" --num-executors 2"
-SPARK_OPTS+=" --executor-cores 1"
+SPARK_OPTS+=" --driver-memory 2g"
 SPARK_OPTS+=" --jars ${SCALA_JAR}"
 
 # Download training and test set
@@ -72,7 +71,7 @@ fi
 
 HOST=`hostname`
 
-$SPARK_HOME/bin/spark-submit --master spark://$HOST:7077 \
+$SPARK_HOME/bin/spark-submit --master local[*] \
   --class org.apache.mxnet.spark.example.ClassificationExample \
   ${SPARK_OPTS} \
   ${SPARK_JAR} \
diff --git a/scala-package/spark/pom.xml b/scala-package/spark/pom.xml
index 281fad4..43ff1f7 100644
--- a/scala-package/spark/pom.xml
+++ b/scala-package/spark/pom.xml
@@ -16,7 +16,44 @@
   <properties>
     <spark.version>1.6.3</spark.version>
   </properties>
-
+  <profiles>
+    <profile>
+      <id>osx-x86_64-cpu</id>
+      <properties>
+        <platform>osx-x86_64-cpu</platform>
+      </properties>
+    </profile>
+    <profile>
+      <id>linux-x86_64-cpu</id>
+      <properties>
+        <platform>linux-x86_64-cpu</platform>
+      </properties>
+    </profile>
+    <profile>
+      <id>linux-x86_64-gpu</id>
+      <properties>
+        <platform>linux-x86_64-gpu</platform>
+      </properties>
+    </profile>
+  </profiles>
+  <build>
+    <plugins>
+      <plugin>
+        <groupId>org.scalatest</groupId>
+        <artifactId>scalatest-maven-plugin</artifactId>
+        <configuration>
+          <argLine>
+            
-Djava.library.path=${project.parent.basedir}/native/${platform}/target \
+            
-Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
+          </argLine>
+        </configuration>
+      </plugin>
+      <plugin>
+        <groupId>org.scalastyle</groupId>
+        <artifactId>scalastyle-maven-plugin</artifactId>
+      </plugin>
+    </plugins>
+  </build>
   <dependencies>
     <dependency>
       <groupId>org.apache.mxnet</groupId>
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
index 9720038..4952ca2 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNet.scala
@@ -127,7 +127,8 @@ class MXNet extends Serializable {
           logger.info("Starting server ...")
           val server = new ParameterServer(params.runtimeClasspath,
             role = "server",
-            rootUri = schedulerIP, rootPort = schedulerPort,
+            rootUri = schedulerIP,
+            rootPort = schedulerPort,
             numServer = params.numServer,
             numWorker = params.numWorker,
             timeout = params.timeout,
@@ -241,7 +242,9 @@ class MXNet extends Serializable {
   def fit(data: RDD[LabeledPoint]): MXNetModel = {
     val sc = data.context
     // distribute native jars
-    params.jars.foreach(jar => sc.addFile(jar))
+    if (params.jars != null) {
+      params.jars.foreach(jar => sc.addFile(jar))
+    }
     val trainData = {
       if (params.numWorker != data.partitions.length) {
         logger.info("repartitioning training set to {} partitions", 
params.numWorker)
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
index 47e6cd4..f72e56e 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetParams.scala
@@ -61,7 +61,11 @@ private[mxnet] class MXNetParams extends Serializable {
   // jars on executors for running mxnet application
   var jars: Array[String] = null
   def runtimeClasspath: String = {
-    jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":")
+    if (jars != null) {
+      jars.map(jar => SparkFiles.get(new File(jar).getName)).mkString(":")
+    } else {
+      ""
+    }
   }
 
   // java binary
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
index 907d3de..45033d4 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/ParameterServer.scala
@@ -51,7 +51,7 @@ private[mxnet] object ParameterServer {
 
   def buildEnv(role: String, rootUri: String, rootPort: Int,
                numServer: Int, numWorker: Int): Map[String, String] = {
-    val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, 
String]
+    val envs = mutable.HashMap.empty[String, String]
     envs.put("DMLC_ROLE", role)
     envs.put("DMLC_PS_ROOT_URI", rootUri)
     envs.put("DMLC_PS_ROOT_PORT", rootPort.toString)
@@ -127,9 +127,9 @@ class ParameterServer(
     val inputStream = psProcess.get().getInputStream
     val errorStream = psProcess.get().getErrorStream
     logger.info(s"Starting InputStream-Redirecter Thread for 
$rootUri:$rootPort")
-    new RedirectThread(inputStream, System.out, "InputStream-Redirecter", 
true).start()
+    new RedirectThread(inputStream, System.out, "InputStream-Redirecter", 
false).start()
     logger.info(s"Starting ErrorStream-Redirecter Thread for 
$rootUri:$rootPort")
-    new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", 
true).start()
+    new RedirectThread(errorStream, System.err, "ErrorStream-Redirecter", 
false).start()
   }
 
   def startProcess(): Int = {
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
index ce49302..2026bde 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/example/ClassificationExample.scala
@@ -103,6 +103,7 @@ object ClassificationExample {
       sc.stop()
     } catch {
       case e: Throwable =>
+        e.printStackTrace()
         logger.error(e.getMessage, e)
         sys.exit(-1)
     }
diff --git 
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
 
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
new file mode 100644
index 0000000..74bc1db
--- /dev/null
+++ 
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.spark
+
+import java.io.{BufferedReader, File, InputStreamReader}
+import java.nio.file.Files
+
+import scala.sys.process.Process
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+
+class MXNetGeneralSuite extends SharedSparkContext {
+
+  private var testDataDir: String = _
+
+  private def parseRawData(sc: SparkContext, path: String): RDD[LabeledPoint] 
= {
+    val raw = sc.textFile(path)
+    raw.map { s =>
+      val parts = s.split(' ')
+      val label = java.lang.Double.parseDouble(parts(0))
+      val features = 
Vectors.dense(parts(1).trim().split(',').map(java.lang.Double.parseDouble))
+      LabeledPoint(label, features)
+    }
+  }
+
+  private def downloadTestData(): Unit = {
+    Process("wget 
http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon"; +
+      "/dataset/mxnet-spark-test/train.txt" + " -P " + testDataDir + " -q") !
+  }
+
+  override def beforeAll(): Unit = {
+    val tempDirFile = 
Files.createTempDirectory(s"mxnet-spark-test-${System.currentTimeMillis()}").
+      toFile
+    testDataDir = tempDirFile.getPath
+    tempDirFile.deleteOnExit()
+    downloadTestData()
+  }
+
+
+  test("run spark with MLP") {
+    val trainData = parseRawData(sc, s"$testDataDir/train.txt")
+    val model = buildMlp().fit(trainData)
+    assert(model != null)
+  }
+
+  test("run spark with LeNet") {
+    val trainData = parseRawData(sc, s"$testDataDir/train.txt")
+    val model = buildLeNet().fit(trainData)
+    assert(model != null)
+  }
+}
diff --git 
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
 
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
new file mode 100644
index 0000000..2efd181
--- /dev/null
+++ 
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.spark
+
+import java.io.{File, FileFilter}
+
+import org.apache.mxnet.{Context, Shape, Symbol}
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+
+trait SharedSparkContext extends FunSuite with BeforeAndAfterEach with 
BeforeAndAfterAll {
+
+  protected var sc: SparkContext = _
+
+  protected val numWorkers: Int = 
math.min(Runtime.getRuntime.availableProcessors(), 2)
+
+  override def beforeEach() {
+    sc = new SparkContext(new 
SparkConf().setMaster("local[*]").setAppName("mxnet-spark-test"))
+  }
+
+  override def afterEach(): Unit = {
+    if (sc != null) {
+      sc.stop()
+    }
+  }
+
+  private def getMlp: Symbol = {
+    val data = Symbol.Variable("data")
+    val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, 
"num_hidden" -> 128))
+    val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, 
"act_type" -> "relu"))
+    val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, 
"num_hidden" -> 64))
+    val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, 
"act_type" -> "relu"))
+    val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, 
"num_hidden" -> 10))
+    val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3))
+    mlp
+  }
+
+  def getLenet: Symbol = {
+    val data = Symbol.Variable("data")
+    // first conv
+    val conv1 = Symbol.Convolution()()(
+      Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20))
+    val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> 
"tanh"))
+    val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max",
+      "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+    // second conv
+    val conv2 = Symbol.Convolution()()(
+      Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50))
+    val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> 
"tanh"))
+    val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
+      "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+    // first fullc
+    val flatten = Symbol.Flatten()()(Map("data" -> pool2))
+    val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 
500))
+    val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh"))
+    // second fullc
+    val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 
10))
+    // loss
+    val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))
+    lenet
+  }
+
+  private def composeWorkingDirPath: String = {
+    System.getProperty("user.dir")
+  }
+
+  private def getJarFilePath(root: String): String = {
+    for (platform <- List("linux-x86_64-cpu", "linux-x86_64-gpu", 
"osx-x86_64-cpu")) {
+      val jarFiles = new File(s"$root/$platform/target/").listFiles(new 
FileFilter {
+        override def accept(pathname: File) = {
+          pathname.getAbsolutePath.endsWith(".jar") &&
+            !pathname.getAbsolutePath.contains("javadoc") &&
+            !pathname.getAbsolutePath.contains("sources")
+        }
+      })
+      if (jarFiles != null && jarFiles.nonEmpty) {
+        return jarFiles.head.getAbsolutePath
+      }
+    }
+    null
+  }
+
+  private def getSparkJar: String = {
+    val jarFiles = new File(s"$composeWorkingDirPath/target/").listFiles(new 
FileFilter {
+      override def accept(pathname: File) = {
+        pathname.getAbsolutePath.endsWith(".jar") &&
+          !pathname.getAbsolutePath.contains("javadoc") &&
+          !pathname.getAbsolutePath.contains("sources")
+      }
+    })
+    if (jarFiles != null && jarFiles.nonEmpty) {
+      jarFiles.head.getAbsolutePath
+    } else {
+      null
+    }
+  }
+
+  protected def buildLeNet(): MXNet = {
+    val workingDir = composeWorkingDirPath
+    val assemblyRoot = s"$workingDir/../assembly"
+    new MXNet()
+      .setBatchSize(128)
+      .setLabelName("softmax_label")
+      .setContext(Array(Context.cpu(0), Context.cpu(1)))
+      .setDimension(Shape(1, 28, 28))
+      .setNetwork(getLenet)
+      .setNumEpoch(10)
+      .setNumServer(1)
+      .setNumWorker(numWorkers)
+      .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar")
+      .setJava("java")
+  }
+
+  protected def buildMlp(): MXNet = {
+    val workingDir = composeWorkingDirPath
+    val assemblyRoot = s"$workingDir/../assembly"
+    new MXNet()
+      .setBatchSize(128)
+      .setLabelName("softmax_label")
+      .setContext(Array(Context.cpu(0), Context.cpu(1)))
+      .setDimension(Shape(784))
+      .setNetwork(getMlp)
+      .setNumEpoch(10)
+      .setNumServer(1)
+      .setNumWorker(numWorkers)
+      .setExecutorJars(s"${getJarFilePath(assemblyRoot)},$getSparkJar")
+      .setJava("java")
+      .setTimeout(0)
+  }
+}
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 373081b..dd3464b 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -61,6 +61,7 @@ class KVStoreDist : public KVStoreLocal {
 
   virtual ~KVStoreDist() {
     Engine::Get()->WaitForAll();
+    customer_id_ = 0;
     if (IsWorkerNode()) {
       if (barrier_before_exit_) {
         Barrier();
@@ -183,7 +184,7 @@ class KVStoreDist : public KVStoreLocal {
     for (size_t i = 0; i < keys.size(); ++i) {
       comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), 
values[i].dtype());
     }
-    if (get_rank() == 0) {
+    if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 
0) {
       Push_(keys, values, 0, false);
       // wait until the push is finished
       for (const int key : keys) {
diff --git a/src/kvstore/kvstore_dist_server.h 
b/src/kvstore/kvstore_dist_server.h
index 421de27..a150ff4 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -103,7 +103,8 @@ class Executor {
       lk.unlock();
 
       if (blk.f) {
-        blk.f(); blk.p->set_value();
+        blk.f();
+        blk.p->set_value();
       } else {
         blk.p->set_value(); break;
       }

-- 
To stop receiving notification emails like this one, please contact
liuyi...@apache.org.

Reply via email to