This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new eed7a34  Reduced test to 3 epochs and made gpu only (#11863)
eed7a34 is described below

commit eed7a34aa8c8145950fd282cdfe3ab16a358dc5c
Author: Andrew Ayres <[email protected]>
AuthorDate: Wed Aug 1 13:22:04 2018 -0700

    Reduced test to 3 epochs and made gpu only (#11863)
    
    * Reduced test to 3 epochs and made GPU only
    
    * Moved logger variable so that it's accessible
---
 .../mxnetexamples/multitask/MultiTaskSuite.scala   | 25 ++++++++++++----------
 1 file changed, 14 insertions(+), 11 deletions(-)

diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
index dab9770..b86f675 100644
--- 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
@@ -44,21 +44,24 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
   * This will run as a part of "make scalatest"
   */
 class MultiTaskSuite extends FunSuite {
-
   test("Multitask Test") {
     val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite])
-    logger.info("Multitask Test...")
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      logger.info("Multitask Test...")
 
-    val batchSize = 100
-    val numEpoch = 10
-    val ctx = Context.cpu()
+      val batchSize = 100
+      val numEpoch = 3
+      val ctx = Context.gpu()
 
-    val modelPath = ExampleMultiTask.getTrainingData
-    val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch, 
ctx, modelPath)
-    evalMetric.get.foreach { case (name, value) =>
-      assert(value >= 0.95f)
+      val modelPath = ExampleMultiTask.getTrainingData
+      val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch, 
ctx, modelPath)
+      evalMetric.get.foreach { case (name, value) =>
+        assert(value >= 0.95f)
+      }
+      executor.dispose()
+    } else {
+      logger.info("GPU test only, skipped...")
     }
-    executor.dispose()
   }
-
 }

Reply via email to