This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 92f0c51 [MXNET-531] GAN MNIST Examples for Scala new API (#11547)
92f0c51 is described below
commit 92f0c512336db4219fdcd97fccf74dcdeb56dbf4
Author: Lanking <[email protected]>
AuthorDate: Tue Jul 3 22:19:10 2018 -0700
[MXNET-531] GAN MNIST Examples for Scala new API (#11547)
* add gan base file and example suite
---
.../org/apache/mxnetexamples/gan/GanMnist.scala | 227 ++++++++++-----------
.../org/apache/mxnetexamples/gan/Module.scala | 3 -
.../scala/org/apache/mxnetexamples/gan/README.md | 18 ++
.../apache/mxnetexamples/gan/GanExampleSuite.scala | 60 ++++++
4 files changed, 188 insertions(+), 120 deletions(-)
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 4fa96f6..6186989 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -17,66 +17,52 @@
package org.apache.mxnetexamples.gan
+import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, Shape,
Symbol, Xavier}
+import org.apache.mxnet.optimizer.Adam
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
+
import scala.collection.JavaConverters._
-import Viz._
-import org.apache.mxnet.Context
-import org.apache.mxnet.Shape
-import org.apache.mxnet.IO
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.CustomMetric
-import org.apache.mxnet.Xavier
-import org.apache.mxnet.optimizer.Adam
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Shape
-/**
- * @author Depeng Liang
- */
object GanMnist {
private val logger = LoggerFactory.getLogger(classOf[GanMnist])
- // a deconv layer that enlarges the feature map
+ // a deconv layer that enlarges the feature map
def deconv2D(data: Symbol, iShape: Shape, oShape: Shape,
- kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
- val targetShape = (oShape(oShape.length - 2), oShape(oShape.length - 1))
- val net = Symbol.Deconvolution(name)()(Map(
- "data" -> data,
- "kernel" -> s"$kShape",
- "stride" -> s"$stride",
- "target_shape" -> s"$targetShape",
- "num_filter" -> oShape(0),
- "no_bias" -> true))
+ kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)):
Symbol = {
+ val targetShape = Shape(oShape(oShape.length - 2), oShape(oShape.length -
1))
+ val net = Symbol.api.Deconvolution(data = Some(data), kernel =
Shape(kShape._1, kShape._2),
+ stride = Some(Shape(stride._1, stride._2)), target_shape =
Some(targetShape),
+ num_filter = oShape(0), no_bias = Some(true), name = name)
net
}
def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape,
- oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol
= {
+ oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f +
1e-12f): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name =
s"${prefix}_deconv")
- net = Symbol.BatchNorm(s"${prefix}_bn")()(Map("data" -> net, "fix_gamma"
-> true, "eps" -> eps))
- net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type"
-> "relu"))
+ net = Symbol.api.BatchNorm(name = s"${prefix}_bn", data = Some(net),
+ fix_gamma = Some(true), eps = Some(eps))
+ net = Symbol.api.Activation(data = Some(net), act_type = "relu", name =
s"${prefix}_act")
net
}
def deconv2DAct(data: Symbol, prefix: String, actType: String,
- iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
+ iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name =
s"${prefix}_deconv")
- net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type"
-> actType))
+ net = Symbol.api.Activation(data = Some(net), act_type = "relu", name =
s"${prefix}_act")
net
}
def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid",
- eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
+ eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
val code = Symbol.Variable("rand")
- var net = Symbol.FullyConnected("g1")()(Map("data" -> code,
- "num_hidden" -> 4 * 4 * ngf * 4, "no_bias" -> true))
- net = Symbol.Activation("gact1")()(Map("data" -> net, "act_type" ->
"relu"))
+ var net = Symbol.api.FullyConnected(data = Some(code), num_hidden = 4 * 4
* ngf * 4,
+ no_bias = Some(true), name = " g1")
+ net = Symbol.api.Activation(data = Some(net), act_type = "relu", name =
"gact1")
// 4 x 4
- net = Symbol.Reshape()()(Map("data" -> net, "shape" -> s"(-1, ${ngf * 4},
4, 4)"))
+ net = Symbol.api.Reshape(data = Some(net), shape = Some(Shape(-1, ngf * 4,
4, 4)))
// 8 x 8
net = deconv2DBnRelu(net, prefix = "g2",
iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape =
(3, 3))
@@ -89,22 +75,22 @@ object GanMnist {
val data = Symbol.Variable("data")
// 28 x 28
- val conv1 = Symbol.Convolution("conv1")()(Map("data" -> data,
- "kernel" -> "(5,5)", "num_filter" -> 20))
- val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" ->
"tanh"))
- val pool1 = Symbol.Pooling()()(Map("data" -> tanh1,
- "pool_type" -> "max", "kernel" -> "(2,2)", "stride" -> "(2,2)"))
+ val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5),
+ num_filter = 20, name = "conv1")
+ val tanh1 = Symbol.api.Activation(data = Some(conv1), act_type = "tanh")
+ val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
// second conv
- val conv2 = Symbol.Convolution("conv2")()(Map("data" -> pool1,
- "kernel" -> "(5,5)", "num_filter" -> 50))
- val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" ->
"tanh"))
- val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
- "kernel" -> "(2,2)", "stride" -> "(2,2)"))
- var d5 = Symbol.Flatten()()(Map("data" -> pool2))
- d5 = Symbol.FullyConnected("fc1")()(Map("data" -> d5, "num_hidden" -> 500))
- d5 = Symbol.Activation()()(Map("data" -> d5, "act_type" -> "tanh"))
- d5 = Symbol.FullyConnected("fc_dloss")()(Map("data" -> d5, "num_hidden" ->
1))
- val dloss = Symbol.LogisticRegressionOutput("dloss")()(Map("data" -> d5))
+ val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5,
5),
+ num_filter = 50, name = "conv2")
+ val tanh2 = Symbol.api.Activation(data = Some(conv2), act_type = "tanh")
+ val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+ var d5 = Symbol.api.Flatten(data = Some(pool2))
+ d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 500, name =
"fc1")
+ d5 = Symbol.api.Activation(data = Some(d5), act_type = "tanh")
+ d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 1, name =
"fc_dloss")
+ val dloss = Symbol.api.LogisticRegressionOutput(data = Some(d5), name =
"dloss")
(gout, dloss)
}
@@ -116,6 +102,79 @@ object GanMnist {
labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum /
label.shape(0)
}
+ def runTraining(dataPath : String, context : Context,
+ outputPath : String, numEpoch : Int): Float = {
+ val lr = 0.0005f
+ val beta1 = 0.5f
+ val batchSize = 100
+ val randShape = Shape(batchSize, 100)
+ val dataShape = Shape(batchSize, 1, 28, 28)
+
+ val (symGen, symDec) =
+ makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
+
+ val gMod = new GANModule(
+ symGen,
+ symDec,
+ context = context,
+ dataShape = dataShape,
+ codeShape = randShape)
+
+ gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
+ gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
+
+ gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
+
+ val params = Map(
+ "image" -> s"$dataPath/train-images-idx3-ubyte",
+ "label" -> s"$dataPath/train-labels-idx1-ubyte",
+ "input_shape" -> s"(1, 28, 28)",
+ "batch_size" -> s"$batchSize",
+ "shuffle" -> "True"
+ )
+
+ val mnistIter = IO.MNISTIter(params)
+
+ val metricAcc = new CustomMetric(ferr, "ferr")
+
+ var t = 0
+ var dataBatch: DataBatch = null
+ var acc = 0.0f
+ for (epoch <- 0 until numEpoch) {
+ mnistIter.reset()
+ metricAcc.reset()
+ t = 0
+ while (mnistIter.hasNext) {
+ dataBatch = mnistIter.next()
+ gMod.update(dataBatch)
+ gMod.dLabel.set(0f)
+ metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
+ gMod.dLabel.set(1f)
+ metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
+
+ if (t % 50 == 0) {
+ val (name, value) = metricAcc.get
+ acc = value(0)
+ logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
+ Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
+ val diff = gMod.tempDiffD
+ val arr = diff.toArray
+ val mean = arr.sum / arr.length
+ val std = {
+ val tmpA = arr.map(a => (a - mean) * (a - mean))
+ Math.sqrt(tmpA.sum / tmpA.length).toFloat
+ }
+ diff.set((diff - mean) / std + 0.5f)
+ Viz.imSave("diff", outputPath, diff, flip = true)
+ Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
+ }
+
+ t += 1
+ }
+ }
+ acc
+ }
+
def main(args: Array[String]): Unit = {
val anst = new GanMnist
val parser: CmdLineParser = new CmdLineParser(anst)
@@ -123,78 +182,12 @@ object GanMnist {
parser.parseArgument(args.toList.asJava)
val dataPath = if (anst.mnistDataPath == null)
System.getenv("MXNET_DATA_DIR")
- else anst.mnistDataPath
+ else anst.mnistDataPath
assert(dataPath != null)
-
- val lr = 0.0005f
- val beta1 = 0.5f
- val batchSize = 100
- val randShape = Shape(batchSize, 100)
- val numEpoch = 100
- val dataShape = Shape(batchSize, 1, 28, 28)
val context = if (anst.gpu == -1) Context.cpu() else
Context.gpu(anst.gpu)
- val (symGen, symDec) =
- makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
-
- val gMod = new GANModule(
- symGen,
- symDec,
- context = context,
- dataShape = dataShape,
- codeShape = randShape)
-
- gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
- gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
-
- gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
-
- val params = Map(
- "image" -> s"${dataPath}/train-images-idx3-ubyte",
- "label" -> s"${dataPath}/train-labels-idx1-ubyte",
- "input_shape" -> s"(1, 28, 28)",
- "batch_size" -> s"$batchSize",
- "shuffle" -> "True"
- )
-
- val mnistIter = IO.MNISTIter(params)
-
- val metricAcc = new CustomMetric(ferr, "ferr")
-
- var t = 0
- var dataBatch: DataBatch = null
- for (epoch <- 0 until numEpoch) {
- mnistIter.reset()
- metricAcc.reset()
- t = 0
- while (mnistIter.hasNext) {
- dataBatch = mnistIter.next()
- gMod.update(dataBatch)
- gMod.dLabel.set(0f)
- metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
- gMod.dLabel.set(1f)
- metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
-
- if (t % 50 == 0) {
- val (name, value) = metricAcc.get
- logger.info(s"epoch: $epoch, iter $t, metric=$value")
- Viz.imSave("gout", anst.outputPath, gMod.tempOutG(0), flip = true)
- val diff = gMod.tempDiffD
- val arr = diff.toArray
- val mean = arr.sum / arr.length
- val std = {
- val tmpA = arr.map(a => (a - mean) * (a - mean))
- Math.sqrt(tmpA.sum / tmpA.length).toFloat
- }
- diff.set((diff - mean) / std + 0.5f)
- Viz.imSave("diff", anst.outputPath, diff, flip = true)
- Viz.imSave("data", anst.outputPath, dataBatch.data(0), flip = true)
- }
-
- t += 1
- }
- }
+ runTraining(dataPath, context, anst.outputPath, 100)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
index faab945..55b5296 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/Module.scala
@@ -26,9 +26,6 @@ import org.apache.mxnet.Initializer
import org.apache.mxnet.DataBatch
import org.apache.mxnet.Random
-/**
- * @author Depeng Liang
- */
class GANModule(
symbolGenerator: Symbol,
symbolEncoder: Symbol,
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md
new file mode 100644
index 0000000..40db092
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/README.md
@@ -0,0 +1,18 @@
+# GAN MNIST Example for Scala
+This is the GAN MNIST Training Example implemented for Scala type-safe api
+
+This example is only for Illustration and not modeled to achieve the best
accuracy.
+## Setup
+### Download the source File
+```$xslt
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/mnist/mnist.zip
+```
+### Unzip the file
+```$xslt
+unzip mnist.zip
+```
+### Arguement Configuration
+Then you need to define the arguments that you would like to pass in the model:
+```$xslt
+--mnist-data-path <location of your downloaded file>
+```
\ No newline at end of file
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
new file mode 100644
index 0000000..12459fb
--- /dev/null
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.gan
+
+import java.io.File
+import java.net.URL
+
+import org.apache.commons.io.FileUtils
+import org.apache.mxnet.Context
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+import scala.sys.process.Process
+
+class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
+ private val logger = LoggerFactory.getLogger(classOf[GanExampleSuite])
+
+ test("Example CI: Test GAN MNIST") {
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ logger.info("Downloading mnist model")
+ val baseUrl =
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ val modelDirPath = tempDirPath + File.separator + "mnist/"
+ logger.info("tempDirPath: %s".format(tempDirPath))
+ val tmpFile = new File(tempDirPath + "/mnist/mnist.zip")
+ if (!tmpFile.exists()) {
+ FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"),
+ tmpFile)
+ }
+ // TODO: Need to confirm with Windows
+ Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+ + tempDirPath + "/mnist/") !
+
+ val context = Context.gpu()
+
+ val output = GanMnist.runTraining(modelDirPath, context, modelDirPath, 5)
+ Process("rm -rf " + modelDirPath) !
+
+ assert(output >= 0.0f)
+ } else {
+ logger.info("GPU test only, skipped...")
+ }
+ }
+}