This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 92f0c51  [MXNET-531] GAN MNIST Examples for Scala new API (#11547)
92f0c51 is described below

commit 92f0c512336db4219fdcd97fccf74dcdeb56dbf4
Author: Lanking <[email protected]>
AuthorDate: Tue Jul 3 22:19:10 2018 -0700

    [MXNET-531] GAN MNIST Examples for Scala new API (#11547)
    
    * add gan base file and example suite
---
 .../org/apache/mxnetexamples/gan/GanMnist.scala    | 227 ++++++++++-----------
 .../org/apache/mxnetexamples/gan/Module.scala      |   3 -
 .../scala/org/apache/mxnetexamples/gan/README.md   |  18 ++
 .../apache/mxnetexamples/gan/GanExampleSuite.scala |  60 ++++++
 4 files changed, 188 insertions(+), 120 deletions(-)

diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 4fa96f6..6186989 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -17,66 +17,52 @@
 
 package org.apache.mxnetexamples.gan
 
+import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, Shape, 
Symbol, Xavier}
+import org.apache.mxnet.optimizer.Adam
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
+
 import scala.collection.JavaConverters._
-import Viz._
-import org.apache.mxnet.Context
-import org.apache.mxnet.Shape
-import org.apache.mxnet.IO
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.CustomMetric
-import org.apache.mxnet.Xavier
-import org.apache.mxnet.optimizer.Adam
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Shape
 
-/**
- * @author Depeng Liang
- */
 object GanMnist {
 
   private val logger = LoggerFactory.getLogger(classOf[GanMnist])
 
-    // a deconv layer that enlarges the feature map
+  // a deconv layer that enlarges the feature map
   def deconv2D(data: Symbol, iShape: Shape, oShape: Shape,
-    kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
-    val targetShape = (oShape(oShape.length - 2), oShape(oShape.length - 1))
-    val net = Symbol.Deconvolution(name)()(Map(
-                                           "data" -> data,
-                                           "kernel" -> s"$kShape",
-                                           "stride" -> s"$stride",
-                                           "target_shape" -> s"$targetShape",
-                                           "num_filter" -> oShape(0),
-                                           "no_bias" -> true))
+               kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): 
Symbol = {
+    val targetShape = Shape(oShape(oShape.length - 2), oShape(oShape.length - 
1))
+    val net = Symbol.api.Deconvolution(data = Some(data), kernel = 
Shape(kShape._1, kShape._2),
+      stride = Some(Shape(stride._1, stride._2)), target_shape = 
Some(targetShape),
+      num_filter = oShape(0), no_bias = Some(true), name = name)
     net
   }
 
   def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape,
-      oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol 
= {
+                     oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 
1e-12f): Symbol = {
     var net = deconv2D(data, iShape, oShape, kShape, name = 
s"${prefix}_deconv")
-    net = Symbol.BatchNorm(s"${prefix}_bn")()(Map("data" -> net, "fix_gamma" 
-> true, "eps" -> eps))
-    net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" 
-> "relu"))
+    net = Symbol.api.BatchNorm(name = s"${prefix}_bn", data = Some(net),
+      fix_gamma = Some(true), eps = Some(eps))
+    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = 
s"${prefix}_act")
     net
   }
 
   def deconv2DAct(data: Symbol, prefix: String, actType: String,
-    iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
+                  iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
     var net = deconv2D(data, iShape, oShape, kShape, name = 
s"${prefix}_deconv")
-    net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" 
-> actType))
+    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = 
s"${prefix}_act")
     net
   }
 
   def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid",
-      eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
+                   eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
 
     val code = Symbol.Variable("rand")
-    var net = Symbol.FullyConnected("g1")()(Map("data" -> code,
-      "num_hidden" -> 4 * 4 * ngf * 4, "no_bias" -> true))
-    net = Symbol.Activation("gact1")()(Map("data" -> net, "act_type" -> 
"relu"))
+    var net = Symbol.api.FullyConnected(data = Some(code), num_hidden = 4 * 4 
* ngf * 4,
+      no_bias = Some(true), name = " g1")
+    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = 
"gact1")
     // 4 x 4
-    net = Symbol.Reshape()()(Map("data" -> net, "shape" -> s"(-1, ${ngf * 4}, 
4, 4)"))
+    net = Symbol.api.Reshape(data = Some(net), shape = Some(Shape(-1, ngf * 4, 
4, 4)))
     // 8 x 8
     net = deconv2DBnRelu(net, prefix = "g2",
       iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape = 
(3, 3))
@@ -89,22 +75,22 @@ object GanMnist {
 
     val data = Symbol.Variable("data")
     // 28 x 28
-    val conv1 = Symbol.Convolution("conv1")()(Map("data" -> data,
-      "kernel" -> "(5,5)", "num_filter" -> 20))
-    val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> 
"tanh"))
-    val pool1 = Symbol.Pooling()()(Map("data" -> tanh1,
-      "pool_type" -> "max", "kernel" -> "(2,2)", "stride" -> "(2,2)"))
+    val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5),
+      num_filter = 20, name = "conv1")
+    val tanh1 = Symbol.api.Activation(data = Some(conv1), act_type = "tanh")
+    val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
     // second conv
