Author: tommaso Date: Fri Mar 10 09:01:11 2017 New Revision: 1786303 URL: http://svn.apache.org/viewvc?rev=1786303&view=rev Log: improved sRNN
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1786303&r1=1786302&r2=1786303&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Mar 10 09:01:11 2017 @@ -50,6 +50,7 @@ public class StackedRNN extends RNN { private final INDArray whh; // hidden to hidden private final INDArray whh2; // hidden to hidden2 private final INDArray wh2y; // hidden2 to output + private final INDArray wxh2; private final INDArray bh; // hidden bias private final INDArray bh2; // hidden2 bias private final INDArray by; // output bias @@ -69,6 +70,7 @@ public class StackedRNN extends RNN { wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize)); whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); + wxh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).div(Math.sqrt(vocabSize)); bh = Nd4j.zeros(hiddenLayerSize, 1); bh2 = Nd4j.zeros(hiddenLayerSize, 1); @@ -84,6 +86,7 @@ public class StackedRNN extends RNN { // memory variables for Adagrad INDArray mWxh = Nd4j.zerosLike(wxh); + INDArray mWxh2 = Nd4j.zerosLike(wxh2); INDArray mWhh = Nd4j.zerosLike(whh); INDArray mWhh2 = Nd4j.zerosLike(whh2); INDArray mWh2y = Nd4j.zerosLike(wh2y); @@ -118,6 +121,7 @@ public class StackedRNN extends RNN { } INDArray dWxh = Nd4j.zerosLike(wxh); + INDArray dWxh2 = Nd4j.zerosLike(wxh2); INDArray dWhh = Nd4j.zerosLike(whh); INDArray dWhh2 = Nd4j.zerosLike(whh2); INDArray dWh2y = Nd4j.zerosLike(wh2y); @@ -127,7 +131,7 @@ public class StackedRNN extends RNN { INDArray dby = Nd4j.zerosLike(by); // forward seqLength characters through the net and fetch gradient - double loss = lossFun(inputs, targets, dWxh, dWhh, dWhh2, dWh2y, dbh, dbh2, dby); + double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby); smoothLoss = smoothLoss * 0.999 + loss * 0.001; if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) { System.out.println("loss is " + smoothLoss + " (over/underflow occured, try adjusting hyperparameters)"); @@ -141,6 +145,9 @@ public class StackedRNN extends RNN { mWxh.addi(dWxh.mul(dWxh)); wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg)))); + mWxh2.addi(dWxh2.mul(dWxh2)); + wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg)))); + mWhh.addi(dWhh.mul(dWhh)); whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); @@ -169,7 +176,7 @@ public class StackedRNN extends RNN { * hprev is Hx1 array of initial hidden state * returns the loss, gradients on model parameters and last hidden state */ - private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhh2, INDArray dWh2y, + private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y, INDArray dbh, INDArray dbh2, INDArray dby) { INDArray xs = Nd4j.zeros(seqLength, vocabSize); @@ -193,13 +200,13 @@ public class StackedRNN extends RNN { } hs.putRow(t, hPrev.dup()); - hPrev2 = Transforms.tanh((whh.mmul(hs.getRow(t)).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2 + hPrev2 = Transforms.tanh((wxh2.mmul(hPrev).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2 if (hs2 == null) { hs2 = init(seqLength, hPrev2.shape()); } hs2.putRow(t, hPrev2.dup()); - INDArray yst = wh2y.mmul(hs2.getRow(t)).add(by); // unnormalized log probabilities for next chars + INDArray yst = wh2y.mmul(hPrev2).add(by); // unnormalized log probabilities for next chars if (ys == null) { ys = init(seqLength, yst.shape()); } @@ -231,11 +238,11 @@ public class StackedRNN extends RNN { INDArray dhraw2 = (Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); // backprop through tanh nonlinearity dbh2.addi(dhraw2); INDArray hst = hs.getRow(t); - dWhh.addi(dhraw2.mmul(hst.transpose())); + dWxh2.addi(dhraw2.mmul(hst.transpose())); dWhh2.addi(dhraw2.mmul(hs2tm1.transpose())); dh2Next = whh2.transpose().mmul(dhraw2); - INDArray dh = dh2Next.add(dhNext); // backprop into h + INDArray dh = wxh2.transpose().mmul(dhraw2).add(dhNext); // backprop into h INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // backprop through tanh nonlinearity dbh.addi(dhraw); dWxh.addi(dhraw.mmul(xs.getRow(t))); @@ -250,13 +257,14 @@ public class StackedRNN extends RNN { // clip exploding gradients int clip = 5; - Nd4j.getExecutioner().exec(new SetRange(dWxh, -clip, clip)); - Nd4j.getExecutioner().exec(new SetRange(dWhh, -clip, clip)); - Nd4j.getExecutioner().exec(new SetRange(dWhh2, -clip, clip)); - Nd4j.getExecutioner().exec(new SetRange(dWh2y, -clip, clip)); - Nd4j.getExecutioner().exec(new SetRange(dbh, -clip, clip)); - Nd4j.getExecutioner().exec(new SetRange(dbh2, -clip, clip)); - Nd4j.getExecutioner().exec(new SetRange(dby, -clip, clip)); + dWxh = Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -clip, clip)); + dWxh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh2, -clip, clip)); + dWhh = Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -clip, clip)); + dWhh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh2, -clip, clip)); + dWh2y = Nd4j.getExecutioner().execAndReturn(new SetRange(dWh2y, -clip, clip)); + dbh = Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -clip, clip)); + dbh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dbh2, -clip, clip)); + dby = Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -clip, clip)); return loss; } @@ -275,8 +283,8 @@ public class StackedRNN extends RNN { INDArray h2 = hPrev2.dup(); for (int t = 0; t < sampleSize; t++) { - h = Transforms.tanh(((wxh.mmul(x)).add((whh.mmul(h)).add(bh)))); - h2 = Transforms.tanh(((whh.mmul(h)).add((whh2.mmul(h2)).add(bh2)))); + h = Transforms.tanh((wxh.mmul(x)).add(whh.mmul(h)).add(bh)); + h2 = Transforms.tanh((wxh2.mmul(h)).add(whh2.mmul(h2)).add(bh2)); INDArray y = wh2y.mmul(h2).add(by); INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel(); Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?rev=1786303&r1=1786302&r2=1786303&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Fri Mar 10 09:01:11 2017 @@ -61,11 +61,7 @@ public class RNNCrossValidationTest { @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][]{ - {1e-1f, 50, 15}, - {1e-1f, 50, 25}, - {1e-1f, 50, 50}, - {1e-1f, 50, 100}, - {1e-1f, 50, 150}, + {1e-1f, 25, 100}, }); } @@ -83,20 +79,6 @@ public class RNNCrossValidationTest { rnn.serialize("target/wrnn-weights-"); } - @Test - public void testStackedCharRNNLearn() throws Exception { - RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); - evaluate(rnn, true); - rnn.serialize("target/scrnn-weights-"); - } - - @Test - public void testStackedWordRNNLearn() throws Exception { - RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); - evaluate(rnn, false); - rnn.serialize("target/swrnn-weights-"); - } - private void evaluate(RNN rnn, boolean checkRatio) { System.out.println(rnn); rnn.learn(); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org