This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new eed7a34 Reduced test to 3 epochs and made gpu only (#11863)
eed7a34 is described below
commit eed7a34aa8c8145950fd282cdfe3ab16a358dc5c
Author: Andrew Ayres <[email protected]>
AuthorDate: Wed Aug 1 13:22:04 2018 -0700
Reduced test to 3 epochs and made gpu only (#11863)
* Reduced test to 3 epochs and made GPU only
* Moved logger variable so that it's accessible
---
.../mxnetexamples/multitask/MultiTaskSuite.scala | 25 ++++++++++++----------
1 file changed, 14 insertions(+), 11 deletions(-)
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
index dab9770..b86f675 100644
---
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
@@ -44,21 +44,24 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* This will run as a part of "make scalatest"
*/
class MultiTaskSuite extends FunSuite {
-
test("Multitask Test") {
val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite])
- logger.info("Multitask Test...")
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ logger.info("Multitask Test...")
- val batchSize = 100
- val numEpoch = 10
- val ctx = Context.cpu()
+ val batchSize = 100
+ val numEpoch = 3
+ val ctx = Context.gpu()
- val modelPath = ExampleMultiTask.getTrainingData
- val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch,
ctx, modelPath)
- evalMetric.get.foreach { case (name, value) =>
- assert(value >= 0.95f)
+ val modelPath = ExampleMultiTask.getTrainingData
+ val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch,
ctx, modelPath)
+ evalMetric.get.foreach { case (name, value) =>
+ assert(value >= 0.95f)
+ }
+ executor.dispose()
+ } else {
+ logger.info("GPU test only, skipped...")
}
- executor.dispose()
}
-
}