-    val conv2 = Symbol.Convolution("conv2")()(Map("data" -> pool1,
-      "kernel" -> "(5,5)", "num_filter" -> 50))
-    val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> 
"tanh"))
-    val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
-      "kernel" -> "(2,2)", "stride" -> "(2,2)"))
-    var d5 = Symbol.Flatten()()(Map("data" -> pool2))
-    d5 = Symbol.FullyConnected("fc1")()(Map("data" -> d5, "num_hidden" -> 500))
-    d5 = Symbol.Activation()()(Map("data" -> d5, "act_type" -> "tanh"))
-    d5 = Symbol.FullyConnected("fc_dloss")()(Map("data" -> d5, "num_hidden" -> 
1))
-    val dloss = Symbol.LogisticRegressionOutput("dloss")()(Map("data" -> d5))
+    val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 
5),
+      num_filter = 50, name = "conv2")
+    val tanh2 = Symbol.api.Activation(data = Some(conv2), act_type = "tanh")
+    val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+    var d5 = Symbol.api.Flatten(data = Some(pool2))
+    d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 500, name = 
"fc1")
+    d5 = Symbol.api.Activation(data = Some(d5), act_type = "tanh")
+    d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 1, name = 
"fc_dloss")
+    val dloss = Symbol.api.LogisticRegressionOutput(data = Some(d5), name = 
"dloss")
 
     (gout, dloss)
   }
@@ -116,6 +102,79 @@ object GanMnist {
     labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum / 
label.shape(0)
   }
 
