Author: tommaso
Date: Fri Oct 14 11:57:34 2016
New Revision: 1764880

URL: http://svn.apache.org/viewvc?rev=1764880&view=rev
Log:
fixed wrong adagrad update, more cli params for nn runner

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1764880&r1=1764879&r2=1764880&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Fri Oct 14 
11:57:34 2016
@@ -41,7 +41,7 @@ public class NNRunner {
         case "recurrent": {
           // recurrent neural network
           // e.g. bin/nn recurrent 
core/src/test/resources/word2vec/sentences.txt true 100 25 100 stacked
-          float learningRate = 1e-2f;
+          float learningRate = 1e-1f;
           int seqLength = 25;
           int hiddenLayerSize = 30;
           int epochs = 20;
@@ -69,9 +69,13 @@ public class NNRunner {
           if (args.length > 5 && args[5] != null) {
             seqLength = Integer.valueOf(args[5]);
           }
-          RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text, useChars);
           if (args.length > 6 && args[6] != null) {
-            if ("stacked".equals(args[6])) {
+            learningRate = Float.valueOf(args[6]);
+          }
+
+          RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text, useChars);
+          if (args.length > 7 && args[7] != null) {
+            if ("stacked".equals(args[7])) {
               rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text, useChars);
             }
           }
@@ -79,6 +83,11 @@ public class NNRunner {
           rnn.learn();
           int seed = random.nextInt(rnn.vocabSize);
           System.out.println(rnn.sample(seed));
+          try {
+            rnn.serialize("weights.txt");
+          } catch (IOException e) {
+            throw new RuntimeException("cannot serialize weights", e);
+          }
           break;
         }
         case "multi": {

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1764880&r1=1764879&r2=1764880&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Fri Oct 14 
11:57:34 2016
@@ -170,19 +170,19 @@ public class RNN {
 
       // perform parameter update with Adagrad
       mWxh.addi(dWxh.mul(dWxh));
-      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8))));
+      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(1e-8))));
 
       mWhh.addi(dWhh.mul(dWhh));
-      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8))));
+      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(1e-8))));
 
       mWhy.addi(dWhy.mul(dWhy));
-      why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.addi(1e-8))));
+      why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(1e-8))));
 
       mbh.addi(dbh.mul(dbh));
-      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8))));
+      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(1e-8))));
 
       mby.addi(dby.mul(dby));
-      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8))));
+      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(1e-8))));
 
       p += seqLength; // move data pointer
       n++; // iteration counter

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1764880&r1=1764879&r2=1764880&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Oct 14 
11:57:34 2016
@@ -22,6 +22,7 @@ import org.apache.commons.math3.distribu
 import org.apache.commons.math3.util.Pair;
 import org.nd4j.linalg.api.ndarray.INDArray;
 import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
+import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
 import org.nd4j.linalg.factory.Nd4j;
 import org.nd4j.linalg.ops.transforms.Transforms;
 
@@ -74,7 +75,7 @@ public class StackedRNN extends RNN {
 
   public void learn() {
 
-    int currentEpoch = 0;
+    int currentEpoch = -1;
 
     int n = 0;
     int p = 0;
@@ -136,25 +137,25 @@ public class StackedRNN extends RNN {
 
       // perform parameter update with Adagrad
       mWxh.addi(dWxh.mul(dWxh));
-      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8))));
+      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(1e-8))));
 
       mWhh.addi(dWhh.mul(dWhh));
-      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8))));
+      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(1e-8))));
 
       mWhh2.addi(dWhh2.mul(dWhh2));
-      
whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.addi(1e-8))));
+      whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(1e-8))));
 
       mbh2.addi(dbh2.mul(dbh2));
-      bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.addi(1e-8))));
+      bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(1e-8))));
 
       mWh2y.addi(dWh2y.mul(dWh2y));
-      
wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.addi(1e-8))));
+      wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(1e-8))));
 
       mbh.addi(dbh.mul(dbh));
-      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8))));
+      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(1e-8))));
 
       mby.addi(dby.mul(dby));
-      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8))));
+      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(1e-8))));
 
       p += seqLength; // move data pointer
       n++; // iteration counter
@@ -190,6 +191,7 @@ public class StackedRNN extends RNN {
 
       INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
       INDArray hst = 
Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh)));
 // hidden state
+//      INDArray hst = 
Transforms.relu((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh)));
 // hidden state
       if (hs == null) {
         hs = init(inputs.length(), hst);
       }
@@ -197,6 +199,7 @@ public class StackedRNN extends RNN {
 
       INDArray hs2Row = t == 0 ? hs12 : hs2.getRow(t - 1);
       INDArray hst2 = 
Transforms.tanh((whh.mmul(hs.getRow(t))).add((whh2.mmul(hs2Row)).add(bh2))); // 
hidden state 2
+//      INDArray hst2 = 
Transforms.relu((whh.mmul(hs.getRow(t))).add((whh2.mmul(hs2Row)).add(bh2))); // 
hidden state 2
       if (hs2 == null) {
         hs2 = init(inputs.length(), hst2);
       }
@@ -207,9 +210,11 @@ public class StackedRNN extends RNN {
         ys = init(inputs.length(), yst);
       }
       ys.putRow(t, yst);
-      INDArray exp = Transforms.exp(yst);
-      Number sumExp = exp.sumNumber();
-      INDArray pst = exp.div(sumExp); // probabilities for next chars
+
+//      INDArray exp = Transforms.exp(yst);
+//      Number sumExp = exp.sumNumber();
+//      INDArray pst = exp.div(sumExp); // probabilities for next chars
+      INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst));
       if (ps == null) {
         ps = init(inputs.length(), pst);
       }
@@ -272,7 +277,9 @@ public class StackedRNN extends RNN {
 
     for (int t = 0; t < sampleSize; t++) {
       INDArray h = 
Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh)));
+//      INDArray h = 
Transforms.relu((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh)));
       INDArray h2 = 
Transforms.tanh((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2)));
+//      INDArray h2 = 
Transforms.relu((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2)));
       INDArray y = (wh2y.mmul(h2)).add(by);
       INDArray exp = Transforms.exp(y);
       INDArray pm = exp.div(Nd4j.sum(exp)).ravel();



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to