Repository: opennlp-sandbox Updated Branches: refs/heads/master fe2b1d920 -> 6f0659f2a
removed useless state update, minor fixes Project: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/commit/6f0659f2 Tree: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/tree/6f0659f2 Diff: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/diff/6f0659f2 Branch: refs/heads/master Commit: 6f0659f2ad2f3186ff0b266203ae960659ad1d98 Parents: fe2b1d9 Author: Tommaso Teofili <[email protected]> Authored: Sat Jul 1 14:12:48 2017 +0200 Committer: Tommaso Teofili <[email protected]> Committed: Sat Jul 1 14:12:48 2017 +0200 ---------------------------------------------------------------------- .../src/main/java/opennlp/tools/dl/StackedRNN.java | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/6f0659f2/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java index 889fac1..e9a5f7e 100644 --- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java +++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java @@ -29,6 +29,7 @@ import java.util.List; import org.apache.commons.math3.distribution.EnumeratedDistribution; import org.apache.commons.math3.util.Pair; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.ReplaceNans; import org.nd4j.linalg.api.ops.impl.transforms.SoftMax; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; @@ -118,8 +119,10 @@ public class StackedRNN extends RNN { // sample from the model every now and then if (n % 1000 == 0 && n > 0) { - String txt = sample(inputs.getInt(0)); - System.out.printf("\n---\n %s \n----\n", txt); + for (int i = 0; i < 3; i++) { + String txt = sample(inputs.getInt(0)); + System.out.printf("\n---\n %s \n----\n", txt); + } } INDArray dWxh = Nd4j.zerosLike(wxh); @@ -171,7 +174,6 @@ public class StackedRNN extends RNN { by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps))); } else { // perform parameter update with Adagrad - mWxh.addi(dWxh.mul(dWxh)); wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps))); @@ -244,13 +246,13 @@ public class StackedRNN extends RNN { } ys.putRow(t, yst); - INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars + INDArray pst = Nd4j.getExecutioner().execAndReturn(new ReplaceNans(Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)), 0d)); // probabilities for next chars if (ps == null) { ps = init(seqLength, pst.shape()); } ps.putRow(t, pst); - loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax (cross-entropy loss) + loss += -Math.log(pst.getDouble(targets.getInt(t),0)); // softmax (cross-entropy loss) } // backward pass: compute gradients going backwards @@ -284,9 +286,6 @@ public class StackedRNN extends RNN { dhNext = whh.transpose().mmul(dhraw); } - this.hPrev = hs.getRow(seqLength - 1); - this.hPrev2 = hs2.getRow(seqLength - 1); - return loss; } @@ -298,7 +297,7 @@ public class StackedRNN extends RNN { INDArray x = Nd4j.zeros(vocabSize, 1); x.putScalar(seedIx, 1); - int sampleSize = seqLength * 2; + int sampleSize = 100; INDArray ixes = Nd4j.create(sampleSize); INDArray h = hPrev.dup();
