This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 15deabc [MXNET-947] Expand scala imclassification example with resnet
(#12639)
15deabc is described below
commit 15deabc689eb5810656307da6d3f295f0de4e242
Author: Zach Kimberg <[email protected]>
AuthorDate: Wed Oct 10 13:38:39 2018 -0700
[MXNET-947] Expand scala imclassification example with resnet (#12639)
* [MXNET-947] Scala imclassification example with Resnet
---
.../mxnetexamples/imclassification/README.md | 45 ++++--
.../{TrainMnist.scala => TrainModel.scala} | 175 ++++++++++----------
.../imclassification/datasets/MnistIter.scala | 59 +++++++
.../datasets/SyntheticDataIter.scala | 70 ++++++++
.../imclassification/models/Lenet.scala | 52 ++++++
.../models/MultiLayerPerceptron.scala | 41 +++++
.../imclassification/models/Resnet.scala | 178 +++++++++++++++++++++
.../{ModelTrain.scala => util/Trainer.scala} | 38 +++--
...te.scala => IMClassificationExampleSuite.scala} | 42 ++---
9 files changed, 575 insertions(+), 125 deletions(-)
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
index 5141f44..cec750a 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/README.md
@@ -1,17 +1,36 @@
-# MNIST Example for Scala
-This is the MNIST Training Example implemented for Scala type-safe api
+# Image Classification Models
+
+This examples contains a number of image classification models that can be run
on various datasets.
+
+## Models
+
+Currently, the following models are supported:
+- MultiLayerPerceptron
+- Lenet
+- Resnet
+
+## Datasets
+
+Currently, the following datasets are supported:
+- MNIST
+
+#### Synthetic Benchmark Data
+
+Additionally, the datasets can be replaced by randomly generated data for
benchmarking.
+Data is produced to match the shapes of the supported datasets above.
+
+The following additional dataset image shapes are also defined for use with
the benchmark synthetic data:
+- imagenet
+
+
+
## Setup
-### Download the source File
+
+### MNIST
+
+For this dataset, the data must be downloaded and extracted from the source or
```$xslt
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/mnist/mnist.zip
```
-### Unzip the file
-```$xslt
-unzip mnist.zip
-```
-### Arguement Configuration
-Then you need to define the arguments that you would like to pass in the model:
-```$xslt
---data-dir <location of your downloaded file>
-```
-You can find more information
[here](https://github.com/apache/incubator-mxnet/blob/scala-package/examples/src/main/scala/org/apache/mxnet/examples/imclassification/TrainMnist.scala#L169-L207)
\ No newline at end of file
+
+Afterwards, the location of the data folder must be passed in through the
`--data-dir` argument.
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
similarity index 55%
rename from
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
rename to
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index 2f024fd..608e191 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -17,109 +17,106 @@
package org.apache.mxnetexamples.imclassification
+import java.util.concurrent._
+
+import org.apache.mxnetexamples.imclassification.models._
+import org.apache.mxnetexamples.imclassification.util.Trainer
import org.apache.mxnet._
+import org.apache.mxnetexamples.imclassification.datasets.{MnistIter,
SyntheticDataIter}
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
import scala.collection.JavaConverters._
import scala.collection.mutable
-object TrainMnist {
- private val logger = LoggerFactory.getLogger(classOf[TrainMnist])
-
- // multi-layer perceptron
- def getMlp: Symbol = {
- val data = Symbol.Variable("data")
-
- // val fc1 = Symbol.FullyConnected(name = "relu")()(Map("data" -> data,
"act_type" -> "relu"))
- val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128,
name = "fc1")
- val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
- val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name =
"fc2")
- val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
- val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name =
"fc3")
- val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
- mlp
- }
-
- def getLenet: Symbol = {
- val data = Symbol.Variable("data")
- // first conv
- val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5,
5), num_filter = 20)
- val tanh1 = Symbol.api.tanh(data = Some(conv1))
- val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
- kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
- // second conv
- val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5,
5), num_filter = 50)
- val tanh2 = Symbol.api.tanh(data = Some(conv2))
- val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
- kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
- // first fullc
- val flatten = Symbol.api.Flatten(data = Some(pool2))
- val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
- val tanh3 = Symbol.api.tanh(data = Some(fc1))
- // second fullc
- val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10)
- // loss
- val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
- lenet
- }
-
- def getIterator(dataShape: Shape)
- (dataDir: String, batchSize: Int, kv: KVStore): (DataIter, DataIter) = {
- val flat = if (dataShape.size == 3) "False" else "True"
-
- val train = IO.MNISTIter(Map(
- "image" -> (dataDir + "train-images-idx3-ubyte"),
- "label" -> (dataDir + "train-labels-idx1-ubyte"),
- "label_name" -> "softmax_label",
- "input_shape" -> dataShape.toString,
- "batch_size" -> batchSize.toString,
- "shuffle" -> "True",
- "flat" -> flat,
- "num_parts" -> kv.numWorkers.toString,
- "part_index" -> kv.`rank`.toString))
-
- val eval = IO.MNISTIter(Map(
- "image" -> (dataDir + "t10k-images-idx3-ubyte"),
- "label" -> (dataDir + "t10k-labels-idx1-ubyte"),
- "label_name" -> "softmax_label",
- "input_shape" -> dataShape.toString,
- "batch_size" -> batchSize.toString,
- "flat" -> flat,
- "num_parts" -> kv.numWorkers.toString,
- "part_index" -> kv.`rank`.toString))
-
- (train, eval)
- }
-
- def test(dataPath : String) : Float = {
+object TrainModel {
+ private val logger = LoggerFactory.getLogger(classOf[TrainModel])
+
+ /**
+ * Simple model training and execution
+ * @param model The model identifying string
+ * @param dataPath Path to location of image data
+ * @param numExamples Number of image data examples
+ * @param numEpochs Number of epochs to train for
+ * @param benchmark Whether to use benchmark synthetic data instead of real
image data
+ * @return The final validation accuracy
+ */
+ def test(model: String, dataPath: String, numExamples: Int = 60000,
+ numEpochs: Int = 10, benchmark: Boolean = false): Float = {
NDArrayCollector.auto().withScope {
- val (dataShape, net) = (Shape(784), getMlp)
val devs = Array(Context.cpu(0))
val envs: mutable.Map[String, String] = mutable.HashMap.empty[String,
String]
- val Acc = ModelTrain.fit(dataDir = dataPath,
- batchSize = 128, numExamples = 60000, devs = devs,
- network = net, dataLoader = getIterator(dataShape),
- kvStore = "local", numEpochs = 10)
+ val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
+ numExamples = numExamples, benchmark = benchmark)
+ val Acc = Trainer.fit(batchSize = 128, numExamples, devs = devs,
+ network = net, dataLoader = dataLoader,
+ kvStore = "local", numEpochs = numEpochs)
logger.info("Finish test fit ...")
val (_, num) = Acc.get
num(0)
}
}
+ /**
+ * Gets dataset iterator and model symbol
+ * @param dataset The dataset identifying string
+ * @param model The model identifying string
+ * @param dataDir Path to location of image data
+ * @param numLayers The number of model layers (resnet only)
+ * @param numExamples The number of examples in the dataset
+ * @param benchmark Whether to use benchmark synthetic data instead of real
image data
+ * @return Data iterator (partially applied function) and model symbol
+ */
+ def dataLoaderAndModel(dataset: String, model: String, dataDir: String = "",
+ numLayers: Int = 50, numExamples: Int = 60000,
+ benchmark: Boolean = false
+ ): ((Int, KVStore) => (DataIter, DataIter), Symbol) = {
+ val (imageShape, numClasses) = dataset match {
+ case "mnist" => (List(1, 28, 28), 10)
+ case "imagenet" => (List(3, 224, 224), 1000)
+ case _ => throw new Exception("Invalid image data collection")
+ }
+
+ val List(channels, height, width) = imageShape
+ val dataSize: Int = channels * height * width
+ val (datumShape, net) = model match {
+ case "mlp" => (List(dataSize),
MultiLayerPerceptron.getSymbol(numClasses))
+ case "lenet" => (List(channels, height, width),
Lenet.getSymbol(numClasses))
+ case "resnet" => (List(channels, height, width),
Resnet.getSymbol(numClasses,
+ numLayers, imageShape))
+ case _ => throw new Exception("Invalid model name")
+ }
+ val dataLoader: (Int, KVStore) => (DataIter, DataIter) = if (benchmark) {
+ (batchSize: Int, kv: KVStore) => {
+ val iter = new SyntheticDataIter(numClasses, batchSize, datumShape,
List(), numExamples)
+ (iter, iter)
+ }
+ } else {
+ dataset match {
+ case "mnist" => MnistIter.getIterator(Shape(datumShape), dataDir)
+ case _ => throw new Exception("This image data collection only
supports the"
+ + "synthetic benchmark iterator. Use --benchmark to enable")
+ }
+ }
+ (dataLoader, net)
+ }
+
+ /**
+ * Runs image classification training from CLI with various options
+ * @param args CLI args
+ */
def main(args: Array[String]): Unit = {
- val inst = new TrainMnist
+ val inst = new TrainModel
val parser: CmdLineParser = new CmdLineParser(inst)
try {
parser.parseArgument(args.toList.asJava)
val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
- else inst.dataDir
+ else inst.dataDir
- val (dataShape, net) =
- if (inst.network == "mlp") (Shape(784), getMlp)
- else (Shape(1, 28, 28), getLenet)
+ val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network,
dataPath,
+ inst.numLayers, inst.numExamples, inst.benchmark)
val devs =
if (inst.gpus != null) inst.gpus.split(',').map(id =>
Context.gpu(id.trim.toInt))
@@ -144,9 +141,8 @@ object TrainMnist {
logger.info("Start KVStoreServer for scheduler & servers")
KVStoreServer.start()
} else {
- ModelTrain.fit(dataDir = inst.dataDir,
- batchSize = inst.batchSize, numExamples = inst.numExamples, devs =
devs,
- network = net, dataLoader = getIterator(dataShape),
+ Trainer.fit(batchSize = inst.batchSize, numExamples =
inst.numExamples, devs = devs,
+ network = net, dataLoader = dataLoader,
kvStore = inst.kvStore, numEpochs = inst.numEpochs,
modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch =
inst.lrFactorEpoch,
@@ -163,11 +159,19 @@ object TrainMnist {
}
}
-class TrainMnist {
- @Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet']")
+class TrainModel {
+ @Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet',
'resnet']")
private val network: String = "mlp"
+ @Option(name = "--num-layers", usage = "the number of resnet layers to use")
+ private val numLayers: Int = 50
@Option(name = "--data-dir", usage = "the input data directory")
private val dataDir: String = "mnist/"
+
+ @Option(name = "--dataset", usage = "the images to classify: ['mnist',
'imagenet']")
+ private val dataset: String = "mnist"
+ @Option(name = "--benchmark", usage = "Benchmark to use synthetic data to
measure performance")
+ private val benchmark: Boolean = false
+
@Option(name = "--gpus", usage = "the gpus will be used, e.g. '0,1,2,3'")
private val gpus: String = null
@Option(name = "--cpus", usage = "the cpus will be used, e.g. '0,1,2,3'")
@@ -187,7 +191,7 @@ class TrainMnist {
@Option(name = "--kv-store", usage = "the kvstore type")
private val kvStore = "local"
@Option(name = "--lr-factor",
- usage = "times the lr with a factor for every lr-factor-epoch epoch")
+ usage = "times the lr with a factor for every lr-factor-epoch epoch")
private val lrFactor: Float = 1f
@Option(name = "--lr-factor-epoch", usage = "the number of epoch to factor
the lr, could be .5")
private val lrFactorEpoch: Float = 1f
@@ -205,3 +209,4 @@ class TrainMnist {
@Option(name = "--num-server", usage = "# of servers")
private val numServer: Int = 1
}
+
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
new file mode 100644
index 0000000..9e6e1c2
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/MnistIter.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.datasets
+
+import org.apache.mxnet._
+
+object MnistIter {
+ /**
+ * Returns an iterator over the MNIST dataset
+ * @param dataShape Image size (channels, height, width)
+ * @param dataDir The path to the image data
+ * @param batchSize Number of images per batch
+ * @param kv KVStore to use
+ * @return
+ */
+ def getIterator(dataShape: Shape, dataDir: String)
+ (batchSize: Int, kv: KVStore): (DataIter, DataIter) = {
+ val flat = if (dataShape.size == 3) "False" else "True"
+
+ val train = IO.MNISTIter(Map(
+ "image" -> (dataDir + "train-images-idx3-ubyte"),
+ "label" -> (dataDir + "train-labels-idx1-ubyte"),
+ "label_name" -> "softmax_label",
+ "input_shape" -> dataShape.toString,
+ "batch_size" -> batchSize.toString,
+ "shuffle" -> "True",
+ "flat" -> flat,
+ "num_parts" -> kv.numWorkers.toString,
+ "part_index" -> kv.`rank`.toString))
+
+ val eval = IO.MNISTIter(Map(
+ "image" -> (dataDir + "t10k-images-idx3-ubyte"),
+ "label" -> (dataDir + "t10k-labels-idx1-ubyte"),
+ "label_name" -> "softmax_label",
+ "input_shape" -> dataShape.toString,
+ "batch_size" -> batchSize.toString,
+ "flat" -> flat,
+ "num_parts" -> kv.numWorkers.toString,
+ "part_index" -> kv.`rank`.toString))
+
+ (train, eval)
+ }
+
+}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
new file mode 100644
index 0000000..9421f10
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.datasets
+
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
+
+import scala.collection.immutable.ListMap
+import scala.util.Random
+
+class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape:
List[Int],
+ labelShape: List[Int], maxIter: Int, dtype: DType =
DType.Float32
+ ) extends DataIter {
+ var curIter = 0
+ val random = new Random()
+ val shape = Shape(batchSize :: datumShape)
+ val batchLabelShape = Shape(batchSize :: labelShape)
+
+ val maxLabel = if (labelShape.isEmpty) numClasses.toFloat else 1f
+ var label: IndexedSeq[NDArray] = IndexedSeq(
+ NDArray.api.random_uniform(Some(0f), Some(maxLabel), shape =
Some(batchLabelShape)))
+ var data: IndexedSeq[NDArray] = IndexedSeq(
+ NDArray.api.random_uniform(shape = Some(shape)))
+
+ val provideDataDesc: IndexedSeq[DataDesc] = IndexedSeq(
+ new DataDesc("data", shape, dtype, Layout.UNDEFINED))
+ val provideLabelDesc: IndexedSeq[DataDesc] = IndexedSeq(
+ new DataDesc("softmax_label", batchLabelShape, dtype, Layout.UNDEFINED))
+ val getPad: Int = 0
+
+ override def getData(): IndexedSeq[NDArray] = data
+
+ override def getIndex: IndexedSeq[Long] = IndexedSeq(curIter)
+
+ override def getLabel: IndexedSeq[NDArray] = label
+
+ override def hasNext: Boolean = curIter < maxIter - 1
+
+ override def next(): DataBatch = {
+ if (hasNext) {
+ curIter += batchSize
+ new DataBatch(data, label, getIndex, getPad, null, null, null)
+ } else {
+ throw new NoSuchElementException
+ }
+ }
+
+ override def reset(): Unit = {
+ curIter = 0
+ }
+
+ override def provideData: ListMap[String, Shape] = ListMap("data" -> shape)
+
+ override def provideLabel: ListMap[String, Shape] = ListMap("softmax_label"
-> batchLabelShape)
+}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
new file mode 100644
index 0000000..76fb7bb
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object Lenet {
+
+ /**
+ * Gets Lenet Model Symbol
+ * @param numClasses Number of classes to classify into
+ * @return model symbol
+ */
+ def getSymbol(numClasses: Int): Symbol = {
+ val data = Symbol.Variable("data")
+ // first conv
+ val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5,
5), num_filter = 20)
+ val tanh1 = Symbol.api.tanh(data = Some(conv1))
+ val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+ // second conv
+ val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5,
5), num_filter = 50)
+ val tanh2 = Symbol.api.tanh(data = Some(conv2))
+ val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
+ // first fullc
+ val flatten = Symbol.api.Flatten(data = Some(pool2))
+ val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
+ val tanh3 = Symbol.api.tanh(data = Some(fc1))
+ // second fullc
+ val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden =
numClasses)
+ // loss
+ val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
+ lenet
+ }
+
+}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
new file mode 100644
index 0000000..5d880bb
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object MultiLayerPerceptron {
+
+ /**
+ * Gets MultiLayer Perceptron Model Symbol
+ * @param numClasses Number of classes to classify into
+ * @return model symbol
+ */
+ def getSymbol(numClasses: Int): Symbol = {
+ val data = Symbol.Variable("data")
+
+ val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128,
name = "fc1")
+ val act1 = Symbol.api.Activation(data = Some(fc1), "relu", name = "relu")
+ val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name =
"fc2")
+ val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
+ val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, numClasses,
name = "fc3")
+ val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
+ mlp
+ }
+
+}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
new file mode 100644
index 0000000..c3f43d9
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.imclassification.models
+
+import org.apache.mxnet._
+
+object Resnet {
+ /**
+ * Helper to produce individual residual unit
+ */
+ def residualUnit(data: Symbol, numFilter: Int, stride: Shape, dimMatch:
Boolean,
+ name: String = "", bottleNeck: Boolean = true, bnMom: Float
= 0.9f,
+ workspace: Int = 256, memonger: Boolean = false): Symbol = {
+ val (act1, operated) = if (bottleNeck) {
+ val bn1 = Symbol.api.BatchNorm(data = Some(data), fix_gamma =
Some(false), eps = Some(2e-5),
+ momentum = Some(bnMom), name = name + "_bn1")
+ val act1: Symbol = Symbol.api.Activation(data = Some(bn1), act_type =
"relu",
+ name = name + "_relu1")
+ val conv1 = Symbol.api.Convolution(data = Some(act1), num_filter =
(numFilter * 0.25).toInt,
+ kernel = Shape(1, 1), stride = Some(Shape(1, 1)), pad = Some(Shape(0,
0)),
+ no_bias = Some(true), workspace = Some(workspace), name = name +
"_conv1")
+ val bn2 = Symbol.api.BatchNorm(data = Some(conv1), fix_gamma =
Some(false),
+ eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn2")
+ val act2 = Symbol.api.Activation(data = Some(bn2), act_type = "relu",
name = name + "_relu2")
+ val conv2 = Symbol.api.Convolution(data = Some(act2), num_filter =
(numFilter * 0.25).toInt,
+ kernel = Shape(3, 3), stride = Some(stride), pad = Some(Shape(1, 1)),
+ no_bias = Some(true), workspace = Some(workspace), name = name +
"_conv2")
+ val bn3 = Symbol.api.BatchNorm(data = Some(conv2), fix_gamma =
Some(false),
+ eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn3")
+ val act3 = Symbol.api.Activation(data = Some(bn3), act_type = "relu",
name = name + "_relu3")
+ val conv3 = Symbol.api.Convolution(data = Some(act3), num_filter =
numFilter,
+ kernel = Shape(1, 1), stride = Some(Shape(1, 1)), pad = Some(Shape(0,
0)),
+ no_bias = Some(true), workspace = Some(workspace), name = name +
"_conv3")
+ (act1, conv3)
+ } else {
+ val bn1 = Symbol.api.BatchNorm(data = Some(data), fix_gamma =
Some(false),
+ eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn1")
+ val act1 = Symbol.api.Activation(data = Some(bn1), act_type = "relu",
name = name + "_relu1")
+ val conv1 = Symbol.api.Convolution(data = Some(act1), num_filter =
numFilter,
+ kernel = Shape(3, 3), stride = Some(stride), pad = Some(Shape(1, 1)),
+ no_bias = Some(true), workspace = Some(workspace), name = name +
"_conv1")
+ val bn2 = Symbol.api.BatchNorm(data = Some(conv1), fix_gamma =
Some(false),
+ eps = Some(2e-5), momentum = Some(bnMom), name = name + "_bn2")
+ val act2 = Symbol.api.Activation(data = Some(bn2), act_type = "relu",
name = name + "_relu2")
+ val conv2 = Symbol.api.Convolution(data = Some(act2), num_filter =
numFilter,
+ kernel = Shape(3, 3), stride = Some(Shape(1, 1)), pad = Some(Shape(1,
1)),
+ no_bias = Some(true), workspace = Some(workspace), name = name +
"_conv2")
+ (act1, conv2)
+ }
+ val shortcut = if (dimMatch) {
+ data
+ } else {
+ Symbol.api.Convolution(Some(act1), num_filter = numFilter, kernel =
Shape(1, 1),
+ stride = Some(stride), no_bias = Some(true), workspace =
Some(workspace),
+ name = name + "_sc")
+ }
+ operated + shortcut
+ }
+
+ /**
+ * Helper for building the resnet Symbol
+ */
+ def resnet(units: List[Int], numStages: Int, filterList: List[Int],
numClasses: Int,
+ imageShape: List[Int], bottleNeck: Boolean = true, bnMom: Float =
0.9f,
+ workspace: Int = 256, dtype: String = "float32", memonger:
Boolean = false): Symbol = {
+ assert(units.size == numStages)
+ var data = Symbol.Variable("data", shape = Shape(List(4) ::: imageShape),
dType = DType.Float32)
+ if (dtype == "float32") {
+ data = Symbol.api.identity(Some(data), "id")
+ } else if (dtype == "float16") {
+ data = Symbol.api.cast(Some(data), "float16")
+ }
+ data = Symbol.api.BatchNorm(Some(data), fix_gamma = Some(true), eps =
Some(2e-5),
+ momentum = Some(bnMom), name = "bn_data")
+ val List(channels, height, width) = imageShape
+ var body = if (height <= 32) {
+ Symbol.api.Convolution(Some(data), num_filter = filterList.head, kernel
= Shape(7, 7),
+ stride = Some(Shape(1, 1)), pad = Some(Shape(1, 1)), no_bias =
Some(true), name = "conv0",
+ workspace = Some(workspace))
+ } else {
+ var body0 = Symbol.api.Convolution(Some(data), num_filter =
filterList.head,
+ kernel = Shape(3, 3), stride = Some(Shape(2, 2)), pad = Some(Shape(3,
3)),
+ no_bias = Some(true), name = "conv0", workspace = Some(workspace))
+ body0 = Symbol.api.BatchNorm(Some(body0), fix_gamma = Some(false), eps =
Some(2e-5),
+ momentum = Some(bnMom), name = "bn0")
+ body0 = Symbol.api.Activation(Some(body0), act_type = "relu", name =
"relu0")
+ Symbol.api.Pooling(Some(body0), kernel = Some(Shape(3, 3)), stride =
Some(Shape(2, 2)),
+ pad = Some(Shape(1, 1)), pool_type = Some("max"))
+ }
+ for (((filter, i), unit) <- filterList.tail.zipWithIndex.zip(units)) {
+ val stride = Shape(if (i == 0) 1 else 2, if (i == 0) 1 else 2)
+ body = residualUnit(body, filter, stride, false, name = s"stage${i +
1}_unit${1}",
+ bottleNeck = bottleNeck, workspace = workspace, memonger = memonger)
+ for (j <- 0 until unit - 1) {
+ body = residualUnit(body, filter, Shape(1, 1), true, s"stage${i +
1}_unit${j + 2}",
+ bottleNeck, workspace = workspace, memonger = memonger)
+ }
+ }
+ val bn1 = Symbol.api.BatchNorm(Some(body), fix_gamma = Some(false), eps =
Some(2e-5),
+ momentum = Some(bnMom), name = "bn1")
+ val relu1 = Symbol.api.Activation(Some(bn1), act_type = "relu", name =
"relu1")
+ val pool1 = Symbol.api.Pooling(Some(relu1), global_pool = Some(true),
+ kernel = Some(Shape(7, 7)), pool_type = Some("avg"), name = "pool1")
+ val flat = Symbol.api.Flatten(Some(pool1))
+ var fc1 = Symbol.api.FullyConnected(Some(flat), num_hidden = numClasses,
name = "fc1")
+ if (dtype == "float16") {
+ fc1 = Symbol.api.cast(Some(fc1), "float32")
+ }
+ Symbol.api.SoftmaxOutput(Some(fc1), name = "softmax")
+ }
+
+ /**
+ * Gets the resnet model symbol
+ * @param numClasses Number of classes to classify into
+ * @param numLayers Number of residual layers
+ * @param imageShape The image shape as List(channels, height, width)
+ * @param convWorkspace Maximum temporary workspace allowed (MB) in
convolutions
+ * @param dtype Type of data (float16, float32, etc) to use during
computation
+ * @return Model symbol
+ */
+ def getSymbol(numClasses: Int, numLayers: Int, imageShape: List[Int],
convWorkspace: Int = 256,
+ dtype: String = "float32"): Symbol = {
+ val List(channels, height, width) = imageShape
+ val (numStages, units, filterList, bottleNeck): (Int, List[Int],
List[Int], Boolean) =
+ if (height <= 28) {
+ val (perUnit, filterList, bottleNeck) = if ((numLayers - 2) % 9 == 0
&& numLayers > 165) {
+ (List(Math.floor((numLayers - 2) / 9).toInt),
+ List(16, 64, 128, 256),
+ true)
+ } else if ((numLayers - 2) % 6 == 0 && numLayers < 164) {
+ (List(Math.floor((numLayers - 2) / 6).toInt),
+ List(16, 16, 32, 64),
+ false)
+ } else {
+ throw new Exception(s"Invalid number of layers: ${numLayers}")
+ }
+ val numStages = 3
+ val units = (1 to numStages).map(_ => perUnit.head).toList
+ (numStages, units, filterList, bottleNeck)
+ } else {
+ val (filterList, bottleNeck) = if (numLayers >= 50) {
+ (List(64, 256, 512, 1024, 2048), true)
+ } else {
+ (List(64, 64, 128, 256, 512), false)
+ }
+ val units: List[Int] = Map(
+ 18 -> List(2, 2, 2, 2),
+ 34 -> List(3, 4, 6, 3),
+ 50 -> List(3, 4, 6, 3),
+ 101 -> List(3, 4, 23, 3),
+ 152 -> List(3, 8, 36, 3),
+ 200 -> List(3, 24, 36, 3),
+ 269 -> List(3, 30, 48, 8)
+ ).get(numLayers) match {
+ case Some(x) => x
+ case None => throw new Exception(s"Invalid number of layers:
${numLayers}")
+ }
+ (4, units, filterList, bottleNeck)
+ }
+ resnet(units, numStages, filterList, numClasses, imageShape, bottleNeck,
+ workspace = convWorkspace, dtype = dtype)
+ }
+}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
similarity index 73%
rename from
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
rename to
scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
index 1a77775..9a54e58 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/ModelTrain.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
@@ -15,19 +15,37 @@
* limitations under the License.
*/
-package org.apache.mxnetexamples.imclassification
+package org.apache.mxnetexamples.imclassification.util
import org.apache.mxnet.Callback.Speedometer
import org.apache.mxnet._
import org.apache.mxnet.optimizer.SGD
import org.slf4j.LoggerFactory
-object ModelTrain {
- private val logger = LoggerFactory.getLogger(classOf[ModelTrain])
+object Trainer {
+ private val logger = LoggerFactory.getLogger(classOf[Trainer])
+ /**
+ * Fits a model
+ * @param batchSize Number of images per training batch
+ * @param numExamples Total number of image examples
+ * @param devs List of device contexts to use
+ * @param network The model to train
+ * @param dataLoader Function to get data loaders for training and
validation data
+ * @param kvStore KVStore to use
+ * @param numEpochs Number of times to train on each image
+ * @param modelPrefix Prefix to model identification
+ * @param loadEpoch Loads a saved checkpoint at this epoch when set
+ * @param lr The learning rate
+ * @param lrFactor Learning rate factor (see FactorScheduler)
+ * @param lrFactorEpoch Learning rate factor epoch (see FactorScheduler)
+ * @param clipGradient Maximum gradient during optimization
+ * @param monitorSize (See Monitor)
+ * @return Final accuracy
+ */
// scalastyle:off parameterNum
- def fit(dataDir: String, batchSize: Int, numExamples: Int, devs:
Array[Context],
- network: Symbol, dataLoader: (String, Int, KVStore) => (DataIter,
DataIter),
+ def fit(batchSize: Int, numExamples: Int, devs: Array[Context],
+ network: Symbol, dataLoader: (Int, KVStore) => (DataIter, DataIter),
kvStore: String, numEpochs: Int, modelPrefix: String = null,
loadEpoch: Int = -1,
lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
clipGradient: Float = 0f, monitorSize: Int = -1): Accuracy = {
@@ -60,7 +78,7 @@ object ModelTrain {
}
// data
- val (train, validation) = dataLoader(dataDir, batchSize, kv)
+ val (train, validation) = dataLoader(batchSize, kv)
// train
val epochSize =
@@ -75,8 +93,8 @@ object ModelTrain {
null
}
val optimizer: Optimizer = new SGD(learningRate = lr,
- lrScheduler = lrScheduler, clipGradient = clipGradient,
- momentum = 0.9f, wd = 0.00001f)
+ lrScheduler = lrScheduler, clipGradient = clipGradient,
+ momentum = 0.9f, wd = 0.00001f)
// disable kvstore for single device
if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType
!= "gpu")) {
@@ -108,7 +126,9 @@ object ModelTrain {
}
acc
}
+
// scalastyle:on parameterNum
}
-class ModelTrain
+class Trainer
+
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
similarity index 53%
rename from
scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
rename to
scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
index 0fd3af0..6e9667a 100644
---
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
@@ -18,9 +18,7 @@
package org.apache.mxnetexamples.imclassification
import java.io.File
-import java.net.URL
-import org.apache.commons.io.FileUtils
import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -31,27 +29,35 @@ import scala.sys.process.Process
/**
* Integration test for MNIST example.
*/
-class MNISTExampleSuite extends FunSuite with BeforeAndAfterAll {
- private val logger = LoggerFactory.getLogger(classOf[MNISTExampleSuite])
+class IMClassificationExampleSuite extends FunSuite with BeforeAndAfterAll {
+ private val logger =
LoggerFactory.getLogger(classOf[IMClassificationExampleSuite])
test("Example CI: Test MNIST Training") {
- logger.info("Downloading mnist model")
- val baseUrl =
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
- val tempDirPath = System.getProperty("java.io.tmpdir")
- val modelDirPath = tempDirPath + File.separator + "mnist/"
- logger.info("tempDirPath: %s".format(tempDirPath))
- Util.downloadUrl(baseUrl + "/mnist/mnist.zip",
- tempDirPath + "/mnist/mnist.zip")
- // TODO: Need to confirm with Windows
- Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
- + tempDirPath + "/mnist/") !
+ logger.info("Downloading mnist model")
+ val baseUrl =
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ val modelDirPath = tempDirPath + File.separator + "mnist/"
+ logger.info("tempDirPath: %s".format(tempDirPath))
+ Util.downloadUrl(baseUrl + "/mnist/mnist.zip",
+ tempDirPath + "/mnist/mnist.zip")
+ // TODO: Need to confirm with Windows
+ Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+ + tempDirPath + "/mnist/") !
- var context = Context.cpu()
+ var context = Context.cpu()
- val output = TrainMnist.test(modelDirPath)
- Process("rm -rf " + modelDirPath) !
+ val valAccuracy = TrainModel.test("mlp", modelDirPath)
+ Process("rm -rf " + modelDirPath) !
- assert(output >= 0.95f)
+ assert(valAccuracy >= 0.95f)
}
+
+ for(model <- List("mlp", "lenet", "resnet")) {
+ test(s"Example CI: Test Image Classification Model ${model}") {
+ var context = Context.cpu()
+ val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true)
+ }
+ }
+
}