This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8271005  Added context object to run TestCharRnn example (#12841)
8271005 is described below

commit 8271005af5753f35ed413457e9d267e435d94d4e
Author: Piyush Ghai <[email protected]>
AuthorDate: Wed Oct 17 09:45:39 2018 -0700

    Added context object to run TestCharRnn example (#12841)
---
 .../src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala  | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
index 0fbdf7d..25bf479 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
@@ -123,9 +123,14 @@ class TestCharRnn(CLIParser: CLIParser) extends InferBase {
     val numLstmLayer = 3
     val (_, argParams, _) = Model.loadCheckpoint(CLIParser.modelPrefix, 75)
     this.vocab = Utils.buildVocab(CLIParser.dataPath)
+    var ctx = Context.cpu()
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      ctx = Context.gpu()
+    }
     val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
       numHidden = numHidden, numEmbed = numEmbed,
-      numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
+      numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f, ctx = 
ctx)
     model
   }
 

Reply via email to