This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8271005 Added context object to run TestCharRnn example (#12841)
8271005 is described below
commit 8271005af5753f35ed413457e9d267e435d94d4e
Author: Piyush Ghai <[email protected]>
AuthorDate: Wed Oct 17 09:45:39 2018 -0700
Added context object to run TestCharRnn example (#12841)
---
.../src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
index 0fbdf7d..25bf479 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
@@ -123,9 +123,14 @@ class TestCharRnn(CLIParser: CLIParser) extends InferBase {
val numLstmLayer = 3
val (_, argParams, _) = Model.loadCheckpoint(CLIParser.modelPrefix, 75)
this.vocab = Utils.buildVocab(CLIParser.dataPath)
+ var ctx = Context.cpu()
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ ctx = Context.gpu()
+ }
val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
numHidden = numHidden, numEmbed = numEmbed,
- numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
+ numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f, ctx =
ctx)
model
}