Author: tommaso
Date: Tue Oct 18 09:35:05 2016
New Revision: 1765407

URL: http://svn.apache.org/viewvc?rev=1765407&view=rev
Log:
bugfixes to bptt in srnn, minor fixes

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1765407&r1=1765406&r2=1765407&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Tue Oct 18 
09:35:05 2016
@@ -84,7 +84,7 @@ public class NNRunner {
           int seed = random.nextInt(rnn.vocabSize);
           System.out.println(rnn.sample(seed));
           try {
-            rnn.serialize("weights.txt");
+            rnn.serialize("weights-");
           } catch (IOException e) {
             throw new RuntimeException("cannot serialize weights", e);
           }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1765407&r1=1765406&r2=1765407&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Tue Oct 18 
09:35:05 2016
@@ -255,7 +255,7 @@ public class RNN {
       dWhy.addi(dy.mmul(hst.transpose()));
       dby.addi(dy);
       INDArray dh = why.transpose().mmul(dy).add(dhNext); // backprop into h
-      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // 
backprop through tanh nonlinearity
+      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // 
backprop through tanh nonlinearity
       dbh.addi(dhraw);
       dWxh.addi(dhraw.mmul(xs.getRow(t)));
       INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
@@ -357,7 +357,7 @@ public class RNN {
   }
 
   public void serialize(String prefix) throws IOException {
-    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File(prefix + new Date().toString() + ".csv")));
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File(prefix + new Date().toString() + ".txt")));
     bufferedWriter.write("wxh");
     bufferedWriter.write(wxh.toString());
     bufferedWriter.write("whh");

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1765407&r1=1765406&r2=1765407&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Tue Oct 18 
09:35:05 2016
@@ -96,6 +96,7 @@ public class StackedRNN extends RNN {
     while (true) {
       // prepare inputs (we're sweeping from left to right in steps seqLength 
long)
       if (p + seqLength + 1 >= data.size() || n == 0) {
+//        hPrev2 = hPrev.dup(); // reset RNN memory to previous lower layer 
memory
         hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
         hPrev2 = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
         p = 0; // go from start of data
@@ -110,7 +111,7 @@ public class StackedRNN extends RNN {
       INDArray targets = getSequence(p + 1);
 
       // sample from the model every now and then
-      if (n % 1000 == 0 && n > 0) {
+      if (n % 100 == 0 && n > 0) {
         String txt = sample(inputs.getInt(0));
         System.out.printf("\n---\n %s \n----\n", txt);
       }
@@ -131,7 +132,7 @@ public class StackedRNN extends RNN {
         System.out.println("loss is NaN (over/underflow occured, try adjusting 
hyperparameters)");
         break;
       }
-      if (n % 1000 == 0) {
+      if (n % 100 == 0) {
         System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print 
progress
       }
 
@@ -223,28 +224,55 @@ public class StackedRNN extends RNN {
     }
 
     this.hPrev = hs.getRow(inputs.length() - 1);
+    this.hPrev2 = hs2.getRow(inputs.length() - 1);
 
     // backward pass: compute gradients going backwards
     INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
     INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0));
     for (int t = inputs.length() - 1; t >= 0; t--) {
-      INDArray dy = ps.getRow(t).dup();
-      dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // 
backprop into y
-
-      INDArray hst2 = hs2.getRow(t);
-      dWh2y.addi(dy.mmul(hst2.transpose()));
-      dby.addi(dy);
-      INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // backprop into 
h2
-      INDArray dhraw2 = 
(Nd4j.ones(hst2.shape()).sub(hst2).mul(hst2)).mul(dh2); // backprop through 
tanh nonlinearity
-//      INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new 
SetRange(hst2, 0, Double.MAX_VALUE));; // backprop through relu nonlinearity
-      dbh2.addi(dhraw2);
 
+      /*
+       dy = np.copy(ps[t])
+        dy[targets[t]] -= 1 # backprop into y
+        dWhy += np.dot(dy, hs2[t].T)
+        dby += dy
+        dh2 = np.dot(Why.T, dy) + dhnext # backprop into h2
+        dhraw2 = (1 - hs2[t] * hs2[t]) * dh2 # backprop through tanh 
nonlinearity
+        dbh2 += dhraw2
+        dWhh += np.dot(dhraw2, hs[t].T)
+        dWhh2 += np.dot(dhraw2, hs2[t-1].T)
+        dhnext2 = np.dot(Whh2.T, dhraw2)
+        dh = np.dot(Whh2.T, dh2) + dhnext2 # backprop into h
+        dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity
+        dbh += dhraw
+        dWxh += np.dot(dhraw, xs[t].T)
+        dWhh += np.dot(dhraw, hs[t-1].T)
+        dhnext = np.dot(Whh.T, dhraw)
+       */
+
+      INDArray dy = ps.getRow(t).dup(); // dy = np.copy(ps[t])
+//      dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // 
backprop into y
+      dy.getRow(targets.getInt(t)).subi(1); // dy[targets[t]] -= 1 # backprop 
into ybackprop into y
+
+      INDArray hs2t = hs2.getRow(t);
+      INDArray hs2tm1 = t == 0 ? hs12 : hs2.getRow(t - 1);
+
+      dWh2y.addi(dy.mmul(hs2t.transpose())); // dWhy += np.dot(dy, hs[t].T)
+      dby.addi(dy); // dby += dy
+
+      INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // dh = 
np.dot(Why.T, dy) + dhnext # backprop into h2
+
+      INDArray dhraw2 = 
(Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); // dhraw = (1 - hs[t] * 
hs[t]) * dh # backprop through tanh nonlinearity
+//      INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new 
SetRange(hst2, 0, Double.MAX_VALUE)).mul(dh2); // backprop through relu 
nonlinearity
+      dbh2.addi(dhraw2); // dbh += dhraw
       INDArray hst = hs.getRow(t);
-      dWhh2.addi(dh2.mmul(hst.transpose()));
-      dbh.addi(dh2);
+      dWhh.addi(dhraw2.mmul(hst.transpose())); // dWxh += np.dot(dhraw, 
xs[t].T)
+      dWhh2.addi(dhraw2.mmul(hs2tm1.transpose())); // dWhh += np.dot(dhraw, 
hs[t-1].T)
+      dh2Next = whh2.transpose().mmul(dhraw2); // dhnext = np.dot(Whh.T, dhraw)
+
       INDArray dh = whh2.transpose().mmul(dh2).add(dhNext); // backprop into h
