Author: tommaso
Date: Mon Oct 24 13:14:59 2016
New Revision: 1766403

URL: http://svn.apache.org/viewvc?rev=1766403&view=rev
Log:
minor fixes

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
    labs/yay/trunk/pom.xml

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1766403&r1=1766402&r2=1766403&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Mon Oct 24 
13:14:59 2016
@@ -48,8 +48,10 @@ public class NNRunner {
           boolean useChars = true;
           String text = "";
 
+          String name = "";
           if (args.length > 1 && args[1] != null) {
             Path path = Paths.get(args[1]);
+            name = path.getFileName().toString();
             try {
               byte[] bytes = Files.readAllBytes(path);
               text = new String(bytes);
@@ -81,10 +83,13 @@ public class NNRunner {
           }
 
           rnn.learn();
-          int seed = random.nextInt(rnn.vocabSize);
-          System.out.println(rnn.sample(seed));
+
+          for (int i = 0; i < 10; i++) {
+            int seed = random.nextInt(rnn.vocabSize);
+            System.out.println(rnn.sample(seed));
+          }
           try {
-            rnn.serialize("weights-");
+            rnn.serialize(name + "-weights-");
           } catch (IOException e) {
             throw new RuntimeException("cannot serialize weights", e);
           }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1766403&r1=1766402&r2=1766403&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Mon Oct 24 
13:14:59 2016
@@ -58,6 +58,7 @@ public class RNN {
   protected final Map<String, Integer> charToIx;
   protected final Map<Integer, String> ixToChar;
   protected final List<String> data;
+  private final static double reg = 1e-8;
 
   // model parameters
   private final INDArray wxh; // input to hidden
@@ -144,7 +145,7 @@ public class RNN {
       INDArray targets = getSequence(p + 1);
 
       // sample from the model every now and then
-      if (n % 1000 == 0 && n > 0) {
+      if (n % 100 == 0 && n > 0) {
         String txt = sample(inputs.getInt(0));
         System.out.printf("\n---\n %s \n----\n", txt);
       }
@@ -163,25 +164,25 @@ public class RNN {
         System.out.println("loss is NaN (over/underflow occured, try adjusting 
hyperparameters)");
         break;
       }
-      if (n % 1000 == 0) {
+      if (n % 100 == 0) {
         System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print 
progress
       }
 
       // perform parameter update with Adagrad
       mWxh.addi(dWxh.mul(dWxh));
-      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(1e-8))));
+      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg))));
 
       mWhh.addi(dWhh.mul(dWhh));
-      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(1e-8))));
+      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
 
       mWhy.addi(dWhy.mul(dWhy));
-      why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(1e-8))));
+      why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg))));
 
       mbh.addi(dbh.mul(dbh));
-      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(1e-8))));
+      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
 
       mby.addi(dby.mul(dby));
-      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(1e-8))));
+      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
 
       p += seqLength; // move data pointer
       n++; // iteration counter
@@ -260,8 +261,8 @@ public class RNN {
       INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
       dWhh.addi(dhraw.mmul(hsRow.transpose()));
       dhNext = whh.transpose().mmul(dhraw);
-
     }
+
     // clip exploding gradients
     Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5));
     Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5));
@@ -289,7 +290,7 @@ public class RNN {
 
     INDArray x = Nd4j.zeros(vocabSize, 1);
     x.putScalar(seedIx, 1);
-    int sampleSize = 200;
+    int sampleSize = 2 * seqLength;
     INDArray ixes = Nd4j.create(sampleSize);
 
     INDArray h = hPrev.dup();

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1766403&r1=1766402&r2=1766403&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Mon Oct 24 
13:14:59 2016
@@ -220,9 +220,6 @@ public class StackedRNN extends RNN {
       loss += -Transforms.log(ps.getRow(t).getRow(targets.getInt(t)), 
true).sumNumber().doubleValue(); // softmax (cross-entropy loss)
     }
 
-    this.hPrev = hs.getRow(inputs.length() - 1);
-    this.hPrev2 = hs2.getRow(inputs.length() - 1);
-
     // backward pass: compute gradients going backwards
     INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
     INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0));
@@ -241,7 +238,7 @@ public class StackedRNN extends RNN {
 
       INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // dh = 
np.dot(Why.T, dy) + dhnext # backprop into h2
 
-      INDArray dhraw2 = 
(Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); // dhraw = (1 - hs[t] * 
hs[t]) * dh # backprop through tanh nonlinearity
+      INDArray dhraw2 = 
(Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); //  backprop through 
tanh nonlinearity
 //      INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new 
SetRange(hst2, 0, Double.MAX_VALUE)).mul(dh2); // backprop through relu 
nonlinearity
       dbh2.addi(dhraw2); // dbh += dhraw
       INDArray hst = hs.getRow(t);
@@ -254,12 +251,16 @@ public class StackedRNN extends RNN {
 //      INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 
0, Double.MAX_VALUE)).mul(dh); // backprop through relu nonlinearity
       dbh.addi(dhraw);
 
-      dWxh.addi(dhraw.mmul(xs.getRow(t)));
+      dWxh.addi(dhraw.mmul(xs.getRow(t))); // dWxh += np.dot(dhraw, xs[t].T)
       INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
-      dWhh.addi(dhraw.mmul(hsRow.transpose()));
-      dhNext = whh.transpose().mmul(dhraw);
+      dWhh.addi(dhraw.mmul(hsRow.transpose())); // dWhh += np.dot(dhraw, 
hs[t-1].T)
+      dhNext = whh.transpose().mmul(dhraw); // dhnext = np.dot(Whh.T, dhraw)
 
     }
+
+    this.hPrev = hs.getRow(inputs.length() - 1);
+    this.hPrev2 = hs2.getRow(inputs.length() - 1);
+
     // clip exploding gradients
     Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5));
     Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5));
@@ -280,7 +281,7 @@ public class StackedRNN extends RNN {
 
     INDArray x = Nd4j.zeros(vocabSize, 1);
     x.putScalar(seedIx, 1);
-    int sampleSize = 200;
+    int sampleSize = seqLength * 2;
     INDArray ixes = Nd4j.create(sampleSize);
 
     INDArray h = hPrev.dup();

Modified: labs/yay/trunk/pom.xml
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1766403&r1=1766402&r2=1766403&view=diff
==============================================================================
--- labs/yay/trunk/pom.xml (original)
+++ labs/yay/trunk/pom.xml Mon Oct 24 13:14:59 2016
@@ -187,7 +187,7 @@
       <plugin>
         <groupId>org.apache.maven.plugins</groupId>
         <artifactId>maven-surefire-plugin</artifactId>
-        <version>2.12</version>
+        <version>2.18.1</version>
       </plugin>
     </plugins>
     <resources>



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to