nswamy closed pull request #11547: [MXNET-531] GAN MNIST Examples for Scala new 
API
URL: https://github.com/apache/incubator-mxnet/pull/11547
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 4fa96f62875..6186989b74f 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -17,66 +17,52 @@
 
 package org.apache.mxnetexamples.gan
 
+import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, Shape, 
Symbol, Xavier}
+import org.apache.mxnet.optimizer.Adam
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
+
 import scala.collection.JavaConverters._
-import Viz._
-import org.apache.mxnet.Context
-import org.apache.mxnet.Shape
-import org.apache.mxnet.IO
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.CustomMetric
-import org.apache.mxnet.Xavier
-import org.apache.mxnet.optimizer.Adam
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Shape
 
-/**
- * @author Depeng Liang
- */
 object GanMnist {
 
   private val logger = LoggerFactory.getLogger(classOf[GanMnist])
 
-    // a deconv layer that enlarges the feature map
+  // a deconv layer that enlarges the feature map
   def deconv2D(data: Symbol, iShape: Shape, oShape: Shape,
-    kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
-    val targetShape = (oShape(oShape.length - 2), oShape(oShape.length - 1))
-    val net = Symbol.Deconvolution(name)()(Map(
-                                           "data" -> data,
-                                           "kernel" -> s"$kShape",
-                                           "stride" -> s"$stride",
-                                           "target_shape" -> s"$targetShape",
-                                           "num_filter" -> oShape(0),
-                                           "no_bias" -> true))
+               kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): 
Symbol = {
+    val targetShape = Shape(oShape(oShape.length - 2), oShape(oShape.length - 
1))
+    val net = Symbol.api.Deconvolution(data = Some(data), kernel = 
Shape(kShape._1, kShape._2),
+      stride = Some(Shape(stride._1, stride._2)), target_shape = 
Some(targetShape),
+      num_filter = oShape(0), no_bias = Some(true), name = name)
     net
   }
 
   def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape,
-      oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol 
= {
+                     oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 
1e-12f): Symbol = {
     var net = deconv2D(data, iShape, oShape, kShape, name = 
s"${prefix}_deconv")
-    net = Symbol.BatchNorm(s"${prefix}_bn")()(Map("data" -> net, "fix_gamma" 
-> true, "eps" -> eps))
-    net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" 
-> "relu"))
+    net = Symbol.api.BatchNorm(name = s"${prefix}_bn", data = Some(net),
+      fix_gamma = Some(true), eps = Some(eps))
+    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = 
s"${prefix}_act")
     net
   }
 
   def deconv2DAct(data: Symbol, prefix: String, actType: String,
-    iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
+                  iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
     var net = deconv2D(data, iShape, oShape, kShape, name = 
s"${prefix}_deconv")
-    net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" 
-> actType))
+    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = 
s"${prefix}_act")
     net
   }
 
   def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid",
-      eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
+                   eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
 
     val code = Symbol.Variable("rand")
-    var net = Symbol.FullyConnected("g1")()(Map("data" -> code,
-      "num_hidden" -> 4 * 4 * ngf * 4, "no_bias" -> true))
-    net = Symbol.Activation("gact1")()(Map("data" -> net, "act_type" -> 
"relu"))
+    var net = Symbol.api.FullyConnected(data = Some(code), num_hidden = 4 * 4 
* ngf * 4,
+      no_bias = Some(true), name = " g1")
+    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = 
"gact1")
     // 4 x 4
-    net = Symbol.Reshape()()(Map("data" -> net, "shape" -> s"(-1, ${ngf * 4}, 
4, 4)"))
+    net = Symbol.api.Reshape(data = Some(net), shape = Some(Shape(-1, ngf * 4, 
4, 4)))
     // 8 x 8
     net = deconv2DBnRelu(net, prefix = "g2",
       iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape = 
(3, 3))
@@ -89,22 +75,22 @@ object GanMnist {
 
     val data = Symbol.Variable("data")
     // 28 x 28
-    val conv1 = Symbol.Convolution("conv1")()(Map("data" -> data,
-      "kernel" -> "(5,5)", "num_filter" -> 20))
-    val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> 
"tanh"))
-    val pool1 = Symbol.Pooling()()(Map("data" -> tanh1,
-      "pool_type" -> "max", "kernel" -> "(2,2)", "stride" -> "(2,2)"))
+    val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5),
+      num_filter = 20, name = "conv1")
+    val tanh1 = Symbol.api.Activation(data = Some(conv1), act_type = "tanh")
+    val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
     // second conv
