Author: tommaso
Date: Sat Oct 29 06:25:35 2016
New Revision: 1767095

URL: http://svn.apache.org/viewvc?rev=1767095&view=rev
Log:
using SoftMax TransofmOp for more compact code

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1767095&r1=1767094&r2=1767095&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Sat Oct 29 
06:25:35 2016
@@ -23,6 +23,7 @@ import org.apache.commons.math3.util.Pai
 import org.nd4j.linalg.api.iter.NdIndexIterator;
 import org.nd4j.linalg.api.ndarray.INDArray;
 import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
+import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
 import org.nd4j.linalg.factory.Nd4j;
 import org.nd4j.linalg.ops.transforms.Transforms;
 
@@ -203,7 +204,7 @@ public class RNN {
   /**
    * inputs, targets are both list of integers
    * hprev is Hx1 array of initial hidden state
-   * returns the loss, gradients on model parameters and last hidden state
+   * returns the modified loss, gradients on model parameters
    */
   private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, 
INDArray dWhh, INDArray dWhy, INDArray dbh,
                          INDArray dby) {
@@ -234,9 +235,7 @@ public class RNN {
         ys = init(inputs.length(), yst);
       }
       ys.putRow(t, yst);
-      INDArray exp = Transforms.exp(yst);
-      Number sumExp = exp.sumNumber();
-      INDArray pst = exp.div(sumExp); // probabilities for next chars
+      INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // 
probabilities for next chars
       if (ps == null) {
         ps = init(inputs.length(), pst);
       }
@@ -298,8 +297,7 @@ public class RNN {
     for (int t = 0; t < sampleSize; t++) {
       h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(h)).add(bh)));
       INDArray y = (why.mmul(h)).add(by);
-      INDArray exp = Transforms.exp(y);
-      INDArray pm = exp.div(Nd4j.sum(exp)).ravel();
+      INDArray pm = Nd4j.getExecutioner().execAndReturn(new 
SoftMax(y)).ravel();
 
       List<Pair<Integer, Double>> d = new LinkedList<>();
       for (int pi = 0; pi < vocabSize; pi++) {



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to