This is an automated email from the ASF dual-hosted git repository. nswamy pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 777089d adding context parameter to infer api- imageclassifier and objectdetector (#10252) 777089d is described below commit 777089dd771735aa2c8efb4ae088a4a68ce896a4 Author: Roshani Nagmote <roshaninagmo...@gmail.com> AuthorDate: Mon Mar 26 16:27:17 2018 -0700 adding context parameter to infer api- imageclassifier and objectdetector (#10252) * adding context parameter * parameter description added --- .../ml/dmlc/mxnet/infer/ImageClassifier.scala | 18 +++++++---- .../scala/ml/dmlc/mxnet/infer/ObjectDetector.scala | 37 ++++++++++++++-------- .../ml/dmlc/mxnet/infer/ImageClassifierSuite.scala | 26 ++++++++------- .../ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala | 11 ++++--- 4 files changed, 56 insertions(+), 36 deletions(-) diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala index 45c4e76..070b0bf 100644 --- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala @@ -17,7 +17,7 @@ package ml.dmlc.mxnet.infer -import ml.dmlc.mxnet.{DataDesc, NDArray, Shape} +import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape} import scala.collection.mutable.ListBuffer @@ -37,13 +37,15 @@ import javax.imageio.ImageIO * file://model-dir/synset.txt * @param inputDescriptors Descriptors defining the input node names, shape, * layout and Type parameters + * @param contexts Device Contexts on which you want to run Inference, defaults to CPU. + * @param epoch Model epoch to load, defaults to 0. */ class ImageClassifier(modelPathPrefix: String, - inputDescriptors: IndexedSeq[DataDesc]) + inputDescriptors: IndexedSeq[DataDesc], + contexts: Array[Context] = Context.cpu(), + epoch: Option[Int] = Some(0)) extends Classifier(modelPathPrefix, - inputDescriptors) { - - val classifier: Classifier = getClassifier(modelPathPrefix, inputDescriptors) + inputDescriptors, contexts, epoch) { protected[infer] val inputLayout = inputDescriptors.head.layout @@ -108,8 +110,10 @@ class ImageClassifier(modelPathPrefix: String, result } - def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): Classifier = { - new Classifier(modelPathPrefix, inputDescriptors) + def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc], + contexts: Array[Context] = Context.cpu(), + epoch: Option[Int] = Some(0)): Classifier = { + new Classifier(modelPathPrefix, inputDescriptors, contexts, epoch) } } diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala index 2d83caf..30e1432 100644 --- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala @@ -16,12 +16,14 @@ */ package ml.dmlc.mxnet.infer + // scalastyle:off import java.awt.image.BufferedImage // scalastyle:on -import ml.dmlc.mxnet.NDArray -import ml.dmlc.mxnet.DataDesc + +import ml.dmlc.mxnet.{Context, DataDesc, NDArray} import scala.collection.mutable.ListBuffer + /** * A class for object detection tasks * @@ -32,11 +34,16 @@ import scala.collection.mutable.ListBuffer * file://model-dir/synset.txt * @param inputDescriptors Descriptors defining the input node names, shape, * layout and Type parameters + * @param contexts Device Contexts on which you want to run Inference, defaults to CPU. + * @param epoch Model epoch to load, defaults to 0. */ class ObjectDetector(modelPathPrefix: String, - inputDescriptors: IndexedSeq[DataDesc]) { + inputDescriptors: IndexedSeq[DataDesc], + contexts: Array[Context] = Context.cpu(), + epoch: Option[Int] = Some(0)) { - val imgClassifier: ImageClassifier = getImageClassifier(modelPathPrefix, inputDescriptors) + val imgClassifier: ImageClassifier = + getImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch) val inputShape = imgClassifier.inputShape @@ -54,7 +61,7 @@ class ObjectDetector(modelPathPrefix: String, * To Detect bounding boxes and corresponding labels * * @param inputImage : PathPrefix of the input image - * @param topK : Get top k elements with maximum probability + * @param topK : Get top k elements with maximum probability * @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax]) */ def imageObjectDetect(inputImage: BufferedImage, @@ -71,9 +78,10 @@ class ObjectDetector(modelPathPrefix: String, /** * Takes input images as NDArrays. Useful when you want to perform multiple operations on * the input Array, or when you want to pass a batch of input images. + * * @param input : Indexed Sequence of NDArrays - * @param topK : (Optional) How many top_k(sorting will be based on the last axis) - * elements to return. If not passed, returns all unsorted output. + * @param topK : (Optional) How many top_k(sorting will be based on the last axis) + * elements to return. If not passed, returns all unsorted output. * @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax]) */ def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int]) @@ -90,10 +98,10 @@ class ObjectDetector(modelPathPrefix: String, batchResult.toIndexedSeq } - private def sortAndReformat(predictResultND : NDArray, topK: Option[Int]) + private def sortAndReformat(predictResultND: NDArray, topK: Option[Int]) : IndexedSeq[(String, Array[Float])] = { val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]() - val accuracy : ListBuffer[Float] = ListBuffer[Float]() + val accuracy: ListBuffer[Float] = ListBuffer[Float]() // iterating over the all the predictions val length = predictResultND.shape(0) @@ -110,7 +118,7 @@ class ObjectDetector(modelPathPrefix: String, handler.execute(r.dispose()) } var result = IndexedSeq[(String, Array[Float])]() - if(topK.isDefined) { + if (topK.isDefined) { var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2) sortedIndices = sortedIndices.take(topK.get) // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax @@ -127,8 +135,9 @@ class ObjectDetector(modelPathPrefix: String, /** * To classify batch of input images according to the provided model + * * @param inputBatch Input batch of Buffered images - * @param topK Get top k elements with maximum probability + * @param topK Get top k elements with maximum probability * @return List of list of tuples of (class, probability) */ def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None): @@ -148,9 +157,11 @@ class ObjectDetector(modelPathPrefix: String, result } - def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): + def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc], + contexts: Array[Context] = Context.cpu(), + epoch: Option[Int] = Some(0)): ImageClassifier = { - new ImageClassifier(modelPathPrefix, inputDescriptors) + new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch) } } diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala index 96fc800..85059be 100644 --- a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala @@ -17,11 +17,10 @@ package ml.dmlc.mxnet.infer -import ml.dmlc.mxnet.{DType, DataDesc, Shape, NDArray} - +import ml.dmlc.mxnet._ import org.mockito.Matchers._ import org.mockito.Mockito -import org.scalatest.{BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll // scalastyle:off import java.awt.image.BufferedImage @@ -33,7 +32,7 @@ import java.awt.image.BufferedImage class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { class MyImageClassifier(modelPathPrefix: String, - inputDescriptors: IndexedSeq[DataDesc]) + inputDescriptors: IndexedSeq[DataDesc]) extends ImageClassifier(modelPathPrefix, inputDescriptors) { override def getPredictor(): MyClassyPredictor = { @@ -41,7 +40,8 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { } override def getClassifier(modelPathPrefix: String, inputDescriptors: - IndexedSeq[DataDesc]): Classifier = { + IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(), + epoch: Option[Int] = Some(0)): Classifier = { Mockito.mock(classOf[Classifier]) } @@ -84,7 +84,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { val synset = testImageClassifier.synset - val predictExpectedOp : List[(String, Float)] = + val predictExpectedOp: List[(String, Float)] = List[(String, Float)]((synset(1), .98f), (synset(2), .97f), (synset(3), .96f), (synset(0), .99f)) @@ -93,13 +93,14 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor) .predictWithNDArray(any(classOf[IndexedSeq[NDArray]])) - Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier) + Mockito.doReturn(IndexedSeq(predictExpectedOp)) + .when(testImageClassifier.getClassifier(modelPath, inputDescriptor)) .classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt())) val predictResult: IndexedSeq[IndexedSeq[(String, Float)]] = testImageClassifier.classifyImage(inputImage, Some(4)) - for(i <- predictExpected.indices) { + for (i <- predictExpected.indices) { assertResult(predictExpected(i).sortBy(-_)) { predictResult(i).map(_._2).toArray } @@ -119,15 +120,15 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { val predictExpected: IndexedSeq[Array[Array[Float]]] = IndexedSeq[Array[Array[Float]]](Array(Array(.98f, 0.97f, 0.96f, 0.99f), - Array(.98f, 0.97f, 0.96f, 0.99f))) + Array(.98f, 0.97f, 0.96f, 0.99f))) val synset = testImageClassifier.synset - val predictExpectedOp : List[List[(String, Float)]] = + val predictExpectedOp: List[List[(String, Float)]] = List[List[(String, Float)]](List((synset(1), .98f), (synset(2), .97f), (synset(3), .96f), (synset(0), .99f)), List((synset(1), .98f), (synset(2), .97f), - (synset(3), .96f), (synset(0), .99f))) + (synset(3), .96f), (synset(0), .99f))) val predictExpectedND: NDArray = NDArray.array(predictExpected.flatten.flatten.toArray, Shape(2, 4)) @@ -135,7 +136,8 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor) .predictWithNDArray(any(classOf[IndexedSeq[NDArray]])) - Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier) + Mockito.doReturn(IndexedSeq(predictExpectedOp)) + .when(testImageClassifier.getClassifier(modelPath, inputDescriptor)) .classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt())) val result: IndexedSeq[IndexedSeq[(String, Float)]] = diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala index a691aa3..5e6f32f 100644 --- a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala +++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala @@ -23,7 +23,7 @@ import java.awt.image.BufferedImage // scalastyle:on import ml.dmlc.mxnet.Context import ml.dmlc.mxnet.DataDesc -import ml.dmlc.mxnet.{NDArray, Shape} +import ml.dmlc.mxnet.{Context, NDArray, Shape} import org.mockito.Matchers.any import org.mockito.Mockito import org.scalatest.BeforeAndAfterAll @@ -36,7 +36,8 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll { extends ObjectDetector(modelPathPrefix, inputDescriptors) { override def getImageClassifier(modelPathPrefix: String, inputDescriptors: - IndexedSeq[DataDesc]): ImageClassifier = { + IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(), + epoch: Option[Int] = Some(0)): ImageClassifier = { new MyImageClassifier(modelPathPrefix, inputDescriptors) } @@ -44,13 +45,15 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll { class MyImageClassifier(modelPathPrefix: String, protected override val inputDescriptors: IndexedSeq[DataDesc]) - extends ImageClassifier(modelPathPrefix, inputDescriptors) { + extends ImageClassifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) { override def getPredictor(): MyClassyPredictor = { Mockito.mock(classOf[MyClassyPredictor]) } - override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): + override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc], + contexts: Array[Context] = Context.cpu(), + epoch: Option[Int] = Some(0)): Classifier = { new MyClassifier(modelPathPrefix, inputDescriptors) } -- To stop receiving notification emails like this one, please contact nsw...@apache.org.