-    val conv2 = Symbol.Convolution("conv2")()(Map("data" -> pool1,
-      "kernel" -> "(5,5)", "num_filter" -> 50))
-    val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> 
"tanh"))
-    val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
-      "kernel" -> "(2,2)", "stride" -> "(2,2)"))
-    var d5 = Symbol.Flatten()()(Map("data" -> pool2))
-    d5 = Symbol.FullyConnected("fc1")()(Map("data" -> d5, "num_hidden" -> 500))
-    d5 = Symbol.Activation()()(Map("data" -> d5, "act_type" -> "tanh"))
-    d5 = Symbol.FullyConnected("fc_dloss")()(Map("data" -> d5, "num_hidden" -> 
1))
-    val dloss = Symbol.LogisticRegressionOutput("dloss")()(Map("data" -> d5))
+    val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 
5),
+      num_filter = 50, name = "conv2")
+    val tanh2 = Symbol.api.Activation(data = Some(conv2), act_type = "tanh")
+    val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+    var d5 = Symbol.api.Flatten(data = Some(pool2))
+    d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 500, name = 
"fc1")
+    d5 = Symbol.api.Activation(data = Some(d5), act_type = "tanh")
+    d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 1, name = 
"fc_dloss")
+    val dloss = Symbol.api.LogisticRegressionOutput(data = Some(d5), name = 
"dloss")
 
     (gout, dloss)
   }
@@ -116,6 +102,79 @@ object GanMnist {
     labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum / 
label.shape(0)
   }
 