-      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // 
backprop through tanh nonlinearity
-//      INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 
0, Double.MAX_VALUE));; // backprop through relu nonlinearity
+      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // 
backprop through tanh nonlinearity
+//      INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 
0, Double.MAX_VALUE)).mul(dh); // backprop through relu nonlinearity
       dbh.addi(dhraw);
 
       dWxh.addi(dhraw.mmul(xs.getRow(t)));
@@ -252,6 +280,7 @@ public class StackedRNN extends RNN {
       dWhh.addi(dhraw.mmul(hsRow.transpose()));
       dhNext = whh.transpose().mmul(dhraw);
 
+
     }
     // clip exploding gradients
     Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5));
@@ -259,6 +288,7 @@ public class StackedRNN extends RNN {
     Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh2, -5, 5));
     Nd4j.getExecutioner().execAndReturn(new SetRange(dWh2y, -5, 5));
     Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dbh2, -5, 5));
     Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -5, 5));
 
     return loss;
@@ -281,9 +311,12 @@ public class StackedRNN extends RNN {
       INDArray h2 = 
Transforms.tanh((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2)));
 //      INDArray h2 = 
Transforms.relu((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2)));
       INDArray y = (wh2y.mmul(h2)).add(by);
-      INDArray exp = Transforms.exp(y);
-      INDArray pm = exp.div(Nd4j.sum(exp)).ravel();
+      INDArray pm = Nd4j.getExecutioner().execAndReturn(new 
SoftMax(y)).ravel();
+//      INDArray exp = Transforms.exp(y);
+//      INDArray pm = exp.div(Nd4j.sum(exp)).ravel();
+
 
+//      Nd4j.getExecutioner().exec(new ReplaceNans(pm, 0.0000001));
       List<Pair<Integer, Double>> d = new LinkedList<>();
       for (int pi = 0; pi < vocabSize; pi++) {
         d.add(new Pair<>(pi, pm.getDouble(0, pi)));
@@ -302,7 +335,7 @@ public class StackedRNN extends RNN {
 
   @Override
   public void serialize(String prefix) throws IOException {
-    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File(prefix + new Date().toString() + ".csv")));
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File(prefix + new Date().toString() + ".txt")));
     bufferedWriter.write("wxh");
     bufferedWriter.write(wxh.toString());
     bufferedWriter.write("whh");

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?rev=1765407&r1=1765406&r2=1765407&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java 
(original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java 
Tue Oct 18 09:35:05 2016
@@ -41,7 +41,7 @@ public class RNNCrossValidationTest {
   private int hiddenLayerSize;
   private Random r = new Random();
   private String text;
-  private final int epochs = 2;
+  private final int epochs = 5;
   private List<String> words;
 
   public RNNCrossValidationTest(float learningRate, int seqLength, int 
hiddenLayerSize) {
@@ -61,12 +61,12 @@ public class RNNCrossValidationTest {
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][]{
-            {3e-1f, 50, 5},
-            {3e-1f, 50, 10},
-            {3e-1f, 50, 15},
-            {3e-1f, 50, 25},
-            {3e-1f, 50, 50},
-            {3e-1f, 50, 100},
+            {1e-1f, 50, 5},
+            {1e-1f, 50, 10},
+            {1e-1f, 50, 15},
+            {1e-1f, 50, 25},
+            {1e-1f, 50, 50},
+            {1e-1f, 50, 100},
     });
   }
 



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to