+  def runTraining(dataPath : String, context : Context,
+                  outputPath : String, numEpoch : Int): Float = {
+    val lr = 0.0005f
+    val beta1 = 0.5f
+    val batchSize = 100
+    val randShape = Shape(batchSize, 100)
+    val dataShape = Shape(batchSize, 1, 28, 28)
+
+    val (symGen, symDec) =
+      makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
+
+    val gMod = new GANModule(
+      symGen,
+      symDec,
+      context = context,
+      dataShape = dataShape,
+      codeShape = randShape)
+
+    gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
+    gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
+
+    gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
+
+    val params = Map(
+      "image" -> s"$dataPath/train-images-idx3-ubyte",
+      "label" -> s"$dataPath/train-labels-idx1-ubyte",
+      "input_shape" -> s"(1, 28, 28)",
+      "batch_size" -> s"$batchSize",
+      "shuffle" -> "True"
+    )
+
+    val mnistIter = IO.MNISTIter(params)
+
+    val metricAcc = new CustomMetric(ferr, "ferr")
+
+    var t = 0
+    var dataBatch: DataBatch = null
+    var acc = 0.0f
+    for (epoch <- 0 until numEpoch) {
+      mnistIter.reset()
+      metricAcc.reset()
+      t = 0
+      while (mnistIter.hasNext) {
+        dataBatch = mnistIter.next()
+        gMod.update(dataBatch)
+        gMod.dLabel.set(0f)
+        metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
+        gMod.dLabel.set(1f)
+        metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
+
+        if (t % 50 == 0) {
+          val (name, value) = metricAcc.get
+          acc = value(0)
+          logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
+          Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
+          val diff = gMod.tempDiffD
+          val arr = diff.toArray
+          val mean = arr.sum / arr.length
+          val std = {
+            val tmpA = arr.map(a => (a - mean) * (a - mean))
+            Math.sqrt(tmpA.sum / tmpA.length).toFloat
+          }
+          diff.set((diff - mean) / std + 0.5f)
+          Viz.imSave("diff", outputPath, diff, flip = true)
+          Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
+        }
+
+        t += 1
+      }
+    }
+    acc
+  }
+
   def main(args: Array[String]): Unit = {
     val anst = new GanMnist
     val parser: CmdLineParser = new CmdLineParser(anst)
@@ -123,78 +182,12 @@ object GanMnist {
       parser.parseArgument(args.toList.asJava)
 
       val dataPath = if (anst.mnistDataPath == null) 
System.getenv("MXNET_DATA_DIR")
-        else anst.mnistDataPath
+      else anst.mnistDataPath
 
       assert(dataPath != null)
-
-      val lr = 0.0005f
-      val beta1 = 0.5f
-      val batchSize = 100
-      val randShape = Shape(batchSize, 100)
-      val numEpoch = 100
-      val dataShape = Shape(batchSize, 1, 28, 28)
       val context = if (anst.gpu == -1) Context.cpu() else 
Context.gpu(anst.gpu)
 
-      val (symGen, symDec) =
-      makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
-
-      val gMod = new GANModule(
-          symGen,
-          symDec,
-          context = context,
-          dataShape = dataShape,
-          codeShape = randShape)
-
-      gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
-      gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
-
-      gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
-
-      val params = Map(
-        "image" -> s"${dataPath}/train-images-idx3-ubyte",
-        "label" -> s"${dataPath}/train-labels-idx1-ubyte",
-        "input_shape" -> s"(1, 28, 28)",
-        "batch_size" -> s"$batchSize",
-        "shuffle" -> "True"
-      )
-
-      val mnistIter = IO.MNISTIter(params)
-
-      val metricAcc = new CustomMetric(ferr, "ferr")
-
-      var t = 0
-      var dataBatch: DataBatch = null
-      for (epoch <- 0 until numEpoch) {
-        mnistIter.reset()
-        metricAcc.reset()
-        t = 0
-        while (mnistIter.hasNext) {
-          dataBatch = mnistIter.next()
-          gMod.update(dataBatch)
-          gMod.dLabel.set(0f)
-          metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
-          gMod.dLabel.set(1f)
-          metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
-
-          if (t % 50 == 0) {
-            val (name, value) = metricAcc.get
-            logger.info(s"epoch: $epoch, iter $t, metric=$value")
-            Viz.imSave("gout", anst.outputPath, gMod.tempOutG(0), flip = true)
-            val diff = gMod.tempDiffD
-            val arr = diff.toArray
-            val mean = arr.sum / arr.length
-            val std = {
-              val tmpA = arr.map(a => (a - mean) * (a - mean))
-              Math.sqrt(tmpA.sum / tmpA.length).toFloat
-            }
-            diff.set((diff - mean) / std + 0.5f)
-            Viz.imSave("diff", anst.outputPath, diff, flip = true)
-            Viz.imSave("data", anst.outputPath, dataBatch.data(0), flip = true)
-          }
-
-          t += 1
-        }
-      }
+      runTraining(dataPath, context, anst.outputPath, 100)
     } catch {
       case ex: Exception => {
         logger.error(ex.getMessage, ex)
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
index faab945..55b5296 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
@@ -26,9 +26,6 @@ import org.apache.mxnet.Initializer
 import org.apache.mxnet.DataBatch
 import org.apache.mxnet.Random
 
-/**
- * @author Depeng Liang
- */
 class GANModule(
               symbolGenerator: Symbol,
               symbolEncoder: Symbol,
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md
new file mode 100644
index 0000000..40db092
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md
@@ -0,0 +1,18 @@
+# GAN MNIST Example for Scala
+This is the GAN MNIST Training Example implemented for Scala type-safe api
+
+This example is only for Illustration and not modeled to achieve the best 
accuracy.
+## Setup
+### Download the source File
+```$xslt
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/mnist/mnist.zip
+```
+### Unzip the file
+```$xslt
+unzip mnist.zip
+```
+### Arguement Configuration
+Then you need to define the arguments that you would like to pass in the model:
+```$xslt
+--mnist-data-path <location of your downloaded file>
+```
\ No newline at end of file
diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
new file mode 100644
index 0000000..12459fb
--- /dev/null
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.gan
+
+import java.io.File
+import java.net.URL
+
+import org.apache.commons.io.FileUtils
+import org.apache.mxnet.Context
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+import scala.sys.process.Process
+
+class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
+  private val logger = LoggerFactory.getLogger(classOf[GanExampleSuite])
+
+  test("Example CI: Test GAN MNIST") {
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      logger.info("Downloading mnist model")
+      val baseUrl = 
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci";
+      val tempDirPath = System.getProperty("java.io.tmpdir")
+      val modelDirPath = tempDirPath + File.separator + "mnist/"
+      logger.info("tempDirPath: %s".format(tempDirPath))
+      val tmpFile = new File(tempDirPath + "/mnist/mnist.zip")
+      if (!tmpFile.exists()) {
+        FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"),
+          tmpFile)
+      }
+      // TODO: Need to confirm with Windows
+      Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+        + tempDirPath + "/mnist/") !
+
+      val context = Context.gpu()
+
+      val output = GanMnist.runTraining(modelDirPath, context, modelDirPath, 5)
+      Process("rm -rf " + modelDirPath) !
+
+      assert(output >= 0.0f)
+    } else {
+      logger.info("GPU test only, skipped...")
+    }
+  }
+}

Reply via email to