This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 15deabc  [MXNET-947] Expand scala imclassification example with resnet 
(#12639)
15deabc is described below

commit 15deabc689eb5810656307da6d3f295f0de4e242
Author: Zach Kimberg <[email protected]>
AuthorDate: Wed Oct 10 13:38:39 2018 -0700

    [MXNET-947] Expand scala imclassification example with resnet (#12639)
    
    * [MXNET-947] Scala imclassification example with Resnet
---
 .../mxnetexamples/imclassification/README.md       |  45 ++++--
 .../{TrainMnist.scala => TrainModel.scala}         | 175 ++++++++++----------
 .../imclassification/datasets/MnistIter.scala      |  59 +++++++
 .../datasets/SyntheticDataIter.scala               |  70 ++++++++
 .../imclassification/models/Lenet.scala            |  52 ++++++
 .../models/MultiLayerPerceptron.scala              |  41 +++++
 .../imclassification/models/Resnet.scala           | 178 +++++++++++++++++++++
 .../{ModelTrain.scala => util/Trainer.scala}       |  38 +++--
 ...te.scala => IMClassificationExampleSuite.scala} |  42 ++---
 9 files changed, 575 insertions(+), 125 deletions(-)

diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
index 5141f44..cec750a 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
@@ -1,17 +1,36 @@
-# MNIST Example for Scala
-This is the MNIST Training Example implemented for Scala type-safe api
+# Image Classification Models
+
+This examples contains a number of image classification models that can be run 
on various datasets.
+
+## Models
+
+Currently, the following models are supported:
+- MultiLayerPerceptron
+- Lenet
+- Resnet
+
+## Datasets
+
+Currently, the following datasets are supported:
+- MNIST
+
+#### Synthetic Benchmark Data
+
+Additionally, the datasets can be replaced by randomly generated data for 
benchmarking.
+Data is produced to match the shapes of the supported datasets above.
+
+The following additional dataset image shapes are also defined for use with 
the benchmark synthetic data:
+- imagenet
+
+
+
 ## Setup
-### Download the source File
+
+### MNIST
+
+For this dataset, the data must be downloaded and extracted from the source or 
 ```$xslt
 https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/mnist/mnist.zip
 ```
-### Unzip the file
-```$xslt
-unzip mnist.zip
-```
-### Arguement Configuration
-Then you need to define the arguments that you would like to pass in the model:
-```$xslt
---data-dir <location of your downloaded file>
-```
-You can find more information 
[here](https://github.com/apache/incubator-mxnet/blob/scala-package/examples/src/main/scala/org/apache/mxnet/examples/imclassification/TrainMnist.scala#L169-L207)
\ No newline at end of file
+
+Afterwards, the location of the data folder must be passed in through the 
`--data-dir` argument.
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
similarity index 55%
rename from 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
rename to 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index 2f024fd..608e191 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -17,109 +17,106 @@
 
 package org.apache.mxnetexamples.imclassification
 
+import java.util.concurrent._
+
+import org.apache.mxnetexamples.imclassification.models._
+import org.apache.mxnetexamples.imclassification.util.Trainer
 import org.apache.mxnet._
+import org.apache.mxnetexamples.imclassification.datasets.{MnistIter, 
SyntheticDataIter}
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
-object TrainMnist {
-  private val logger = LoggerFactory.getLogger(classOf[TrainMnist])
-
-  // multi-layer perceptron
-  def getMlp: Symbol = {
-    val data = Symbol.Variable("data")
-
-    // val fc1 = Symbol.FullyConnected(name = "relu")()(Map("data" -> data, 
"act_type" -> "relu"))
-    val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, 
name = "fc1")
-    val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
-    val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = 
"fc2")
-    val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
-    val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = 
"fc3")
-    val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
-    mlp
-  }
-
-  def getLenet: Symbol = {
-    val data = Symbol.Variable("data")
-    // first conv
-    val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 
5), num_filter = 20)
-    val tanh1 = Symbol.api.tanh(data = Some(conv1))
-    val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
-      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
-    // second conv
-    val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 
5), num_filter = 50)
-    val tanh2 = Symbol.api.tanh(data = Some(conv2))
-    val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
-      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
-    // first fullc
-    val flatten = Symbol.api.Flatten(data = Some(pool2))
-    val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
-    val tanh3 = Symbol.api.tanh(data = Some(fc1))
-    // second fullc
-    val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10)
-    // loss
-    val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
-    lenet
-  }
-
-  def getIterator(dataShape: Shape)
-    (dataDir: String, batchSize: Int, kv: KVStore): (DataIter, DataIter) = {
-    val flat = if (dataShape.size == 3) "False" else "True"
-
-    val train = IO.MNISTIter(Map(
-      "image" -> (dataDir + "train-images-idx3-ubyte"),
-      "label" -> (dataDir + "train-labels-idx1-ubyte"),
-      "label_name" -> "softmax_label",
-      "input_shape" -> dataShape.toString,
-      "batch_size" -> batchSize.toString,
-      "shuffle" -> "True",
-      "flat" -> flat,
-      "num_parts" -> kv.numWorkers.toString,
-      "part_index" -> kv.`rank`.toString))
-
-    val eval = IO.MNISTIter(Map(
-      "image" -> (dataDir + "t10k-images-idx3-ubyte"),
-      "label" -> (dataDir + "t10k-labels-idx1-ubyte"),
-      "label_name" -> "softmax_label",
-      "input_shape" -> dataShape.toString,
-      "batch_size" -> batchSize.toString,
-      "flat" -> flat,
-      "num_parts" -> kv.numWorkers.toString,
-      "part_index" -> kv.`rank`.toString))
-
-    (train, eval)
-  }
-
-  def test(dataPath : String) : Float = {
+object TrainModel {
+  private val logger = LoggerFactory.getLogger(classOf[TrainModel])
+
+  /**
+    * Simple model training and execution
+    * @param model The model identifying string
+    * @param dataPath Path to location of image data
+    * @param numExamples Number of image data examples
+    * @param numEpochs Number of epochs to train for
+    * @param benchmark Whether to use benchmark synthetic data instead of real 
image data
+    * @return The final validation accuracy
+    */
+  def test(model: String, dataPath: String, numExamples: Int = 60000,
+           numEpochs: Int = 10, benchmark: Boolean = false): Float = {
     NDArrayCollector.auto().withScope {
-      val (dataShape, net) = (Shape(784), getMlp)
       val devs = Array(Context.cpu(0))
       val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, 
String]
-      val Acc = ModelTrain.fit(dataDir = dataPath,
-        batchSize = 128, numExamples = 60000, devs = devs,
-        network = net, dataLoader = getIterator(dataShape),
-        kvStore = "local", numEpochs = 10)
+      val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
+        numExamples = numExamples, benchmark = benchmark)
+      val Acc = Trainer.fit(batchSize = 128, numExamples, devs = devs,
+        network = net, dataLoader = dataLoader,
+        kvStore = "local", numEpochs = numEpochs)
       logger.info("Finish test fit ...")
       val (_, num) = Acc.get
       num(0)
     }
   }
 
