gigasquid closed pull request #12721: [MXNET-716] Adding Scala Inference Benchmarks URL: https://github.com/apache/incubator-mxnet/pull/12721
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/scala-package/examples/scripts/benchmark/run_image_inference_bm.sh b/scala-package/examples/scripts/benchmark/run_image_inference_bm.sh new file mode 100755 index 00000000000..82aa9f62234 --- /dev/null +++ b/scala-package/examples/scripts/benchmark/run_image_inference_bm.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +echo $OSTYPE + +hw_type=cpu +if [ "$1" = "gpu" ] +then + hw_type=gpu +fi + +platform=linux-x86_64 + +if [ "$OSTYPE" == "darwin"* ] +then + platform=osx-x86_64 +fi + +MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd) +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* + +MODEL_NAME=$2 + +RUNS=$3 + +BATCHSIZE=$4 + +# model dir +MODEL_PATH_PREFIX=$5 +# input image +INPUT_IMG=$6 +# which input image dir +INPUT_DIR=$7 + +java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \ + org.apache.mxnetexamples.benchmark.ScalaInferenceBenchmark \ + --example $MODEL_NAME \ + --count $RUNS \ + --batchSize $BATCHSIZE \ + --model-path-prefix $MODEL_PATH_PREFIX \ + --input-image $INPUT_IMG \ + --input-dir $INPUT_DIR \ + diff --git a/scala-package/examples/scripts/infer/imageclassifier/get_resnet_18_data.sh b/scala-package/examples/scripts/infer/imageclassifier/get_resnet_18_data.sh new file mode 100755 index 00000000000..4ba9fd5ac4c --- /dev/null +++ b/scala-package/examples/scripts/infer/imageclassifier/get_resnet_18_data.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd) + +data_path=$MXNET_ROOT/scripts/infer/models/resnet-18/ + +image_path=$MXNET_ROOT/scripts/infer/images/ + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +if [ ! -d "$image_path" ]; then + mkdir -p "$image_path" +fi + +if [ ! -f "$data_path" ]; then + wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $data_path + wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $data_path + wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $data_path + wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg -P $image_path +fi diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/InferBase.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/InferBase.scala new file mode 100644 index 00000000000..85d5c85329b --- /dev/null +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/InferBase.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples + +import org.apache.mxnet._ + +trait InferBase { + + def loadModel(context: Array[Context]): Any + def loadSingleData(): Any + def loadBatchFileList(batchSize: Int): List[Any] + def loadInputBatch(source: Any): Any + def runSingleInference(loadedModel: Any, input: Any): Any + def runBatchInference(loadedModel: Any, input: Any): Any +} diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/README.md new file mode 100644 index 00000000000..cfff93397be --- /dev/null +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/README.md @@ -0,0 +1,49 @@ +# Benchmarking Scala Inference APIs + +This folder contains a base class [ScalaInferenceBenchmark](https://github.com/apache/incubator-mxnet/tree/master/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/) and provides a mechanism for benchmarking [MXNet Inference APIs]((https://github.com/apache/incubator-mxnet/tree/master/scala-package/infer)) in Scala. +The benchmarking scripts provided runs an experiment for single inference calls and batch inference calls. It collects the time taken to perform an inference operation and emits the P99, P50 and Average values for these metrics. One can easily add/modify any new/existing examples to the ScalaInferenceBenchmark framework in order to get the benchmark numbers for inference calls. +Currently the ScalaInferenceBenchmark script supports three Scala examples : +1. [ImageClassification using ResNet-152](https://github.com/apache/incubator-mxnet/blob/master/scala-package/mxnet-demo/src/main/scala/sample/ImageClassificationExample.scala) +2. [Object Detection Example](https://github.com/apache/incubator-mxnet/blob/master/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala) +3. [Text Generation through RNNs](https://github.com/apache/incubator-mxnet/blob/master/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala) + +This script can be easily placed in an automated environment to run benchmark regressions on the Scala APIs. The script automatically picks up whether you are running it on a CPU machine or on a GPU machine and appropriately uses that. + +## Contents + +1. [Prerequisites](#prerequisites) +2. [Scripts](#scripts) + +## Prerequisites + +1. MXNet +2. MXNet Scala Package +3. [IntelliJ IDE (or alternative IDE) project setup](http://mxnet.incubator.apache.org/tutorials/scala/mxnet_scala_on_intellij.html) with the MXNet Scala Package +4. Model files and datasets for the model one will try to benchmark + +## Scripts +To help you easily run the benchmarks, a starter shell script has been provided for each of three examples mentioned above. The scripts can be found [here](https://github.com/apache/incubator-mxnet/blob/master/scala-package/examples/scripts/benchmark). +Each of the script takes some parameters as inputs, details of which can be found either in the bash scripts or in the example classes itself. + +* *ImageClassification Example* +<br> The following shows an example of running ImageClassifier under the benchmark script. The script takes as parameters, the platform type (cpu/gpu), number of iterations for inference calls, the batch size for batch inference calls, the model path, input file, and input directory. +For more details to run ImageClassificationExample as a standalone file, refer to the [README](https://github.com/apache/incubator-mxnet/blob/master/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/README.md) for ImageClassifierExample. +You may need to run ```chmod u+x run_image_inference_bm.sh``` before running this script. + ```bash + cd <Path-To-MXNET-Repo>/scala-package/examples/scripts/infer/imageclassifier + ./get_resnet_data.sh + cd <Path-To-MXNET-Repo>/scala-package/examples/scripts/benchmark + ./run_image_inference_bm.sh gpu ImageClassifierExample 100 10 ../infer/models/resnet-152/resnet-152 ../infer/images/kitten.jpg ../infer/images/ + ./run_image_inference_bm.sh cpu ImageClassifierExample 100 10 ../infer/models/resnet-152/resnet-152 ../infer/images/kitten.jpg ../infer/images/ + ``` + Upon running this script, you might see an output like this : + ``` + [main] INFO org.apache.mxnetexamples.benchmark.CLIParserBase - + single_inference_latency p99 1663, single_inference_p50 729, single_inference_average 755.17 + ... + + INFO org.apache.mxnetexamples.benchmark.CLIParserBase - + batch_inference_latency p99 4241, batch_inference_p50 4241, batch_inference_average 4241.00 + ``` + +More examples to be added soon. \ No newline at end of file diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala new file mode 100644 index 00000000000..9ae50dc9d12 --- /dev/null +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.benchmark + +import org.apache.mxnetexamples.InferBase +import org.apache.mxnetexamples.infer.imageclassifier.ImageClassifierExample +import org.apache.mxnet._ +import org.kohsuke.args4j.{CmdLineParser, Option} +import org.slf4j.LoggerFactory + +import scala.collection.JavaConverters._ + +object ScalaInferenceBenchmark { + + private val logger = LoggerFactory.getLogger(classOf[CLIParserBase]) + + def loadModel(objectToRun: InferBase, context: Array[Context]): + Any = { + objectToRun.loadModel(context) + } + + def loadDataSet(objectToRun: InferBase): + Any = { + objectToRun.loadSingleData() + } + + def loadBatchDataSet(objectToRun: InferBase, batchSize: Int): + List[Any] = { + objectToRun.loadBatchFileList(batchSize) + } + + def runInference(objectToRun: InferBase, loadedModel: Any, dataSet: Any, totalRuns: Int): + List[Long] = { + var inferenceTimes: List[Long] = List() + for (i <- 1 to totalRuns) { + NDArrayCollector.auto().withScope { + val startTimeSingle = System.currentTimeMillis() + objectToRun.runSingleInference(loadedModel, dataSet) + val estimatedTimeSingle = System.currentTimeMillis() - startTimeSingle + inferenceTimes = estimatedTimeSingle :: inferenceTimes + logger.info("Inference time at iteration: %d is : %d \n".format(i, estimatedTimeSingle)) + } + } + + inferenceTimes + } + + def runBatchInference(objecToRun: InferBase, loadedModel: Any, dataSetBatches: List[Any]): + List[Long] = { + + var inferenceTimes: List[Long] = List() + for (batch <- dataSetBatches) { + NDArrayCollector.auto().withScope { + val loadedBatch = objecToRun.loadInputBatch(batch) + val startTimeSingle = System.currentTimeMillis() + objecToRun.runBatchInference(loadedModel, loadedBatch) + val estimatedTimeSingle = System.currentTimeMillis() - startTimeSingle + inferenceTimes = estimatedTimeSingle :: inferenceTimes + logger.info("Batch Inference time is : %d \n".format(estimatedTimeSingle)) + } + } + + inferenceTimes + } + + def percentile(p: Int, seq: Seq[Long]): Long = { + val sorted = seq.sorted + val k = math.ceil((seq.length - 1) * (p / 100.0)).toInt + sorted(k) + } + + def printStatistics(inferenceTimes: List[Long], metricsPrefix: String) { + + val times: Seq[Long] = inferenceTimes + val p50 = percentile(50, times) + val p99 = percentile(99, times) + val average = times.sum / (times.length * 1.0) + + logger.info("\n%s_p99 %d, %s_p50 %d, %s_average %1.2f".format(metricsPrefix, + p99, metricsPrefix, p50, metricsPrefix, average)) + + } + + def main(args: Array[String]): Unit = { + + var context = Context.cpu() + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { + context = Context.gpu() + } + var baseCLI : CLIParserBase = null + try { + val exampleName = args(1) + val exampleToBenchmark : InferBase = exampleName match { + case "ImageClassifierExample" => { + val imParser = new org.apache.mxnetexamples.infer.imageclassifier.CLIParser + baseCLI = imParser + val parsedVals = new CmdLineParser(imParser).parseArgument(args.toList.asJava) + new ImageClassifierExample(imParser) + } + case _ => throw new Exception("Invalid example name to run") + } + + logger.info("Running single inference call") + // Benchmarking single inference call + NDArrayCollector.auto().withScope { + val loadedModel = loadModel(exampleToBenchmark, context) + val dataSet = loadDataSet(exampleToBenchmark) + val inferenceTimes = runInference(exampleToBenchmark, loadedModel, dataSet, baseCLI.count) + printStatistics(inferenceTimes, "single_inference") + } + + if (baseCLI.batchSize != 0) { + logger.info("Running for batch inference call") + // Benchmarking batch inference call + NDArrayCollector.auto().withScope { + val loadedModel = loadModel(exampleToBenchmark, context) + val batchDataSet = loadBatchDataSet(exampleToBenchmark, baseCLI.batchSize) + val inferenceTimes = runBatchInference(exampleToBenchmark, loadedModel, batchDataSet) + printStatistics(inferenceTimes, "batch_inference") + } + } + + } catch { + case ex: Exception => { + logger.error(ex.getMessage, ex) + new CmdLineParser(baseCLI).printUsage(System.err) + sys.exit(1) + } + } + } + +} + +class CLIParserBase { + @Option(name = "--example", usage = "The scala example to benchmark") + val exampleName: String = "ImageClassifierExample" + @Option(name = "--count", usage = "number of times to run inference") + val count: Int = 1000 + @Option(name = "--batchSize", usage = "BatchSize to run batchinference calls", required = false) + val batchSize: Int = 0 +} diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala index f6e4fe0941d..9b6f19a93a9 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala @@ -20,11 +20,18 @@ package org.apache.mxnetexamples.infer.imageclassifier import org.apache.mxnet._ import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.LoggerFactory -import org.apache.mxnet.infer.ImageClassifier +import org.apache.mxnet.infer.{Classifier, ImageClassifier} import scala.collection.JavaConverters._ import java.io.File +import org.apache.mxnetexamples.benchmark.CLIParserBase +// scalastyle:off +import java.awt.image.BufferedImage +// scalastyle:on + +import org.apache.mxnetexamples.InferBase + import scala.collection.mutable.ListBuffer // scalastyle:off @@ -108,7 +115,7 @@ object ImageClassifierExample { } def main(args: Array[String]): Unit = { - val inst = new ImageClassifierExample + val inst = new CLIParser val parser: CmdLineParser = new CmdLineParser(inst) var context = Context.cpu() @@ -157,11 +164,71 @@ object ImageClassifierExample { } } -class ImageClassifierExample { +class CLIParser extends CLIParserBase{ @Option(name = "--model-path-prefix", usage = "the input model directory") - private val modelPathPrefix: String = "/resnet-152/resnet-152" + val modelPathPrefix: String = "/resnet-152/resnet-152" @Option(name = "--input-image", usage = "the input image") - private val inputImagePath: String = "/images/kitten.jpg" + val inputImagePath: String = "/images/kitten.jpg" @Option(name = "--input-dir", usage = "the input batch of images directory") - private val inputImageDir: String = "/images/" + val inputImageDir: String = "/images/" +} + +class ImageClassifierExample(CLIParser: CLIParser) extends InferBase{ + + override def loadModel(context: Array[Context]): Classifier = { + val dType = DType.Float32 + val inputShape = Shape(1, 3, 224, 224) + + val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW")) + + // Create object of ImageClassifier class + val imgClassifier: ImageClassifier = new ImageClassifier(CLIParser.modelPathPrefix, + inputDescriptor, context) + imgClassifier + } + + override def loadSingleData(): Any = { + val img = ImageClassifier.loadImageFromFile(CLIParser.inputImagePath) + img + } + + override def loadBatchFileList(batchSize: Int): List[Any] = { + val dir = new File(CLIParser.inputImageDir) + require(dir.exists && dir.isDirectory, + "input image directory: %s not found".format(CLIParser.inputImageDir)) + val output = ListBuffer[List[String]]() + var batch = ListBuffer[String]() + for (imgFile: File <- dir.listFiles()){ + batch += imgFile.getPath + if (batch.length == batchSize) { + output += batch.toList + batch = ListBuffer[String]() + } + } + if (batch.length > 0) { + output += batch.toList + } + output.toList + } + + override def loadInputBatch(inputPaths: Any): Any = { + val batchFile = inputPaths.asInstanceOf[List[String]] + ImageClassifier.loadInputBatch(batchFile) + } + + override def runSingleInference(loadedModel: Any, input: Any): Any = { + // Running inference on single image + val imageModel = loadedModel.asInstanceOf[ImageClassifier] + val imgInput = input.asInstanceOf[BufferedImage] + val output = imageModel.classifyImage(imgInput, Some(5)) + output + } + + override def runBatchInference(loadedModel: Any, input: Any): Any = { + val imageModel = loadedModel.asInstanceOf[ImageClassifier] + val imgInput = input.asInstanceOf[Traversable[BufferedImage]] + val output = imageModel.classifyImageBatch(imgInput, Some(5)) + output + } + } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala index 0edde9e6516..53c4d367048 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala @@ -193,4 +193,4 @@ object SSDClassifierExample { exist } -} +} \ No newline at end of file diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala index bd064dbd351..5eb3ab41805 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala @@ -23,9 +23,9 @@ import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ /** - * Follows the demo, to test the char rnn: - * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb - */ + * Follows the demo, to test the char rnn: + * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb + */ object TestCharRnn { private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn]) diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala new file mode 100644 index 00000000000..8786e63efbb --- /dev/null +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnetexamples.benchmark + +import java.io.File + +import org.apache.mxnetexamples.Util +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.slf4j.LoggerFactory + +import scala.sys.process.Process + +class ScalaInferenceBenchmarkSuite extends FunSuite with BeforeAndAfterAll { + private val logger = LoggerFactory.getLogger(classOf[ScalaInferenceBenchmarkSuite]) + override def beforeAll(): Unit = { + } + + test("Testing Benchmark -- Image Classification") { + logger.info("Downloading resnet-18 model") + val tempDirPath = System.getProperty("java.io.tmpdir") + logger.info("tempDirPath: %s".format(tempDirPath)) + val baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models" + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json", + tempDirPath + "/resnet18/resnet-18-symbol.json") + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params", + tempDirPath + "/resnet18/resnet-18-0000.params") + Util.downloadUrl(baseUrl + "/resnet-18/synset.txt", + tempDirPath + "/resnet18/synset.txt") + Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg") + val modelDirPath = tempDirPath + File.separator + "resnet18/" + val inputImagePath = tempDirPath + File.separator + + "inputImages/resnet18/Pug-Cookie.jpg" + val inputImageDir = tempDirPath + File.separator + "inputImages/resnet18/" + val args = Array( + "--example", "ImageClassifierExample", + "--count", "1", + "--batchSize", "10", + "--model-path-prefix", s"$modelDirPath/resnet-18", + "--input-image", inputImagePath, + "--input-dir", inputImageDir + ) + ScalaInferenceBenchmark.main(args) + } + +} diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala index 34d3bc97a00..d8631df5405 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala @@ -64,7 +64,7 @@ class ImageClassifierExampleSuite extends FunSuite with BeforeAndAfterAll { } val output = ImageClassifierExample.runInferenceOnSingleImage(modelDirPath + "resnet-18", - inputImagePath, context) + inputImagePath, context) val outputList = ImageClassifierExample.runInferenceOnBatchOfImage(modelDirPath + "resnet-18", inputImageDir, context) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
