Author: tommaso Date: Sat Oct 29 06:25:35 2016 New Revision: 1767095 URL: http://svn.apache.org/viewvc?rev=1767095&view=rev Log: using SoftMax TransofmOp for more compact code
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1767095&r1=1767094&r2=1767095&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Sat Oct 29 06:25:35 2016 @@ -23,6 +23,7 @@ import org.apache.commons.math3.util.Pai import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.SetRange; +import org.nd4j.linalg.api.ops.impl.transforms.SoftMax; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; @@ -203,7 +204,7 @@ public class RNN { /** * inputs, targets are both list of integers * hprev is Hx1 array of initial hidden state - * returns the loss, gradients on model parameters and last hidden state + * returns the modified loss, gradients on model parameters */ private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhy, INDArray dbh, INDArray dby) { @@ -234,9 +235,7 @@ public class RNN { ys = init(inputs.length(), yst); } ys.putRow(t, yst); - INDArray exp = Transforms.exp(yst); - Number sumExp = exp.sumNumber(); - INDArray pst = exp.div(sumExp); // probabilities for next chars + INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars if (ps == null) { ps = init(inputs.length(), pst); } @@ -298,8 +297,7 @@ public class RNN { for (int t = 0; t < sampleSize; t++) { h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(h)).add(bh))); INDArray y = (why.mmul(h)).add(by); - INDArray exp = Transforms.exp(y); - INDArray pm = exp.div(Nd4j.sum(exp)).ravel(); + INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel(); List<Pair<Integer, Double>> d = new LinkedList<>(); for (int pi = 0; pi < vocabSize; pi++) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org