nswamy closed pull request #12639: [MXNET-947] Expand scala imclassification 
example with resnet
URL: https://github.com/apache/incubator-mxnet/pull/12639
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
index 5141f441b1e..cec750acdc9 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
@@ -1,17 +1,36 @@
-# MNIST Example for Scala
-This is the MNIST Training Example implemented for Scala type-safe api
+# Image Classification Models
+
+This examples contains a number of image classification models that can be run 
on various datasets.
+
+## Models
+
+Currently, the following models are supported:
+- MultiLayerPerceptron
+- Lenet
+- Resnet
+
+## Datasets
+
+Currently, the following datasets are supported:
+- MNIST
+
+#### Synthetic Benchmark Data
+
+Additionally, the datasets can be replaced by randomly generated data for 
benchmarking.
+Data is produced to match the shapes of the supported datasets above.
+
+The following additional dataset image shapes are also defined for use with 
the benchmark synthetic data:
+- imagenet
+
+
+
 ## Setup
-### Download the source File
+
+### MNIST
+
+For this dataset, the data must be downloaded and extracted from the source or 
 ```$xslt
 https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/mnist/mnist.zip
 ```
-### Unzip the file
-```$xslt
-unzip mnist.zip
-```
-### Arguement Configuration
-Then you need to define the arguments that you would like to pass in the model:
-```$xslt
---data-dir <location of your downloaded file>
-```
-You can find more information 
[here](https://github.com/apache/incubator-mxnet/blob/scala-package/examples/src/main/scala/org/apache/mxnet/examples/imclassification/TrainMnist.scala#L169-L207)
\ No newline at end of file
+
+Afterwards, the location of the data folder must be passed in through the 
`--data-dir` argument.
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
similarity index 55%
rename from 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
rename to 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index 2f024fd039b..608e191e019 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -17,109 +17,106 @@
 
 package org.apache.mxnetexamples.imclassification
 
+import java.util.concurrent._
+
+import org.apache.mxnetexamples.imclassification.models._
+import org.apache.mxnetexamples.imclassification.util.Trainer
 import org.apache.mxnet._
+import org.apache.mxnetexamples.imclassification.datasets.{MnistIter, 
SyntheticDataIter}
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
-object TrainMnist {
-  private val logger = LoggerFactory.getLogger(classOf[TrainMnist])
-
-  // multi-layer perceptron
-  def getMlp: Symbol = {
-    val data = Symbol.Variable("data")
-
-    // val fc1 = Symbol.FullyConnected(name = "relu")()(Map("data" -> data, 
"act_type" -> "relu"))
-    val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, 
name = "fc1")
-    val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
-    val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = 
"fc2")
-    val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
-    val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = 
"fc3")
-    val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
-    mlp
-  }
-
-  def getLenet: Symbol = {
-    val data = Symbol.Variable("data")
-    // first conv
-    val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 
5), num_filter = 20)
-    val tanh1 = Symbol.api.tanh(data = Some(conv1))
-    val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
-      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
-    // second conv
-    val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 
5), num_filter = 50)
-    val tanh2 = Symbol.api.tanh(data = Some(conv2))
-    val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
-      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
-    // first fullc
-    val flatten = Symbol.api.Flatten(data = Some(pool2))
-    val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
-    val tanh3 = Symbol.api.tanh(data = Some(fc1))
-    // second fullc
-    val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10)
-    // loss
-    val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
-    lenet
-  }
-
-  def getIterator(dataShape: Shape)
-    (dataDir: String, batchSize: Int, kv: KVStore): (DataIter, DataIter) = {
-    val flat = if (dataShape.size == 3) "False" else "True"
-
-    val train = IO.MNISTIter(Map(
-      "image" -> (dataDir + "train-images-idx3-ubyte"),
-      "label" -> (dataDir + "train-labels-idx1-ubyte"),
-      "label_name" -> "softmax_label",
-      "input_shape" -> dataShape.toString,
-      "batch_size" -> batchSize.toString,
-      "shuffle" -> "True",
-      "flat" -> flat,
-      "num_parts" -> kv.numWorkers.toString,
-      "part_index" -> kv.`rank`.toString))
-
-    val eval = IO.MNISTIter(Map(
-      "image" -> (dataDir + "t10k-images-idx3-ubyte"),
-      "label" -> (dataDir + "t10k-labels-idx1-ubyte"),
-      "label_name" -> "softmax_label",
-      "input_shape" -> dataShape.toString,
-      "batch_size" -> batchSize.toString,
-      "flat" -> flat,
-      "num_parts" -> kv.numWorkers.toString,
-      "part_index" -> kv.`rank`.toString))
-
-    (train, eval)
-  }
-
-  def test(dataPath : String) : Float = {
+object TrainModel {
+  private val logger = LoggerFactory.getLogger(classOf[TrainModel])
+
+  /**
+    * Simple model training and execution
+    * @param model The model identifying string
+    * @param dataPath Path to location of image data
+    * @param numExamples Number of image data examples
+    * @param numEpochs Number of epochs to train for
+    * @param benchmark Whether to use benchmark synthetic data instead of real 
image data
+    * @return The final validation accuracy
+    */
+  def test(model: String, dataPath: String, numExamples: Int = 60000,
+           numEpochs: Int = 10, benchmark: Boolean = false): Float = {
     NDArrayCollector.auto().withScope {
-      val (dataShape, net) = (Shape(784), getMlp)
       val devs = Array(Context.cpu(0))
       val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, 
String]
-      val Acc = ModelTrain.fit(dataDir = dataPath,
-        batchSize = 128, numExamples = 60000, devs = devs,
-        network = net, dataLoader = getIterator(dataShape),
-        kvStore = "local", numEpochs = 10)
+      val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
+        numExamples = numExamples, benchmark = benchmark)
+      val Acc = Trainer.fit(batchSize = 128, numExamples, devs = devs,
+        network = net, dataLoader = dataLoader,
+        kvStore = "local", numEpochs = numEpochs)
       logger.info("Finish test fit ...")
       val (_, num) = Acc.get
       num(0)
     }
   }
 