+  def runTraining(dataPath : String, context : Context,
+                  outputPath : String, numEpoch : Int): Float = {
+    val lr = 0.0005f
+    val beta1 = 0.5f
+    val batchSize = 100
+    val randShape = Shape(batchSize, 100)
+    val dataShape = Shape(batchSize, 1, 28, 28)
+
+    val (symGen, symDec) =
+      makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
+
+    val gMod = new GANModule(
+      symGen,
+      symDec,
+      context = context,
+      dataShape = dataShape,
+      codeShape = randShape)
+
+    gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
+    gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
+
+    gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
+
+    val params = Map(
+      "image" -> s"$dataPath/train-images-idx3-ubyte",
+      "label" -> s"$dataPath/train-labels-idx1-ubyte",
+      "input_shape" -> s"(1, 28, 28)",
+      "batch_size" -> s"$batchSize",
+      "shuffle" -> "True"
+    )
+
+    val mnistIter = IO.MNISTIter(params)
+
+    val metricAcc = new CustomMetric(ferr, "ferr")
+
+    var t = 0
+    var dataBatch: DataBatch = null
+    var acc = 0.0f
+    for (epoch <- 0 until numEpoch) {
+      mnistIter.reset()
+      metricAcc.reset()
+      t = 0
+      while (mnistIter.hasNext) {
+        dataBatch = mnistIter.next()
+        gMod.update(dataBatch)
+        gMod.dLabel.set(0f)
+        metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
+        gMod.dLabel.set(1f)
+        metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
+
+        if (t % 50 == 0) {
+          val (name, value) = metricAcc.get
+          acc = value(0)
+          logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
+          Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
+          val diff = gMod.tempDiffD
+          val arr = diff.toArray
+          val mean = arr.sum / arr.length
+          val std = {
+            val tmpA = arr.map(a => (a - mean) * (a - mean))
+            Math.sqrt(tmpA.sum / tmpA.length).toFloat
+          }
+          diff.set((diff - mean) / std + 0.5f)
+          Viz.imSave("diff", outputPath, diff, flip = true)
+          Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
+        }
+
+        t += 1
+      }
+    }
+    acc
+  }
+
   def main(args: Array[String]): Unit = {
     val anst = new GanMnist
     val parser: CmdLineParser = new CmdLineParser(anst)
@@ -123,78 +182,12 @@ object GanMnist {
       parser.parseArgument(args.toList.asJava)
 
       val dataPath = if (anst.mnistDataPath == null) 
System.getenv("MXNET_DATA_DIR")
-        else anst.mnistDataPath
+      else anst.mnistDataPath
 
       assert(dataPath != null)
-
-      val lr = 0.0005f
-      val beta1 = 0.5f
-      val batchSize = 100
-      val randShape = Shape(batchSize, 100)
-      val numEpoch = 100
-      val dataShape = Shape(batchSize, 1, 28, 28)
       val context = if (anst.gpu == -1) Context.cpu() else 
Context.gpu(anst.gpu)
 
-      val (symGen, symDec) =
-      makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
-
-      val gMod = new GANModule(
-          symGen,
-          symDec,
-          context = context,
-          dataShape = dataShape,
-          codeShape = randShape)
-
-      gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
-      gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
-
-      gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
-
-      val params = Map(
-        "image" -> s"${dataPath}/train-images-idx3-ubyte",
-        "label" -> s"${dataPath}/train-labels-idx1-ubyte",
-        "input_shape" -> s"(1, 28, 28)",
-        "batch_size" -> s"$batchSize",
-        "shuffle" -> "True"
-      )
-
-      val mnistIter = IO.MNISTIter(params)
-
-      val metricAcc = new CustomMetric(ferr, "ferr")
-
-      var t = 0
-      var dataBatch: DataBatch = null
-      for (epoch <- 0 until numEpoch) {
-        mnistIter.reset()
-        metricAcc.reset()
-        t = 0
-        while (mnistIter.hasNext) {
-          dataBatch = mnistIter.next()
-          gMod.update(dataBatch)
-          gMod.dLabel.set(0f)
-          metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
-          gMod.dLabel.set(1f)
-          metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
-
-          if (t % 50 == 0) {
-            val (name, value) = metricAcc.get
-            logger.info(s"epoch: $epoch, iter $t, metric=$value")
-            Viz.imSave("gout", anst.outputPath, gMod.tempOutG(0), flip = true)
-            val diff = gMod.tempDiffD
-            val arr = diff.toArray
-            val mean = arr.sum / arr.length
-            val std = {
-              val tmpA = arr.map(a => (a - mean) * (a - mean))
-              Math.sqrt(tmpA.sum / tmpA.length).toFloat
-            }
-            diff.set((diff - mean) / std + 0.5f)
-            Viz.imSave("diff", anst.outputPath, diff, flip = true)
-            Viz.imSave("data", anst.outputPath, dataBatch.data(0), flip = true)
-          }
-
-          t += 1
-        }
-      }
+      runTraining(dataPath, context, anst.outputPath, 100)
     } catch {
       case ex: Exception => {
         logger.error(ex.getMessage, ex)
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
index faab945067a..55b52965230 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
@@ -26,9 +26,6 @@ import org.apache.mxnet.Initializer
 import org.apache.mxnet.DataBatch
 import org.apache.mxnet.Random
 
-/**
- * @author Depeng Liang
- */
 class GANModule(
               symbolGenerator: Symbol,
               symbolEncoder: Symbol,
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md
new file mode 100644
index 00000000000..40db092727c
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md
@@ -0,0 +1,18 @@
+# GAN MNIST Example for Scala
+This is the GAN MNIST Training Example implemented for Scala type-safe api
+
+This example is only for Illustration and not modeled to achieve the best 
accuracy.
+## Setup
+### Download the source File
+```$xslt
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/mnist/mnist.zip
+```
+### Unzip the file
+```$xslt
+unzip mnist.zip
+```
+### Arguement Configuration
+Then you need to define the arguments that you would like to pass in the model:
+```$xslt
+--mnist-data-path <location of your downloaded file>
+```
\ No newline at end of file
diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
new file mode 100644
index 00000000000..12459fb1cc1
--- /dev/null
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.gan
+
+import java.io.File
+import java.net.URL
+
+import org.apache.commons.io.FileUtils
+import org.apache.mxnet.Context
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+import scala.sys.process.Process
+
+class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
+  private val logger = LoggerFactory.getLogger(classOf[GanExampleSuite])
+
+  test("Example CI: Test GAN MNIST") {
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      logger.info("Downloading mnist model")
+      val baseUrl = 
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci";
+      val tempDirPath = System.getProperty("java.io.tmpdir")
+      val modelDirPath = tempDirPath + File.separator + "mnist/"
+      logger.info("tempDirPath: %s".format(tempDirPath))
+      val tmpFile = new File(tempDirPath + "/mnist/mnist.zip")
+      if (!tmpFile.exists()) {
+        FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"),
+          tmpFile)
+      }
+      // TODO: Need to confirm with Windows
+      Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+        + tempDirPath + "/mnist/") !
+
+      val context = Context.gpu()
+
+      val output = GanMnist.runTraining(modelDirPath, context, modelDirPath, 5)
+      Process("rm -rf " + modelDirPath) !
+
+      assert(output >= 0.0f)
+    } else {
+      logger.info("GPU test only, skipped...")
+    }
+  }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to