+  /**
+    * Gets dataset iterator and model symbol
+    * @param dataset The dataset identifying string
+    * @param model The model identifying string
+    * @param dataDir Path to location of image data
+    * @param numLayers The number of model layers (resnet only)
+    * @param numExamples The number of examples in the dataset
+    * @param benchmark Whether to use benchmark synthetic data instead of real 
image data
+    * @return Data iterator (partially applied function) and model symbol
+    */
+  def dataLoaderAndModel(dataset: String, model: String, dataDir: String = "",
+                         numLayers: Int = 50, numExamples: Int = 60000,
+                         benchmark: Boolean = false
+                        ): ((Int, KVStore) => (DataIter, DataIter), Symbol) = {
+    val (imageShape, numClasses) = dataset match {
+      case "mnist" => (List(1, 28, 28), 10)
+      case "imagenet" => (List(3, 224, 224), 1000)
+      case _ => throw new Exception("Invalid image data collection")
+    }
+
+    val List(channels, height, width) = imageShape
+    val dataSize: Int = channels * height * width
+    val (datumShape, net) = model match {
+      case "mlp" => (List(dataSize), 
MultiLayerPerceptron.getSymbol(numClasses))
+      case "lenet" => (List(channels, height, width), 
Lenet.getSymbol(numClasses))
+      case "resnet" => (List(channels, height, width), 
Resnet.getSymbol(numClasses,
+        numLayers, imageShape))
+      case _ => throw new Exception("Invalid model name")
+    }
 
+    val dataLoader: (Int, KVStore) => (DataIter, DataIter) = if (benchmark) {
+      (batchSize: Int, kv: KVStore) => {
+        val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, 
List(), numExamples)
+        (iter, iter)
+      }
+    } else {
+      dataset match {
+        case "mnist" => MnistIter.getIterator(Shape(datumShape), dataDir)
+        case _ => throw new Exception("This image data collection only 
supports the"
+          + "synthetic benchmark iterator.  Use --benchmark to enable")
+      }
+    }
+    (dataLoader, net)
+  }
+
+  /**
+    * Runs image classification training from CLI with various options
+    * @param args CLI args
+    */
   def main(args: Array[String]): Unit = {
-    val inst = new TrainMnist
+    val inst = new TrainModel
     val parser: CmdLineParser = new CmdLineParser(inst)
     try {
       parser.parseArgument(args.toList.asJava)
 
       val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
-        else inst.dataDir
+      else inst.dataDir
 
-      val (dataShape, net) =
-        if (inst.network == "mlp") (Shape(784), getMlp)
-        else (Shape(1, 28, 28), getLenet)
+      val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, 
dataPath,
+        inst.numLayers, inst.numExamples, inst.benchmark)
 
       val devs =
         if (inst.gpus != null) inst.gpus.split(',').map(id => 
Context.gpu(id.trim.toInt))
@@ -144,9 +141,8 @@ object TrainMnist {
         logger.info("Start KVStoreServer for scheduler & servers")
         KVStoreServer.start()
       } else {
-        ModelTrain.fit(dataDir = inst.dataDir,
-          batchSize = inst.batchSize, numExamples = inst.numExamples, devs = 
devs,
-          network = net, dataLoader = getIterator(dataShape),
+        Trainer.fit(batchSize = inst.batchSize, numExamples = 
inst.numExamples, devs = devs,
+          network = net, dataLoader = dataLoader,
           kvStore = inst.kvStore, numEpochs = inst.numEpochs,
           modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
           lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = 
inst.lrFactorEpoch,
@@ -163,11 +159,19 @@ object TrainMnist {
   }
 }
 
-class TrainMnist {
-  @Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet']")
+class TrainModel {
+  @Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet', 
'resnet']")
   private val network: String = "mlp"
+  @Option(name = "--num-layers", usage = "the number of resnet layers to use")
+  private val numLayers: Int = 50
   @Option(name = "--data-dir", usage = "the input data directory")
   private val dataDir: String = "mnist/"
+
+  @Option(name = "--dataset", usage = "the images to classify: ['mnist', 
'imagenet']")
+  private val dataset: String = "mnist"
+  @Option(name = "--benchmark", usage = "Benchmark to use synthetic data to 
measure performance")
+  private val benchmark: Boolean = false
+
   @Option(name = "--gpus", usage = "the gpus will be used, e.g. '0,1,2,3'")
   private val gpus: String = null
   @Option(name = "--cpus", usage = "the cpus will be used, e.g. '0,1,2,3'")
@@ -187,7 +191,7 @@ class TrainMnist {
   @Option(name = "--kv-store", usage = "the kvstore type")
   private val kvStore = "local"
   @Option(name = "--lr-factor",
-          usage = "times the lr with a factor for every lr-factor-epoch epoch")
+    usage = "times the lr with a factor for every lr-factor-epoch epoch")
   private val lrFactor: Float = 1f
   @Option(name = "--lr-factor-epoch", usage = "the number of epoch to factor 
the lr, could be .5")
   private val lrFactorEpoch: Float = 1f
@@ -205,3 +209,4 @@ class TrainMnist {
   @Option(name = "--num-server", usage = "# of servers")
   private val numServer: Int = 1
 }
+
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
new file mode 100644
index 0000000..9e6e1c2
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.datasets
+
+import org.apache.mxnet._
+
+object MnistIter {
+  /**
+    * Returns an iterator over the MNIST dataset
+    * @param dataShape Image size (channels, height, width)
+    * @param dataDir The path to the image data
+    * @param batchSize Number of images per batch
+    * @param kv KVStore to use
+    * @return
+    */
+  def getIterator(dataShape: Shape, dataDir: String)
+                 (batchSize: Int, kv: KVStore): (DataIter, DataIter) = {
+    val flat = if (dataShape.size == 3) "False" else "True"
+
+    val train = IO.MNISTIter(Map(
+      "image" -> (dataDir + "train-images-idx3-ubyte"),
+      "label" -> (dataDir + "train-labels-idx1-ubyte"),
+      "label_name" -> "softmax_label",
+      "input_shape" -> dataShape.toString,
+      "batch_size" -> batchSize.toString,
+      "shuffle" -> "True",
+      "flat" -> flat,
+      "num_parts" -> kv.numWorkers.toString,
+      "part_index" -> kv.`rank`.toString))
+
+    val eval = IO.MNISTIter(Map(
+      "image" -> (dataDir + "t10k-images-idx3-ubyte"),
+      "label" -> (dataDir + "t10k-labels-idx1-ubyte"),
+      "label_name" -> "softmax_label",
+      "input_shape" -> dataShape.toString,
+      "batch_size" -> batchSize.toString,
+      "flat" -> flat,
+      "num_parts" -> kv.numWorkers.toString,
+      "part_index" -> kv.`rank`.toString))
+
+    (train, eval)
+  }
+
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
new file mode 100644
index 0000000..9421f10
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.datasets
+
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
+
+import scala.collection.immutable.ListMap
+import scala.util.Random
+
+class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: 
List[Int],
+                        labelShape: List[Int], maxIter: Int, dtype: DType = 
DType.Float32
+                       ) extends DataIter {
+  var curIter = 0
+  val random = new Random()
+  val shape = Shape(batchSize :: datumShape)
+  val batchLabelShape = Shape(batchSize :: labelShape)
+
+  val maxLabel = if (labelShape.isEmpty) numClasses.toFloat else 1f
+  var label: IndexedSeq[NDArray] = IndexedSeq(
+    NDArray.api.random_uniform(Some(0f), Some(maxLabel), shape = 
Some(batchLabelShape)))
+  var data: IndexedSeq[NDArray] = IndexedSeq(
+    NDArray.api.random_uniform(shape = Some(shape)))
+
+  val provideDataDesc: IndexedSeq[DataDesc] = IndexedSeq(
+    new DataDesc("data", shape, dtype, Layout.UNDEFINED))
+  val provideLabelDesc: IndexedSeq[DataDesc] = IndexedSeq(
+    new DataDesc("softmax_label", batchLabelShape, dtype, Layout.UNDEFINED))
+  val getPad: Int = 0
+
+  override def getData(): IndexedSeq[NDArray] = data
+
+  override def getIndex: IndexedSeq[Long] = IndexedSeq(curIter)
+
+  override def getLabel: IndexedSeq[NDArray] = label
+
+  override def hasNext: Boolean = curIter < maxIter - 1
+
+  override def next(): DataBatch = {
+    if (hasNext) {
+      curIter += batchSize
+      new DataBatch(data, label, getIndex, getPad, null, null, null)
+    } else {
+      throw new NoSuchElementException
+    }
+  }
+
+  override def reset(): Unit = {
+    curIter = 0
+  }
+
+  override def provideData: ListMap[String, Shape] = ListMap("data" -> shape)
+
+  override def provideLabel: ListMap[String, Shape] = ListMap("softmax_label" 
-> batchLabelShape)
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
new file mode 100644
index 0000000..76fb7bb
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object Lenet {
+
+  /**
+    * Gets Lenet Model Symbol
+    * @param numClasses Number of classes to classify into
+    * @return model symbol
+    */
+  def getSymbol(numClasses: Int): Symbol = {
+    val data = Symbol.Variable("data")
+    // first conv
+    val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 
5), num_filter = 20)
+    val tanh1 = Symbol.api.tanh(data = Some(conv1))
+    val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+    // second conv
+    val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 
5), num_filter = 50)
+    val tanh2 = Symbol.api.tanh(data = Some(conv2))
+    val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+    // first fullc
+    val flatten = Symbol.api.Flatten(data = Some(pool2))
+    val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
+    val tanh3 = Symbol.api.tanh(data = Some(fc1))
+    // second fullc
+    val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 
numClasses)
+    // loss
+    val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
+    lenet
+  }
+
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
new file mode 100644
index 0000000..5d880bb
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object MultiLayerPerceptron {
+
+  /**
+    * Gets MultiLayer Perceptron Model Symbol
+    * @param numClasses Number of classes to classify into
+    * @return model symbol
+    */
+  def getSymbol(numClasses: Int): Symbol = {
+    val data = Symbol.Variable("data")
+
+    val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, 
name = "fc1")
+    val act1 = Symbol.api.Activation(data = Some(fc1), "relu", name = "relu")
+    val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = 
"fc2")
+    val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
+    val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, numClasses, 
name = "fc3")
+    val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
+    mlp
+  }
+
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
new file mode 100644
index 0000000..c3f43d9
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object Resnet {
+  /**
+    * Helper to produce individual residual unit
+    */
+  def residualUnit(data: Symbol, numFilter: Int, stride: Shape, dimMatch: 
Boolean,
+                   name: String = "", bottleNeck: Boolean = true, bnMom: Float 
= 0.9f,
+                   workspace: Int = 256, memonger: Boolean = false): Symbol = {
+    val (act1, operated) = if (bottleNeck) {
+      val bn1 = Symbol.api.BatchNorm(data = Some(data), fix_gamma = 
Some(false), eps = Some(2e-5),
+        momentum = Some(bnMom), name = name + "_bn1")
+      val act1: Symbol = Symbol.api.Activation(data = Some(bn1), act_type = 
"relu",
+        name = name + "_relu1")
+      val conv1 = Symbol.api.Convolution(data = Some(act1), num_filter = 
(numFilter * 0.25).toInt,
+        kernel = Shape(1, 1), stride = Some(Shape(1, 1)), pad = Some(Shape(0, 
0)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv1")
+      val bn2 = Symbol.api.BatchNorm(data = Some(conv1), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn2")
+      val act2 = Symbol.api.Activation(data = Some(bn2), act_type = "relu", 
name = name + "_relu2")
+      val conv2 = Symbol.api.Convolution(data = Some(act2), num_filter = 
(numFilter * 0.25).toInt,
+        kernel = Shape(3, 3), stride = Some(stride), pad = Some(Shape(1, 1)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv2")
+      val bn3 = Symbol.api.BatchNorm(data = Some(conv2), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn3")
+      val act3 = Symbol.api.Activation(data = Some(bn3), act_type = "relu", 
name = name + "_relu3")
+      val conv3 = Symbol.api.Convolution(data = Some(act3), num_filter = 
numFilter,
+        kernel = Shape(1, 1), stride = Some(Shape(1, 1)), pad = Some(Shape(0, 
0)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv3")
+      (act1, conv3)
+    } else {
+      val bn1 = Symbol.api.BatchNorm(data = Some(data), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn1")
+      val act1 = Symbol.api.Activation(data = Some(bn1), act_type = "relu", 
name = name + "_relu1")
+      val conv1 = Symbol.api.Convolution(data = Some(act1), num_filter = 
numFilter,
+        kernel = Shape(3, 3), stride = Some(stride), pad = Some(Shape(1, 1)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv1")
+      val bn2 = Symbol.api.BatchNorm(data = Some(conv1), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn2")
+      val act2 = Symbol.api.Activation(data = Some(bn2), act_type = "relu", 
name = name + "_relu2")
+      val conv2 = Symbol.api.Convolution(data = Some(act2), num_filter = 
numFilter,
+        kernel = Shape(3, 3), stride = Some(Shape(1, 1)), pad = Some(Shape(1, 
1)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv2")
+      (act1, conv2)
+    }
+    val shortcut = if (dimMatch) {
+      data
+    } else {
+      Symbol.api.Convolution(Some(act1), num_filter = numFilter, kernel = 
Shape(1, 1),
+        stride = Some(stride), no_bias = Some(true), workspace = 
Some(workspace),
+        name = name + "_sc")
+    }
+    operated + shortcut
+  }
+
+  /**
+    * Helper for building the resnet Symbol
+    */
+  def resnet(units: List[Int], numStages: Int, filterList: List[Int], 
numClasses: Int,
+             imageShape: List[Int], bottleNeck: Boolean = true, bnMom: Float = 
0.9f,
+             workspace: Int = 256, dtype: String = "float32", memonger: 
Boolean = false): Symbol = {
+    assert(units.size == numStages)
+    var data = Symbol.Variable("data", shape = Shape(List(4) ::: imageShape), 
dType = DType.Float32)
+    if (dtype == "float32") {
+      data = Symbol.api.identity(Some(data), "id")
+    } else if (dtype == "float16") {
+      data = Symbol.api.cast(Some(data), "float16")
+    }
+    data = Symbol.api.BatchNorm(Some(data), fix_gamma = Some(true), eps = 
Some(2e-5),
+      momentum = Some(bnMom), name = "bn_data")
+    val List(channels, height, width) = imageShape
+    var body = if (height <= 32) {
+      Symbol.api.Convolution(Some(data), num_filter = filterList.head, kernel 
= Shape(7, 7),
+        stride = Some(Shape(1, 1)), pad = Some(Shape(1, 1)), no_bias = 
Some(true), name = "conv0",
+        workspace = Some(workspace))
+    } else {
+      var body0 = Symbol.api.Convolution(Some(data), num_filter = 
filterList.head,
+        kernel = Shape(3, 3), stride = Some(Shape(2, 2)), pad = Some(Shape(3, 
3)),
+        no_bias = Some(true), name = "conv0", workspace = Some(workspace))
+      body0 = Symbol.api.BatchNorm(Some(body0), fix_gamma = Some(false), eps = 
Some(2e-5),
+        momentum = Some(bnMom), name = "bn0")
+      body0 = Symbol.api.Activation(Some(body0), act_type = "relu", name = 
"relu0")
+      Symbol.api.Pooling(Some(body0), kernel = Some(Shape(3, 3)), stride = 
Some(Shape(2, 2)),
+        pad = Some(Shape(1, 1)), pool_type = Some("max"))
+    }
+    for (((filter, i), unit) <- filterList.tail.zipWithIndex.zip(units)) {
+      val stride = Shape(if (i == 0) 1 else 2, if (i == 0) 1 else 2)
+      body = residualUnit(body, filter, stride, false, name = s"stage${i + 
1}_unit${1}",
+        bottleNeck = bottleNeck, workspace = workspace, memonger = memonger)
+      for (j <- 0 until unit - 1) {
+        body = residualUnit(body, filter, Shape(1, 1), true, s"stage${i + 
1}_unit${j + 2}",
+          bottleNeck, workspace = workspace, memonger = memonger)
+      }
+    }
+    val bn1 = Symbol.api.BatchNorm(Some(body), fix_gamma = Some(false), eps = 
Some(2e-5),
+      momentum = Some(bnMom), name = "bn1")
+    val relu1 = Symbol.api.Activation(Some(bn1), act_type = "relu", name = 
"relu1")
+    val pool1 = Symbol.api.Pooling(Some(relu1), global_pool = Some(true),
+      kernel = Some(Shape(7, 7)), pool_type = Some("avg"), name = "pool1")
+    val flat = Symbol.api.Flatten(Some(pool1))
+    var fc1 = Symbol.api.FullyConnected(Some(flat), num_hidden = numClasses, 
name = "fc1")
+    if (dtype == "float16") {
+      fc1 = Symbol.api.cast(Some(fc1), "float32")
+    }
+    Symbol.api.SoftmaxOutput(Some(fc1), name = "softmax")
+  }
+
+  /**
+    * Gets the resnet model symbol
+    * @param numClasses Number of classes to classify into
+    * @param numLayers Number of residual layers
+    * @param imageShape The image shape as List(channels, height, width)
+    * @param convWorkspace Maximum temporary workspace allowed (MB) in 
convolutions
+    * @param dtype Type of data (float16, float32, etc) to use during 
computation
+    * @return Model symbol
+    */
+  def getSymbol(numClasses: Int, numLayers: Int, imageShape: List[Int], 
convWorkspace: Int = 256,
+                dtype: String = "float32"): Symbol = {
+    val List(channels, height, width) = imageShape
+    val (numStages, units, filterList, bottleNeck): (Int, List[Int], 
List[Int], Boolean) =
+      if (height <= 28) {
+        val (perUnit, filterList, bottleNeck) = if ((numLayers - 2) % 9 == 0 
&& numLayers > 165) {
+          (List(Math.floor((numLayers - 2) / 9).toInt),
+            List(16, 64, 128, 256),
+            true)
+        } else if ((numLayers - 2) % 6 == 0 && numLayers < 164) {
+          (List(Math.floor((numLayers - 2) / 6).toInt),
+            List(16, 16, 32, 64),
+            false)
+        } else {
+          throw new Exception(s"Invalid number of layers: ${numLayers}")
+        }
+        val numStages = 3
+        val units = (1 to numStages).map(_ => perUnit.head).toList
+        (numStages, units, filterList, bottleNeck)
+      } else {
+        val (filterList, bottleNeck) = if (numLayers >= 50) {
+          (List(64, 256, 512, 1024, 2048), true)
+        } else {
+          (List(64, 64, 128, 256, 512), false)
+        }
+        val units: List[Int] = Map(
+          18 -> List(2, 2, 2, 2),
+          34 -> List(3, 4, 6, 3),
+          50 -> List(3, 4, 6, 3),
+          101 -> List(3, 4, 23, 3),
+          152 -> List(3, 8, 36, 3),
+          200 -> List(3, 24, 36, 3),
+          269 -> List(3, 30, 48, 8)
+        ).get(numLayers) match {
+          case Some(x) => x
+          case None => throw new Exception(s"Invalid number of layers: 
${numLayers}")
+        }
+        (4, units, filterList, bottleNeck)
+      }
+    resnet(units, numStages, filterList, numClasses, imageShape, bottleNeck,
+      workspace = convWorkspace, dtype = dtype)
+  }
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
similarity index 73%
rename from 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
rename to 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
index 1a77775..9a54e58 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
@@ -15,19 +15,37 @@
  * limitations under the License.
  */
 
-package org.apache.mxnetexamples.imclassification
+package org.apache.mxnetexamples.imclassification.util
 
 import org.apache.mxnet.Callback.Speedometer
 import org.apache.mxnet._
 import org.apache.mxnet.optimizer.SGD
 import org.slf4j.LoggerFactory
 
-object ModelTrain {
-  private val logger = LoggerFactory.getLogger(classOf[ModelTrain])
+object Trainer {
+  private val logger = LoggerFactory.getLogger(classOf[Trainer])
 
+  /**
+    * Fits a model
+    * @param batchSize Number of images per training batch
+    * @param numExamples Total number of image examples
+    * @param devs List of device contexts to use
+    * @param network The model to train
+    * @param dataLoader Function to get data loaders for training and 
validation data
+    * @param kvStore KVStore to use
+    * @param numEpochs Number of times to train on each image
+    * @param modelPrefix Prefix to model identification
+    * @param loadEpoch Loads a saved checkpoint at this epoch when set
+    * @param lr The learning rate
+    * @param lrFactor Learning rate factor (see FactorScheduler)
+    * @param lrFactorEpoch Learning rate factor epoch (see FactorScheduler)
+    * @param clipGradient Maximum gradient during optimization
+    * @param monitorSize (See Monitor)
+    * @return Final accuracy
+    */
   // scalastyle:off parameterNum
-  def fit(dataDir: String, batchSize: Int, numExamples: Int, devs: 
Array[Context],
-          network: Symbol, dataLoader: (String, Int, KVStore) => (DataIter, 
DataIter),
+  def fit(batchSize: Int, numExamples: Int, devs: Array[Context],
+          network: Symbol, dataLoader: (Int, KVStore) => (DataIter, DataIter),
           kvStore: String, numEpochs: Int, modelPrefix: String = null, 
loadEpoch: Int = -1,
           lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
           clipGradient: Float = 0f, monitorSize: Int = -1): Accuracy = {
@@ -60,7 +78,7 @@ object ModelTrain {
       }
 
     // data
-    val (train, validation) = dataLoader(dataDir, batchSize, kv)
+    val (train, validation) = dataLoader(batchSize, kv)
 
     // train
     val epochSize =
@@ -75,8 +93,8 @@ object ModelTrain {
         null
       }
     val optimizer: Optimizer = new SGD(learningRate = lr,
-        lrScheduler = lrScheduler, clipGradient = clipGradient,
-        momentum = 0.9f, wd = 0.00001f)
+      lrScheduler = lrScheduler, clipGradient = clipGradient,
+      momentum = 0.9f, wd = 0.00001f)
 
     // disable kvstore for single device
     if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType 
!= "gpu")) {
@@ -108,7 +126,9 @@ object ModelTrain {
     }
     acc
   }
+
   // scalastyle:on parameterNum
 }
 
-class ModelTrain
+class Trainer
+
diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
similarity index 53%
rename from 
scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
rename to 
scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
index 0fd3af0..6e9667a 100644
--- 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
@@ -18,9 +18,7 @@
 package org.apache.mxnetexamples.imclassification
 
 import java.io.File
-import java.net.URL
 
-import org.apache.commons.io.FileUtils
 import org.apache.mxnet.Context
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -31,27 +29,35 @@ import scala.sys.process.Process
 /**
   * Integration test for MNIST example.
   */
-class MNISTExampleSuite extends FunSuite with BeforeAndAfterAll {
-  private val logger = LoggerFactory.getLogger(classOf[MNISTExampleSuite])
+class IMClassificationExampleSuite extends FunSuite with BeforeAndAfterAll {
+  private val logger = 
LoggerFactory.getLogger(classOf[IMClassificationExampleSuite])
 
   test("Example CI: Test MNIST Training") {
 
-      logger.info("Downloading mnist model")
-      val baseUrl = 
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci";
-      val tempDirPath = System.getProperty("java.io.tmpdir")
-      val modelDirPath = tempDirPath + File.separator + "mnist/"
-      logger.info("tempDirPath: %s".format(tempDirPath))
-      Util.downloadUrl(baseUrl + "/mnist/mnist.zip",
-        tempDirPath + "/mnist/mnist.zip")
-      // TODO: Need to confirm with Windows
-      Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
-        + tempDirPath + "/mnist/") !
+    logger.info("Downloading mnist model")
+    val baseUrl = 
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci";
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    val modelDirPath = tempDirPath + File.separator + "mnist/"
+    logger.info("tempDirPath: %s".format(tempDirPath))
+    Util.downloadUrl(baseUrl + "/mnist/mnist.zip",
+      tempDirPath + "/mnist/mnist.zip")
+    // TODO: Need to confirm with Windows
+    Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+      + tempDirPath + "/mnist/") !
 
-      var context = Context.cpu()
+    var context = Context.cpu()
 
-      val output = TrainMnist.test(modelDirPath)
-      Process("rm -rf " + modelDirPath) !
+    val valAccuracy = TrainModel.test("mlp", modelDirPath)
+    Process("rm -rf " + modelDirPath) !
 
-      assert(output >= 0.95f)
+    assert(valAccuracy >= 0.95f)
   }
+
+  for(model <- List("mlp", "lenet", "resnet")) {
+    test(s"Example CI: Test Image Classification Model ${model}") {
+      var context = Context.cpu()
+      val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true)
+    }
+  }
+
 }

Reply via email to