+  /**
+    * Gets dataset iterator and model symbol
+    * @param dataset The dataset identifying string
+    * @param model The model identifying string
+    * @param dataDir Path to location of image data
+    * @param numLayers The number of model layers (resnet only)
+    * @param numExamples The number of examples in the dataset
+    * @param benchmark Whether to use benchmark synthetic data instead of real 
image data
+    * @return Data iterator (partially applied function) and model symbol
+    */
+  def dataLoaderAndModel(dataset: String, model: String, dataDir: String = "",
+                         numLayers: Int = 50, numExamples: Int = 60000,
+                         benchmark: Boolean = false
+                        ): ((Int, KVStore) => (DataIter, DataIter), Symbol) = {
+    val (imageShape, numClasses) = dataset match {
+      case "mnist" => (List(1, 28, 28), 10)
+      case "imagenet" => (List(3, 224, 224), 1000)
+      case _ => throw new Exception("Invalid image data collection")
+    }
+
+    val List(channels, height, width) = imageShape
+    val dataSize: Int = channels * height * width
+    val (datumShape, net) = model match {
+      case "mlp" => (List(dataSize), 
MultiLayerPerceptron.getSymbol(numClasses))
+      case "lenet" => (List(channels, height, width), 
Lenet.getSymbol(numClasses))
+      case "resnet" => (List(channels, height, width), 
Resnet.getSymbol(numClasses,
+        numLayers, imageShape))
+      case _ => throw new Exception("Invalid model name")
+    }
 
+    val dataLoader: (Int, KVStore) => (DataIter, DataIter) = if (benchmark) {
+      (batchSize: Int, kv: KVStore) => {
+        val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, 
List(), numExamples)
+        (iter, iter)
+      }
+    } else {
+      dataset match {
+        case "mnist" => MnistIter.getIterator(Shape(datumShape), dataDir)
+        case _ => throw new Exception("This image data collection only 
supports the"
+          + "synthetic benchmark iterator.  Use --benchmark to enable")
+      }
+    }
+    (dataLoader, net)
+  }
+
+  /**
+    * Runs image classification training from CLI with various options
+    * @param args CLI args
+    */
   def main(args: Array[String]): Unit = {
-    val inst = new TrainMnist
+    val inst = new TrainModel
     val parser: CmdLineParser = new CmdLineParser(inst)
     try {
       parser.parseArgument(args.toList.asJava)
 
       val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
-        else inst.dataDir
+      else inst.dataDir
 
-      val (dataShape, net) =
-        if (inst.network == "mlp") (Shape(784), getMlp)
-        else (Shape(1, 28, 28), getLenet)
+      val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, 
dataPath,
+        inst.numLayers, inst.numExamples, inst.benchmark)
 
       val devs =
         if (inst.gpus != null) inst.gpus.split(',').map(id => 
Context.gpu(id.trim.toInt))
@@ -144,9 +141,8 @@ object TrainMnist {
         logger.info("Start KVStoreServer for scheduler & servers")
         KVStoreServer.start()
       } else {
-        ModelTrain.fit(dataDir = inst.dataDir,
-          batchSize = inst.batchSize, numExamples = inst.numExamples, devs = 
devs,
-          network = net, dataLoader = getIterator(dataShape),
+        Trainer.fit(batchSize = inst.batchSize, numExamples = 
inst.numExamples, devs = devs,
+          network = net, dataLoader = dataLoader,
           kvStore = inst.kvStore, numEpochs = inst.numEpochs,
           modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
           lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = 
inst.lrFactorEpoch,
@@ -163,11 +159,19 @@ object TrainMnist {
   }
 }
 
-class TrainMnist {
-  @Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet']")
+class TrainModel {
+  @Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet', 
'resnet']")
   private val network: String = "mlp"
+  @Option(name = "--num-layers", usage = "the number of resnet layers to use")
+  private val numLayers: Int = 50
   @Option(name = "--data-dir", usage = "the input data directory")
   private val dataDir: String = "mnist/"
+
+  @Option(name = "--dataset", usage = "the images to classify: ['mnist', 
'imagenet']")
+  private val dataset: String = "mnist"
+  @Option(name = "--benchmark", usage = "Benchmark to use synthetic data to 
measure performance")
+  private val benchmark: Boolean = false
+
   @Option(name = "--gpus", usage = "the gpus will be used, e.g. '0,1,2,3'")
   private val gpus: String = null
   @Option(name = "--cpus", usage = "the cpus will be used, e.g. '0,1,2,3'")
@@ -187,7 +191,7 @@ class TrainMnist {
   @Option(name = "--kv-store", usage = "the kvstore type")
   private val kvStore = "local"
   @Option(name = "--lr-factor",
-          usage = "times the lr with a factor for every lr-factor-epoch epoch")
+    usage = "times the lr with a factor for every lr-factor-epoch epoch")
   private val lrFactor: Float = 1f
   @Option(name = "--lr-factor-epoch", usage = "the number of epoch to factor 
the lr, could be .5")
   private val lrFactorEpoch: Float = 1f
@@ -205,3 +209,4 @@ class TrainMnist {
   @Option(name = "--num-server", usage = "# of servers")
   private val numServer: Int = 1
 }
+
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
new file mode 100644
index 00000000000..9e6e1c2a326
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.datasets
+
+import org.apache.mxnet._
+
+object MnistIter {
+  /**
+    * Returns an iterator over the MNIST dataset
+    * @param dataShape Image size (channels, height, width)
+    * @param dataDir The path to the image data
+    * @param batchSize Number of images per batch
+    * @param kv KVStore to use
+    * @return
+    */
+  def getIterator(dataShape: Shape, dataDir: String)
+                 (batchSize: Int, kv: KVStore): (DataIter, DataIter) = {
+    val flat = if (dataShape.size == 3) "False" else "True"
+
+    val train = IO.MNISTIter(Map(
+      "image" -> (dataDir + "train-images-idx3-ubyte"),
+      "label" -> (dataDir + "train-labels-idx1-ubyte"),
+      "label_name" -> "softmax_label",
+      "input_shape" -> dataShape.toString,
+      "batch_size" -> batchSize.toString,
+      "shuffle" -> "True",
+      "flat" -> flat,
+      "num_parts" -> kv.numWorkers.toString,
+      "part_index" -> kv.`rank`.toString))
+
+    val eval = IO.MNISTIter(Map(
+      "image" -> (dataDir + "t10k-images-idx3-ubyte"),
+      "label" -> (dataDir + "t10k-labels-idx1-ubyte"),
+      "label_name" -> "softmax_label",
+      "input_shape" -> dataShape.toString,
+      "batch_size" -> batchSize.toString,
+      "flat" -> flat,
+      "num_parts" -> kv.numWorkers.toString,
+      "part_index" -> kv.`rank`.toString))
+
+    (train, eval)
+  }
+
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
new file mode 100644
index 00000000000..9421f102161
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.datasets
+
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
+
+import scala.collection.immutable.ListMap
+import scala.util.Random
+
+class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: 
List[Int],
+                        labelShape: List[Int], maxIter: Int, dtype: DType = 
DType.Float32
+                       ) extends DataIter {
+  var curIter = 0
+  val random = new Random()
+  val shape = Shape(batchSize :: datumShape)
+  val batchLabelShape = Shape(batchSize :: labelShape)
+
+  val maxLabel = if (labelShape.isEmpty) numClasses.toFloat else 1f
+  var label: IndexedSeq[NDArray] = IndexedSeq(
+    NDArray.api.random_uniform(Some(0f), Some(maxLabel), shape = 
Some(batchLabelShape)))
+  var data: IndexedSeq[NDArray] = IndexedSeq(
+    NDArray.api.random_uniform(shape = Some(shape)))
+
+  val provideDataDesc: IndexedSeq[DataDesc] = IndexedSeq(
+    new DataDesc("data", shape, dtype, Layout.UNDEFINED))
+  val provideLabelDesc: IndexedSeq[DataDesc] = IndexedSeq(
+    new DataDesc("softmax_label", batchLabelShape, dtype, Layout.UNDEFINED))
+  val getPad: Int = 0
+
+  override def getData(): IndexedSeq[NDArray] = data
+
+  override def getIndex: IndexedSeq[Long] = IndexedSeq(curIter)
+
+  override def getLabel: IndexedSeq[NDArray] = label
+
+  override def hasNext: Boolean = curIter < maxIter - 1
+
+  override def next(): DataBatch = {
+    if (hasNext) {
+      curIter += batchSize
+      new DataBatch(data, label, getIndex, getPad, null, null, null)
+    } else {
+      throw new NoSuchElementException
+    }
+  }
+
+  override def reset(): Unit = {
+    curIter = 0
+  }
+
+  override def provideData: ListMap[String, Shape] = ListMap("data" -> shape)
+
+  override def provideLabel: ListMap[String, Shape] = ListMap("softmax_label" 
-> batchLabelShape)
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
new file mode 100644
index 00000000000..76fb7bb6602
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object Lenet {
+
+  /**
+    * Gets Lenet Model Symbol
+    * @param numClasses Number of classes to classify into
+    * @return model symbol
+    */
+  def getSymbol(numClasses: Int): Symbol = {
+    val data = Symbol.Variable("data")
+    // first conv
+    val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 
5), num_filter = 20)
+    val tanh1 = Symbol.api.tanh(data = Some(conv1))
+    val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+    // second conv
+    val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 
5), num_filter = 50)
+    val tanh2 = Symbol.api.tanh(data = Some(conv2))
+    val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
+      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+    // first fullc
+    val flatten = Symbol.api.Flatten(data = Some(pool2))
+    val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
+    val tanh3 = Symbol.api.tanh(data = Some(fc1))
+    // second fullc
+    val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 
numClasses)
+    // loss
+    val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
+    lenet
+  }
+
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
new file mode 100644
index 00000000000..5d880bbe061
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object MultiLayerPerceptron {
+
+  /**
+    * Gets MultiLayer Perceptron Model Symbol
+    * @param numClasses Number of classes to classify into
+    * @return model symbol
+    */
+  def getSymbol(numClasses: Int): Symbol = {
+    val data = Symbol.Variable("data")
+
+    val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, 
name = "fc1")
+    val act1 = Symbol.api.Activation(data = Some(fc1), "relu", name = "relu")
+    val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = 
"fc2")
+    val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
+    val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, numClasses, 
name = "fc3")
+    val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
+    mlp
+  }
+
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
new file mode 100644
index 00000000000..c3f43d97e89
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object Resnet {
+  /**
+    * Helper to produce individual residual unit
+    */
+  def residualUnit(data: Symbol, numFilter: Int, stride: Shape, dimMatch: 
Boolean,
+                   name: String = "", bottleNeck: Boolean = true, bnMom: Float 
= 0.9f,
+                   workspace: Int = 256, memonger: Boolean = false): Symbol = {
+    val (act1, operated) = if (bottleNeck) {
+      val bn1 = Symbol.api.BatchNorm(data = Some(data), fix_gamma = 
Some(false), eps = Some(2e-5),
+        momentum = Some(bnMom), name = name + "_bn1")
+      val act1: Symbol = Symbol.api.Activation(data = Some(bn1), act_type = 
"relu",
+        name = name + "_relu1")
+      val conv1 = Symbol.api.Convolution(data = Some(act1), num_filter = 
(numFilter * 0.25).toInt,
+        kernel = Shape(1, 1), stride = Some(Shape(1, 1)), pad = Some(Shape(0, 
0)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv1")
+      val bn2 = Symbol.api.BatchNorm(data = Some(conv1), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn2")
+      val act2 = Symbol.api.Activation(data = Some(bn2), act_type = "relu", 
name = name + "_relu2")
+      val conv2 = Symbol.api.Convolution(data = Some(act2), num_filter = 
(numFilter * 0.25).toInt,
+        kernel = Shape(3, 3), stride = Some(stride), pad = Some(Shape(1, 1)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv2")
+      val bn3 = Symbol.api.BatchNorm(data = Some(conv2), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn3")
+      val act3 = Symbol.api.Activation(data = Some(bn3), act_type = "relu", 
name = name + "_relu3")
+      val conv3 = Symbol.api.Convolution(data = Some(act3), num_filter = 
numFilter,
+        kernel = Shape(1, 1), stride = Some(Shape(1, 1)), pad = Some(Shape(0, 
0)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv3")
+      (act1, conv3)
+    } else {
+      val bn1 = Symbol.api.BatchNorm(data = Some(data), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn1")
+      val act1 = Symbol.api.Activation(data = Some(bn1), act_type = "relu", 
name = name + "_relu1")
+      val conv1 = Symbol.api.Convolution(data = Some(act1), num_filter = 
numFilter,
+        kernel = Shape(3, 3), stride = Some(stride), pad = Some(Shape(1, 1)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv1")
+      val bn2 = Symbol.api.BatchNorm(data = Some(conv1), fix_gamma = 
Some(false),
+        eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn2")
+      val act2 = Symbol.api.Activation(data = Some(bn2), act_type = "relu", 
name = name + "_relu2")
+      val conv2 = Symbol.api.Convolution(data = Some(act2), num_filter = 
numFilter,
+        kernel = Shape(3, 3), stride = Some(Shape(1, 1)), pad = Some(Shape(1, 
1)),
+        no_bias = Some(true), workspace = Some(workspace), name = name + 
"_conv2")
+      (act1, conv2)
+    }
+    val shortcut = if (dimMatch) {
+      data
+    } else {
+      Symbol.api.Convolution(Some(act1), num_filter = numFilter, kernel = 
Shape(1, 1),
+        stride = Some(stride), no_bias = Some(true), workspace = 
Some(workspace),
+        name = name + "_sc")
+    }
+    operated + shortcut
+  }
+
+  /**
+    * Helper for building the resnet Symbol
+    */
+  def resnet(units: List[Int], numStages: Int, filterList: List[Int], 
numClasses: Int,
+             imageShape: List[Int], bottleNeck: Boolean = true, bnMom: Float = 
0.9f,
+             workspace: Int = 256, dtype: String = "float32", memonger: 
Boolean = false): Symbol = {
+    assert(units.size == numStages)
+    var data = Symbol.Variable("data", shape = Shape(List(4) ::: imageShape), 
dType = DType.Float32)
+    if (dtype == "float32") {
+      data = Symbol.api.identity(Some(data), "id")
+    } else if (dtype == "float16") {
+      data = Symbol.api.cast(Some(data), "float16")
+    }
+    data = Symbol.api.BatchNorm(Some(data), fix_gamma = Some(true), eps = 
Some(2e-5),
+      momentum = Some(bnMom), name = "bn_data")
+    val List(channels, height, width) = imageShape
+    var body = if (height <= 32) {
+      Symbol.api.Convolution(Some(data), num_filter = filterList.head, kernel 
= Shape(7, 7),
+        stride = Some(Shape(1, 1)), pad = Some(Shape(1, 1)), no_bias = 
Some(true), name = "conv0",
+        workspace = Some(workspace))
+    } else {
+      var body0 = Symbol.api.Convolution(Some(data), num_filter = 
filterList.head,
+        kernel = Shape(3, 3), stride = Some(Shape(2, 2)), pad = Some(Shape(3, 
3)),
+        no_bias = Some(true), name = "conv0", workspace = Some(workspace))
+      body0 = Symbol.api.BatchNorm(Some(body0), fix_gamma = Some(false), eps = 
Some(2e-5),
+        momentum = Some(bnMom), name = "bn0")
+      body0 = Symbol.api.Activation(Some(body0), act_type = "relu", name = 
"relu0")
+      Symbol.api.Pooling(Some(body0), kernel = Some(Shape(3, 3)), stride = 
Some(Shape(2, 2)),
+        pad = Some(Shape(1, 1)), pool_type = Some("max"))
+    }
+    for (((filter, i), unit) <- filterList.tail.zipWithIndex.zip(units)) {
+      val stride = Shape(if (i == 0) 1 else 2, if (i == 0) 1 else 2)
+      body = residualUnit(body, filter, stride, false, name = s"stage${i + 
1}_unit${1}",
+        bottleNeck = bottleNeck, workspace = workspace, memonger = memonger)
+      for (j <- 0 until unit - 1) {
+        body = residualUnit(body, filter, Shape(1, 1), true, s"stage${i + 
1}_unit${j + 2}",
+          bottleNeck, workspace = workspace, memonger = memonger)
+      }
+    }
+    val bn1 = Symbol.api.BatchNorm(Some(body), fix_gamma = Some(false), eps = 
Some(2e-5),
+      momentum = Some(bnMom), name = "bn1")
+    val relu1 = Symbol.api.Activation(Some(bn1), act_type = "relu", name = 
"relu1")
+    val pool1 = Symbol.api.Pooling(Some(relu1), global_pool = Some(true),
+      kernel = Some(Shape(7, 7)), pool_type = Some("avg"), name = "pool1")
+    val flat = Symbol.api.Flatten(Some(pool1))
+    var fc1 = Symbol.api.FullyConnected(Some(flat), num_hidden = numClasses, 
name = "fc1")
+    if (dtype == "float16") {
+      fc1 = Symbol.api.cast(Some(fc1), "float32")
+    }
+    Symbol.api.SoftmaxOutput(Some(fc1), name = "softmax")
+  }
+
+  /**
+    * Gets the resnet model symbol
+    * @param numClasses Number of classes to classify into
+    * @param numLayers Number of residual layers
+    * @param imageShape The image shape as List(channels, height, width)
+    * @param convWorkspace Maximum temporary workspace allowed (MB) in 
convolutions
+    * @param dtype Type of data (float16, float32, etc) to use during 
computation
+    * @return Model symbol
+    */
+  def getSymbol(numClasses: Int, numLayers: Int, imageShape: List[Int], 
convWorkspace: Int = 256,
+                dtype: String = "float32"): Symbol = {
+    val List(channels, height, width) = imageShape
+    val (numStages, units, filterList, bottleNeck): (Int, List[Int], 
List[Int], Boolean) =
+      if (height <= 28) {
+        val (perUnit, filterList, bottleNeck) = if ((numLayers - 2) % 9 == 0 
&& numLayers > 165) {
+          (List(Math.floor((numLayers - 2) / 9).toInt),
+            List(16, 64, 128, 256),
+            true)
+        } else if ((numLayers - 2) % 6 == 0 && numLayers < 164) {
+          (List(Math.floor((numLayers - 2) / 6).toInt),
+            List(16, 16, 32, 64),
+            false)
+        } else {
+          throw new Exception(s"Invalid number of layers: ${numLayers}")
+        }
+        val numStages = 3
+        val units = (1 to numStages).map(_ => perUnit.head).toList
+        (numStages, units, filterList, bottleNeck)
+      } else {
+        val (filterList, bottleNeck) = if (numLayers >= 50) {
+          (List(64, 256, 512, 1024, 2048), true)
+        } else {
+          (List(64, 64, 128, 256, 512), false)
+        }
+        val units: List[Int] = Map(
+          18 -> List(2, 2, 2, 2),
+          34 -> List(3, 4, 6, 3),
+          50 -> List(3, 4, 6, 3),
+          101 -> List(3, 4, 23, 3),
+          152 -> List(3, 8, 36, 3),
+          200 -> List(3, 24, 36, 3),
+          269 -> List(3, 30, 48, 8)
+        ).get(numLayers) match {
+          case Some(x) => x
+          case None => throw new Exception(s"Invalid number of layers: 
${numLayers}")
+        }
+        (4, units, filterList, bottleNeck)
+      }
+    resnet(units, numStages, filterList, numClasses, imageShape, bottleNeck,
+      workspace = convWorkspace, dtype = dtype)
+  }
+}
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
similarity index 73%
rename from 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
rename to 
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
index 1a77775b985..9a54e58b653 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
@@ -15,19 +15,37 @@
  * limitations under the License.
  */
 
-package org.apache.mxnetexamples.imclassification
+package org.apache.mxnetexamples.imclassification.util
 
 import org.apache.mxnet.Callback.Speedometer
 import org.apache.mxnet._
 import org.apache.mxnet.optimizer.SGD
 import org.slf4j.LoggerFactory
 
-object ModelTrain {
-  private val logger = LoggerFactory.getLogger(classOf[ModelTrain])
+object Trainer {
+  private val logger = LoggerFactory.getLogger(classOf[Trainer])
 
+  /**
+    * Fits a model
+    * @param batchSize Number of images per training batch
+    * @param numExamples Total number of image examples
+    * @param devs List of device contexts to use
+    * @param network The model to train
+    * @param dataLoader Function to get data loaders for training and 
validation data
+    * @param kvStore KVStore to use
+    * @param numEpochs Number of times to train on each image
+    * @param modelPrefix Prefix to model identification
+    * @param loadEpoch Loads a saved checkpoint at this epoch when set
+    * @param lr The learning rate
+    * @param lrFactor Learning rate factor (see FactorScheduler)
+    * @param lrFactorEpoch Learning rate factor epoch (see FactorScheduler)
+    * @param clipGradient Maximum gradient during optimization
+    * @param monitorSize (See Monitor)
+    * @return Final accuracy
+    */
   // scalastyle:off parameterNum
-  def fit(dataDir: String, batchSize: Int, numExamples: Int, devs: 
Array[Context],
-          network: Symbol, dataLoader: (String, Int, KVStore) => (DataIter, 
DataIter),
+  def fit(batchSize: Int, numExamples: Int, devs: Array[Context],
+          network: Symbol, dataLoader: (Int, KVStore) => (DataIter, DataIter),
           kvStore: String, numEpochs: Int, modelPrefix: String = null, 
loadEpoch: Int = -1,
           lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
           clipGradient: Float = 0f, monitorSize: Int = -1): Accuracy = {
@@ -60,7 +78,7 @@ object ModelTrain {
       }
 
     // data
-    val (train, validation) = dataLoader(dataDir, batchSize, kv)
+    val (train, validation) = dataLoader(batchSize, kv)
 
     // train
     val epochSize =
@@ -75,8 +93,8 @@ object ModelTrain {
         null
       }
     val optimizer: Optimizer = new SGD(learningRate = lr,
-        lrScheduler = lrScheduler, clipGradient = clipGradient,
-        momentum = 0.9f, wd = 0.00001f)
+      lrScheduler = lrScheduler, clipGradient = clipGradient,
+      momentum = 0.9f, wd = 0.00001f)
 
     // disable kvstore for single device
     if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType 
!= "gpu")) {
@@ -108,7 +126,9 @@ object ModelTrain {
     }
     acc
   }
+
   // scalastyle:on parameterNum
 }
 
-class ModelTrain
+class Trainer
+
diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
similarity index 53%
rename from 
scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
rename to 
scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
index 0fd3af02d9c..6e9667abe9c 100644
--- 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
@@ -18,9 +18,7 @@
 package org.apache.mxnetexamples.imclassification
 
 import java.io.File
-import java.net.URL
 
-import org.apache.commons.io.FileUtils
 import org.apache.mxnet.Context
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -31,27 +29,35 @@ import scala.sys.process.Process
 /**
   * Integration test for MNIST example.
   */
-class MNISTExampleSuite extends FunSuite with BeforeAndAfterAll {
-  private val logger = LoggerFactory.getLogger(classOf[MNISTExampleSuite])
+class IMClassificationExampleSuite extends FunSuite with BeforeAndAfterAll {
+  private val logger = 
LoggerFactory.getLogger(classOf[IMClassificationExampleSuite])
 
   test("Example CI: Test MNIST Training") {
 
-      logger.info("Downloading mnist model")
-      val baseUrl = 
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci";
-      val tempDirPath = System.getProperty("java.io.tmpdir")
-      val modelDirPath = tempDirPath + File.separator + "mnist/"
-      logger.info("tempDirPath: %s".format(tempDirPath))
-      Util.downloadUrl(baseUrl + "/mnist/mnist.zip",
-        tempDirPath + "/mnist/mnist.zip")
-      // TODO: Need to confirm with Windows
-      Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
-        + tempDirPath + "/mnist/") !
+    logger.info("Downloading mnist model")
+    val baseUrl = 
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci";
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    val modelDirPath = tempDirPath + File.separator + "mnist/"
+    logger.info("tempDirPath: %s".format(tempDirPath))
+    Util.downloadUrl(baseUrl + "/mnist/mnist.zip",
+      tempDirPath + "/mnist/mnist.zip")
+    // TODO: Need to confirm with Windows
+    Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+      + tempDirPath + "/mnist/") !
 
-      var context = Context.cpu()
+    var context = Context.cpu()
 
-      val output = TrainMnist.test(modelDirPath)
-      Process("rm -rf " + modelDirPath) !
+    val valAccuracy = TrainModel.test("mlp", modelDirPath)
+    Process("rm -rf " + modelDirPath) !
 
-      assert(output >= 0.95f)
+    assert(valAccuracy >= 0.95f)
   }
+
+  for(model <- List("mlp", "lenet", "resnet")) {
+    test(s"Example CI: Test Image Classification Model ${model}") {
+      var context = Context.cpu()
+      val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true)
+    }
+  }
+
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to