This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0038473 [MXNET-1222] Scala Inference enable different shapes input
(#13330)
0038473 is described below
commit 0038473e799bccd77f57718eb5f8af28b81c8284
Author: Lanking <[email protected]>
AuthorDate: Thu Nov 29 16:16:45 2018 -0800
[MXNET-1222] Scala Inference enable different shapes input (#13330)
* init commit with Predictor Improvement
* add predictor Example
* change into dArr
* add img config
* add new line and fix code style
important bug fixes
---
.../src/main/scala/org/apache/mxnet/Executor.scala | 4 +-
.../infer/predictor/PredictorExample.scala | 92 ++++++++++++++++++++++
.../ImageClassifierExampleSuite.scala | 5 +-
.../ObjectDetectorExampleSuite.scala | 5 +-
.../PredictorExampleSuite.scala} | 67 +++++++++-------
.../scala/org/apache/mxnet/infer/Predictor.scala | 31 ++++++--
.../org/apache/mxnet/infer/javaapi/Predictor.scala | 2 +-
7 files changed, 160 insertions(+), 46 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index b342a96..85f45bc 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -106,9 +106,9 @@ class Executor private[mxnet](private[mxnet] val handle:
ExecutorHandle,
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing =
true " +
"to enable allocation of new arrays.")
- newArgDict = newArgDict + (name -> NDArray.empty(newShape,
arr.context))
+ newArgDict = newArgDict + (name -> NDArray.empty(newShape,
arr.context, arr.dtype))
if (dArr != null) {
- newGradDict = newGradDict + (name -> NDArray.empty(newShape,
dArr.context))
+ newGradDict = newGradDict + (name -> NDArray.empty(newShape,
dArr.context, dArr.dtype))
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/predictor/PredictorExample.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/predictor/PredictorExample.scala
new file mode 100644
index 0000000..be90936
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/predictor/PredictorExample.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.infer.predictor
+
+import java.io.File
+
+import scala.io
+import org.apache.mxnet._
+import org.apache.mxnet.infer.Predictor
+import org.apache.mxnetexamples.benchmark.CLIParserBase
+import org.kohsuke.args4j.{CmdLineParser, Option}
+
+import scala.collection.JavaConverters._
+
+object PredictorExample {
+
+ def loadModel(modelPathPrefix : String, inputDesc : IndexedSeq[DataDesc],
+ context : Context, epoch : Int): Predictor = {
+ new Predictor(modelPathPrefix, inputDesc, context, Some(epoch))
+ }
+
+ def doInference(predictor : Predictor, imageND : NDArray):
IndexedSeq[NDArray] = {
+ predictor.predictWithNDArray(IndexedSeq(imageND))
+ }
+
+ def preProcess(imagePath: String, h: Int, w: Int) : NDArray = {
+ var img = Image.imRead(imagePath)
+ img = Image.imResize(img, h, w)
+ // HWC -> CHW
+ img = NDArray.api.transpose(img, Some(Shape(2, 0, 1)))
+ img = NDArray.api.expand_dims(img, 0)
+ img.asType(DType.Float32)
+ }
+
+ def postProcess(modelPathPrefix : String, result : Array[Float]) : String = {
+ val dirPath = modelPathPrefix.substring(0, 1 +
modelPathPrefix.lastIndexOf(File.separator))
+ val d = new File(dirPath)
+ require(d.exists && d.isDirectory, s"directory: $dirPath not found")
+ val f = io.Source.fromFile(dirPath + "synset.txt")
+ val s = f.getLines().toIndexedSeq
+ val maxIdx = result.zipWithIndex.maxBy(_._1)._2
+ printf(s"Predict Result ${s(maxIdx)} with prob ${result(maxIdx)}\n")
+ s(maxIdx)
+ }
+
+ def main(args : Array[String]): Unit = {
+ val inst = new CLIParser
+ val parser: CmdLineParser = new CmdLineParser(inst)
+
+ parser.parseArgument(args.toList.asJava)
+
+ var context = Context.cpu()
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ context = Context.gpu()
+ }
+
+ val imgWidth = 224
+ val imgHeight = 224
+
+ val inputDesc = IndexedSeq(new DataDesc("data", Shape(1, 3, imgHeight,
imgWidth),
+ DType.Float32, Layout.NCHW))
+
+ val predictor = loadModel(inst.modelPathPrefix, inputDesc, context, 0)
+ val img = preProcess(inst.inputImagePath, imgHeight, imgWidth)
+ val result = doInference(predictor, img)(0).toArray
+ postProcess(inst.modelPathPrefix, result)
+ }
+
+}
+
+class CLIParser extends CLIParserBase{
+ @Option(name = "--model-path-prefix", usage = "the input model directory")
+ val modelPathPrefix: String = "/resnet-152/resnet-152"
+ @Option(name = "--input-image", usage = "the input image")
+ val inputImagePath: String = "/images/kitten.jpg"
+}
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
index d8631df..27d9bb4 100644
---
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
@@ -20,10 +20,7 @@ package org.apache.mxnetexamples.infer.imageclassifier
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
import java.io.File
-import java.net.URL
-
-import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import sys.process.Process
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
index addc837..bd960bd 100644
---
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
@@ -18,10 +18,7 @@
package org.apache.mxnetexamples.infer.objectdetector
import java.io.File
-import java.net.URL
-
-import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/predictor/PredictorExampleSuite.scala
similarity index 51%
copy from
scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
copy to
scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/predictor/PredictorExampleSuite.scala
index d8631df..97ca33e 100644
---
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/predictor/PredictorExampleSuite.scala
@@ -15,27 +15,22 @@
* limitations under the License.
*/
-package org.apache.mxnetexamples.infer.imageclassifier
+package org.apache.mxnetexamples.infer.predictor
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.slf4j.LoggerFactory
import java.io.File
-import java.net.URL
-import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet._
import org.apache.mxnetexamples.Util
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
-import sys.process.Process
-
-/**
- * Integration test for imageClassifier example.
- * This will run as a part of "make scalatest"
- */
-class ImageClassifierExampleSuite extends FunSuite with BeforeAndAfterAll {
- private val logger =
LoggerFactory.getLogger(classOf[ImageClassifierExampleSuite])
+class PredictorExampleSuite extends FunSuite with BeforeAndAfterAll {
+ private val logger = LoggerFactory.getLogger(classOf[PredictorExampleSuite])
+ private var modelDirPrefix = ""
+ private var inputImagePath = ""
+ private var context = Context.cpu()
- test("testImageClassifierExample") {
+ override def beforeAll(): Unit = {
logger.info("Downloading resnet-18 model")
val tempDirPath = System.getProperty("java.io.tmpdir")
@@ -52,27 +47,41 @@ class ImageClassifierExampleSuite extends FunSuite with
BeforeAndAfterAll {
Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg")
- val modelDirPath = tempDirPath + File.separator + "resnet18/"
- val inputImagePath = tempDirPath + File.separator +
+ modelDirPrefix = tempDirPath + File.separator + "resnet18/resnet-18"
+ inputImagePath = tempDirPath + File.separator +
"inputImages/resnet18/Pug-Cookie.jpg"
- val inputImageDir = tempDirPath + File.separator + "inputImages/resnet18/"
- var context = Context.cpu()
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
context = Context.gpu()
}
+ val props = System.getProperties
+ props.setProperty("mxnet.disableShapeCheck", "true")
+ }
- val output = ImageClassifierExample.runInferenceOnSingleImage(modelDirPath
+ "resnet-18",
- inputImagePath, context)
-
- val outputList =
ImageClassifierExample.runInferenceOnBatchOfImage(modelDirPath + "resnet-18",
- inputImageDir, context)
-
- Process("rm -rf " + modelDirPath + " " + inputImageDir) !
-
- assert(output(0).toList.head._1 === "n02110958 pug, pug-dog")
- assert(outputList(0).toList.head._1 === "n02110958 pug, pug-dog")
+ override def afterAll(): Unit = {
+ val props = System.getProperties
+ props.setProperty("mxnet.disableShapeCheck", "false")
+ }
+ test("test Predictor With Fixed Shape and random shape") {
+ val inputDesc = IndexedSeq(new DataDesc("data", Shape(1, 3, 224, 224),
+ DType.Float32, Layout.NCHW))
+ val predictor = PredictorExample.loadModel(modelDirPrefix, inputDesc,
context, 0)
+ // fix size
+ var img = PredictorExample.preProcess(inputImagePath, 224, 224)
+ var result = PredictorExample.doInference(predictor, img)(0)
+ var top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
+ assert(top1 === "n02110958 pug, pug-dog")
+ // random size 512
+ img = PredictorExample.preProcess(inputImagePath, 512, 512)
+ result = PredictorExample.doInference(predictor, img)(0)
+ top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
+ assert(top1 === "n02110958 pug, pug-dog")
+ // original size
+ img = PredictorExample.preProcess(inputImagePath, 1024, 576)
+ result = PredictorExample.doInference(predictor, img)(0)
+ top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
+ assert(top1 === "n02110958 pug, pug-dog")
}
}
diff --git
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index e2a0e7c..d4bce9f 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -22,8 +22,10 @@ import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
import org.apache.mxnet.module.Module
import scala.collection.mutable.ListBuffer
+import scala.util.Try
import org.slf4j.LoggerFactory
+
/**
* Base Trait for MXNet Predictor classes.
*/
@@ -76,6 +78,21 @@ class Predictor(modelPathPrefix: String,
private val logger = LoggerFactory.getLogger(classOf[Predictor])
+ /*
+ By setting -Dmxnet.disableShapeCheck=true would disable the data Shape
+ Check of the predictor. Some model may allow different lens of the data
+ such as Seq2Seq, however there maybe risk of crashes if the lens beyond
+ the acceptable range of the model
+ */
+ private val traceProperty = "mxnet.disableShapeCheck"
+ private lazy val shapeCheckDisabled = {
+ val value =
Try(System.getProperty(traceProperty).toBoolean).getOrElse(false)
+ if (value) {
+ logger.warn("Shape check is disabled (property {} is set)",
traceProperty)
+ }
+ value
+ }
+
require(inputDescriptors.head.layout.size != 0, "layout size should not be
zero")
protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N')
@@ -172,9 +189,11 @@ class Predictor(modelPathPrefix: String,
for((i, d) <- inputBatch.zip(iDescriptors)) {
require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
"All inputs should be of same batch size")
- require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
- s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
- s"shape: ${d.shape} except batchSize")
+ if (!shapeCheckDisabled) {
+ require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
+ s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
+ s"shape: ${d.shape} except batchSize")
+ }
}
val inputBatchSize = inputBatch(0).shape(batchIndex)
@@ -182,8 +201,8 @@ class Predictor(modelPathPrefix: String,
// rebind with the new batchSize
if (batchSize != inputBatchSize) {
logger.info(s"Latency increased due to batchSize mismatch $batchSize vs
$inputBatchSize")
- val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
- Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)),
f.dtype, f.layout) )
+ val desc = inputBatch.zip(iDescriptors).map(f => new DataDesc(f._2.name,
+ f._1.shape, f._2.dtype, f._2.layout))
mxNetHandler.execute(mod.bind(desc, forceRebind = true,
forTraining = false))
}
@@ -200,7 +219,7 @@ class Predictor(modelPathPrefix: String,
private[infer] def loadModule(): Module = {
val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix,
epoch.get,
- contexts = contexts))
+ contexts = contexts, dataNames = inputDescriptors.map(desc =>
desc.name)))
mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false))
mod
}
diff --git
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index c867168..8c48742 100644
---
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -93,7 +93,7 @@ class Predictor private[mxnet] (val predictor:
org.apache.mxnet.infer.Predictor)
* This method is useful when the input is a batch of data
* Note: User is responsible for managing allocation/deallocation of
input/output NDArrays.
*
- * @param input List of NDArrays
+ * @param input List of NDArrays
* @return Output of predictions as NDArrays
*/
def predictWithNDArray(input: java.util.List[NDArray]):