Author: tommaso Date: Tue Oct 18 09:35:05 2016 New Revision: 1765407 URL: http://svn.apache.org/viewvc?rev=1765407&view=rev Log: bugfixes to bptt in srnn, minor fixes
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1765407&r1=1765406&r2=1765407&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Tue Oct 18 09:35:05 2016 @@ -84,7 +84,7 @@ public class NNRunner { int seed = random.nextInt(rnn.vocabSize); System.out.println(rnn.sample(seed)); try { - rnn.serialize("weights.txt"); + rnn.serialize("weights-"); } catch (IOException e) { throw new RuntimeException("cannot serialize weights", e); } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1765407&r1=1765406&r2=1765407&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Tue Oct 18 09:35:05 2016 @@ -255,7 +255,7 @@ public class RNN { dWhy.addi(dy.mmul(hst.transpose())); dby.addi(dy); INDArray dh = why.transpose().mmul(dy).add(dhNext); // backprop into h - INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // backprop through tanh nonlinearity + INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // backprop through tanh nonlinearity dbh.addi(dhraw); dWxh.addi(dhraw.mmul(xs.getRow(t))); INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); @@ -357,7 +357,7 @@ public class RNN { } public void serialize(String prefix) throws IOException { - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(prefix + new Date().toString() + ".csv"))); + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(prefix + new Date().toString() + ".txt"))); bufferedWriter.write("wxh"); bufferedWriter.write(wxh.toString()); bufferedWriter.write("whh"); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1765407&r1=1765406&r2=1765407&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Tue Oct 18 09:35:05 2016 @@ -96,6 +96,7 @@ public class StackedRNN extends RNN { while (true) { // prepare inputs (we're sweeping from left to right in steps seqLength long) if (p + seqLength + 1 >= data.size() || n == 0) { +// hPrev2 = hPrev.dup(); // reset RNN memory to previous lower layer memory hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory hPrev2 = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory p = 0; // go from start of data @@ -110,7 +111,7 @@ public class StackedRNN extends RNN { INDArray targets = getSequence(p + 1); // sample from the model every now and then - if (n % 1000 == 0 && n > 0) { + if (n % 100 == 0 && n > 0) { String txt = sample(inputs.getInt(0)); System.out.printf("\n---\n %s \n----\n", txt); } @@ -131,7 +132,7 @@ public class StackedRNN extends RNN { System.out.println("loss is NaN (over/underflow occured, try adjusting hyperparameters)"); break; } - if (n % 1000 == 0) { + if (n % 100 == 0) { System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress } @@ -223,28 +224,55 @@ public class StackedRNN extends RNN { } this.hPrev = hs.getRow(inputs.length() - 1); + this.hPrev2 = hs2.getRow(inputs.length() - 1); // backward pass: compute gradients going backwards INDArray dhNext = Nd4j.zerosLike(hs.getRow(0)); INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0)); for (int t = inputs.length() - 1; t >= 0; t--) { - INDArray dy = ps.getRow(t).dup(); - dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y - - INDArray hst2 = hs2.getRow(t); - dWh2y.addi(dy.mmul(hst2.transpose())); - dby.addi(dy); - INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // backprop into h2 - INDArray dhraw2 = (Nd4j.ones(hst2.shape()).sub(hst2).mul(hst2)).mul(dh2); // backprop through tanh nonlinearity -// INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new SetRange(hst2, 0, Double.MAX_VALUE));; // backprop through relu nonlinearity - dbh2.addi(dhraw2); + /* + dy = np.copy(ps[t]) + dy[targets[t]] -= 1 # backprop into y + dWhy += np.dot(dy, hs2[t].T) + dby += dy + dh2 = np.dot(Why.T, dy) + dhnext # backprop into h2 + dhraw2 = (1 - hs2[t] * hs2[t]) * dh2 # backprop through tanh nonlinearity + dbh2 += dhraw2 + dWhh += np.dot(dhraw2, hs[t].T) + dWhh2 += np.dot(dhraw2, hs2[t-1].T) + dhnext2 = np.dot(Whh2.T, dhraw2) + dh = np.dot(Whh2.T, dh2) + dhnext2 # backprop into h + dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity + dbh += dhraw + dWxh += np.dot(dhraw, xs[t].T) + dWhh += np.dot(dhraw, hs[t-1].T) + dhnext = np.dot(Whh.T, dhraw) + */ + + INDArray dy = ps.getRow(t).dup(); // dy = np.copy(ps[t]) +// dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y + dy.getRow(targets.getInt(t)).subi(1); // dy[targets[t]] -= 1 # backprop into ybackprop into y + + INDArray hs2t = hs2.getRow(t); + INDArray hs2tm1 = t == 0 ? hs12 : hs2.getRow(t - 1); + + dWh2y.addi(dy.mmul(hs2t.transpose())); // dWhy += np.dot(dy, hs[t].T) + dby.addi(dy); // dby += dy + + INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // dh = np.dot(Why.T, dy) + dhnext #Â backprop into h2 + + INDArray dhraw2 = (Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); // dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity +// INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new SetRange(hst2, 0, Double.MAX_VALUE)).mul(dh2); // backprop through relu nonlinearity + dbh2.addi(dhraw2); // dbh += dhraw INDArray hst = hs.getRow(t); - dWhh2.addi(dh2.mmul(hst.transpose())); - dbh.addi(dh2); + dWhh.addi(dhraw2.mmul(hst.transpose())); // dWxh += np.dot(dhraw, xs[t].T) + dWhh2.addi(dhraw2.mmul(hs2tm1.transpose())); // dWhh += np.dot(dhraw, hs[t-1].T) + dh2Next = whh2.transpose().mmul(dhraw2); // dhnext = np.dot(Whh.T, dhraw) + INDArray dh = whh2.transpose().mmul(dh2).add(dhNext); // backprop into h - INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // backprop through tanh nonlinearity -// INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 0, Double.MAX_VALUE));; // backprop through relu nonlinearity + INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // backprop through tanh nonlinearity +// INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 0, Double.MAX_VALUE)).mul(dh); // backprop through relu nonlinearity dbh.addi(dhraw); dWxh.addi(dhraw.mmul(xs.getRow(t))); @@ -252,6 +280,7 @@ public class StackedRNN extends RNN { dWhh.addi(dhraw.mmul(hsRow.transpose())); dhNext = whh.transpose().mmul(dhraw); + } // clip exploding gradients Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5)); @@ -259,6 +288,7 @@ public class StackedRNN extends RNN { Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh2, -5, 5)); Nd4j.getExecutioner().execAndReturn(new SetRange(dWh2y, -5, 5)); Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dbh2, -5, 5)); Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -5, 5)); return loss; @@ -281,9 +311,12 @@ public class StackedRNN extends RNN { INDArray h2 = Transforms.tanh((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2))); // INDArray h2 = Transforms.relu((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2))); INDArray y = (wh2y.mmul(h2)).add(by); - INDArray exp = Transforms.exp(y); - INDArray pm = exp.div(Nd4j.sum(exp)).ravel(); + INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel(); +// INDArray exp = Transforms.exp(y); +// INDArray pm = exp.div(Nd4j.sum(exp)).ravel(); + +// Nd4j.getExecutioner().exec(new ReplaceNans(pm, 0.0000001)); List<Pair<Integer, Double>> d = new LinkedList<>(); for (int pi = 0; pi < vocabSize; pi++) { d.add(new Pair<>(pi, pm.getDouble(0, pi))); @@ -302,7 +335,7 @@ public class StackedRNN extends RNN { @Override public void serialize(String prefix) throws IOException { - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(prefix + new Date().toString() + ".csv"))); + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(prefix + new Date().toString() + ".txt"))); bufferedWriter.write("wxh"); bufferedWriter.write(wxh.toString()); bufferedWriter.write("whh"); Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?rev=1765407&r1=1765406&r2=1765407&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Tue Oct 18 09:35:05 2016 @@ -41,7 +41,7 @@ public class RNNCrossValidationTest { private int hiddenLayerSize; private Random r = new Random(); private String text; - private final int epochs = 2; + private final int epochs = 5; private List<String> words; public RNNCrossValidationTest(float learningRate, int seqLength, int hiddenLayerSize) { @@ -61,12 +61,12 @@ public class RNNCrossValidationTest { @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][]{ - {3e-1f, 50, 5}, - {3e-1f, 50, 10}, - {3e-1f, 50, 15}, - {3e-1f, 50, 25}, - {3e-1f, 50, 50}, - {3e-1f, 50, 100}, + {1e-1f, 50, 5}, + {1e-1f, 50, 10}, + {1e-1f, 50, 15}, + {1e-1f, 50, 25}, + {1e-1f, 50, 50}, + {1e-1f, 50, 100}, }); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org