Author: tommaso
Date: Fri Mar  3 10:27:37 2017
New Revision: 1785257

URL: http://svn.apache.org/viewvc?rev=1785257&view=rev
Log:
minor tweaks

Modified:
    labs/yay/trunk/core/pom.xml
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java

Modified: labs/yay/trunk/core/pom.xml
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1785257&r1=1785256&r2=1785257&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Fri Mar  3 10:27:37 2017
@@ -26,7 +26,7 @@
         <relativePath>../</relativePath>
     </parent>
     <properties>
-        <nd4j.version>0.6.0</nd4j.version>
+        <dl4j.version>0.7.2</dl4j.version>
     </properties>
     <name>Yay core</name>
     <dependencies>
@@ -58,7 +58,12 @@
         <dependency>
             <groupId>org.nd4j</groupId>
             <artifactId>nd4j-native-platform</artifactId>
-            <version>${nd4j.version}</version>
+            <version>${dl4j.version}</version>
+        </dependency>
+        <dependency>
+            <groupId>org.deeplearning4j</groupId>
+            <artifactId>deeplearning4j-nlp</artifactId>
+            <version>${dl4j.version}</version>
         </dependency>
 
     </dependencies>

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1785257&r1=1785256&r2=1785257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Fri Mar  3 
10:27:37 2017
@@ -226,25 +226,23 @@ public class RNN {
       INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
       INDArray hst = 
Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh)));
 // hidden state
       if (hs == null) {
-        hs = init(inputs.length(), hst);
+        hs = init(inputs.length(), hst.shape());
       }
       hs.putRow(t, hst);
 
       INDArray yst = (why.mmul(hst)).add(by); // unnormalized log 
probabilities for next chars
       if (ys == null) {
-        ys = init(inputs.length(), yst);
+        ys = init(inputs.length(), yst.shape());
       }
       ys.putRow(t, yst);
       INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // 
probabilities for next chars
       if (ps == null) {
-        ps = init(inputs.length(), pst);
+        ps = init(inputs.length(), pst.shape());
       }
       ps.putRow(t, pst);
       loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax 
(cross-entropy loss)
     }
 
-    this.hPrev = hs.getRow(inputs.length() - 1);
-
     // backward pass: compute gradients going backwards
     INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
     for (int t = inputs.length() - 1; t >= 0; t--) {
@@ -269,12 +267,13 @@ public class RNN {
     Nd4j.getExecutioner().exec(new SetRange(dbh, -5, 5));
     Nd4j.getExecutioner().exec(new SetRange(dby, -5, 5));
 
+    this.hPrev = hs.getRow(inputs.length() - 1);
+
     return loss;
   }
 
-  protected INDArray init(int t, INDArray ast) {
+  protected INDArray init(int t, int[] aShape) {
     INDArray as;
-    int[] aShape = ast.shape();
     int[] shape = new int[1 + aShape.length];
     shape[0] = t;
     System.arraycopy(aShape, 0, shape, 1, aShape.length);
@@ -292,9 +291,11 @@ public class RNN {
     int sampleSize = 2 * seqLength;
     INDArray ixes = Nd4j.create(sampleSize);
 
+    INDArray h = hPrev.dup();
+
     for (int t = 0; t < sampleSize; t++) {
-      hPrev = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh)));
-      INDArray y = (why.mmul(hPrev)).add(by);
+      h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(h)).add(bh)));
+      INDArray y = (why.mmul(h)).add(by);
       INDArray pm = Nd4j.getExecutioner().execAndReturn(new 
SoftMax(y)).ravel();
 
       List<Pair<Integer, Double>> d = new LinkedList<>();

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1785257&r1=1785256&r2=1785257&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Mar  3 
10:27:37 2017
@@ -66,10 +66,10 @@ public class StackedRNN extends RNN {
   public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, 
int epochs, String text, boolean useChars) {
     super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars);
 
-    wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01);
-    whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01);
-    whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01);
-    wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).mul(0.01);
+    wxh = Nd4j.randn(hiddenLayerSize, 
vocabSize).div(Math.sqrt(hiddenLayerSize));
+    whh = Nd4j.randn(hiddenLayerSize, 
hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
+    whh2 = Nd4j.randn(hiddenLayerSize, 
hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
+    wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).div(Math.sqrt(vocabSize));
     bh = Nd4j.zeros(hiddenLayerSize, 1);
     bh2 = Nd4j.zeros(hiddenLayerSize, 1);
     by = Nd4j.zeros(vocabSize, 1);
@@ -129,8 +129,8 @@ public class StackedRNN extends RNN {
       // forward seqLength characters through the net and fetch gradient
       double loss = lossFun(inputs, targets, dWxh, dWhh, dWhh2, dWh2y, dbh, 
dbh2, dby);
       smoothLoss = smoothLoss * 0.999 + loss * 0.001;
-      if (Double.isNaN(smoothLoss)) {
-        System.out.println("loss is NaN (over/underflow occured, try adjusting 
hyperparameters)");
+      if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
+        System.out.println("loss is " + smoothLoss + " (over/underflow 
occured, try adjusting hyperparameters)");
         break;
       }
       if (n % 100 == 0) {
@@ -139,7 +139,7 @@ public class StackedRNN extends RNN {
 
       // perform parameter update with Adagrad
       mWxh.addi(dWxh.mul(dWxh));
-      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg))));
+      wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
 
       mWhh.addi(dWhh.mul(dWhh));
       whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
@@ -185,30 +185,29 @@ public class StackedRNN extends RNN {
       int tIndex = inputs.getScalar(t).getInt(0);
       xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation
 
-      INDArray hsRow = t == 0 ? hPrev : hs.getRow(t - 1);
       INDArray xst = xs.getRow(t);
-      INDArray hst = 
Transforms.tanh((wxh.mmul(xst.transpose())).add((whh.mmul(hsRow)).add(bh))); // 
hidden state
+
+      hPrev = 
Transforms.tanh((wxh.mmul(xst.transpose()).add(whh.mmul(hPrev)).add(bh))); // 
hidden state
       if (hs == null) {
-        hs = init(seqLength, hst);
+        hs = init(seqLength, hPrev.shape());
       }
-      hs.putRow(t, hst);
+      hs.putRow(t, hPrev.dup());
 
-      INDArray hs2Row = t == 0 ? hPrev2 : hs2.getRow(t - 1);
-      INDArray hst2 = 
Transforms.tanh((whh.mmul(hst)).add((whh2.mmul(hs2Row)).add(bh2))); // hidden 
state 2
+      hPrev2 = 
Transforms.tanh((whh.mmul(hs.getRow(t)).add(whh2.mmul(hPrev2)).add(bh2))); // 
hidden state 2
       if (hs2 == null) {
-        hs2 = init(seqLength, hst2);
+        hs2 = init(seqLength, hPrev2.shape());
       }
-      hs2.putRow(t, hst2);
+      hs2.putRow(t, hPrev2.dup());
 
-      INDArray yst = (wh2y.mmul(hst2)).add(by); // unnormalized log 
probabilities for next chars
+      INDArray yst = wh2y.mmul(hs2.getRow(t)).add(by); // unnormalized log 
probabilities for next chars
       if (ys == null) {
-        ys = init(seqLength, yst);
+        ys = init(seqLength, yst.shape());
       }
       ys.putRow(t, yst);
 
       INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // 
probabilities for next chars
       if (ps == null) {
-        ps = init(seqLength, pst);
+        ps = init(seqLength, pst.shape());
       }
       ps.putRow(t, pst);
 
@@ -258,7 +257,6 @@ public class StackedRNN extends RNN {
     Nd4j.getExecutioner().exec(new SetRange(dbh, -clip, clip));
     Nd4j.getExecutioner().exec(new SetRange(dbh2, -clip, clip));
     Nd4j.getExecutioner().exec(new SetRange(dby, -clip, clip));
-
     return loss;
   }
 
@@ -273,26 +271,29 @@ public class StackedRNN extends RNN {
     int sampleSize = seqLength * 2;
     INDArray ixes = Nd4j.create(sampleSize);
 
-    INDArray h = hPrev;
-    INDArray h2 = hPrev2;
+    INDArray h = hPrev.dup();
+    INDArray h2 = hPrev2.dup();
 
     for (int t = 0; t < sampleSize; t++) {
-      h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(h)).add(bh)));
-      h2 = Transforms.tanh((whh.mmul(h)).add((whh2.mmul(h2)).add(bh2)));
-      INDArray y = (wh2y.mmul(h2)).add(by);
+      h = Transforms.tanh(((wxh.mmul(x)).add((whh.mmul(h)).add(bh))));
+      h2 = Transforms.tanh(((whh.mmul(h)).add((whh2.mmul(h2)).add(bh2))));
+      INDArray y = wh2y.mmul(h2).add(by);
       INDArray pm = Nd4j.getExecutioner().execAndReturn(new 
SoftMax(y)).ravel();
 
       List<Pair<Integer, Double>> d = new LinkedList<>();
       for (int pi = 0; pi < vocabSize; pi++) {
         d.add(new Pair<>(pi, pm.getDouble(0, pi)));
       }
-      EnumeratedDistribution<Integer> distribution = new 
EnumeratedDistribution<>(d);
+      try {
+        EnumeratedDistribution<Integer> distribution = new 
EnumeratedDistribution<>(d);
 
-      int ix = distribution.sample();
+        int ix = distribution.sample();
 
-      x = Nd4j.zeros(vocabSize, 1);
-      x.putScalar(ix, 1);
-      ixes.putScalar(t, ix);
+        x = Nd4j.zeros(vocabSize, 1);
+        x.putScalar(ix, 1);
+        ixes.putScalar(t, ix);
+      } catch (Exception e) {
+      }
     }
 
     return getSampleString(ixes);



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to